Compare commits

...

46 Commits

Author SHA1 Message Date
powderluv
b444528715 Pin torch-mlir for windows too 2023-06-23 19:19:28 -07:00
Ean Garvey
6e6c90f62b Pin torch-mlir and use local-task in OPT. (#1592) 2023-06-23 19:17:05 -07:00
AyaanShah2204
8cdb38496e Final REST API Fixes (#1590)
* fixed outpaint api and added tests

* fixed text2img api

* more elegant generator to subscriptable conversion

* final fixes
2023-06-23 16:46:47 -07:00
powderluv
726d73d6ba Revert "[vicuna] Add streaming of tokens (#1587)" (#1588)
This reverts commit 4d55e51d46.
2023-06-23 10:29:00 -07:00
Gaurav Shukla
4d55e51d46 [vicuna] Add streaming of tokens (#1587)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-23 08:20:46 -07:00
Prashant Kumar
6ef78ee7ba Add cpu compile time flags. (#1585) 2023-06-23 07:23:26 -07:00
jinchen62
4002da7161 Add int4/int8 options to chatbot webui (#1586) 2023-06-23 07:18:34 -07:00
powderluv
ecb5e8e5d8 Update txt2img_ui.py 2023-06-23 06:42:12 -07:00
PhaneeshB
28e0919321 Add AMD cpu device 2023-06-23 18:47:04 +05:30
Daniel Garvey
28f4d44a6b downloader was double downloading (#1580) 2023-06-22 18:30:27 -07:00
AyaanShah2204
97f7e79391 [Blender Integration] Fixed Inpainting REST API (#1577)
* fixed inpaint api

* added inpainting test

* fixed linter errors

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-22 16:08:26 -07:00
Nelson Sharpe
44a8f2f8db Include VAE & LoRA data into PNG metadata (#1573)
* include custom lora and vae data in png metadata

* include pycharm settings

* lint with black
2023-06-22 16:05:54 -07:00
Eliasj42
8822b9acd7 added ability to use config file to shard vicuna (#1565)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-06-22 17:40:35 -05:00
Daniel Garvey
0ca3b9fce3 fix some mmap and vicuna bugs (#1576) 2023-06-22 17:39:55 -05:00
Nithin Meganathan
045f2bb147 Add dispatch-level config file generator for manual annotation (#1566) 2023-06-22 15:11:41 -07:00
Prashant Kumar
a811b867b9 Add shark_eager mode.
-- Eager mode with step by step op compilation and execution.
2023-06-22 22:59:14 +05:30
Abhishek Varma
cdd505e2dd [SharkInference-SharkRuntime] Adds capability to mmap vmfbs
-- This commit is based on [VmModule.mmap() API](https://github.com/openxla/iree/pull/14124).
-- It thereby adds capability to mmap vmfbs in SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-22 20:43:40 +05:30
powderluv
1b0f39107c Move torch_mlir import to the top (#1574) 2023-06-21 22:31:35 -07:00
powderluv
b9b8955f74 exclude vulkan on macos 2023-06-21 22:22:27 -07:00
powderluv
6f7a85eee3 switch to metal backend for CI 2023-06-21 22:17:11 -07:00
Ranvir Singh Virk
18c8e9e51e Metal typo fix (#1572)
* fixing typos for metal changes

* black formating
2023-06-21 21:56:11 -07:00
Daniel Garvey
a202bb466a fp16 fixes for webui (#1571) 2023-06-21 20:24:02 -07:00
Ranvir Singh Virk
07c1e1d712 Adding metal_utils for iree_utils (#1561)
* Adding metal_utils for iree_utils

* Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)

-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)

* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.

* Fixing iree-metal-target-platform

* adding metal to txt2img pipeline

* Fixing Copyright date

* removing debug prints

* black lint formating

* fixing device dump

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <avarma094@gmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-21 19:09:03 -07:00
Ranvir Singh Virk
18daec78c8 Added check for python version (#1570)
* Added check for python version

* Update for PYTHON_VERSION_X_Y
2023-06-21 18:56:47 -07:00
Ean Garvey
1a8e2024d6 Exclude non-square sizes from use_tuned on rdna2 (#1568) 2023-06-21 11:36:55 -05:00
AyaanShah2204
d61b6641fb Rest API: Resolved Generator Object not Subscripatable error (#1556) 2023-06-20 19:27:41 -07:00
Phaneesh Barwaria
88cc2423cc Enable Vicuna fp16 cpu (#1562)
* fix second vic mlir gen

* fp16 mlir/vmfb download from shark_tank
2023-06-20 13:43:21 -05:00
Ean Garvey
ccf944c1bd Enable tuner for upscaler unet. (#1563) 2023-06-20 13:40:13 -05:00
Ean Garvey
0def74f520 [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)
* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.
2023-06-20 10:26:36 -07:00
Abhishek Varma
3fb72e192e Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)
-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-20 10:04:17 -07:00
Vivek Khandelwal
855435ee24 Fix for the user input for Falcon pipeline 2023-06-20 18:09:32 +05:30
Elias Joseph
6f9f868fc0 fixed a bug where designating device for vicuna didn't work 2023-06-20 17:09:32 +05:30
powderluv
fb865f1b99 Move to checkout@v3
This will break Windows again but we have to fix it up since the old node.js is now deprecated.
2023-06-19 18:44:36 -07:00
rprasad2
3e5c50f07b changes for tuning (#1542)
* Add tuning sizes for rdna3
2023-06-19 15:29:08 -05:00
powderluv
a544f30a8f Move mega to the shark examples (#1555) 2023-06-19 11:10:51 -07:00
Abhishek Varma
1fe56d460a [MEGABYTE] Add script to compile MEGABYTE through SHARK (#1553)
-- Usage: `python mega_test.py`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-19 11:00:35 -07:00
Vivek Khandelwal
fafd713141 Minor change to falcon pipeline 2023-06-19 22:36:32 +05:30
Vivek Khandelwal
015d0132c3 Modify falcon pipeline to add fp16 support (#1551) 2023-06-19 09:57:13 -07:00
powderluv
20ddd96ef7 unpin diffusers (#1550) 2023-06-18 13:45:55 -07:00
powderluv
ee33cfd2d1 Add PIL in main index.py (#1549)
* Add PIL in main index.py

This is to ensure pyinstaller picks it up

* Update index.py
2023-06-18 11:51:44 -07:00
Stefan Kapusniak
a3cba21d5b Fix load of unet512 vmfb fail on get of iree opts (#1546)
* Change retrieval of Iree options used when loading an existing
unet512 vmfb to look up the "unet" options rather than attempt to
find a non-existent set of options for "unet512"

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-18 06:42:20 -07:00
Stefan Kapusniak
a7b6ec4095 Fix unet512 always being used when --max_length=77 (#1547)
* Switches a few places in the SD pipeline where an assumption of
max_length=64 was being made, to using the actual max_length
as passed into the pipeline. This prevents unet512 always being
used and producing different images than previously when
--max_length=77
2023-06-18 06:41:25 -07:00
Ean Garvey
d80b087d95 Add PIL hidden imports to sd spec. (#1544)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-18 06:39:08 -07:00
Stefan Kapusniak
297a209608 Remove workarounds for gradio tempfile bugs (#1548) 2023-06-17 19:50:36 -07:00
gpetters94
b204113563 Add UNet512 (#1504)
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-06-17 03:46:25 -04:00
Chi_Liu
f60ab1f4fa Add Deberta to stablehlo in shark tank (#1545) 2023-06-16 13:24:44 -07:00
45 changed files with 1773 additions and 339 deletions

View File

@@ -35,6 +35,8 @@ jobs:
include:
- os: ubuntu-latest
suite: lint
- os: MacStudio
suite: metal
exclude:
- os: ubuntu-latest
suite: vulkan
@@ -46,6 +48,8 @@ jobs:
suite: cuda
- os: MacStudio
suite: cpu
- os: MacStudio
suite: vulkan
- os: icelake
suite: vulkan
- os: icelake
@@ -61,7 +65,6 @@ jobs:
steps:
- uses: actions/checkout@v3
if: matrix.os != '7950x'
- name: Set Environment Variables
if: matrix.os != '7950x'
@@ -84,9 +87,6 @@ jobs:
#cache-dependency-path: |
# **/requirements-importer.txt
# **/requirements.txt
- uses: actions/checkout@v2
if: matrix.os == '7950x'
- name: Install dependencies
if: matrix.suite == 'lint'
@@ -129,15 +129,14 @@ jobs:
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
if: matrix.suite == 'metal' && matrix.os == 'MacStudio'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'

4
.gitignore vendored
View File

@@ -2,6 +2,8 @@
__pycache__/
*.py[cod]
*$py.class
*.mlir
*.vmfb
# C extensions
*.so
@@ -157,7 +159,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
# vscode related
.vscode

View File

@@ -3,6 +3,10 @@ from pathlib import Path
from apps.language_models.src.pipelines import vicuna_pipeline as vp
from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp
import torch
import json
if __name__ == "__main__":
import gc
parser = argparse.ArgumentParser(
@@ -55,35 +59,38 @@ parser.add_argument(
help="Run model in cli mode",
)
parser.add_argument(
"--config",
default=None,
help="configuration file",
)
if __name__ == "__main__":
args, unknown = parser.parse_known_args()
vic = None
if not args.sharded:
first_vic_mlir_path = (
Path(f"first_vicuna_{args.precision}.mlir")
None
if args.first_vicuna_mlir_path is None
else Path(args.first_vicuna_mlir_path)
)
second_vic_mlir_path = (
Path(f"second_vicuna_{args.precision}.mlir")
None
if args.second_vicuna_mlir_path is None
else Path(args.second_vicuna_mlir_path)
)
first_vic_vmfb_path = (
Path(
f"first_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
)
None
if args.first_vicuna_vmfb_path is None
else Path(args.first_vicuna_vmfb_path)
)
second_vic_vmfb_path = (
Path(
f"second_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
)
None
if args.second_vicuna_vmfb_path is None
else Path(args.second_vicuna_vmfb_path)
)
vic = vp.Vicuna(
"vicuna",
device=args.device,
@@ -95,16 +102,21 @@ if __name__ == "__main__":
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
)
else:
if args.config is not None:
config_file = open(args.config)
config_json = json.load(config_file)
config_file.close()
else:
config_json = None
vic = vsp.Vicuna(
"vicuna",
device=args.device,
precision=args.precision,
config_json=config_json,
)
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"
import gc
while True:
# TODO: Add break condition from user input
user_prompt = input("User: ")

View File

@@ -145,7 +145,7 @@ class CompiledSecondVicunaLayer(torch.nn.Module):
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1):
def __init__(self, model, layers0, layers1, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
@@ -154,6 +154,12 @@ class ShardedVicunaModel(torch.nn.Module):
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
self.norm = norm
self.embedding = embedding
self.lmhead = lmhead
self.model.model.norm = self.norm
self.model.model.embed_tokens = self.embedding
self.model.lm_head = self.lmhead
def forward(
self,
@@ -176,3 +182,69 @@ class ShardedVicunaModel(torch.nn.Module):
attention_mask=attention_mask,
past_key_values=past_key_values,
)
class LMHead(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states):
output = self.model(hidden_states)
return output
class LMHeadCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, hidden_states):
hidden_states = hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
class VicunaNorm(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states):
output = self.model(hidden_states)
return output
class VicunaNormCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, hidden_states):
hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
class VicunaEmbedding(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids):
output = self.model(input_ids)
return output
class VicunaEmbeddingCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,))
output = torch.tensor(output)
return output

View File

@@ -28,8 +28,9 @@ parser = argparse.ArgumentParser(
description="runs a falcon model",
)
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
parser.add_argument(
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
@@ -40,7 +41,12 @@ parser.add_argument(
default=None,
help="path to falcon's mlir file",
)
parser.add_argument(
"--use_precompiled_model",
default=True,
action=argparse.BooleanOptionalAction,
help="use the precompiled vmfb",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
@@ -59,12 +65,12 @@ class Falcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_model_path,
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=Path("falcon.mlir"),
falcon_vmfb_path=Path("falcon.vmfb"),
falcon_mlir_path=None,
falcon_vmfb_path=None,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_padding_length = 100
@@ -85,7 +91,7 @@ class Falcon(SharkLLMBase):
return tokenizer
def get_src_model(self):
print("Loading src model")
print("Loading src model: ", self.model_name)
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
@@ -93,9 +99,26 @@ class Falcon(SharkLLMBase):
return falcon_model
def compile_falcon(self):
vmfb = get_vmfb_from_path(self.falcon_vmfb_path, self.device, "linalg")
if vmfb is not None:
return vmfb
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ "_"
+ self.device
+ ".vmfb",
self.falcon_vmfb_path.absolute(),
single_file=True,
)
vmfb = get_vmfb_from_path(
self.falcon_vmfb_path, self.device, "linalg"
)
if vmfb is not None:
return vmfb
print(
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
@@ -106,27 +129,26 @@ class Falcon(SharkLLMBase):
bytecode = f.read()
else:
mlir_generated = False
if args.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/falcon/7b/cuda/falcon.mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
if not mlir_generated:
compilation_input_ids = torch.randint(
@@ -184,6 +206,7 @@ class Falcon(SharkLLMBase):
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-spirv-index-bits=64",
],
)
print("Saved falcon vmfb at ", str(path))
@@ -192,17 +215,6 @@ class Falcon(SharkLLMBase):
return shark_module
def compile(self):
if (
not self.falcon_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/falcon/7b/cuda/falcon.vmfb",
self.falcon_vmfb_path.absolute(),
single_file=True,
)
falcon_shark_model = self.compile_falcon()
return falcon_shark_model
@@ -375,6 +387,8 @@ class Falcon(SharkLLMBase):
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
if self.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
@@ -428,18 +442,35 @@ if __name__ == "__main__":
args = parser.parse_args()
falcon_mlir_path = (
Path("falcon.mlir")
Path(
"falcon_"
+ args.falcon_variant_to_use
+ "_"
+ args.precision
+ ".mlir"
)
if args.falcon_mlir_path is None
else Path(args.falcon_mlir_path)
)
falcon_vmfb_path = (
Path("falcon.vmfb")
Path(
"falcon_"
+ args.falcon_variant_to_use
+ "_"
+ args.precision
+ "_"
+ args.device
+ ".vmfb"
)
if args.falcon_vmfb_path is None
else Path(args.falcon_vmfb_path)
)
falcon = Falcon(
"falcon",
"falcon_" + args.falcon_variant_to_use,
hf_model_path="tiiuae/falcon-"
+ args.falcon_variant_to_use
+ "-instruct",
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
@@ -451,11 +482,16 @@ if __name__ == "__main__":
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
continue_execution = True
print("\n-----\nScript executing for the following config: \n")
print("Falcon Model: ", falcon.model_name)
print("Precision: ", args.precision)
print("Device: ", args.device)
while continue_execution:
use_default_prompt = input(
"\nDo you wish to use the default prompt text? True or False?: "
"\nDo you wish to use the default prompt text? Y/N ?: "
)
if use_default_prompt:
if use_default_prompt in ["Y", "y"]:
prompt = default_prompt_text
else:
prompt = input("Please enter the prompt text: ")
@@ -469,5 +505,8 @@ if __name__ == "__main__":
res_str,
)
continue_execution = input(
"\nDo you wish to run script one more time? True or False?: "
"\nDo you wish to run script one more time? Y/N ?: "
)
continue_execution = (
True if continue_execution in ["Y", "y"] else False
)

View File

@@ -10,7 +10,7 @@ from apps.language_models.utils import (
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, get_f16_inputs
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -28,23 +28,48 @@ class Vicuna(SharkLLMBase):
max_num_tokens=512,
device="cuda",
precision="fp32",
first_vicuna_mlir_path=Path("first_vicuna.mlir"),
second_vicuna_mlir_path=Path("second_vicuna.mlir"),
first_vicuna_vmfb_path=Path("first_vicuna.vmfb"),
second_vicuna_vmfb_path=Path("second_vicuna.vmfb"),
first_vicuna_mlir_path=None,
second_vicuna_mlir_path=None,
first_vicuna_vmfb_path=None,
second_vicuna_vmfb_path=None,
load_mlir_from_shark_tank=True,
low_device_memory=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
if precision in ["int4", "int8"]:
print("int4 and int8 are not supported yet, using fp32")
precision = "fp32"
self.precision = precision
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
self.first_vicuna_mlir_path = first_vicuna_mlir_path
self.second_vicuna_mlir_path = second_vicuna_mlir_path
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
self.low_device_memory = low_device_memory
self.first_vic = None
self.second_vic = None
if self.first_vicuna_mlir_path == None:
self.first_vicuna_mlir_path = self.get_model_path()
if self.second_vicuna_mlir_path == None:
self.second_vicuna_mlir_path = self.get_model_path("second")
if self.first_vicuna_vmfb_path == None:
self.first_vicuna_vmfb_path = self.get_model_path(suffix="vmfb")
if self.second_vicuna_vmfb_path == None:
self.second_vicuna_vmfb_path = self.get_model_path(
"second", "vmfb"
)
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
def get_model_path(self, model_number="first", suffix="mlir"):
safe_device = "_".join(self.device.split("-"))
if suffix == "mlir":
return Path(f"{model_number}_vicuna_{self.precision}.{suffix}")
return Path(
f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}"
)
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
@@ -69,7 +94,7 @@ class Vicuna(SharkLLMBase):
# Compilation path needs some more work before it is functional
print(
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with\n"
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.first_vicuna_mlir_path.exists():
@@ -78,10 +103,10 @@ class Vicuna(SharkLLMBase):
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
if self.precision in ["fp32", "fp16"]:
# download MLIR from shark_tank for fp32/fp16
download_public_file(
"gs://shark_tank/vicuna/unsharded/mlir/first_vicuna.mlir",
f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
self.first_vicuna_mlir_path.absolute(),
single_file=True,
)
@@ -96,7 +121,7 @@ class Vicuna(SharkLLMBase):
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
f"Only fp32 and fp16 mlir added to tank, generating {self.precision} mlir on device."
)
if not mlir_generated:
@@ -220,10 +245,10 @@ class Vicuna(SharkLLMBase):
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
if self.precision in ["fp32", "fp16"]:
# download MLIR from shark_tank for fp32/fp16
download_public_file(
"gs://shark_tank/vicuna/unsharded/mlir/second_vicuna.mlir",
f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
self.second_vicuna_mlir_path.absolute(),
single_file=True,
)
@@ -253,9 +278,15 @@ class Vicuna(SharkLLMBase):
model,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
f16_input_mask=[False] + [True] * 64,
mlir_type="torchscript",
)
if self.precision == "fp16":
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
f16_input_mask=[False] + [True] * 64,
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
@@ -307,7 +338,7 @@ class Vicuna(SharkLLMBase):
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128xf32>"
f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
@@ -365,14 +396,19 @@ class Vicuna(SharkLLMBase):
# download vmfbs for A100
if (
not self.first_vicuna_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
and self.device in ["cuda", "cpu"]
and self.precision in ["fp32", "fp16"]
):
download_public_file(
"gs://shark_tank/vicuna/unsharded/first_vicuna.vmfb",
self.first_vicuna_vmfb_path.absolute(),
single_file=True,
)
# combinations that are still in the works
if not (self.device == "cuda" and self.precision == "fp16"):
# Will generate vmfb on device
pass
else:
download_public_file(
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
self.first_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get first vic
# TODO: Remove after testing to avoid memory overload
@@ -380,26 +416,25 @@ class Vicuna(SharkLLMBase):
pass
if (
not self.second_vicuna_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
and self.device in ["cuda", "cpu"]
and self.precision in ["fp32", "fp16"]
):
download_public_file(
"gs://shark_tank/vicuna/unsharded/second_vicuna.vmfb",
self.second_vicuna_vmfb_path.absolute(),
single_file=True,
)
# combinations that are still in the works
if not (self.device == "cuda" and self.precision == "fp16"):
# Will generate vmfb on device
pass
else:
download_public_file(
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
self.second_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get second vic
# TODO: Remove after testing to avoid memory overload
# svic_shark_model = self.compile_second_vicuna()
pass
# get first vic
# fvic_shark_model = self.compile_first_vicuna()
# get second vic
# svic_shark_model = self.compile_second_vicuna()
# return tuple of shark_modules
# return fvic_shark_model, svic_shark_model
return None
# return tuple of shark_modules once mem is supported
# return fvic_shark_model, svic_shark_model
@@ -408,12 +443,19 @@ class Vicuna(SharkLLMBase):
# TODO: refactor for cleaner integration
import gc
if not self.low_device_memory:
if self.first_vic == None:
self.first_vic = self.compile_first_vicuna()
if self.second_vic == None:
self.second_vic = self.compile_second_vicuna()
res = []
res_tokens = []
params = {
"prompt": prompt,
"is_first": True,
"fv": self.compile_first_vicuna(),
"fv": self.compile_first_vicuna()
if self.first_vic == None
else self.first_vic,
}
generated_token_op = self.generate_new_token(params=params)
@@ -429,18 +471,20 @@ class Vicuna(SharkLLMBase):
print(f"Assistant: {detok}", end=" ", flush=True)
# Clear First Vic from Memory (main and cuda)
del params
torch.cuda.empty_cache()
gc.collect()
if self.low_device_memory:
del params
torch.cuda.empty_cache()
gc.collect()
sec_vic = self.compile_second_vicuna()
for _ in range(self.max_num_tokens - 2):
params = {
"prompt": None,
"is_first": False,
"logits": logits,
"pkv": pkv,
"sv": sec_vic,
"sv": self.compile_second_vicuna()
if self.second_vic == None
else self.second_vic,
}
generated_token_op = self.generate_new_token(params=params)
@@ -461,9 +505,10 @@ class Vicuna(SharkLLMBase):
res.append(detok)
if cli:
print(f"{detok}", end=" ", flush=True)
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()
if self.device == "cuda":
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:

View File

@@ -4,6 +4,12 @@ from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
CompiledFirstVicunaLayer,
CompiledSecondVicunaLayer,
ShardedVicunaModel,
LMHead,
LMHeadCompiled,
VicunaEmbedding,
VicunaEmbeddingCompiled,
VicunaNorm,
VicunaNormCompiled,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from shark.shark_importer import import_with_fx
@@ -19,9 +25,11 @@ import re
import torch
import torch_mlir
import os
import json
class Vicuna(SharkLLMBase):
# Class representing Sharded Vicuna Model
def __init__(
self,
model_name,
@@ -29,21 +37,25 @@ class Vicuna(SharkLLMBase):
max_num_tokens=512,
device="cuda",
precision="fp32",
config_json=None,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.config = config_json
self.shark_model = self.compile(device=device)
def get_tokenizer(self):
# Retrieve the tokenizer from Huggingface
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
# Retrieve the torch model from Huggingface
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
@@ -51,6 +63,8 @@ class Vicuna(SharkLLMBase):
return vicuna_model
def write_in_dynamic_inputs0(self, module, dynamic_input_size):
# Current solution for ensuring mlir files support dynamic inputs
# TODO find a more elegant way to implement this
new_lines = []
for line in module.splitlines():
line = re.sub(f"{dynamic_input_size}x", "?x", line)
@@ -107,6 +121,7 @@ class Vicuna(SharkLLMBase):
past_key_value0=None,
past_key_value1=None,
):
# Compile a hidden decoder layer of vicuna
if past_key_value0 is None and past_key_value1 is None:
model_inputs = (hidden_states, attention_mask, position_ids)
else:
@@ -126,7 +141,154 @@ class Vicuna(SharkLLMBase):
)
return mlir_bytecode
def compile_to_vmfb(self, inputs, layers, is_first=True):
def get_device_index(self, layer_string):
# Get the device index from the config file
# In the event that different device indices are assigned to
# different parts of a layer, a majority vote will be taken and
# everything will be run on the most commonly used device
if self.config is None:
return None
idx_votes = {}
for key in self.config.keys():
if re.search(layer_string, key):
if int(self.config[key]["gpu"]) in idx_votes.keys():
idx_votes[int(self.config[key]["gpu"])] += 1
else:
idx_votes[int(self.config[key]["gpu"])] = 1
device_idx = max(idx_votes, key=idx_votes.get)
return device_idx
def compile_lmhead(
self, lmh, hidden_states, device="cpu", device_idx=None
):
# compile the lm head of the vicuna model
# This can be used for both first and second vicuna, so only needs to be run once
mlir_path = Path(f"lmhead.mlir")
vmfb_path = Path(f"lmhead.vmfb")
if mlir_path.exists():
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
hidden_states = torch_mlir.TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
)
module = torch_mlir.compile(
lmh,
(hidden_states,),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
bytecode,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
else:
shark_module.save_module(module_name="lmhead")
shark_module.load_module(vmfb_path)
compiled_module = LMHeadCompiled(shark_module)
return compiled_module
def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None):
# compile the normalization layer of the vicuna model
# This can be used for both first and second vicuna, so only needs to be run once
mlir_path = Path(f"norm.mlir")
vmfb_path = Path(f"norm.vmfb")
if mlir_path.exists():
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
hidden_states = torch_mlir.TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
)
module = torch_mlir.compile(
fvn,
(hidden_states,),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
bytecode,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
else:
shark_module.save_module(module_name="norm")
shark_module.load_module(vmfb_path)
compiled_module = VicunaNormCompiled(shark_module)
return compiled_module
def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None):
# compile the embedding layer of the vicuna model
# This can be used for both first and second vicuna, so only needs to be run once
mlir_path = Path(f"embedding.mlir")
vmfb_path = Path(f"embedding.vmfb")
if mlir_path.exists():
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
input_ids = torch_mlir.TensorPlaceholder.like(
input_ids, dynamic_axes=[1]
)
module = torch_mlir.compile(
fve,
(input_ids,),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
bytecode,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
else:
shark_module.save_module(module_name="embedding")
shark_module.load_module(vmfb_path)
compiled_module = VicunaEmbeddingCompiled(shark_module)
return compiled_module
def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True):
# compile all layers for vmfb
# this needs to be run seperatley for first and second vicuna
mlirs, modules = [], []
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
if is_first:
@@ -198,10 +360,6 @@ class Vicuna(SharkLLMBase):
verbose=False,
)
# bytecode_stream = BytesIO()
# module.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
if is_first:
module = self.write_in_dynamic_inputs0(str(module), 137)
bytecode = module.encode("UTF-8")
@@ -210,19 +368,6 @@ class Vicuna(SharkLLMBase):
else:
module = self.write_in_dynamic_inputs1(str(module), 138)
if idx in [0, 5, 6, 7]:
module_str = module
module_str = module_str.splitlines()
new_lines = []
for line in module_str:
if len(line) < 1000:
new_lines.append(line)
else:
new_lines.append(line[:999])
module_str = "\n".join(new_lines)
f1_ = open(f"{idx}_1_test.mlir", "w+")
f1_.write(module_str)
f1_.close()
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
@@ -236,20 +381,27 @@ class Vicuna(SharkLLMBase):
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
if is_first:
vmfb_path = Path(f"{idx}_0.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
device_idx = self.get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
None,
device=device,
device_idx=device_idx,
mlir_dialect="tm_tensor",
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
device_idx = self.get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
mlirs[idx],
device=device,
device_idx=device_idx,
mlir_dialect="tm_tensor",
)
module.save_module(
module_name=f"{idx}_0",
@@ -264,20 +416,28 @@ class Vicuna(SharkLLMBase):
modules.append(module)
else:
vmfb_path = Path(f"{idx}_1.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
device_idx = self.get_device_index(
f"second_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
None,
device=device,
device_idx=device_idx,
mlir_dialect="tm_tensor",
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
device_idx = self.get_device_index(
f"second_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
mlirs[idx],
device=device,
device_idx=device_idx,
mlir_dialect="tm_tensor",
)
module.save_module(
module_name=f"{idx}_1",
@@ -293,7 +453,7 @@ class Vicuna(SharkLLMBase):
return mlirs, modules
def get_sharded_model(self):
def get_sharded_model(self, device="cpu"):
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it
SAMPLE_INPUT_LEN = 137
@@ -312,11 +472,50 @@ class Vicuna(SharkLLMBase):
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
norm = VicunaNorm(vicuna_model.model.norm)
device_idx = self.get_device_index(
r"vicuna\.model\.model\.norm(?:\.|\s|$)"
)
print(device_idx)
norm = self.compile_norm(
norm,
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
device=self.device,
device_idx=device_idx,
)
embeddings = VicunaEmbedding(vicuna_model.model.embed_tokens)
device_idx = self.get_device_index(
r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)"
)
print(device_idx)
embeddings = self.compile_embedding(
embeddings,
(torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)),
device=self.device,
device_idx=device_idx,
)
lmhead = LMHead(vicuna_model.lm_head)
device_idx = self.get_device_index(
r"vicuna\.model\.lm_head(?:\.|\s|$)"
)
print(device_idx)
lmhead = self.compile_lmhead(
lmhead,
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
device=self.device,
device_idx=device_idx,
)
layers0 = [
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules0 = self.compile_to_vmfb(
placeholder_input0, layers0, is_first=True
placeholder_input0,
layers0,
is_first=True,
device=device,
)
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
@@ -324,17 +523,22 @@ class Vicuna(SharkLLMBase):
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules1 = self.compile_to_vmfb(
placeholder_input1, layers1, is_first=False
placeholder_input1, layers1, is_first=False, device=device
)
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
sharded_model = ShardedVicunaModel(
vicuna_model, shark_layers0, shark_layers1
vicuna_model,
shark_layers0,
shark_layers1,
lmhead,
embeddings,
norm,
)
return sharded_model
def compile(self):
return self.get_sharded_model()
def compile(self, device="cpu"):
return self.get_sharded_model(device=device)
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration

View File

@@ -17,6 +17,10 @@ from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def load_mlir_module():
if "upscaler" in args.hf_model_id:
is_upscaler = True
else:
is_upscaler = False
sd_model = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
@@ -27,6 +31,7 @@ def load_mlir_module():
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
is_upscaler=is_upscaler,
use_tuned=False,
low_cpu_mem_usage=args.low_cpu_mem_usage,
return_mlir=True,

View File

@@ -61,6 +61,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -19,6 +19,7 @@ datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += copy_metadata('Pillow')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')

View File

@@ -163,7 +163,7 @@ class SharkifyStableDiffusionModel:
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
sub_model_list = ["clip", "unet", "unet512", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
index = 0
for model in sub_model_list:
sub_model = model
@@ -415,7 +415,7 @@ class SharkifyStableDiffusionModel:
)
return shark_cnet, cnet_mlir
def get_unet(self):
def get_unet(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
super().__init__()
@@ -426,7 +426,7 @@ class SharkifyStableDiffusionModel:
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.in_channels = self.unet.config.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
if(args.attention_slicing.isdigit()):
@@ -452,17 +452,27 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if(use_large):
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3])
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet512"])
else:
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
input_mask = [True, True, True, False]
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
@@ -471,13 +481,13 @@ class SharkifyStableDiffusionModel:
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_unet_upscaler(self):
def get_unet_upscaler(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
@@ -502,6 +512,13 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if(use_large):
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3])
input_mask = [True, True, True, False]
shark_unet, unet_mlir = compile_through_fx(
unet,
@@ -579,16 +596,16 @@ class SharkifyStableDiffusionModel:
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
def compile_unet_variants(self, model):
def compile_unet_variants(self, model, use_large=False):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler()
return self.get_unet_upscaler(use_large=use_large)
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
return get_unet()
else:
return self.get_unet()
return self.get_unet(use_large=use_large)
else:
return self.get_controlled_unet()
@@ -616,7 +633,7 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def unet(self):
def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
compiled_unet = None
@@ -624,14 +641,14 @@ class SharkifyStableDiffusionModel:
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet, unet_mlir = self.compile_unet_variants(model)
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
try:
compiled_unet, unet_mlir = self.compile_unet_variants(model)
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")

View File

@@ -81,6 +81,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -112,7 +113,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -57,6 +57,7 @@ class StableDiffusionPipeline:
self.vae = None
self.text_encoder = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
@@ -114,6 +115,24 @@ class StableDiffusionPipeline:
del self.unet
self.unet = None
def load_unet_512(self):
if self.unet_512 is not None:
return
if self.import_mlir or self.use_lora:
self.unet_512 = self.sd_model.unet(use_large=True)
else:
try:
self.unet_512 = get_unet(use_large=True)
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.unet_512 = self.sd_model.unet(use_large=True)
def unload_unet_512(self):
del self.unet_512
self.unet_512 = None
def load_vae(self):
if self.vae is not None:
return
@@ -203,7 +222,10 @@ class StableDiffusionPipeline:
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
@@ -222,16 +244,28 @@ class StableDiffusionPipeline:
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
@@ -254,6 +288,7 @@ class StableDiffusionPipeline:
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -412,6 +447,11 @@ class StableDiffusionPipeline:
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if text_embeddings.shape[1] > model_max_length:
pad = (0, 0) * (len(text_embeddings.shape) - 2)
pad = pad + (0, 512 - text_embeddings.shape[1])
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:

View File

@@ -37,4 +37,5 @@ from apps.stable_diffusion.src.utils.utils import (
get_generation_text_info,
update_lora_weight,
resize_stencil,
_compile_module,
)

View File

@@ -116,7 +116,7 @@ def load_lower_configs(base_model_id=None):
else:
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
if not spec or spec in ["sm_80"]:
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
@@ -125,6 +125,13 @@ def load_lower_configs(base_model_id=None):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
elif spec in ["rdna3"] and version in [
"v2_1",
"v2_1base",
"v1_4",
"v1_5",
]:
config_name = f"{args.annotation_model}_{version}_{args.max_length}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
else:

View File

@@ -108,6 +108,13 @@ p.add_argument(
help="max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--max_embeddings_multiples",
type=int,
default=5,
help="The max multiple length of prompt embeddings compared to the max output length of text encoder.",
)
p.add_argument(
"--strength",
type=float,
@@ -372,6 +379,13 @@ p.add_argument(
help="Specify target triple for vulkan",
)
p.add_argument(
"--iree_metal_target_platform",
type=str,
default="",
help="Specify target triple for metal",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,

View File

@@ -18,6 +18,7 @@ from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.metal_utils import get_metal_target_triple
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
@@ -47,6 +48,7 @@ def get_vmfb_path_name(model_name):
def _load_vmfb(shark_module, vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
model = "unet" if "unet512" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module.load_module(vmfb_path, extra_args=extra_args)
@@ -115,6 +117,7 @@ def compile_through_fx(
model_name=None,
precision=None,
return_mlir=False,
device=None,
):
if not return_mlir and model_name is not None:
vmfb_path = get_vmfb_path_name(extended_model_name)
@@ -145,7 +148,10 @@ def compile_through_fx(
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
if "unet" in model_name.split("_")[0]:
if (
"unet" in model_name.split("_")[0]
or "unet_512" in model_name.split("_")[0]
):
args.annotation_model = "unet"
mlir_module = sd_model_annotation(
mlir_module, extended_model_name, base_model_id
@@ -153,7 +159,7 @@ def compile_through_fx(
shark_module = SharkInference(
mlir_module,
device=args.device,
device=args.device if device is None else device,
mlir_dialect="tm_tensor",
)
if generate_vmfb:
@@ -269,6 +275,15 @@ def set_init_device_flags():
)
elif "cuda" in args.device:
args.device = "cuda"
elif "metal" in args.device:
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_metal_target_platform:
triple = get_metal_target_triple(device_name)
if triple is not None:
args.iree_metal_target_platform = triple
print(
f"Found device {device_name}. Using target triple {args.iree_metal_target_platform}."
)
elif "cpu" in args.device:
args.device = "cpu"
@@ -293,13 +308,18 @@ def set_init_device_flags():
if (
args.precision != "fp16"
or args.height not in [512, 768]
or (args.height == 512 and args.width != 512)
or (args.height == 768 and args.width != 768)
or (args.height == 512 and args.width not in [512, 768])
or (args.height == 768 and args.width not in [512, 768])
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
args.use_tuned = False
elif (
args.height != args.width and "rdna2" in args.iree_vulkan_target_triple
):
args.use_tuned = False
elif base_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
@@ -421,9 +441,14 @@ def get_available_devices():
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("device => cpu")
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
return available_devices
@@ -732,6 +757,14 @@ def save_output_img(output_img, img_seed, extra_info={}):
if args.ckpt_loc:
img_model = Path(os.path.basename(args.ckpt_loc)).stem
img_vae = None
if args.custom_vae:
img_vae = Path(os.path.basename(args.custom_vae)).stem
img_lora = None
if args.use_lora:
img_lora = Path(os.path.basename(args.use_lora)).stem
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
@@ -742,7 +775,9 @@ def save_output_img(output_img, img_seed, extra_info={}):
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps: {args.steps},"
f"Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed},"
f"Size: {args.width}x{args.height}, Model: {img_model}, VAE: {img_vae}, LoRA: {img_lora}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
@@ -753,6 +788,9 @@ def save_output_img(output_img, img_seed, extra_info={}):
"Image saved as png instead. Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
# importance for each data point. Something to consider.
new_entry = {
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
@@ -766,6 +804,8 @@ def save_output_img(output_img, img_seed, extra_info={}):
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
"VAE": img_vae,
"LORA": img_lora,
}
new_entry.update(extra_info)

View File

@@ -1,7 +1,13 @@
from multiprocessing import Process, freeze_support
import os
import sys
import transformers # ensures inclusion in pysintaller exe generation
if sys.platform == "darwin":
# import before IREE to avoid torch-MLIR library issues
import torch_mlir
import shutil
import PIL, transformers # ensures inclusion in pysintaller exe generation
from apps.stable_diffusion.src import args, clear_all
import apps.stable_diffusion.web.utils.global_obj as global_obj
@@ -38,6 +44,7 @@ if __name__ == "__main__":
img2img_api,
upscaler_api,
inpaint_api,
outpaint_api,
)
from fastapi import FastAPI, APIRouter
import uvicorn
@@ -49,23 +56,25 @@ if __name__ == "__main__":
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route(
# "/sdapi/v1/outpaint", outpaint_api, methods=["post"]
# )
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
app.include_router(APIRouter())
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
sys.exit(0)
import gradio as gr
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
config_gradio_tmp_imgs_folder,
)
config_gradio_tmp_imgs_folder()
import gradio as gr
# Create custom models folders if they don't exist
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create custom models folders if they don't exist
create_custom_models_folders()
def resource_path(relative_path):

View File

@@ -340,6 +340,10 @@ def img2img_api(
lora_hf_id="",
ondemand=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},

View File

@@ -278,7 +278,7 @@ def inpaint_api(
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
precision="fp16",
device=available_devices[0],
@@ -289,6 +289,10 @@ def inpaint_api(
lora_hf_id="",
ondemand=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},

View File

@@ -287,7 +287,7 @@ def outpaint_api(
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
precision="fp16",
device=available_devices[0],
@@ -298,6 +298,10 @@ def outpaint_api(
lora_hf_id="",
ondemand=False,
)
# Convert Generator to Subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},

View File

@@ -9,9 +9,6 @@ from apps.stable_diffusion.src.utils import (
get_generated_imgs_todays_subdir,
)
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
from apps.stable_diffusion.web.utils.gradio_configs import (
gradio_tmp_galleries_folder,
)
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
# -- Functions for file, directory and image info querying
@@ -63,19 +60,6 @@ def output_subdirs() -> list[str]:
return result_paths
# clear zero length temporary files that gradio 3.22.0 buggily creates
# TODO: remove once gradio is upgraded to or past 3.32.0
def clear_zero_length_temps():
zero_length_temps = [
os.path.join(root, file)
for root, dirs, files in os.walk(gradio_tmp_galleries_folder)
for file in files
if os.path.getsize(os.path.join(root, file)) == 0
]
for file in zero_length_temps:
os.remove(file)
# --- Define UI layout for Gradio
with gr.Blocks() as outputgallery_web:
@@ -105,7 +89,6 @@ with gr.Blocks() as outputgallery_web:
visible=False,
show_label=True,
).style(columns=4)
gallery.DEFAULT_TEMP_DIR = gradio_tmp_galleries_folder
with gr.Column(scale=4):
with gr.Box():
@@ -179,7 +162,6 @@ with gr.Blocks() as outputgallery_web:
# --- Event handlers
def on_clear_gallery():
clear_zero_length_temps()
return [
gr.Gallery.update(
value=[],
@@ -247,7 +229,6 @@ with gr.Blocks() as outputgallery_web:
# only update if the current subdir is the most recent one as new images only go there
if subdir_paths[0] == subdir:
clear_zero_length_temps()
new_images = outputgallery_filenames(subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"

View File

@@ -41,17 +41,21 @@ def chat(curr_system_message, history, model, device, precision):
curr_system_message = start_message_vicuna
if vicuna_model == 0:
first_vic_vmfb_path = Path("first_vicuna.vmfb")
second_vic_vmfb_path = Path("second_vicuna.vmfb")
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = Vicuna(
"vicuna",
hf_model_path=model,
device=device,
precision=precision,
first_vicuna_vmfb_path=first_vic_vmfb_path,
second_vicuna_vmfb_path=second_vic_vmfb_path,
)
messages = curr_system_message + "".join(
[
@@ -120,9 +124,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
"TheBloke/vicuna-7B-1.1-HF",
],
)
supported_devices = [
device for device in available_devices if "cuda" in device
]
supported_devices = available_devices
enabled = len(supported_devices) > 0
device = gr.Dropdown(
label="Device",
@@ -138,6 +140,8 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
choices=[
"fp16",
"fp32",
"int4",
"int8",
],
visible=True,
)

View File

@@ -34,6 +34,7 @@ from apps.stable_diffusion.src.utils import (
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platform = args.iree_metal_target_platform
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
@@ -137,6 +138,7 @@ def txt2img_inf(
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
@@ -193,6 +195,7 @@ def txt2img_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
@@ -262,6 +265,10 @@ def txt2img_api(
lora_hf_id="",
ondemand=False,
)
# Convert Generator to Subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -298,7 +305,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
placeholder="Select 'None' in the dropdown on the left and enter model ID here",
value="",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
@@ -550,6 +557,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
],
outputs=[
txt2img_png_info_img,
@@ -563,5 +573,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
],
)

View File

@@ -299,6 +299,9 @@ def upscaler_api(
lora_hf_id="",
ondemand=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},

View File

@@ -1,60 +1,54 @@
import os
import shutil
import tempfile
import gradio
from time import time
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
gradio_tmp_galleries_folder = os.path.join(gradio_tmp_imgs_folder, "galleries")
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
# Clear all gradio tmp images
def clear_gradio_tmp_imgs_folder():
if not os.path.exists(gradio_tmp_imgs_folder):
return
def config_gradio_tmp_imgs_folder():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
# clear all gradio tmp files created by generation galleries
print(
"Clearing gradio temporary image files from a prior run. This may take some time..."
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
)
image_files = [
filename
for filename in os.listdir(gradio_tmp_imgs_folder)
if os.path.isfile(os.path.join(gradio_tmp_imgs_folder, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
# Clear all gradio tmp images from the last session
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
cleanup_start = time()
for filename in image_files:
os.remove(gradio_tmp_imgs_folder + filename)
print(
f"Clearing generation temporary image files took {time() - cleanup_start:4f} seconds"
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
)
else:
print("no generation temporary files to clear")
# Clear all gradio tmp files created by output galleries
if os.path.exists(gradio_tmp_galleries_folder):
cleanup_start = time()
shutil.rmtree(gradio_tmp_galleries_folder, ignore_errors=True)
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
print(
f"Clearing output gallery temporary image files took {time() - cleanup_start:4f} seconds"
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
)
# older SHARK versions had to workaround gradio bugs and stored things differently
else:
print("no output gallery temporary files to clear")
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
def save_pil_to_file(pil_image, dir=None):
if not os.path.exists(gradio_tmp_imgs_folder):
os.mkdir(gradio_tmp_imgs_folder)
file_obj = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
)
pil_image.save(file_obj)
return file_obj
# Register save_pil_to_file override
gradio.processing_utils.save_pil_to_file = save_pil_to_file
image_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
print(
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
)
cleanup_start = time()
for filename in image_files:
os.remove(shark_tmp + filename)
print(
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
)
else:
print("No temporary images files to clear.")

View File

@@ -62,6 +62,82 @@ def parse_generation_parameters(x: str):
return res
def try_find_model_base_from_png_metadata(
file: str, folder: str = "models"
) -> str:
custom = ""
# Remove extension from file info
if file.endswith(".safetensors") or file.endswith(".ckpt"):
file = Path(file).stem
# Check for the file name match with one of the local ckpt or safetensors files
if Path(get_custom_model_pathfile(file + ".ckpt", folder)).is_file():
custom = file + ".ckpt"
if Path(
get_custom_model_pathfile(file + ".safetensors", folder)
).is_file():
custom = file + ".safetensors"
return custom
def find_model_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
png_hf_id = ""
png_custom = ""
if key in metadata:
model_file = metadata[key]
png_custom = try_find_model_base_from_png_metadata(model_file)
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if model_file in predefined_models:
png_custom = model_file
# If nothing had matched, check vendor/hf_model_id
if not png_custom and model_file.count("/"):
png_hf_id = model_file
# No matching model was found
if not png_custom and not png_hf_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% model_file
)
return png_custom, png_hf_id
def find_vae_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> str:
vae_custom = ""
if key in metadata:
vae_file = metadata[key]
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
# VAE input is optional, should not print or throw an error if missing
return vae_custom
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
lora_custom = ""
if key in metadata:
lora_file = metadata[key]
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
def import_png_metadata(
pil_data,
prompt,
@@ -74,40 +150,21 @@ def import_png_metadata(
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
png_hf_model_id = ""
png_custom_model = ""
if "Model" in metadata:
# Remove extension from model info
if metadata["Model"].endswith(".safetensors") or metadata[
"Model"
].endswith(".ckpt"):
metadata["Model"] = Path(metadata["Model"]).stem
# Check for the model name match with one of the local ckpt or safetensors files
if Path(
get_custom_model_pathfile(metadata["Model"] + ".ckpt")
).is_file():
png_custom_model = metadata["Model"] + ".ckpt"
if Path(
get_custom_model_pathfile(metadata["Model"] + ".safetensors")
).is_file():
png_custom_model = metadata["Model"] + ".safetensors"
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if metadata["Model"] in predefined_models:
png_custom_model = metadata["Model"]
# If nothing had matched, check vendor/hf_model_id
if not png_custom_model and metadata["Model"].count("/"):
png_hf_model_id = metadata["Model"]
# No matching model was found
if not png_custom_model and not png_hf_model_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% metadata["Model"]
)
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
@@ -115,12 +172,24 @@ def import_png_metadata(
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
hf_model_id = ""
if "Model" in metadata and png_hf_model_id:
custom_model = "None"
hf_model_id = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
if "Prompt" in metadata:
prompt = metadata["Prompt"]
if "Sampler" in metadata:
@@ -149,4 +218,7 @@ def import_png_metadata(
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
)

View File

@@ -16,7 +16,7 @@ parameterized
# Add transformers, diffusers and scipy since it most commonly used
transformers
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
diffusers
scipy
ftfy
gradio==3.34.0

243
rest_api_tests/api_test.py Normal file
View File

@@ -0,0 +1,243 @@
import requests
from PIL import Image
import base64
from io import BytesIO
def upscaler_test():
# Define values here
prompt = ""
negative_prompt = ""
seed = 2121991605
height = 512
width = 512
steps = 50
noise_level = 10
cfg_scale = 7
image_path = r"./rest_api_tests/dog.png"
# Converting Image to base64
img_file = open(image_path, "rb")
init_images = [
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
]
url = "http://127.0.0.1:8080/sdapi/v1/upscaler"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"height": height,
"width": width,
"steps": steps,
"noise_level": noise_level,
"cfg_scale": cfg_scale,
"init_images": init_images,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"response from server was : {res.status_code}")
def img2img_test():
# Define values here
prompt = "Paint a rabbit riding on the dog"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
height = 512
width = 512
steps = 50
denoising_strength = 0.75
cfg_scale = 7
image_path = r"./rest_api_tests/dog.png"
# Converting Image to Base64
img_file = open(image_path, "rb")
init_images = [
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
]
url = "http://127.0.0.1:8080/sdapi/v1/img2img"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"init_images": init_images,
"height": height,
"width": width,
"steps": steps,
"denoising_strength": denoising_strength,
"cfg_scale": cfg_scale,
"seed": seed,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"response from server was : {res.status_code}")
# NOTE Uncomment below to save the picture
# print("Extracting response object")
# response_obj = res.json()
# img_b64 = response_obj.get("images", [False])[0] or response_obj.get(
# "image"
# )
# img_b2 = base64.b64decode(img_b64.replace("data:image/png;base64,", ""))
# im_file = BytesIO(img_b2)
# response_img = Image.open(im_file)
# print("Saving Response Image to: response_img")
# response_img.save(r"rest_api_tests/response_img.png")
def inpainting_test():
prompt = "Paint a rabbit riding on the dog"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
height = 512
width = 512
steps = 50
noise_level = 10
cfg_scale = 7
is_full_res = False
full_res_padding = 32
image_path = r"./rest_api_tests/dog.png"
img_file = open(image_path, "rb")
image = (
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
)
img_file = open(image_path, "rb")
mask = (
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
)
url = "http://127.0.0.1:8080/sdapi/v1/inpaint"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": image,
"mask": mask,
"height": height,
"width": width,
"steps": steps,
"noise_level": noise_level,
"cfg_scale": cfg_scale,
"seed": seed,
"is_full_res": is_full_res,
"full_res_padding": full_res_padding,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"[Inpainting] response from server was : {res.status_code}")
def outpainting_test():
prompt = "Paint a rabbit riding on the dog"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
height = 512
width = 512
steps = 50
cfg_scale = 7
color_variation = 0.2
noise_q = 0.2
directions = ["up", "down", "right", "left"]
pixels = 32
mask_blur = 64
image_path = r"./rest_api_tests/dog.png"
# Converting Image to Base64
img_file = open(image_path, "rb")
init_images = [
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
]
url = "http://127.0.0.1:8080/sdapi/v1/outpaint"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"height": height,
"width": width,
"steps": steps,
"cfg_scale": cfg_scale,
"color_variation": color_variation,
"noise_q": noise_q,
"directions": directions,
"pixels": pixels,
"mask_blur": mask_blur,
"init_images": init_images,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"[Outpaint] response from server was : {res.status_code}")
def txt2img_test():
prompt = "Paint a rabbit in a top hate"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
height = 512
width = 512
steps = 50
cfg_scale = 7
url = "http://127.0.0.1:8080/sdapi/v1/txt2img"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"height": height,
"width": width,
"steps": steps,
"cfg_scale": cfg_scale,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"[txt2img] response from server was : {res.status_code}")
if __name__ == "__main__":
txt2img_test()
img2img_test()
upscaler_test()
inpainting_test()
outpainting_test()

BIN
rest_api_tests/dog.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

View File

@@ -39,7 +39,7 @@ setup(
install_requires=[
"numpy",
"PyYAML",
"torch-mlir>=20221021.633",
"torch-mlir==20230620.875",
]
+ backend_deps,
)

View File

@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --pre torch-mlir==20230620.875 torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html

View File

@@ -27,6 +27,11 @@ PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; prin
echo "Python: $PYTHON"
echo "Python version: $PYTHON_VERSION_X_Y"
if [ "$PYTHON_VERSION_X_Y" != "3.11" ]; then
echo "Error: Python version 3.11 is required."
exit 1
fi
if [[ "$SKIP_VENV" != "1" ]]; then
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
@@ -83,7 +88,7 @@ if [ "$torch_mlir_bin" = true ]; then
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
else
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
$PYTHON -m pip install --pre torch-mlir==20230620.875 -f https://llvm.github.io/torch-mlir/package-index/
if [ $? -eq 0 ];then
echo "Successfully Installed torch-mlir"
else

View File

@@ -0,0 +1,76 @@
import torch
import torch_mlir
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.utils import (
compile_through_fx,
args,
)
from MEGABYTE_pytorch import MEGABYTE
import os
class MegaModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = MEGABYTE(
num_tokens=16000, # number of tokens
dim=(
512,
256,
), # transformer model dimension (512 for coarsest, 256 for fine in this example)
max_seq_len=(
1024,
4,
), # sequence length for global and then local. this can be more than 2
depth=(
6,
4,
), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
dim_head=64, # dimension per head
heads=8, # number of attention heads
flash_attn=True, # use flash attention
)
def forward(self, input):
return self.model(input)
megaModel = MegaModel()
input = [torch.randint(0, 16000, (1, 1024, 4))]
# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :-
# 1. aten.alias
shark_module, _ = compile_through_fx(
megaModel,
inputs=input,
extended_model_name="mega_shark",
debug=False,
generate_vmfb=True,
save_dir=os.getcwd(),
extra_args=[],
base_model_id=None,
model_name="mega_shark",
precision=None,
return_mlir=True,
device="cuda",
)
# logits = model(x)
def print_output_info(output, msg):
print("\n", msg)
print("\n\t", output.shape)
ans = shark_module("forward", input)
print_output_info(torch.from_numpy(ans), "SHARK's output")
ans = megaModel.forward(*input)
print_output_info(ans, "ORIGINAL Model's output")
# and sample from the logits accordingly
# or you can use the generate function
# NEED TO LOOK AT THIS LATER IF REQUIRED IN SHARK.
# sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)

View File

@@ -1,4 +1,4 @@
# Copyright 2020 The Nod Team. All rights reserved.
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -63,10 +63,11 @@ def get_supported_device_list():
_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"AMD-AIE": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
"metal": "metal",
"rocm": "rocm",
"intel-gpu": "level_zero",
}
@@ -81,10 +82,11 @@ def iree_target_map(device):
_IREE_TARGET_MAP = {
"cpu": "llvm-cpu",
"cpu-task": "llvm-cpu",
"AMD-AIE": "llvm-cpu",
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
"metal": "metal",
"rocm": "rocm",
"intel-gpu": "opencl-spirv",
}
@@ -101,11 +103,13 @@ def check_device_drivers(device):
subprocess.check_output("nvidia-smi")
except Exception:
return True
elif device in ["metal", "vulkan"]:
elif device in ["vulkan"]:
try:
subprocess.check_output("vulkaninfo")
except Exception:
return True
elif device == "metal":
return False
elif device in ["intel-gpu"]:
try:
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])

View File

@@ -1,4 +1,4 @@
# Copyright 2020 The Nod Team. All rights reserved.
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,6 +19,8 @@ from shark.parser import shark_args
import numpy as np
import os
import re
import tempfile
from pathlib import Path
# Get the iree-compile arguments given device.
@@ -38,17 +40,26 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
return get_iree_cpu_args()
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
return get_iree_cpu_args() + data_tiling_flag + u_kernel_flag
if device_uri[0] == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
return get_iree_gpu_args()
if device_uri[0] in ["metal", "vulkan"]:
if device_uri[0] == "vulkan":
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args
return get_iree_metal_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
@@ -175,8 +186,10 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
vmfb_file.close()
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
vm_module = ireert.VmModule.from_buffer(
config.vm_instance,
flatbuffer_blob,
warn_if_copy=False,
)
benchmark_cl = build_benchmark_args_non_tensor_input(
@@ -307,8 +320,8 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
vm_module = ireert.VmModule.from_buffer(
config.vm_instance, flatbuffer_blob, warn_if_copy=False
)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
@@ -316,6 +329,58 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
return ModuleCompiled, config
def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
):
instance = ireert.VmInstance()
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device_by_uri(
device,
allocators=[],
)
# First get configs.
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(mmaped_vmfb)
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
return mmaped_vmfb, config, temp_file_to_unlink
def get_iree_compiled_module(
module,
device: str,
@@ -323,19 +388,58 @@ def get_iree_compiled_module(
model_config_path: str = None,
extra_args: list = [],
device_idx: int = None,
mmap: bool = False,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args
)
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
# Got to find a cleaner way to unlink/delete the temporary file since
# we're setting delete=False when creating NamedTemporaryFile. That's why
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
if mmap:
print(f"Will load the compiled module as a mmapped temporary file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_blob, device, device_idx
)
else:
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
)
ret_params = {
"vmfb": vmfb,
"config": config,
"temp_file_to_unlink": temp_file_to_unlink,
}
return ret_params
def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None):
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
def load_flatbuffer(
flatbuffer_path: str,
device: str,
device_idx: int = None,
mmap: bool = False,
):
temp_file_to_unlink = None
if mmap:
print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_path, device, device_idx
)
else:
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
)
ret_params = {
"vmfb": vmfb,
"config": config,
"temp_file_to_unlink": temp_file_to_unlink,
}
return ret_params
def export_iree_module_to_vmfb(

View File

@@ -0,0 +1,121 @@
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# All the iree_vulkan related functionalities go here.
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_metal_device_name(device_num=0):
iree_device_dump = run_cmd("iree-run-module --dump_devices")
iree_device_dump = iree_device_dump[0].split("\n\n")
metal_device_list = [
s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s
]
if len(metal_device_list) == 0:
raise ValueError("No device name found in device dump!")
if len(metal_device_list) > 1:
print("Following devices found:")
for i, dname in enumerate(metal_device_list):
print(f"{i}. {dname}")
print(f"Choosing device: {metal_device_list[device_num]}")
return metal_device_list[device_num]
def get_os_name():
if platform.startswith("linux"):
return "linux"
elif platform == "darwin":
return "macos"
elif platform == "win32":
return "windows"
else:
print("Cannot detect OS type, defaulting to linux.")
return "linux"
def get_metal_target_triple(device_name):
"""This method provides a target triple str for specified vulkan device.
Args:
device_name (str): name of the hardware device to be used with vulkan
Returns:
str or None: target triple or None if no match found for given name
"""
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"
else:
triple = None
return triple
def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
for flag in extra_args:
if "-iree-metal-target-platform=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None
if device_name == "" or device_name == [] or device_name is None:
metal_device = get_metal_device_name(device_num=device_num)
else:
metal_device = device_name
triple = get_metal_target_triple(metal_device)
if triple is not None:
print(
f"Found metal device {metal_device}. Using metal target triple {triple}"
)
return f"-iree-metal-target-platform={triple}"
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
or pull up an issue."""
)
print(f"Target : {metal_device}")
return None
def get_iree_metal_args(device_num=0, extra_args=[]):
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]
res_metal_flag = []
metal_triple_flag = None
for arg in extra_args:
if "-iree-metal-target-platform=" in arg:
print(f"Using target triple {arg} from command line args")
metal_triple_flag = arg
break
if metal_triple_flag is None:
metal_triple_flag = get_metal_triple_flag(
device_num=device_num, extra_args=extra_args
)
if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
res_metal_flag.append(vulkan_target_env)
return res_metal_flag
def set_iree_metal_runtime_flags(flags):
for flag in flags:
ireert.flags.parse_flags(flag)
return

View File

@@ -60,12 +60,15 @@ def download_public_file(
else:
continue
destination_filename = os.path.join(destination_folder_name, blob_name)
if os.path.isdir(destination_filename):
continue
with open(destination_filename, "wb") as f:
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
storage_client.download_blob_to_file(blob, file_obj)
else:
destination_filename = os.path.join(
destination_folder_name, blob_name
)
if os.path.isdir(destination_filename):
continue
with open(destination_filename, "wb") as f:
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
storage_client.download_blob_to_file(blob, file_obj)
input_type_to_np_dtype = {

View File

@@ -0,0 +1,206 @@
from typing import Any, Dict, List, Tuple
from collections import defaultdict
from shark.shark_importer import import_with_fx
import torchvision.models as models
import copy
import io
import numpy as np
import sys
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
import torch_mlir
def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
mlir_module = torch_mlir.compile(
fx_g, inputs, output_type="linalg-on-tensors"
)
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode,
device=device,
mlir_dialect="tm_tensor",
)
shark_module.compile(extra_args=[])
return shark_module
def _make_single_op_gm(node, captured_val, compiled_graph):
"""Make a GraphModule that just executes the given node."""
g = torch.fx.Graph()
env = {}
inputs = []
for arg in node.args:
if arg and hasattr(arg, "name"):
env[arg.name] = g.placeholder(arg.name)
if isinstance(captured_val[arg.name], (list, tuple)):
for val in captured_val[arg.name]:
inputs.append(val)
else:
inputs.append(captured_val[arg.name])
call = g.node_copy(node, lambda n: env[n.name])
g.output(call)
g.lint()
single_node = torch.fx.GraphModule(torch.nn.Module(), g)
compiled_module = shark_backend(single_node, inputs)
compiled_graph[node.name] = {
"module": compiled_module,
"inputs": [i for i in env],
"result": None,
}
return
def compiled_graph(gm: torch.fx.GraphModule, attr_info):
compiled_graph = {}
g = gm.graph
for node in g.nodes:
if node.op == "call_function":
if not (
node.target in [torch.ops.aten.empty]
or node.name.startswith("getitem")
):
_make_single_op_gm(node, attr_info, compiled_graph)
# Currently torch.aten.empty has an compilation issue, so running natively.
elif node.target in [torch.ops.aten.empty]:
compiled_graph[node.name] = {
"target": node.target,
"args": node.args,
"kwargs": node.kwargs,
"result": None,
}
# Get item is a simple case takes a tuple and return the tensor at a particular index.
elif node.name.startswith("getitem"):
compiled_graph[node.name] = {
"input": node.args[0].name,
"pos": node.args[1],
"result": None,
}
return compiled_graph
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env: Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target: str):
target_atoms = target.split(".")
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == "placeholder":
result = next(args_iter)
elif node.op == "get_attr":
result = fetch_attr(node.target)
elif node.op == "call_function":
result = node.target(
*load_arg(node.args), **load_arg(node.kwargs)
)
elif node.op == "call_method":
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == "call_module":
result = self.modules[node.target](
*load_arg(node.args), **load_arg(node.kwargs)
)
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return env
# return load_arg(self.graph.result)
resnet18 = models.resnet18(pretrained=True)
resnet18.train(False)
input = (torch.randn(1, 3, 224, 224),)
print(resnet18(input[0]))
fx_graph = import_with_fx(resnet18, input, mlir_type="fx")
shape_prop = ShapeProp(fx_graph)
x = shape_prop.propagate(input[0])
shark_graph = compiled_graph(fx_graph, x)
for key in shark_graph:
if key.startswith("getitem"):
input_val = shark_graph[key]["input"]
pos = shark_graph[key]["pos"]
if input_val not in shark_graph:
shark_graph[key]["result"] = x[input_val][pos].detach()
else:
shark_graph[key]["result"] = shark_graph[input_val]["result"][
pos
].detach()
elif key.startswith("empty"):
operator = shark_graph[key]["target"]
args = shark_graph[key]["args"]
kwargs = shark_graph[key]["kwargs"]
shark_graph[key]["result"] = operator(*args, **kwargs).detach()
else:
input_val = shark_graph[key]["inputs"]
input_tensors = []
for input in input_val:
if input not in shark_graph:
input_tensors.append(x[input].detach())
else:
input_tensors.append(shark_graph[input]["result"])
val = shark_graph[key]["module"]("forward", input_tensors)
if isinstance(val, (tuple, list)):
list_val = []
for v in val:
list_val.append(torch.from_numpy(v))
shark_graph[key]["result"] = list_val
else:
shark_graph[key]["result"] = torch.from_numpy(val)
print(shark_graph)

View File

@@ -1,5 +1,8 @@
import re
import json
from collections import OrderedDict
import torch_mlir
from iree.compiler import compile_str
from shark.shark_importer import import_with_fx, get_f16_inputs
class GenerateConfigFile:
@@ -7,7 +10,9 @@ class GenerateConfigFile:
self,
model,
num_sharding_stages: int,
sharding_stages_id: list[str] = None,
sharding_stages_id: list[str],
model_input=None,
config_file_path="model_config.json",
):
self.model = model
self.num_sharding_stages = num_sharding_stages
@@ -15,8 +20,67 @@ class GenerateConfigFile:
assert self.num_sharding_stages == len(
self.sharding_stages_id
), "Number of sharding stages should be equal to the list of their ID"
self.model_input = model_input
self.config_file_path = config_file_path
def generate_json(self):
def split_into_dispatches(
self,
backend,
fx_tracing_required=True,
f16_model=False,
torch_mlir_tracing=False,
):
graph_for_compilation = self.model
if fx_tracing_required:
graph_for_compilation = import_with_fx(
self.model,
self.model_input,
is_f16=f16_model,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
module = torch_mlir.compile(
graph_for_compilation,
(self.model_input),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=torch_mlir_tracing,
verbose=False,
)
module = module.operation.get_asm(large_elements_limit=4)
compiled_module_str = str(
compile_str(
str(module),
target_backends=[backend],
extra_args=[
"--compile-to=flow",
"--mlir-elide-elementsattrs-if-larger=4",
],
)
)
substring_start_idx = [
m.start()
for m in re.finditer("flow.dispatch @", compiled_module_str)
]
dispatch_list = dict()
# dispatch_no is the 'i'th index of a dispatch out of n total dispatches of a model
# dispatch_id is the unique id of a dispatch, multiple instances of the same dispatch
# can occur in a model
for dispatch_no, substring_idx in enumerate(substring_start_idx):
dispatch_idx = (
compiled_module_str[substring_idx:]
.split(":")[0]
.split("@")[-1]
)
key = "dispatch_no_" + str(dispatch_no)
dispatch_list[key] = {n: "None" for n in self.sharding_stages_id}
dispatch_list[key]["dispatch_id"] = dispatch_idx
self.generate_json(dispatch_list)
def split_into_layers(self):
model_dictionary = dict()
for name, m in self.model.named_modules():
@@ -34,5 +98,8 @@ class GenerateConfigFile:
layer_dict = {n: "None" for n in self.sharding_stages_id}
model_dictionary[name] = layer_dict
with open("model_config.json", "w") as outfile:
json.dump(model_dictionary, outfile)
self.generate_json(model_dictionary)
def generate_json(self, artifacts):
with open(self.config_file_path, "w") as outfile:
json.dump(artifacts, outfile)

View File

@@ -370,6 +370,7 @@ def transform_fx(fx_g):
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.zeros,
torch.ops.aten.zeros_like,
]:
if node.kwargs.get("dtype") == torch.float32:
node.kwargs = kwargs_dict
@@ -525,6 +526,8 @@ def import_with_fx(
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
]
),
)(*inputs)
@@ -552,6 +555,9 @@ def import_with_fx(
add_upcast(fx_g)
fx_g.recompile()
if mlir_type == "fx":
return fx_g
if training:
change_fx_graph_return_to_tuple(fx_g)
inputs = flatten_training_input(inputs)

View File

@@ -48,6 +48,8 @@ class SharkInference:
Refer to {https://mlir.llvm.org/docs/Dialects/}
is_benchmark: bool
Whether this SharkInference module should be benchmark-enabled.
mmap: bool
Whether to load/run vmfb using mmap. It's `True` by default.
Methods
-------
@@ -70,6 +72,7 @@ class SharkInference:
dispatch_benchmark: str = None,
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
device_idx: int = None,
mmap: bool = True,
):
self.mlir_module = mlir_module
self.device = shark_args.device if device == "none" else device
@@ -88,6 +91,7 @@ class SharkInference:
)
self.shark_runner = None
self.mmap = mmap
def compile(self, extra_args=[]):
if self.dispatch_benchmarks is not None:
@@ -201,12 +205,14 @@ class SharkInference:
compile_vmfb=False,
extra_args=extra_args,
)
(
self.shark_runner.iree_compilation_module,
self.shark_runner.iree_config,
) = load_flatbuffer(
params = load_flatbuffer(
path,
self.device,
self.device_idx,
mmap=self.mmap,
)
self.shark_runner.iree_compilation_module = params["vmfb"]
self.shark_runner.iree_config = params["config"]
self.shark_runner.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
return

View File

@@ -85,16 +85,17 @@ class SharkRunner:
if compile_vmfb == True:
# Compile the module to get the .vmfb.
(
self.iree_compilation_module,
self.iree_config,
) = get_iree_compiled_module(
params = get_iree_compiled_module(
self.mlir_module,
self.device,
self.mlir_dialect,
extra_args=self.extra_args,
device_idx=self.device_idx,
)
self.iree_compilation_module = params["vmfb"]
self.iree_config = params["config"]
self.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
def run(self, function_name, inputs: tuple, send_to_host=False):
return get_results(

View File

@@ -12,8 +12,8 @@ from transformers import AutoTokenizer, OPTForCausalLM
OPT_MODEL = "opt-1.3b"
OPT_FS_NAME = "opt-1_3b"
MAX_SEQUENCE_LENGTH = 30
MAX_NEW_TOKENS = 20
MAX_SEQUENCE_LENGTH = 128
MAX_NEW_TOKENS = 60
def create_module(model_name, tokenizer, device):
@@ -110,13 +110,13 @@ if __name__ == "__main__":
"facebook/" + OPT_MODEL, use_fast=False
)
vmfb_path = (
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-sync.vmfb"
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-task.vmfb"
)
opt_shark_module = SharkInference(mlir_module=None, device="cpu-sync")
opt_shark_module = SharkInference(mlir_module=None, device="cpu-task")
if os.path.isfile(vmfb_path):
opt_shark_module.load_module(vmfb_path)
else:
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-sync")
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-task")
opt_shark_module.load_module(vmfb_path)
while True:
try:

View File

@@ -24,4 +24,5 @@ bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
1 model_name use_tracing model_type dynamic mlir_type decompose param_count tags notes
24 bert-base-uncased True hf False stablehlo False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
25 gpt2 True hf_causallm False stablehlo True 125M nlp;transformer-encoder -
26 facebook/opt-125m True hf False stablehlo True 125M nlp;transformer-encoder -
27 distilgpt2 True hf False stablehlo True 88M nlp;transformer-encoder -
28 microsoft/deberta-v3-base True hf False stablehlo True 88M nlp;transformer-encoder -