mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
46 Commits
20230616.7
...
20230623.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b444528715 | ||
|
|
6e6c90f62b | ||
|
|
8cdb38496e | ||
|
|
726d73d6ba | ||
|
|
4d55e51d46 | ||
|
|
6ef78ee7ba | ||
|
|
4002da7161 | ||
|
|
ecb5e8e5d8 | ||
|
|
28e0919321 | ||
|
|
28f4d44a6b | ||
|
|
97f7e79391 | ||
|
|
44a8f2f8db | ||
|
|
8822b9acd7 | ||
|
|
0ca3b9fce3 | ||
|
|
045f2bb147 | ||
|
|
a811b867b9 | ||
|
|
cdd505e2dd | ||
|
|
1b0f39107c | ||
|
|
b9b8955f74 | ||
|
|
6f7a85eee3 | ||
|
|
18c8e9e51e | ||
|
|
a202bb466a | ||
|
|
07c1e1d712 | ||
|
|
18daec78c8 | ||
|
|
1a8e2024d6 | ||
|
|
d61b6641fb | ||
|
|
88cc2423cc | ||
|
|
ccf944c1bd | ||
|
|
0def74f520 | ||
|
|
3fb72e192e | ||
|
|
855435ee24 | ||
|
|
6f9f868fc0 | ||
|
|
fb865f1b99 | ||
|
|
3e5c50f07b | ||
|
|
a544f30a8f | ||
|
|
1fe56d460a | ||
|
|
fafd713141 | ||
|
|
015d0132c3 | ||
|
|
20ddd96ef7 | ||
|
|
ee33cfd2d1 | ||
|
|
a3cba21d5b | ||
|
|
a7b6ec4095 | ||
|
|
d80b087d95 | ||
|
|
297a209608 | ||
|
|
b204113563 | ||
|
|
f60ab1f4fa |
13
.github/workflows/test-models.yml
vendored
13
.github/workflows/test-models.yml
vendored
@@ -35,6 +35,8 @@ jobs:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
suite: lint
|
||||
- os: MacStudio
|
||||
suite: metal
|
||||
exclude:
|
||||
- os: ubuntu-latest
|
||||
suite: vulkan
|
||||
@@ -46,6 +48,8 @@ jobs:
|
||||
suite: cuda
|
||||
- os: MacStudio
|
||||
suite: cpu
|
||||
- os: MacStudio
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
@@ -61,7 +65,6 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
if: matrix.os != '7950x'
|
||||
|
||||
- name: Set Environment Variables
|
||||
if: matrix.os != '7950x'
|
||||
@@ -84,9 +87,6 @@ jobs:
|
||||
#cache-dependency-path: |
|
||||
# **/requirements-importer.txt
|
||||
# **/requirements.txt
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
if: matrix.os == '7950x'
|
||||
|
||||
- name: Install dependencies
|
||||
if: matrix.suite == 'lint'
|
||||
@@ -129,15 +129,14 @@ jobs:
|
||||
# python build_tools/stable_diffusion_testing.py --device=cuda
|
||||
|
||||
- name: Validate Vulkan Models (MacOS)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
|
||||
if: matrix.suite == 'metal' && matrix.os == 'MacStudio'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
export DYLD_LIBRARY_PATH=/usr/local/lib/
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -2,6 +2,8 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.mlir
|
||||
*.vmfb
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
@@ -157,7 +159,7 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
# vscode related
|
||||
.vscode
|
||||
|
||||
@@ -3,6 +3,10 @@ from pathlib import Path
|
||||
from apps.language_models.src.pipelines import vicuna_pipeline as vp
|
||||
from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp
|
||||
import torch
|
||||
import json
|
||||
|
||||
if __name__ == "__main__":
|
||||
import gc
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -55,35 +59,38 @@ parser.add_argument(
|
||||
help="Run model in cli mode",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
default=None,
|
||||
help="configuration file",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
vic = None
|
||||
if not args.sharded:
|
||||
first_vic_mlir_path = (
|
||||
Path(f"first_vicuna_{args.precision}.mlir")
|
||||
None
|
||||
if args.first_vicuna_mlir_path is None
|
||||
else Path(args.first_vicuna_mlir_path)
|
||||
)
|
||||
second_vic_mlir_path = (
|
||||
Path(f"second_vicuna_{args.precision}.mlir")
|
||||
None
|
||||
if args.second_vicuna_mlir_path is None
|
||||
else Path(args.second_vicuna_mlir_path)
|
||||
)
|
||||
first_vic_vmfb_path = (
|
||||
Path(
|
||||
f"first_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
|
||||
)
|
||||
None
|
||||
if args.first_vicuna_vmfb_path is None
|
||||
else Path(args.first_vicuna_vmfb_path)
|
||||
)
|
||||
second_vic_vmfb_path = (
|
||||
Path(
|
||||
f"second_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
|
||||
)
|
||||
None
|
||||
if args.second_vicuna_vmfb_path is None
|
||||
else Path(args.second_vicuna_vmfb_path)
|
||||
)
|
||||
|
||||
vic = vp.Vicuna(
|
||||
"vicuna",
|
||||
device=args.device,
|
||||
@@ -95,16 +102,21 @@ if __name__ == "__main__":
|
||||
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
|
||||
)
|
||||
else:
|
||||
if args.config is not None:
|
||||
config_file = open(args.config)
|
||||
config_json = json.load(config_file)
|
||||
config_file.close()
|
||||
else:
|
||||
config_json = None
|
||||
vic = vsp.Vicuna(
|
||||
"vicuna",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
config_json=config_json,
|
||||
)
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
import gc
|
||||
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
|
||||
@@ -145,7 +145,7 @@ class CompiledSecondVicunaLayer(torch.nn.Module):
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers0, layers1):
|
||||
def __init__(self, model, layers0, layers1, lmhead, embedding, norm):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert len(layers0) == len(model.model.layers)
|
||||
@@ -154,6 +154,12 @@ class ShardedVicunaModel(torch.nn.Module):
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers0 = layers0
|
||||
self.layers1 = layers1
|
||||
self.norm = norm
|
||||
self.embedding = embedding
|
||||
self.lmhead = lmhead
|
||||
self.model.model.norm = self.norm
|
||||
self.model.model.embed_tokens = self.embedding
|
||||
self.model.lm_head = self.lmhead
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -176,3 +182,69 @@ class ShardedVicunaModel(torch.nn.Module):
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class LMHead(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class LMHeadCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach()
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaNorm(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaNormCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states.detach()
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaEmbedding(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids):
|
||||
output = self.model(input_ids)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaEmbeddingCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids.detach()
|
||||
output = self.model("forward", (input_ids,))
|
||||
output = torch.tensor(output)
|
||||
return output
|
||||
|
||||
@@ -28,8 +28,9 @@ parser = argparse.ArgumentParser(
|
||||
description="runs a falcon model",
|
||||
)
|
||||
|
||||
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
|
||||
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
@@ -40,7 +41,12 @@ parser.add_argument(
|
||||
default=None,
|
||||
help="path to falcon's mlir file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_precompiled_model",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="use the precompiled vmfb",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=False,
|
||||
@@ -59,12 +65,12 @@ class Falcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="tiiuae/falcon-7b-instruct",
|
||||
hf_model_path,
|
||||
max_num_tokens=150,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
falcon_mlir_path=Path("falcon.mlir"),
|
||||
falcon_vmfb_path=Path("falcon.vmfb"),
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_padding_length = 100
|
||||
@@ -85,7 +91,7 @@ class Falcon(SharkLLMBase):
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
print("Loading src model")
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
@@ -93,9 +99,26 @@ class Falcon(SharkLLMBase):
|
||||
return falcon_model
|
||||
|
||||
def compile_falcon(self):
|
||||
vmfb = get_vmfb_from_path(self.falcon_vmfb_path, self.device, "linalg")
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ "_"
|
||||
+ self.device
|
||||
+ ".vmfb",
|
||||
self.falcon_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path, self.device, "linalg"
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
|
||||
@@ -106,27 +129,26 @@ class Falcon(SharkLLMBase):
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
if args.load_mlir_from_shark_tank:
|
||||
if self.precision == "fp32":
|
||||
# download MLIR from shark_tank for fp32
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/7b/cuda/falcon.mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Only fp32 mlir added to tank, generating mlir on device."
|
||||
)
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
compilation_input_ids = torch.randint(
|
||||
@@ -184,6 +206,7 @@ class Falcon(SharkLLMBase):
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-spirv-index-bits=64",
|
||||
],
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
@@ -192,17 +215,6 @@ class Falcon(SharkLLMBase):
|
||||
return shark_module
|
||||
|
||||
def compile(self):
|
||||
if (
|
||||
not self.falcon_vmfb_path.exists()
|
||||
and self.device == "cuda"
|
||||
and self.precision == "fp32"
|
||||
):
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/7b/cuda/falcon.vmfb",
|
||||
self.falcon_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
|
||||
falcon_shark_model = self.compile_falcon()
|
||||
return falcon_shark_model
|
||||
|
||||
@@ -375,6 +387,8 @@ class Falcon(SharkLLMBase):
|
||||
(model_inputs["input_ids"], model_inputs["attention_mask"]),
|
||||
)
|
||||
)
|
||||
if self.precision == "fp16":
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
@@ -428,18 +442,35 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
falcon_mlir_path = (
|
||||
Path("falcon.mlir")
|
||||
Path(
|
||||
"falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ args.precision
|
||||
+ ".mlir"
|
||||
)
|
||||
if args.falcon_mlir_path is None
|
||||
else Path(args.falcon_mlir_path)
|
||||
)
|
||||
falcon_vmfb_path = (
|
||||
Path("falcon.vmfb")
|
||||
Path(
|
||||
"falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ args.precision
|
||||
+ "_"
|
||||
+ args.device
|
||||
+ ".vmfb"
|
||||
)
|
||||
if args.falcon_vmfb_path is None
|
||||
else Path(args.falcon_vmfb_path)
|
||||
)
|
||||
|
||||
falcon = Falcon(
|
||||
"falcon",
|
||||
"falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path="tiiuae/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
@@ -451,11 +482,16 @@ if __name__ == "__main__":
|
||||
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
|
||||
continue_execution = True
|
||||
|
||||
print("\n-----\nScript executing for the following config: \n")
|
||||
print("Falcon Model: ", falcon.model_name)
|
||||
print("Precision: ", args.precision)
|
||||
print("Device: ", args.device)
|
||||
|
||||
while continue_execution:
|
||||
use_default_prompt = input(
|
||||
"\nDo you wish to use the default prompt text? True or False?: "
|
||||
"\nDo you wish to use the default prompt text? Y/N ?: "
|
||||
)
|
||||
if use_default_prompt:
|
||||
if use_default_prompt in ["Y", "y"]:
|
||||
prompt = default_prompt_text
|
||||
else:
|
||||
prompt = input("Please enter the prompt text: ")
|
||||
@@ -469,5 +505,8 @@ if __name__ == "__main__":
|
||||
res_str,
|
||||
)
|
||||
continue_execution = input(
|
||||
"\nDo you wish to run script one more time? True or False?: "
|
||||
"\nDo you wish to run script one more time? Y/N ?: "
|
||||
)
|
||||
continue_execution = (
|
||||
True if continue_execution in ["Y", "y"] else False
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from apps.language_models.utils import (
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, get_f16_inputs
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
@@ -28,23 +28,48 @@ class Vicuna(SharkLLMBase):
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
first_vicuna_mlir_path=Path("first_vicuna.mlir"),
|
||||
second_vicuna_mlir_path=Path("second_vicuna.mlir"),
|
||||
first_vicuna_vmfb_path=Path("first_vicuna.vmfb"),
|
||||
second_vicuna_vmfb_path=Path("second_vicuna.vmfb"),
|
||||
first_vicuna_mlir_path=None,
|
||||
second_vicuna_mlir_path=None,
|
||||
first_vicuna_vmfb_path=None,
|
||||
second_vicuna_vmfb_path=None,
|
||||
load_mlir_from_shark_tank=True,
|
||||
low_device_memory=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
if precision in ["int4", "int8"]:
|
||||
print("int4 and int8 are not supported yet, using fp32")
|
||||
precision = "fp32"
|
||||
self.precision = precision
|
||||
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
|
||||
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
|
||||
self.first_vicuna_mlir_path = first_vicuna_mlir_path
|
||||
self.second_vicuna_mlir_path = second_vicuna_mlir_path
|
||||
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
|
||||
self.low_device_memory = low_device_memory
|
||||
self.first_vic = None
|
||||
self.second_vic = None
|
||||
if self.first_vicuna_mlir_path == None:
|
||||
self.first_vicuna_mlir_path = self.get_model_path()
|
||||
if self.second_vicuna_mlir_path == None:
|
||||
self.second_vicuna_mlir_path = self.get_model_path("second")
|
||||
if self.first_vicuna_vmfb_path == None:
|
||||
self.first_vicuna_vmfb_path = self.get_model_path(suffix="vmfb")
|
||||
if self.second_vicuna_vmfb_path == None:
|
||||
self.second_vicuna_vmfb_path = self.get_model_path(
|
||||
"second", "vmfb"
|
||||
)
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
|
||||
|
||||
def get_model_path(self, model_number="first", suffix="mlir"):
|
||||
safe_device = "_".join(self.device.split("-"))
|
||||
if suffix == "mlir":
|
||||
return Path(f"{model_number}_vicuna_{self.precision}.{suffix}")
|
||||
return Path(
|
||||
f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}"
|
||||
)
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
@@ -69,7 +94,7 @@ class Vicuna(SharkLLMBase):
|
||||
# Compilation path needs some more work before it is functional
|
||||
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with"
|
||||
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with\n"
|
||||
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if self.first_vicuna_mlir_path.exists():
|
||||
@@ -78,10 +103,10 @@ class Vicuna(SharkLLMBase):
|
||||
else:
|
||||
mlir_generated = False
|
||||
if self.load_mlir_from_shark_tank:
|
||||
if self.precision == "fp32":
|
||||
# download MLIR from shark_tank for fp32
|
||||
if self.precision in ["fp32", "fp16"]:
|
||||
# download MLIR from shark_tank for fp32/fp16
|
||||
download_public_file(
|
||||
"gs://shark_tank/vicuna/unsharded/mlir/first_vicuna.mlir",
|
||||
f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
|
||||
self.first_vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
@@ -96,7 +121,7 @@ class Vicuna(SharkLLMBase):
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Only fp32 mlir added to tank, generating mlir on device."
|
||||
f"Only fp32 and fp16 mlir added to tank, generating {self.precision} mlir on device."
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
@@ -220,10 +245,10 @@ class Vicuna(SharkLLMBase):
|
||||
else:
|
||||
mlir_generated = False
|
||||
if self.load_mlir_from_shark_tank:
|
||||
if self.precision == "fp32":
|
||||
# download MLIR from shark_tank for fp32
|
||||
if self.precision in ["fp32", "fp16"]:
|
||||
# download MLIR from shark_tank for fp32/fp16
|
||||
download_public_file(
|
||||
"gs://shark_tank/vicuna/unsharded/mlir/second_vicuna.mlir",
|
||||
f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
|
||||
self.second_vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
@@ -253,9 +278,15 @@ class Vicuna(SharkLLMBase):
|
||||
model,
|
||||
secondVicunaCompileInput,
|
||||
is_f16=self.precision == "fp16",
|
||||
f16_input_mask=[False, False],
|
||||
f16_input_mask=[False] + [True] * 64,
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
if self.precision == "fp16":
|
||||
secondVicunaCompileInput = get_f16_inputs(
|
||||
secondVicunaCompileInput,
|
||||
True,
|
||||
f16_input_mask=[False] + [True] * 64,
|
||||
)
|
||||
secondVicunaCompileInput = list(secondVicunaCompileInput)
|
||||
for i in range(len(secondVicunaCompileInput)):
|
||||
if i != 0:
|
||||
@@ -307,7 +338,7 @@ class Vicuna(SharkLLMBase):
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append(
|
||||
"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128xf32>"
|
||||
f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
@@ -365,14 +396,19 @@ class Vicuna(SharkLLMBase):
|
||||
# download vmfbs for A100
|
||||
if (
|
||||
not self.first_vicuna_vmfb_path.exists()
|
||||
and self.device == "cuda"
|
||||
and self.precision == "fp32"
|
||||
and self.device in ["cuda", "cpu"]
|
||||
and self.precision in ["fp32", "fp16"]
|
||||
):
|
||||
download_public_file(
|
||||
"gs://shark_tank/vicuna/unsharded/first_vicuna.vmfb",
|
||||
self.first_vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
# combinations that are still in the works
|
||||
if not (self.device == "cuda" and self.precision == "fp16"):
|
||||
# Will generate vmfb on device
|
||||
pass
|
||||
else:
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
|
||||
self.first_vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
else:
|
||||
# get first vic
|
||||
# TODO: Remove after testing to avoid memory overload
|
||||
@@ -380,26 +416,25 @@ class Vicuna(SharkLLMBase):
|
||||
pass
|
||||
if (
|
||||
not self.second_vicuna_vmfb_path.exists()
|
||||
and self.device == "cuda"
|
||||
and self.precision == "fp32"
|
||||
and self.device in ["cuda", "cpu"]
|
||||
and self.precision in ["fp32", "fp16"]
|
||||
):
|
||||
download_public_file(
|
||||
"gs://shark_tank/vicuna/unsharded/second_vicuna.vmfb",
|
||||
self.second_vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
# combinations that are still in the works
|
||||
if not (self.device == "cuda" and self.precision == "fp16"):
|
||||
# Will generate vmfb on device
|
||||
pass
|
||||
else:
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
|
||||
self.second_vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
else:
|
||||
# get second vic
|
||||
# TODO: Remove after testing to avoid memory overload
|
||||
# svic_shark_model = self.compile_second_vicuna()
|
||||
pass
|
||||
|
||||
# get first vic
|
||||
# fvic_shark_model = self.compile_first_vicuna()
|
||||
# get second vic
|
||||
# svic_shark_model = self.compile_second_vicuna()
|
||||
# return tuple of shark_modules
|
||||
# return fvic_shark_model, svic_shark_model
|
||||
return None
|
||||
# return tuple of shark_modules once mem is supported
|
||||
# return fvic_shark_model, svic_shark_model
|
||||
@@ -408,12 +443,19 @@ class Vicuna(SharkLLMBase):
|
||||
# TODO: refactor for cleaner integration
|
||||
import gc
|
||||
|
||||
if not self.low_device_memory:
|
||||
if self.first_vic == None:
|
||||
self.first_vic = self.compile_first_vicuna()
|
||||
if self.second_vic == None:
|
||||
self.second_vic = self.compile_second_vicuna()
|
||||
res = []
|
||||
res_tokens = []
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"is_first": True,
|
||||
"fv": self.compile_first_vicuna(),
|
||||
"fv": self.compile_first_vicuna()
|
||||
if self.first_vic == None
|
||||
else self.first_vic,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
@@ -429,18 +471,20 @@ class Vicuna(SharkLLMBase):
|
||||
print(f"Assistant: {detok}", end=" ", flush=True)
|
||||
|
||||
# Clear First Vic from Memory (main and cuda)
|
||||
del params
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
if self.low_device_memory:
|
||||
del params
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
sec_vic = self.compile_second_vicuna()
|
||||
for _ in range(self.max_num_tokens - 2):
|
||||
params = {
|
||||
"prompt": None,
|
||||
"is_first": False,
|
||||
"logits": logits,
|
||||
"pkv": pkv,
|
||||
"sv": sec_vic,
|
||||
"sv": self.compile_second_vicuna()
|
||||
if self.second_vic == None
|
||||
else self.second_vic,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
@@ -461,9 +505,10 @@ class Vicuna(SharkLLMBase):
|
||||
res.append(detok)
|
||||
if cli:
|
||||
print(f"{detok}", end=" ", flush=True)
|
||||
del sec_vic, pkv, logits
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
if self.device == "cuda":
|
||||
del sec_vic, pkv, logits
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
for i in range(len(res_tokens)):
|
||||
if type(res_tokens[i]) != int:
|
||||
|
||||
@@ -4,6 +4,12 @@ from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
CompiledFirstVicunaLayer,
|
||||
CompiledSecondVicunaLayer,
|
||||
ShardedVicunaModel,
|
||||
LMHead,
|
||||
LMHeadCompiled,
|
||||
VicunaEmbedding,
|
||||
VicunaEmbeddingCompiled,
|
||||
VicunaNorm,
|
||||
VicunaNormCompiled,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from shark.shark_importer import import_with_fx
|
||||
@@ -19,9 +25,11 @@ import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
class Vicuna(SharkLLMBase):
|
||||
# Class representing Sharded Vicuna Model
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
@@ -29,21 +37,25 @@ class Vicuna(SharkLLMBase):
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
config_json=None,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
self.config = config_json
|
||||
self.shark_model = self.compile(device=device)
|
||||
|
||||
def get_tokenizer(self):
|
||||
# Retrieve the tokenizer from Huggingface
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
# Retrieve the torch model from Huggingface
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
@@ -51,6 +63,8 @@ class Vicuna(SharkLLMBase):
|
||||
return vicuna_model
|
||||
|
||||
def write_in_dynamic_inputs0(self, module, dynamic_input_size):
|
||||
# Current solution for ensuring mlir files support dynamic inputs
|
||||
# TODO find a more elegant way to implement this
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
@@ -107,6 +121,7 @@ class Vicuna(SharkLLMBase):
|
||||
past_key_value0=None,
|
||||
past_key_value1=None,
|
||||
):
|
||||
# Compile a hidden decoder layer of vicuna
|
||||
if past_key_value0 is None and past_key_value1 is None:
|
||||
model_inputs = (hidden_states, attention_mask, position_ids)
|
||||
else:
|
||||
@@ -126,7 +141,154 @@ class Vicuna(SharkLLMBase):
|
||||
)
|
||||
return mlir_bytecode
|
||||
|
||||
def compile_to_vmfb(self, inputs, layers, is_first=True):
|
||||
def get_device_index(self, layer_string):
|
||||
# Get the device index from the config file
|
||||
# In the event that different device indices are assigned to
|
||||
# different parts of a layer, a majority vote will be taken and
|
||||
# everything will be run on the most commonly used device
|
||||
if self.config is None:
|
||||
return None
|
||||
idx_votes = {}
|
||||
for key in self.config.keys():
|
||||
if re.search(layer_string, key):
|
||||
if int(self.config[key]["gpu"]) in idx_votes.keys():
|
||||
idx_votes[int(self.config[key]["gpu"])] += 1
|
||||
else:
|
||||
idx_votes[int(self.config[key]["gpu"])] = 1
|
||||
device_idx = max(idx_votes, key=idx_votes.get)
|
||||
return device_idx
|
||||
|
||||
def compile_lmhead(
|
||||
self, lmh, hidden_states, device="cpu", device_idx=None
|
||||
):
|
||||
# compile the lm head of the vicuna model
|
||||
# This can be used for both first and second vicuna, so only needs to be run once
|
||||
mlir_path = Path(f"lmhead.mlir")
|
||||
vmfb_path = Path(f"lmhead.vmfb")
|
||||
if mlir_path.exists():
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
hidden_states = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
lmh,
|
||||
(hidden_states,),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
bytecode,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
shark_module.save_module(module_name="lmhead")
|
||||
shark_module.load_module(vmfb_path)
|
||||
compiled_module = LMHeadCompiled(shark_module)
|
||||
return compiled_module
|
||||
|
||||
def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None):
|
||||
# compile the normalization layer of the vicuna model
|
||||
# This can be used for both first and second vicuna, so only needs to be run once
|
||||
mlir_path = Path(f"norm.mlir")
|
||||
vmfb_path = Path(f"norm.vmfb")
|
||||
if mlir_path.exists():
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
hidden_states = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
fvn,
|
||||
(hidden_states,),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
bytecode,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
shark_module.save_module(module_name="norm")
|
||||
shark_module.load_module(vmfb_path)
|
||||
compiled_module = VicunaNormCompiled(shark_module)
|
||||
return compiled_module
|
||||
|
||||
def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None):
|
||||
# compile the embedding layer of the vicuna model
|
||||
# This can be used for both first and second vicuna, so only needs to be run once
|
||||
mlir_path = Path(f"embedding.mlir")
|
||||
vmfb_path = Path(f"embedding.vmfb")
|
||||
if mlir_path.exists():
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
input_ids = torch_mlir.TensorPlaceholder.like(
|
||||
input_ids, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
fve,
|
||||
(input_ids,),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
bytecode,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
shark_module.save_module(module_name="embedding")
|
||||
shark_module.load_module(vmfb_path)
|
||||
compiled_module = VicunaEmbeddingCompiled(shark_module)
|
||||
|
||||
return compiled_module
|
||||
|
||||
def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True):
|
||||
# compile all layers for vmfb
|
||||
# this needs to be run seperatley for first and second vicuna
|
||||
mlirs, modules = [], []
|
||||
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
|
||||
if is_first:
|
||||
@@ -198,10 +360,6 @@ class Vicuna(SharkLLMBase):
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# bytecode_stream = BytesIO()
|
||||
# module.operation.write_bytecode(bytecode_stream)
|
||||
# bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if is_first:
|
||||
module = self.write_in_dynamic_inputs0(str(module), 137)
|
||||
bytecode = module.encode("UTF-8")
|
||||
@@ -210,19 +368,6 @@ class Vicuna(SharkLLMBase):
|
||||
|
||||
else:
|
||||
module = self.write_in_dynamic_inputs1(str(module), 138)
|
||||
if idx in [0, 5, 6, 7]:
|
||||
module_str = module
|
||||
module_str = module_str.splitlines()
|
||||
new_lines = []
|
||||
for line in module_str:
|
||||
if len(line) < 1000:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line[:999])
|
||||
module_str = "\n".join(new_lines)
|
||||
f1_ = open(f"{idx}_1_test.mlir", "w+")
|
||||
f1_.write(module_str)
|
||||
f1_.close()
|
||||
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
@@ -236,20 +381,27 @@ class Vicuna(SharkLLMBase):
|
||||
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
|
||||
if is_first:
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
device_idx = self.get_device_index(
|
||||
f"first_vicuna.model.model.layers.{idx}[\s.$]"
|
||||
)
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
None,
|
||||
device=device,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
device_idx = self.get_device_index(
|
||||
f"first_vicuna.model.model.layers.{idx}[\s.$]"
|
||||
)
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
mlirs[idx],
|
||||
device=device,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_0",
|
||||
@@ -264,20 +416,28 @@ class Vicuna(SharkLLMBase):
|
||||
modules.append(module)
|
||||
else:
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
device_idx = self.get_device_index(
|
||||
f"second_vicuna.model.model.layers.{idx}[\s.$]"
|
||||
)
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
None,
|
||||
device=device,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
device_idx = self.get_device_index(
|
||||
f"second_vicuna.model.model.layers.{idx}[\s.$]"
|
||||
)
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
mlirs[idx],
|
||||
device=device,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_1",
|
||||
@@ -293,7 +453,7 @@ class Vicuna(SharkLLMBase):
|
||||
|
||||
return mlirs, modules
|
||||
|
||||
def get_sharded_model(self):
|
||||
def get_sharded_model(self, device="cpu"):
|
||||
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
|
||||
# please don't change it
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
@@ -312,11 +472,50 @@ class Vicuna(SharkLLMBase):
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
|
||||
norm = VicunaNorm(vicuna_model.model.norm)
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.model\.norm(?:\.|\s|$)"
|
||||
)
|
||||
print(device_idx)
|
||||
norm = self.compile_norm(
|
||||
norm,
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
device=self.device,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
embeddings = VicunaEmbedding(vicuna_model.model.embed_tokens)
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)"
|
||||
)
|
||||
print(device_idx)
|
||||
embeddings = self.compile_embedding(
|
||||
embeddings,
|
||||
(torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)),
|
||||
device=self.device,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
lmhead = LMHead(vicuna_model.lm_head)
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.lm_head(?:\.|\s|$)"
|
||||
)
|
||||
print(device_idx)
|
||||
lmhead = self.compile_lmhead(
|
||||
lmhead,
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
device=self.device,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
layers0 = [
|
||||
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
_, modules0 = self.compile_to_vmfb(
|
||||
placeholder_input0, layers0, is_first=True
|
||||
placeholder_input0,
|
||||
layers0,
|
||||
is_first=True,
|
||||
device=device,
|
||||
)
|
||||
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
|
||||
|
||||
@@ -324,17 +523,22 @@ class Vicuna(SharkLLMBase):
|
||||
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
_, modules1 = self.compile_to_vmfb(
|
||||
placeholder_input1, layers1, is_first=False
|
||||
placeholder_input1, layers1, is_first=False, device=device
|
||||
)
|
||||
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
|
||||
|
||||
sharded_model = ShardedVicunaModel(
|
||||
vicuna_model, shark_layers0, shark_layers1
|
||||
vicuna_model,
|
||||
shark_layers0,
|
||||
shark_layers1,
|
||||
lmhead,
|
||||
embeddings,
|
||||
norm,
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
def compile(self):
|
||||
return self.get_sharded_model()
|
||||
def compile(self, device="cpu"):
|
||||
return self.get_sharded_model(device=device)
|
||||
|
||||
def generate(self, prompt, cli=False):
|
||||
# TODO: refactor for cleaner integration
|
||||
|
||||
@@ -17,6 +17,10 @@ from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
|
||||
|
||||
def load_mlir_module():
|
||||
if "upscaler" in args.hf_model_id:
|
||||
is_upscaler = True
|
||||
else:
|
||||
is_upscaler = False
|
||||
sd_model = SharkifyStableDiffusionModel(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
@@ -27,6 +31,7 @@ def load_mlir_module():
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
is_upscaler=is_upscaler,
|
||||
use_tuned=False,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
return_mlir=True,
|
||||
|
||||
@@ -61,6 +61,7 @@ def main():
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
@@ -19,6 +19,7 @@ datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += copy_metadata('Pillow')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
|
||||
@@ -163,7 +163,7 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_extended_name_for_all_model(self):
|
||||
model_name = {}
|
||||
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
|
||||
sub_model_list = ["clip", "unet", "unet512", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
sub_model = model
|
||||
@@ -415,7 +415,7 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_cnet, cnet_mlir
|
||||
|
||||
def get_unet(self):
|
||||
def get_unet(self, use_large=False):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
|
||||
super().__init__()
|
||||
@@ -426,7 +426,7 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.in_channels = self.unet.config.in_channels
|
||||
self.train(False)
|
||||
if(args.attention_slicing is not None and args.attention_slicing != "none"):
|
||||
if(args.attention_slicing.isdigit()):
|
||||
@@ -452,17 +452,27 @@ class SharkifyStableDiffusionModel:
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
if(use_large):
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3])
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet512"])
|
||||
else:
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
||||
input_mask = [True, True, True, False]
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
model_name = "unet512" if use_large else "unet"
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["unet"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
@@ -471,13 +481,13 @@ class SharkifyStableDiffusionModel:
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="unet",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_unet, unet_mlir
|
||||
|
||||
def get_unet_upscaler(self):
|
||||
def get_unet_upscaler(self, use_large=False):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
@@ -502,6 +512,13 @@ class SharkifyStableDiffusionModel:
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
if(use_large):
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3])
|
||||
input_mask = [True, True, True, False]
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
@@ -579,16 +596,16 @@ class SharkifyStableDiffusionModel:
|
||||
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict
|
||||
|
||||
def compile_unet_variants(self, model):
|
||||
def compile_unet_variants(self, model, use_large=False):
|
||||
if model == "unet":
|
||||
if self.is_upscaler:
|
||||
return self.get_unet_upscaler()
|
||||
return self.get_unet_upscaler(use_large=use_large)
|
||||
# TODO: Plug the experimental "int8" support at right place.
|
||||
elif self.use_quantize == "int8":
|
||||
from apps.stable_diffusion.src.models.opt_params import get_unet
|
||||
return get_unet()
|
||||
else:
|
||||
return self.get_unet()
|
||||
return self.get_unet(use_large=use_large)
|
||||
else:
|
||||
return self.get_controlled_unet()
|
||||
|
||||
@@ -616,7 +633,7 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def unet(self):
|
||||
def unet(self, use_large=False):
|
||||
try:
|
||||
model = "stencil_unet" if self.use_stencil is not None else "unet"
|
||||
compiled_unet = None
|
||||
@@ -624,14 +641,14 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
if self.base_model_id != "":
|
||||
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
|
||||
else:
|
||||
for model_id in unet_inputs:
|
||||
self.base_model_id = model_id
|
||||
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
|
||||
|
||||
try:
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Retrying with a different base model configuration")
|
||||
|
||||
@@ -81,6 +81,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -112,7 +113,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
|
||||
@@ -57,6 +57,7 @@ class StableDiffusionPipeline:
|
||||
self.vae = None
|
||||
self.text_encoder = None
|
||||
self.unet = None
|
||||
self.unet_512 = None
|
||||
self.model_max_length = 77
|
||||
self.scheduler = scheduler
|
||||
# TODO: Implement using logging python utility.
|
||||
@@ -114,6 +115,24 @@ class StableDiffusionPipeline:
|
||||
del self.unet
|
||||
self.unet = None
|
||||
|
||||
def load_unet_512(self):
|
||||
if self.unet_512 is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.unet_512 = self.sd_model.unet(use_large=True)
|
||||
else:
|
||||
try:
|
||||
self.unet_512 = get_unet(use_large=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.unet_512 = self.sd_model.unet(use_large=True)
|
||||
|
||||
def unload_unet_512(self):
|
||||
del self.unet_512
|
||||
self.unet_512 = None
|
||||
|
||||
def load_vae(self):
|
||||
if self.vae is not None:
|
||||
return
|
||||
@@ -203,7 +222,10 @@ class StableDiffusionPipeline:
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.load_unet()
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
@@ -222,16 +244,28 @@ class StableDiffusionPipeline:
|
||||
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
@@ -254,6 +288,7 @@ class StableDiffusionPipeline:
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -412,6 +447,11 @@ class StableDiffusionPipeline:
|
||||
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
if text_embeddings.shape[1] > model_max_length:
|
||||
pad = (0, 0) * (len(text_embeddings.shape) - 2)
|
||||
pad = pad + (0, 512 - text_embeddings.shape[1])
|
||||
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
|
||||
|
||||
# SHARK: Report clip inference time
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
|
||||
@@ -37,4 +37,5 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
get_generation_text_info,
|
||||
update_lora_weight,
|
||||
resize_stencil,
|
||||
_compile_module,
|
||||
)
|
||||
|
||||
@@ -116,7 +116,7 @@ def load_lower_configs(base_model_id=None):
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
|
||||
else:
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
if not spec or spec in ["sm_80"]:
|
||||
if (
|
||||
version in ["v2_1", "v2_1base"]
|
||||
and args.height == 768
|
||||
@@ -125,6 +125,13 @@ def load_lower_configs(base_model_id=None):
|
||||
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
elif spec in ["rdna3"] and version in [
|
||||
"v2_1",
|
||||
"v2_1base",
|
||||
"v1_4",
|
||||
"v1_5",
|
||||
]:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.max_length}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
|
||||
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
|
||||
else:
|
||||
|
||||
@@ -108,6 +108,13 @@ p.add_argument(
|
||||
help="max length of the tokenizer output, options are 64 and 77.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_embeddings_multiples",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The max multiple length of prompt embeddings compared to the max output length of text encoder.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
@@ -372,6 +379,13 @@ p.add_argument(
|
||||
help="Specify target triple for vulkan",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_metal_target_platform",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for metal",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
|
||||
@@ -18,6 +18,7 @@ from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
from shark.iree_utils.metal_utils import get_metal_target_triple
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
@@ -47,6 +48,7 @@ def get_vmfb_path_name(model_name):
|
||||
def _load_vmfb(shark_module, vmfb_path, model, precision):
|
||||
model = "vae" if "base_vae" in model or "vae_encode" in model else model
|
||||
model = "unet" if "stencil" in model else model
|
||||
model = "unet" if "unet512" in model else model
|
||||
precision = "fp32" if "clip" in model else precision
|
||||
extra_args = get_opt_flags(model, precision)
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
@@ -115,6 +117,7 @@ def compile_through_fx(
|
||||
model_name=None,
|
||||
precision=None,
|
||||
return_mlir=False,
|
||||
device=None,
|
||||
):
|
||||
if not return_mlir and model_name is not None:
|
||||
vmfb_path = get_vmfb_path_name(extended_model_name)
|
||||
@@ -145,7 +148,10 @@ def compile_through_fx(
|
||||
if use_tuned:
|
||||
if "vae" in extended_model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
if "unet" in model_name.split("_")[0]:
|
||||
if (
|
||||
"unet" in model_name.split("_")[0]
|
||||
or "unet_512" in model_name.split("_")[0]
|
||||
):
|
||||
args.annotation_model = "unet"
|
||||
mlir_module = sd_model_annotation(
|
||||
mlir_module, extended_model_name, base_model_id
|
||||
@@ -153,7 +159,7 @@ def compile_through_fx(
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
device=args.device if device is None else device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
if generate_vmfb:
|
||||
@@ -269,6 +275,15 @@ def set_init_device_flags():
|
||||
)
|
||||
elif "cuda" in args.device:
|
||||
args.device = "cuda"
|
||||
elif "metal" in args.device:
|
||||
device_name, args.device = map_device_to_name_path(args.device)
|
||||
if not args.iree_metal_target_platform:
|
||||
triple = get_metal_target_triple(device_name)
|
||||
if triple is not None:
|
||||
args.iree_metal_target_platform = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple {args.iree_metal_target_platform}."
|
||||
)
|
||||
elif "cpu" in args.device:
|
||||
args.device = "cpu"
|
||||
|
||||
@@ -293,13 +308,18 @@ def set_init_device_flags():
|
||||
if (
|
||||
args.precision != "fp16"
|
||||
or args.height not in [512, 768]
|
||||
or (args.height == 512 and args.width != 512)
|
||||
or (args.height == 768 and args.width != 768)
|
||||
or (args.height == 512 and args.width not in [512, 768])
|
||||
or (args.height == 768 and args.width not in [512, 768])
|
||||
or args.batch_size != 1
|
||||
or ("vulkan" not in args.device and "cuda" not in args.device)
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif (
|
||||
args.height != args.width and "rdna2" in args.iree_vulkan_target_triple
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif base_model_id not in [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
@@ -421,9 +441,14 @@ def get_available_devices():
|
||||
available_devices = []
|
||||
vulkan_devices = get_devices_by_name("vulkan")
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
available_devices.append("device => cpu")
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
return available_devices
|
||||
|
||||
|
||||
@@ -732,6 +757,14 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
if args.ckpt_loc:
|
||||
img_model = Path(os.path.basename(args.ckpt_loc)).stem
|
||||
|
||||
img_vae = None
|
||||
if args.custom_vae:
|
||||
img_vae = Path(os.path.basename(args.custom_vae)).stem
|
||||
|
||||
img_lora = None
|
||||
if args.use_lora:
|
||||
img_lora = Path(os.path.basename(args.use_lora)).stem
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
@@ -742,7 +775,9 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
if args.write_metadata_to_png:
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
|
||||
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps: {args.steps},"
|
||||
f"Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed},"
|
||||
f"Size: {args.width}x{args.height}, Model: {img_model}, VAE: {img_vae}, LoRA: {img_lora}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
@@ -753,6 +788,9 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
"Image saved as png instead. Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
# To be as low-impact as possible to the existing CSV format, we append
|
||||
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
|
||||
# importance for each data point. Something to consider.
|
||||
new_entry = {
|
||||
"VARIANT": img_model,
|
||||
"SCHEDULER": args.scheduler,
|
||||
@@ -766,6 +804,8 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
"WIDTH": args.width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
"VAE": img_vae,
|
||||
"LORA": img_lora,
|
||||
}
|
||||
|
||||
new_entry.update(extra_info)
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import transformers # ensures inclusion in pysintaller exe generation
|
||||
|
||||
if sys.platform == "darwin":
|
||||
# import before IREE to avoid torch-MLIR library issues
|
||||
import torch_mlir
|
||||
|
||||
import shutil
|
||||
import PIL, transformers # ensures inclusion in pysintaller exe generation
|
||||
from apps.stable_diffusion.src import args, clear_all
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
@@ -38,6 +44,7 @@ if __name__ == "__main__":
|
||||
img2img_api,
|
||||
upscaler_api,
|
||||
inpaint_api,
|
||||
outpaint_api,
|
||||
)
|
||||
from fastapi import FastAPI, APIRouter
|
||||
import uvicorn
|
||||
@@ -49,23 +56,25 @@ if __name__ == "__main__":
|
||||
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
|
||||
# app.add_api_route(
|
||||
# "/sdapi/v1/outpaint", outpaint_api, methods=["post"]
|
||||
# )
|
||||
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
|
||||
app.include_router(APIRouter())
|
||||
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
|
||||
sys.exit(0)
|
||||
|
||||
import gradio as gr
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
clear_gradio_tmp_imgs_folder,
|
||||
config_gradio_tmp_imgs_folder,
|
||||
)
|
||||
|
||||
config_gradio_tmp_imgs_folder()
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
clear_gradio_tmp_imgs_folder()
|
||||
# Create custom models folders if they don't exist
|
||||
create_custom_models_folders()
|
||||
|
||||
def resource_path(relative_path):
|
||||
|
||||
@@ -340,6 +340,10 @@ def img2img_api(
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
|
||||
@@ -278,7 +278,7 @@ def inpaint_api(
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
@@ -289,6 +289,10 @@ def inpaint_api(
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
|
||||
@@ -287,7 +287,7 @@ def outpaint_api(
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
@@ -298,6 +298,10 @@ def outpaint_api(
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
|
||||
@@ -9,9 +9,6 @@ from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
gradio_tmp_galleries_folder,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
|
||||
|
||||
# -- Functions for file, directory and image info querying
|
||||
@@ -63,19 +60,6 @@ def output_subdirs() -> list[str]:
|
||||
return result_paths
|
||||
|
||||
|
||||
# clear zero length temporary files that gradio 3.22.0 buggily creates
|
||||
# TODO: remove once gradio is upgraded to or past 3.32.0
|
||||
def clear_zero_length_temps():
|
||||
zero_length_temps = [
|
||||
os.path.join(root, file)
|
||||
for root, dirs, files in os.walk(gradio_tmp_galleries_folder)
|
||||
for file in files
|
||||
if os.path.getsize(os.path.join(root, file)) == 0
|
||||
]
|
||||
for file in zero_length_temps:
|
||||
os.remove(file)
|
||||
|
||||
|
||||
# --- Define UI layout for Gradio
|
||||
|
||||
with gr.Blocks() as outputgallery_web:
|
||||
@@ -105,7 +89,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
visible=False,
|
||||
show_label=True,
|
||||
).style(columns=4)
|
||||
gallery.DEFAULT_TEMP_DIR = gradio_tmp_galleries_folder
|
||||
|
||||
with gr.Column(scale=4):
|
||||
with gr.Box():
|
||||
@@ -179,7 +162,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
# --- Event handlers
|
||||
|
||||
def on_clear_gallery():
|
||||
clear_zero_length_temps()
|
||||
return [
|
||||
gr.Gallery.update(
|
||||
value=[],
|
||||
@@ -247,7 +229,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
|
||||
# only update if the current subdir is the most recent one as new images only go there
|
||||
if subdir_paths[0] == subdir:
|
||||
clear_zero_length_temps()
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"
|
||||
|
||||
|
||||
@@ -41,17 +41,21 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
|
||||
curr_system_message = start_message_vicuna
|
||||
if vicuna_model == 0:
|
||||
first_vic_vmfb_path = Path("first_vicuna.vmfb")
|
||||
second_vic_vmfb_path = Path("second_vicuna.vmfb")
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
vicuna_model = Vicuna(
|
||||
"vicuna",
|
||||
hf_model_path=model,
|
||||
device=device,
|
||||
precision=precision,
|
||||
first_vicuna_vmfb_path=first_vic_vmfb_path,
|
||||
second_vicuna_vmfb_path=second_vic_vmfb_path,
|
||||
)
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
@@ -120,9 +124,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
"TheBloke/vicuna-7B-1.1-HF",
|
||||
],
|
||||
)
|
||||
supported_devices = [
|
||||
device for device in available_devices if "cuda" in device
|
||||
]
|
||||
supported_devices = available_devices
|
||||
enabled = len(supported_devices) > 0
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
@@ -138,6 +140,8 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
"int4",
|
||||
"int8",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
|
||||
@@ -34,6 +34,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_iree_metal_target_platform = args.iree_metal_target_platform
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
@@ -137,6 +138,7 @@ def txt2img_inf(
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.iree_metal_target_platform = init_iree_metal_target_platform
|
||||
args.use_tuned = init_use_tuned
|
||||
args.import_mlir = init_import_mlir
|
||||
args.img_path = None
|
||||
@@ -193,6 +195,7 @@ def txt2img_inf(
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
@@ -262,6 +265,10 @@ def txt2img_api(
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
@@ -298,7 +305,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
)
|
||||
txt2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
placeholder="Select 'None' in the dropdown on the left and enter model ID here",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
lines=3,
|
||||
@@ -550,6 +557,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
outputs=[
|
||||
txt2img_png_info_img,
|
||||
@@ -563,5 +573,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -299,6 +299,9 @@ def upscaler_api(
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
|
||||
@@ -1,60 +1,54 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import gradio
|
||||
from time import time
|
||||
|
||||
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
gradio_tmp_galleries_folder = os.path.join(gradio_tmp_imgs_folder, "galleries")
|
||||
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
|
||||
|
||||
# Clear all gradio tmp images
|
||||
def clear_gradio_tmp_imgs_folder():
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
return
|
||||
def config_gradio_tmp_imgs_folder():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
|
||||
# tell gradio to use a directory under shark_tmp for its temporary
|
||||
# image files unless somewhere else has been set
|
||||
if "GRADIO_TEMP_DIR" not in os.environ:
|
||||
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
|
||||
|
||||
# clear all gradio tmp files created by generation galleries
|
||||
print(
|
||||
"Clearing gradio temporary image files from a prior run. This may take some time..."
|
||||
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
|
||||
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
|
||||
)
|
||||
image_files = [
|
||||
filename
|
||||
for filename in os.listdir(gradio_tmp_imgs_folder)
|
||||
if os.path.isfile(os.path.join(gradio_tmp_imgs_folder, filename))
|
||||
and filename.startswith("tmp")
|
||||
and filename.endswith(".png")
|
||||
]
|
||||
if len(image_files) > 0:
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
|
||||
cleanup_start = time()
|
||||
for filename in image_files:
|
||||
os.remove(gradio_tmp_imgs_folder + filename)
|
||||
print(
|
||||
f"Clearing generation temporary image files took {time() - cleanup_start:4f} seconds"
|
||||
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
|
||||
)
|
||||
else:
|
||||
print("no generation temporary files to clear")
|
||||
|
||||
# Clear all gradio tmp files created by output galleries
|
||||
if os.path.exists(gradio_tmp_galleries_folder):
|
||||
cleanup_start = time()
|
||||
shutil.rmtree(gradio_tmp_galleries_folder, ignore_errors=True)
|
||||
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
|
||||
print(
|
||||
f"Clearing output gallery temporary image files took {time() - cleanup_start:4f} seconds"
|
||||
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
|
||||
# older SHARK versions had to workaround gradio bugs and stored things differently
|
||||
else:
|
||||
print("no output gallery temporary files to clear")
|
||||
|
||||
|
||||
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
os.mkdir(gradio_tmp_imgs_folder)
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
|
||||
)
|
||||
pil_image.save(file_obj)
|
||||
return file_obj
|
||||
|
||||
|
||||
# Register save_pil_to_file override
|
||||
gradio.processing_utils.save_pil_to_file = save_pil_to_file
|
||||
image_files = [
|
||||
filename
|
||||
for filename in os.listdir(shark_tmp)
|
||||
if os.path.isfile(os.path.join(shark_tmp, filename))
|
||||
and filename.startswith("tmp")
|
||||
and filename.endswith(".png")
|
||||
]
|
||||
if len(image_files) > 0:
|
||||
print(
|
||||
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
|
||||
)
|
||||
cleanup_start = time()
|
||||
for filename in image_files:
|
||||
os.remove(shark_tmp + filename)
|
||||
print(
|
||||
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
else:
|
||||
print("No temporary images files to clear.")
|
||||
|
||||
@@ -62,6 +62,82 @@ def parse_generation_parameters(x: str):
|
||||
return res
|
||||
|
||||
|
||||
def try_find_model_base_from_png_metadata(
|
||||
file: str, folder: str = "models"
|
||||
) -> str:
|
||||
custom = ""
|
||||
|
||||
# Remove extension from file info
|
||||
if file.endswith(".safetensors") or file.endswith(".ckpt"):
|
||||
file = Path(file).stem
|
||||
# Check for the file name match with one of the local ckpt or safetensors files
|
||||
if Path(get_custom_model_pathfile(file + ".ckpt", folder)).is_file():
|
||||
custom = file + ".ckpt"
|
||||
if Path(
|
||||
get_custom_model_pathfile(file + ".safetensors", folder)
|
||||
).is_file():
|
||||
custom = file + ".safetensors"
|
||||
|
||||
return custom
|
||||
|
||||
|
||||
def find_model_from_png_metadata(
|
||||
key: str, metadata: dict[str, str | int]
|
||||
) -> tuple[str, str]:
|
||||
png_hf_id = ""
|
||||
png_custom = ""
|
||||
|
||||
if key in metadata:
|
||||
model_file = metadata[key]
|
||||
png_custom = try_find_model_base_from_png_metadata(model_file)
|
||||
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
|
||||
if model_file in predefined_models:
|
||||
png_custom = model_file
|
||||
# If nothing had matched, check vendor/hf_model_id
|
||||
if not png_custom and model_file.count("/"):
|
||||
png_hf_id = model_file
|
||||
# No matching model was found
|
||||
if not png_custom and not png_hf_id:
|
||||
print(
|
||||
"Import PNG info: Unable to find a matching model for %s"
|
||||
% model_file
|
||||
)
|
||||
|
||||
return png_custom, png_hf_id
|
||||
|
||||
|
||||
def find_vae_from_png_metadata(
|
||||
key: str, metadata: dict[str, str | int]
|
||||
) -> str:
|
||||
vae_custom = ""
|
||||
|
||||
if key in metadata:
|
||||
vae_file = metadata[key]
|
||||
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
|
||||
|
||||
# VAE input is optional, should not print or throw an error if missing
|
||||
|
||||
return vae_custom
|
||||
|
||||
|
||||
def find_lora_from_png_metadata(
|
||||
key: str, metadata: dict[str, str | int]
|
||||
) -> tuple[str, str]:
|
||||
lora_hf_id = ""
|
||||
lora_custom = ""
|
||||
|
||||
if key in metadata:
|
||||
lora_file = metadata[key]
|
||||
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
|
||||
# If nothing had matched, check vendor/hf_model_id
|
||||
if not lora_custom and lora_file.count("/"):
|
||||
lora_hf_id = lora_file
|
||||
|
||||
# LoRA input is optional, should not print or throw an error if missing
|
||||
|
||||
return lora_custom, lora_hf_id
|
||||
|
||||
|
||||
def import_png_metadata(
|
||||
pil_data,
|
||||
prompt,
|
||||
@@ -74,40 +150,21 @@ def import_png_metadata(
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
):
|
||||
try:
|
||||
png_info = pil_data.info["parameters"]
|
||||
metadata = parse_generation_parameters(png_info)
|
||||
png_hf_model_id = ""
|
||||
png_custom_model = ""
|
||||
|
||||
if "Model" in metadata:
|
||||
# Remove extension from model info
|
||||
if metadata["Model"].endswith(".safetensors") or metadata[
|
||||
"Model"
|
||||
].endswith(".ckpt"):
|
||||
metadata["Model"] = Path(metadata["Model"]).stem
|
||||
# Check for the model name match with one of the local ckpt or safetensors files
|
||||
if Path(
|
||||
get_custom_model_pathfile(metadata["Model"] + ".ckpt")
|
||||
).is_file():
|
||||
png_custom_model = metadata["Model"] + ".ckpt"
|
||||
if Path(
|
||||
get_custom_model_pathfile(metadata["Model"] + ".safetensors")
|
||||
).is_file():
|
||||
png_custom_model = metadata["Model"] + ".safetensors"
|
||||
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
|
||||
if metadata["Model"] in predefined_models:
|
||||
png_custom_model = metadata["Model"]
|
||||
# If nothing had matched, check vendor/hf_model_id
|
||||
if not png_custom_model and metadata["Model"].count("/"):
|
||||
png_hf_model_id = metadata["Model"]
|
||||
# No matching model was found
|
||||
if not png_custom_model and not png_hf_model_id:
|
||||
print(
|
||||
"Import PNG info: Unable to find a matching model for %s"
|
||||
% metadata["Model"]
|
||||
)
|
||||
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
|
||||
"Model", metadata
|
||||
)
|
||||
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
|
||||
"LoRA", metadata
|
||||
)
|
||||
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
|
||||
|
||||
negative_prompt = metadata["Negative prompt"]
|
||||
steps = int(metadata["Steps"])
|
||||
@@ -115,12 +172,24 @@ def import_png_metadata(
|
||||
seed = int(metadata["Seed"])
|
||||
width = float(metadata["Size-1"])
|
||||
height = float(metadata["Size-2"])
|
||||
|
||||
if "Model" in metadata and png_custom_model:
|
||||
custom_model = png_custom_model
|
||||
hf_model_id = ""
|
||||
if "Model" in metadata and png_hf_model_id:
|
||||
custom_model = "None"
|
||||
hf_model_id = png_hf_model_id
|
||||
|
||||
if "LoRA" in metadata and lora_custom_model:
|
||||
custom_lora = lora_custom_model
|
||||
hf_lora_id = ""
|
||||
if "LoRA" in metadata and lora_hf_model_id:
|
||||
custom_lora = "None"
|
||||
hf_lora_id = lora_hf_model_id
|
||||
|
||||
if "VAE" in metadata and vae_custom_model:
|
||||
custom_vae = vae_custom_model
|
||||
|
||||
if "Prompt" in metadata:
|
||||
prompt = metadata["Prompt"]
|
||||
if "Sampler" in metadata:
|
||||
@@ -149,4 +218,7 @@ def import_png_metadata(
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
transformers
|
||||
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
|
||||
diffusers
|
||||
scipy
|
||||
ftfy
|
||||
gradio==3.34.0
|
||||
|
||||
243
rest_api_tests/api_test.py
Normal file
243
rest_api_tests/api_test.py
Normal file
@@ -0,0 +1,243 @@
|
||||
import requests
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def upscaler_test():
|
||||
# Define values here
|
||||
prompt = ""
|
||||
negative_prompt = ""
|
||||
seed = 2121991605
|
||||
height = 512
|
||||
width = 512
|
||||
steps = 50
|
||||
noise_level = 10
|
||||
cfg_scale = 7
|
||||
image_path = r"./rest_api_tests/dog.png"
|
||||
|
||||
# Converting Image to base64
|
||||
img_file = open(image_path, "rb")
|
||||
init_images = [
|
||||
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
|
||||
]
|
||||
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/upscaler"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"seed": seed,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"steps": steps,
|
||||
"noise_level": noise_level,
|
||||
"cfg_scale": cfg_scale,
|
||||
"init_images": init_images,
|
||||
}
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"response from server was : {res.status_code}")
|
||||
|
||||
|
||||
def img2img_test():
|
||||
# Define values here
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
height = 512
|
||||
width = 512
|
||||
steps = 50
|
||||
denoising_strength = 0.75
|
||||
cfg_scale = 7
|
||||
image_path = r"./rest_api_tests/dog.png"
|
||||
|
||||
# Converting Image to Base64
|
||||
img_file = open(image_path, "rb")
|
||||
init_images = [
|
||||
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
|
||||
]
|
||||
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/img2img"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"init_images": init_images,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"steps": steps,
|
||||
"denoising_strength": denoising_strength,
|
||||
"cfg_scale": cfg_scale,
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"response from server was : {res.status_code}")
|
||||
|
||||
# NOTE Uncomment below to save the picture
|
||||
|
||||
# print("Extracting response object")
|
||||
# response_obj = res.json()
|
||||
# img_b64 = response_obj.get("images", [False])[0] or response_obj.get(
|
||||
# "image"
|
||||
# )
|
||||
# img_b2 = base64.b64decode(img_b64.replace("data:image/png;base64,", ""))
|
||||
# im_file = BytesIO(img_b2)
|
||||
# response_img = Image.open(im_file)
|
||||
# print("Saving Response Image to: response_img")
|
||||
# response_img.save(r"rest_api_tests/response_img.png")
|
||||
|
||||
|
||||
def inpainting_test():
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
height = 512
|
||||
width = 512
|
||||
steps = 50
|
||||
noise_level = 10
|
||||
cfg_scale = 7
|
||||
is_full_res = False
|
||||
full_res_padding = 32
|
||||
image_path = r"./rest_api_tests/dog.png"
|
||||
|
||||
img_file = open(image_path, "rb")
|
||||
image = (
|
||||
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
|
||||
)
|
||||
img_file = open(image_path, "rb")
|
||||
mask = (
|
||||
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
|
||||
)
|
||||
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/inpaint"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"image": image,
|
||||
"mask": mask,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"steps": steps,
|
||||
"noise_level": noise_level,
|
||||
"cfg_scale": cfg_scale,
|
||||
"seed": seed,
|
||||
"is_full_res": is_full_res,
|
||||
"full_res_padding": full_res_padding,
|
||||
}
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[Inpainting] response from server was : {res.status_code}")
|
||||
|
||||
|
||||
def outpainting_test():
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
height = 512
|
||||
width = 512
|
||||
steps = 50
|
||||
cfg_scale = 7
|
||||
color_variation = 0.2
|
||||
noise_q = 0.2
|
||||
directions = ["up", "down", "right", "left"]
|
||||
pixels = 32
|
||||
mask_blur = 64
|
||||
image_path = r"./rest_api_tests/dog.png"
|
||||
|
||||
# Converting Image to Base64
|
||||
img_file = open(image_path, "rb")
|
||||
init_images = [
|
||||
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
|
||||
]
|
||||
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/outpaint"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"seed": seed,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"steps": steps,
|
||||
"cfg_scale": cfg_scale,
|
||||
"color_variation": color_variation,
|
||||
"noise_q": noise_q,
|
||||
"directions": directions,
|
||||
"pixels": pixels,
|
||||
"mask_blur": mask_blur,
|
||||
"init_images": init_images,
|
||||
}
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[Outpaint] response from server was : {res.status_code}")
|
||||
|
||||
|
||||
def txt2img_test():
|
||||
prompt = "Paint a rabbit in a top hate"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
height = 512
|
||||
width = 512
|
||||
steps = 50
|
||||
cfg_scale = 7
|
||||
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/txt2img"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"seed": seed,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"steps": steps,
|
||||
"cfg_scale": cfg_scale,
|
||||
}
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[txt2img] response from server was : {res.status_code}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
txt2img_test()
|
||||
img2img_test()
|
||||
upscaler_test()
|
||||
inpainting_test()
|
||||
outpainting_test()
|
||||
BIN
rest_api_tests/dog.png
Normal file
BIN
rest_api_tests/dog.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.5 KiB |
2
setup.py
2
setup.py
@@ -39,7 +39,7 @@ setup(
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"PyYAML",
|
||||
"torch-mlir>=20221021.633",
|
||||
"torch-mlir==20230620.875",
|
||||
]
|
||||
+ backend_deps,
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --pre torch-mlir==20230620.875 torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
@@ -27,6 +27,11 @@ PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; prin
|
||||
echo "Python: $PYTHON"
|
||||
echo "Python version: $PYTHON_VERSION_X_Y"
|
||||
|
||||
if [ "$PYTHON_VERSION_X_Y" != "3.11" ]; then
|
||||
echo "Error: Python version 3.11 is required."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$SKIP_VENV" != "1" ]]; then
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
# Not a conda env. So create a new VENV dir
|
||||
@@ -83,7 +88,7 @@ if [ "$torch_mlir_bin" = true ]; then
|
||||
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
$PYTHON -m pip install --pre torch-mlir==20230620.875 -f https://llvm.github.io/torch-mlir/package-index/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch-mlir"
|
||||
else
|
||||
|
||||
76
shark/examples/shark_inference/mega_test.py
Normal file
76
shark/examples/shark_inference/mega_test.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from shark.shark_inference import SharkInference
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
args,
|
||||
)
|
||||
from MEGABYTE_pytorch import MEGABYTE
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class MegaModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MEGABYTE(
|
||||
num_tokens=16000, # number of tokens
|
||||
dim=(
|
||||
512,
|
||||
256,
|
||||
), # transformer model dimension (512 for coarsest, 256 for fine in this example)
|
||||
max_seq_len=(
|
||||
1024,
|
||||
4,
|
||||
), # sequence length for global and then local. this can be more than 2
|
||||
depth=(
|
||||
6,
|
||||
4,
|
||||
), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
|
||||
dim_head=64, # dimension per head
|
||||
heads=8, # number of attention heads
|
||||
flash_attn=True, # use flash attention
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
megaModel = MegaModel()
|
||||
input = [torch.randint(0, 16000, (1, 1024, 4))]
|
||||
|
||||
# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :-
|
||||
# 1. aten.alias
|
||||
shark_module, _ = compile_through_fx(
|
||||
megaModel,
|
||||
inputs=input,
|
||||
extended_model_name="mega_shark",
|
||||
debug=False,
|
||||
generate_vmfb=True,
|
||||
save_dir=os.getcwd(),
|
||||
extra_args=[],
|
||||
base_model_id=None,
|
||||
model_name="mega_shark",
|
||||
precision=None,
|
||||
return_mlir=True,
|
||||
device="cuda",
|
||||
)
|
||||
# logits = model(x)
|
||||
|
||||
|
||||
def print_output_info(output, msg):
|
||||
print("\n", msg)
|
||||
print("\n\t", output.shape)
|
||||
|
||||
|
||||
ans = shark_module("forward", input)
|
||||
print_output_info(torch.from_numpy(ans), "SHARK's output")
|
||||
|
||||
ans = megaModel.forward(*input)
|
||||
print_output_info(ans, "ORIGINAL Model's output")
|
||||
|
||||
# and sample from the logits accordingly
|
||||
# or you can use the generate function
|
||||
|
||||
# NEED TO LOOK AT THIS LATER IF REQUIRED IN SHARK.
|
||||
# sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -63,10 +63,11 @@ def get_supported_device_list():
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"AMD-AIE": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "vulkan",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"intel-gpu": "level_zero",
|
||||
}
|
||||
@@ -81,10 +82,11 @@ def iree_target_map(device):
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cpu-task": "llvm-cpu",
|
||||
"AMD-AIE": "llvm-cpu",
|
||||
"cpu-sync": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "vulkan",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"intel-gpu": "opencl-spirv",
|
||||
}
|
||||
@@ -101,11 +103,13 @@ def check_device_drivers(device):
|
||||
subprocess.check_output("nvidia-smi")
|
||||
except Exception:
|
||||
return True
|
||||
elif device in ["metal", "vulkan"]:
|
||||
elif device in ["vulkan"]:
|
||||
try:
|
||||
subprocess.check_output("vulkaninfo")
|
||||
except Exception:
|
||||
return True
|
||||
elif device == "metal":
|
||||
return False
|
||||
elif device in ["intel-gpu"]:
|
||||
try:
|
||||
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -19,6 +19,8 @@ from shark.parser import shark_args
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
@@ -38,17 +40,26 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
return get_iree_cpu_args()
|
||||
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
|
||||
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
|
||||
|
||||
return get_iree_cpu_args() + data_tiling_flag + u_kernel_flag
|
||||
if device_uri[0] == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
return get_iree_gpu_args()
|
||||
if device_uri[0] in ["metal", "vulkan"]:
|
||||
if device_uri[0] == "vulkan":
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device_uri[0] == "metal":
|
||||
from shark.iree_utils.metal_utils import get_iree_metal_args
|
||||
|
||||
return get_iree_metal_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device_uri[0] == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
@@ -175,8 +186,10 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
|
||||
vmfb_file.close()
|
||||
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_flatbuffer(
|
||||
config.vm_instance, flatbuffer_blob
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
config.vm_instance,
|
||||
flatbuffer_blob,
|
||||
warn_if_copy=False,
|
||||
)
|
||||
|
||||
benchmark_cl = build_benchmark_args_non_tensor_input(
|
||||
@@ -307,8 +320,8 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
config = ireert.Config(device=haldevice)
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_flatbuffer(
|
||||
config.vm_instance, flatbuffer_blob
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
config.vm_instance, flatbuffer_blob, warn_if_copy=False
|
||||
)
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
@@ -316,6 +329,58 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
def load_vmfb_using_mmap(
|
||||
flatbuffer_blob_or_path, device: str, device_idx: int = None
|
||||
):
|
||||
instance = ireert.VmInstance()
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
allocators=[],
|
||||
)
|
||||
# First get configs.
|
||||
if device_idx is not None:
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
haldriver = ireert.get_driver(device)
|
||||
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
# (This would arise if we're invoking `load_module` from a SharkInference obj)
|
||||
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
|
||||
# (This would arise if we're invoking `compile` from a SharkInference obj)
|
||||
temp_file_to_unlink = None
|
||||
|
||||
if isinstance(flatbuffer_blob_or_path, Path):
|
||||
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
|
||||
if (
|
||||
isinstance(flatbuffer_blob_or_path, str)
|
||||
and ".vmfb" in flatbuffer_blob_or_path
|
||||
):
|
||||
vmfb_file_path = flatbuffer_blob_or_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(mmaped_vmfb)
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
tf.write(flatbuffer_blob_or_path)
|
||||
tf.flush()
|
||||
vmfb_file_path = tf.name
|
||||
temp_file_to_unlink = vmfb_file_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
|
||||
return mmaped_vmfb, config, temp_file_to_unlink
|
||||
|
||||
|
||||
def get_iree_compiled_module(
|
||||
module,
|
||||
device: str,
|
||||
@@ -323,19 +388,58 @@ def get_iree_compiled_module(
|
||||
model_config_path: str = None,
|
||||
extra_args: list = [],
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, model_config_path, extra_args
|
||||
)
|
||||
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
|
||||
temp_file_to_unlink = None
|
||||
# TODO: Currently mmap=True control flow path has been switched off for mmap.
|
||||
# Got to find a cleaner way to unlink/delete the temporary file since
|
||||
# we're setting delete=False when creating NamedTemporaryFile. That's why
|
||||
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
|
||||
if mmap:
|
||||
print(f"Will load the compiled module as a mmapped temporary file")
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_blob, device, device_idx
|
||||
)
|
||||
else:
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=device_idx
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
"config": config,
|
||||
"temp_file_to_unlink": temp_file_to_unlink,
|
||||
}
|
||||
return ret_params
|
||||
|
||||
|
||||
def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None):
|
||||
with open(os.path.join(flatbuffer_path), "rb") as f:
|
||||
flatbuffer_blob = f.read()
|
||||
|
||||
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
|
||||
def load_flatbuffer(
|
||||
flatbuffer_path: str,
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
):
|
||||
temp_file_to_unlink = None
|
||||
if mmap:
|
||||
print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file")
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_path, device, device_idx
|
||||
)
|
||||
else:
|
||||
with open(os.path.join(flatbuffer_path), "rb") as f:
|
||||
flatbuffer_blob = f.read()
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=device_idx
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
"config": config,
|
||||
"temp_file_to_unlink": temp_file_to_unlink,
|
||||
}
|
||||
return ret_params
|
||||
|
||||
|
||||
def export_iree_module_to_vmfb(
|
||||
|
||||
121
shark/iree_utils/metal_utils.py
Normal file
121
shark/iree_utils/metal_utils.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# All the iree_vulkan related functionalities go here.
|
||||
|
||||
from shark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
from sys import platform
|
||||
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
def get_metal_device_name(device_num=0):
|
||||
iree_device_dump = run_cmd("iree-run-module --dump_devices")
|
||||
iree_device_dump = iree_device_dump[0].split("\n\n")
|
||||
metal_device_list = [
|
||||
s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s
|
||||
]
|
||||
if len(metal_device_list) == 0:
|
||||
raise ValueError("No device name found in device dump!")
|
||||
if len(metal_device_list) > 1:
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(metal_device_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing device: {metal_device_list[device_num]}")
|
||||
return metal_device_list[device_num]
|
||||
|
||||
|
||||
def get_os_name():
|
||||
if platform.startswith("linux"):
|
||||
return "linux"
|
||||
elif platform == "darwin":
|
||||
return "macos"
|
||||
elif platform == "win32":
|
||||
return "windows"
|
||||
else:
|
||||
print("Cannot detect OS type, defaulting to linux.")
|
||||
return "linux"
|
||||
|
||||
|
||||
def get_metal_target_triple(device_name):
|
||||
"""This method provides a target triple str for specified vulkan device.
|
||||
|
||||
Args:
|
||||
device_name (str): name of the hardware device to be used with vulkan
|
||||
|
||||
Returns:
|
||||
str or None: target triple or None if no match found for given name
|
||||
"""
|
||||
# Apple Targets
|
||||
if all(x in device_name for x in ("Apple", "M1")):
|
||||
triple = "m1-moltenvk-macos"
|
||||
elif all(x in device_name for x in ("Apple", "M2")):
|
||||
triple = "m1-moltenvk-macos"
|
||||
|
||||
else:
|
||||
triple = None
|
||||
return triple
|
||||
|
||||
|
||||
def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
for flag in extra_args:
|
||||
if "-iree-metal-target-platform=" in flag:
|
||||
print(f"Using target triple {flag.split('=')[1]}")
|
||||
return None
|
||||
|
||||
if device_name == "" or device_name == [] or device_name is None:
|
||||
metal_device = get_metal_device_name(device_num=device_num)
|
||||
else:
|
||||
metal_device = device_name
|
||||
triple = get_metal_target_triple(metal_device)
|
||||
if triple is not None:
|
||||
print(
|
||||
f"Found metal device {metal_device}. Using metal target triple {triple}"
|
||||
)
|
||||
return f"-iree-metal-target-platform={triple}"
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
or pull up an issue."""
|
||||
)
|
||||
print(f"Target : {metal_device}")
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_metal_args(device_num=0, extra_args=[]):
|
||||
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
res_metal_flag = []
|
||||
metal_triple_flag = None
|
||||
for arg in extra_args:
|
||||
if "-iree-metal-target-platform=" in arg:
|
||||
print(f"Using target triple {arg} from command line args")
|
||||
metal_triple_flag = arg
|
||||
break
|
||||
|
||||
if metal_triple_flag is None:
|
||||
metal_triple_flag = get_metal_triple_flag(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
|
||||
if metal_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
|
||||
res_metal_flag.append(vulkan_target_env)
|
||||
return res_metal_flag
|
||||
|
||||
|
||||
def set_iree_metal_runtime_flags(flags):
|
||||
for flag in flags:
|
||||
ireert.flags.parse_flags(flag)
|
||||
return
|
||||
@@ -60,12 +60,15 @@ def download_public_file(
|
||||
else:
|
||||
continue
|
||||
|
||||
destination_filename = os.path.join(destination_folder_name, blob_name)
|
||||
if os.path.isdir(destination_filename):
|
||||
continue
|
||||
with open(destination_filename, "wb") as f:
|
||||
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
|
||||
storage_client.download_blob_to_file(blob, file_obj)
|
||||
else:
|
||||
destination_filename = os.path.join(
|
||||
destination_folder_name, blob_name
|
||||
)
|
||||
if os.path.isdir(destination_filename):
|
||||
continue
|
||||
with open(destination_filename, "wb") as f:
|
||||
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
|
||||
storage_client.download_blob_to_file(blob, file_obj)
|
||||
|
||||
|
||||
input_type_to_np_dtype = {
|
||||
|
||||
206
shark/shark_eager/shark_eager.py
Normal file
206
shark/shark_eager/shark_eager.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
from shark.shark_importer import import_with_fx
|
||||
import torchvision.models as models
|
||||
import copy
|
||||
import io
|
||||
import numpy as np
|
||||
import sys
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node
|
||||
from typing import Dict
|
||||
import torch_mlir
|
||||
|
||||
|
||||
def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_g, inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile(extra_args=[])
|
||||
return shark_module
|
||||
|
||||
|
||||
def _make_single_op_gm(node, captured_val, compiled_graph):
|
||||
"""Make a GraphModule that just executes the given node."""
|
||||
g = torch.fx.Graph()
|
||||
env = {}
|
||||
inputs = []
|
||||
for arg in node.args:
|
||||
if arg and hasattr(arg, "name"):
|
||||
env[arg.name] = g.placeholder(arg.name)
|
||||
if isinstance(captured_val[arg.name], (list, tuple)):
|
||||
for val in captured_val[arg.name]:
|
||||
inputs.append(val)
|
||||
else:
|
||||
inputs.append(captured_val[arg.name])
|
||||
|
||||
call = g.node_copy(node, lambda n: env[n.name])
|
||||
g.output(call)
|
||||
g.lint()
|
||||
single_node = torch.fx.GraphModule(torch.nn.Module(), g)
|
||||
compiled_module = shark_backend(single_node, inputs)
|
||||
compiled_graph[node.name] = {
|
||||
"module": compiled_module,
|
||||
"inputs": [i for i in env],
|
||||
"result": None,
|
||||
}
|
||||
return
|
||||
|
||||
|
||||
def compiled_graph(gm: torch.fx.GraphModule, attr_info):
|
||||
compiled_graph = {}
|
||||
g = gm.graph
|
||||
for node in g.nodes:
|
||||
if node.op == "call_function":
|
||||
if not (
|
||||
node.target in [torch.ops.aten.empty]
|
||||
or node.name.startswith("getitem")
|
||||
):
|
||||
_make_single_op_gm(node, attr_info, compiled_graph)
|
||||
|
||||
# Currently torch.aten.empty has an compilation issue, so running natively.
|
||||
elif node.target in [torch.ops.aten.empty]:
|
||||
compiled_graph[node.name] = {
|
||||
"target": node.target,
|
||||
"args": node.args,
|
||||
"kwargs": node.kwargs,
|
||||
"result": None,
|
||||
}
|
||||
# Get item is a simple case takes a tuple and return the tensor at a particular index.
|
||||
elif node.name.startswith("getitem"):
|
||||
compiled_graph[node.name] = {
|
||||
"input": node.args[0].name,
|
||||
"pos": node.args[1],
|
||||
"result": None,
|
||||
}
|
||||
|
||||
return compiled_graph
|
||||
|
||||
|
||||
class ShapeProp:
|
||||
"""
|
||||
Shape propagation. This class takes a `GraphModule`.
|
||||
Then, its `propagate` method executes the `GraphModule`
|
||||
node-by-node with the given arguments. As each operation
|
||||
executes, the ShapeProp class stores away the shape and
|
||||
element type for the output values of each operation on
|
||||
the `shape` and `dtype` attributes of the operation's
|
||||
`Node`.
|
||||
"""
|
||||
|
||||
def __init__(self, mod):
|
||||
self.mod = mod
|
||||
self.graph = mod.graph
|
||||
self.modules = dict(self.mod.named_modules())
|
||||
|
||||
def propagate(self, *args):
|
||||
args_iter = iter(args)
|
||||
env: Dict[str, Node] = {}
|
||||
|
||||
def load_arg(a):
|
||||
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
|
||||
|
||||
def fetch_attr(target: str):
|
||||
target_atoms = target.split(".")
|
||||
attr_itr = self.mod
|
||||
for i, atom in enumerate(target_atoms):
|
||||
if not hasattr(attr_itr, atom):
|
||||
raise RuntimeError(
|
||||
f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
|
||||
)
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
return attr_itr
|
||||
|
||||
for node in self.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
result = next(args_iter)
|
||||
elif node.op == "get_attr":
|
||||
result = fetch_attr(node.target)
|
||||
elif node.op == "call_function":
|
||||
result = node.target(
|
||||
*load_arg(node.args), **load_arg(node.kwargs)
|
||||
)
|
||||
elif node.op == "call_method":
|
||||
self_obj, *args = load_arg(node.args)
|
||||
kwargs = load_arg(node.kwargs)
|
||||
result = getattr(self_obj, node.target)(*args, **kwargs)
|
||||
elif node.op == "call_module":
|
||||
result = self.modules[node.target](
|
||||
*load_arg(node.args), **load_arg(node.kwargs)
|
||||
)
|
||||
|
||||
# This is the only code specific to shape propagation.
|
||||
# you can delete this `if` branch and this becomes
|
||||
# a generic GraphModule interpreter.
|
||||
if isinstance(result, torch.Tensor):
|
||||
node.shape = result.shape
|
||||
node.dtype = result.dtype
|
||||
|
||||
env[node.name] = result
|
||||
|
||||
return env
|
||||
|
||||
# return load_arg(self.graph.result)
|
||||
|
||||
|
||||
resnet18 = models.resnet18(pretrained=True)
|
||||
resnet18.train(False)
|
||||
input = (torch.randn(1, 3, 224, 224),)
|
||||
|
||||
print(resnet18(input[0]))
|
||||
|
||||
fx_graph = import_with_fx(resnet18, input, mlir_type="fx")
|
||||
|
||||
shape_prop = ShapeProp(fx_graph)
|
||||
|
||||
x = shape_prop.propagate(input[0])
|
||||
|
||||
shark_graph = compiled_graph(fx_graph, x)
|
||||
|
||||
|
||||
for key in shark_graph:
|
||||
if key.startswith("getitem"):
|
||||
input_val = shark_graph[key]["input"]
|
||||
pos = shark_graph[key]["pos"]
|
||||
if input_val not in shark_graph:
|
||||
shark_graph[key]["result"] = x[input_val][pos].detach()
|
||||
else:
|
||||
shark_graph[key]["result"] = shark_graph[input_val]["result"][
|
||||
pos
|
||||
].detach()
|
||||
elif key.startswith("empty"):
|
||||
operator = shark_graph[key]["target"]
|
||||
args = shark_graph[key]["args"]
|
||||
kwargs = shark_graph[key]["kwargs"]
|
||||
shark_graph[key]["result"] = operator(*args, **kwargs).detach()
|
||||
else:
|
||||
input_val = shark_graph[key]["inputs"]
|
||||
input_tensors = []
|
||||
for input in input_val:
|
||||
if input not in shark_graph:
|
||||
input_tensors.append(x[input].detach())
|
||||
else:
|
||||
input_tensors.append(shark_graph[input]["result"])
|
||||
|
||||
val = shark_graph[key]["module"]("forward", input_tensors)
|
||||
if isinstance(val, (tuple, list)):
|
||||
list_val = []
|
||||
for v in val:
|
||||
list_val.append(torch.from_numpy(v))
|
||||
shark_graph[key]["result"] = list_val
|
||||
else:
|
||||
shark_graph[key]["result"] = torch.from_numpy(val)
|
||||
|
||||
|
||||
print(shark_graph)
|
||||
@@ -1,5 +1,8 @@
|
||||
import re
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
import torch_mlir
|
||||
from iree.compiler import compile_str
|
||||
from shark.shark_importer import import_with_fx, get_f16_inputs
|
||||
|
||||
|
||||
class GenerateConfigFile:
|
||||
@@ -7,7 +10,9 @@ class GenerateConfigFile:
|
||||
self,
|
||||
model,
|
||||
num_sharding_stages: int,
|
||||
sharding_stages_id: list[str] = None,
|
||||
sharding_stages_id: list[str],
|
||||
model_input=None,
|
||||
config_file_path="model_config.json",
|
||||
):
|
||||
self.model = model
|
||||
self.num_sharding_stages = num_sharding_stages
|
||||
@@ -15,8 +20,67 @@ class GenerateConfigFile:
|
||||
assert self.num_sharding_stages == len(
|
||||
self.sharding_stages_id
|
||||
), "Number of sharding stages should be equal to the list of their ID"
|
||||
self.model_input = model_input
|
||||
self.config_file_path = config_file_path
|
||||
|
||||
def generate_json(self):
|
||||
def split_into_dispatches(
|
||||
self,
|
||||
backend,
|
||||
fx_tracing_required=True,
|
||||
f16_model=False,
|
||||
torch_mlir_tracing=False,
|
||||
):
|
||||
graph_for_compilation = self.model
|
||||
if fx_tracing_required:
|
||||
graph_for_compilation = import_with_fx(
|
||||
self.model,
|
||||
self.model_input,
|
||||
is_f16=f16_model,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
graph_for_compilation,
|
||||
(self.model_input),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=torch_mlir_tracing,
|
||||
verbose=False,
|
||||
)
|
||||
module = module.operation.get_asm(large_elements_limit=4)
|
||||
compiled_module_str = str(
|
||||
compile_str(
|
||||
str(module),
|
||||
target_backends=[backend],
|
||||
extra_args=[
|
||||
"--compile-to=flow",
|
||||
"--mlir-elide-elementsattrs-if-larger=4",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
substring_start_idx = [
|
||||
m.start()
|
||||
for m in re.finditer("flow.dispatch @", compiled_module_str)
|
||||
]
|
||||
dispatch_list = dict()
|
||||
|
||||
# dispatch_no is the 'i'th index of a dispatch out of n total dispatches of a model
|
||||
# dispatch_id is the unique id of a dispatch, multiple instances of the same dispatch
|
||||
# can occur in a model
|
||||
for dispatch_no, substring_idx in enumerate(substring_start_idx):
|
||||
dispatch_idx = (
|
||||
compiled_module_str[substring_idx:]
|
||||
.split(":")[0]
|
||||
.split("@")[-1]
|
||||
)
|
||||
key = "dispatch_no_" + str(dispatch_no)
|
||||
dispatch_list[key] = {n: "None" for n in self.sharding_stages_id}
|
||||
dispatch_list[key]["dispatch_id"] = dispatch_idx
|
||||
|
||||
self.generate_json(dispatch_list)
|
||||
|
||||
def split_into_layers(self):
|
||||
model_dictionary = dict()
|
||||
|
||||
for name, m in self.model.named_modules():
|
||||
@@ -34,5 +98,8 @@ class GenerateConfigFile:
|
||||
layer_dict = {n: "None" for n in self.sharding_stages_id}
|
||||
model_dictionary[name] = layer_dict
|
||||
|
||||
with open("model_config.json", "w") as outfile:
|
||||
json.dump(model_dictionary, outfile)
|
||||
self.generate_json(model_dictionary)
|
||||
|
||||
def generate_json(self, artifacts):
|
||||
with open(self.config_file_path, "w") as outfile:
|
||||
json.dump(artifacts, outfile)
|
||||
|
||||
@@ -370,6 +370,7 @@ def transform_fx(fx_g):
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.zeros,
|
||||
torch.ops.aten.zeros_like,
|
||||
]:
|
||||
if node.kwargs.get("dtype") == torch.float32:
|
||||
node.kwargs = kwargs_dict
|
||||
@@ -525,6 +526,8 @@ def import_with_fx(
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
torch.ops.aten.native_layer_norm,
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
]
|
||||
),
|
||||
)(*inputs)
|
||||
@@ -552,6 +555,9 @@ def import_with_fx(
|
||||
add_upcast(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if mlir_type == "fx":
|
||||
return fx_g
|
||||
|
||||
if training:
|
||||
change_fx_graph_return_to_tuple(fx_g)
|
||||
inputs = flatten_training_input(inputs)
|
||||
|
||||
@@ -48,6 +48,8 @@ class SharkInference:
|
||||
Refer to {https://mlir.llvm.org/docs/Dialects/}
|
||||
is_benchmark: bool
|
||||
Whether this SharkInference module should be benchmark-enabled.
|
||||
mmap: bool
|
||||
Whether to load/run vmfb using mmap. It's `True` by default.
|
||||
|
||||
Methods
|
||||
-------
|
||||
@@ -70,6 +72,7 @@ class SharkInference:
|
||||
dispatch_benchmark: str = None,
|
||||
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
|
||||
device_idx: int = None,
|
||||
mmap: bool = True,
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
@@ -88,6 +91,7 @@ class SharkInference:
|
||||
)
|
||||
|
||||
self.shark_runner = None
|
||||
self.mmap = mmap
|
||||
|
||||
def compile(self, extra_args=[]):
|
||||
if self.dispatch_benchmarks is not None:
|
||||
@@ -201,12 +205,14 @@ class SharkInference:
|
||||
compile_vmfb=False,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
(
|
||||
self.shark_runner.iree_compilation_module,
|
||||
self.shark_runner.iree_config,
|
||||
) = load_flatbuffer(
|
||||
params = load_flatbuffer(
|
||||
path,
|
||||
self.device,
|
||||
self.device_idx,
|
||||
mmap=self.mmap,
|
||||
)
|
||||
self.shark_runner.iree_compilation_module = params["vmfb"]
|
||||
self.shark_runner.iree_config = params["config"]
|
||||
self.shark_runner.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
return
|
||||
|
||||
@@ -85,16 +85,17 @@ class SharkRunner:
|
||||
|
||||
if compile_vmfb == True:
|
||||
# Compile the module to get the .vmfb.
|
||||
(
|
||||
self.iree_compilation_module,
|
||||
self.iree_config,
|
||||
) = get_iree_compiled_module(
|
||||
params = get_iree_compiled_module(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
device_idx=self.device_idx,
|
||||
)
|
||||
self.iree_compilation_module = params["vmfb"]
|
||||
self.iree_config = params["config"]
|
||||
self.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
|
||||
def run(self, function_name, inputs: tuple, send_to_host=False):
|
||||
return get_results(
|
||||
|
||||
@@ -12,8 +12,8 @@ from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
OPT_MODEL = "opt-1.3b"
|
||||
OPT_FS_NAME = "opt-1_3b"
|
||||
MAX_SEQUENCE_LENGTH = 30
|
||||
MAX_NEW_TOKENS = 20
|
||||
MAX_SEQUENCE_LENGTH = 128
|
||||
MAX_NEW_TOKENS = 60
|
||||
|
||||
|
||||
def create_module(model_name, tokenizer, device):
|
||||
@@ -110,13 +110,13 @@ if __name__ == "__main__":
|
||||
"facebook/" + OPT_MODEL, use_fast=False
|
||||
)
|
||||
vmfb_path = (
|
||||
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-sync.vmfb"
|
||||
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-task.vmfb"
|
||||
)
|
||||
opt_shark_module = SharkInference(mlir_module=None, device="cpu-sync")
|
||||
opt_shark_module = SharkInference(mlir_module=None, device="cpu-task")
|
||||
if os.path.isfile(vmfb_path):
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-sync")
|
||||
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-task")
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
while True:
|
||||
try:
|
||||
|
||||
@@ -24,4 +24,5 @@ bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-
|
||||
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
|
Reference in New Issue
Block a user