mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
37 Commits
20230430.7
...
20230516.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
20114deea0 | ||
|
|
9acf519078 | ||
|
|
bdf37b5311 | ||
|
|
8ee2ac89f8 | ||
|
|
60cb48be2e | ||
|
|
86a215b063 | ||
|
|
d6e3a9a236 | ||
|
|
a0097a1ead | ||
|
|
a9bae00606 | ||
|
|
4731c1a835 | ||
|
|
4c07e47e8c | ||
|
|
e0cc2871bb | ||
|
|
649f39408b | ||
|
|
c142297d73 | ||
|
|
9e07360b00 | ||
|
|
7b74c86e42 | ||
|
|
fa833f8366 | ||
|
|
fcb059aa38 | ||
|
|
517c670f82 | ||
|
|
59df14f18b | ||
|
|
6c95ac0f37 | ||
|
|
7a4a51ae73 | ||
|
|
d816cc015e | ||
|
|
54ce3d48ca | ||
|
|
0e4a8ca240 | ||
|
|
6ca1298675 | ||
|
|
bbef7a6464 | ||
|
|
cdf2d61d53 | ||
|
|
6c14847d1f | ||
|
|
68ecdd2a73 | ||
|
|
3f4d444d18 | ||
|
|
e473d0375b | ||
|
|
e38d96850f | ||
|
|
fed63dfd4b | ||
|
|
eba4d06405 | ||
|
|
4cfba153d2 | ||
|
|
307c05f38d |
2
.github/workflows/test-models.yml
vendored
2
.github/workflows/test-models.yml
vendored
@@ -137,7 +137,7 @@ jobs:
|
||||
export DYLD_LIBRARY_PATH=/usr/local/lib/
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan --update_tank
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
|
||||
@@ -1,26 +1,27 @@
|
||||
import torch
|
||||
import shark
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_inference import SharkInference
|
||||
import torch_mlir
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
pipeline,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
TextIteratorStreamer,
|
||||
)
|
||||
import torch_mlir
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
base_models,
|
||||
get_opt_flags,
|
||||
get_vmfb_path_name,
|
||||
)
|
||||
from apps.stable_diffusion.src.models.model_wrappers import replace_shape_str
|
||||
import time
|
||||
import numpy as np
|
||||
from torch.nn import functional as F
|
||||
import os
|
||||
from threading import Thread
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"stabilityai/stablelm-tuned-alpha-7b"
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
@@ -34,6 +35,203 @@ class StopOnTokens(StoppingCriteria):
|
||||
return False
|
||||
|
||||
|
||||
def shouldStop(tokens):
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 256
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
def get_torch_mlir_module_bytecode(model, model_inputs):
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
# tracing_mode='symbolic',
|
||||
)(*model_inputs)
|
||||
print("Got FX_G")
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
print("FX_G recompile")
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
print("Got TS_G")
|
||||
return ts_g
|
||||
|
||||
|
||||
def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
|
||||
# ADD Device Arg
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
vmfb_path = Path(model_vmfb_name + ".vmfb")
|
||||
if vmfb_path.exists():
|
||||
print("Loading ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device="cuda", mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path { mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path) as f:
|
||||
bytecode = f.read("rb")
|
||||
else:
|
||||
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*model_inputs],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(model_name + ".mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir")
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
import os
|
||||
|
||||
path = shark_module.save_module(os.getcwd(), model_vmfb_name, [])
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
class StableLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
|
||||
|
||||
# Initialize a StopOnTokens object
|
||||
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
@@ -41,167 +239,65 @@ system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
|
||||
prompt = f"{system_prompt}<|USER|>What's your mood today?<|ASSISTANT|>"
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
def get_tokenizer():
|
||||
model_path = "stabilityai/stablelm-tuned-alpha-3b"
|
||||
tok = AutoTokenizer.from_pretrained(model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
print(f"Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
|
||||
class SLM(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
"stabilityai/stablelm-tuned-alpha-7b"
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
def generate(
|
||||
new_text,
|
||||
streamer,
|
||||
max_new_tokens,
|
||||
do_sample,
|
||||
top_p,
|
||||
top_k,
|
||||
temperature,
|
||||
num_beams,
|
||||
stopping_criteria,
|
||||
sharkStableLM,
|
||||
tok=None,
|
||||
input_ids=torch.randint(3, (1, 256)),
|
||||
attention_mask=torch.randint(3, (1, 256)),
|
||||
):
|
||||
if tok == None:
|
||||
tok = get_tokenizer
|
||||
# Construct the input message string for the model by concatenating the current system message and conversation history
|
||||
# Tokenize the messages string
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
words_list = []
|
||||
for i in range(max_new_tokens):
|
||||
numWords = len(new_text.split())
|
||||
# if(numWords>220):
|
||||
# break
|
||||
model_inputs = tok(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.model(input_ids, attention_mask)[0]
|
||||
|
||||
|
||||
slm_model = SLM()
|
||||
|
||||
res_pytorch = slm_model(inputs["input_ids"], inputs["attention_mask"])
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
|
||||
fx_g = make_fx(
|
||||
slm_model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(inputs["input_ids"], inputs["attention_mask"])
|
||||
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
[inputs["input_ids"], inputs["attention_mask"]],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
result_shark = shark_module(
|
||||
"forward", [inputs["input_ids"], inputs["attention_mask"]]
|
||||
)
|
||||
|
||||
print("Result PyTorch")
|
||||
print(res_pytorch)
|
||||
print("Result SHARK")
|
||||
print(result_shark)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
output = sharkStableLM(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
if shouldStop(next_toks.indices):
|
||||
break
|
||||
# streamer.put(next_toks.indices[0][int(sum_attentionmask)-1])
|
||||
new_word = tok.decode(
|
||||
next_toks.indices[0][int(sum_attentionmask) - 1],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
print(new_word, end="", flush=True)
|
||||
words_list.append(new_word)
|
||||
if new_word == "":
|
||||
break
|
||||
new_text = new_text + new_word
|
||||
return words_list
|
||||
|
||||
656
apps/language_models/scripts/vicuna.py
Normal file
656
apps/language_models/scripts/vicuna.py
Normal file
@@ -0,0 +1,656 @@
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_compiler")
|
||||
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_runtime")
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import transform_fx as transform_fx_
|
||||
import re
|
||||
from shark.shark_inference import SharkInference
|
||||
from tqdm import tqdm
|
||||
from torch_mlir import TensorPlaceholder
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class SecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class CompiledFirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CompiledSecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers0, layers1):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert len(layers0) == len(model.model.layers)
|
||||
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers0 = layers0
|
||||
self.layers1 = layers1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
is_first=True,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
if is_first:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers0
|
||||
)
|
||||
return self.model.forward(input_ids, attention_mask=attention_mask)
|
||||
else:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers1
|
||||
)
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
def write_in_dynamic_inputs0(module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
|
||||
def write_in_dynamic_inputs1(module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
if "dim_42 =" in line:
|
||||
continue
|
||||
if f"%c{dynamic_input_size}_i64 =" in line:
|
||||
new_lines.append(
|
||||
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
|
||||
)
|
||||
new_lines.append(
|
||||
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
|
||||
)
|
||||
continue
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim_42)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim_42\)",
|
||||
"tensor.empty(%dim_42, %dim_42)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
|
||||
def compile_vicuna_layer(
|
||||
vicuna_layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0=None,
|
||||
past_key_value1=None,
|
||||
):
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
attention_mask, dynamic_axes=[2, 3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
position_ids, dynamic_axes=[1]
|
||||
)
|
||||
|
||||
if past_key_value0 is None and past_key_value1 is None:
|
||||
fx_g = make_fx(
|
||||
vicuna_layer,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(hidden_states, attention_mask, position_ids)
|
||||
|
||||
else:
|
||||
fx_g = make_fx(
|
||||
vicuna_layer,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
)
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
print("FX_G recompile")
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
return ts_g
|
||||
|
||||
|
||||
def get_model_and_tokenizer(path="TheBloke/vicuna-7B-1.1-HF"):
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
|
||||
return vicuna_model, tokenizer
|
||||
|
||||
|
||||
def compile_to_vmfb(inputs, layers, is_first=True):
|
||||
mlirs, modules = [], []
|
||||
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
|
||||
if is_first:
|
||||
mlir_path = Path(f"{idx}_0.mlir")
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
else:
|
||||
mlir_path = Path(f"{idx}_1.mlir")
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if vmfb_path.exists():
|
||||
continue
|
||||
if mlir_path.exists():
|
||||
# print(f"Found layer {idx} mlir")
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
inputs[0], dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
inputs[1], dynamic_axes=[3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
inputs[2], dynamic_axes=[1]
|
||||
)
|
||||
if not is_first:
|
||||
pkv0_placeholder = TensorPlaceholder.like(
|
||||
inputs[3], dynamic_axes=[2]
|
||||
)
|
||||
pkv1_placeholder = TensorPlaceholder.like(
|
||||
inputs[4], dynamic_axes=[2]
|
||||
)
|
||||
print(f"Compiling layer {idx} mlir")
|
||||
if is_first:
|
||||
ts_g = compile_vicuna_layer(
|
||||
layer, inputs[0], inputs[1], inputs[2]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
hidden_states_placeholder,
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
else:
|
||||
ts_g = compile_vicuna_layer(
|
||||
layer,
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
inputs[3],
|
||||
inputs[4],
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
inputs[0],
|
||||
attention_mask_placeholder,
|
||||
inputs[2],
|
||||
pkv0_placeholder,
|
||||
pkv1_placeholder,
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# bytecode_stream = BytesIO()
|
||||
# module.operation.write_bytecode(bytecode_stream)
|
||||
# bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if is_first:
|
||||
module = write_in_dynamic_inputs0(str(module), 137)
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
else:
|
||||
module = write_in_dynamic_inputs1(str(module), 138)
|
||||
if idx in [0, 5, 6, 7]:
|
||||
module_str = module
|
||||
module_str = module_str.splitlines()
|
||||
new_lines = []
|
||||
for line in module_str:
|
||||
if len(line) < 1000:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line[:999])
|
||||
module_str = "\n".join(new_lines)
|
||||
f1_ = open(f"{idx}_1_test.mlir", "w+")
|
||||
f1_.write(module_str)
|
||||
f1_.close()
|
||||
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
mlirs.append(bytecode)
|
||||
|
||||
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
|
||||
if is_first:
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_0",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
else:
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if idx < 25:
|
||||
device = "vulkan"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_1",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
|
||||
return mlirs, modules
|
||||
|
||||
|
||||
def get_sharded_model():
|
||||
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
|
||||
# please don't change it
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
vicuna_model = get_model_and_tokenizer()[0]
|
||||
|
||||
placeholder_input0 = (
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
|
||||
)
|
||||
|
||||
placeholder_input1 = (
|
||||
torch.zeros([1, 1, 4096]),
|
||||
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
|
||||
torch.zeros([1, 1], dtype=torch.int64),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
|
||||
layers0 = [FirstVicunaLayer(layer) for layer in vicuna_model.model.layers]
|
||||
_, modules0 = compile_to_vmfb(placeholder_input0, layers0, is_first=True)
|
||||
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
|
||||
|
||||
layers1 = [SecondVicunaLayer(layer) for layer in vicuna_model.model.layers]
|
||||
_, modules1 = compile_to_vmfb(placeholder_input1, layers1, is_first=False)
|
||||
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
|
||||
|
||||
sharded_model = ShardedVicunaModel(
|
||||
vicuna_model, shark_layers0, shark_layers1
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
sharded_model = get_sharded_model()
|
||||
tokenizer = get_model_and_tokenizer()[1]
|
||||
past_key_values = None
|
||||
while True:
|
||||
print("\n\n")
|
||||
user_prompt = input("User: ")
|
||||
prompt_history = (
|
||||
prompt_history + "USER:\n" + user_prompt + prologue_prompt
|
||||
)
|
||||
prompt = prompt_history.strip()
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
tokens = input_ids
|
||||
prompt = print("Robot:", end=" ")
|
||||
new_sentence = []
|
||||
max_response_len = 1000
|
||||
for iteration in range(max_response_len):
|
||||
original_input_ids = input_ids
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
|
||||
if iteration == 0:
|
||||
output = sharded_model.forward(input_ids, is_first=True)
|
||||
else:
|
||||
output = sharded_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=False
|
||||
)
|
||||
logits = output["logits"]
|
||||
past_key_values = output["past_key_values"]
|
||||
new_token = int(torch.argmax(logits[:, -1, :], dim=1)[0])
|
||||
if new_token == 2:
|
||||
break
|
||||
new_sentence += [new_token]
|
||||
tokens.append(new_token)
|
||||
original_input_ids.append(new_token)
|
||||
input_ids = [new_token]
|
||||
|
||||
for i in range(len(tokens)):
|
||||
if type(tokens[i]) != int:
|
||||
tokens[i] = int(tokens[i][0])
|
||||
new_sentence_str = tokenizer.decode(new_sentence)
|
||||
print(new_sentence_str)
|
||||
prompt_history += f"\n{new_sentence_str}\n"
|
||||
777
apps/language_models/scripts/vicuna_web.py
Normal file
777
apps/language_models/scripts/vicuna_web.py
Normal file
@@ -0,0 +1,777 @@
|
||||
import sys
|
||||
import warnings
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_compiler")
|
||||
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_runtime")
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import transform_fx as transform_fx_
|
||||
import re
|
||||
from shark.shark_inference import SharkInference
|
||||
from tqdm import tqdm
|
||||
from torch_mlir import TensorPlaceholder
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class SecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class CompiledFirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CompiledSecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers0, layers1):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert len(layers0) == len(model.model.layers)
|
||||
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers0 = layers0
|
||||
self.layers1 = layers1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
is_first=True,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
if is_first:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers0
|
||||
)
|
||||
return self.model.forward(input_ids, attention_mask=attention_mask)
|
||||
else:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers1
|
||||
)
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
def write_in_dynamic_inputs0(module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
|
||||
def write_in_dynamic_inputs1(module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
if "dim_42 =" in line:
|
||||
continue
|
||||
if f"%c{dynamic_input_size}_i64 =" in line:
|
||||
new_lines.append(
|
||||
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
|
||||
)
|
||||
new_lines.append(
|
||||
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
|
||||
)
|
||||
continue
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim_42)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim_42\)",
|
||||
"tensor.empty(%dim_42, %dim_42)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
|
||||
def compile_vicuna_layer(
|
||||
vicuna_layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0=None,
|
||||
past_key_value1=None,
|
||||
):
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
attention_mask, dynamic_axes=[2, 3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
position_ids, dynamic_axes=[1]
|
||||
)
|
||||
|
||||
if past_key_value0 is None and past_key_value1 is None:
|
||||
fx_g = make_fx(
|
||||
vicuna_layer,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(hidden_states, attention_mask, position_ids)
|
||||
|
||||
else:
|
||||
fx_g = make_fx(
|
||||
vicuna_layer,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
)
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
print("FX_G recompile")
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
return ts_g
|
||||
|
||||
|
||||
path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
|
||||
|
||||
|
||||
def compile_to_vmfb(inputs, layers, is_first=True):
|
||||
mlirs, modules = [], []
|
||||
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
|
||||
if is_first:
|
||||
mlir_path = Path(f"{idx}_0.mlir")
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
else:
|
||||
mlir_path = Path(f"{idx}_1.mlir")
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if vmfb_path.exists():
|
||||
continue
|
||||
if mlir_path.exists():
|
||||
# print(f"Found layer {idx} mlir")
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
inputs[0], dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
inputs[1], dynamic_axes=[3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
inputs[2], dynamic_axes=[1]
|
||||
)
|
||||
if not is_first:
|
||||
pkv0_placeholder = TensorPlaceholder.like(
|
||||
inputs[3], dynamic_axes=[2]
|
||||
)
|
||||
pkv1_placeholder = TensorPlaceholder.like(
|
||||
inputs[4], dynamic_axes=[2]
|
||||
)
|
||||
print(f"Compiling layer {idx} mlir")
|
||||
if is_first:
|
||||
ts_g = compile_vicuna_layer(
|
||||
layer, inputs[0], inputs[1], inputs[2]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
hidden_states_placeholder,
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
else:
|
||||
ts_g = compile_vicuna_layer(
|
||||
layer,
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
inputs[3],
|
||||
inputs[4],
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
inputs[0],
|
||||
attention_mask_placeholder,
|
||||
inputs[2],
|
||||
pkv0_placeholder,
|
||||
pkv1_placeholder,
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# bytecode_stream = BytesIO()
|
||||
# module.operation.write_bytecode(bytecode_stream)
|
||||
# bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if is_first:
|
||||
module = write_in_dynamic_inputs0(str(module), 137)
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
else:
|
||||
module = write_in_dynamic_inputs1(str(module), 138)
|
||||
if idx in [0, 5, 6, 7]:
|
||||
module_str = module
|
||||
module_str = module_str.splitlines()
|
||||
new_lines = []
|
||||
for line in module_str:
|
||||
if len(line) < 1000:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line[:999])
|
||||
module_str = "\n".join(new_lines)
|
||||
f1_ = open(f"{idx}_1_test.mlir", "w+")
|
||||
f1_.write(module_str)
|
||||
f1_.close()
|
||||
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
mlirs.append(bytecode)
|
||||
|
||||
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
|
||||
if is_first:
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_0",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
else:
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if idx < 25:
|
||||
device = "vulkan"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_1",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
|
||||
return mlirs, modules
|
||||
|
||||
|
||||
def get_sharded_model():
|
||||
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
|
||||
# please don't change it
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
global vicuna_model
|
||||
|
||||
placeholder_input0 = (
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
|
||||
)
|
||||
|
||||
placeholder_input1 = (
|
||||
torch.zeros([1, 1, 4096]),
|
||||
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
|
||||
torch.zeros([1, 1], dtype=torch.int64),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
|
||||
layers0 = [FirstVicunaLayer(layer) for layer in vicuna_model.model.layers]
|
||||
_, modules0 = compile_to_vmfb(placeholder_input0, layers0, is_first=True)
|
||||
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
|
||||
|
||||
layers1 = [SecondVicunaLayer(layer) for layer in vicuna_model.model.layers]
|
||||
_, modules1 = compile_to_vmfb(placeholder_input1, layers1, is_first=False)
|
||||
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
|
||||
|
||||
sharded_model = ShardedVicunaModel(
|
||||
vicuna_model, shark_layers0, shark_layers1
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
|
||||
sharded_model = get_sharded_model()
|
||||
|
||||
|
||||
def user(message, history):
|
||||
print("msg=", message)
|
||||
print("history=", history)
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
def chat(curr_system_message, history):
|
||||
global sharded_model
|
||||
past_key_values = None
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
print(messages)
|
||||
prompt = messages.strip()
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
tokens = input_ids
|
||||
new_sentence = []
|
||||
max_response_len = 1000
|
||||
partial_sentence = []
|
||||
partial_text = ""
|
||||
start_time = time.time()
|
||||
for iteration in range(max_response_len):
|
||||
original_input_ids = input_ids
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
|
||||
if iteration == 0:
|
||||
output = sharded_model.forward(input_ids, is_first=True)
|
||||
else:
|
||||
output = sharded_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=False
|
||||
)
|
||||
logits = output["logits"]
|
||||
past_key_values = output["past_key_values"]
|
||||
new_token = int(torch.argmax(logits[:, -1, :], dim=1)[0])
|
||||
if new_token == 2:
|
||||
break
|
||||
new_sentence += [new_token]
|
||||
partial_sentence += [new_token]
|
||||
if iteration > 0 and iteration % 2 == 0:
|
||||
new_text = tokenizer.decode(partial_sentence)
|
||||
partial_sentence = []
|
||||
print(new_text, " ")
|
||||
partial_text += new_text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
|
||||
tokens.append(new_token)
|
||||
original_input_ids.append(new_token)
|
||||
input_ids = [new_token]
|
||||
end_time = time.time()
|
||||
print(
|
||||
f"Total time taken to generated response is {end_time-start_time} seconds"
|
||||
)
|
||||
|
||||
for i in range(len(tokens)):
|
||||
if type(tokens[i]) != int:
|
||||
tokens[i] = int(tokens[i][0])
|
||||
new_sentence_str = tokenizer.decode(new_sentence)
|
||||
print(new_sentence_str)
|
||||
history[-1][1] = new_sentence_str
|
||||
return history
|
||||
|
||||
|
||||
system_msg = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
# history_eg = [["hi hello how are you", ""]]
|
||||
# print(chat(system_msg, history_eg))
|
||||
|
||||
with gr.Blocks(title="Chatbot") as vicuna_chat:
|
||||
with gr.Row():
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value="TheBloke/vicuna-7B-1.1-HF",
|
||||
choices=[
|
||||
"TheBloke/vicuna-7B-1.1-HF",
|
||||
],
|
||||
)
|
||||
device_value = None
|
||||
for d in available_devices:
|
||||
if "vulkan" in d:
|
||||
device_value = d
|
||||
break
|
||||
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=device_value if device_value else available_devices[0],
|
||||
interactive=False,
|
||||
choices=available_devices,
|
||||
)
|
||||
chatbot = gr.Chatbot().style(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
).style(container=False)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit")
|
||||
stop = gr.Button("Stop")
|
||||
clear = gr.Button("Clear")
|
||||
system_msg = gr.Textbox(
|
||||
system_msg, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[submit_event, submit_click_event],
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
p.add_argument(
|
||||
"--share",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for generating a public URL",
|
||||
)
|
||||
p.add_argument(
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=8080,
|
||||
help="flag for setting server port",
|
||||
)
|
||||
args, unknown = p.parse_known_args()
|
||||
|
||||
vicuna_chat.queue()
|
||||
vicuna_chat.launch(
|
||||
share=args.share,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.server_port,
|
||||
)
|
||||
@@ -31,6 +31,7 @@ datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('shark')
|
||||
datas += collect_data_files('tkinter')
|
||||
datas += collect_data_files('webview')
|
||||
datas += collect_data_files('sentencepiece')
|
||||
datas += [
|
||||
( 'src/utils/resources/prompts.json', 'resources' ),
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
|
||||
@@ -14,6 +14,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
base_models,
|
||||
args,
|
||||
preprocessCKPT,
|
||||
convert_original_vae,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_stem,
|
||||
@@ -571,8 +572,12 @@ class SharkifyStableDiffusionModel:
|
||||
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
|
||||
if "state_dict" in vae_checkpoint:
|
||||
vae_checkpoint = vae_checkpoint["state_dict"]
|
||||
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict
|
||||
|
||||
try:
|
||||
vae_checkpoint = convert_original_vae(vae_checkpoint)
|
||||
finally:
|
||||
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict
|
||||
|
||||
def compile_unet_variants(self, model):
|
||||
if model == "unet":
|
||||
|
||||
@@ -24,6 +24,7 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
get_available_devices,
|
||||
get_opt_flags,
|
||||
preprocessCKPT,
|
||||
convert_original_vae,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
sanitize_seed,
|
||||
|
||||
@@ -495,13 +495,12 @@ p.add_argument(
|
||||
)
|
||||
# TODO: replace API flag when these can be run together
|
||||
p.add_argument(
|
||||
"--web_mode",
|
||||
"--ui",
|
||||
type=str,
|
||||
default="app",
|
||||
help="any number of: [api, app, webui]. Currently api can't be run with others.",
|
||||
default="app" if os.name == "nt" else "web",
|
||||
help="one of: [api, app, web]",
|
||||
)
|
||||
|
||||
|
||||
p.add_argument(
|
||||
"--share",
|
||||
default=False,
|
||||
|
||||
@@ -25,7 +25,12 @@ from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
import sys
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
create_vae_diffusers_config,
|
||||
convert_ldm_vae_checkpoint,
|
||||
)
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def get_extended_name(model_name):
|
||||
@@ -464,7 +469,7 @@ def get_path_stem(path):
|
||||
def get_path_to_diffusers_checkpoint(custom_weights):
|
||||
path = Path(custom_weights)
|
||||
diffusers_path = path.parent.absolute()
|
||||
diffusers_directory_name = path.stem
|
||||
diffusers_directory_name = os.path.join("diffusers", path.stem)
|
||||
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
|
||||
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
|
||||
path_to_diffusers = complete_path_to_diffusers.as_posix()
|
||||
@@ -503,6 +508,22 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
print("Loading complete")
|
||||
|
||||
|
||||
def convert_original_vae(vae_checkpoint):
|
||||
vae_state_dict = {}
|
||||
for key in list(vae_checkpoint.keys()):
|
||||
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
|
||||
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=512)
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
vae_state_dict, vae_config
|
||||
)
|
||||
return converted_vae_checkpoint
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
|
||||
@@ -16,17 +16,21 @@ def launch_app(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
|
||||
tk = Tk()
|
||||
# size of the window where we show our website
|
||||
tk.geometry("1280x720")
|
||||
webview.create_window("SHARK", address)
|
||||
window = Tk()
|
||||
|
||||
# getting screen width and height of display
|
||||
width = window.winfo_screenwidth()
|
||||
height = window.winfo_screenheight()
|
||||
webview.create_window(
|
||||
"SHARK AI Studio", url=address, width=width, height=height
|
||||
)
|
||||
webview.start(private_mode=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
if args.api or "api" in args.web_mode.split(","):
|
||||
if args.api or "api" in args.ui.split(","):
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_api,
|
||||
img2img_api,
|
||||
@@ -55,14 +59,12 @@ if __name__ == "__main__":
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
clear_gradio_tmp_imgs_folder,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
|
||||
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
clear_gradio_tmp_imgs_folder()
|
||||
# Create the custom model folder if it doesn't already exist
|
||||
dir = ["models", "vae", "lora"]
|
||||
for root in dir:
|
||||
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
|
||||
# Create custom models folders if they don't exist
|
||||
create_custom_models_folders()
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
@@ -75,30 +77,40 @@ if __name__ == "__main__":
|
||||
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_sendto_img2img,
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_sendto_inpaint,
|
||||
img2img_sendto_outpaint,
|
||||
img2img_sendto_upscaler,
|
||||
inpaint_web,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_sendto_img2img,
|
||||
inpaint_sendto_outpaint,
|
||||
inpaint_sendto_upscaler,
|
||||
outpaint_web,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_sendto_img2img,
|
||||
outpaint_sendto_inpaint,
|
||||
outpaint_sendto_upscaler,
|
||||
upscaler_web,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_sendto_img2img,
|
||||
@@ -106,6 +118,13 @@ if __name__ == "__main__":
|
||||
upscaler_sendto_outpaint,
|
||||
lora_train_web,
|
||||
model_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
modelmanager_sendto_inpaint,
|
||||
modelmanager_sendto_outpaint,
|
||||
modelmanager_sendto_upscaler,
|
||||
stablelm_chat,
|
||||
)
|
||||
|
||||
# init global sd pipeline and config
|
||||
@@ -121,6 +140,17 @@ if __name__ == "__main__":
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_modelmanager_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
"None",
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
|
||||
) as sd_web:
|
||||
@@ -137,9 +167,9 @@ if __name__ == "__main__":
|
||||
upscaler_web.render()
|
||||
with gr.TabItem(label="Model Manager", id=5):
|
||||
model_web.render()
|
||||
|
||||
with gr.Tabs(visible=False) as experimental_tabs:
|
||||
with gr.TabItem(label="LoRA Training", id=5):
|
||||
with gr.TabItem(label="Chat Bot(Experimental)", id=6):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
|
||||
lora_train_web.render()
|
||||
|
||||
register_button_click(
|
||||
@@ -238,15 +268,46 @@ if __name__ == "__main__":
|
||||
[upscaler_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_txt2img,
|
||||
0,
|
||||
[hf_models],
|
||||
[txt2img_custom_model, txt2img_hf_model_id, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_img2img,
|
||||
1,
|
||||
[hf_models],
|
||||
[img2img_custom_model, img2img_hf_model_id, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_inpaint,
|
||||
2,
|
||||
[hf_models],
|
||||
[inpaint_custom_model, inpaint_hf_model_id, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_outpaint,
|
||||
3,
|
||||
[hf_models],
|
||||
[outpaint_custom_model, outpaint_hf_model_id, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_upscaler,
|
||||
4,
|
||||
[hf_models],
|
||||
[upscaler_custom_model, upscaler_hf_model_id, tabs],
|
||||
)
|
||||
|
||||
sd_web.queue()
|
||||
if "app" in args.web_mode.split(","):
|
||||
if args.ui == "app":
|
||||
t = Process(
|
||||
target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
)
|
||||
t.start()
|
||||
sd_web.launch(
|
||||
share=args.share,
|
||||
inbrowser="webui" in args.web_mode.split(","),
|
||||
inbrowser=args.ui == "web",
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.server_port,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,8 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
txt2img_inf,
|
||||
txt2img_api,
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_sendto_img2img,
|
||||
txt2img_sendto_inpaint,
|
||||
@@ -12,6 +14,8 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_inf,
|
||||
img2img_api,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_sendto_inpaint,
|
||||
@@ -22,6 +26,8 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
inpaint_inf,
|
||||
inpaint_api,
|
||||
inpaint_web,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_sendto_img2img,
|
||||
@@ -32,6 +38,8 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
outpaint_inf,
|
||||
outpaint_api,
|
||||
outpaint_web,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_sendto_img2img,
|
||||
@@ -42,11 +50,22 @@ from apps.stable_diffusion.web.ui.upscaler_ui import (
|
||||
upscaler_inf,
|
||||
upscaler_api,
|
||||
upscaler_web,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.model_manager import (
|
||||
model_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
modelmanager_sendto_inpaint,
|
||||
modelmanager_sendto_outpaint,
|
||||
modelmanager_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
|
||||
from apps.stable_diffusion.web.ui.model_manager import model_web
|
||||
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
@@ -89,6 +90,8 @@ def img2img_inf(
|
||||
return None, "An Initial Image is required"
|
||||
if use_stencil == "scribble":
|
||||
image = image_dict["mask"].convert("RGB")
|
||||
elif isinstance(image_dict, PIL.Image.Image):
|
||||
image = image_dict.convert("RGB")
|
||||
else:
|
||||
image = image_dict["image"].convert("RGB")
|
||||
|
||||
@@ -299,7 +302,7 @@ def img2img_api(
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["image"])
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = img2img_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
@@ -352,17 +355,17 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
img2img_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
img2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
@@ -624,8 +627,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
|
||||
@@ -306,17 +306,19 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
inpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
inpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
@@ -527,8 +529,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
|
||||
@@ -3,16 +3,15 @@ import gradio as gr
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from shark.iree_utils._common import run_cmd
|
||||
|
||||
|
||||
def get_hf_list(limit=20):
|
||||
def get_hf_list(num_of_models=20):
|
||||
path = "https://huggingface.co/api/models"
|
||||
params = {
|
||||
"search": "stable-diffusion",
|
||||
"sort": "downloads",
|
||||
"direction": "-1",
|
||||
"limit": {limit},
|
||||
"limit": {num_of_models},
|
||||
"full": "true",
|
||||
}
|
||||
response = requests.get(path, params=params)
|
||||
@@ -72,65 +71,87 @@ def get_image_from_model(model_json):
|
||||
return image
|
||||
|
||||
|
||||
hf_model_list = get_hf_list()
|
||||
civit_model_list = get_civit_list()
|
||||
|
||||
|
||||
with gr.Blocks() as model_web:
|
||||
model_source = gr.Radio(
|
||||
choices=["Hugging Face", "Civitai"],
|
||||
type="index",
|
||||
value="Hugging Face",
|
||||
label="Model Source",
|
||||
with gr.Row():
|
||||
model_source = gr.Radio(
|
||||
value=None,
|
||||
choices=["Hugging Face", "Civitai"],
|
||||
type="value",
|
||||
label="Model Source",
|
||||
)
|
||||
model_numebr = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Number of models",
|
||||
interactive=True,
|
||||
)
|
||||
# TODO: add more filters
|
||||
get_model_btn = gr.Button(value="Get Models")
|
||||
|
||||
hf_models = gr.Dropdown(
|
||||
label="Hugging Face Model List",
|
||||
choices=None,
|
||||
value=None,
|
||||
visible=False,
|
||||
)
|
||||
with gr.Column(visible=True) as hf_block:
|
||||
for model in hf_model_list:
|
||||
with gr.Row():
|
||||
model_url = gr.Textbox(
|
||||
label="Model ID:",
|
||||
value=model["modelId"],
|
||||
lines=1,
|
||||
interactive=False,
|
||||
)
|
||||
model_info = gr.Textbox(
|
||||
value=f'Download Count: {model["downloads"]}{os.linesep}Favorite Count: {model["likes"]}',
|
||||
lines=2,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
)
|
||||
with gr.Column(visible=False) as civit_block:
|
||||
for model in civit_model_list:
|
||||
with gr.Row():
|
||||
# TODO: select and SendTo
|
||||
civit_models = gr.Gallery(
|
||||
label="Civitai Model Gallery",
|
||||
value=None,
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False) as sendto_btns:
|
||||
modelmanager_sendto_txt2img = gr.Button(value="SendTo Txt2Img")
|
||||
modelmanager_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
modelmanager_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
|
||||
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
|
||||
|
||||
def get_model_list(model_source, model_numebr):
|
||||
if model_source == "Hugging Face":
|
||||
hf_model_list = get_hf_list(model_numebr)
|
||||
models = []
|
||||
for model in hf_model_list:
|
||||
# TODO: add model info
|
||||
models.append(f'{model["modelId"]}')
|
||||
return (
|
||||
gr.Dropdown.update(choices=models, visible=True),
|
||||
gr.Gallery.update(value=None, visible=False),
|
||||
gr.Row.update(visible=True),
|
||||
)
|
||||
elif model_source == "Civitai":
|
||||
civit_model_list = get_civit_list(model_numebr)
|
||||
models = []
|
||||
for model in civit_model_list:
|
||||
image = get_image_from_model(model)
|
||||
if image is None:
|
||||
continue
|
||||
model_img = Image.open(image)
|
||||
gr.Image(
|
||||
value=model_img,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=300, height=300)
|
||||
with gr.Column():
|
||||
gr.Textbox(
|
||||
label=f'{model["modelName"]}',
|
||||
value=f'Rating: {model["rating"]}{os.linesep}Favorite Count: {model["favoriteCount"]}{os.linesep}Download Count: {model["downloadCount"]}{os.linesep}File Format: {model["files"][0]["metadata"]["format"]}',
|
||||
lines=4,
|
||||
)
|
||||
gr.Textbox(
|
||||
label="Download URL:",
|
||||
value=f'{model["files"][0]["downloadUrl"]}',
|
||||
lines=1,
|
||||
)
|
||||
|
||||
def update_model_list(model_source):
|
||||
if model_source:
|
||||
return gr.update(visible=False), gr.update(visible=True)
|
||||
# TODO: add model info
|
||||
models.append(
|
||||
(Image.open(image), f'{model["files"][0]["downloadUrl"]}')
|
||||
)
|
||||
return (
|
||||
gr.Dropdown.update(value=None, choices=None, visible=False),
|
||||
gr.Gallery.update(value=models, visible=True),
|
||||
gr.Row.update(visible=False),
|
||||
)
|
||||
else:
|
||||
return gr.update(visible=True), gr.update(visible=False)
|
||||
return (
|
||||
gr.Dropdown.update(value=None, choices=None, visible=False),
|
||||
gr.Gallery.update(value=None, visible=False),
|
||||
gr.Row.update(visible=False),
|
||||
)
|
||||
|
||||
model_source.change(
|
||||
fn=update_model_list,
|
||||
inputs=model_source,
|
||||
outputs=[hf_block, civit_block],
|
||||
get_model_btn.click(
|
||||
fn=get_model_list,
|
||||
inputs=[model_source, model_numebr],
|
||||
outputs=[
|
||||
hf_models,
|
||||
civit_models,
|
||||
sendto_btns,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -317,17 +317,19 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
outpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
outpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
@@ -558,8 +560,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
|
||||
217
apps/stable_diffusion/web/ui/stablelm_ui.py
Normal file
217
apps/stable_diffusion/web/ui/stablelm_ui.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import gradio as gr
|
||||
import torch
|
||||
import os
|
||||
from apps.language_models.scripts.stablelm import (
|
||||
compile_stableLM,
|
||||
StopOnTokens,
|
||||
generate,
|
||||
get_tokenizer,
|
||||
StableLMModel,
|
||||
)
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
TextIteratorStreamer,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
|
||||
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
input_ids = torch.randint(3, (1, 256))
|
||||
attention_mask = torch.randint(3, (1, 256))
|
||||
|
||||
|
||||
sharkModel = 0
|
||||
sharded_model = 0
|
||||
|
||||
|
||||
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
past_key_values = None
|
||||
|
||||
|
||||
def chat(curr_system_message, history, model):
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
if "vicuna" in model:
|
||||
from apps.language_models.scripts.sharded_vicuna_fp32 import (
|
||||
tokenizer,
|
||||
get_sharded_model,
|
||||
)
|
||||
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
curr_system_message = start_message_vicuna
|
||||
if sharded_model == 0:
|
||||
sharded_model = get_sharded_model()
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
prompt = messages.strip()
|
||||
print("prompt = ", prompt)
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
new_sentence = ""
|
||||
for _ in range(200):
|
||||
original_input_ids = input_ids
|
||||
input_id_len = len(input_ids)
|
||||
pad_len = SAMPLE_INPUT_LEN - input_id_len
|
||||
attention_mask = torch.ones([1, input_id_len], dtype=torch.int64)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
attention_mask = torch.nn.functional.pad(
|
||||
torch.tensor(attention_mask),
|
||||
(0, pad_len),
|
||||
mode="constant",
|
||||
value=0,
|
||||
)
|
||||
|
||||
if _ == 0:
|
||||
output = sharded_model.forward(input_ids, is_first=True)
|
||||
else:
|
||||
output = sharded_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=False
|
||||
)
|
||||
logits = output["logits"]
|
||||
past_key_values = output["past_key_values"]
|
||||
new_word = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=1))
|
||||
if new_word == "</s>":
|
||||
break
|
||||
new_sentence += " " + new_word
|
||||
history[-1][1] = new_sentence
|
||||
yield history
|
||||
next_token = torch.argmax(logits[:, input_id_len - 1, :], dim=1)
|
||||
original_input_ids.append(next_token)
|
||||
input_ids = [next_token]
|
||||
print(new_sentence)
|
||||
return history
|
||||
|
||||
global sharkModel
|
||||
print("In chat")
|
||||
if sharkModel == 0:
|
||||
tok = get_tokenizer()
|
||||
# sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb")
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32
|
||||
)
|
||||
stableLMModel = StableLMModel(m)
|
||||
sharkModel = compile_stableLM(
|
||||
stableLMModel,
|
||||
tuple([input_ids, attention_mask]),
|
||||
"stableLM_linalg_f32_seqLen256",
|
||||
os.getcwd(),
|
||||
)
|
||||
# Initialize a StopOnTokens object
|
||||
stop = StopOnTokens()
|
||||
# Construct the input message string for the model by concatenating the current system message and conversation history
|
||||
if len(curr_system_message.split()) > 160:
|
||||
print("clearing context")
|
||||
curr_system_message = start_message
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
# print(messages)
|
||||
# Tokenize the messages string
|
||||
streamer = TextIteratorStreamer(
|
||||
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
generate_kwargs = dict(
|
||||
new_text=messages,
|
||||
streamer=streamer,
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
top_p=0.95,
|
||||
top_k=1000,
|
||||
temperature=1.0,
|
||||
num_beams=1,
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
sharkStableLM=sharkModel,
|
||||
)
|
||||
words_list = generate(**generate_kwargs)
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
# print(new_text)
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to cleanup the message textbox and the updated conversation history
|
||||
yield history
|
||||
return words_list
|
||||
|
||||
|
||||
with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
with gr.Row():
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value="TheBloke/vicuna-7B-1.1-HF",
|
||||
choices=[
|
||||
"stabilityai/stablelm-tuned-alpha-3b",
|
||||
"TheBloke/vicuna-7B-1.1-HF",
|
||||
],
|
||||
)
|
||||
device_value = None
|
||||
for d in available_devices:
|
||||
if "vulkan" in d:
|
||||
device_value = d
|
||||
break
|
||||
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=device_value if device_value else available_devices[0],
|
||||
interactive=False,
|
||||
choices=available_devices,
|
||||
)
|
||||
chatbot = gr.Chatbot().style(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
).style(container=False)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit")
|
||||
stop = gr.Button("Stop")
|
||||
clear = gr.Button("Clear")
|
||||
system_msg = gr.Textbox(
|
||||
start_message, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[submit_event, submit_click_event],
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
@@ -281,17 +281,17 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
txt2img_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
txt2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
@@ -502,8 +502,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -538,8 +538,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
],
|
||||
outputs=[
|
||||
png_info_img,
|
||||
@@ -551,7 +551,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -309,17 +309,19 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
upscaler_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-x4-upscaler",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="upscaler"
|
||||
)
|
||||
+ predefined_upscaler_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
upscaler_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
@@ -525,8 +527,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
|
||||
@@ -72,30 +72,36 @@ def resource_path(relative_path):
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def create_custom_models_folders():
|
||||
dir = ["vae", "lora"]
|
||||
if not args.ckpt_dir:
|
||||
dir.insert(0, "models")
|
||||
else:
|
||||
if not os.path.isdir(args.ckpt_dir):
|
||||
sys.exit(
|
||||
f"Invalid --ckpt_dir argument, {args.ckpt_dir} folder does not exists."
|
||||
)
|
||||
for root in dir:
|
||||
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_custom_model_path(model="models"):
|
||||
# If `--ckpt_dir` is provided it'd override the heirarchical folder
|
||||
# structure in WebUI :-
|
||||
# model
|
||||
# models or args.ckpt_dir
|
||||
# |___lora
|
||||
# |___vae
|
||||
sub_folder = "" if model == "models" else model
|
||||
if args.ckpt_dir:
|
||||
return Path(args.ckpt_dir)
|
||||
match model:
|
||||
case "models":
|
||||
return Path(Path.cwd(), "models")
|
||||
case "vae":
|
||||
return Path(Path.cwd(), "models/vae")
|
||||
case "lora":
|
||||
return Path(Path.cwd(), "models/lora")
|
||||
case _:
|
||||
return ""
|
||||
return Path(Path(args.ckpt_dir), sub_folder)
|
||||
else:
|
||||
return Path(Path.cwd(), "models/" + sub_folder)
|
||||
|
||||
|
||||
def get_custom_model_pathfile(custom_model_name, model="models"):
|
||||
return os.path.join(get_custom_model_path(model), custom_model_name)
|
||||
|
||||
|
||||
def get_custom_model_files(model="models"):
|
||||
def get_custom_model_files(model="models", custom_checkpoint_type=""):
|
||||
ckpt_files = []
|
||||
file_types = custom_model_filetypes
|
||||
if model == "lora":
|
||||
@@ -107,6 +113,28 @@ def get_custom_model_files(model="models"):
|
||||
os.path.join(get_custom_model_path(model), extn)
|
||||
)
|
||||
]
|
||||
match custom_checkpoint_type:
|
||||
case "inpainting":
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if val.endswith("inpainting" + extn.removeprefix("*"))
|
||||
]
|
||||
case "upscaler":
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if val.endswith("upscaler" + extn.removeprefix("*"))
|
||||
]
|
||||
case _:
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if not (
|
||||
val.endswith("inpainting" + extn.removeprefix("*"))
|
||||
or val.endswith("upscaler" + extn.removeprefix("*"))
|
||||
)
|
||||
]
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ transformers
|
||||
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
gradio==3.22.0
|
||||
altair
|
||||
omegaconf
|
||||
safetensors
|
||||
@@ -28,6 +28,7 @@ scikit-image
|
||||
pytorch_lightning # for runwayml models
|
||||
tk
|
||||
pywebview
|
||||
sentencepiece
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
# Sets up a venv suitable for running samples.
|
||||
# e.g:
|
||||
# ./setup_venv.sh #setup a default $PYTHON3 shark.venv
|
||||
# Environment Variables by the script.
|
||||
# Environment variables used by the script.
|
||||
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
|
||||
# VENV_DIR=myshark.venv #create a venv called myshark.venv
|
||||
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
|
||||
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
|
||||
# IMPORTER=1 #Install importer deps
|
||||
# BENCHMARK=1 #Install benchmark deps
|
||||
@@ -26,15 +27,17 @@ PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; prin
|
||||
echo "Python: $PYTHON"
|
||||
echo "Python version: $PYTHON_VERSION_X_Y"
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
# Not a conda env. So create a new VENV dir
|
||||
VENV_DIR=${VENV_DIR:-shark.venv}
|
||||
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
|
||||
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
|
||||
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
|
||||
PYTHON="$(which python3)"
|
||||
else
|
||||
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
|
||||
if [[ "$SKIP_VENV" != "1" ]]; then
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
# Not a conda env. So create a new VENV dir
|
||||
VENV_DIR=${VENV_DIR:-shark.venv}
|
||||
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
|
||||
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
|
||||
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
|
||||
PYTHON="$(which python3)"
|
||||
else
|
||||
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
|
||||
fi
|
||||
fi
|
||||
|
||||
Red=`tput setaf 1`
|
||||
@@ -147,8 +150,7 @@ if [[ ! -z "${ONNX}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
|
||||
echo "${Green}Before running examples activate venv with:"
|
||||
echo " ${Green}source $VENV_DIR/bin/activate"
|
||||
fi
|
||||
|
||||
|
||||
73
shark/examples/shark_inference/minilm_jax.py
Normal file
73
shark/examples/shark_inference/minilm_jax.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from transformers import AutoTokenizer, FlaxAutoModel
|
||||
import torch
|
||||
import jax
|
||||
from typing import Union, Dict, List, Any
|
||||
import numpy as np
|
||||
from shark.shark_inference import SharkInference
|
||||
import io
|
||||
|
||||
NumpyTree = Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]
|
||||
|
||||
|
||||
def convert_torch_tensor_tree_to_numpy(
|
||||
tree: Union[torch.tensor, Dict[str, torch.tensor], List[torch.tensor]]
|
||||
) -> NumpyTree:
|
||||
return jax.tree_util.tree_map(
|
||||
lambda torch_tensor: torch_tensor.cpu().detach().numpy(), tree
|
||||
)
|
||||
|
||||
|
||||
def convert_int64_to_int32(tree: NumpyTree) -> NumpyTree:
|
||||
return jax.tree_util.tree_map(
|
||||
lambda tensor: np.array(tensor, dtype=np.int32)
|
||||
if tensor.dtype == np.int64
|
||||
else tensor,
|
||||
tree,
|
||||
)
|
||||
|
||||
|
||||
def get_sample_input():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
inputs_torch = tokenizer("Hello, World!", return_tensors="pt")
|
||||
return convert_int64_to_int32(
|
||||
convert_torch_tensor_tree_to_numpy(inputs_torch.data)
|
||||
)
|
||||
|
||||
|
||||
def get_jax_model():
|
||||
return FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
|
||||
|
||||
def export_jax_to_mlir(jax_model: Any, sample_input: NumpyTree):
|
||||
model_mlir = jax.jit(jax_model).lower(**sample_input).compiler_ir()
|
||||
byte_stream = io.BytesIO()
|
||||
model_mlir.operation.write_bytecode(file=byte_stream)
|
||||
return byte_stream.getvalue()
|
||||
|
||||
|
||||
def assert_array_list_allclose(x, y, *args, **kwargs):
|
||||
assert len(x) == len(y)
|
||||
for a, b in zip(x, y):
|
||||
np.testing.assert_allclose(
|
||||
np.asarray(a), np.asarray(b), *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
sample_input = get_sample_input()
|
||||
jax_model = get_jax_model()
|
||||
mlir = export_jax_to_mlir(jax_model, sample_input)
|
||||
|
||||
# Compile and load module.
|
||||
shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo")
|
||||
shark_inference.compile()
|
||||
|
||||
# Run main function.
|
||||
result = shark_inference("main", jax.tree_util.tree_flatten(sample_input)[0])
|
||||
|
||||
# Run JAX model.
|
||||
reference_result = jax.tree_util.tree_flatten(jax_model(**sample_input))[0]
|
||||
|
||||
# Verify result.
|
||||
assert_array_list_allclose(result, reference_result, atol=1e-5)
|
||||
@@ -0,0 +1,6 @@
|
||||
flax
|
||||
jax[cpu]
|
||||
nodai-SHARK
|
||||
orbax
|
||||
transformers
|
||||
torch
|
||||
@@ -45,10 +45,15 @@ def run_cmd(cmd, debug=False):
|
||||
|
||||
def iree_device_map(device):
|
||||
uri_parts = device.split("://", 2)
|
||||
iree_driver = (
|
||||
_IREE_DEVICE_MAP[uri_parts[0]]
|
||||
if uri_parts[0] in _IREE_DEVICE_MAP
|
||||
else uri_parts[0]
|
||||
)
|
||||
if len(uri_parts) == 1:
|
||||
return _IREE_DEVICE_MAP[uri_parts[0]]
|
||||
return iree_driver
|
||||
else:
|
||||
return f"{_IREE_DEVICE_MAP[uri_parts[0]]}://{uri_parts[1]}"
|
||||
return f"{iree_driver}://{uri_parts[1]}"
|
||||
|
||||
|
||||
def get_supported_device_list():
|
||||
@@ -68,7 +73,7 @@ _IREE_DEVICE_MAP = {
|
||||
def iree_target_map(device):
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
return _IREE_TARGET_MAP[device]
|
||||
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
|
||||
|
||||
|
||||
_IREE_TARGET_MAP = {
|
||||
@@ -110,10 +115,8 @@ def check_device_drivers(device):
|
||||
subprocess.check_output("rocminfo")
|
||||
except Exception:
|
||||
return True
|
||||
# Unknown device.
|
||||
else:
|
||||
return True
|
||||
|
||||
# Unknown device. We assume drivers are installed.
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import re
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
print("Configuring for device:" + device)
|
||||
device_uri = device.split("://")
|
||||
if len(device_uri) > 1:
|
||||
if device_uri[0] not in ["vulkan"]:
|
||||
@@ -30,6 +31,9 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
f"Specific device selection only supported for vulkan now."
|
||||
f"Proceeding with {device} as device."
|
||||
)
|
||||
device_num = device_uri[1]
|
||||
else:
|
||||
device_num = 0
|
||||
|
||||
if device_uri[0] == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
@@ -42,7 +46,9 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] in ["metal", "vulkan"]:
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(extra_args=extra_args)
|
||||
return get_iree_vulkan_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device_uri[0] == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
@@ -307,7 +313,7 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
)
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
ModuleCompiled = ctx.modules.module
|
||||
ModuleCompiled = getattr(ctx.modules, vm_module.name)
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from sys import platform
|
||||
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
def get_vulkan_device_name():
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
|
||||
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
@@ -31,8 +31,8 @@ def get_vulkan_device_name():
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing first one: {vulkaninfo_list[0]}")
|
||||
return vulkaninfo_list[0]
|
||||
print(f"Choosing device: {vulkaninfo_list[device_num]}")
|
||||
return vulkaninfo_list[device_num]
|
||||
|
||||
|
||||
def get_os_name():
|
||||
@@ -119,14 +119,14 @@ def get_vulkan_target_triple(device_name):
|
||||
return triple
|
||||
|
||||
|
||||
def get_vulkan_triple_flag(device_name="", extra_args=[]):
|
||||
def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
for flag in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in flag:
|
||||
print(f"Using target triple {flag.split('=')[1]}")
|
||||
return None
|
||||
|
||||
if device_name == "" or device_name == [] or device_name is None:
|
||||
vulkan_device = get_vulkan_device_name()
|
||||
vulkan_device = get_vulkan_device_name(device_num=device_num)
|
||||
else:
|
||||
vulkan_device = device_name
|
||||
triple = get_vulkan_target_triple(vulkan_device)
|
||||
@@ -144,7 +144,7 @@ def get_vulkan_triple_flag(device_name="", extra_args=[]):
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_vulkan_args(extra_args=[]):
|
||||
def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
res_vulkan_flag = []
|
||||
@@ -156,7 +156,9 @@ def get_iree_vulkan_args(extra_args=[]):
|
||||
break
|
||||
|
||||
if vulkan_triple_flag is None:
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)
|
||||
|
||||
@@ -30,8 +30,8 @@ import os
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
|
||||
import iree.compiler._mlir_libs
|
||||
from iree.compiler import ir
|
||||
from iree.compiler.transforms import ireec as ireec_trans
|
||||
|
||||
|
||||
def model_annotation(
|
||||
@@ -409,7 +409,6 @@ def shape_list_to_string(input):
|
||||
|
||||
def create_context() -> ir.Context:
|
||||
context = ir.Context()
|
||||
ireec_trans.register_all_dialects(context)
|
||||
context.allow_unregistered_dialects = True
|
||||
return context
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ def download_model(
|
||||
tank_url=None,
|
||||
frontend=None,
|
||||
tuned=None,
|
||||
import_args=None,
|
||||
import_args={"batch_size": 1},
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
|
||||
@@ -191,6 +191,7 @@ class SharkImporter:
|
||||
dir=tempfile.gettempdir(),
|
||||
model_name="model",
|
||||
golden_values=None,
|
||||
mlir_type="linalg",
|
||||
):
|
||||
if self.inputs == None:
|
||||
print(
|
||||
@@ -204,6 +205,7 @@ class SharkImporter:
|
||||
tracing_required,
|
||||
func_name,
|
||||
save_dir=artifact_path,
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
|
||||
# TODO: Check for multiple outputs.
|
||||
|
||||
@@ -21,7 +21,7 @@ import io
|
||||
|
||||
mlir_type_mapping_dict = {
|
||||
"linalg": torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
"mhlo": torch_mlir.OutputType.STABLEHLO,
|
||||
"stablehlo": torch_mlir.OutputType.STABLEHLO,
|
||||
"tosa": torch_mlir.OutputType.TOSA,
|
||||
}
|
||||
|
||||
|
||||
@@ -15,16 +15,16 @@ microsoft/layoutlm-base-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False
|
||||
microsoft/mpnet-base,mhlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
|
||||
albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with aten.tanh in torch-mlir",""
|
||||
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,False,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-large-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344",""
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344",""
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/388","macos"
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/343","macos"
|
||||
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
|
||||
@@ -36,12 +36,12 @@ wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,Fal
|
||||
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243",""
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"",""
|
||||
gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported",""
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
stabilityai/stable-diffusion-2-1-base,linalg,torch,1e-3,1e-3,default,None,True,False,False,"",""
|
||||
|
||||
|
@@ -50,6 +50,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
tracing_required = row[1]
|
||||
model_type = row[2]
|
||||
is_dynamic = row[3]
|
||||
mlir_type = row[4]
|
||||
|
||||
tracing_required = False if tracing_required == "False" else True
|
||||
is_dynamic = False if is_dynamic == "False" else True
|
||||
@@ -121,6 +122,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name,
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
# Generate torch dynamic models.
|
||||
if is_dynamic:
|
||||
@@ -129,6 +131,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name + "_dynamic",
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
|
||||
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
microsoft/MiniLM-L12-H384-uncased,True,hf,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
|
||||
bert-base-uncased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
google/mobilebert-uncased,True,hf,True,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
|
||||
alexnet,False,vision,True,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
|
||||
resnet18,False,vision,True,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
|
||||
resnet50,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
resnet101,False,vision,True,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
squeezenet1_0,False,vision,True,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
|
||||
wide_resnet50_2,False,vision,True,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
|
||||
mobilenet_v3_small,False,vision,True,2.5M,"image-classification,cnn,mobile",N/A
|
||||
google/vit-base-patch16-224,True,hf_img_cls,False,86M,"image-classification,vision-transformer,transformer-encoder",N/A
|
||||
microsoft/resnet-50,True,hf_img_cls,False,23M,"image-classification,cnn,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,22M,"image-classification,vision-transformer,cnn",N/A
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
|
||||
nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encoder",SegFormer
|
||||
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
model_name, use_tracing, model_type, dynamic, mlir_type, param_count, tags, notes
|
||||
efficientnet_b0,True,vision,False,linalg,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,vision,False,linalg,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
microsoft/MiniLM-L12-H384-uncased,True,hf,True,linalg,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
|
||||
bert-base-uncased,True,hf,True,linalg,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,hf,True,linalg,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
google/mobilebert-uncased,True,hf,True,linalg,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
|
||||
alexnet,False,vision,True,linalg,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
|
||||
resnet18,False,vision,True,linalg,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
|
||||
resnet50,False,vision,True,linalg,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
resnet101,False,vision,True,linalg,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
squeezenet1_0,False,vision,True,linalg,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
|
||||
wide_resnet50_2,False,vision,True,linalg,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
|
||||
mobilenet_v3_small,False,vision,True,linalg,2.5M,"image-classification,cnn,mobile",N/A
|
||||
google/vit-base-patch16-224,True,hf_img_cls,False,linalg,86M,"image-classification,vision-transformer,transformer-encoder",N/A
|
||||
microsoft/resnet-50,True,hf_img_cls,False,linalg,23M,"image-classification,cnn,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,linalg,22M,"image-classification,vision-transformer,cnn",N/A
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,linalg,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
|
||||
nvidia/mit-b0,True,hf_img_cls,False,linalg,3.7M,"image-classification,transformer-encoder",SegFormer
|
||||
mnasnet1_0,False,vision,True,linalg,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,linalg,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,linalg,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,linalg,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
bert-base-uncased,True,hf,False,stablehlo,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
|
||||
|
Reference in New Issue
Block a user