Compare commits

..

44 Commits

Author SHA1 Message Date
Boian Petkantchin
20114deea0 In MiniLM JAX example verify MLIR result against JAX 2023-05-16 09:54:07 -07:00
Boian Petkantchin
9acf519078 Add option to skip venv creation in setup script 2023-05-16 09:54:07 -07:00
Boian Petkantchin
bdf37b5311 If device/backend is unknown pass it to IREE verbatim 2023-05-16 09:54:07 -07:00
powderluv
8ee2ac89f8 Rename sharded_vicuna_fp32_web.py to vicuna_web.py 2023-05-16 09:41:35 -07:00
powderluv
60cb48be2e Rename sharded_vicuna_fp32.py to vicuna.py 2023-05-16 09:40:51 -07:00
powderluv
86a215b063 Delete sharded_vicunia.py 2023-05-16 09:37:39 -07:00
powderluv
d6e3a9a236 Delete standalone_vicuna.py 2023-05-16 09:37:26 -07:00
Chi_Liu
a0097a1ead Add mlir_type for torch_model_list.csv (#1428)
- Enable stablehlo/tosa mlir output for torch model
- Add BERT stablehlo support
2023-05-15 10:23:54 -07:00
Ean Garvey
a9bae00606 Fix vulkan device selection at compile time and adapt to IREE python changes. (#1407)
* Add support for vulkan device selection at compile time.

* Don't convert device ID to int and fix .exe imports
2023-05-12 23:31:50 -07:00
Daniel Garvey
4731c1a835 prevent loading tokenizer on import (#1432)
also adds sentencepiece dep for exe
moved vicuna imports to after an if statement
in general we should avoid importing files that load whole models as
global variables
2023-05-12 19:11:45 -07:00
Ean Garvey
4c07e47e8c Specify a few models for expected failure on CUDA CI. (#1430) 2023-05-12 17:03:37 -05:00
Gaurav Shukla
e0cc2871bb [SD] Yield 2 tokens at a time in vicuna
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 23:49:01 +05:30
Gaurav Shukla
649f39408b [SD] Fix vicuna response
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 18:06:21 +05:30
Gaurav Shukla
c142297d73 [SD] Fix gradio to 3.22.0 version
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com
2023-05-11 18:05:55 +05:30
Gaurav Shukla
9e07360b00 [SD] Standalone vicuna with web
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 17:23:44 +05:30
Gaurav Shukla
7b74c86e42 [SD] Fix SAMPLE_INPUT_LEN import issue
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 15:41:43 +05:30
Eliasj42
fa833f8366 fixed spacing issue with chat-bot (#1417)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-10 16:07:50 -07:00
Gaurav Shukla
fcb059aa38 [SD] Integrate vicuna in the web (#1410) 2023-05-10 11:30:22 -07:00
PhaneeshB
517c670f82 vicuna chat cli 2023-05-10 22:55:06 +05:30
Eliasj42
59df14f18b added vicuna demo (#1408)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-09 21:18:20 -07:00
Ean Garvey
6c95ac0f37 Revert dialect registration in model annotator (#1406)
Matches https://github.com/nod-ai/SHARK-Runtime/pull/58
2023-05-09 11:50:19 -07:00
Daniel Garvey
7a4a51ae73 vulkan vic f16 (#1404)
Co-authored-by: dan <dan@nod-labs.com>
2023-05-08 16:46:53 -07:00
powderluv
d816cc015e Revert "added standalone vicuna script (#1399)" (#1402)
This reverts commit 0e4a8ca240.
2023-05-05 16:08:05 -07:00
Eliasj42
54ce3d48ca added standalone vicuna script (#1401)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-05 18:05:52 -05:00
Eliasj42
0e4a8ca240 added standalone vicuna script (#1399)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-05 15:46:05 -07:00
Daniel Garvey
6ca1298675 maximizes window size for webview launch (#1394) 2023-05-04 20:43:06 -07:00
jinchen62
bbef7a6464 Redesign model manager webui (#1391) 2023-05-04 20:41:29 -07:00
Ean Garvey
cdf2d61d53 Remove imports from iree.compiler.transforms from model annotator. (#1392) 2023-05-04 20:40:19 -07:00
Ean Garvey
6c14847d1f xfail some large tests on macOS builder and switch to hash updates. (#1341)
* Update test-models.yml

* Disable large tests on macOS builder
2023-05-04 19:47:03 -05:00
Gaurav Shukla
68ecdd2a73 [SD] Add LoRA as experimental tab
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-04 22:30:25 +05:30
Gaurav Shukla
3f4d444d18 [SD] Fix stable LM chatbot
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-04 22:30:25 +05:30
m68k-fr
e473d0375b [Web] Models folders cleanup (#1365) 2023-05-03 16:13:20 -05:00
Ean Garvey
e38d96850f Fix input image loading in img2img rest API (#1388) 2023-05-03 15:51:00 -05:00
Gaurav Shukla
fed63dfd4b [SD] Add stableLM chatbot (#1383)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-05-03 15:37:20 -05:00
Boian Petkantchin
eba4d06405 In MiniLM JAX example do not hardcode device (#1385)
* In MiniLM JAX example do not hardcode device

* In MiniLM JAX example don't use bytecode MLIR

---------

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-05-03 10:34:42 -07:00
Boian Petkantchin
4cfba153d2 Add example JAX MiniLM inference (#1380)
* Do not hardcode the name of the VM module in get_iree_module

* Add example JAX MiniLM inference

---------

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-05-02 15:03:54 -07:00
jinchen62
307c05f38d Convert original vae to diffusers (#1382) 2023-05-02 01:27:28 -07:00
jinchen62
696df349cb Fix curl issue (#1369) 2023-04-28 09:31:14 -07:00
jinchen62
cb54cb1348 Add model manager tab for SD webui (#1368) 2023-04-28 02:43:40 -07:00
Daniel Garvey
9bdb86637d add tkinter launch for webui (#1364) 2023-04-27 19:17:55 -05:00
jinchen62
fb6f26517f Fix webui note (#1367) 2023-04-27 16:14:43 -07:00
Chi_Liu
aa8ada9da9 Add support for torch to stablehlo and tosa in shark_importer (#1360) 2023-04-27 08:09:45 -07:00
powderluv
1db906a373 Revert "Add model manager tab for webui (#1359)" (#1362)
This reverts commit 9d1d1617d8.
2023-04-26 22:25:26 -07:00
jinchen62
9d1d1617d8 Add model manager tab for webui (#1359) 2023-04-26 13:38:18 -07:00
33 changed files with 2503 additions and 310 deletions

View File

@@ -137,7 +137,7 @@ jobs:
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan --update_tank
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'

View File

@@ -1,26 +1,27 @@
import torch
import shark
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
import torch_mlir
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
)
import torch_mlir
from apps.stable_diffusion.src.utils import (
base_models,
get_opt_flags,
get_vmfb_path_name,
)
from apps.stable_diffusion.src.models.model_wrappers import replace_shape_str
import time
import numpy as np
from torch.nn import functional as F
import os
from threading import Thread
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
tokenizer = AutoTokenizer.from_pretrained(
"stabilityai/stablelm-tuned-alpha-7b"
)
from shark.shark_inference import SharkInference
from pathlib import Path
class StopOnTokens(StoppingCriteria):
@@ -34,6 +35,203 @@ class StopOnTokens(StoppingCriteria):
return False
def shouldStop(tokens):
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if tokens[0][-1] == stop_id:
return True
return False
MAX_SEQUENCE_LENGTH = 256
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def get_torch_mlir_module_bytecode(model, model_inputs):
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
# tracing_mode='symbolic',
)(*model_inputs)
print("Got FX_G")
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
transform_fx(fx_g)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
print("FX_G recompile")
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
print("Got TS_G")
return ts_g
def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
# ADD Device Arg
from shark.shark_inference import SharkInference
vmfb_path = Path(model_vmfb_name + ".vmfb")
if vmfb_path.exists():
print("Loading ", vmfb_path)
shark_module = SharkInference(
None, device="cuda", mlir_dialect="tm_tensor"
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
return shark_module
mlir_path = Path(model_name + ".mlir")
print(
f"[DEBUG] mlir path { mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path) as f:
bytecode = f.read("rb")
else:
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
)
shark_module.compile()
import os
path = shark_module.save_module(os.getcwd(), model_vmfb_name, [])
print("Saved vmfb at ", str(path))
return shark_module
class StableLMModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
combine_input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
output = self.model(**combine_input_dict)
return output.logits
# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
@@ -41,167 +239,65 @@ system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM will refuse to participate in anything that could harm a human.
"""
prompt = f"{system_prompt}<|USER|>What's your mood today?<|ASSISTANT|>"
inputs = tokenizer(prompt, return_tensors="pt")
def get_tokenizer():
model_path = "stabilityai/stablelm-tuned-alpha-3b"
tok = AutoTokenizer.from_pretrained(model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
print(f"Sucessfully loaded the tokenizer to the memory")
return tok
class SLM(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
"stabilityai/stablelm-tuned-alpha-7b"
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
def generate(
new_text,
streamer,
max_new_tokens,
do_sample,
top_p,
top_k,
temperature,
num_beams,
stopping_criteria,
sharkStableLM,
tok=None,
input_ids=torch.randint(3, (1, 256)),
attention_mask=torch.randint(3, (1, 256)),
):
if tok == None:
tok = get_tokenizer
# Construct the input message string for the model by concatenating the current system message and conversation history
# Tokenize the messages string
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
words_list = []
for i in range(max_new_tokens):
numWords = len(new_text.split())
# if(numWords>220):
# break
model_inputs = tok(
[new_text],
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
)
def forward(self, input_ids, attention_mask):
return self.model(input_ids, attention_mask)[0]
slm_model = SLM()
res_pytorch = slm_model(inputs["input_ids"], inputs["attention_mask"])
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
fx_g = make_fx(
slm_model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(inputs["input_ids"], inputs["attention_mask"])
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
transform_fx(fx_g)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
module = torch_mlir.compile(
ts_g,
[inputs["input_ids"], inputs["attention_mask"]],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
)
shark_module.compile()
result_shark = shark_module(
"forward", [inputs["input_ids"], inputs["attention_mask"]]
)
print("Result PyTorch")
print(res_pytorch)
print("Result SHARK")
print(result_shark)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
output = sharkStableLM(
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
)
output = torch.from_numpy(output)
next_toks = torch.topk(output, 1)
if shouldStop(next_toks.indices):
break
# streamer.put(next_toks.indices[0][int(sum_attentionmask)-1])
new_word = tok.decode(
next_toks.indices[0][int(sum_attentionmask) - 1],
skip_special_tokens=True,
)
print(new_word, end="", flush=True)
words_list.append(new_word)
if new_word == "":
break
new_text = new_text + new_word
return words_list

View File

@@ -0,0 +1,656 @@
import sys
import warnings
warnings.filterwarnings("ignore")
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_compiler")
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_runtime")
import torch
import torch_mlir
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import transform_fx as transform_fx_
import re
from shark.shark_inference import SharkInference
from tqdm import tqdm
from torch_mlir import TensorPlaceholder
class FirstVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states, attention_mask, position_ids):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class SecondVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
past_key_value0,
past_key_value1,
),
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class CompiledFirstVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class CompiledSecondVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
def forward(
self,
input_ids,
is_first=True,
past_key_values=None,
attention_mask=None,
):
if is_first:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers0
)
return self.model.forward(input_ids, attention_mask=attention_mask)
else:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers1
)
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
def write_in_dynamic_inputs0(module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def write_in_dynamic_inputs1(module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
if "dim_42 =" in line:
continue
if f"%c{dynamic_input_size}_i64 =" in line:
new_lines.append(
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
)
new_lines.append(
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
)
continue
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim_42)", line)
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim_42\)",
"tensor.empty(%dim_42, %dim_42)",
line,
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def compile_vicuna_layer(
vicuna_layer,
hidden_states,
attention_mask,
position_ids,
past_key_value0=None,
past_key_value1=None,
):
hidden_states_placeholder = TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
attention_mask, dynamic_axes=[2, 3]
)
position_ids_placeholder = TensorPlaceholder.like(
position_ids, dynamic_axes=[1]
)
if past_key_value0 is None and past_key_value1 is None:
fx_g = make_fx(
vicuna_layer,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(hidden_states, attention_mask, position_ids)
else:
fx_g = make_fx(
vicuna_layer,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
)
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
transform_fx(fx_g)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
print("FX_G recompile")
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
return ts_g
def get_model_and_tokenizer(path="TheBloke/vicuna-7B-1.1-HF"):
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
return vicuna_model, tokenizer
def compile_to_vmfb(inputs, layers, is_first=True):
mlirs, modules = [], []
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
if is_first:
mlir_path = Path(f"{idx}_0.mlir")
vmfb_path = Path(f"{idx}_0.vmfb")
else:
mlir_path = Path(f"{idx}_1.mlir")
vmfb_path = Path(f"{idx}_1.vmfb")
if vmfb_path.exists():
continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
hidden_states_placeholder = TensorPlaceholder.like(
inputs[0], dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
inputs[1], dynamic_axes=[3]
)
position_ids_placeholder = TensorPlaceholder.like(
inputs[2], dynamic_axes=[1]
)
if not is_first:
pkv0_placeholder = TensorPlaceholder.like(
inputs[3], dynamic_axes=[2]
)
pkv1_placeholder = TensorPlaceholder.like(
inputs[4], dynamic_axes=[2]
)
print(f"Compiling layer {idx} mlir")
if is_first:
ts_g = compile_vicuna_layer(
layer, inputs[0], inputs[1], inputs[2]
)
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
inputs[1],
inputs[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
else:
ts_g = compile_vicuna_layer(
layer,
inputs[0],
inputs[1],
inputs[2],
inputs[3],
inputs[4],
)
module = torch_mlir.compile(
ts_g,
(
inputs[0],
attention_mask_placeholder,
inputs[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
# bytecode_stream = BytesIO()
# module.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
if is_first:
module = write_in_dynamic_inputs0(str(module), 137)
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
else:
module = write_in_dynamic_inputs1(str(module), 138)
if idx in [0, 5, 6, 7]:
module_str = module
module_str = module_str.splitlines()
new_lines = []
for line in module_str:
if len(line) < 1000:
new_lines.append(line)
else:
new_lines.append(line[:999])
module_str = "\n".join(new_lines)
f1_ = open(f"{idx}_1_test.mlir", "w+")
f1_.write(module_str)
f1_.close()
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
mlirs.append(bytecode)
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
if is_first:
vmfb_path = Path(f"{idx}_0.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_0",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
else:
vmfb_path = Path(f"{idx}_1.vmfb")
if idx < 25:
device = "vulkan"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_1",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
return mlirs, modules
def get_sharded_model():
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it
SAMPLE_INPUT_LEN = 137
vicuna_model = get_model_and_tokenizer()[0]
placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
)
placeholder_input1 = (
torch.zeros([1, 1, 4096]),
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
torch.zeros([1, 1], dtype=torch.int64),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
layers0 = [FirstVicunaLayer(layer) for layer in vicuna_model.model.layers]
_, modules0 = compile_to_vmfb(placeholder_input0, layers0, is_first=True)
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
layers1 = [SecondVicunaLayer(layer) for layer in vicuna_model.model.layers]
_, modules1 = compile_to_vmfb(placeholder_input1, layers1, is_first=False)
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
sharded_model = ShardedVicunaModel(
vicuna_model, shark_layers0, shark_layers1
)
return sharded_model
if __name__ == "__main__":
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"
sharded_model = get_sharded_model()
tokenizer = get_model_and_tokenizer()[1]
past_key_values = None
while True:
print("\n\n")
user_prompt = input("User: ")
prompt_history = (
prompt_history + "USER:\n" + user_prompt + prologue_prompt
)
prompt = prompt_history.strip()
input_ids = tokenizer(prompt).input_ids
tokens = input_ids
prompt = print("Robot:", end=" ")
new_sentence = []
max_response_len = 1000
for iteration in range(max_response_len):
original_input_ids = input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
if iteration == 0:
output = sharded_model.forward(input_ids, is_first=True)
else:
output = sharded_model.forward(
input_ids, past_key_values=past_key_values, is_first=False
)
logits = output["logits"]
past_key_values = output["past_key_values"]
new_token = int(torch.argmax(logits[:, -1, :], dim=1)[0])
if new_token == 2:
break
new_sentence += [new_token]
tokens.append(new_token)
original_input_ids.append(new_token)
input_ids = [new_token]
for i in range(len(tokens)):
if type(tokens[i]) != int:
tokens[i] = int(tokens[i][0])
new_sentence_str = tokenizer.decode(new_sentence)
print(new_sentence_str)
prompt_history += f"\n{new_sentence_str}\n"

View File

@@ -0,0 +1,777 @@
import sys
import warnings
import gradio as gr
import time
warnings.filterwarnings("ignore")
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_compiler")
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_runtime")
import torch
import torch_mlir
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import transform_fx as transform_fx_
import re
from shark.shark_inference import SharkInference
from tqdm import tqdm
from torch_mlir import TensorPlaceholder
from apps.stable_diffusion.web.ui.utils import available_devices
class FirstVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states, attention_mask, position_ids):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class SecondVicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
past_key_value0,
past_key_value1,
),
use_cache=True,
)
next_hidden_states = outputs[0]
past_key_value_out0, past_key_value_out1 = (
outputs[-1][0],
outputs[-1][1],
)
return (
next_hidden_states,
past_key_value_out0,
past_key_value_out1,
)
class CompiledFirstVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class CompiledSecondVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
def forward(
self,
input_ids,
is_first=True,
past_key_values=None,
attention_mask=None,
):
if is_first:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers0
)
return self.model.forward(input_ids, attention_mask=attention_mask)
else:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers1
)
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
def write_in_dynamic_inputs0(module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def write_in_dynamic_inputs1(module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
if "dim_42 =" in line:
continue
if f"%c{dynamic_input_size}_i64 =" in line:
new_lines.append(
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
)
new_lines.append(
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
)
continue
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim_42)", line)
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim_42\)",
"tensor.empty(%dim_42, %dim_42)",
line,
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def compile_vicuna_layer(
vicuna_layer,
hidden_states,
attention_mask,
position_ids,
past_key_value0=None,
past_key_value1=None,
):
hidden_states_placeholder = TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
attention_mask, dynamic_axes=[2, 3]
)
position_ids_placeholder = TensorPlaceholder.like(
position_ids, dynamic_axes=[1]
)
if past_key_value0 is None and past_key_value1 is None:
fx_g = make_fx(
vicuna_layer,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(hidden_states, attention_mask, position_ids)
else:
fx_g = make_fx(
vicuna_layer,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
)
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
transform_fx(fx_g)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
print("FX_G recompile")
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
return ts_g
path = "TheBloke/vicuna-7B-1.1-HF"
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
def compile_to_vmfb(inputs, layers, is_first=True):
mlirs, modules = [], []
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
if is_first:
mlir_path = Path(f"{idx}_0.mlir")
vmfb_path = Path(f"{idx}_0.vmfb")
else:
mlir_path = Path(f"{idx}_1.mlir")
vmfb_path = Path(f"{idx}_1.vmfb")
if vmfb_path.exists():
continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
hidden_states_placeholder = TensorPlaceholder.like(
inputs[0], dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
inputs[1], dynamic_axes=[3]
)
position_ids_placeholder = TensorPlaceholder.like(
inputs[2], dynamic_axes=[1]
)
if not is_first:
pkv0_placeholder = TensorPlaceholder.like(
inputs[3], dynamic_axes=[2]
)
pkv1_placeholder = TensorPlaceholder.like(
inputs[4], dynamic_axes=[2]
)
print(f"Compiling layer {idx} mlir")
if is_first:
ts_g = compile_vicuna_layer(
layer, inputs[0], inputs[1], inputs[2]
)
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
inputs[1],
inputs[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
else:
ts_g = compile_vicuna_layer(
layer,
inputs[0],
inputs[1],
inputs[2],
inputs[3],
inputs[4],
)
module = torch_mlir.compile(
ts_g,
(
inputs[0],
attention_mask_placeholder,
inputs[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
# bytecode_stream = BytesIO()
# module.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
if is_first:
module = write_in_dynamic_inputs0(str(module), 137)
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
else:
module = write_in_dynamic_inputs1(str(module), 138)
if idx in [0, 5, 6, 7]:
module_str = module
module_str = module_str.splitlines()
new_lines = []
for line in module_str:
if len(line) < 1000:
new_lines.append(line)
else:
new_lines.append(line[:999])
module_str = "\n".join(new_lines)
f1_ = open(f"{idx}_1_test.mlir", "w+")
f1_.write(module_str)
f1_.close()
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
mlirs.append(bytecode)
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
if is_first:
vmfb_path = Path(f"{idx}_0.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_0",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
else:
vmfb_path = Path(f"{idx}_1.vmfb")
if idx < 25:
device = "vulkan"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_1",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
return mlirs, modules
def get_sharded_model():
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it
SAMPLE_INPUT_LEN = 137
global vicuna_model
placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
)
placeholder_input1 = (
torch.zeros([1, 1, 4096]),
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
torch.zeros([1, 1], dtype=torch.int64),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
layers0 = [FirstVicunaLayer(layer) for layer in vicuna_model.model.layers]
_, modules0 = compile_to_vmfb(placeholder_input0, layers0, is_first=True)
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
layers1 = [SecondVicunaLayer(layer) for layer in vicuna_model.model.layers]
_, modules1 = compile_to_vmfb(placeholder_input1, layers1, is_first=False)
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
sharded_model = ShardedVicunaModel(
vicuna_model, shark_layers0, shark_layers1
)
return sharded_model
sharded_model = get_sharded_model()
def user(message, history):
print("msg=", message)
print("history=", history)
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def chat(curr_system_message, history):
global sharded_model
past_key_values = None
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
print(messages)
prompt = messages.strip()
input_ids = tokenizer(prompt).input_ids
tokens = input_ids
new_sentence = []
max_response_len = 1000
partial_sentence = []
partial_text = ""
start_time = time.time()
for iteration in range(max_response_len):
original_input_ids = input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
if iteration == 0:
output = sharded_model.forward(input_ids, is_first=True)
else:
output = sharded_model.forward(
input_ids, past_key_values=past_key_values, is_first=False
)
logits = output["logits"]
past_key_values = output["past_key_values"]
new_token = int(torch.argmax(logits[:, -1, :], dim=1)[0])
if new_token == 2:
break
new_sentence += [new_token]
partial_sentence += [new_token]
if iteration > 0 and iteration % 2 == 0:
new_text = tokenizer.decode(partial_sentence)
partial_sentence = []
print(new_text, " ")
partial_text += new_text + " "
history[-1][1] = partial_text
yield history
tokens.append(new_token)
original_input_ids.append(new_token)
input_ids = [new_token]
end_time = time.time()
print(
f"Total time taken to generated response is {end_time-start_time} seconds"
)
for i in range(len(tokens)):
if type(tokens[i]) != int:
tokens[i] = int(tokens[i][0])
new_sentence_str = tokenizer.decode(new_sentence)
print(new_sentence_str)
history[-1][1] = new_sentence_str
return history
system_msg = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
# history_eg = [["hi hello how are you", ""]]
# print(chat(system_msg, history_eg))
with gr.Blocks(title="Chatbot") as vicuna_chat:
with gr.Row():
model = gr.Dropdown(
label="Select Model",
value="TheBloke/vicuna-7B-1.1-HF",
choices=[
"TheBloke/vicuna-7B-1.1-HF",
],
)
device_value = None
for d in available_devices:
if "vulkan" in d:
device_value = d
break
device = gr.Dropdown(
label="Device",
value=device_value if device_value else available_devices[0],
interactive=False,
choices=available_devices,
)
chatbot = gr.Chatbot().style(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
).style(container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit")
stop = gr.Button("Stop")
clear = gr.Button("Clear")
system_msg = gr.Textbox(
system_msg, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot],
outputs=[chatbot],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot],
outputs=[chatbot],
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)
import argparse
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
args, unknown = p.parse_known_args()
vicuna_chat.queue()
vicuna_chat.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

View File

@@ -29,6 +29,9 @@ datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('tkinter')
datas += collect_data_files('webview')
datas += collect_data_files('sentencepiece')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),

View File

@@ -14,6 +14,7 @@ from apps.stable_diffusion.src.utils import (
base_models,
args,
preprocessCKPT,
convert_original_vae,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
@@ -571,8 +572,12 @@ class SharkifyStableDiffusionModel:
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
if "state_dict" in vae_checkpoint:
vae_checkpoint = vae_checkpoint["state_dict"]
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
try:
vae_checkpoint = convert_original_vae(vae_checkpoint)
finally:
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
def compile_unet_variants(self, model):
if model == "unet":

View File

@@ -24,6 +24,7 @@ from apps.stable_diffusion.src.utils.utils import (
get_available_devices,
get_opt_flags,
preprocessCKPT,
convert_original_vae,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,

View File

@@ -493,7 +493,13 @@ p.add_argument(
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="one of: [api, app, web]",
)
p.add_argument(
"--share",

View File

@@ -25,7 +25,12 @@ from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
create_vae_diffusers_config,
convert_ldm_vae_checkpoint,
)
import requests
from io import BytesIO
from omegaconf import OmegaConf
def get_extended_name(model_name):
@@ -464,7 +469,7 @@ def get_path_stem(path):
def get_path_to_diffusers_checkpoint(custom_weights):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = path.stem
diffusers_directory_name = os.path.join("diffusers", path.stem)
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
@@ -503,6 +508,22 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
print("Loading complete")
def convert_original_vae(vae_checkpoint):
vae_state_dict = {}
for key in list(vae_checkpoint.keys()):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
vae_state_dict, vae_config
)
return converted_vae_checkpoint
def processLoRA(model, use_lora, splitting_prefix):
state_dict = ""
if ".safetensors" in use_lora:

View File

@@ -1,3 +1,4 @@
from multiprocessing import Process, freeze_support
import os
import sys
import transformers
@@ -10,8 +11,26 @@ if sys.platform == "darwin":
if args.clear_all:
clear_all()
def launch_app(address):
from tkinter import Tk
import webview
window = Tk()
# getting screen width and height of display
width = window.winfo_screenwidth()
height = window.winfo_screenheight()
webview.create_window(
"SHARK AI Studio", url=address, width=width, height=height
)
webview.start(private_mode=False)
if __name__ == "__main__":
if args.api:
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
if args.api or "api" in args.ui.split(","):
from apps.stable_diffusion.web.ui import (
txt2img_api,
img2img_api,
@@ -40,14 +59,12 @@ if __name__ == "__main__":
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
)
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create the custom model folder if it doesn't already exist
dir = ["models", "vae", "lora"]
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
# Create custom models folders if they don't exist
create_custom_models_folders()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
@@ -60,36 +77,54 @@ if __name__ == "__main__":
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_sendto_inpaint,
img2img_sendto_outpaint,
img2img_sendto_upscaler,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_sendto_img2img,
inpaint_sendto_outpaint,
inpaint_sendto_upscaler,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_sendto_img2img,
outpaint_sendto_inpaint,
outpaint_sendto_upscaler,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
lora_train_web,
model_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
modelmanager_sendto_inpaint,
modelmanager_sendto_outpaint,
modelmanager_sendto_upscaler,
stablelm_chat,
)
# init global sd pipeline and config
@@ -105,6 +140,17 @@ if __name__ == "__main__":
outputs,
)
def register_modelmanager_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
@@ -119,9 +165,11 @@ if __name__ == "__main__":
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.Tabs(visible=False) as experimental_tabs:
with gr.TabItem(label="LoRA Training", id=5):
with gr.TabItem(label="Model Manager", id=5):
model_web.render()
with gr.TabItem(label="Chat Bot(Experimental)", id=6):
stablelm_chat.render()
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
lora_train_web.render()
register_button_click(
@@ -220,10 +268,46 @@ if __name__ == "__main__":
[upscaler_gallery],
[outpaint_init_image, tabs],
)
register_modelmanager_button(
modelmanager_sendto_txt2img,
0,
[hf_models],
[txt2img_custom_model, txt2img_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_img2img,
1,
[hf_models],
[img2img_custom_model, img2img_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_inpaint,
2,
[hf_models],
[inpaint_custom_model, inpaint_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_outpaint,
3,
[hf_models],
[outpaint_custom_model, outpaint_hf_model_id, tabs],
)
register_modelmanager_button(
modelmanager_sendto_upscaler,
4,
[hf_models],
[upscaler_custom_model, upscaler_hf_model_id, tabs],
)
sd_web.queue()
if args.ui == "app":
t = Process(
target=launch_app, args=[f"http://localhost:{args.server_port}"]
)
t.start()
sd_web.launch(
share=args.share,
inbrowser=True,
inbrowser=args.ui == "web",
server_name="0.0.0.0",
server_port=args.server_port,
)

View File

@@ -2,6 +2,8 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_inf,
txt2img_api,
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_sendto_img2img,
txt2img_sendto_inpaint,
@@ -12,6 +14,8 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
img2img_api,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_sendto_inpaint,
@@ -22,6 +26,8 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
inpaint_inf,
inpaint_api,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_sendto_img2img,
@@ -32,6 +38,8 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
outpaint_inf,
outpaint_api,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_sendto_img2img,
@@ -42,10 +50,22 @@ from apps.stable_diffusion.web.ui.upscaler_ui import (
upscaler_inf,
upscaler_api,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
)
from apps.stable_diffusion.web.ui.model_manager import (
model_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
modelmanager_sendto_inpaint,
modelmanager_sendto_outpaint,
modelmanager_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat

View File

@@ -4,6 +4,7 @@ import torch
import time
import sys
import gradio as gr
import PIL
from PIL import Image
import base64
from io import BytesIO
@@ -89,6 +90,8 @@ def img2img_inf(
return None, "An Initial Image is required"
if use_stencil == "scribble":
image = image_dict["mask"].convert("RGB")
elif isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
else:
image = image_dict["image"].convert("RGB")
@@ -299,7 +302,7 @@ def img2img_api(
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["image"])
init_image = decode_base64_to_image(InputData["init_images"][0])
res = img2img_inf(
InputData["prompt"],
InputData["negative_prompt"],
@@ -352,21 +355,21 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
img2img_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
hf_model_id = gr.Textbox(
img2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://a88802436301955b3a.gradio.live",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download url",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
@@ -624,8 +627,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
img2img_custom_model,
img2img_hf_model_id,
custom_vae,
precision,
device,

View File

@@ -306,21 +306,23 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
inpaint_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files()
+ get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
)
hf_model_id = gr.Textbox(
inpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model download url",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
@@ -527,8 +529,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
inpaint_custom_model,
inpaint_hf_model_id,
custom_vae,
precision,
device,

View File

@@ -0,0 +1,157 @@
import os
import gradio as gr
import requests
from io import BytesIO
from PIL import Image
def get_hf_list(num_of_models=20):
path = "https://huggingface.co/api/models"
params = {
"search": "stable-diffusion",
"sort": "downloads",
"direction": "-1",
"limit": {num_of_models},
"full": "true",
}
response = requests.get(path, params=params)
return response.json()
def get_civit_list(num_of_models=50):
path = f"https://civitai.com/api/v1/models?limit={num_of_models}&types=Checkpoint"
headers = {"Content-Type": "application/json"}
raw_json = requests.get(path, headers=headers).json()
models = list(raw_json.items())[0][1]
safe_models = [
safe_model for safe_model in models if not safe_model["nsfw"]
]
version_id = 0 # Currently just using the first version.
safe_models = [
safe_model
for safe_model in safe_models
if safe_model["modelVersions"][version_id]["files"][0]["metadata"][
"format"
]
== "SafeTensor"
]
first_version_models = []
for model_iter in safe_models:
# The modelVersion would only keep the version name.
if (
model_iter["modelVersions"][version_id]["images"][0]["nsfw"]
!= "None"
):
continue
model_iter["modelVersions"][version_id]["modelName"] = model_iter[
"name"
]
model_iter["modelVersions"][version_id]["rating"] = model_iter[
"stats"
]["rating"]
model_iter["modelVersions"][version_id]["favoriteCount"] = model_iter[
"stats"
]["favoriteCount"]
model_iter["modelVersions"][version_id]["downloadCount"] = model_iter[
"stats"
]["downloadCount"]
first_version_models.append(model_iter["modelVersions"][version_id])
return first_version_models
def get_image_from_model(model_json):
model_id = model_json["modelId"]
image = None
for img_info in model_json["images"]:
if img_info["nsfw"] == "None":
image_url = model_json["images"][0]["url"]
response = requests.get(image_url)
image = BytesIO(response.content)
break
return image
with gr.Blocks() as model_web:
with gr.Row():
model_source = gr.Radio(
value=None,
choices=["Hugging Face", "Civitai"],
type="value",
label="Model Source",
)
model_numebr = gr.Slider(
1,
100,
value=10,
step=1,
label="Number of models",
interactive=True,
)
# TODO: add more filters
get_model_btn = gr.Button(value="Get Models")
hf_models = gr.Dropdown(
label="Hugging Face Model List",
choices=None,
value=None,
visible=False,
)
# TODO: select and SendTo
civit_models = gr.Gallery(
label="Civitai Model Gallery",
value=None,
interactive=True,
visible=False,
)
with gr.Row(visible=False) as sendto_btns:
modelmanager_sendto_txt2img = gr.Button(value="SendTo Txt2Img")
modelmanager_sendto_img2img = gr.Button(value="SendTo Img2Img")
modelmanager_sendto_inpaint = gr.Button(value="SendTo Inpaint")
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
def get_model_list(model_source, model_numebr):
if model_source == "Hugging Face":
hf_model_list = get_hf_list(model_numebr)
models = []
for model in hf_model_list:
# TODO: add model info
models.append(f'{model["modelId"]}')
return (
gr.Dropdown.update(choices=models, visible=True),
gr.Gallery.update(value=None, visible=False),
gr.Row.update(visible=True),
)
elif model_source == "Civitai":
civit_model_list = get_civit_list(model_numebr)
models = []
for model in civit_model_list:
image = get_image_from_model(model)
if image is None:
continue
# TODO: add model info
models.append(
(Image.open(image), f'{model["files"][0]["downloadUrl"]}')
)
return (
gr.Dropdown.update(value=None, choices=None, visible=False),
gr.Gallery.update(value=models, visible=True),
gr.Row.update(visible=False),
)
else:
return (
gr.Dropdown.update(value=None, choices=None, visible=False),
gr.Gallery.update(value=None, visible=False),
gr.Row.update(visible=False),
)
get_model_btn.click(
fn=get_model_list,
inputs=[model_source, model_numebr],
outputs=[
hf_models,
civit_models,
sendto_btns,
],
)

View File

@@ -317,21 +317,23 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
outpaint_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files()
+ get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
)
hf_model_id = gr.Textbox(
outpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
@@ -371,9 +373,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4, https://civitai.com/api/download/models/3433",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID or Civitai model download url",
label="HuggingFace Model ID",
lines=3,
)
with gr.Accordion(label="Advanced Options", open=False):
@@ -558,8 +560,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
outpaint_custom_model,
outpaint_hf_model_id,
custom_vae,
precision,
device,

View File

@@ -0,0 +1,217 @@
import gradio as gr
import torch
import os
from apps.language_models.scripts.stablelm import (
compile_stableLM,
StopOnTokens,
generate,
get_tokenizer,
StableLMModel,
)
from transformers import (
AutoModelForCausalLM,
TextIteratorStreamer,
StoppingCriteriaList,
)
from apps.stable_diffusion.web.ui.utils import available_devices
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
input_ids = torch.randint(3, (1, 256))
attention_mask = torch.randint(3, (1, 256))
sharkModel = 0
sharded_model = 0
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
past_key_values = None
def chat(curr_system_message, history, model):
global sharded_model
global past_key_values
if "vicuna" in model:
from apps.language_models.scripts.sharded_vicuna_fp32 import (
tokenizer,
get_sharded_model,
)
SAMPLE_INPUT_LEN = 137
curr_system_message = start_message_vicuna
if sharded_model == 0:
sharded_model = get_sharded_model()
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
prompt = messages.strip()
print("prompt = ", prompt)
input_ids = tokenizer(prompt).input_ids
new_sentence = ""
for _ in range(200):
original_input_ids = input_ids
input_id_len = len(input_ids)
pad_len = SAMPLE_INPUT_LEN - input_id_len
attention_mask = torch.ones([1, input_id_len], dtype=torch.int64)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
attention_mask = torch.nn.functional.pad(
torch.tensor(attention_mask),
(0, pad_len),
mode="constant",
value=0,
)
if _ == 0:
output = sharded_model.forward(input_ids, is_first=True)
else:
output = sharded_model.forward(
input_ids, past_key_values=past_key_values, is_first=False
)
logits = output["logits"]
past_key_values = output["past_key_values"]
new_word = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=1))
if new_word == "</s>":
break
new_sentence += " " + new_word
history[-1][1] = new_sentence
yield history
next_token = torch.argmax(logits[:, input_id_len - 1, :], dim=1)
original_input_ids.append(next_token)
input_ids = [next_token]
print(new_sentence)
return history
global sharkModel
print("In chat")
if sharkModel == 0:
tok = get_tokenizer()
# sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb")
m = AutoModelForCausalLM.from_pretrained(
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32
)
stableLMModel = StableLMModel(m)
sharkModel = compile_stableLM(
stableLMModel,
tuple([input_ids, attention_mask]),
"stableLM_linalg_f32_seqLen256",
os.getcwd(),
)
# Initialize a StopOnTokens object
stop = StopOnTokens()
# Construct the input message string for the model by concatenating the current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
curr_system_message = start_message
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
# print(messages)
# Tokenize the messages string
streamer = TextIteratorStreamer(
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
new_text=messages,
streamer=streamer,
max_new_tokens=512,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=1.0,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop]),
sharkStableLM=sharkModel,
)
words_list = generate(**generate_kwargs)
partial_text = ""
for new_text in words_list:
# print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
yield history
return words_list
with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model = gr.Dropdown(
label="Select Model",
value="TheBloke/vicuna-7B-1.1-HF",
choices=[
"stabilityai/stablelm-tuned-alpha-3b",
"TheBloke/vicuna-7B-1.1-HF",
],
)
device_value = None
for d in available_devices:
if "vulkan" in d:
device_value = d
break
device = gr.Dropdown(
label="Device",
value=device_value if device_value else available_devices[0],
interactive=False,
choices=available_devices,
)
chatbot = gr.Chatbot().style(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
).style(container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit")
stop = gr.Button("Stop")
clear = gr.Button("Clear")
system_msg = gr.Textbox(
start_message, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model],
outputs=[chatbot],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model],
outputs=[chatbot],
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)

View File

@@ -281,21 +281,21 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
custom_model = gr.Dropdown(
txt2img_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
hf_model_id = gr.Textbox(
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download url",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
@@ -502,8 +502,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
txt2img_custom_model,
txt2img_hf_model_id,
custom_vae,
precision,
device,
@@ -538,8 +538,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
seed,
width,
height,
custom_model,
hf_model_id,
txt2img_custom_model,
txt2img_hf_model_id,
],
outputs=[
png_info_img,
@@ -551,7 +551,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
seed,
width,
height,
custom_model,
hf_model_id,
txt2img_custom_model,
txt2img_hf_model_id,
],
)

View File

@@ -309,21 +309,23 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
custom_model = gr.Dropdown(
upscaler_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "None",
else "stabilityai/stable-diffusion-x4-upscaler",
choices=["None"]
+ get_custom_model_files()
+ get_custom_model_files(
custom_checkpoint_type="upscaler"
)
+ predefined_upscaler_models,
)
hf_model_id = gr.Textbox(
upscaler_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download url",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
@@ -525,8 +527,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
upscaler_custom_model,
upscaler_hf_model_id,
custom_vae,
precision,
device,

View File

@@ -72,30 +72,36 @@ def resource_path(relative_path):
return os.path.join(base_path, relative_path)
def create_custom_models_folders():
dir = ["vae", "lora"]
if not args.ckpt_dir:
dir.insert(0, "models")
else:
if not os.path.isdir(args.ckpt_dir):
sys.exit(
f"Invalid --ckpt_dir argument, {args.ckpt_dir} folder does not exists."
)
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
def get_custom_model_path(model="models"):
# If `--ckpt_dir` is provided it'd override the heirarchical folder
# structure in WebUI :-
# model
# models or args.ckpt_dir
# |___lora
# |___vae
sub_folder = "" if model == "models" else model
if args.ckpt_dir:
return Path(args.ckpt_dir)
match model:
case "models":
return Path(Path.cwd(), "models")
case "vae":
return Path(Path.cwd(), "models/vae")
case "lora":
return Path(Path.cwd(), "models/lora")
case _:
return ""
return Path(Path(args.ckpt_dir), sub_folder)
else:
return Path(Path.cwd(), "models/" + sub_folder)
def get_custom_model_pathfile(custom_model_name, model="models"):
return os.path.join(get_custom_model_path(model), custom_model_name)
def get_custom_model_files(model="models"):
def get_custom_model_files(model="models", custom_checkpoint_type=""):
ckpt_files = []
file_types = custom_model_filetypes
if model == "lora":
@@ -107,6 +113,28 @@ def get_custom_model_files(model="models"):
os.path.join(get_custom_model_path(model), extn)
)
]
match custom_checkpoint_type:
case "inpainting":
files = [
val
for val in files
if val.endswith("inpainting" + extn.removeprefix("*"))
]
case "upscaler":
files = [
val
for val in files
if val.endswith("upscaler" + extn.removeprefix("*"))
]
case _:
files = [
val
for val in files
if not (
val.endswith("inpainting" + extn.removeprefix("*"))
or val.endswith("upscaler" + extn.removeprefix("*"))
)
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)

View File

@@ -19,13 +19,16 @@ transformers
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
scipy
ftfy
gradio
gradio==3.22.0
altair
omegaconf
safetensors
opencv-python
scikit-image
pytorch_lightning # for runwayml models
tk
pywebview
sentencepiece
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile

View File

@@ -2,9 +2,10 @@
# Sets up a venv suitable for running samples.
# e.g:
# ./setup_venv.sh #setup a default $PYTHON3 shark.venv
# Environment Variables by the script.
# Environment variables used by the script.
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
# VENV_DIR=myshark.venv #create a venv called myshark.venv
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
# IMPORTER=1 #Install importer deps
# BENCHMARK=1 #Install benchmark deps
@@ -26,15 +27,17 @@ PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; prin
echo "Python: $PYTHON"
echo "Python version: $PYTHON_VERSION_X_Y"
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
VENV_DIR=${VENV_DIR:-shark.venv}
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
PYTHON="$(which python3)"
else
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
if [[ "$SKIP_VENV" != "1" ]]; then
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
VENV_DIR=${VENV_DIR:-shark.venv}
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
PYTHON="$(which python3)"
else
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
fi
fi
Red=`tput setaf 1`
@@ -147,8 +150,7 @@ if [[ ! -z "${ONNX}" ]]; then
fi
fi
if [[ -z "${CONDA_PREFIX}" ]]; then
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
echo "${Green}Before running examples activate venv with:"
echo " ${Green}source $VENV_DIR/bin/activate"
fi

View File

@@ -0,0 +1,73 @@
from transformers import AutoTokenizer, FlaxAutoModel
import torch
import jax
from typing import Union, Dict, List, Any
import numpy as np
from shark.shark_inference import SharkInference
import io
NumpyTree = Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]
def convert_torch_tensor_tree_to_numpy(
tree: Union[torch.tensor, Dict[str, torch.tensor], List[torch.tensor]]
) -> NumpyTree:
return jax.tree_util.tree_map(
lambda torch_tensor: torch_tensor.cpu().detach().numpy(), tree
)
def convert_int64_to_int32(tree: NumpyTree) -> NumpyTree:
return jax.tree_util.tree_map(
lambda tensor: np.array(tensor, dtype=np.int32)
if tensor.dtype == np.int64
else tensor,
tree,
)
def get_sample_input():
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased"
)
inputs_torch = tokenizer("Hello, World!", return_tensors="pt")
return convert_int64_to_int32(
convert_torch_tensor_tree_to_numpy(inputs_torch.data)
)
def get_jax_model():
return FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
def export_jax_to_mlir(jax_model: Any, sample_input: NumpyTree):
model_mlir = jax.jit(jax_model).lower(**sample_input).compiler_ir()
byte_stream = io.BytesIO()
model_mlir.operation.write_bytecode(file=byte_stream)
return byte_stream.getvalue()
def assert_array_list_allclose(x, y, *args, **kwargs):
assert len(x) == len(y)
for a, b in zip(x, y):
np.testing.assert_allclose(
np.asarray(a), np.asarray(b), *args, **kwargs
)
sample_input = get_sample_input()
jax_model = get_jax_model()
mlir = export_jax_to_mlir(jax_model, sample_input)
# Compile and load module.
shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo")
shark_inference.compile()
# Run main function.
result = shark_inference("main", jax.tree_util.tree_flatten(sample_input)[0])
# Run JAX model.
reference_result = jax.tree_util.tree_flatten(jax_model(**sample_input))[0]
# Verify result.
assert_array_list_allclose(result, reference_result, atol=1e-5)

View File

@@ -0,0 +1,6 @@
flax
jax[cpu]
nodai-SHARK
orbax
transformers
torch

View File

@@ -45,10 +45,15 @@ def run_cmd(cmd, debug=False):
def iree_device_map(device):
uri_parts = device.split("://", 2)
iree_driver = (
_IREE_DEVICE_MAP[uri_parts[0]]
if uri_parts[0] in _IREE_DEVICE_MAP
else uri_parts[0]
)
if len(uri_parts) == 1:
return _IREE_DEVICE_MAP[uri_parts[0]]
return iree_driver
else:
return f"{_IREE_DEVICE_MAP[uri_parts[0]]}://{uri_parts[1]}"
return f"{iree_driver}://{uri_parts[1]}"
def get_supported_device_list():
@@ -68,7 +73,7 @@ _IREE_DEVICE_MAP = {
def iree_target_map(device):
if "://" in device:
device = device.split("://")[0]
return _IREE_TARGET_MAP[device]
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
_IREE_TARGET_MAP = {
@@ -110,10 +115,8 @@ def check_device_drivers(device):
subprocess.check_output("rocminfo")
except Exception:
return True
# Unknown device.
else:
return True
# Unknown device. We assume drivers are installed.
return False

View File

@@ -23,6 +23,7 @@ import re
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan"]:
@@ -30,6 +31,9 @@ def get_iree_device_args(device, extra_args=[]):
f"Specific device selection only supported for vulkan now."
f"Proceeding with {device} as device."
)
device_num = device_uri[1]
else:
device_num = 0
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
@@ -42,7 +46,9 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] in ["metal", "vulkan"]:
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(extra_args=extra_args)
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
@@ -307,7 +313,7 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
ModuleCompiled = ctx.modules.module
ModuleCompiled = getattr(ctx.modules, vm_module.name)
return ModuleCompiled, config

View File

@@ -21,7 +21,7 @@ from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_vulkan_device_name():
def get_vulkan_device_name(device_num=0):
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
@@ -31,8 +31,8 @@ def get_vulkan_device_name():
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing first one: {vulkaninfo_list[0]}")
return vulkaninfo_list[0]
print(f"Choosing device: {vulkaninfo_list[device_num]}")
return vulkaninfo_list[device_num]
def get_os_name():
@@ -119,14 +119,14 @@ def get_vulkan_target_triple(device_name):
return triple
def get_vulkan_triple_flag(device_name="", extra_args=[]):
def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
for flag in extra_args:
if "-iree-vulkan-target-triple=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None
if device_name == "" or device_name == [] or device_name is None:
vulkan_device = get_vulkan_device_name()
vulkan_device = get_vulkan_device_name(device_num=device_num)
else:
vulkan_device = device_name
triple = get_vulkan_target_triple(vulkan_device)
@@ -144,7 +144,7 @@ def get_vulkan_triple_flag(device_name="", extra_args=[]):
return None
def get_iree_vulkan_args(extra_args=[]):
def get_iree_vulkan_args(device_num=0, extra_args=[]):
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
res_vulkan_flag = []
@@ -156,7 +156,9 @@ def get_iree_vulkan_args(extra_args=[]):
break
if vulkan_triple_flag is None:
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
vulkan_triple_flag = get_vulkan_triple_flag(
device_num=device_num, extra_args=extra_args
)
if vulkan_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)

View File

@@ -30,8 +30,8 @@ import os
import sys
from typing import Dict, List
import iree.compiler._mlir_libs
from iree.compiler import ir
from iree.compiler.transforms import ireec as ireec_trans
def model_annotation(
@@ -409,7 +409,6 @@ def shape_list_to_string(input):
def create_context() -> ir.Context:
context = ir.Context()
ireec_trans.register_all_dialects(context)
context.allow_unregistered_dialects = True
return context

View File

@@ -196,7 +196,7 @@ def download_model(
tank_url=None,
frontend=None,
tuned=None,
import_args=None,
import_args={"batch_size": 1},
):
model_name = model_name.replace("/", "_")
dyn_str = "_dynamic" if dynamic else ""

View File

@@ -81,7 +81,7 @@ class SharkImporter:
# NOTE: The default function for torch is "forward" and tf-lite is "main".
def _torch_mlir(self, is_dynamic, tracing_required):
def _torch_mlir(self, is_dynamic, tracing_required, mlir_type):
from shark.torch_mlir_utils import get_torch_mlir_module
return get_torch_mlir_module(
@@ -90,6 +90,7 @@ class SharkImporter:
is_dynamic,
tracing_required,
self.return_str,
mlir_type,
)
def _tf_mlir(self, func_name, save_dir="."):
@@ -120,6 +121,7 @@ class SharkImporter:
tracing_required=False,
func_name="forward",
save_dir="./shark_tmp/",
mlir_type="linalg",
):
if self.frontend in ["torch", "pytorch"]:
if self.inputs == None:
@@ -127,7 +129,10 @@ class SharkImporter:
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module"
)
sys.exit(1)
return self._torch_mlir(is_dynamic, tracing_required), func_name
return (
self._torch_mlir(is_dynamic, tracing_required, mlir_type),
func_name,
)
if self.frontend in ["tf", "tensorflow"]:
return self._tf_mlir(func_name, save_dir), func_name
if self.frontend in ["tflite", "tf-lite"]:
@@ -186,6 +191,7 @@ class SharkImporter:
dir=tempfile.gettempdir(),
model_name="model",
golden_values=None,
mlir_type="linalg",
):
if self.inputs == None:
print(
@@ -199,6 +205,7 @@ class SharkImporter:
tracing_required,
func_name,
save_dir=artifact_path,
mlir_type=mlir_type,
)
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
# TODO: Check for multiple outputs.

View File

@@ -19,6 +19,12 @@ import tempfile
from shark.parser import shark_args
import io
mlir_type_mapping_dict = {
"linalg": torch_mlir.OutputType.LINALG_ON_TENSORS,
"stablehlo": torch_mlir.OutputType.STABLEHLO,
"tosa": torch_mlir.OutputType.TOSA,
}
def get_module_name_for_asm_dump(module):
"""Gets a name suitable for an assembly dump.
@@ -57,6 +63,7 @@ def get_torch_mlir_module(
dynamic: bool,
jit_trace: bool,
return_str: bool = False,
mlir_type: str = "linalg",
):
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
ignore_traced_shapes = False
@@ -70,10 +77,11 @@ def get_torch_mlir_module(
mlir_module = torch_mlir.compile(
module,
input,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
output_type=mlir_type_mapping_dict[mlir_type],
use_tracing=jit_trace,
ignore_traced_shapes=ignore_traced_shapes,
)
if return_str:
return mlir_module.operation.get_asm()
bytecode_stream = io.BytesIO()

View File

@@ -15,16 +15,16 @@ microsoft/layoutlm-base-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False
microsoft/mpnet-base,mhlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with aten.tanh in torch-mlir",""
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,False,True,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-large-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344",""
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344",""
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/388","macos"
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/343","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
@@ -36,12 +36,12 @@ wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,Fal
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243",""
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails on MacOS builder, VK device lost","macos"
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"",""
gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported",""
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
stabilityai/stable-diffusion-2-1-base,linalg,torch,1e-3,1e-3,default,None,True,False,False,"",""
1 resnet50 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False False macos
15 microsoft/mpnet-base mhlo tf 1e-2 1e-2 default None True True True
16 albert-base-v2 linalg torch 1e-2 1e-3 default None True True True issue with aten.tanh in torch-mlir
17 alexnet linalg torch 1e-2 1e-3 default None True True False https://github.com/nod-ai/SHARK/issues/879
18 bert-base-cased linalg torch 1e-2 1e-3 default None False False True False
19 bert-base-uncased linalg torch 1e-2 1e-3 default None False False True False
20 bert-base-uncased_fp16 linalg torch 1e-1 1e-1 default None True False True True
21 bert-large-uncased linalg torch 1e-2 1e-3 default None False False True False
22 bert-large-uncased mhlo tf 1e-2 1e-3 default None False False False
23 facebook/deit-small-distilled-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True False Fails during iree-compile.
24 google/vit-base-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/311
25 microsoft/beit-base-patch16-224-pt22k-ft22k linalg torch 1e-2 1e-3 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/390 macos
26 microsoft/MiniLM-L12-H384-uncased linalg torch 1e-2 1e-3 default None False False True False
27 google/mobilebert-uncased linalg torch 1e-2 1e-3 default None False False True False https://github.com/nod-ai/SHARK/issues/344
28 mobilenet_v3_small linalg torch 1e-1 1e-2 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/388 macos
29 nvidia/mit-b0 linalg torch 1e-2 1e-3 default None True True False https://github.com/nod-ai/SHARK/issues/343 macos
30 resnet101 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col False False False macos
36 efficientnet-v2-s mhlo tf 1e-02 1e-3 default nhcw-nhwc False False False macos
37 mnasnet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True macos
38 efficientnet_b0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True False https://github.com/nod-ai/SHARK/issues/1243
39 efficientnet_b7 linalg torch 1e-2 1e-3 default nhcw-nhwc False False True False Fails on MacOS builder, VK device lost macos
40 efficientnet_b0 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False False
41 efficientnet_b7 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False False Fails on MacOS builder, VK device lost macos
42 gpt2 mhlo tf 1e-2 1e-3 default None True False False macos
43 t5-base linalg torch 1e-2 1e-3 default None True True True Inputs for seq2seq models in torch currently unsupported.
44 t5-base mhlo tf 1e-2 1e-3 default None False False False
45 t5-large linalg torch 1e-2 1e-3 default None True True True Inputs for seq2seq models in torch currently unsupported macos
46 t5-large mhlo tf 1e-2 1e-3 default None False False False macos
47 stabilityai/stable-diffusion-2-1-base linalg torch 1e-3 1e-3 default None True False False

View File

@@ -50,6 +50,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
tracing_required = row[1]
model_type = row[2]
is_dynamic = row[3]
mlir_type = row[4]
tracing_required = False if tracing_required == "False" else True
is_dynamic = False if is_dynamic == "False" else True
@@ -121,6 +122,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name,
mlir_type=mlir_type,
)
# Generate torch dynamic models.
if is_dynamic:
@@ -129,6 +131,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name + "_dynamic",
mlir_type=mlir_type,
)

View File

@@ -1,23 +1,24 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
microsoft/MiniLM-L12-H384-uncased,True,hf,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
bert-base-uncased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
google/mobilebert-uncased,True,hf,True,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
alexnet,False,vision,True,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
resnet18,False,vision,True,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
resnet50,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
resnet101,False,vision,True,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
squeezenet1_0,False,vision,True,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
wide_resnet50_2,False,vision,True,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
mobilenet_v3_small,False,vision,True,2.5M,"image-classification,cnn,mobile",N/A
google/vit-base-patch16-224,True,hf_img_cls,False,86M,"image-classification,vision-transformer,transformer-encoder",N/A
microsoft/resnet-50,True,hf_img_cls,False,23M,"image-classification,cnn,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,22M,"image-classification,vision-transformer,cnn",N/A
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encoder",SegFormer
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
model_name, use_tracing, model_type, dynamic, mlir_type, param_count, tags, notes
efficientnet_b0,True,vision,False,linalg,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,vision,False,linalg,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
microsoft/MiniLM-L12-H384-uncased,True,hf,True,linalg,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
bert-base-uncased,True,hf,True,linalg,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,hf,True,linalg,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
google/mobilebert-uncased,True,hf,True,linalg,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
alexnet,False,vision,True,linalg,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
resnet18,False,vision,True,linalg,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
resnet50,False,vision,True,linalg,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
resnet101,False,vision,True,linalg,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
squeezenet1_0,False,vision,True,linalg,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
wide_resnet50_2,False,vision,True,linalg,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
mobilenet_v3_small,False,vision,True,linalg,2.5M,"image-classification,cnn,mobile",N/A
google/vit-base-patch16-224,True,hf_img_cls,False,linalg,86M,"image-classification,vision-transformer,transformer-encoder",N/A
microsoft/resnet-50,True,hf_img_cls,False,linalg,23M,"image-classification,cnn,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,linalg,22M,"image-classification,vision-transformer,cnn",N/A
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,linalg,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
nvidia/mit-b0,True,hf_img_cls,False,linalg,3.7M,"image-classification,transformer-encoder",SegFormer
mnasnet1_0,False,vision,True,linalg,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,linalg,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,linalg,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-large-uncased,True,hf,True,linalg,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
bert-base-uncased,True,hf,False,stablehlo,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
1 model_name use_tracing model_type dynamic mlir_type param_count tags notes
2 efficientnet_b0 True vision False linalg 5.3M image-classification;cnn;conv2d;depthwise-conv Smallest EfficientNet variant with 224x224 input
3 efficientnet_b7 True vision False linalg 66M image-classification;cnn;conv2d;depthwise-conv Largest EfficientNet variant with 600x600 input
4 microsoft/MiniLM-L12-H384-uncased True hf True linalg 66M nlp;bert-variant;transformer-encoder Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)
5 bert-base-uncased True hf True linalg 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
6 bert-base-cased True hf True linalg 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
7 google/mobilebert-uncased True hf True linalg 25M nlp,bert-variant,transformer-encoder,mobile 24 layers, 512 hidden size, 128 embedding
8 alexnet False vision True linalg 61M cnn,parallel-layers The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod.
9 resnet18 False vision True linalg 11M cnn,image-classification,residuals,resnet-variant 1 7x7 conv2d and the rest are 3x3 conv2d
10 resnet50 False vision True linalg 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
11 resnet101 False vision True linalg 29M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
12 squeezenet1_0 False vision True linalg 1.25M cnn,image-classification,mobile,parallel-layers Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)
13 wide_resnet50_2 False vision True linalg 69M cnn,image-classification,residuals,resnet-variant Resnet variant where model depth is decreased and width is increased.
14 mobilenet_v3_small False vision True linalg 2.5M image-classification,cnn,mobile N/A
15 google/vit-base-patch16-224 True hf_img_cls False linalg 86M image-classification,vision-transformer,transformer-encoder N/A
16 microsoft/resnet-50 True hf_img_cls False linalg 23M image-classification,cnn,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
17 facebook/deit-small-distilled-patch16-224 True hf_img_cls False linalg 22M image-classification,vision-transformer,cnn N/A
18 microsoft/beit-base-patch16-224-pt22k-ft22k True hf_img_cls False linalg 86M image-classification,transformer-encoder,bert-variant,vision-transformer N/A
19 nvidia/mit-b0 True hf_img_cls False linalg 3.7M image-classification,transformer-encoder SegFormer
20 mnasnet1_0 False vision True linalg - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
21 resnet50_fp16 False vision True linalg 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
22 bert-base-uncased_fp16 True fp16 False linalg 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
23 bert-large-uncased True hf True linalg 330M nlp;bert-variant;transformer-encoder 24 layers, 1024 hidden units, 16 attention heads
24 bert-base-uncased True hf False stablehlo 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads