mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
33 Commits
20230621.7
...
20230629.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
534de05791 | ||
|
|
5779e8c039 | ||
|
|
d496053590 | ||
|
|
6274a813c9 | ||
|
|
1d6a1f9f8a | ||
|
|
75672c0e28 | ||
|
|
74a7202173 | ||
|
|
27a08735db | ||
|
|
eaa49cce17 | ||
|
|
10657d6fb1 | ||
|
|
e3ab844cd1 | ||
|
|
5ce6001b41 | ||
|
|
501d0ca52e | ||
|
|
b444528715 | ||
|
|
6e6c90f62b | ||
|
|
8cdb38496e | ||
|
|
726d73d6ba | ||
|
|
4d55e51d46 | ||
|
|
6ef78ee7ba | ||
|
|
4002da7161 | ||
|
|
ecb5e8e5d8 | ||
|
|
28e0919321 | ||
|
|
28f4d44a6b | ||
|
|
97f7e79391 | ||
|
|
44a8f2f8db | ||
|
|
8822b9acd7 | ||
|
|
0ca3b9fce3 | ||
|
|
045f2bb147 | ||
|
|
a811b867b9 | ||
|
|
cdd505e2dd | ||
|
|
1b0f39107c | ||
|
|
b9b8955f74 | ||
|
|
6f7a85eee3 |
9
.github/workflows/test-models.yml
vendored
9
.github/workflows/test-models.yml
vendored
@@ -35,6 +35,8 @@ jobs:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
suite: lint
|
||||
- os: MacStudio
|
||||
suite: metal
|
||||
exclude:
|
||||
- os: ubuntu-latest
|
||||
suite: vulkan
|
||||
@@ -46,6 +48,8 @@ jobs:
|
||||
suite: cuda
|
||||
- os: MacStudio
|
||||
suite: cpu
|
||||
- os: MacStudio
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
@@ -125,15 +129,14 @@ jobs:
|
||||
# python build_tools/stable_diffusion_testing.py --device=cuda
|
||||
|
||||
- name: Validate Vulkan Models (MacOS)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
|
||||
if: matrix.suite == 'metal' && matrix.os == 'MacStudio'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
export DYLD_LIBRARY_PATH=/usr/local/lib/
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -159,7 +159,7 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
# vscode related
|
||||
.vscode
|
||||
|
||||
@@ -3,6 +3,10 @@ from pathlib import Path
|
||||
from apps.language_models.src.pipelines import vicuna_pipeline as vp
|
||||
from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp
|
||||
import torch
|
||||
import json
|
||||
|
||||
if __name__ == "__main__":
|
||||
import gc
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -55,35 +59,38 @@ parser.add_argument(
|
||||
help="Run model in cli mode",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
default=None,
|
||||
help="configuration file",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
vic = None
|
||||
if not args.sharded:
|
||||
first_vic_mlir_path = (
|
||||
Path(f"first_vicuna_{args.precision}.mlir")
|
||||
None
|
||||
if args.first_vicuna_mlir_path is None
|
||||
else Path(args.first_vicuna_mlir_path)
|
||||
)
|
||||
second_vic_mlir_path = (
|
||||
Path(f"second_vicuna_{args.precision}.mlir")
|
||||
None
|
||||
if args.second_vicuna_mlir_path is None
|
||||
else Path(args.second_vicuna_mlir_path)
|
||||
)
|
||||
first_vic_vmfb_path = (
|
||||
Path(
|
||||
f"first_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
|
||||
)
|
||||
None
|
||||
if args.first_vicuna_vmfb_path is None
|
||||
else Path(args.first_vicuna_vmfb_path)
|
||||
)
|
||||
second_vic_vmfb_path = (
|
||||
Path(
|
||||
f"second_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
|
||||
)
|
||||
None
|
||||
if args.second_vicuna_vmfb_path is None
|
||||
else Path(args.second_vicuna_vmfb_path)
|
||||
)
|
||||
|
||||
vic = vp.Vicuna(
|
||||
"vicuna",
|
||||
device=args.device,
|
||||
@@ -95,16 +102,21 @@ if __name__ == "__main__":
|
||||
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
|
||||
)
|
||||
else:
|
||||
if args.config is not None:
|
||||
config_file = open(args.config)
|
||||
config_json = json.load(config_file)
|
||||
config_file.close()
|
||||
else:
|
||||
config_json = None
|
||||
vic = vsp.Vicuna(
|
||||
"vicuna",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
config_json=config_json,
|
||||
)
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
import gc
|
||||
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
|
||||
@@ -145,7 +145,7 @@ class CompiledSecondVicunaLayer(torch.nn.Module):
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers0, layers1):
|
||||
def __init__(self, model, layers0, layers1, lmhead, embedding, norm):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert len(layers0) == len(model.model.layers)
|
||||
@@ -154,6 +154,12 @@ class ShardedVicunaModel(torch.nn.Module):
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers0 = layers0
|
||||
self.layers1 = layers1
|
||||
self.norm = norm
|
||||
self.embedding = embedding
|
||||
self.lmhead = lmhead
|
||||
self.model.model.norm = self.norm
|
||||
self.model.model.embed_tokens = self.embedding
|
||||
self.model.lm_head = self.lmhead
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -176,3 +182,69 @@ class ShardedVicunaModel(torch.nn.Module):
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class LMHead(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class LMHeadCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach()
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaNorm(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaNormCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states.detach()
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaEmbedding(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids):
|
||||
output = self.model(input_ids)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaEmbeddingCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids.detach()
|
||||
output = self.model("forward", (input_ids,))
|
||||
output = torch.tensor(output)
|
||||
return output
|
||||
|
||||
@@ -33,16 +33,25 @@ class Vicuna(SharkLLMBase):
|
||||
first_vicuna_vmfb_path=None,
|
||||
second_vicuna_vmfb_path=None,
|
||||
load_mlir_from_shark_tank=True,
|
||||
low_device_memory=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
if not load_mlir_from_shark_tank and precision in ["int4", "int8"]:
|
||||
print(
|
||||
"int4 and int8 are only available from SHARK tank, please set --load_mlir_from_shark_tank, using fp32 now"
|
||||
)
|
||||
precision = "fp32"
|
||||
self.precision = precision
|
||||
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
|
||||
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
|
||||
self.first_vicuna_mlir_path = first_vicuna_mlir_path
|
||||
self.second_vicuna_mlir_path = second_vicuna_mlir_path
|
||||
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
|
||||
self.low_device_memory = low_device_memory
|
||||
self.first_vic = None
|
||||
self.second_vic = None
|
||||
if self.first_vicuna_mlir_path == None:
|
||||
self.first_vicuna_mlir_path = self.get_model_path()
|
||||
if self.second_vicuna_mlir_path == None:
|
||||
@@ -61,7 +70,7 @@ class Vicuna(SharkLLMBase):
|
||||
if suffix == "mlir":
|
||||
return Path(f"{model_number}_vicuna_{self.precision}.{suffix}")
|
||||
return Path(
|
||||
f"{model_number}_vicuna_{safe_device}_{self.precision}.{suffix}"
|
||||
f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}"
|
||||
)
|
||||
|
||||
def get_tokenizer(self):
|
||||
@@ -87,7 +96,7 @@ class Vicuna(SharkLLMBase):
|
||||
# Compilation path needs some more work before it is functional
|
||||
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with"
|
||||
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with\n"
|
||||
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if self.first_vicuna_mlir_path.exists():
|
||||
@@ -96,8 +105,8 @@ class Vicuna(SharkLLMBase):
|
||||
else:
|
||||
mlir_generated = False
|
||||
if self.load_mlir_from_shark_tank:
|
||||
if self.precision in ["fp32", "fp16"]:
|
||||
# download MLIR from shark_tank for fp32/fp16
|
||||
if self.precision in ["fp32", "fp16", "int8", "int4"]:
|
||||
# download MLIR from shark_tank
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
|
||||
self.first_vicuna_mlir_path.absolute(),
|
||||
@@ -114,7 +123,7 @@ class Vicuna(SharkLLMBase):
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Only fp32 and fp16 mlir added to tank, generating {self.precision} mlir on device."
|
||||
f"Only fp32/fp16/int8/int4 mlir added to tank, generating {self.precision} mlir on device."
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
@@ -238,8 +247,8 @@ class Vicuna(SharkLLMBase):
|
||||
else:
|
||||
mlir_generated = False
|
||||
if self.load_mlir_from_shark_tank:
|
||||
if self.precision in ["fp32", "fp16"]:
|
||||
# download MLIR from shark_tank for fp32/fp16
|
||||
if self.precision in ["fp32", "fp16", "int8", "int4"]:
|
||||
# download MLIR from shark_tank
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
|
||||
self.second_vicuna_mlir_path.absolute(),
|
||||
@@ -256,7 +265,7 @@ class Vicuna(SharkLLMBase):
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Only fp32 mlir added to tank, generating mlir on device."
|
||||
"Only fp32/fp16/int8/int4 mlir added to tank, generating mlir on device."
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
@@ -432,16 +441,30 @@ class Vicuna(SharkLLMBase):
|
||||
# return tuple of shark_modules once mem is supported
|
||||
# return fvic_shark_model, svic_shark_model
|
||||
|
||||
def decode_tokens(self, res_tokens):
|
||||
for i in range(len(res_tokens)):
|
||||
if type(res_tokens[i]) != int:
|
||||
res_tokens[i] = int(res_tokens[i][0])
|
||||
|
||||
res_str = self.tokenizer.decode(res_tokens)
|
||||
return res_str
|
||||
|
||||
def generate(self, prompt, cli=False):
|
||||
# TODO: refactor for cleaner integration
|
||||
import gc
|
||||
|
||||
res = []
|
||||
if not self.low_device_memory:
|
||||
if self.first_vic == None:
|
||||
self.first_vic = self.compile_first_vicuna()
|
||||
if self.second_vic == None:
|
||||
self.second_vic = self.compile_second_vicuna()
|
||||
res_tokens = []
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"is_first": True,
|
||||
"fv": self.compile_first_vicuna(),
|
||||
"fv": self.compile_first_vicuna()
|
||||
if self.first_vic == None
|
||||
else self.first_vic,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
@@ -450,25 +473,27 @@ class Vicuna(SharkLLMBase):
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["pkv"]
|
||||
detok = generated_token_op["detok"]
|
||||
yield detok
|
||||
|
||||
res.append(detok)
|
||||
res_tokens.append(token)
|
||||
if cli:
|
||||
print(f"Assistant: {detok}", end=" ", flush=True)
|
||||
|
||||
# Clear First Vic from Memory (main and cuda)
|
||||
del params
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
if self.low_device_memory:
|
||||
del params
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
sec_vic = self.compile_second_vicuna()
|
||||
for _ in range(self.max_num_tokens - 2):
|
||||
params = {
|
||||
"prompt": None,
|
||||
"is_first": False,
|
||||
"logits": logits,
|
||||
"pkv": pkv,
|
||||
"sv": sec_vic,
|
||||
"sv": self.compile_second_vicuna()
|
||||
if self.second_vic == None
|
||||
else self.second_vic,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
@@ -482,24 +507,24 @@ class Vicuna(SharkLLMBase):
|
||||
break
|
||||
res_tokens.append(token)
|
||||
if detok == "<0x0A>":
|
||||
res.append("\n")
|
||||
if cli:
|
||||
print("\n", end="", flush=True)
|
||||
else:
|
||||
res.append(detok)
|
||||
if cli:
|
||||
print(f"{detok}", end=" ", flush=True)
|
||||
del sec_vic, pkv, logits
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
for i in range(len(res_tokens)):
|
||||
if type(res_tokens[i]) != int:
|
||||
res_tokens[i] = int(res_tokens[i][0])
|
||||
if len(res_tokens) % 3 == 0:
|
||||
part_str = self.decode_tokens(res_tokens)
|
||||
yield part_str
|
||||
|
||||
res_str = self.tokenizer.decode(res_tokens)
|
||||
if self.device == "cuda":
|
||||
del sec_vic, pkv, logits
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
res_str = self.decode_tokens(res_tokens)
|
||||
# print(f"[DEBUG] final output : \n{res_str}")
|
||||
return res_str
|
||||
yield res_str
|
||||
|
||||
def generate_new_token(self, params, debug=False):
|
||||
def forward_first(first_vic, prompt, cache_outputs=False):
|
||||
|
||||
@@ -4,6 +4,12 @@ from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
CompiledFirstVicunaLayer,
|
||||
CompiledSecondVicunaLayer,
|
||||
ShardedVicunaModel,
|
||||
LMHead,
|
||||
LMHeadCompiled,
|
||||
VicunaEmbedding,
|
||||
VicunaEmbeddingCompiled,
|
||||
VicunaNorm,
|
||||
VicunaNormCompiled,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from shark.shark_importer import import_with_fx
|
||||
@@ -19,9 +25,11 @@ import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
class Vicuna(SharkLLMBase):
|
||||
# Class representing Sharded Vicuna Model
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
@@ -29,21 +37,25 @@ class Vicuna(SharkLLMBase):
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
config_json=None,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.config = config_json
|
||||
self.shark_model = self.compile(device=device)
|
||||
|
||||
def get_tokenizer(self):
|
||||
# Retrieve the tokenizer from Huggingface
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
# Retrieve the torch model from Huggingface
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
@@ -51,6 +63,8 @@ class Vicuna(SharkLLMBase):
|
||||
return vicuna_model
|
||||
|
||||
def write_in_dynamic_inputs0(self, module, dynamic_input_size):
|
||||
# Current solution for ensuring mlir files support dynamic inputs
|
||||
# TODO find a more elegant way to implement this
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
@@ -107,6 +121,7 @@ class Vicuna(SharkLLMBase):
|
||||
past_key_value0=None,
|
||||
past_key_value1=None,
|
||||
):
|
||||
# Compile a hidden decoder layer of vicuna
|
||||
if past_key_value0 is None and past_key_value1 is None:
|
||||
model_inputs = (hidden_states, attention_mask, position_ids)
|
||||
else:
|
||||
@@ -126,7 +141,154 @@ class Vicuna(SharkLLMBase):
|
||||
)
|
||||
return mlir_bytecode
|
||||
|
||||
def get_device_index(self, layer_string):
|
||||
# Get the device index from the config file
|
||||
# In the event that different device indices are assigned to
|
||||
# different parts of a layer, a majority vote will be taken and
|
||||
# everything will be run on the most commonly used device
|
||||
if self.config is None:
|
||||
return None
|
||||
idx_votes = {}
|
||||
for key in self.config.keys():
|
||||
if re.search(layer_string, key):
|
||||
if int(self.config[key]["gpu"]) in idx_votes.keys():
|
||||
idx_votes[int(self.config[key]["gpu"])] += 1
|
||||
else:
|
||||
idx_votes[int(self.config[key]["gpu"])] = 1
|
||||
device_idx = max(idx_votes, key=idx_votes.get)
|
||||
return device_idx
|
||||
|
||||
def compile_lmhead(
|
||||
self, lmh, hidden_states, device="cpu", device_idx=None
|
||||
):
|
||||
# compile the lm head of the vicuna model
|
||||
# This can be used for both first and second vicuna, so only needs to be run once
|
||||
mlir_path = Path(f"lmhead.mlir")
|
||||
vmfb_path = Path(f"lmhead.vmfb")
|
||||
if mlir_path.exists():
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
hidden_states = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
lmh,
|
||||
(hidden_states,),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
bytecode,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
shark_module.save_module(module_name="lmhead")
|
||||
shark_module.load_module(vmfb_path)
|
||||
compiled_module = LMHeadCompiled(shark_module)
|
||||
return compiled_module
|
||||
|
||||
def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None):
|
||||
# compile the normalization layer of the vicuna model
|
||||
# This can be used for both first and second vicuna, so only needs to be run once
|
||||
mlir_path = Path(f"norm.mlir")
|
||||
vmfb_path = Path(f"norm.vmfb")
|
||||
if mlir_path.exists():
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
hidden_states = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
fvn,
|
||||
(hidden_states,),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
bytecode,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
shark_module.save_module(module_name="norm")
|
||||
shark_module.load_module(vmfb_path)
|
||||
compiled_module = VicunaNormCompiled(shark_module)
|
||||
return compiled_module
|
||||
|
||||
def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None):
|
||||
# compile the embedding layer of the vicuna model
|
||||
# This can be used for both first and second vicuna, so only needs to be run once
|
||||
mlir_path = Path(f"embedding.mlir")
|
||||
vmfb_path = Path(f"embedding.vmfb")
|
||||
if mlir_path.exists():
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
input_ids = torch_mlir.TensorPlaceholder.like(
|
||||
input_ids, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
fve,
|
||||
(input_ids,),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
bytecode,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
shark_module.save_module(module_name="embedding")
|
||||
shark_module.load_module(vmfb_path)
|
||||
compiled_module = VicunaEmbeddingCompiled(shark_module)
|
||||
|
||||
return compiled_module
|
||||
|
||||
def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True):
|
||||
# compile all layers for vmfb
|
||||
# this needs to be run seperatley for first and second vicuna
|
||||
mlirs, modules = [], []
|
||||
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
|
||||
if is_first:
|
||||
@@ -198,10 +360,6 @@ class Vicuna(SharkLLMBase):
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# bytecode_stream = BytesIO()
|
||||
# module.operation.write_bytecode(bytecode_stream)
|
||||
# bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if is_first:
|
||||
module = self.write_in_dynamic_inputs0(str(module), 137)
|
||||
bytecode = module.encode("UTF-8")
|
||||
@@ -224,20 +382,25 @@ class Vicuna(SharkLLMBase):
|
||||
if is_first:
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
device_idx = self.get_device_index(
|
||||
f"first_vicuna.model.model.layers.{idx}[\s.$]"
|
||||
)
|
||||
module = SharkInference(
|
||||
None,
|
||||
device=device,
|
||||
device_idx=idx % 1,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
device_idx = self.get_device_index(
|
||||
f"first_vicuna.model.model.layers.{idx}[\s.$]"
|
||||
)
|
||||
module = SharkInference(
|
||||
mlirs[idx],
|
||||
device=device,
|
||||
device_idx=idx % 1,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.save_module(
|
||||
@@ -255,19 +418,25 @@ class Vicuna(SharkLLMBase):
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
device_idx = self.get_device_index(
|
||||
f"second_vicuna.model.model.layers.{idx}[\s.$]"
|
||||
)
|
||||
module = SharkInference(
|
||||
None,
|
||||
device=device,
|
||||
device_idx=idx % 1,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
device_idx = self.get_device_index(
|
||||
f"second_vicuna.model.model.layers.{idx}[\s.$]"
|
||||
)
|
||||
module = SharkInference(
|
||||
mlirs[idx],
|
||||
device=device,
|
||||
device_idx=idx % 1,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.save_module(
|
||||
@@ -303,6 +472,42 @@ class Vicuna(SharkLLMBase):
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
|
||||
norm = VicunaNorm(vicuna_model.model.norm)
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.model\.norm(?:\.|\s|$)"
|
||||
)
|
||||
print(device_idx)
|
||||
norm = self.compile_norm(
|
||||
norm,
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
device=self.device,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
embeddings = VicunaEmbedding(vicuna_model.model.embed_tokens)
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)"
|
||||
)
|
||||
print(device_idx)
|
||||
embeddings = self.compile_embedding(
|
||||
embeddings,
|
||||
(torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)),
|
||||
device=self.device,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
lmhead = LMHead(vicuna_model.lm_head)
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.lm_head(?:\.|\s|$)"
|
||||
)
|
||||
print(device_idx)
|
||||
lmhead = self.compile_lmhead(
|
||||
lmhead,
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
device=self.device,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
layers0 = [
|
||||
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
@@ -323,7 +528,12 @@ class Vicuna(SharkLLMBase):
|
||||
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
|
||||
|
||||
sharded_model = ShardedVicunaModel(
|
||||
vicuna_model, shark_layers0, shark_layers1
|
||||
vicuna_model,
|
||||
shark_layers0,
|
||||
shark_layers1,
|
||||
lmhead,
|
||||
embeddings,
|
||||
norm,
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ def main():
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
|
||||
@@ -81,6 +81,7 @@ def main():
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
@@ -79,6 +79,7 @@ def main():
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
@@ -73,6 +73,7 @@ if __name__ == "__main__":
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
@@ -78,7 +78,7 @@ exe = EXE(
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx=False,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
|
||||
@@ -520,16 +520,17 @@ class SharkifyStableDiffusionModel:
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3])
|
||||
input_mask = [True, True, True, False]
|
||||
model_name = "unet512" if use_large else "unet"
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["unet"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="unet",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
|
||||
@@ -135,6 +135,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
@@ -156,7 +157,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
|
||||
@@ -378,6 +378,7 @@ class InpaintPipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -408,7 +409,10 @@ class InpaintPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
|
||||
@@ -379,6 +379,7 @@ class OutpaintPipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -409,7 +410,10 @@ class OutpaintPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
|
||||
@@ -204,6 +204,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
@@ -230,7 +231,10 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
|
||||
@@ -168,7 +168,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.status = SD_STATE_IDLE
|
||||
self.load_unet()
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
@@ -182,15 +185,26 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
noise_level,
|
||||
),
|
||||
)
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
noise_level,
|
||||
),
|
||||
)
|
||||
else:
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
noise_level,
|
||||
),
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
@@ -219,6 +233,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -243,6 +258,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -264,7 +280,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
|
||||
@@ -757,6 +757,14 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
if args.ckpt_loc:
|
||||
img_model = Path(os.path.basename(args.ckpt_loc)).stem
|
||||
|
||||
img_vae = None
|
||||
if args.custom_vae:
|
||||
img_vae = Path(os.path.basename(args.custom_vae)).stem
|
||||
|
||||
img_lora = None
|
||||
if args.use_lora:
|
||||
img_lora = Path(os.path.basename(args.use_lora)).stem
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
@@ -767,7 +775,9 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
if args.write_metadata_to_png:
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
|
||||
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps: {args.steps},"
|
||||
f"Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed},"
|
||||
f"Size: {args.width}x{args.height}, Model: {img_model}, VAE: {img_vae}, LoRA: {img_lora}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
@@ -778,6 +788,9 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
"Image saved as png instead. Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
# To be as low-impact as possible to the existing CSV format, we append
|
||||
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
|
||||
# importance for each data point. Something to consider.
|
||||
new_entry = {
|
||||
"VARIANT": img_model,
|
||||
"SCHEDULER": args.scheduler,
|
||||
@@ -791,12 +804,17 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
"WIDTH": args.width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
"VAE": img_vae,
|
||||
"LORA": img_lora,
|
||||
}
|
||||
|
||||
new_entry.update(extra_info)
|
||||
|
||||
with open(csv_path, "a", encoding="utf-8") as csv_obj:
|
||||
csv_mode = "a" if os.path.isfile(csv_path) else "w"
|
||||
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
if csv_mode == "w":
|
||||
dictwriter_obj.writeheader()
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
|
||||
if sys.platform == "darwin":
|
||||
# import before IREE to avoid torch-MLIR library issues
|
||||
import torch_mlir
|
||||
|
||||
import shutil
|
||||
import PIL, transformers # ensures inclusion in pysintaller exe generation
|
||||
import PIL, sentencepiece, transformers # ensures inclusion in pysintaller exe generation
|
||||
from apps.stable_diffusion.src import args, clear_all
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
@@ -25,7 +30,11 @@ def launch_app(address):
|
||||
width = window.winfo_screenwidth()
|
||||
height = window.winfo_screenheight()
|
||||
webview.create_window(
|
||||
"SHARK AI Studio", url=address, width=width, height=height
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False)
|
||||
|
||||
@@ -39,6 +48,7 @@ if __name__ == "__main__":
|
||||
img2img_api,
|
||||
upscaler_api,
|
||||
inpaint_api,
|
||||
outpaint_api,
|
||||
)
|
||||
from fastapi import FastAPI, APIRouter
|
||||
import uvicorn
|
||||
@@ -50,9 +60,7 @@ if __name__ == "__main__":
|
||||
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
|
||||
# app.add_api_route(
|
||||
# "/sdapi/v1/outpaint", outpaint_api, methods=["post"]
|
||||
# )
|
||||
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
|
||||
app.include_router(APIRouter())
|
||||
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
|
||||
|
||||
@@ -249,6 +249,7 @@ def img2img_inf(
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
@@ -342,7 +343,7 @@ def img2img_api(
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = list(res)[0]
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
|
||||
@@ -204,6 +204,7 @@ def inpaint_inf(
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
@@ -278,7 +279,7 @@ def inpaint_api(
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
@@ -289,6 +290,10 @@ def inpaint_api(
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
|
||||
@@ -211,6 +211,7 @@ def outpaint_inf(
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
@@ -287,7 +288,7 @@ def outpaint_api(
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
@@ -298,6 +299,10 @@ def outpaint_api(
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
|
||||
@@ -65,16 +65,11 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
)
|
||||
prompt = messages.strip()
|
||||
print("prompt = ", prompt)
|
||||
sentence = vicuna_model.generate(prompt)
|
||||
|
||||
partial_text = ""
|
||||
for new_text in sentence.split(" "):
|
||||
# print(new_text)
|
||||
partial_text += new_text + " "
|
||||
for partial_text in vicuna_model.generate(prompt):
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to cleanup the message textbox and the updated conversation history
|
||||
yield history
|
||||
history[-1][1] = sentence
|
||||
|
||||
return history
|
||||
|
||||
# else Model is StableLM
|
||||
@@ -136,8 +131,10 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp32",
|
||||
value="fp16",
|
||||
choices=[
|
||||
"int4",
|
||||
"int8",
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
|
||||
@@ -265,6 +265,10 @@ def txt2img_api(
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
@@ -301,7 +305,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
)
|
||||
txt2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
placeholder="Select 'None' in the dropdown on the left and enter model ID here",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
lines=3,
|
||||
@@ -553,6 +557,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
outputs=[
|
||||
txt2img_png_info_img,
|
||||
@@ -566,5 +573,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -202,6 +202,7 @@ def upscaler_inf(
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
@@ -300,7 +301,7 @@ def upscaler_api(
|
||||
ondemand=False,
|
||||
)
|
||||
# Converts generator type to subscriptable
|
||||
res = list(res)[0]
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
|
||||
@@ -11,21 +11,35 @@ def has_csv(image_filename: str) -> bool:
|
||||
return os.path.exists(csv_path(image_filename))
|
||||
|
||||
|
||||
def parse_csv(image_filename: str):
|
||||
# We use a reader instead of a DictReader here for images_details.csv files due to the lack of
|
||||
# headers, and then match up the return list for each row with our guess at which column format
|
||||
# the file is using.
|
||||
|
||||
def matching_filename(image_filename: str, row):
|
||||
# we assume the final column of the csv has the original filename with full path and match that
|
||||
# against the image_filename. We then exclude the filename from the output, hence the -1's.
|
||||
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
|
||||
# the value of the OUTPUT key
|
||||
return os.path.basename(image_filename) in (
|
||||
row[-1] if isinstance(row, list) else row["OUTPUT"]
|
||||
)
|
||||
|
||||
|
||||
def parse_csv(image_filename: str):
|
||||
csv_filename = csv_path(image_filename)
|
||||
|
||||
matches = [
|
||||
humanize(row)
|
||||
for row in csv.reader(open(csv_filename, "r", newline=""))
|
||||
if row
|
||||
and humanizable(row)
|
||||
and os.path.basename(image_filename) in row[-1]
|
||||
]
|
||||
with open(csv_filename, "r", newline="") as csv_file:
|
||||
# We use a reader or DictReader here for images_details.csv depending on whether we think it
|
||||
# has headers or not. Having headers means less guessing of the format.
|
||||
has_header = csv.Sniffer().has_header(csv_file.read(2048))
|
||||
csv_file.seek(0)
|
||||
|
||||
reader = (
|
||||
csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
|
||||
)
|
||||
|
||||
matches = [
|
||||
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
|
||||
humanize(row)
|
||||
for row in reader
|
||||
if row
|
||||
and (has_header or humanizable(row))
|
||||
and matching_filename(image_filename, row)
|
||||
]
|
||||
|
||||
return matches[0] if matches else {}
|
||||
|
||||
@@ -50,7 +50,22 @@ PARAMS_FORMATS = {
|
||||
},
|
||||
}
|
||||
|
||||
PARAMS_FORMAT_LONGEST = PARAMS_FORMATS[max(PARAMS_FORMATS.keys())]
|
||||
PARAMS_FORMAT_CURRENT = {
|
||||
"VARIANT": "Model",
|
||||
"VAE": "VAE",
|
||||
"LORA": "LoRA",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"HEIGHT": "Height",
|
||||
"WIDTH": "Width",
|
||||
"MAX_LENGTH": "Max Length",
|
||||
"OUTPUT": "Filename",
|
||||
}
|
||||
|
||||
|
||||
def compact(metadata: dict) -> dict:
|
||||
@@ -97,19 +112,20 @@ def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
|
||||
)
|
||||
|
||||
# For dictionaries we try to use the matching length parameter format if
|
||||
# available, otherwise we use the longest. Then we swap keys in the
|
||||
# metadata that match keys in the format for the friendlier name that we
|
||||
# have set in the format value
|
||||
# available, otherwise we just use the current format which is assumed to
|
||||
# have everything currently known about. Then we swap keys in the metadata
|
||||
# that match keys in the format for the friendlier name that we have set
|
||||
# in the format value
|
||||
if isinstance(metadata, dict):
|
||||
if humanizable(metadata, includes_filename):
|
||||
format = PARAMS_FORMATS[lookup_key]
|
||||
else:
|
||||
format = PARAMS_FORMAT_LONGEST
|
||||
format = PARAMS_FORMAT_CURRENT
|
||||
|
||||
return {
|
||||
format[key]: value
|
||||
for (key, value) in metadata.items()
|
||||
if key in format.keys()
|
||||
format[key]: metadata[key]
|
||||
for key in format.keys()
|
||||
if key in metadata.keys() and metadata[key]
|
||||
}
|
||||
|
||||
raise TypeError("Can only humanize parameter lists or dictionaries")
|
||||
|
||||
@@ -62,6 +62,82 @@ def parse_generation_parameters(x: str):
|
||||
return res
|
||||
|
||||
|
||||
def try_find_model_base_from_png_metadata(
|
||||
file: str, folder: str = "models"
|
||||
) -> str:
|
||||
custom = ""
|
||||
|
||||
# Remove extension from file info
|
||||
if file.endswith(".safetensors") or file.endswith(".ckpt"):
|
||||
file = Path(file).stem
|
||||
# Check for the file name match with one of the local ckpt or safetensors files
|
||||
if Path(get_custom_model_pathfile(file + ".ckpt", folder)).is_file():
|
||||
custom = file + ".ckpt"
|
||||
if Path(
|
||||
get_custom_model_pathfile(file + ".safetensors", folder)
|
||||
).is_file():
|
||||
custom = file + ".safetensors"
|
||||
|
||||
return custom
|
||||
|
||||
|
||||
def find_model_from_png_metadata(
|
||||
key: str, metadata: dict[str, str | int]
|
||||
) -> tuple[str, str]:
|
||||
png_hf_id = ""
|
||||
png_custom = ""
|
||||
|
||||
if key in metadata:
|
||||
model_file = metadata[key]
|
||||
png_custom = try_find_model_base_from_png_metadata(model_file)
|
||||
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
|
||||
if model_file in predefined_models:
|
||||
png_custom = model_file
|
||||
# If nothing had matched, check vendor/hf_model_id
|
||||
if not png_custom and model_file.count("/"):
|
||||
png_hf_id = model_file
|
||||
# No matching model was found
|
||||
if not png_custom and not png_hf_id:
|
||||
print(
|
||||
"Import PNG info: Unable to find a matching model for %s"
|
||||
% model_file
|
||||
)
|
||||
|
||||
return png_custom, png_hf_id
|
||||
|
||||
|
||||
def find_vae_from_png_metadata(
|
||||
key: str, metadata: dict[str, str | int]
|
||||
) -> str:
|
||||
vae_custom = ""
|
||||
|
||||
if key in metadata:
|
||||
vae_file = metadata[key]
|
||||
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
|
||||
|
||||
# VAE input is optional, should not print or throw an error if missing
|
||||
|
||||
return vae_custom
|
||||
|
||||
|
||||
def find_lora_from_png_metadata(
|
||||
key: str, metadata: dict[str, str | int]
|
||||
) -> tuple[str, str]:
|
||||
lora_hf_id = ""
|
||||
lora_custom = ""
|
||||
|
||||
if key in metadata:
|
||||
lora_file = metadata[key]
|
||||
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
|
||||
# If nothing had matched, check vendor/hf_model_id
|
||||
if not lora_custom and lora_file.count("/"):
|
||||
lora_hf_id = lora_file
|
||||
|
||||
# LoRA input is optional, should not print or throw an error if missing
|
||||
|
||||
return lora_custom, lora_hf_id
|
||||
|
||||
|
||||
def import_png_metadata(
|
||||
pil_data,
|
||||
prompt,
|
||||
@@ -74,40 +150,21 @@ def import_png_metadata(
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
):
|
||||
try:
|
||||
png_info = pil_data.info["parameters"]
|
||||
metadata = parse_generation_parameters(png_info)
|
||||
png_hf_model_id = ""
|
||||
png_custom_model = ""
|
||||
|
||||
if "Model" in metadata:
|
||||
# Remove extension from model info
|
||||
if metadata["Model"].endswith(".safetensors") or metadata[
|
||||
"Model"
|
||||
].endswith(".ckpt"):
|
||||
metadata["Model"] = Path(metadata["Model"]).stem
|
||||
# Check for the model name match with one of the local ckpt or safetensors files
|
||||
if Path(
|
||||
get_custom_model_pathfile(metadata["Model"] + ".ckpt")
|
||||
).is_file():
|
||||
png_custom_model = metadata["Model"] + ".ckpt"
|
||||
if Path(
|
||||
get_custom_model_pathfile(metadata["Model"] + ".safetensors")
|
||||
).is_file():
|
||||
png_custom_model = metadata["Model"] + ".safetensors"
|
||||
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
|
||||
if metadata["Model"] in predefined_models:
|
||||
png_custom_model = metadata["Model"]
|
||||
# If nothing had matched, check vendor/hf_model_id
|
||||
if not png_custom_model and metadata["Model"].count("/"):
|
||||
png_hf_model_id = metadata["Model"]
|
||||
# No matching model was found
|
||||
if not png_custom_model and not png_hf_model_id:
|
||||
print(
|
||||
"Import PNG info: Unable to find a matching model for %s"
|
||||
% metadata["Model"]
|
||||
)
|
||||
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
|
||||
"Model", metadata
|
||||
)
|
||||
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
|
||||
"LoRA", metadata
|
||||
)
|
||||
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
|
||||
|
||||
negative_prompt = metadata["Negative prompt"]
|
||||
steps = int(metadata["Steps"])
|
||||
@@ -115,12 +172,24 @@ def import_png_metadata(
|
||||
seed = int(metadata["Seed"])
|
||||
width = float(metadata["Size-1"])
|
||||
height = float(metadata["Size-2"])
|
||||
|
||||
if "Model" in metadata and png_custom_model:
|
||||
custom_model = png_custom_model
|
||||
hf_model_id = ""
|
||||
if "Model" in metadata and png_hf_model_id:
|
||||
custom_model = "None"
|
||||
hf_model_id = png_hf_model_id
|
||||
|
||||
if "LoRA" in metadata and lora_custom_model:
|
||||
custom_lora = lora_custom_model
|
||||
hf_lora_id = ""
|
||||
if "LoRA" in metadata and lora_hf_model_id:
|
||||
custom_lora = "None"
|
||||
hf_lora_id = lora_hf_model_id
|
||||
|
||||
if "VAE" in metadata and vae_custom_model:
|
||||
custom_vae = vae_custom_model
|
||||
|
||||
if "Prompt" in metadata:
|
||||
prompt = metadata["Prompt"]
|
||||
if "Sampler" in metadata:
|
||||
@@ -149,4 +218,7 @@ def import_png_metadata(
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
)
|
||||
|
||||
@@ -89,21 +89,155 @@ def img2img_test():
|
||||
|
||||
print(f"response from server was : {res.status_code}")
|
||||
|
||||
print("Extracting response object")
|
||||
# NOTE Uncomment below to save the picture
|
||||
|
||||
# Uncomment below to save the picture
|
||||
# print("Extracting response object")
|
||||
# response_obj = res.json()
|
||||
# img_b64 = response_obj.get("images", [False])[0] or response_obj.get(
|
||||
# "image"
|
||||
# )
|
||||
# img_b2 = base64.b64decode(img_b64.replace("data:image/png;base64,", ""))
|
||||
# im_file = BytesIO(img_b2)
|
||||
# response_img = Image.open(im_file)
|
||||
# print("Saving Response Image to: response_img")
|
||||
# response_img.save(r"rest_api_tests/response_img.png")
|
||||
|
||||
response_obj = res.json()
|
||||
img_b64 = response_obj.get("images", [False])[0] or response_obj.get(
|
||||
"image"
|
||||
|
||||
def inpainting_test():
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
height = 512
|
||||
width = 512
|
||||
steps = 50
|
||||
noise_level = 10
|
||||
cfg_scale = 7
|
||||
is_full_res = False
|
||||
full_res_padding = 32
|
||||
image_path = r"./rest_api_tests/dog.png"
|
||||
|
||||
img_file = open(image_path, "rb")
|
||||
image = (
|
||||
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
|
||||
)
|
||||
img_b2 = base64.b64decode(img_b64.replace("data:image/png;base64,", ""))
|
||||
im_file = BytesIO(img_b2)
|
||||
response_img = Image.open(im_file)
|
||||
print("Saving Response Image to: response_img")
|
||||
response_img.save(r"rest_api_tests/response_img.png")
|
||||
img_file = open(image_path, "rb")
|
||||
mask = (
|
||||
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
|
||||
)
|
||||
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/inpaint"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"image": image,
|
||||
"mask": mask,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"steps": steps,
|
||||
"noise_level": noise_level,
|
||||
"cfg_scale": cfg_scale,
|
||||
"seed": seed,
|
||||
"is_full_res": is_full_res,
|
||||
"full_res_padding": full_res_padding,
|
||||
}
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[Inpainting] response from server was : {res.status_code}")
|
||||
|
||||
|
||||
def outpainting_test():
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
height = 512
|
||||
width = 512
|
||||
steps = 50
|
||||
cfg_scale = 7
|
||||
color_variation = 0.2
|
||||
noise_q = 0.2
|
||||
directions = ["up", "down", "right", "left"]
|
||||
pixels = 32
|
||||
mask_blur = 64
|
||||
image_path = r"./rest_api_tests/dog.png"
|
||||
|
||||
# Converting Image to Base64
|
||||
img_file = open(image_path, "rb")
|
||||
init_images = [
|
||||
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
|
||||
]
|
||||
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/outpaint"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"seed": seed,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"steps": steps,
|
||||
"cfg_scale": cfg_scale,
|
||||
"color_variation": color_variation,
|
||||
"noise_q": noise_q,
|
||||
"directions": directions,
|
||||
"pixels": pixels,
|
||||
"mask_blur": mask_blur,
|
||||
"init_images": init_images,
|
||||
}
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[Outpaint] response from server was : {res.status_code}")
|
||||
|
||||
|
||||
def txt2img_test():
|
||||
prompt = "Paint a rabbit in a top hate"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
height = 512
|
||||
width = 512
|
||||
steps = 50
|
||||
cfg_scale = 7
|
||||
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/txt2img"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"seed": seed,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"steps": steps,
|
||||
"cfg_scale": cfg_scale,
|
||||
}
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[txt2img] response from server was : {res.status_code}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
txt2img_test()
|
||||
img2img_test()
|
||||
upscaler_test()
|
||||
inpainting_test()
|
||||
outpainting_test()
|
||||
|
||||
2
setup.py
2
setup.py
@@ -39,7 +39,7 @@ setup(
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"PyYAML",
|
||||
"torch-mlir>=20221021.633",
|
||||
"torch-mlir==20230620.875",
|
||||
]
|
||||
+ backend_deps,
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --pre torch-mlir==20230620.875 torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
@@ -88,7 +88,7 @@ if [ "$torch_mlir_bin" = true ]; then
|
||||
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
$PYTHON -m pip install --pre torch-mlir==20230620.875 -f https://llvm.github.io/torch-mlir/package-index/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch-mlir"
|
||||
else
|
||||
|
||||
154
shark/dynamo_backend/utils.py
Normal file
154
shark/dynamo_backend/utils.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch._decomp import get_decompositions
|
||||
from torch.func import functionalize
|
||||
import io
|
||||
import torch_mlir
|
||||
|
||||
|
||||
# TODO: Control decompositions.
|
||||
def default_decompositions():
|
||||
return get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
torch.ops.aten.native_layer_norm,
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
class SharkBackend:
|
||||
def __init__(
|
||||
self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict
|
||||
):
|
||||
self.fx_g = fx_g
|
||||
self.inputs = inputs
|
||||
self.shark_module = None
|
||||
self.device: str = options.get("device", "cpu")
|
||||
self.was_unwrapped: bool = False
|
||||
self.none_indices: list = []
|
||||
self._modify_fx_g()
|
||||
self.compile()
|
||||
|
||||
def _modify_fx_g(self):
|
||||
self.none_indices = _remove_nones(self.fx_g)
|
||||
self.was_unwrapped = _unwrap_single_tuple_return(self.fx_g)
|
||||
|
||||
def compile(self):
|
||||
gm = make_fx(
|
||||
functionalize(self.fx_g),
|
||||
decomposition_table=default_decompositions(),
|
||||
)(*self.inputs)
|
||||
gm.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
gm.recompile()
|
||||
strip_overloads(gm)
|
||||
ts_g = torch.jit.script(gm)
|
||||
mlir_module = torch_mlir.compile(
|
||||
ts_g, self.inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=self.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile(extra_args=[])
|
||||
self.shark_module = shark_module
|
||||
|
||||
def __call__(self, *inputs):
|
||||
np_inputs = [x.contiguous().detach().cpu().numpy() for x in inputs]
|
||||
np_outs = self.shark_module("forward", np_inputs)
|
||||
if self.was_unwrapped:
|
||||
np_outs = [
|
||||
np_outs,
|
||||
]
|
||||
|
||||
if not isinstance(np_outs, list):
|
||||
res = torch.from_numpy(np_outs)
|
||||
return res
|
||||
|
||||
result = [torch.from_numpy(x) for x in np_outs]
|
||||
for r_in in self.none_indices:
|
||||
result.insert(r_in, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
@@ -1,10 +1,7 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from shark.shark_inference import SharkInference
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
args,
|
||||
)
|
||||
from shark.shark_compile import shark_compile_through_fx
|
||||
from MEGABYTE_pytorch import MEGABYTE
|
||||
|
||||
import os
|
||||
@@ -37,23 +34,22 @@ class MegaModel(torch.nn.Module):
|
||||
|
||||
|
||||
megaModel = MegaModel()
|
||||
input = [torch.randint(0, 16000, (1, 1024, 4))]
|
||||
inputs = [torch.randint(0, 16000, (1, 1024, 4))]
|
||||
|
||||
# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :-
|
||||
# 1. aten.alias
|
||||
shark_module, _ = compile_through_fx(
|
||||
megaModel,
|
||||
inputs=input,
|
||||
shark_module, _ = shark_compile_through_fx(
|
||||
model=megaModel,
|
||||
inputs=inputs,
|
||||
extended_model_name="mega_shark",
|
||||
debug=False,
|
||||
generate_vmfb=True,
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
save_dir=os.getcwd(),
|
||||
debug=False,
|
||||
generate_or_load_vmfb=True,
|
||||
extra_args=[],
|
||||
base_model_id=None,
|
||||
model_name="mega_shark",
|
||||
precision=None,
|
||||
return_mlir=True,
|
||||
device="cuda",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
# logits = model(x)
|
||||
|
||||
@@ -63,10 +59,10 @@ def print_output_info(output, msg):
|
||||
print("\n\t", output.shape)
|
||||
|
||||
|
||||
ans = shark_module("forward", input)
|
||||
ans = shark_module("forward", inputs)
|
||||
print_output_info(torch.from_numpy(ans), "SHARK's output")
|
||||
|
||||
ans = megaModel.forward(*input)
|
||||
ans = megaModel.forward(*inputs)
|
||||
print_output_info(ans, "ORIGINAL Model's output")
|
||||
|
||||
# and sample from the logits accordingly
|
||||
|
||||
@@ -63,6 +63,7 @@ def get_supported_device_list():
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"AMD-AIE": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
@@ -81,6 +82,7 @@ def iree_target_map(device):
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cpu-task": "llvm-cpu",
|
||||
"AMD-AIE": "llvm-cpu",
|
||||
"cpu-sync": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
|
||||
@@ -14,11 +14,14 @@
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.iree_utils._common import iree_device_map, iree_target_map
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
|
||||
from shark.iree_utils.benchmark_utils import *
|
||||
from shark.parser import shark_args
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
@@ -38,7 +41,10 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
return get_iree_cpu_args()
|
||||
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
|
||||
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
|
||||
|
||||
return get_iree_cpu_args() + data_tiling_flag + u_kernel_flag
|
||||
if device_uri[0] == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
@@ -181,8 +187,10 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
|
||||
vmfb_file.close()
|
||||
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_flatbuffer(
|
||||
config.vm_instance, flatbuffer_blob
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
config.vm_instance,
|
||||
flatbuffer_blob,
|
||||
warn_if_copy=False,
|
||||
)
|
||||
|
||||
benchmark_cl = build_benchmark_args_non_tensor_input(
|
||||
@@ -313,8 +321,8 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
config = ireert.Config(device=haldevice)
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_flatbuffer(
|
||||
config.vm_instance, flatbuffer_blob
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
config.vm_instance, flatbuffer_blob, warn_if_copy=False
|
||||
)
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
@@ -322,6 +330,63 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
def load_vmfb_using_mmap(
|
||||
flatbuffer_blob_or_path, device: str, device_idx: int = None
|
||||
):
|
||||
instance = ireert.VmInstance()
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
allocators=[],
|
||||
)
|
||||
# First get configs.
|
||||
if device_idx is not None:
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
haldriver = ireert.get_driver(device)
|
||||
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
ireert.flags.parse_flags(flag)
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
# (This would arise if we're invoking `load_module` from a SharkInference obj)
|
||||
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
|
||||
# (This would arise if we're invoking `compile` from a SharkInference obj)
|
||||
temp_file_to_unlink = None
|
||||
if isinstance(flatbuffer_blob_or_path, Path):
|
||||
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
|
||||
if (
|
||||
isinstance(flatbuffer_blob_or_path, str)
|
||||
and ".vmfb" in flatbuffer_blob_or_path
|
||||
):
|
||||
vmfb_file_path = flatbuffer_blob_or_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(mmaped_vmfb)
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
tf.write(flatbuffer_blob_or_path)
|
||||
tf.flush()
|
||||
vmfb_file_path = tf.name
|
||||
temp_file_to_unlink = vmfb_file_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
|
||||
return mmaped_vmfb, config, temp_file_to_unlink
|
||||
|
||||
|
||||
def get_iree_compiled_module(
|
||||
module,
|
||||
device: str,
|
||||
@@ -329,19 +394,58 @@ def get_iree_compiled_module(
|
||||
model_config_path: str = None,
|
||||
extra_args: list = [],
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, model_config_path, extra_args
|
||||
)
|
||||
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
|
||||
temp_file_to_unlink = None
|
||||
# TODO: Currently mmap=True control flow path has been switched off for mmap.
|
||||
# Got to find a cleaner way to unlink/delete the temporary file since
|
||||
# we're setting delete=False when creating NamedTemporaryFile. That's why
|
||||
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
|
||||
if mmap:
|
||||
print(f"Will load the compiled module as a mmapped temporary file")
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_blob, device, device_idx
|
||||
)
|
||||
else:
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=device_idx
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
"config": config,
|
||||
"temp_file_to_unlink": temp_file_to_unlink,
|
||||
}
|
||||
return ret_params
|
||||
|
||||
|
||||
def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None):
|
||||
with open(os.path.join(flatbuffer_path), "rb") as f:
|
||||
flatbuffer_blob = f.read()
|
||||
|
||||
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
|
||||
def load_flatbuffer(
|
||||
flatbuffer_path: str,
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
):
|
||||
temp_file_to_unlink = None
|
||||
if mmap:
|
||||
print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file")
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_path, device, device_idx
|
||||
)
|
||||
else:
|
||||
with open(os.path.join(flatbuffer_path), "rb") as f:
|
||||
flatbuffer_blob = f.read()
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=device_idx
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
"config": config,
|
||||
"temp_file_to_unlink": temp_file_to_unlink,
|
||||
}
|
||||
return ret_params
|
||||
|
||||
|
||||
def export_iree_module_to_vmfb(
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import subprocess
|
||||
import platform
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
def get_cpu_count():
|
||||
@@ -44,4 +45,18 @@ def get_iree_cpu_args():
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
print(f"Target triple found:{target_triple}")
|
||||
return [f"--iree-llvmcpu-target-triple={target_triple}"]
|
||||
return [
|
||||
f"--iree-llvmcpu-target-triple={target_triple}",
|
||||
]
|
||||
|
||||
|
||||
# Get iree runtime flags for cpu
|
||||
def get_iree_cpu_rt_args():
|
||||
default = get_cpu_count()
|
||||
default = default if default <= 8 else default - 2
|
||||
cpu_count = (
|
||||
default
|
||||
if shark_args.task_topology_max_group_count is None
|
||||
else shark_args.task_topology_max_group_count
|
||||
)
|
||||
return [f"--task_topology_max_group_count={cpu_count}"]
|
||||
|
||||
@@ -119,5 +119,11 @@ parser.add_argument(
|
||||
"to augment the base device allocator",
|
||||
choices=["debug", "caching"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_topology_max_group_count",
|
||||
type=str,
|
||||
default=None,
|
||||
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
|
||||
)
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
|
||||
99
shark/shark_compile.py
Normal file
99
shark/shark_compile.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import tempfile
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
|
||||
|
||||
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
|
||||
shark_module = None
|
||||
if os.path.isfile(vmfb_path):
|
||||
shark_module = SharkInference(
|
||||
None,
|
||||
device=device,
|
||||
mlir_dialect=mlir_dialect,
|
||||
)
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
return shark_module
|
||||
|
||||
|
||||
def compile_module(
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[]
|
||||
):
|
||||
if generate_vmfb:
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
|
||||
if os.path.isfile(vmfb_path):
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_model_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
|
||||
|
||||
def shark_compile_through_fx(
|
||||
model,
|
||||
inputs,
|
||||
extended_model_name,
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
save_dir=tempfile.gettempdir(),
|
||||
debug=False,
|
||||
generate_or_load_vmfb=True,
|
||||
extra_args=[],
|
||||
device=None,
|
||||
mlir_dialect="tm_tensor",
|
||||
):
|
||||
if generate_or_load_vmfb:
|
||||
shark_module = load_vmfb(
|
||||
extended_model_name=extended_model_name,
|
||||
device=device,
|
||||
mlir_dialect=mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
if shark_module:
|
||||
return (
|
||||
shark_module,
|
||||
None,
|
||||
)
|
||||
|
||||
from shark.parser import shark_args
|
||||
|
||||
if "cuda" in device:
|
||||
shark_args.enable_tf32 = True
|
||||
|
||||
(
|
||||
mlir_module,
|
||||
_,
|
||||
) = import_with_fx(
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=f16_input_mask,
|
||||
debug=debug,
|
||||
model_name=extended_model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=device,
|
||||
mlir_dialect=mlir_dialect,
|
||||
)
|
||||
return (
|
||||
compile_module(
|
||||
shark_module,
|
||||
extended_model_name,
|
||||
generate_vmfb=generate_or_load_vmfb,
|
||||
extra_args=extra_args,
|
||||
),
|
||||
mlir_module,
|
||||
)
|
||||
@@ -60,12 +60,15 @@ def download_public_file(
|
||||
else:
|
||||
continue
|
||||
|
||||
destination_filename = os.path.join(destination_folder_name, blob_name)
|
||||
if os.path.isdir(destination_filename):
|
||||
continue
|
||||
with open(destination_filename, "wb") as f:
|
||||
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
|
||||
storage_client.download_blob_to_file(blob, file_obj)
|
||||
else:
|
||||
destination_filename = os.path.join(
|
||||
destination_folder_name, blob_name
|
||||
)
|
||||
if os.path.isdir(destination_filename):
|
||||
continue
|
||||
with open(destination_filename, "wb") as f:
|
||||
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
|
||||
storage_client.download_blob_to_file(blob, file_obj)
|
||||
|
||||
|
||||
input_type_to_np_dtype = {
|
||||
|
||||
206
shark/shark_eager/shark_eager.py
Normal file
206
shark/shark_eager/shark_eager.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
from shark.shark_importer import import_with_fx
|
||||
import torchvision.models as models
|
||||
import copy
|
||||
import io
|
||||
import numpy as np
|
||||
import sys
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node
|
||||
from typing import Dict
|
||||
import torch_mlir
|
||||
|
||||
|
||||
def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_g, inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile(extra_args=[])
|
||||
return shark_module
|
||||
|
||||
|
||||
def _make_single_op_gm(node, captured_val, compiled_graph):
|
||||
"""Make a GraphModule that just executes the given node."""
|
||||
g = torch.fx.Graph()
|
||||
env = {}
|
||||
inputs = []
|
||||
for arg in node.args:
|
||||
if arg and hasattr(arg, "name"):
|
||||
env[arg.name] = g.placeholder(arg.name)
|
||||
if isinstance(captured_val[arg.name], (list, tuple)):
|
||||
for val in captured_val[arg.name]:
|
||||
inputs.append(val)
|
||||
else:
|
||||
inputs.append(captured_val[arg.name])
|
||||
|
||||
call = g.node_copy(node, lambda n: env[n.name])
|
||||
g.output(call)
|
||||
g.lint()
|
||||
single_node = torch.fx.GraphModule(torch.nn.Module(), g)
|
||||
compiled_module = shark_backend(single_node, inputs)
|
||||
compiled_graph[node.name] = {
|
||||
"module": compiled_module,
|
||||
"inputs": [i for i in env],
|
||||
"result": None,
|
||||
}
|
||||
return
|
||||
|
||||
|
||||
def compiled_graph(gm: torch.fx.GraphModule, attr_info):
|
||||
compiled_graph = {}
|
||||
g = gm.graph
|
||||
for node in g.nodes:
|
||||
if node.op == "call_function":
|
||||
if not (
|
||||
node.target in [torch.ops.aten.empty]
|
||||
or node.name.startswith("getitem")
|
||||
):
|
||||
_make_single_op_gm(node, attr_info, compiled_graph)
|
||||
|
||||
# Currently torch.aten.empty has an compilation issue, so running natively.
|
||||
elif node.target in [torch.ops.aten.empty]:
|
||||
compiled_graph[node.name] = {
|
||||
"target": node.target,
|
||||
"args": node.args,
|
||||
"kwargs": node.kwargs,
|
||||
"result": None,
|
||||
}
|
||||
# Get item is a simple case takes a tuple and return the tensor at a particular index.
|
||||
elif node.name.startswith("getitem"):
|
||||
compiled_graph[node.name] = {
|
||||
"input": node.args[0].name,
|
||||
"pos": node.args[1],
|
||||
"result": None,
|
||||
}
|
||||
|
||||
return compiled_graph
|
||||
|
||||
|
||||
class ShapeProp:
|
||||
"""
|
||||
Shape propagation. This class takes a `GraphModule`.
|
||||
Then, its `propagate` method executes the `GraphModule`
|
||||
node-by-node with the given arguments. As each operation
|
||||
executes, the ShapeProp class stores away the shape and
|
||||
element type for the output values of each operation on
|
||||
the `shape` and `dtype` attributes of the operation's
|
||||
`Node`.
|
||||
"""
|
||||
|
||||
def __init__(self, mod):
|
||||
self.mod = mod
|
||||
self.graph = mod.graph
|
||||
self.modules = dict(self.mod.named_modules())
|
||||
|
||||
def propagate(self, *args):
|
||||
args_iter = iter(args)
|
||||
env: Dict[str, Node] = {}
|
||||
|
||||
def load_arg(a):
|
||||
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
|
||||
|
||||
def fetch_attr(target: str):
|
||||
target_atoms = target.split(".")
|
||||
attr_itr = self.mod
|
||||
for i, atom in enumerate(target_atoms):
|
||||
if not hasattr(attr_itr, atom):
|
||||
raise RuntimeError(
|
||||
f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
|
||||
)
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
return attr_itr
|
||||
|
||||
for node in self.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
result = next(args_iter)
|
||||
elif node.op == "get_attr":
|
||||
result = fetch_attr(node.target)
|
||||
elif node.op == "call_function":
|
||||
result = node.target(
|
||||
*load_arg(node.args), **load_arg(node.kwargs)
|
||||
)
|
||||
elif node.op == "call_method":
|
||||
self_obj, *args = load_arg(node.args)
|
||||
kwargs = load_arg(node.kwargs)
|
||||
result = getattr(self_obj, node.target)(*args, **kwargs)
|
||||
elif node.op == "call_module":
|
||||
result = self.modules[node.target](
|
||||
*load_arg(node.args), **load_arg(node.kwargs)
|
||||
)
|
||||
|
||||
# This is the only code specific to shape propagation.
|
||||
# you can delete this `if` branch and this becomes
|
||||
# a generic GraphModule interpreter.
|
||||
if isinstance(result, torch.Tensor):
|
||||
node.shape = result.shape
|
||||
node.dtype = result.dtype
|
||||
|
||||
env[node.name] = result
|
||||
|
||||
return env
|
||||
|
||||
# return load_arg(self.graph.result)
|
||||
|
||||
|
||||
resnet18 = models.resnet18(pretrained=True)
|
||||
resnet18.train(False)
|
||||
input = (torch.randn(1, 3, 224, 224),)
|
||||
|
||||
print(resnet18(input[0]))
|
||||
|
||||
fx_graph = import_with_fx(resnet18, input, mlir_type="fx")
|
||||
|
||||
shape_prop = ShapeProp(fx_graph)
|
||||
|
||||
x = shape_prop.propagate(input[0])
|
||||
|
||||
shark_graph = compiled_graph(fx_graph, x)
|
||||
|
||||
|
||||
for key in shark_graph:
|
||||
if key.startswith("getitem"):
|
||||
input_val = shark_graph[key]["input"]
|
||||
pos = shark_graph[key]["pos"]
|
||||
if input_val not in shark_graph:
|
||||
shark_graph[key]["result"] = x[input_val][pos].detach()
|
||||
else:
|
||||
shark_graph[key]["result"] = shark_graph[input_val]["result"][
|
||||
pos
|
||||
].detach()
|
||||
elif key.startswith("empty"):
|
||||
operator = shark_graph[key]["target"]
|
||||
args = shark_graph[key]["args"]
|
||||
kwargs = shark_graph[key]["kwargs"]
|
||||
shark_graph[key]["result"] = operator(*args, **kwargs).detach()
|
||||
else:
|
||||
input_val = shark_graph[key]["inputs"]
|
||||
input_tensors = []
|
||||
for input in input_val:
|
||||
if input not in shark_graph:
|
||||
input_tensors.append(x[input].detach())
|
||||
else:
|
||||
input_tensors.append(shark_graph[input]["result"])
|
||||
|
||||
val = shark_graph[key]["module"]("forward", input_tensors)
|
||||
if isinstance(val, (tuple, list)):
|
||||
list_val = []
|
||||
for v in val:
|
||||
list_val.append(torch.from_numpy(v))
|
||||
shark_graph[key]["result"] = list_val
|
||||
else:
|
||||
shark_graph[key]["result"] = torch.from_numpy(val)
|
||||
|
||||
|
||||
print(shark_graph)
|
||||
@@ -1,5 +1,8 @@
|
||||
import re
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
import torch_mlir
|
||||
from iree.compiler import compile_str
|
||||
from shark.shark_importer import import_with_fx, get_f16_inputs
|
||||
|
||||
|
||||
class GenerateConfigFile:
|
||||
@@ -7,7 +10,9 @@ class GenerateConfigFile:
|
||||
self,
|
||||
model,
|
||||
num_sharding_stages: int,
|
||||
sharding_stages_id: list[str] = None,
|
||||
sharding_stages_id: list[str],
|
||||
model_input=None,
|
||||
config_file_path="model_config.json",
|
||||
):
|
||||
self.model = model
|
||||
self.num_sharding_stages = num_sharding_stages
|
||||
@@ -15,8 +20,67 @@ class GenerateConfigFile:
|
||||
assert self.num_sharding_stages == len(
|
||||
self.sharding_stages_id
|
||||
), "Number of sharding stages should be equal to the list of their ID"
|
||||
self.model_input = model_input
|
||||
self.config_file_path = config_file_path
|
||||
|
||||
def generate_json(self):
|
||||
def split_into_dispatches(
|
||||
self,
|
||||
backend,
|
||||
fx_tracing_required=True,
|
||||
f16_model=False,
|
||||
torch_mlir_tracing=False,
|
||||
):
|
||||
graph_for_compilation = self.model
|
||||
if fx_tracing_required:
|
||||
graph_for_compilation = import_with_fx(
|
||||
self.model,
|
||||
self.model_input,
|
||||
is_f16=f16_model,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
graph_for_compilation,
|
||||
(self.model_input),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=torch_mlir_tracing,
|
||||
verbose=False,
|
||||
)
|
||||
module = module.operation.get_asm(large_elements_limit=4)
|
||||
compiled_module_str = str(
|
||||
compile_str(
|
||||
str(module),
|
||||
target_backends=[backend],
|
||||
extra_args=[
|
||||
"--compile-to=flow",
|
||||
"--mlir-elide-elementsattrs-if-larger=4",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
substring_start_idx = [
|
||||
m.start()
|
||||
for m in re.finditer("flow.dispatch @", compiled_module_str)
|
||||
]
|
||||
dispatch_list = dict()
|
||||
|
||||
# dispatch_no is the 'i'th index of a dispatch out of n total dispatches of a model
|
||||
# dispatch_id is the unique id of a dispatch, multiple instances of the same dispatch
|
||||
# can occur in a model
|
||||
for dispatch_no, substring_idx in enumerate(substring_start_idx):
|
||||
dispatch_idx = (
|
||||
compiled_module_str[substring_idx:]
|
||||
.split(":")[0]
|
||||
.split("@")[-1]
|
||||
)
|
||||
key = "dispatch_no_" + str(dispatch_no)
|
||||
dispatch_list[key] = {n: "None" for n in self.sharding_stages_id}
|
||||
dispatch_list[key]["dispatch_id"] = dispatch_idx
|
||||
|
||||
self.generate_json(dispatch_list)
|
||||
|
||||
def split_into_layers(self):
|
||||
model_dictionary = dict()
|
||||
|
||||
for name, m in self.model.named_modules():
|
||||
@@ -34,5 +98,8 @@ class GenerateConfigFile:
|
||||
layer_dict = {n: "None" for n in self.sharding_stages_id}
|
||||
model_dictionary[name] = layer_dict
|
||||
|
||||
with open("model_config.json", "w") as outfile:
|
||||
json.dump(model_dictionary, outfile)
|
||||
self.generate_json(model_dictionary)
|
||||
|
||||
def generate_json(self, artifacts):
|
||||
with open(self.config_file_path, "w") as outfile:
|
||||
json.dump(artifacts, outfile)
|
||||
|
||||
@@ -555,6 +555,9 @@ def import_with_fx(
|
||||
add_upcast(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if mlir_type == "fx":
|
||||
return fx_g
|
||||
|
||||
if training:
|
||||
change_fx_graph_return_to_tuple(fx_g)
|
||||
inputs = flatten_training_input(inputs)
|
||||
|
||||
@@ -48,6 +48,8 @@ class SharkInference:
|
||||
Refer to {https://mlir.llvm.org/docs/Dialects/}
|
||||
is_benchmark: bool
|
||||
Whether this SharkInference module should be benchmark-enabled.
|
||||
mmap: bool
|
||||
Whether to load/run vmfb using mmap. It's `True` by default.
|
||||
|
||||
Methods
|
||||
-------
|
||||
@@ -70,6 +72,7 @@ class SharkInference:
|
||||
dispatch_benchmark: str = None,
|
||||
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
|
||||
device_idx: int = None,
|
||||
mmap: bool = True,
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
@@ -88,6 +91,7 @@ class SharkInference:
|
||||
)
|
||||
|
||||
self.shark_runner = None
|
||||
self.mmap = mmap
|
||||
|
||||
def compile(self, extra_args=[]):
|
||||
if self.dispatch_benchmarks is not None:
|
||||
@@ -201,12 +205,14 @@ class SharkInference:
|
||||
compile_vmfb=False,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
(
|
||||
self.shark_runner.iree_compilation_module,
|
||||
self.shark_runner.iree_config,
|
||||
) = load_flatbuffer(
|
||||
params = load_flatbuffer(
|
||||
path,
|
||||
self.device,
|
||||
self.device_idx,
|
||||
mmap=self.mmap,
|
||||
)
|
||||
self.shark_runner.iree_compilation_module = params["vmfb"]
|
||||
self.shark_runner.iree_config = params["config"]
|
||||
self.shark_runner.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
return
|
||||
|
||||
@@ -85,16 +85,17 @@ class SharkRunner:
|
||||
|
||||
if compile_vmfb == True:
|
||||
# Compile the module to get the .vmfb.
|
||||
(
|
||||
self.iree_compilation_module,
|
||||
self.iree_config,
|
||||
) = get_iree_compiled_module(
|
||||
params = get_iree_compiled_module(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
device_idx=self.device_idx,
|
||||
)
|
||||
self.iree_compilation_module = params["vmfb"]
|
||||
self.iree_config = params["config"]
|
||||
self.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
|
||||
def run(self, function_name, inputs: tuple, send_to_host=False):
|
||||
return get_results(
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
1. Install torchdynamo
|
||||
- `git clone https://github.com/pytorch/torchdynamo.git`
|
||||
- `cd torchdynamo`
|
||||
- `python -m pip install -r requirements.txt`
|
||||
- `python setup.py develop`
|
||||
|
||||
2. Install functorch
|
||||
- `python -m pip install -v "git+https://github.com/pytorch/pytorch.git@$(python -c "import torch.version; print(torch.version.git_version)")#subdirectory=functorch"`
|
||||
|
||||
3. Run examples.
|
||||
- `python shark/examples/shark_dynamo/basic_examples.py`
|
||||
@@ -1,163 +0,0 @@
|
||||
import functools
|
||||
import time
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch._decomp import get_decompositions
|
||||
|
||||
import torch_mlir
|
||||
|
||||
|
||||
# TODO: Control decompositions.
|
||||
def default_decompositions():
|
||||
return get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def timeit(*, append_time_to: Optional[List] = None):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time_ns()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time_ns()
|
||||
|
||||
if append_time_to is not None:
|
||||
append_time_to.append(end_time - start_time)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
def make_shark_compiler(use_tracing: bool, device: str, verbose=False):
|
||||
def compiler(
|
||||
fx_graph: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor],
|
||||
):
|
||||
"""Compile GraphModule using torch-mlir + SHARK."""
|
||||
if verbose:
|
||||
print("Compiling graph...")
|
||||
|
||||
if _returns_nothing(fx_graph):
|
||||
return fx_graph
|
||||
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
|
||||
fx_graph = make_fx(
|
||||
fx_graph, decomposition_table=default_decompositions()
|
||||
)(*example_inputs)
|
||||
strip_overloads(fx_graph)
|
||||
|
||||
if verbose:
|
||||
print("torch.fx graph:")
|
||||
print(fx_graph.graph)
|
||||
|
||||
ts_compiler = torch.jit.trace if use_tracing else torch.jit.script
|
||||
ts_graph = ts_compiler(fx_graph, example_inputs)
|
||||
|
||||
if verbose:
|
||||
torch_mlir_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
example_inputs,
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
)
|
||||
print("\n\ntorch-mlir backend contract graph:")
|
||||
print(torch_mlir_module)
|
||||
|
||||
linalg_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
example_inputs,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
)
|
||||
import io
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
linalg_module.operation.write_bytecode(bytecode_stream)
|
||||
mlir_module = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module, mlir_dialect="linalg", device=device
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
def forward(*inputs):
|
||||
result = shark_module("forward", inputs)
|
||||
result = tuple() if result is None else result
|
||||
return (result,) if was_unwrapped else result
|
||||
|
||||
return forward
|
||||
|
||||
return compiler
|
||||
|
||||
|
||||
def check_results(compiled_results, eager_results):
|
||||
for compiled_result, eager_result in zip(compiled_results, eager_results):
|
||||
if not torch.allclose(
|
||||
compiled_result.to("cpu"), eager_result.to("cpu"), atol=1e-5
|
||||
):
|
||||
print("Compiled result does not match eager result")
|
||||
return
|
||||
print("Compiled result matches eager result!")
|
||||
|
||||
|
||||
def print_time_stats(times):
|
||||
times_tensor = torch.tensor(times)
|
||||
|
||||
def quantile_ms(q):
|
||||
return torch.quantile(times_tensor.to(float), q).item() / 1e6
|
||||
|
||||
print(f"Median: {quantile_ms(0.5)} ms")
|
||||
print(f"10%ile: {quantile_ms(0.1)} ms")
|
||||
print(f"90%ile: {quantile_ms(0.9)} ms")
|
||||
print(f"Total: {torch.sum(times_tensor) / 1e6} ms")
|
||||
print()
|
||||
@@ -12,8 +12,8 @@ from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
OPT_MODEL = "opt-1.3b"
|
||||
OPT_FS_NAME = "opt-1_3b"
|
||||
MAX_SEQUENCE_LENGTH = 30
|
||||
MAX_NEW_TOKENS = 20
|
||||
MAX_SEQUENCE_LENGTH = 128
|
||||
MAX_NEW_TOKENS = 60
|
||||
|
||||
|
||||
def create_module(model_name, tokenizer, device):
|
||||
@@ -110,13 +110,13 @@ if __name__ == "__main__":
|
||||
"facebook/" + OPT_MODEL, use_fast=False
|
||||
)
|
||||
vmfb_path = (
|
||||
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-sync.vmfb"
|
||||
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-task.vmfb"
|
||||
)
|
||||
opt_shark_module = SharkInference(mlir_module=None, device="cpu-sync")
|
||||
opt_shark_module = SharkInference(mlir_module=None, device="cpu-task")
|
||||
if os.path.isfile(vmfb_path):
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-sync")
|
||||
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-task")
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
while True:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user