mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
93 Commits
20230428.7
...
20230613.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a4c8fcb3e | ||
|
|
2fec3c8169 | ||
|
|
5e7d5930dd | ||
|
|
b6dbd20250 | ||
|
|
34f1295349 | ||
|
|
1980d7b2c3 | ||
|
|
2cfacc5051 | ||
|
|
436f58ddc4 | ||
|
|
6b29bd17c8 | ||
|
|
2c3485ca3e | ||
|
|
f206ecc635 | ||
|
|
a187e05ae6 | ||
|
|
8c21960486 | ||
|
|
be62fce676 | ||
|
|
f23b778a6c | ||
|
|
436edf900d | ||
|
|
ed58c2553f | ||
|
|
f2ca58e844 | ||
|
|
1dbcc736eb | ||
|
|
a83808ddc5 | ||
|
|
a07fe80530 | ||
|
|
d0ba3ef8fa | ||
|
|
8400529c2c | ||
|
|
7eaee9c242 | ||
|
|
8230eebce5 | ||
|
|
6296ea4be9 | ||
|
|
4151ec3a8f | ||
|
|
a2467e8d43 | ||
|
|
e677178bcc | ||
|
|
7ef1bea953 | ||
|
|
ad89bb1413 | ||
|
|
218ed78c40 | ||
|
|
6046f36ab6 | ||
|
|
5915bf7de3 | ||
|
|
f0a4e59758 | ||
|
|
1ddef26af5 | ||
|
|
ba8eddb12f | ||
|
|
47b346d428 | ||
|
|
1b4f4f5f4d | ||
|
|
73cd7e8320 | ||
|
|
19c0ae3702 | ||
|
|
54e57f7771 | ||
|
|
6d64b8e273 | ||
|
|
a8ea0326f5 | ||
|
|
58e9194553 | ||
|
|
eb360e255d | ||
|
|
a6f88d7f72 | ||
|
|
8e571d165f | ||
|
|
3cddd01b10 | ||
|
|
64c2b2d96b | ||
|
|
f5ce121988 | ||
|
|
991f144598 | ||
|
|
09bea17e59 | ||
|
|
aefcf80b48 | ||
|
|
512235892e | ||
|
|
6602a2f5ba | ||
|
|
20114deea0 | ||
|
|
9acf519078 | ||
|
|
bdf37b5311 | ||
|
|
8ee2ac89f8 | ||
|
|
60cb48be2e | ||
|
|
86a215b063 | ||
|
|
d6e3a9a236 | ||
|
|
a0097a1ead | ||
|
|
a9bae00606 | ||
|
|
4731c1a835 | ||
|
|
4c07e47e8c | ||
|
|
e0cc2871bb | ||
|
|
649f39408b | ||
|
|
c142297d73 | ||
|
|
9e07360b00 | ||
|
|
7b74c86e42 | ||
|
|
fa833f8366 | ||
|
|
fcb059aa38 | ||
|
|
517c670f82 | ||
|
|
59df14f18b | ||
|
|
6c95ac0f37 | ||
|
|
7a4a51ae73 | ||
|
|
d816cc015e | ||
|
|
54ce3d48ca | ||
|
|
0e4a8ca240 | ||
|
|
6ca1298675 | ||
|
|
bbef7a6464 | ||
|
|
cdf2d61d53 | ||
|
|
6c14847d1f | ||
|
|
68ecdd2a73 | ||
|
|
3f4d444d18 | ||
|
|
e473d0375b | ||
|
|
e38d96850f | ||
|
|
fed63dfd4b | ||
|
|
eba4d06405 | ||
|
|
4cfba153d2 | ||
|
|
307c05f38d |
26
.github/workflows/nightly.yml
vendored
26
.github/workflows/nightly.yml
vendored
@@ -50,27 +50,13 @@ jobs:
|
||||
shell: powershell
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
|
||||
python process_skipfiles.py
|
||||
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
|
||||
|
||||
# GHA windows VM OOMs so disable for now
|
||||
#- name: Build and validate the SHARK Runtime package
|
||||
# shell: powershell
|
||||
# run: |
|
||||
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
#- uses: actions/upload-artifact@v2
|
||||
# with:
|
||||
# path: dist/*
|
||||
|
||||
mv ./dist/shark_sd.exe ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
@@ -78,7 +64,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ./dist/*
|
||||
assets_path: ./dist/nodai*
|
||||
#asset_content_type: application/vnd.microsoft.portable-executable
|
||||
|
||||
- name: Publish Release
|
||||
|
||||
2
.github/workflows/test-models.yml
vendored
2
.github/workflows/test-models.yml
vendored
@@ -137,7 +137,7 @@ jobs:
|
||||
export DYLD_LIBRARY_PATH=/usr/local/lib/
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan --update_tank
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
|
||||
@@ -1,25 +1,14 @@
|
||||
import torch
|
||||
import shark
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_inference import SharkInference
|
||||
import torch_mlir
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
import torch_mlir
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
base_models,
|
||||
get_opt_flags,
|
||||
get_vmfb_path_name,
|
||||
)
|
||||
from apps.stable_diffusion.src.models.model_wrappers import replace_shape_str
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"stabilityai/stablelm-tuned-alpha-7b"
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,6 +23,97 @@ class StopOnTokens(StoppingCriteria):
|
||||
return False
|
||||
|
||||
|
||||
def shouldStop(tokens):
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 256
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
def compile_stableLM(
|
||||
model,
|
||||
model_inputs,
|
||||
model_name,
|
||||
model_vmfb_name,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
):
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
vmfb_path = (
|
||||
Path(model_name + f"_{device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
shark_module = get_vmfb_from_path(
|
||||
vmfb_path, device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
if shark_module is not None:
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*model_inputs],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(model_name + ".mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir")
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
class StableLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
|
||||
|
||||
# Initialize a StopOnTokens object
|
||||
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
@@ -41,167 +121,90 @@ system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
|
||||
prompt = f"{system_prompt}<|USER|>What's your mood today?<|ASSISTANT|>"
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
def get_tokenizer():
|
||||
model_path = "stabilityai/stablelm-tuned-alpha-3b"
|
||||
tok = AutoTokenizer.from_pretrained(model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
print("Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
|
||||
class SLM(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
"stabilityai/stablelm-tuned-alpha-7b"
|
||||
# sharkStableLM = compile_stableLM
|
||||
# (
|
||||
# None,
|
||||
# tuple([input_ids, attention_mask]),
|
||||
# "stableLM_linalg_f32_seqLen256",
|
||||
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
|
||||
# )
|
||||
def generate(
|
||||
new_text,
|
||||
max_new_tokens,
|
||||
sharkStableLM,
|
||||
tokenizer=None,
|
||||
):
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer()
|
||||
# Construct the input message string for the model by
|
||||
# concatenating the current system message and conversation history
|
||||
# Tokenize the messages string
|
||||
# sharkStableLM = compile_stableLM
|
||||
# (
|
||||
# None,
|
||||
# tuple([input_ids, attention_mask]),
|
||||
# "stableLM_linalg_f32_seqLen256",
|
||||
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
|
||||
# )
|
||||
words_list = []
|
||||
for i in range(max_new_tokens):
|
||||
# numWords = len(new_text.split())
|
||||
# if(numWords>220):
|
||||
# break
|
||||
params = {
|
||||
"new_text": new_text,
|
||||
}
|
||||
generated_token_op = generate_new_token(
|
||||
sharkStableLM, tokenizer, params
|
||||
)
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.model(input_ids, attention_mask)[0]
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
if stop_generation:
|
||||
break
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text = new_text + detok
|
||||
return words_list
|
||||
|
||||
|
||||
slm_model = SLM()
|
||||
|
||||
res_pytorch = slm_model(inputs["input_ids"], inputs["attention_mask"])
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
|
||||
fx_g = make_fx(
|
||||
slm_model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(inputs["input_ids"], inputs["attention_mask"])
|
||||
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
[inputs["input_ids"], inputs["attention_mask"]],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
result_shark = shark_module(
|
||||
"forward", [inputs["input_ids"], inputs["attention_mask"]]
|
||||
)
|
||||
|
||||
print("Result PyTorch")
|
||||
print(res_pytorch)
|
||||
print("Result SHARK")
|
||||
print(result_shark)
|
||||
def generate_new_token(shark_model, tokenizer, params):
|
||||
new_text = params["new_text"]
|
||||
model_inputs = tokenizer(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
output = shark_model(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
|
||||
detok = tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
15
apps/language_models/src/model_wrappers/stablelm_model.py
Normal file
15
apps/language_models/src/model_wrappers/stablelm_model.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
class StableLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
261
apps/language_models/src/model_wrappers/vicuna_model.py
Normal file
261
apps/language_models/src/model_wrappers/vicuna_model.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
class FirstVicuna(torch.nn.Module):
|
||||
def __init__(self, model_path):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
|
||||
def forward(self, input_ids):
|
||||
op = self.model(input_ids=input_ids, use_cache=True)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna(torch.nn.Module):
|
||||
def __init__(self, model_path):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
i0,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
):
|
||||
# input_ids = input_tuple[0]
|
||||
# input_tuple = torch.unbind(pkv, dim=0)
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class CombinedModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
first_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
second_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
):
|
||||
super().__init__()
|
||||
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
|
||||
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
|
||||
|
||||
def forward(self, input_ids):
|
||||
first_output = self.first_vicuna(input_ids=input_ids, use_cache=True)
|
||||
logits = first_output[0]
|
||||
pkv = first_output[1:]
|
||||
|
||||
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
|
||||
token = token.to(torch.int64).reshape([1, 1])
|
||||
secondVicunaInput = (token,) + tuple(pkv)
|
||||
second_output = self.second_vicuna(secondVicunaInput)
|
||||
return second_output
|
||||
178
apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Normal file
178
apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import torch
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class SecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class CompiledFirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CompiledSecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers0, layers1):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert len(layers0) == len(model.model.layers)
|
||||
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers0 = layers0
|
||||
self.layers1 = layers1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
is_first=True,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
if is_first:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers0
|
||||
)
|
||||
return self.model.forward(input_ids, attention_mask=attention_mask)
|
||||
else:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers1
|
||||
)
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
41
apps/language_models/src/pipelines/SharkLLMBase.py
Normal file
41
apps/language_models/src/pipelines/SharkLLMBase.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class SharkLLMBase(ABC):
|
||||
def __init__(
|
||||
self, model_name, hf_model_path=None, max_num_tokens=512
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.hf_model_path = hf_model_path
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.shark_model = None
|
||||
self.device = "cpu"
|
||||
self.precision = "fp32"
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def compile(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate(self, prompt):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_new_token(self, params):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_tokenizer(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_src_model(self):
|
||||
pass
|
||||
|
||||
def load_init_from_config(self):
|
||||
pass
|
||||
185
apps/language_models/src/pipelines/stablelm_pipeline.py
Normal file
185
apps/language_models/src/pipelines/stablelm_pipeline.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.src.model_wrappers.stablelm_model import (
|
||||
StableLMModel,
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class SharkStableLM(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def shouldStop(self, tokens):
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_src_model(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, torch_dtype=torch.float32
|
||||
)
|
||||
return model
|
||||
|
||||
def get_model_inputs(self):
|
||||
input_ids = torch.randint(3, (1, self.max_sequence_len))
|
||||
attention_mask = torch.randint(3, (1, self.max_sequence_len))
|
||||
return input_ids, attention_mask
|
||||
|
||||
def compile(self):
|
||||
tmp_model_name = (
|
||||
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
|
||||
)
|
||||
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
model_vmfb_name = None
|
||||
vmfb_path = (
|
||||
Path(tmp_model_name + f"_{self.device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
shark_module = get_vmfb_from_path(
|
||||
vmfb_path, self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
if shark_module is not None:
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(tmp_model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
model = StableLMModel(self.get_src_model())
|
||||
model_inputs = self.get_model_inputs()
|
||||
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*model_inputs],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(tmp_model_name + ".mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir")
|
||||
f_.close()
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
|
||||
def get_tokenizer(self):
|
||||
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
def generate(self, prompt):
|
||||
words_list = []
|
||||
for i in range(self.max_num_tokens):
|
||||
params = {
|
||||
"new_text": prompt,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params)
|
||||
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
|
||||
if stop_generation:
|
||||
break
|
||||
|
||||
print(detok, end="", flush=True) # this is for CLI and DEBUG
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
prompt = prompt + detok
|
||||
return words_list
|
||||
|
||||
def generate_new_token(self, params):
|
||||
new_text = params["new_text"]
|
||||
model_inputs = self.tokenizer(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=self.max_sequence_len,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
output = self.shark_model(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if self.shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
|
||||
detok = self.tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
|
||||
# Initialize a StopOnTokens object
|
||||
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
646
apps/language_models/src/pipelines/vicuna_pipeline.py
Normal file
646
apps/language_models/src/pipelines/vicuna_pipeline.py
Normal file
@@ -0,0 +1,646 @@
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="vicuna runner",
|
||||
description="runs a vicuna model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
"--first_vicuna_vmfb_path", default=None, help="path to first vicuna vmfb"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--second_vicuna_vmfb_path",
|
||||
default=None,
|
||||
help="path to second vicuna vmfb",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first_vicuna_mlir_path",
|
||||
default=None,
|
||||
help="path to first vicuna mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--second_vicuna_mlir_path",
|
||||
default=None,
|
||||
help="path to second vicuna mlir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cli",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model in cli mode",
|
||||
)
|
||||
|
||||
|
||||
class Vicuna(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
first_vicuna_mlir_path=Path("first_vicuna.mlir"),
|
||||
second_vicuna_mlir_path=Path("second_vicuna.mlir"),
|
||||
first_vicuna_vmfb_path=Path("first_vicuna.vmfb"),
|
||||
second_vicuna_vmfb_path=Path("second_vicuna.vmfb"),
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
|
||||
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
|
||||
self.first_vicuna_mlir_path = first_vicuna_mlir_path
|
||||
self.second_vicuna_mlir_path = second_vicuna_mlir_path
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
return vicuna_model
|
||||
|
||||
def compile_first_vicuna(self):
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.first_vicuna_vmfb_path, self.device, "tm_tensor"
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
# Compilation path needs some more work before it is functional
|
||||
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with"
|
||||
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if self.first_vicuna_mlir_path.exists():
|
||||
with open(self.first_vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
if args.load_mlir_from_shark_tank:
|
||||
if self.precision == "fp32":
|
||||
# download MLIR from shark_tank for fp32
|
||||
download_public_file(
|
||||
"gs://shark_tank/vicuna/unsharded/mlir/second_vicuna.mlir",
|
||||
self.first_vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.first_vicuna_mlir_path.exists():
|
||||
with open(self.first_vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Only fp32 mlir added to tank, generating mlir on device."
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = self.tokenizer(
|
||||
compilation_prompt
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(
|
||||
compilation_input_ids
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(self.hf_model_path)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = get_torch_mlir_module_bytecode(
|
||||
model, firstVicunaCompileInput
|
||||
)
|
||||
del model
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
|
||||
firstVicunaCompileInput = list(firstVicunaCompileInput)
|
||||
firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like(
|
||||
firstVicunaCompileInput[0], dynamic_axes=[1]
|
||||
)
|
||||
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dim)", line
|
||||
)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
return line
|
||||
|
||||
module = str(module)
|
||||
new_lines = []
|
||||
|
||||
print(f"[DEBUG] rewriting torch_mlir file")
|
||||
for line in module.splitlines():
|
||||
line = remove_constant_dim(line)
|
||||
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
|
||||
new_lines.append(
|
||||
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
|
||||
)
|
||||
if (
|
||||
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
|
||||
in line
|
||||
):
|
||||
continue
|
||||
|
||||
new_lines.append(line)
|
||||
|
||||
module = "\n".join(new_lines)
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
del new_lines
|
||||
module = module.encode("UTF-8")
|
||||
module = BytesIO(module)
|
||||
bytecode = module.read()
|
||||
del module
|
||||
|
||||
print(f"[DEBUG] writing mlir to file")
|
||||
f_ = open(f"{self.model_name}.mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.first_vicuna_vmfb_path.parent.absolute(),
|
||||
self.first_vicuna_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
print("Saved first vic vmfb at vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile_second_vicuna(self):
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.second_vicuna_vmfb_path, self.device, "tm_tensor"
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
# Compilation path needs some more work before it is functional
|
||||
print(
|
||||
f"[DEBUG] mlir path {self.second_vicuna_mlir_path} {'exists' if self.second_vicuna_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if self.second_vicuna_mlir_path.exists():
|
||||
with open(self.second_vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
if args.load_mlir_from_shark_tank:
|
||||
if self.precision == "fp32":
|
||||
# download MLIR from shark_tank for fp32
|
||||
download_public_file(
|
||||
"gs://shark_tank/vicuna/unsharded/mlir/second_vicuna.mlir",
|
||||
self.second_vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.second_vicuna_mlir_path.exists():
|
||||
with open(self.second_vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Only fp32 mlir added to tank, generating mlir on device."
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
|
||||
pkv = tuple(
|
||||
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
|
||||
for _ in range(64)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
model = SecondVicuna(self.hf_model_path)
|
||||
ts_graph = get_torch_mlir_module_bytecode(
|
||||
model, secondVicunaCompileInput
|
||||
)
|
||||
secondVicunaCompileInput = list(secondVicunaCompileInput)
|
||||
for i in range(len(secondVicunaCompileInput)):
|
||||
if i != 0:
|
||||
secondVicunaCompileInput[
|
||||
i
|
||||
] = torch_mlir.TensorPlaceholder.like(
|
||||
secondVicunaCompileInput[i], dynamic_axes=[2]
|
||||
)
|
||||
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "c19_i64" in line:
|
||||
line = re.sub("c19_i64", "dim_i64", line)
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dim)", line
|
||||
)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
if "20x" in line:
|
||||
line = re.sub("20x", "?x", line)
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dimp1)", line
|
||||
)
|
||||
if " 20," in line:
|
||||
line = re.sub(" 20,", " %dimp1,", line)
|
||||
return line
|
||||
|
||||
module_str = str(module)
|
||||
new_lines = []
|
||||
|
||||
for line in module_str.splitlines():
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append(
|
||||
"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128xf32>"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
)
|
||||
continue
|
||||
if "%c2 = arith.constant 2 : index" in line:
|
||||
continue
|
||||
if "%c20_i64 = arith.constant 20 : i64" in line:
|
||||
new_lines.append("%c1_i64 = arith.constant 1 : i64")
|
||||
new_lines.append(
|
||||
"%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
|
||||
)
|
||||
continue
|
||||
line = remove_constant_dim(line)
|
||||
new_lines.append(line)
|
||||
|
||||
module_str = "\n".join(new_lines)
|
||||
bytecode = module_str.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
f_ = open(self.second_vicuna_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
|
||||
path = shark_module.save_module(
|
||||
self.second_vicuna_vmfb_path.parent.absolute(),
|
||||
self.second_vicuna_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
shark_module.load_module(self.second_vicuna_vmfb_path)
|
||||
|
||||
# self.shark_module = shark_module
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile(self):
|
||||
# Cannot load both the models in the memory at once
|
||||
# due to memory constraints, hence on demand compilation
|
||||
# is being used until the space is enough for both models
|
||||
|
||||
# Testing : DO NOT Download Vmfbs if not found. Modify later
|
||||
# download vmfbs for A100
|
||||
if (
|
||||
not self.first_vicuna_vmfb_path.exists()
|
||||
and self.device == "cuda"
|
||||
and self.precision == "fp32"
|
||||
):
|
||||
download_public_file(
|
||||
"gs://shark_tank/vicuna/unsharded/first_vicuna.vmfb",
|
||||
self.first_vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
else:
|
||||
# get first vic
|
||||
# TODO: Remove after testing to avoid memory overload
|
||||
# fvic_shark_model = self.compile_first_vicuna()
|
||||
pass
|
||||
if (
|
||||
not self.second_vicuna_vmfb_path.exists()
|
||||
and self.device == "cuda"
|
||||
and self.precision == "fp32"
|
||||
):
|
||||
download_public_file(
|
||||
"gs://shark_tank/vicuna/unsharded/second_vicuna.vmfb",
|
||||
self.second_vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
else:
|
||||
# get second vic
|
||||
# TODO: Remove after testing to avoid memory overload
|
||||
# svic_shark_model = self.compile_second_vicuna()
|
||||
pass
|
||||
|
||||
# get first vic
|
||||
# fvic_shark_model = self.compile_first_vicuna()
|
||||
# get second vic
|
||||
# svic_shark_model = self.compile_second_vicuna()
|
||||
# return tuple of shark_modules
|
||||
# return fvic_shark_model, svic_shark_model
|
||||
return None
|
||||
# return tuple of shark_modules once mem is supported
|
||||
# return fvic_shark_model, svic_shark_model
|
||||
|
||||
def generate(self, prompt):
|
||||
# TODO: refactor for cleaner integration
|
||||
import gc
|
||||
|
||||
res = []
|
||||
res_tokens = []
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"is_first": True,
|
||||
"fv": self.compile_first_vicuna(),
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["pkv"]
|
||||
detok = generated_token_op["detok"]
|
||||
|
||||
res.append(detok)
|
||||
res_tokens.append(token)
|
||||
if args.cli:
|
||||
print(f"Assistant: {detok}", end=" ", flush=True)
|
||||
|
||||
# Clear First Vic from Memory (main and cuda)
|
||||
del params
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
sec_vic = self.compile_second_vicuna()
|
||||
for _ in range(self.max_num_tokens - 2):
|
||||
params = {
|
||||
"prompt": None,
|
||||
"is_first": False,
|
||||
"logits": logits,
|
||||
"pkv": pkv,
|
||||
"sv": sec_vic,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["pkv"]
|
||||
detok = generated_token_op["detok"]
|
||||
|
||||
if token == 2:
|
||||
break
|
||||
res_tokens.append(token)
|
||||
if detok == "<0x0A>":
|
||||
res.append("\n")
|
||||
if args.cli:
|
||||
print("\n", end="", flush=True)
|
||||
else:
|
||||
res.append(detok)
|
||||
if args.cli:
|
||||
print(f"{detok}", end=" ", flush=True)
|
||||
del sec_vic, pkv, logits
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
for i in range(len(res_tokens)):
|
||||
if type(res_tokens[i]) != int:
|
||||
res_tokens[i] = int(res_tokens[i][0])
|
||||
|
||||
res_str = self.tokenizer.decode(res_tokens)
|
||||
# print(f"[DEBUG] final output : \n{res_str}")
|
||||
return res_str
|
||||
|
||||
def generate_new_token(self, params):
|
||||
def forward_first(first_vic, prompt, cache_outputs=False):
|
||||
input_ids = self.tokenizer(prompt).input_ids
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
firstVicunaInput = (input_ids,)
|
||||
assert first_vic is not None
|
||||
output_first_vicuna = first_vic("forward", firstVicunaInput)
|
||||
output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:])
|
||||
logits_first_vicuna = torch.tensor(output_first_vicuna[0])
|
||||
if cache_outputs:
|
||||
torch.save(
|
||||
logits_first_vicuna, "logits_first_vicuna_tensor.pt"
|
||||
)
|
||||
torch.save(
|
||||
output_first_vicuna_tensor, "output_first_vicuna_tensor.pt"
|
||||
)
|
||||
token = torch.argmax(
|
||||
torch.tensor(logits_first_vicuna)[:, -1, :], dim=1
|
||||
)
|
||||
return token, logits_first_vicuna, output_first_vicuna_tensor
|
||||
|
||||
def forward_second(sec_vic, inputs=None, load_inputs=False):
|
||||
if inputs is not None:
|
||||
logits = inputs[0]
|
||||
pkv = inputs[1:]
|
||||
elif load_inputs:
|
||||
pkv = torch.load("output_first_vicuna_tensor.pt")
|
||||
pkv = tuple(torch.tensor(x) for x in pkv)
|
||||
logits = torch.load("logits_first_vicuna_tensor.pt")
|
||||
else:
|
||||
print(
|
||||
"Either inputs must be given, or load_inputs must be true"
|
||||
)
|
||||
return None
|
||||
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
|
||||
token = token.to(torch.int64).reshape([1, 1])
|
||||
secondVicunaInput = (token,) + tuple(pkv)
|
||||
|
||||
secondVicunaOutput = sec_vic("forward", secondVicunaInput)
|
||||
new_pkv = secondVicunaOutput[1:]
|
||||
new_logits = secondVicunaOutput[0]
|
||||
new_token = torch.argmax(torch.tensor(new_logits)[:, -1, :], dim=1)
|
||||
return new_token, new_logits, new_pkv
|
||||
|
||||
is_first = params["is_first"]
|
||||
|
||||
if is_first:
|
||||
prompt = params["prompt"]
|
||||
fv = params["fv"]
|
||||
token, logits, pkv = forward_first(
|
||||
fv, # self.shark_model[0],
|
||||
prompt=prompt,
|
||||
cache_outputs=False,
|
||||
)
|
||||
else:
|
||||
_logits = params["logits"]
|
||||
_pkv = params["pkv"]
|
||||
inputs = (_logits,) + tuple(_pkv)
|
||||
sv = params["sv"]
|
||||
token, logits, pkv = forward_second(
|
||||
sv, # self.shark_model[1],
|
||||
inputs=inputs,
|
||||
load_inputs=False,
|
||||
)
|
||||
|
||||
detok = self.tokenizer.decode(token)
|
||||
if not args.cli:
|
||||
print(
|
||||
f"[DEBUG] is_first: {is_first} |"
|
||||
f" token : {token} | detok : {detok}"
|
||||
)
|
||||
ret_dict = {
|
||||
"token": token,
|
||||
"logits": logits,
|
||||
"pkv": pkv,
|
||||
"detok": detok,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
first_vic_mlir_path = (
|
||||
Path("first_vicuna.mlir")
|
||||
if args.first_vicuna_mlir_path is None
|
||||
else Path(args.first_vicuna_mlir_path)
|
||||
)
|
||||
second_vic_mlir_path = (
|
||||
Path("second_vicuna.mlir")
|
||||
if args.second_vicuna_mlir_path is None
|
||||
else Path(args.second_vicuna_mlir_path)
|
||||
)
|
||||
first_vic_vmfb_path = (
|
||||
Path("first_vicuna.vmfb")
|
||||
if args.first_vicuna_vmfb_path is None
|
||||
else Path(args.first_vicuna_vmfb_path)
|
||||
)
|
||||
second_vic_vmfb_path = (
|
||||
Path("second_vicuna.vmfb")
|
||||
if args.second_vicuna_vmfb_path is None
|
||||
else Path(args.second_vicuna_vmfb_path)
|
||||
)
|
||||
|
||||
vic = Vicuna(
|
||||
"vicuna",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
first_vicuna_mlir_path=first_vic_mlir_path,
|
||||
second_vicuna_mlir_path=second_vic_mlir_path,
|
||||
first_vicuna_vmfb_path=first_vic_vmfb_path,
|
||||
second_vicuna_vmfb_path=second_vic_vmfb_path,
|
||||
)
|
||||
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
import gc
|
||||
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
prompt_history = (
|
||||
prompt_history + "USER:\n" + user_prompt + prologue_prompt
|
||||
)
|
||||
prompt = prompt_history.strip()
|
||||
res_str = vic.generate(prompt)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
print(
|
||||
"\n-----\nAssistant: Here's the complete formatted reply:\n",
|
||||
res_str,
|
||||
)
|
||||
prompt_history += f"\n{res_str}\n"
|
||||
416
apps/language_models/src/pipelines/vicuna_sharded_pipeline.py
Normal file
416
apps/language_models/src/pipelines/vicuna_sharded_pipeline.py
Normal file
@@ -0,0 +1,416 @@
|
||||
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
FirstVicunaLayer,
|
||||
SecondVicunaLayer,
|
||||
CompiledFirstVicunaLayer,
|
||||
CompiledSecondVicunaLayer,
|
||||
ShardedVicunaModel,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.utils import get_torch_mlir_module_bytecode
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from tqdm import tqdm
|
||||
from torch_mlir import TensorPlaceholder
|
||||
|
||||
|
||||
import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
|
||||
|
||||
class Vicuna(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
return vicuna_model
|
||||
|
||||
def write_in_dynamic_inputs0(self, module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
def write_in_dynamic_inputs1(self, module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
if "dim_42 =" in line:
|
||||
continue
|
||||
if f"%c{dynamic_input_size}_i64 =" in line:
|
||||
new_lines.append(
|
||||
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
|
||||
)
|
||||
new_lines.append(
|
||||
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
|
||||
)
|
||||
continue
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dim_42)", line
|
||||
)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim_42\)",
|
||||
"tensor.empty(%dim_42, %dim_42)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
def compile_vicuna_layer(
|
||||
self,
|
||||
vicuna_layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0=None,
|
||||
past_key_value1=None,
|
||||
):
|
||||
if past_key_value0 is None and past_key_value1 is None:
|
||||
model_inputs = (hidden_states, attention_mask, position_ids)
|
||||
else:
|
||||
model_inputs = (
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
)
|
||||
mlir_bytecode = get_torch_mlir_module_bytecode(
|
||||
vicuna_layer, model_inputs
|
||||
)
|
||||
return mlir_bytecode
|
||||
|
||||
def compile_to_vmfb(self, inputs, layers, is_first=True):
|
||||
mlirs, modules = [], []
|
||||
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
|
||||
if is_first:
|
||||
mlir_path = Path(f"{idx}_0.mlir")
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
else:
|
||||
mlir_path = Path(f"{idx}_1.mlir")
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if vmfb_path.exists():
|
||||
continue
|
||||
if mlir_path.exists():
|
||||
# print(f"Found layer {idx} mlir")
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
inputs[0], dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
inputs[1], dynamic_axes=[3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
inputs[2], dynamic_axes=[1]
|
||||
)
|
||||
if not is_first:
|
||||
pkv0_placeholder = TensorPlaceholder.like(
|
||||
inputs[3], dynamic_axes=[2]
|
||||
)
|
||||
pkv1_placeholder = TensorPlaceholder.like(
|
||||
inputs[4], dynamic_axes=[2]
|
||||
)
|
||||
print(f"Compiling layer {idx} mlir")
|
||||
if is_first:
|
||||
ts_g = self.compile_vicuna_layer(
|
||||
layer, inputs[0], inputs[1], inputs[2]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
hidden_states_placeholder,
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
else:
|
||||
ts_g = self.compile_vicuna_layer(
|
||||
layer,
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
inputs[3],
|
||||
inputs[4],
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
inputs[0],
|
||||
attention_mask_placeholder,
|
||||
inputs[2],
|
||||
pkv0_placeholder,
|
||||
pkv1_placeholder,
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# bytecode_stream = BytesIO()
|
||||
# module.operation.write_bytecode(bytecode_stream)
|
||||
# bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if is_first:
|
||||
module = self.write_in_dynamic_inputs0(str(module), 137)
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
else:
|
||||
module = self.write_in_dynamic_inputs1(str(module), 138)
|
||||
if idx in [0, 5, 6, 7]:
|
||||
module_str = module
|
||||
module_str = module_str.splitlines()
|
||||
new_lines = []
|
||||
for line in module_str:
|
||||
if len(line) < 1000:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line[:999])
|
||||
module_str = "\n".join(new_lines)
|
||||
f1_ = open(f"{idx}_1_test.mlir", "w+")
|
||||
f1_.write(module_str)
|
||||
f1_.close()
|
||||
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
mlirs.append(bytecode)
|
||||
|
||||
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
|
||||
if is_first:
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_0",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
else:
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_1",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
|
||||
return mlirs, modules
|
||||
|
||||
def get_sharded_model(self):
|
||||
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
|
||||
# please don't change it
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
vicuna_model = self.get_src_model()
|
||||
placeholder_input0 = (
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
|
||||
)
|
||||
|
||||
placeholder_input1 = (
|
||||
torch.zeros([1, 1, 4096]),
|
||||
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
|
||||
torch.zeros([1, 1], dtype=torch.int64),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
|
||||
layers0 = [
|
||||
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
_, modules0 = self.compile_to_vmfb(
|
||||
placeholder_input0, layers0, is_first=True
|
||||
)
|
||||
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
|
||||
|
||||
layers1 = [
|
||||
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
_, modules1 = self.compile_to_vmfb(
|
||||
placeholder_input1, layers1, is_first=False
|
||||
)
|
||||
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
|
||||
|
||||
sharded_model = ShardedVicunaModel(
|
||||
vicuna_model, shark_layers0, shark_layers1
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
def compile(self):
|
||||
return self.get_sharded_model()
|
||||
|
||||
def generate(self, prompt):
|
||||
# TODO: refactor for cleaner integration
|
||||
|
||||
tokens_generated = []
|
||||
_past_key_values = None
|
||||
_token = None
|
||||
detoks_generated = []
|
||||
for iteration in range(self.max_num_tokens):
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"is_first": iteration == 0,
|
||||
"token": _token,
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
|
||||
_token = generated_token_op["token"]
|
||||
_past_key_values = generated_token_op["past_key_values"]
|
||||
_detok = generated_token_op["detok"]
|
||||
|
||||
if _token == 2:
|
||||
break
|
||||
detoks_generated.append(_detok)
|
||||
tokens_generated.append(_token)
|
||||
|
||||
for i in range(len(tokens_generated)):
|
||||
if type(tokens_generated[i]) != int:
|
||||
tokens_generated[i] = int(tokens_generated[i][0])
|
||||
result_output = self.tokenizer.decode(tokens_generated)
|
||||
return result_output
|
||||
|
||||
def generate_new_token(self, params):
|
||||
is_first = params["is_first"]
|
||||
if is_first:
|
||||
prompt = params["prompt"]
|
||||
input_ids = self.tokenizer(prompt).input_ids
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
output = self.shark_model.forward(input_ids, is_first=is_first)
|
||||
else:
|
||||
token = params["token"]
|
||||
past_key_values = params["past_key_values"]
|
||||
input_ids = [token]
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
output = self.shark_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=is_first
|
||||
)
|
||||
|
||||
_logits = output["logits"]
|
||||
_past_key_values = output["past_key_values"]
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
_detok = self.tokenizer.decode(_token)
|
||||
|
||||
ret_dict = {
|
||||
"token": _token,
|
||||
"detok": _detok,
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
|
||||
print(f" token : {_token} | detok : {_detok}")
|
||||
|
||||
return ret_dict
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
vic = Vicuna("vicuna")
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
user_prompt = input("User: ")
|
||||
prompt_history = prompt_history + "USER:\n" + user_prompt + prologue_prompt
|
||||
prompt = prompt_history.strip()
|
||||
|
||||
res = vic.generate(prompt)
|
||||
print(prompt + res)
|
||||
140
apps/language_models/utils.py
Normal file
140
apps/language_models/utils.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_torch_mlir_module_bytecode(model, model_inputs):
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
# tracing_mode='symbolic',
|
||||
)(*model_inputs)
|
||||
print("Got FX_G")
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
print("FX_G recompile")
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
print("Got TS_G")
|
||||
return ts_g
|
||||
|
||||
|
||||
# expects a Path / str as arg
|
||||
# returns None if path not found or SharkInference module
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
if not isinstance(vmfb_path, Path):
|
||||
vmfb_path = Path(vmfb_path)
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
if not vmfb_path.exists():
|
||||
return None
|
||||
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect=mlir_dialect
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
return shark_module
|
||||
@@ -10,7 +10,7 @@ Vulkan AMD:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
|
||||
# use –iree-input-type=mhlo for tf models
|
||||
# use –iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
|
||||
|
||||
CUDA NVIDIA:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
@@ -31,6 +31,7 @@ datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('shark')
|
||||
datas += collect_data_files('tkinter')
|
||||
datas += collect_data_files('webview')
|
||||
datas += collect_data_files('sentencepiece')
|
||||
datas += [
|
||||
( 'src/utils/resources/prompts.json', 'resources' ),
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
@@ -46,6 +47,7 @@ block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['web/index.py'],
|
||||
|
||||
@@ -42,6 +42,7 @@ block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/main.py'],
|
||||
|
||||
@@ -14,6 +14,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
base_models,
|
||||
args,
|
||||
preprocessCKPT,
|
||||
convert_original_vae,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_stem,
|
||||
@@ -571,8 +572,12 @@ class SharkifyStableDiffusionModel:
|
||||
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
|
||||
if "state_dict" in vae_checkpoint:
|
||||
vae_checkpoint = vae_checkpoint["state_dict"]
|
||||
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict
|
||||
|
||||
try:
|
||||
vae_checkpoint = convert_original_vae(vae_checkpoint)
|
||||
finally:
|
||||
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict
|
||||
|
||||
def compile_unet_variants(self, model):
|
||||
if model == "unet":
|
||||
|
||||
@@ -20,6 +20,8 @@ from diffusers import (
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_IDLE,
|
||||
SD_STATE_CANCEL,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
@@ -84,6 +86,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.low_res_scheduler = low_res_scheduler
|
||||
self.status = SD_STATE_IDLE
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
accepts_eta = "eta" in set(
|
||||
@@ -164,6 +167,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.status = SD_STATE_IDLE
|
||||
self.load_unet()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
@@ -210,6 +214,9 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.status == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
|
||||
@@ -87,7 +87,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.text_encoder = get_clip()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.text_encoder = self.sd_model.clip()
|
||||
|
||||
@@ -104,7 +105,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.unet = get_unet()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.unet = self.sd_model.unet()
|
||||
|
||||
@@ -121,7 +123,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.vae = get_vae()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae = self.sd_model.vae()
|
||||
|
||||
|
||||
@@ -24,11 +24,14 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
get_available_devices,
|
||||
get_opt_flags,
|
||||
preprocessCKPT,
|
||||
convert_original_vae,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
sanitize_seed,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
get_generation_text_info,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"stablediffusion/untuned":"gs://shark_tank/nightly"
|
||||
},
|
||||
{
|
||||
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
|
||||
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
|
||||
@@ -125,6 +125,8 @@ def load_lower_configs(base_model_id=None):
|
||||
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
|
||||
|
||||
|
||||
@@ -495,13 +495,12 @@ p.add_argument(
|
||||
)
|
||||
# TODO: replace API flag when these can be run together
|
||||
p.add_argument(
|
||||
"--web_mode",
|
||||
"--ui",
|
||||
type=str,
|
||||
default="app",
|
||||
help="any number of: [api, app, webui]. Currently api can't be run with others.",
|
||||
default="app" if os.name == "nt" else "web",
|
||||
help="one of: [api, app, web]",
|
||||
)
|
||||
|
||||
|
||||
p.add_argument(
|
||||
"--share",
|
||||
default=False,
|
||||
@@ -522,6 +521,22 @@ p.add_argument(
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for enabling rest API",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for removing the output gallery tab, and avoid exposing images under --output_dir in the UI",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery_followlinks",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for whether the output gallery tab in the UI should follow symlinks when listing subdirectorys under --output_dir",
|
||||
)
|
||||
|
||||
|
||||
##############################################################################
|
||||
### SD model auto-annotation flags
|
||||
##############################################################################
|
||||
|
||||
@@ -25,7 +25,12 @@ from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
import sys
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
create_vae_diffusers_config,
|
||||
convert_ldm_vae_checkpoint,
|
||||
)
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def get_extended_name(model_name):
|
||||
@@ -78,7 +83,6 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
if "cuda" in args.device:
|
||||
@@ -333,13 +337,25 @@ def set_init_device_flags():
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
]
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
or "rdna" not in args.iree_vulkan_target_triple
|
||||
)
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif "rdna2" in args.iree_vulkan_target_triple and (
|
||||
base_model_id
|
||||
not in [
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
if args.use_tuned:
|
||||
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
|
||||
print(
|
||||
f"Using tuned models for {base_model_id}(fp16) on device {args.device}."
|
||||
)
|
||||
else:
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
@@ -464,7 +480,7 @@ def get_path_stem(path):
|
||||
def get_path_to_diffusers_checkpoint(custom_weights):
|
||||
path = Path(custom_weights)
|
||||
diffusers_path = path.parent.absolute()
|
||||
diffusers_directory_name = path.stem
|
||||
diffusers_directory_name = os.path.join("diffusers", path.stem)
|
||||
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
|
||||
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
|
||||
path_to_diffusers = complete_path_to_diffusers.as_posix()
|
||||
@@ -503,6 +519,22 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
print("Loading complete")
|
||||
|
||||
|
||||
def convert_original_vae(vae_checkpoint):
|
||||
vae_state_dict = {}
|
||||
for key in list(vae_checkpoint.keys()):
|
||||
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
|
||||
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=512)
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
vae_state_dict, vae_config
|
||||
)
|
||||
return converted_vae_checkpoint
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
@@ -673,11 +705,20 @@ def clear_all():
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
return Path(
|
||||
args.output_dir if args.output_dir else Path.cwd(), "generated_imgs"
|
||||
)
|
||||
|
||||
|
||||
def get_generated_imgs_todays_subdir() -> str:
|
||||
return dt.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed, extra_info={}):
|
||||
output_path = args.output_dir if args.output_dir else Path.cwd()
|
||||
generated_imgs_path = Path(
|
||||
output_path, "generated_imgs", dt.now().strftime("%Y%m%d")
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import transformers
|
||||
import transformers # ensures inclusion in pysintaller exe generation
|
||||
from apps.stable_diffusion.src import args, clear_all
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
# import before IREE to avoid MLIR library issues
|
||||
import torch_mlir
|
||||
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
@@ -16,17 +18,21 @@ def launch_app(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
|
||||
tk = Tk()
|
||||
# size of the window where we show our website
|
||||
tk.geometry("1280x720")
|
||||
webview.create_window("SHARK", address)
|
||||
window = Tk()
|
||||
|
||||
# getting screen width and height of display
|
||||
width = window.winfo_screenwidth()
|
||||
height = window.winfo_screenheight()
|
||||
webview.create_window(
|
||||
"SHARK AI Studio", url=address, width=width, height=height
|
||||
)
|
||||
webview.start(private_mode=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
if args.api or "api" in args.web_mode.split(","):
|
||||
if args.api or "api" in args.ui.split(","):
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_api,
|
||||
img2img_api,
|
||||
@@ -55,14 +61,12 @@ if __name__ == "__main__":
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
clear_gradio_tmp_imgs_folder,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
|
||||
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
clear_gradio_tmp_imgs_folder()
|
||||
# Create the custom model folder if it doesn't already exist
|
||||
dir = ["models", "vae", "lora"]
|
||||
for root in dir:
|
||||
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
|
||||
# Create custom models folders if they don't exist
|
||||
create_custom_models_folders()
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
@@ -75,37 +79,69 @@ if __name__ == "__main__":
|
||||
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
txt2img_sendto_img2img,
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
img2img_sendto_inpaint,
|
||||
img2img_sendto_outpaint,
|
||||
img2img_sendto_upscaler,
|
||||
inpaint_web,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
inpaint_sendto_img2img,
|
||||
inpaint_sendto_outpaint,
|
||||
inpaint_sendto_upscaler,
|
||||
outpaint_web,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
outpaint_sendto_img2img,
|
||||
outpaint_sendto_inpaint,
|
||||
outpaint_sendto_upscaler,
|
||||
upscaler_web,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
lora_train_web,
|
||||
model_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
modelmanager_sendto_inpaint,
|
||||
modelmanager_sendto_outpaint,
|
||||
modelmanager_sendto_upscaler,
|
||||
stablelm_chat,
|
||||
outputgallery_web,
|
||||
outputgallery_tab_select,
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
outputgallery_sendto_upscaler,
|
||||
)
|
||||
|
||||
# init global sd pipeline and config
|
||||
@@ -121,6 +157,27 @@ if __name__ == "__main__":
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_modelmanager_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
"None",
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_outputgallery_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
|
||||
) as sd_web:
|
||||
@@ -137,11 +194,27 @@ if __name__ == "__main__":
|
||||
upscaler_web.render()
|
||||
with gr.TabItem(label="Model Manager", id=5):
|
||||
model_web.render()
|
||||
|
||||
with gr.Tabs(visible=False) as experimental_tabs:
|
||||
with gr.TabItem(label="LoRA Training", id=5):
|
||||
with gr.TabItem(label="Chat Bot(Experimental)", id=6):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
|
||||
lora_train_web.render()
|
||||
if args.output_gallery:
|
||||
with gr.TabItem(label="Output Gallery", id=8) as og_tab:
|
||||
outputgallery_web.render()
|
||||
|
||||
# extra output gallery configuration
|
||||
outputgallery_tab_select(og_tab.select)
|
||||
outputgallery_watch(
|
||||
[
|
||||
txt2img_status,
|
||||
img2img_status,
|
||||
inpaint_status,
|
||||
outpaint_status,
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
|
||||
# send to buttons
|
||||
register_button_click(
|
||||
txt2img_sendto_img2img,
|
||||
1,
|
||||
@@ -238,15 +311,77 @@ if __name__ == "__main__":
|
||||
[upscaler_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
if args.output_gallery:
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_txt2img,
|
||||
0,
|
||||
[outputgallery_filename],
|
||||
[txt2img_png_info_img, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_img2img,
|
||||
1,
|
||||
[outputgallery_filename],
|
||||
[img2img_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_inpaint,
|
||||
2,
|
||||
[outputgallery_filename],
|
||||
[inpaint_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_outpaint,
|
||||
3,
|
||||
[outputgallery_filename],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_upscaler,
|
||||
4,
|
||||
[outputgallery_filename],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_txt2img,
|
||||
0,
|
||||
[hf_models],
|
||||
[txt2img_custom_model, txt2img_hf_model_id, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_img2img,
|
||||
1,
|
||||
[hf_models],
|
||||
[img2img_custom_model, img2img_hf_model_id, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_inpaint,
|
||||
2,
|
||||
[hf_models],
|
||||
[inpaint_custom_model, inpaint_hf_model_id, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_outpaint,
|
||||
3,
|
||||
[hf_models],
|
||||
[outpaint_custom_model, outpaint_hf_model_id, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_upscaler,
|
||||
4,
|
||||
[hf_models],
|
||||
[upscaler_custom_model, upscaler_hf_model_id, tabs],
|
||||
)
|
||||
|
||||
sd_web.queue()
|
||||
if "app" in args.web_mode.split(","):
|
||||
if args.ui == "app":
|
||||
t = Process(
|
||||
target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
)
|
||||
t.start()
|
||||
sd_web.launch(
|
||||
share=args.share,
|
||||
inbrowser="webui" in args.web_mode.split(","),
|
||||
inbrowser=args.ui == "web",
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.server_port,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,11 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
txt2img_inf,
|
||||
txt2img_api,
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
txt2img_sendto_img2img,
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
@@ -12,8 +16,11 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_inf,
|
||||
img2img_api,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
img2img_sendto_inpaint,
|
||||
img2img_sendto_outpaint,
|
||||
img2img_sendto_upscaler,
|
||||
@@ -22,8 +29,11 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
inpaint_inf,
|
||||
inpaint_api,
|
||||
inpaint_web,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
inpaint_sendto_img2img,
|
||||
inpaint_sendto_outpaint,
|
||||
inpaint_sendto_upscaler,
|
||||
@@ -32,8 +42,11 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
outpaint_inf,
|
||||
outpaint_api,
|
||||
outpaint_web,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
outpaint_sendto_img2img,
|
||||
outpaint_sendto_inpaint,
|
||||
outpaint_sendto_upscaler,
|
||||
@@ -42,11 +55,34 @@ from apps.stable_diffusion.web.ui.upscaler_ui import (
|
||||
upscaler_inf,
|
||||
upscaler_api,
|
||||
upscaler_web,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.model_manager import (
|
||||
model_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
modelmanager_sendto_inpaint,
|
||||
modelmanager_sendto_outpaint,
|
||||
modelmanager_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
|
||||
from apps.stable_diffusion.web.ui.model_manager import model_web
|
||||
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat
|
||||
from apps.stable_diffusion.web.ui.outputgallery_ui import (
|
||||
outputgallery_web,
|
||||
outputgallery_tab_select,
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
outputgallery_sendto_upscaler,
|
||||
)
|
||||
|
||||
@@ -230,3 +230,44 @@ footer {
|
||||
#top_logo .download {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* output gallery tab */
|
||||
.output_parameters_dataframe tbody td {
|
||||
font-size: small;
|
||||
line-height: var(--line-xs)
|
||||
}
|
||||
|
||||
#output_refresh_button {
|
||||
max-width: 30px;
|
||||
align-self: end;
|
||||
padding-bottom: 8px;
|
||||
}
|
||||
|
||||
.outputgallery_sendto {
|
||||
min-width: 7em !important;
|
||||
}
|
||||
|
||||
/* output gallery should take up most of the viewport height regardless of image size/number */
|
||||
#outputgallery_gallery .fixed-height {
|
||||
min-height: 89vh !important;
|
||||
}
|
||||
|
||||
/* don't stretch non-square images to be square, breaking their aspect ratio */
|
||||
#outputgallery_gallery .thumbnail-item.thumbnail-lg > img {
|
||||
object-fit: contain !important;
|
||||
}
|
||||
|
||||
/* centered logo for when there are no images */
|
||||
#top_logo.logo_centered {
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
#top_logo.logo_centered img{
|
||||
object-fit: scale-down;
|
||||
position: absolute;
|
||||
width: 80%;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
@@ -25,10 +24,13 @@ from apps.stable_diffusion.src import (
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -89,6 +91,8 @@ def img2img_inf(
|
||||
return None, "An Initial Image is required"
|
||||
if use_stencil == "scribble":
|
||||
image = image_dict["mask"].convert("RGB")
|
||||
elif isinstance(image_dict, PIL.Image.Image):
|
||||
image = image_dict.convert("RGB")
|
||||
else:
|
||||
image = image_dict["image"].convert("RGB")
|
||||
|
||||
@@ -256,11 +260,17 @@ def img2img_inf(
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed, extra_info)
|
||||
save_output_img(
|
||||
out_imgs[0],
|
||||
img_seed,
|
||||
extra_info,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
# yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Image-to-Image", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
@@ -299,7 +309,7 @@ def img2img_api(
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["image"])
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = img2img_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
@@ -352,17 +362,17 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
img2img_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
img2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
@@ -590,16 +600,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
img2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
img2img_sendto_outpaint = gr.Button(
|
||||
@@ -624,8 +631,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -637,13 +644,21 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output],
|
||||
outputs=[img2img_gallery, std_output, img2img_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Image-to-Image", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=img2img_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
@@ -26,7 +25,11 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
@@ -213,7 +216,9 @@ def inpaint_inf(
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Inpaint", i + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
@@ -306,17 +311,19 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
inpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
inpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
@@ -492,16 +499,14 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
inpaint_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
inpaint_sendto_outpaint = gr.Button(
|
||||
@@ -527,8 +532,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -539,13 +544,20 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[inpaint_gallery, std_output],
|
||||
outputs=[inpaint_gallery, std_output, inpaint_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=inpaint_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
@@ -3,16 +3,15 @@ import gradio as gr
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from shark.iree_utils._common import run_cmd
|
||||
|
||||
|
||||
def get_hf_list(limit=20):
|
||||
def get_hf_list(num_of_models=20):
|
||||
path = "https://huggingface.co/api/models"
|
||||
params = {
|
||||
"search": "stable-diffusion",
|
||||
"sort": "downloads",
|
||||
"direction": "-1",
|
||||
"limit": {limit},
|
||||
"limit": {num_of_models},
|
||||
"full": "true",
|
||||
}
|
||||
response = requests.get(path, params=params)
|
||||
@@ -72,65 +71,87 @@ def get_image_from_model(model_json):
|
||||
return image
|
||||
|
||||
|
||||
hf_model_list = get_hf_list()
|
||||
civit_model_list = get_civit_list()
|
||||
|
||||
|
||||
with gr.Blocks() as model_web:
|
||||
model_source = gr.Radio(
|
||||
choices=["Hugging Face", "Civitai"],
|
||||
type="index",
|
||||
value="Hugging Face",
|
||||
label="Model Source",
|
||||
with gr.Row():
|
||||
model_source = gr.Radio(
|
||||
value=None,
|
||||
choices=["Hugging Face", "Civitai"],
|
||||
type="value",
|
||||
label="Model Source",
|
||||
)
|
||||
model_numebr = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Number of models",
|
||||
interactive=True,
|
||||
)
|
||||
# TODO: add more filters
|
||||
get_model_btn = gr.Button(value="Get Models")
|
||||
|
||||
hf_models = gr.Dropdown(
|
||||
label="Hugging Face Model List",
|
||||
choices=None,
|
||||
value=None,
|
||||
visible=False,
|
||||
)
|
||||
with gr.Column(visible=True) as hf_block:
|
||||
for model in hf_model_list:
|
||||
with gr.Row():
|
||||
model_url = gr.Textbox(
|
||||
label="Model ID:",
|
||||
value=model["modelId"],
|
||||
lines=1,
|
||||
interactive=False,
|
||||
)
|
||||
model_info = gr.Textbox(
|
||||
value=f'Download Count: {model["downloads"]}{os.linesep}Favorite Count: {model["likes"]}',
|
||||
lines=2,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
)
|
||||
with gr.Column(visible=False) as civit_block:
|
||||
for model in civit_model_list:
|
||||
with gr.Row():
|
||||
# TODO: select and SendTo
|
||||
civit_models = gr.Gallery(
|
||||
label="Civitai Model Gallery",
|
||||
value=None,
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False) as sendto_btns:
|
||||
modelmanager_sendto_txt2img = gr.Button(value="SendTo Txt2Img")
|
||||
modelmanager_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
modelmanager_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
|
||||
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
|
||||
|
||||
def get_model_list(model_source, model_numebr):
|
||||
if model_source == "Hugging Face":
|
||||
hf_model_list = get_hf_list(model_numebr)
|
||||
models = []
|
||||
for model in hf_model_list:
|
||||
# TODO: add model info
|
||||
models.append(f'{model["modelId"]}')
|
||||
return (
|
||||
gr.Dropdown.update(choices=models, visible=True),
|
||||
gr.Gallery.update(value=None, visible=False),
|
||||
gr.Row.update(visible=True),
|
||||
)
|
||||
elif model_source == "Civitai":
|
||||
civit_model_list = get_civit_list(model_numebr)
|
||||
models = []
|
||||
for model in civit_model_list:
|
||||
image = get_image_from_model(model)
|
||||
if image is None:
|
||||
continue
|
||||
model_img = Image.open(image)
|
||||
gr.Image(
|
||||
value=model_img,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=300, height=300)
|
||||
with gr.Column():
|
||||
gr.Textbox(
|
||||
label=f'{model["modelName"]}',
|
||||
value=f'Rating: {model["rating"]}{os.linesep}Favorite Count: {model["favoriteCount"]}{os.linesep}Download Count: {model["downloadCount"]}{os.linesep}File Format: {model["files"][0]["metadata"]["format"]}',
|
||||
lines=4,
|
||||
)
|
||||
gr.Textbox(
|
||||
label="Download URL:",
|
||||
value=f'{model["files"][0]["downloadUrl"]}',
|
||||
lines=1,
|
||||
)
|
||||
|
||||
def update_model_list(model_source):
|
||||
if model_source:
|
||||
return gr.update(visible=False), gr.update(visible=True)
|
||||
# TODO: add model info
|
||||
models.append(
|
||||
(Image.open(image), f'{model["files"][0]["downloadUrl"]}')
|
||||
)
|
||||
return (
|
||||
gr.Dropdown.update(value=None, choices=None, visible=False),
|
||||
gr.Gallery.update(value=models, visible=True),
|
||||
gr.Row.update(visible=False),
|
||||
)
|
||||
else:
|
||||
return gr.update(visible=True), gr.update(visible=False)
|
||||
return (
|
||||
gr.Dropdown.update(value=None, choices=None, visible=False),
|
||||
gr.Gallery.update(value=None, visible=False),
|
||||
gr.Row.update(visible=False),
|
||||
)
|
||||
|
||||
model_source.change(
|
||||
fn=update_model_list,
|
||||
inputs=model_source,
|
||||
outputs=[hf_block, civit_block],
|
||||
get_model_btn.click(
|
||||
fn=get_model_list,
|
||||
inputs=[model_source, model_numebr],
|
||||
outputs=[
|
||||
hf_models,
|
||||
civit_models,
|
||||
sendto_btns,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
@@ -23,10 +21,13 @@ from apps.stable_diffusion.src import (
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
@@ -222,9 +223,11 @@ def outpaint_inf(
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Outpaint", i + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
@@ -317,17 +320,19 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
outpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
outpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
@@ -522,16 +527,13 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
outpaint_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -558,8 +560,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -570,13 +572,20 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[outpaint_gallery, std_output],
|
||||
outputs=[outpaint_gallery, std_output, outpaint_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Outpaint", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=outpaint_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
419
apps/stable_diffusion/web/ui/outputgallery_ui.py
Normal file
419
apps/stable_diffusion/web/ui/outputgallery_ui.py
Normal file
@@ -0,0 +1,419 @@
|
||||
import glob
|
||||
import gradio as gr
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
gradio_tmp_galleries_folder,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
|
||||
|
||||
# -- Functions for file, directory and image info querying
|
||||
|
||||
output_dir = get_generated_imgs_path()
|
||||
|
||||
|
||||
def outputgallery_filenames(subdir) -> list[str]:
|
||||
new_dir_path = os.path.join(output_dir, subdir)
|
||||
if os.path.exists(new_dir_path):
|
||||
filenames = [
|
||||
glob.glob(new_dir_path + "/" + ext)
|
||||
for ext in ("*.png", "*.jpg", "*.jpeg")
|
||||
]
|
||||
|
||||
return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True)
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def output_subdirs() -> list[str]:
|
||||
# Gets a list of subdirectories of output_dir and below, as relative paths.
|
||||
relative_paths = [
|
||||
os.path.relpath(entry[0], output_dir)
|
||||
for entry in os.walk(
|
||||
output_dir, followlinks=args.output_gallery_followlinks
|
||||
)
|
||||
]
|
||||
|
||||
# It is less confusing to always including the subdir that will take any images generated
|
||||
# today even if it doesn't exist yet
|
||||
if get_generated_imgs_todays_subdir() not in relative_paths:
|
||||
relative_paths.append(get_generated_imgs_todays_subdir())
|
||||
|
||||
# sort subdirectories so that that the date named ones we probably created in this or
|
||||
# previous sessions come first, sorted with the most recent first. Other subdirs are listed
|
||||
# after.
|
||||
generated_paths = sorted(
|
||||
[path for path in relative_paths if path.isnumeric()], reverse=True
|
||||
)
|
||||
result_paths = generated_paths + sorted(
|
||||
[
|
||||
path
|
||||
for path in relative_paths
|
||||
if (not path.isnumeric()) and path != "."
|
||||
]
|
||||
)
|
||||
|
||||
return result_paths
|
||||
|
||||
|
||||
# --- Define UI layout for Gradio
|
||||
|
||||
with gr.Blocks() as outputgallery_web:
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
|
||||
with gr.Row(elem_id="outputgallery_gallery"):
|
||||
# needed to workaround gradio issue: https://github.com/gradio-app/gradio/issues/2907
|
||||
dev_null = gr.Textbox("", visible=False)
|
||||
|
||||
gallery_files = gr.State(value=[])
|
||||
subdirectory_paths = gr.State(value=[])
|
||||
|
||||
with gr.Column(scale=6):
|
||||
logo = gr.Image(
|
||||
label="Getting subdirectories...",
|
||||
value=nod_logo,
|
||||
interactive=False,
|
||||
visible=True,
|
||||
show_label=True,
|
||||
elem_id="top_logo",
|
||||
elem_classes="logo_centered",
|
||||
)
|
||||
|
||||
gallery = gr.Gallery(
|
||||
label="",
|
||||
value=gallery_files.value,
|
||||
visible=False,
|
||||
show_label=True,
|
||||
).style(grid=4)
|
||||
gallery.DEFAULT_TEMP_DIR = gradio_tmp_galleries_folder
|
||||
|
||||
with gr.Column(scale=4):
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column(scale=16, min_width=160):
|
||||
subdirectories = gr.Dropdown(
|
||||
label=f"Subdirectories of {output_dir}",
|
||||
type="value",
|
||||
choices=subdirectory_paths.value,
|
||||
value="",
|
||||
interactive=True,
|
||||
).style(container=False)
|
||||
with gr.Column(
|
||||
scale=1, min_width=32, elem_id="output_refresh_button"
|
||||
):
|
||||
refresh = gr.Button(
|
||||
variant="secondary",
|
||||
value="\u21BB", # unicode clockwise arrow circle
|
||||
).style(size="sm")
|
||||
|
||||
image_columns = gr.Slider(
|
||||
label="Columns shown", value=4, minimum=1, maximum=16, step=1
|
||||
)
|
||||
outputgallery_filename = gr.Textbox(
|
||||
label="Filename", value="None", interactive=False
|
||||
).style(show_copy_button=True)
|
||||
|
||||
with gr.Accordion(
|
||||
label="Parameter Information", open=False
|
||||
) as parameters_accordian:
|
||||
image_parameters = gr.DataFrame(
|
||||
headers=["Parameter", "Value"],
|
||||
col_count=2,
|
||||
wrap=True,
|
||||
elem_classes="output_parameters_dataframe",
|
||||
value=[["Status", "No image selected"]],
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Send To", open=True):
|
||||
with gr.Row():
|
||||
outputgallery_sendto_txt2img = gr.Button(
|
||||
value="Txt2Img",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
outputgallery_sendto_img2img = gr.Button(
|
||||
value="Img2Img",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
outputgallery_sendto_inpaint = gr.Button(
|
||||
value="Inpaint",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
outputgallery_sendto_outpaint = gr.Button(
|
||||
value="Outpaint",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
outputgallery_sendto_upscaler = gr.Button(
|
||||
value="Upscaler",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
# --- Event handlers
|
||||
|
||||
def on_clear_gallery():
|
||||
return [
|
||||
gr.Gallery.update(
|
||||
value=[],
|
||||
visible=False,
|
||||
),
|
||||
gr.Image.update(
|
||||
visible=True,
|
||||
),
|
||||
]
|
||||
|
||||
def on_select_subdir(subdir) -> list:
|
||||
# evt.value is the subdirectory name
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = (
|
||||
f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
|
||||
)
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
|
||||
def on_refresh(current_subdir: str) -> list:
|
||||
# get an up to date subdirectory list
|
||||
refreshed_subdirs = output_subdirs()
|
||||
# get the images using either the current subdirectory or the most recent valid one
|
||||
new_subdir = (
|
||||
current_subdir
|
||||
if current_subdir in refreshed_subdirs
|
||||
else refreshed_subdirs[0]
|
||||
)
|
||||
new_images = outputgallery_filenames(new_subdir)
|
||||
new_label = f"{len(new_images)} images in {os.path.join(output_dir, new_subdir)}"
|
||||
|
||||
return [
|
||||
gr.Dropdown.update(
|
||||
choices=refreshed_subdirs,
|
||||
value=new_subdir,
|
||||
),
|
||||
refreshed_subdirs,
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
value=new_images, label=new_label, visible=len(new_images) > 0
|
||||
),
|
||||
gr.Image.update(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
|
||||
def on_new_image(subdir, subdir_paths, status) -> list:
|
||||
# prevent error triggered when an image generates before the tab has even been selected
|
||||
subdir_paths = (
|
||||
subdir_paths
|
||||
if len(subdir_paths) > 0
|
||||
else [get_generated_imgs_todays_subdir()]
|
||||
)
|
||||
|
||||
# only update if the current subdir is the most recent one as new images only go there
|
||||
if subdir_paths[0] == subdir:
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"
|
||||
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
else:
|
||||
# otherwise change nothing, (only untyped gradio gr.update() does this)
|
||||
return [gr.update(), gr.update(), gr.update()]
|
||||
|
||||
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
|
||||
# evt.index is an index into the full list of filenames for the current subdirectory
|
||||
filename = images[evt.index]
|
||||
params = displayable_metadata(filename)
|
||||
|
||||
if params:
|
||||
return [
|
||||
filename,
|
||||
list(map(list, params["parameters"].items())),
|
||||
]
|
||||
|
||||
return [
|
||||
filename,
|
||||
[["Status", "No parameters found"]],
|
||||
]
|
||||
|
||||
def on_outputgallery_filename_change(filename: str) -> list:
|
||||
exists = filename != "None" and os.path.exists(filename)
|
||||
return [
|
||||
# disable or enable each of the sendto button based on whether an image is selected
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
]
|
||||
|
||||
# The time first our tab is selected we need to do an initial refresh to populate
|
||||
# the subdirectory select box and the images from the most recent subdirectory.
|
||||
#
|
||||
# We do it at this point rather than setting this up in the controls' definitions
|
||||
# as when you refresh the browser you always get what was *initially* set, which
|
||||
# won't include any new subdirectories or images that might have created since
|
||||
# the application was started. Doing it this way means a browser refresh/reload
|
||||
# always gets the most up to date data.
|
||||
def on_select_tab(subdir_paths):
|
||||
if len(subdir_paths) == 0:
|
||||
return on_refresh("")
|
||||
else:
|
||||
return (
|
||||
# Change nothing, (only untyped gr.update() does this)
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
# Unfortunately as of gradio 3.22.0 gr.update against Galleries doesn't support
|
||||
# things set with .style, nor the elem_classes kwarg so we have to directly set
|
||||
# things up via JavaScript if we want the client to take notice of any of our
|
||||
# changes to the number of columns after it decides to put them back to the
|
||||
# original number when we change something
|
||||
def js_set_columns_in_browser(timeout_length):
|
||||
return f"""
|
||||
(new_cols) => {{
|
||||
setTimeout(() => {{
|
||||
required_style = "auto ".repeat(new_cols).trim();
|
||||
gallery = document.querySelector('#outputgallery_gallery .grid-container');
|
||||
if (gallery) {{
|
||||
gallery.style.gridTemplateColumns = required_style
|
||||
}}
|
||||
}}, {timeout_length});
|
||||
return []; // prevents console error from gradio
|
||||
}}
|
||||
"""
|
||||
|
||||
# --- Wire handlers up to the actions
|
||||
|
||||
# - Many actions reset the number of columns shown in the gallery on the browser end,
|
||||
# so we have to set them back to what we think they should be after the initial
|
||||
# action.
|
||||
# - None of the actions on this tab trigger inference, and we want the user to be able
|
||||
# to do them whilst other tabs have ongoing inference running. Waiting in the queue
|
||||
# behind inference jobs would mean the UI can't fully respond until the inference tasks
|
||||
# complete, hence queue=False on all of these.
|
||||
set_gallery_columns_immediate = dict(
|
||||
fn=None,
|
||||
inputs=[image_columns],
|
||||
# gradio blanks the UI on Chrome on Linux on gallery select if I don't put an output here
|
||||
outputs=[dev_null],
|
||||
_js=js_set_columns_in_browser(0),
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# setting columns after selecting a gallery item needs a real timeout length for the
|
||||
# number of columns to actually be applied. Not really sure why, maybe something has
|
||||
# to finish animating?
|
||||
set_gallery_columns_delayed = dict(
|
||||
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
|
||||
)
|
||||
|
||||
# clearing images when we need to completely change what's in the gallery avoids current
|
||||
# images being shown replacing piecemeal and prevents weirdness and errors if the user
|
||||
# selects an image during the replacement phase.
|
||||
clear_gallery = dict(
|
||||
fn=on_clear_gallery,
|
||||
inputs=None,
|
||||
outputs=[gallery, logo],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
image_columns.change(**set_gallery_columns_immediate)
|
||||
|
||||
subdirectories.select(**clear_gallery).then(
|
||||
on_select_subdir,
|
||||
[subdirectories],
|
||||
[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
refresh.click(**clear_gallery).then(
|
||||
on_refresh,
|
||||
[subdirectories],
|
||||
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
gallery.select(
|
||||
on_select_image,
|
||||
[gallery_files],
|
||||
[outputgallery_filename, image_parameters],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_delayed)
|
||||
|
||||
outputgallery_filename.change(
|
||||
on_outputgallery_filename_change,
|
||||
[outputgallery_filename],
|
||||
[
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
outputgallery_sendto_upscaler,
|
||||
],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# We should have been given the .select function for our tab, so set it up
|
||||
def outputgallery_tab_select(select):
|
||||
select(
|
||||
fn=on_select_tab,
|
||||
inputs=[subdirectory_paths],
|
||||
outputs=[
|
||||
subdirectories,
|
||||
subdirectory_paths,
|
||||
gallery_files,
|
||||
gallery,
|
||||
logo,
|
||||
],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
# We should have been passed a list of components on other tabs that update
|
||||
# when a new image has generated on that tab, so set things up so the user
|
||||
# will see that new image if they are looking at today's subdirectory
|
||||
def outputgallery_watch(components: gr.Textbox):
|
||||
for component in components:
|
||||
component.change(
|
||||
on_new_image,
|
||||
inputs=[subdirectories, subdirectory_paths, component],
|
||||
outputs=[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
185
apps/stable_diffusion/web/ui/stablelm_ui.py
Normal file
185
apps/stable_diffusion/web/ui/stablelm_ui.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import gradio as gr
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
|
||||
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
sharkModel = 0
|
||||
sharded_model = 0
|
||||
vicuna_model = 0
|
||||
|
||||
|
||||
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
past_key_values = None
|
||||
|
||||
|
||||
def chat(curr_system_message, history, model, device, precision):
|
||||
print(f"In chat for {model}")
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
global vicuna_model
|
||||
if "vicuna" in model:
|
||||
from apps.language_models.src.pipelines.vicuna_pipeline import (
|
||||
Vicuna,
|
||||
)
|
||||
|
||||
curr_system_message = start_message_vicuna
|
||||
if vicuna_model == 0:
|
||||
first_vic_vmfb_path = Path("first_vicuna.vmfb")
|
||||
second_vic_vmfb_path = Path("second_vicuna.vmfb")
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
vicuna_model = Vicuna(
|
||||
"vicuna",
|
||||
hf_model_path=model,
|
||||
device=device,
|
||||
precision=precision,
|
||||
first_vicuna_vmfb_path=first_vic_vmfb_path,
|
||||
second_vicuna_vmfb_path=second_vic_vmfb_path,
|
||||
)
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
prompt = messages.strip()
|
||||
print("prompt = ", prompt)
|
||||
sentence = vicuna_model.generate(prompt)
|
||||
|
||||
partial_text = ""
|
||||
for new_text in sentence.split(" "):
|
||||
# print(new_text)
|
||||
partial_text += new_text + " "
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to cleanup the message textbox and the updated conversation history
|
||||
yield history
|
||||
history[-1][1] = sentence
|
||||
return history
|
||||
|
||||
# else Model is StableLM
|
||||
global sharkModel
|
||||
from apps.language_models.src.pipelines.stablelm_pipeline import (
|
||||
SharkStableLM,
|
||||
)
|
||||
|
||||
if sharkModel == 0:
|
||||
# max_new_tokens=512
|
||||
shark_slm = SharkStableLM(
|
||||
"StableLM"
|
||||
) # pass elements from UI as required
|
||||
|
||||
# Construct the input message string for the model by concatenating the current system message and conversation history
|
||||
if len(curr_system_message.split()) > 160:
|
||||
print("clearing context")
|
||||
curr_system_message = start_message
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
|
||||
generate_kwargs = dict(prompt=messages)
|
||||
|
||||
words_list = shark_slm.generate(**generate_kwargs)
|
||||
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
# print(new_text)
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to cleanup the message textbox and the updated conversation history
|
||||
yield history
|
||||
return words_list
|
||||
|
||||
|
||||
with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
with gr.Row():
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value="TheBloke/vicuna-7B-1.1-HF",
|
||||
choices=[
|
||||
"stabilityai/stablelm-tuned-alpha-3b",
|
||||
"TheBloke/vicuna-7B-1.1-HF",
|
||||
],
|
||||
)
|
||||
supported_devices = [
|
||||
device for device in available_devices if "cuda" in device
|
||||
]
|
||||
enabled = len(supported_devices) > 0
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp32",
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
chatbot = gr.Chatbot().style(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
interactive=enabled,
|
||||
).style(container=False)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
system_msg = gr.Textbox(
|
||||
start_message, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[submit_event, submit_click_event],
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
@@ -1,4 +1,3 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
@@ -17,7 +16,8 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.png_metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImagePipeline,
|
||||
@@ -27,7 +27,10 @@ from apps.stable_diffusion.src import (
|
||||
save_output_img,
|
||||
prompt_examples,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
@@ -202,9 +205,11 @@ def txt2img_inf(
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Text-to-Image", i + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
@@ -281,17 +286,17 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
txt2img_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
txt2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
@@ -308,7 +313,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
+ get_custom_model_files("vae"),
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
png_info_img = gr.Image(
|
||||
txt2img_png_info_img = gr.Image(
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
@@ -469,16 +474,13 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
txt2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -502,8 +504,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -514,22 +516,30 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[txt2img_gallery, std_output],
|
||||
outputs=[txt2img_gallery, std_output, txt2img_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Text-to-Image", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=txt2img_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
png_info_img.change(
|
||||
txt2img_png_info_img.change(
|
||||
fn=import_png_metadata,
|
||||
inputs=[
|
||||
png_info_img,
|
||||
txt2img_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
@@ -538,11 +548,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
],
|
||||
outputs=[
|
||||
png_info_img,
|
||||
txt2img_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
@@ -551,7 +561,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
@@ -17,16 +15,16 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_upscaler_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
UpscalerPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
from apps.stable_diffusion.src.utils import get_generated_imgs_path
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
@@ -66,6 +64,9 @@ def upscaler_inf(
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
@@ -202,13 +203,24 @@ def upscaler_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
|
||||
|
||||
save_output_img(high_res_img, img_seed, extra_info)
|
||||
generated_imgs.append(high_res_img)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(high_res_img, img_seed, extra_info)
|
||||
generated_imgs.append(high_res_img)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log, status_label(
|
||||
"Upscaler", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
@@ -220,7 +232,7 @@ def upscaler_inf(
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
@@ -309,17 +321,19 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
upscaler_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-x4-upscaler",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="upscaler"
|
||||
)
|
||||
+ predefined_upscaler_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
upscaler_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
@@ -493,16 +507,14 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
upscaler_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -525,8 +537,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -537,13 +549,21 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[upscaler_gallery, std_output],
|
||||
outputs=[upscaler_gallery, std_output, upscaler_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Upscaler", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=upscaler_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
@@ -72,30 +72,36 @@ def resource_path(relative_path):
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def create_custom_models_folders():
|
||||
dir = ["vae", "lora"]
|
||||
if not args.ckpt_dir:
|
||||
dir.insert(0, "models")
|
||||
else:
|
||||
if not os.path.isdir(args.ckpt_dir):
|
||||
sys.exit(
|
||||
f"Invalid --ckpt_dir argument, {args.ckpt_dir} folder does not exists."
|
||||
)
|
||||
for root in dir:
|
||||
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_custom_model_path(model="models"):
|
||||
# If `--ckpt_dir` is provided it'd override the heirarchical folder
|
||||
# structure in WebUI :-
|
||||
# model
|
||||
# models or args.ckpt_dir
|
||||
# |___lora
|
||||
# |___vae
|
||||
sub_folder = "" if model == "models" else model
|
||||
if args.ckpt_dir:
|
||||
return Path(args.ckpt_dir)
|
||||
match model:
|
||||
case "models":
|
||||
return Path(Path.cwd(), "models")
|
||||
case "vae":
|
||||
return Path(Path.cwd(), "models/vae")
|
||||
case "lora":
|
||||
return Path(Path.cwd(), "models/lora")
|
||||
case _:
|
||||
return ""
|
||||
return Path(Path(args.ckpt_dir), sub_folder)
|
||||
else:
|
||||
return Path(Path.cwd(), "models/" + sub_folder)
|
||||
|
||||
|
||||
def get_custom_model_pathfile(custom_model_name, model="models"):
|
||||
return os.path.join(get_custom_model_path(model), custom_model_name)
|
||||
|
||||
|
||||
def get_custom_model_files(model="models"):
|
||||
def get_custom_model_files(model="models", custom_checkpoint_type=""):
|
||||
ckpt_files = []
|
||||
file_types = custom_model_filetypes
|
||||
if model == "lora":
|
||||
@@ -107,6 +113,28 @@ def get_custom_model_files(model="models"):
|
||||
os.path.join(get_custom_model_path(model), extn)
|
||||
)
|
||||
]
|
||||
match custom_checkpoint_type:
|
||||
case "inpainting":
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if val.endswith("inpainting" + extn.removeprefix("*"))
|
||||
]
|
||||
case "upscaler":
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if val.endswith("upscaler" + extn.removeprefix("*"))
|
||||
]
|
||||
case _:
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if not (
|
||||
val.endswith("inpainting" + extn.removeprefix("*"))
|
||||
or val.endswith("upscaler" + extn.removeprefix("*"))
|
||||
)
|
||||
]
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
9
apps/stable_diffusion/web/utils/common_label_calc.py
Normal file
9
apps/stable_diffusion/web/utils/common_label_calc.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# functions for generating labels used in common by tabs across the UI
|
||||
|
||||
|
||||
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
|
||||
if batch_index < batch_count:
|
||||
bs = f"x{batch_size}" if batch_size > 1 else ""
|
||||
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
|
||||
else:
|
||||
return f"{tab_name} complete"
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import gradio
|
||||
from os import listdir
|
||||
|
||||
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
gradio_tmp_galleries_folder = os.path.join(gradio_tmp_imgs_folder, "galleries")
|
||||
|
||||
|
||||
# Clear all gradio tmp images
|
||||
@@ -15,6 +17,10 @@ def clear_gradio_tmp_imgs_folder():
|
||||
if fileName.startswith("tmp") and fileName.endswith(".png"):
|
||||
os.remove(gradio_tmp_imgs_folder + fileName)
|
||||
|
||||
# Clear all gradio tmp files created by galleries
|
||||
if os.path.exists(gradio_tmp_galleries_folder):
|
||||
shutil.rmtree(gradio_tmp_galleries_folder)
|
||||
|
||||
|
||||
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
|
||||
6
apps/stable_diffusion/web/utils/metadata/__init__.py
Normal file
6
apps/stable_diffusion/web/utils/metadata/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .png_metadata import (
|
||||
import_png_metadata,
|
||||
)
|
||||
from .display import (
|
||||
displayable_metadata,
|
||||
)
|
||||
31
apps/stable_diffusion/web/utils/metadata/csv_metadata.py
Normal file
31
apps/stable_diffusion/web/utils/metadata/csv_metadata.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import csv
|
||||
import os
|
||||
from .format import humanize, humanizable
|
||||
|
||||
|
||||
def csv_path(image_filename: str):
|
||||
return os.path.join(os.path.dirname(image_filename), "imgs_details.csv")
|
||||
|
||||
|
||||
def has_csv(image_filename: str) -> bool:
|
||||
return os.path.exists(csv_path(image_filename))
|
||||
|
||||
|
||||
def parse_csv(image_filename: str):
|
||||
# We use a reader instead of a DictReader here for images_details.csv files due to the lack of
|
||||
# headers, and then match up the return list for each row with our guess at which column format
|
||||
# the file is using.
|
||||
|
||||
# we assume the final column of the csv has the original filename with full path and match that
|
||||
# against the image_filename. We then exclude the filename from the output, hence the -1's.
|
||||
csv_filename = csv_path(image_filename)
|
||||
|
||||
matches = [
|
||||
humanize(row)
|
||||
for row in csv.reader(open(csv_filename, "r", newline=""))
|
||||
if row
|
||||
and humanizable(row)
|
||||
and os.path.basename(image_filename) in row[-1]
|
||||
]
|
||||
|
||||
return matches[0] if matches else {}
|
||||
50
apps/stable_diffusion/web/utils/metadata/display.py
Normal file
50
apps/stable_diffusion/web/utils/metadata/display.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import json
|
||||
import os
|
||||
from PIL import Image
|
||||
from .png_metadata import parse_generation_parameters
|
||||
from .exif_metadata import has_exif, parse_exif
|
||||
from .csv_metadata import has_csv, parse_csv
|
||||
from .format import compact, humanize
|
||||
|
||||
|
||||
def displayable_metadata(image_filename: str) -> dict:
|
||||
pil_image = Image.open(image_filename)
|
||||
|
||||
# we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads,
|
||||
# and we go via that for SendTo, and is directly tied to the image)
|
||||
if "parameters" in pil_image.info:
|
||||
return {
|
||||
"source": "png",
|
||||
"parameters": compact(
|
||||
parse_generation_parameters(pil_image.info["parameters"])
|
||||
),
|
||||
}
|
||||
|
||||
# we have a matching json file (next most likely to be accurate when it's there)
|
||||
json_path = os.path.splitext(image_filename)[0] + ".json"
|
||||
if os.path.isfile(json_path):
|
||||
with open(json_path) as params_file:
|
||||
return {
|
||||
"source": "json",
|
||||
"parameters": compact(
|
||||
humanize(json.load(params_file), includes_filename=False)
|
||||
),
|
||||
}
|
||||
|
||||
# we have a CSV file so try that (can be different shapes, and it usually has no
|
||||
# headers/param names so of the things we we *know* have parameters, it's the
|
||||
# last resort)
|
||||
if has_csv(image_filename):
|
||||
params = parse_csv(image_filename)
|
||||
if params: # we might not have found the filename in the csv
|
||||
return {
|
||||
"source": "csv",
|
||||
"parameters": compact(params), # already humanized
|
||||
}
|
||||
|
||||
# EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something*
|
||||
if has_exif(image_filename):
|
||||
return {"source": "exif", "parameters": parse_exif(pil_image)}
|
||||
|
||||
# we've got nothing
|
||||
return None
|
||||
52
apps/stable_diffusion/web/utils/metadata/exif_metadata.py
Normal file
52
apps/stable_diffusion/web/utils/metadata/exif_metadata.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from PIL import Image
|
||||
from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS
|
||||
|
||||
|
||||
def has_exif(image_filename: str) -> bool:
|
||||
return True if Image.open(image_filename).getexif() else False
|
||||
|
||||
|
||||
def parse_exif(pil_image: Image) -> dict:
|
||||
img_exif = pil_image.getexif()
|
||||
|
||||
# See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594
|
||||
# I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I
|
||||
# I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a
|
||||
# dependency
|
||||
exif_tags = {
|
||||
TAGS.get(key, key): str(val)
|
||||
for (key, val) in img_exif.items()
|
||||
if key in TAGS
|
||||
and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo)
|
||||
and val
|
||||
and (not isinstance(val, bytes))
|
||||
and (not str(val).isspace())
|
||||
}
|
||||
|
||||
def try_get_ifd(ifd_id):
|
||||
try:
|
||||
return img_exif.get_ifd(ifd_id).items()
|
||||
except KeyError:
|
||||
return {}
|
||||
|
||||
ifd_tags = {
|
||||
TAGS.get(key, key): str(val)
|
||||
for ifd_id in IFD
|
||||
for (key, val) in try_get_ifd(ifd_id)
|
||||
if ifd_id != IFD.GPSInfo
|
||||
and key in TAGS
|
||||
and val
|
||||
and (not isinstance(val, bytes))
|
||||
and (not str(val).isspace())
|
||||
}
|
||||
|
||||
gps_tags = {
|
||||
GPSTAGS.get(key, key): str(val)
|
||||
for (key, val) in try_get_ifd(IFD.GPSInfo)
|
||||
if key in GPSTAGS
|
||||
and val
|
||||
and (not isinstance(val, bytes))
|
||||
and (not str(val).isspace())
|
||||
}
|
||||
|
||||
return {**exif_tags, **ifd_tags, **gps_tags}
|
||||
115
apps/stable_diffusion/web/utils/metadata/format.py
Normal file
115
apps/stable_diffusion/web/utils/metadata/format.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# As SHARK has evolved more columns have been added to images_details.csv. However, since
|
||||
# no version of the CSV has any headers (yet) we don't actually have anything within the
|
||||
# file that tells us which parameter each column is for. So this is a list of known patterns
|
||||
# indexed by length which is what we're going to have to use to guess which columns are the
|
||||
# right ones for the file we're looking at.
|
||||
|
||||
# The same ordering is used for JSON, but these do have key names, however they are not very
|
||||
# human friendly, nor do they match up with the what is written to the .png headers
|
||||
|
||||
# So these are functions to try and get something consistent out the raw input from all
|
||||
# these sources
|
||||
|
||||
PARAMS_FORMATS = {
|
||||
9: {
|
||||
"VARIANT": "Model",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"OUTPUT": "Filename",
|
||||
},
|
||||
10: {
|
||||
"MODEL": "Model",
|
||||
"VARIANT": "Variant",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"OUTPUT": "Filename",
|
||||
},
|
||||
12: {
|
||||
"VARIANT": "Model",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"HEIGHT": "Height",
|
||||
"WIDTH": "Width",
|
||||
"MAX_LENGTH": "Max Length",
|
||||
"OUTPUT": "Filename",
|
||||
},
|
||||
}
|
||||
|
||||
PARAMS_FORMAT_LONGEST = PARAMS_FORMATS[max(PARAMS_FORMATS.keys())]
|
||||
|
||||
|
||||
def compact(metadata: dict) -> dict:
|
||||
# we don't want to alter the original dictionary
|
||||
result = dict(metadata)
|
||||
|
||||
# discard the filename because we should already have it
|
||||
if result.keys() & {"Filename"}:
|
||||
result.pop("Filename")
|
||||
|
||||
# make showing the sizes more compact by using only one line each
|
||||
if result.keys() & {"Size-1", "Size-2"}:
|
||||
result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}"
|
||||
elif result.keys() & {"Height", "Width"}:
|
||||
result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}"
|
||||
|
||||
if result.keys() & {"Hires resize-1", "Hires resize-1"}:
|
||||
hires_y = result.pop("Hires resize-1")
|
||||
hires_x = result.pop("Hires resize-2")
|
||||
|
||||
if hires_x == 0 and hires_y == 0:
|
||||
result["Hires resize"] = "None"
|
||||
else:
|
||||
result["Hires resize"] = f"{hires_y}x{hires_x}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def humanizable(metadata: dict | list[str], includes_filename=True) -> dict:
|
||||
lookup_key = len(metadata) + (0 if includes_filename else 1)
|
||||
return lookup_key in PARAMS_FORMATS.keys()
|
||||
|
||||
|
||||
def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
|
||||
lookup_key = len(metadata) + (0 if includes_filename else 1)
|
||||
|
||||
# For lists we can only work based on the length, we have no other information
|
||||
if isinstance(metadata, list):
|
||||
if humanizable(metadata, includes_filename):
|
||||
return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata))
|
||||
else:
|
||||
raise KeyError(
|
||||
f"Humanize could not find the format for a parameter list of length {len(metadata)}"
|
||||
)
|
||||
|
||||
# For dictionaries we try to use the matching length parameter format if
|
||||
# available, otherwise we use the longest. Then we swap keys in the
|
||||
# metadata that match keys in the format for the friendlier name that we
|
||||
# have set in the format value
|
||||
if isinstance(metadata, dict):
|
||||
if humanizable(metadata, includes_filename):
|
||||
format = PARAMS_FORMATS[lookup_key]
|
||||
else:
|
||||
format = PARAMS_FORMAT_LONGEST
|
||||
|
||||
return {
|
||||
format[key]: value
|
||||
for (key, value) in metadata.items()
|
||||
if key in format.keys()
|
||||
}
|
||||
|
||||
raise TypeError("Can only humanize parameter lists or dictionaries")
|
||||
@@ -40,7 +40,7 @@ cmake --build build/
|
||||
*Prepare the model*
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
```
|
||||
*Prepare the input*
|
||||
|
||||
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
|
||||
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
|
||||
```
|
||||
VAE and Autoencoder are also available
|
||||
```bash
|
||||
# VAE
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
|
||||
|
||||
# CLIP Autoencoder
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
|
||||
```
|
||||
|
||||
@@ -21,7 +21,7 @@ endif()
|
||||
# Compile mnist.mlir to mnist.vmfb.
|
||||
set(_COMPILE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)
|
||||
set(_COMPILE_ARGS)
|
||||
list(APPEND _COMPILE_ARGS "--iree-input-type=mhlo")
|
||||
list(APPEND _COMPILE_ARGS "--iree-input-type=auto")
|
||||
list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu")
|
||||
list(APPEND _COMPILE_ARGS "${IREE_SOURCE_DIR}/samples/models/mnist.mlir")
|
||||
list(APPEND _COMPILE_ARGS "-o")
|
||||
|
||||
@@ -19,7 +19,7 @@ transformers
|
||||
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
gradio==3.22.0
|
||||
altair
|
||||
omegaconf
|
||||
safetensors
|
||||
@@ -28,6 +28,7 @@ scikit-image
|
||||
pytorch_lightning # for runwayml models
|
||||
tk
|
||||
pywebview
|
||||
sentencepiece
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
# Sets up a venv suitable for running samples.
|
||||
# e.g:
|
||||
# ./setup_venv.sh #setup a default $PYTHON3 shark.venv
|
||||
# Environment Variables by the script.
|
||||
# Environment variables used by the script.
|
||||
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
|
||||
# VENV_DIR=myshark.venv #create a venv called myshark.venv
|
||||
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
|
||||
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
|
||||
# IMPORTER=1 #Install importer deps
|
||||
# BENCHMARK=1 #Install benchmark deps
|
||||
@@ -26,15 +27,17 @@ PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; prin
|
||||
echo "Python: $PYTHON"
|
||||
echo "Python version: $PYTHON_VERSION_X_Y"
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
# Not a conda env. So create a new VENV dir
|
||||
VENV_DIR=${VENV_DIR:-shark.venv}
|
||||
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
|
||||
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
|
||||
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
|
||||
PYTHON="$(which python3)"
|
||||
else
|
||||
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
|
||||
if [[ "$SKIP_VENV" != "1" ]]; then
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
# Not a conda env. So create a new VENV dir
|
||||
VENV_DIR=${VENV_DIR:-shark.venv}
|
||||
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
|
||||
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
|
||||
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
|
||||
PYTHON="$(which python3)"
|
||||
else
|
||||
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
|
||||
fi
|
||||
fi
|
||||
|
||||
Red=`tput setaf 1`
|
||||
@@ -147,8 +150,7 @@ if [[ ! -z "${ONNX}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
|
||||
echo "${Green}Before running examples activate venv with:"
|
||||
echo " ${Green}source $VENV_DIR/bin/activate"
|
||||
fi
|
||||
|
||||
|
||||
73
shark/examples/shark_inference/minilm_jax.py
Normal file
73
shark/examples/shark_inference/minilm_jax.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from transformers import AutoTokenizer, FlaxAutoModel
|
||||
import torch
|
||||
import jax
|
||||
from typing import Union, Dict, List, Any
|
||||
import numpy as np
|
||||
from shark.shark_inference import SharkInference
|
||||
import io
|
||||
|
||||
NumpyTree = Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]
|
||||
|
||||
|
||||
def convert_torch_tensor_tree_to_numpy(
|
||||
tree: Union[torch.tensor, Dict[str, torch.tensor], List[torch.tensor]]
|
||||
) -> NumpyTree:
|
||||
return jax.tree_util.tree_map(
|
||||
lambda torch_tensor: torch_tensor.cpu().detach().numpy(), tree
|
||||
)
|
||||
|
||||
|
||||
def convert_int64_to_int32(tree: NumpyTree) -> NumpyTree:
|
||||
return jax.tree_util.tree_map(
|
||||
lambda tensor: np.array(tensor, dtype=np.int32)
|
||||
if tensor.dtype == np.int64
|
||||
else tensor,
|
||||
tree,
|
||||
)
|
||||
|
||||
|
||||
def get_sample_input():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
inputs_torch = tokenizer("Hello, World!", return_tensors="pt")
|
||||
return convert_int64_to_int32(
|
||||
convert_torch_tensor_tree_to_numpy(inputs_torch.data)
|
||||
)
|
||||
|
||||
|
||||
def get_jax_model():
|
||||
return FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
|
||||
|
||||
def export_jax_to_mlir(jax_model: Any, sample_input: NumpyTree):
|
||||
model_mlir = jax.jit(jax_model).lower(**sample_input).compiler_ir()
|
||||
byte_stream = io.BytesIO()
|
||||
model_mlir.operation.write_bytecode(file=byte_stream)
|
||||
return byte_stream.getvalue()
|
||||
|
||||
|
||||
def assert_array_list_allclose(x, y, *args, **kwargs):
|
||||
assert len(x) == len(y)
|
||||
for a, b in zip(x, y):
|
||||
np.testing.assert_allclose(
|
||||
np.asarray(a), np.asarray(b), *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
sample_input = get_sample_input()
|
||||
jax_model = get_jax_model()
|
||||
mlir = export_jax_to_mlir(jax_model, sample_input)
|
||||
|
||||
# Compile and load module.
|
||||
shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo")
|
||||
shark_inference.compile()
|
||||
|
||||
# Run main function.
|
||||
result = shark_inference("main", jax.tree_util.tree_flatten(sample_input)[0])
|
||||
|
||||
# Run JAX model.
|
||||
reference_result = jax.tree_util.tree_flatten(jax_model(**sample_input))[0]
|
||||
|
||||
# Verify result.
|
||||
assert_array_list_allclose(result, reference_result, atol=1e-5)
|
||||
@@ -0,0 +1,6 @@
|
||||
flax
|
||||
jax[cpu]
|
||||
nodai-SHARK
|
||||
orbax
|
||||
transformers
|
||||
torch
|
||||
@@ -70,11 +70,11 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"resnet50", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
|
||||
shark_module = SharkInference(mlir_model, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
path = shark_module.save_module()
|
||||
shark_module.load_module(path)
|
||||
result = shark_module.forward((img.detach().numpy(),))
|
||||
result = shark_module("forward", (img.detach().numpy(),))
|
||||
|
||||
print("The top 3 results obtained via shark_runner is:")
|
||||
print(top3_possibilities(torch.from_numpy(result)))
|
||||
|
||||
@@ -45,10 +45,15 @@ def run_cmd(cmd, debug=False):
|
||||
|
||||
def iree_device_map(device):
|
||||
uri_parts = device.split("://", 2)
|
||||
iree_driver = (
|
||||
_IREE_DEVICE_MAP[uri_parts[0]]
|
||||
if uri_parts[0] in _IREE_DEVICE_MAP
|
||||
else uri_parts[0]
|
||||
)
|
||||
if len(uri_parts) == 1:
|
||||
return _IREE_DEVICE_MAP[uri_parts[0]]
|
||||
return iree_driver
|
||||
else:
|
||||
return f"{_IREE_DEVICE_MAP[uri_parts[0]]}://{uri_parts[1]}"
|
||||
return f"{iree_driver}://{uri_parts[1]}"
|
||||
|
||||
|
||||
def get_supported_device_list():
|
||||
@@ -57,6 +62,8 @@ def get_supported_device_list():
|
||||
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "vulkan",
|
||||
@@ -68,11 +75,13 @@ _IREE_DEVICE_MAP = {
|
||||
def iree_target_map(device):
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
return _IREE_TARGET_MAP[device]
|
||||
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
|
||||
|
||||
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cpu-task": "llvm-cpu",
|
||||
"cpu-sync": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "vulkan",
|
||||
@@ -110,10 +119,8 @@ def check_device_drivers(device):
|
||||
subprocess.check_output("rocminfo")
|
||||
except Exception:
|
||||
return True
|
||||
# Unknown device.
|
||||
else:
|
||||
return True
|
||||
|
||||
# Unknown device. We assume drivers are installed.
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import re
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
print("Configuring for device:" + device)
|
||||
device_uri = device.split("://")
|
||||
if len(device_uri) > 1:
|
||||
if device_uri[0] not in ["vulkan"]:
|
||||
@@ -30,6 +31,9 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
f"Specific device selection only supported for vulkan now."
|
||||
f"Proceeding with {device} as device."
|
||||
)
|
||||
device_num = device_uri[1]
|
||||
else:
|
||||
device_num = 0
|
||||
|
||||
if device_uri[0] == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
@@ -42,7 +46,9 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] in ["metal", "vulkan"]:
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(extra_args=extra_args)
|
||||
return get_iree_vulkan_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device_uri[0] == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
@@ -54,7 +60,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
def get_iree_frontend_args(frontend):
|
||||
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
|
||||
return ["--iree-llvmcpu-target-cpu-features=host"]
|
||||
elif frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
return [
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
@@ -259,8 +265,8 @@ def compile_module_to_flatbuffer(
|
||||
args += extra_args
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "mhlo"
|
||||
elif frontend in ["mhlo", "tosa"]:
|
||||
input_type = "auto"
|
||||
elif frontend in ["stablehlo", "tosa"]:
|
||||
input_type = frontend
|
||||
elif frontend in ["tflite", "tflite-tosa"]:
|
||||
input_type = "tosa"
|
||||
@@ -307,7 +313,7 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
)
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
ModuleCompiled = ctx.modules.module
|
||||
ModuleCompiled = getattr(ctx.modules, vm_module.name)
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
@@ -361,7 +367,7 @@ def export_iree_module_to_vmfb(
|
||||
def export_module_to_mlir_file(module, frontend, directory: str):
|
||||
# TODO: write proper documentation.
|
||||
mlir_str = module
|
||||
if frontend in ["tensorflow", "tf", "mhlo", "tflite"]:
|
||||
if frontend in ["tensorflow", "tf", "mhlo", "stablehlo", "tflite"]:
|
||||
mlir_str = module.decode("utf-8")
|
||||
elif frontend in ["pytorch", "torch"]:
|
||||
mlir_str = module.operation.get_asm()
|
||||
|
||||
@@ -117,7 +117,8 @@ def get_extensions(triple):
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("VK_NV_cooperative_matrix")
|
||||
|
||||
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
|
||||
ext.append("VK_KHR_shader_integer_dot_product")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
|
||||
@@ -228,6 +229,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
@@ -236,12 +238,12 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
|
||||
]
|
||||
|
||||
if product == "rx5700xt":
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["storagePushConstant8"] = False
|
||||
@@ -274,7 +276,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
@@ -305,6 +307,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = False
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
@@ -367,6 +370,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = False
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
@@ -408,11 +412,12 @@ def get_vulkan_target_capabilities(triple):
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat16"] = False
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
@@ -446,6 +451,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
|
||||
@@ -21,7 +21,7 @@ from sys import platform
|
||||
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
def get_vulkan_device_name():
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
|
||||
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
@@ -31,8 +31,8 @@ def get_vulkan_device_name():
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing first one: {vulkaninfo_list[0]}")
|
||||
return vulkaninfo_list[0]
|
||||
print(f"Choosing device: {vulkaninfo_list[device_num]}")
|
||||
return vulkaninfo_list[device_num]
|
||||
|
||||
|
||||
def get_os_name():
|
||||
@@ -119,14 +119,14 @@ def get_vulkan_target_triple(device_name):
|
||||
return triple
|
||||
|
||||
|
||||
def get_vulkan_triple_flag(device_name="", extra_args=[]):
|
||||
def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
for flag in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in flag:
|
||||
print(f"Using target triple {flag.split('=')[1]}")
|
||||
return None
|
||||
|
||||
if device_name == "" or device_name == [] or device_name is None:
|
||||
vulkan_device = get_vulkan_device_name()
|
||||
vulkan_device = get_vulkan_device_name(device_num=device_num)
|
||||
else:
|
||||
vulkan_device = device_name
|
||||
triple = get_vulkan_target_triple(vulkan_device)
|
||||
@@ -144,7 +144,7 @@ def get_vulkan_triple_flag(device_name="", extra_args=[]):
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_vulkan_args(extra_args=[]):
|
||||
def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
res_vulkan_flag = []
|
||||
@@ -156,7 +156,9 @@ def get_iree_vulkan_args(extra_args=[]):
|
||||
break
|
||||
|
||||
if vulkan_triple_flag is None:
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)
|
||||
|
||||
@@ -30,8 +30,8 @@ import os
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
|
||||
import iree.compiler._mlir_libs
|
||||
from iree.compiler import ir
|
||||
from iree.compiler.transforms import ireec as ireec_trans
|
||||
|
||||
|
||||
def model_annotation(
|
||||
@@ -311,11 +311,18 @@ def add_attributes(op: ir.Operation, config: List[Dict]):
|
||||
split_k = config["split_k"]
|
||||
elif "SPIRV" in config["pipeline"]:
|
||||
pipeline = config["pipeline"]
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"],
|
||||
config["parallel_tile_sizes"],
|
||||
config["reduction_tile_sizes"],
|
||||
]
|
||||
if pipeline == "SPIRVMatmulPromoteVectorize":
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"]
|
||||
+ [config["reduction_tile_sizes"][-1]],
|
||||
]
|
||||
else:
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"],
|
||||
config["parallel_tile_sizes"],
|
||||
config["reduction_tile_sizes"],
|
||||
]
|
||||
|
||||
workgroup_size = config["work_group_sizes"]
|
||||
if "vector_tile_sizes" in config.keys():
|
||||
tile_sizes += [config["vector_tile_sizes"]]
|
||||
@@ -409,7 +416,6 @@ def shape_list_to_string(input):
|
||||
|
||||
def create_context() -> ir.Context:
|
||||
context = ir.Context()
|
||||
ireec_trans.register_all_dialects(context)
|
||||
context.allow_unregistered_dialects = True
|
||||
return context
|
||||
|
||||
|
||||
@@ -61,6 +61,8 @@ def download_public_file(
|
||||
continue
|
||||
|
||||
destination_filename = os.path.join(destination_folder_name, blob_name)
|
||||
if os.path.isdir(destination_filename):
|
||||
continue
|
||||
with open(destination_filename, "wb") as f:
|
||||
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
|
||||
storage_client.download_blob_to_file(blob, file_obj)
|
||||
@@ -196,7 +198,7 @@ def download_model(
|
||||
tank_url=None,
|
||||
frontend=None,
|
||||
tuned=None,
|
||||
import_args=None,
|
||||
import_args={"batch_size": 1},
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
@@ -210,6 +212,9 @@ def download_model(
|
||||
+ "_BS"
|
||||
+ str(import_args["batch_size"])
|
||||
)
|
||||
elif any(model in model_name for model in ["clip", "unet", "vae"]):
|
||||
# TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation.
|
||||
model_dir_name = model_name
|
||||
else:
|
||||
model_dir_name = model_name + "_" + frontend
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
@@ -270,6 +275,9 @@ def download_model(
|
||||
tuned_str = "" if tuned is None else "_" + tuned
|
||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
print(
|
||||
f"Verifying that model artifacts were downloaded successfully to {filename}..."
|
||||
)
|
||||
if not os.path.exists(filename):
|
||||
from tank.generate_sharktank import gen_shark_files
|
||||
|
||||
|
||||
38
shark/shark_generate_model_config.py
Normal file
38
shark/shark_generate_model_config.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class GenerateConfigFile:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
num_sharding_stages: int,
|
||||
sharding_stages_id: list[str] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.num_sharding_stages = num_sharding_stages
|
||||
self.sharding_stages_id = sharding_stages_id
|
||||
assert self.num_sharding_stages == len(
|
||||
self.sharding_stages_id
|
||||
), "Number of sharding stages should be equal to the list of their ID"
|
||||
|
||||
def generate_json(self):
|
||||
model_dictionary = dict()
|
||||
|
||||
for name, m in self.model.named_modules():
|
||||
if name == "":
|
||||
continue
|
||||
|
||||
# Remove non-leaf nodes from the config as they aren't an operation
|
||||
substring_before_final_period = name.split(".")[:-1]
|
||||
substring_before_final_period = ".".join(
|
||||
substring_before_final_period
|
||||
)
|
||||
if substring_before_final_period in model_dictionary:
|
||||
del model_dictionary[substring_before_final_period]
|
||||
|
||||
layer_dict = {n: "None" for n in self.sharding_stages_id}
|
||||
model_dictionary[name] = layer_dict
|
||||
|
||||
with open("model_config.json", "w") as outfile:
|
||||
json.dump(model_dictionary, outfile)
|
||||
@@ -148,14 +148,23 @@ class SharkImporter:
|
||||
|
||||
# Saves `function_name.npy`, `inputs.npz`, `golden_out.npz` and `model_name.mlir` in the directory `dir`.
|
||||
def save_data(
|
||||
self, dir, model_name, mlir_data, func_name, inputs, outputs
|
||||
self,
|
||||
dir,
|
||||
model_name,
|
||||
mlir_data,
|
||||
func_name,
|
||||
inputs,
|
||||
outputs,
|
||||
mlir_type="linalg",
|
||||
):
|
||||
import numpy as np
|
||||
|
||||
inputs_name = "inputs.npz"
|
||||
outputs_name = "golden_out.npz"
|
||||
func_file_name = "function_name"
|
||||
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
|
||||
model_name_mlir = (
|
||||
model_name + "_" + self.frontend + "_" + mlir_type + ".mlir"
|
||||
)
|
||||
print(f"saving {model_name_mlir} to {dir}")
|
||||
try:
|
||||
inputs = [x.cpu().detach() for x in inputs]
|
||||
@@ -191,19 +200,23 @@ class SharkImporter:
|
||||
dir=tempfile.gettempdir(),
|
||||
model_name="model",
|
||||
golden_values=None,
|
||||
mlir_type="linalg",
|
||||
):
|
||||
if self.inputs == None:
|
||||
print(
|
||||
f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir."
|
||||
)
|
||||
sys.exit(1)
|
||||
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
|
||||
model_name_mlir = (
|
||||
model_name + "_" + self.frontend + "_" + mlir_type + ".mlir"
|
||||
)
|
||||
artifact_path = os.path.join(dir, model_name_mlir)
|
||||
imported_mlir = self.import_mlir(
|
||||
is_dynamic,
|
||||
tracing_required,
|
||||
func_name,
|
||||
save_dir=artifact_path,
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
|
||||
# TODO: Check for multiple outputs.
|
||||
@@ -229,6 +242,7 @@ class SharkImporter:
|
||||
imported_mlir[1],
|
||||
self.inputs,
|
||||
golden_out,
|
||||
mlir_type,
|
||||
)
|
||||
return (
|
||||
imported_mlir,
|
||||
@@ -298,6 +312,46 @@ def get_f16_inputs(inputs, is_f16, f16_input_mask):
|
||||
return tuple(f16_masked_inputs)
|
||||
|
||||
|
||||
# Upcasts the block/list of ops.
|
||||
def add_upcast(fx_g):
|
||||
import torch
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.target in [torch.ops.aten.mul]:
|
||||
# This is a very strict check.
|
||||
if (
|
||||
node.args[1].target in [torch.ops.aten.rsqrt]
|
||||
and node.args[1].args[0].target in [torch.ops.aten.add]
|
||||
and node.args[1].args[0].args[0].target
|
||||
in [torch.ops.aten.mean]
|
||||
and node.args[1].args[0].args[0].args[0].target
|
||||
in [torch.ops.aten.pow]
|
||||
):
|
||||
print("found an upcasting block let's upcast it.")
|
||||
pow_node = node.args[1].args[0].args[0].args[0]
|
||||
mul_node = node
|
||||
with fx_g.graph.inserting_before(pow_node):
|
||||
lhs = pow_node.args[0]
|
||||
upcast_lhs = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(lhs,),
|
||||
kwargs={"dtype": torch.float32},
|
||||
)
|
||||
pow_node.args = (upcast_lhs, pow_node.args[1])
|
||||
with fx_g.graph.inserting_before(mul_node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(mul_node,),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
mul_node.append(new_node)
|
||||
mul_node.replace_all_uses_with(new_node)
|
||||
new_node.args = (mul_node,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
def transform_fx(fx_g):
|
||||
import torch
|
||||
|
||||
@@ -306,6 +360,9 @@ def transform_fx(fx_g):
|
||||
"device": torch.device(type="cpu"),
|
||||
"pin_memory": False,
|
||||
}
|
||||
kwargs_dict1 = {
|
||||
"dtype": torch.float16,
|
||||
}
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
@@ -313,7 +370,38 @@ def transform_fx(fx_g):
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.zeros,
|
||||
]:
|
||||
node.kwargs = kwargs_dict
|
||||
if node.kwargs.get("dtype") == torch.float32:
|
||||
node.kwargs = kwargs_dict
|
||||
|
||||
# Vicuna
|
||||
if node.target in [
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("dtype") == torch.float32:
|
||||
node.kwargs = kwargs_dict1
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten.masked_fill,
|
||||
]:
|
||||
if node.args[2] > torch.finfo(torch.half).max:
|
||||
max_val = torch.finfo(torch.half).max
|
||||
node.args = (node.args[0], node.args[1], max_val)
|
||||
elif node.args[2] < torch.finfo(torch.half).min:
|
||||
min_val = torch.finfo(torch.half).min
|
||||
node.args = (node.args[0], node.args[1], min_val)
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten.full,
|
||||
]:
|
||||
if node.args[1] > torch.finfo(torch.half).max:
|
||||
max_val = torch.finfo(torch.half).max
|
||||
node.args = (node.args[0], max_val)
|
||||
node.kwargs = kwargs_dict
|
||||
elif node.args[1] < torch.finfo(torch.half).min:
|
||||
min_val = torch.finfo(torch.half).min
|
||||
node.args = (node.args[0], min_val)
|
||||
node.kwargs = kwargs_dict
|
||||
|
||||
# Inputs and outputs of aten.var.mean should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.var_mean]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
@@ -323,6 +411,7 @@ def transform_fx(fx_g):
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node, node.args[1])
|
||||
|
||||
if node.name.startswith("getitem"):
|
||||
with fx_g.graph.inserting_before(node):
|
||||
if node.args[0].target in [torch.ops.aten.var_mean]:
|
||||
@@ -335,6 +424,7 @@ def transform_fx(fx_g):
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
@@ -346,6 +436,14 @@ def transform_fx(fx_g):
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
# Required for cuda debugging.
|
||||
# for node in fx_g.graph.nodes:
|
||||
# if node.op == "call_function":
|
||||
# if node.kwargs.get("device") == torch.device(type="cpu"):
|
||||
# new_kwargs = node.kwargs.copy()
|
||||
# new_kwargs["device"] = torch.device(type="cuda")
|
||||
# node.kwargs = new_kwargs
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
@@ -397,6 +495,9 @@ def import_with_fx(
|
||||
return_str=False,
|
||||
save_dir=tempfile.gettempdir(),
|
||||
model_name="model",
|
||||
mlir_type="linalg",
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
):
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@@ -446,6 +547,8 @@ def import_with_fx(
|
||||
if is_f16:
|
||||
fx_g = fx_g.half()
|
||||
transform_fx(fx_g)
|
||||
# TODO: Have to make it more generic.
|
||||
add_upcast(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if training:
|
||||
@@ -463,7 +566,12 @@ def import_with_fx(
|
||||
|
||||
if debug: # and not is_f16:
|
||||
(mlir_module, func_name), _, _ = mlir_importer.import_debug(
|
||||
dir=save_dir, model_name=model_name, golden_values=golden_values
|
||||
dir=save_dir,
|
||||
model_name=model_name,
|
||||
golden_values=golden_values,
|
||||
mlir_type=mlir_type,
|
||||
is_dynamic=is_dynamic,
|
||||
tracing_required=tracing_required,
|
||||
)
|
||||
return mlir_module, func_name
|
||||
|
||||
|
||||
@@ -25,7 +25,14 @@ import sys
|
||||
|
||||
|
||||
# supported dialects by the shark-runtime.
|
||||
supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite", "tm_tensor"}
|
||||
supported_dialects = {
|
||||
"linalg",
|
||||
"mhlo",
|
||||
"stablehlo",
|
||||
"tosa",
|
||||
"tf-lite",
|
||||
"tm_tensor",
|
||||
}
|
||||
|
||||
|
||||
class SharkRunner:
|
||||
|
||||
@@ -59,6 +59,7 @@ class SharkTrainer:
|
||||
"torch",
|
||||
"tensorflow",
|
||||
"tf",
|
||||
"stablehlo",
|
||||
"mhlo",
|
||||
"linalg",
|
||||
"tosa",
|
||||
@@ -84,7 +85,7 @@ class SharkTrainer:
|
||||
"tm_tensor",
|
||||
extra_args=extra_args,
|
||||
)
|
||||
elif self.frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
elif self.frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
self.shark_runner = SharkRunner(
|
||||
self.model,
|
||||
self.input,
|
||||
|
||||
@@ -21,7 +21,7 @@ import io
|
||||
|
||||
mlir_type_mapping_dict = {
|
||||
"linalg": torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
"mhlo": torch_mlir.OutputType.STABLEHLO,
|
||||
"stablehlo": torch_mlir.OutputType.STABLEHLO,
|
||||
"tosa": torch_mlir.OutputType.TOSA,
|
||||
}
|
||||
|
||||
|
||||
@@ -8,25 +8,25 @@ distilbert-base-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/convnext-tiny-224,mhlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342","macos"
|
||||
funnel-transformer/small,mhlo,tf,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/201",""
|
||||
google/electra-small-discriminator,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"Fails during iree-compile",""
|
||||
google/mobilebert-uncased,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"Fails during iree-compile","macos"
|
||||
google/vit-base-patch16-224,mhlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,False,"",""
|
||||
microsoft/MiniLM-L12-H384-uncased,mhlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
|
||||
microsoft/layoutlm-base-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
microsoft/mpnet-base,mhlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
|
||||
albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with aten.tanh in torch-mlir",""
|
||||
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,False,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-large-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344",""
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/388","macos"
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/343","macos"
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,True,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
|
||||
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
|
||||
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
@@ -35,13 +35,13 @@ squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","mac
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
|
||||
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243",""
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"",""
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported",""
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
stabilityai/stable-diffusion-2-1-base,linalg,torch,1e-3,1e-3,default,None,True,False,False,"",""
|
||||
gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
stabilityai/stable-diffusion-2-1-base,linalg,torch,1e-3,1e-3,default,None,True,False,False,"","macos"
|
||||
|
||||
|
@@ -75,7 +75,7 @@ if __name__ == "__main__":
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
input_type="auto",
|
||||
)
|
||||
# flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ if __name__ == "__main__":
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
input_type="auto",
|
||||
)
|
||||
|
||||
# Save module as MLIR file in a directory
|
||||
|
||||
@@ -96,7 +96,7 @@ if __name__ == "__main__":
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
input_type="auto",
|
||||
)
|
||||
# flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ if __name__ == "__main__":
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
input_type="auto",
|
||||
)
|
||||
# flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
|
||||
|
||||
159
tank/examples/opt/opt_causallm.py
Normal file
159
tank/examples/opt/opt_causallm.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import unittest
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import torch_mlir
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark_hf_opt import OPTForCausalLM
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from tank.model_utils import compare_tensors
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
OPT_MODEL = "opt-350m"
|
||||
OPT_MODEL_66B = "facebook/opt-66b"
|
||||
MAX_SEQUENCE_LENGTH = 256
|
||||
MAX_NEW_TOKENS = 200
|
||||
|
||||
|
||||
def create_module(model_name, tokenizer, device):
|
||||
opt_model = OPTForCausalLM.from_pretrained(
|
||||
"facebook/" + model_name, return_dict=False
|
||||
)
|
||||
opt_model.eval()
|
||||
|
||||
encoded_inputs = tokenizer(
|
||||
"This is a sample input for generating the model.",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
mlir_path = f"./{OPT_MODEL}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
|
||||
if os.path.isfile(mlir_path):
|
||||
with open(mlir_path, "r") as f:
|
||||
model_mlir = f.read()
|
||||
print(f"Loaded .mlir from {mlir_path}")
|
||||
else:
|
||||
module = torch_mlir.compile(
|
||||
opt_model,
|
||||
inputs,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=True,
|
||||
)
|
||||
|
||||
model_mlir = module.operation.get_asm(
|
||||
large_elements_limit=None, enable_debug_info=True
|
||||
)
|
||||
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
func_name = "forward"
|
||||
act_out = opt_model(inputs[0], attention_mask=inputs[1], return_dict=False)
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=False,
|
||||
)
|
||||
vmfb_name = f"{OPT_MODEL}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
|
||||
shark_module.save_module(module_name=vmfb_name)
|
||||
shark_module.load_module(vmfb_name + ".vmfb")
|
||||
|
||||
results = shark_module("forward", inputs)
|
||||
print(
|
||||
"SHARK logits have shape: ",
|
||||
str(results[0].shape) + " : " + str(results[0]),
|
||||
)
|
||||
print(
|
||||
"PyTorch logits have shape: "
|
||||
+ str(act_out[0].shape)
|
||||
+ " : "
|
||||
+ str(act_out[0])
|
||||
)
|
||||
# exp_out = tokenizer.decode(act_out[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
# shark_out = tokenizer.decode(results[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
return shark_module
|
||||
|
||||
|
||||
def shouldStop(tokens):
|
||||
stop_ids = [50278, 50279, 50277, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def generate_new_token(shark_model, tokenizer, new_text):
|
||||
model_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
model_inputs["input_ids"],
|
||||
model_inputs["attention_mask"],
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
output = shark_model("forward", inputs)
|
||||
output = torch.FloatTensor(output[0])
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[int(sum_attentionmask) - 1]
|
||||
detok = tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/" + OPT_MODEL, use_fast=False
|
||||
)
|
||||
vmfb_path = f"./{OPT_MODEL}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu.vmfb"
|
||||
if os.path.isfile(vmfb_path):
|
||||
opt_shark_module = SharkInference(mlir_module=None, device="cpu")
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
opt_shark_module = create_module(OPT_MODEL, tokenizer, "cpu")
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence to complete:")
|
||||
words_list = []
|
||||
|
||||
for i in range(MAX_NEW_TOKENS):
|
||||
generated_token_op = generate_new_token(
|
||||
opt_shark_module, tokenizer, new_text
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
if stop_generation:
|
||||
break
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text = new_text + detok
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
break
|
||||
211
tank/examples/opt/opt_causallm_torch_test.py
Normal file
211
tank/examples/opt/opt_causallm_torch_test.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import unittest
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import torch_mlir
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark_hf_opt import OPTForCausalLM
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from tank.model_utils import compare_tensors
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
OPT_MODEL = "facebook/opt-1.3B"
|
||||
OPT_MODEL_66B = "facebook/opt-66b"
|
||||
|
||||
|
||||
class OPTModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
benchmark=False,
|
||||
):
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device, model_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
# config = OPTConfig()
|
||||
# opt_model = OPTModel(config)
|
||||
opt_model = OPTForCausalLM.from_pretrained(
|
||||
model_name, return_dict=False
|
||||
)
|
||||
opt_model.eval()
|
||||
|
||||
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
input_ids, attention_mask = (
|
||||
inputs.data["input_ids"],
|
||||
inputs.data["attention_mask"],
|
||||
)
|
||||
np.save("opt_inputs.npy", input_ids.detach())
|
||||
mlir_path = "./OPT1-3b_causallm_torch.mlir"
|
||||
if os.path.isfile(mlir_path):
|
||||
with open(mlir_path, "r") as f:
|
||||
model_mlir = f.read()
|
||||
print(f"Loaded .mlir from {mlir_path}")
|
||||
else:
|
||||
module = torch_mlir.compile(
|
||||
opt_model,
|
||||
input_ids,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=True,
|
||||
)
|
||||
|
||||
model_mlir = module.operation.get_asm(
|
||||
large_elements_limit=None, enable_debug_info=True
|
||||
)
|
||||
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
func_name = "forward"
|
||||
act_out = opt_model(input_ids, return_dict=False)
|
||||
|
||||
# mlir_importer = SharkImporter(
|
||||
# model,
|
||||
# (input,),
|
||||
# frontend="torch",
|
||||
# )
|
||||
# minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
# is_dynamic=dynamic, tracing_required=True
|
||||
# )
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=self.benchmark,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module("forward", (input_ids,))
|
||||
print(
|
||||
"SHARK logits have shape: ",
|
||||
str(results[0].shape) + " : " + str(results[0]),
|
||||
)
|
||||
print(
|
||||
"PyTorch logits have shape: "
|
||||
+ str(act_out[0].shape)
|
||||
+ " : "
|
||||
+ str(act_out[0])
|
||||
)
|
||||
# exp_out = tokenizer.decode(act_out[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
# shark_out = tokenizer.decode(results[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
assert compare_tensors(act_out[0].detach(), results[0])
|
||||
|
||||
if self.benchmark:
|
||||
shark_module.shark_runner.benchmark_all_csv(
|
||||
(input_ids, attention_mask),
|
||||
"opt",
|
||||
dynamic,
|
||||
device,
|
||||
"torch",
|
||||
)
|
||||
|
||||
|
||||
class OPTModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = OPTModuleTester(self)
|
||||
self.module_tester.save_mlir = False
|
||||
self.module_tester.save_vmfb = False
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
def test_1_3b_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
def test_1_3b_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
)
|
||||
def test_1_3b_static_cuda(self):
|
||||
dynamic = False
|
||||
device = "cuda"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
)
|
||||
def test_1_3b_dynamic_cuda(self):
|
||||
dynamic = True
|
||||
device = "cuda"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_1_3b_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_1_3b_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
# def test_66b_static_cpu(self):
|
||||
# dynamic = False
|
||||
# device = "cpu"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# def test_66b_dynamic_cpu(self):
|
||||
# dynamic = True
|
||||
# device = "cpu"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
# )
|
||||
# def test_66b_static_cuda(self):
|
||||
# dynamic = False
|
||||
# device = "cuda"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
# )
|
||||
# def test_66b_dynamic_cuda(self):
|
||||
# dynamic = True
|
||||
# device = "cuda"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
# )
|
||||
# def test_66b_static_vulkan(self):
|
||||
# dynamic = False
|
||||
# device = "vulkan"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
# )
|
||||
# def test_66b_dynamic_vulkan(self):
|
||||
# dynamic = True
|
||||
# device = "vulkan"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -2,11 +2,11 @@ import unittest
|
||||
|
||||
import pytest
|
||||
import torch_mlir
|
||||
from hacked_hf_opt import OPTModel
|
||||
from shark_hf_opt import OPTModel
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from tank.model_utils import compare_tensors
|
||||
from transformers import GPT2Tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
OPT_MODEL = "facebook/opt-350m"
|
||||
OPT_MODEL_66B = "facebook/opt-66b"
|
||||
@@ -20,7 +20,7 @@ class OPTModuleTester:
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device, model_name):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
# config = OPTConfig()
|
||||
# opt_model = OPTModel(config)
|
||||
opt_model = OPTModel.from_pretrained(model_name)
|
||||
@@ -56,13 +56,12 @@ class OPTModuleTester:
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
func_name,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=self.benchmark,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input_ids, attention_mask))
|
||||
results = shark_module("forward", (input_ids, attention_mask))
|
||||
assert compare_tensors(act_out, results)
|
||||
|
||||
if self.benchmark:
|
||||
|
||||
@@ -279,7 +279,6 @@ class OPTAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (
|
||||
bsz * self.num_heads,
|
||||
tgt_len,
|
||||
@@ -832,7 +831,10 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if isinstance(outputs[1:], tuple):
|
||||
output = (logits,) + outputs[1:]
|
||||
else:
|
||||
output = (logits, outputs[1:])
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
@@ -37,10 +37,12 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
from tank.model_utils import (
|
||||
get_hf_model,
|
||||
get_hf_seq2seq_model,
|
||||
get_hf_causallm_model,
|
||||
get_vision_model,
|
||||
get_hf_img_cls_model,
|
||||
get_fp16_model,
|
||||
)
|
||||
from shark.shark_importer import import_with_fx
|
||||
|
||||
with open(torch_model_list) as csvfile:
|
||||
torch_reader = csv.reader(csvfile, delimiter=",")
|
||||
@@ -50,6 +52,8 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
tracing_required = row[1]
|
||||
model_type = row[2]
|
||||
is_dynamic = row[3]
|
||||
mlir_type = row[4]
|
||||
is_decompose = row[5]
|
||||
|
||||
tracing_required = False if tracing_required == "False" else True
|
||||
is_dynamic = False if is_dynamic == "False" else True
|
||||
@@ -91,6 +95,10 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
model, input, _ = get_hf_seq2seq_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf_causallm":
|
||||
model, input, _ = get_hf_causallm_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf_img_cls":
|
||||
model, input, _ = get_hf_img_cls_model(
|
||||
torch_model_name, import_args
|
||||
@@ -111,25 +119,45 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
)
|
||||
os.makedirs(torch_model_dir, exist_ok=True)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name,
|
||||
)
|
||||
# Generate torch dynamic models.
|
||||
if is_dynamic:
|
||||
if is_decompose:
|
||||
# Add decomposition to some torch ops
|
||||
# TODO add op whitelist/blacklist
|
||||
import_with_fx(
|
||||
model,
|
||||
(input,),
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
debug=True,
|
||||
training=False,
|
||||
return_str=False,
|
||||
save_dir=torch_model_dir,
|
||||
model_name=torch_model_name,
|
||||
mlir_type=mlir_type,
|
||||
is_dynamic=False,
|
||||
tracing_required=tracing_required,
|
||||
)
|
||||
else:
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=True,
|
||||
is_dynamic=False,
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name + "_dynamic",
|
||||
model_name=torch_model_name,
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
# Generate torch dynamic models.
|
||||
if is_dynamic:
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=True,
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name + "_dynamic",
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
|
||||
|
||||
def save_tf_model(tf_model_list, local_tank_cache, import_args):
|
||||
|
||||
@@ -176,6 +176,43 @@ def get_hf_seq2seq_model(name, import_args):
|
||||
return m, test_input, actual_out
|
||||
|
||||
|
||||
##################### Hugging Face CausalLM Models ###################################
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
|
||||
def prepare_sentence_tokens(hf_model: str, sentence: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model)
|
||||
return torch.tensor([tokenizer.encode(sentence)])
|
||||
|
||||
|
||||
class HFCausalLM(torch.nn.Module):
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__()
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, # The pretrained model name.
|
||||
# The number of output labels--2 for binary classification.
|
||||
num_labels=2,
|
||||
# Whether the model returns attentions weights.
|
||||
output_attentions=False,
|
||||
# Whether the model returns all hidden-states.
|
||||
output_hidden_states=False,
|
||||
torchscript=True,
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
def get_hf_causallm_model(name, import_args):
|
||||
m = HFCausalLM(name)
|
||||
test_input = prepare_sentence_tokens(
|
||||
name, "this project is very interesting"
|
||||
)
|
||||
actual_out = m.forward(*test_input)
|
||||
return m, test_input, actual_out
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
##################### Torch Vision Models ###################################
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
|
||||
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
microsoft/MiniLM-L12-H384-uncased,True,hf,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
|
||||
bert-base-uncased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
google/mobilebert-uncased,True,hf,True,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
|
||||
alexnet,False,vision,True,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
|
||||
resnet18,False,vision,True,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
|
||||
resnet50,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
resnet101,False,vision,True,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
squeezenet1_0,False,vision,True,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
|
||||
wide_resnet50_2,False,vision,True,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
|
||||
mobilenet_v3_small,False,vision,True,2.5M,"image-classification,cnn,mobile",N/A
|
||||
google/vit-base-patch16-224,True,hf_img_cls,False,86M,"image-classification,vision-transformer,transformer-encoder",N/A
|
||||
microsoft/resnet-50,True,hf_img_cls,False,23M,"image-classification,cnn,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,22M,"image-classification,vision-transformer,cnn",N/A
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
|
||||
nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encoder",SegFormer
|
||||
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
model_name, use_tracing, model_type, dynamic, mlir_type, decompose, param_count, tags, notes
|
||||
efficientnet_b0,True,vision,False,linalg,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,vision,False,linalg,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
microsoft/MiniLM-L12-H384-uncased,True,hf,True,linalg,False,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
|
||||
bert-base-uncased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
google/mobilebert-uncased,True,hf,True,linalg,False,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
|
||||
alexnet,False,vision,True,linalg,False,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
|
||||
resnet18,False,vision,True,linalg,False,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
|
||||
resnet50,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
resnet101,False,vision,True,linalg,False,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
squeezenet1_0,False,vision,True,linalg,False,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
|
||||
wide_resnet50_2,False,vision,True,linalg,False,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
|
||||
mobilenet_v3_small,False,vision,True,linalg,False,2.5M,"image-classification,cnn,mobile",N/A
|
||||
google/vit-base-patch16-224,True,hf_img_cls,False,linalg,False,86M,"image-classification,vision-transformer,transformer-encoder",N/A
|
||||
microsoft/resnet-50,True,hf_img_cls,False,linalg,False,23M,"image-classification,cnn,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,linalg,False,22M,"image-classification,vision-transformer,cnn",N/A
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,linalg,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
|
||||
nvidia/mit-b0,True,hf_img_cls,False,linalg,False,3.7M,"image-classification,transformer-encoder",SegFormer
|
||||
mnasnet1_0,False,vision,True,linalg,False,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
|
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"version": "2023-03-31_02d52bb"
|
||||
"version": "nightly"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user