mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
42 Commits
20230522.7
...
20230608.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
436f58ddc4 | ||
|
|
6b29bd17c8 | ||
|
|
2c3485ca3e | ||
|
|
f206ecc635 | ||
|
|
a187e05ae6 | ||
|
|
8c21960486 | ||
|
|
be62fce676 | ||
|
|
f23b778a6c | ||
|
|
436edf900d | ||
|
|
ed58c2553f | ||
|
|
f2ca58e844 | ||
|
|
1dbcc736eb | ||
|
|
a83808ddc5 | ||
|
|
a07fe80530 | ||
|
|
d0ba3ef8fa | ||
|
|
8400529c2c | ||
|
|
7eaee9c242 | ||
|
|
8230eebce5 | ||
|
|
6296ea4be9 | ||
|
|
4151ec3a8f | ||
|
|
a2467e8d43 | ||
|
|
e677178bcc | ||
|
|
7ef1bea953 | ||
|
|
ad89bb1413 | ||
|
|
218ed78c40 | ||
|
|
6046f36ab6 | ||
|
|
5915bf7de3 | ||
|
|
f0a4e59758 | ||
|
|
1ddef26af5 | ||
|
|
ba8eddb12f | ||
|
|
47b346d428 | ||
|
|
1b4f4f5f4d | ||
|
|
73cd7e8320 | ||
|
|
19c0ae3702 | ||
|
|
54e57f7771 | ||
|
|
6d64b8e273 | ||
|
|
a8ea0326f5 | ||
|
|
58e9194553 | ||
|
|
eb360e255d | ||
|
|
a6f88d7f72 | ||
|
|
8e571d165f | ||
|
|
3cddd01b10 |
26
.github/workflows/nightly.yml
vendored
26
.github/workflows/nightly.yml
vendored
@@ -50,27 +50,13 @@ jobs:
|
||||
shell: powershell
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
|
||||
python process_skipfiles.py
|
||||
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
|
||||
|
||||
# GHA windows VM OOMs so disable for now
|
||||
#- name: Build and validate the SHARK Runtime package
|
||||
# shell: powershell
|
||||
# run: |
|
||||
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
#- uses: actions/upload-artifact@v2
|
||||
# with:
|
||||
# path: dist/*
|
||||
|
||||
mv ./dist/shark_sd.exe ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
@@ -78,7 +64,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ./dist/*
|
||||
assets_path: ./dist/nodai*
|
||||
#asset_content_type: application/vnd.microsoft.portable-executable
|
||||
|
||||
- name: Publish Release
|
||||
|
||||
@@ -1,27 +1,15 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
pipeline,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
TextIteratorStreamer,
|
||||
)
|
||||
import time
|
||||
import numpy as np
|
||||
from torch.nn import functional as F
|
||||
import os
|
||||
from threading import Thread
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
@@ -51,143 +39,32 @@ def user(message, history):
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
def get_torch_mlir_module_bytecode(model, model_inputs):
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
# tracing_mode='symbolic',
|
||||
)(*model_inputs)
|
||||
print("Got FX_G")
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
print("FX_G recompile")
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
print("Got TS_G")
|
||||
return ts_g
|
||||
|
||||
|
||||
def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
|
||||
# ADD Device Arg
|
||||
def compile_stableLM(
|
||||
model,
|
||||
model_inputs,
|
||||
model_name,
|
||||
model_vmfb_name,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
):
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
device = "cuda" # 'cpu'
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
vmfb_path = (
|
||||
Path(model_name + f"_{device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
shark_module = get_vmfb_from_path(
|
||||
vmfb_path, device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
if shark_module is not None:
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path { mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
@@ -214,9 +91,9 @@ def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
import os
|
||||
|
||||
path = shark_module.save_module(os.getcwd(), model_vmfb_name, [])
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
@@ -249,60 +126,85 @@ def get_tokenizer():
|
||||
model_path = "stabilityai/stablelm-tuned-alpha-3b"
|
||||
tok = AutoTokenizer.from_pretrained(model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
print(f"Sucessfully loaded the tokenizer to the memory")
|
||||
print("Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
# sharkStableLM = compile_stableLM
|
||||
# (
|
||||
# None,
|
||||
# tuple([input_ids, attention_mask]),
|
||||
# "stableLM_linalg_f32_seqLen256",
|
||||
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
|
||||
# )
|
||||
def generate(
|
||||
new_text,
|
||||
# streamer,
|
||||
max_new_tokens,
|
||||
do_sample,
|
||||
top_p,
|
||||
top_k,
|
||||
temperature,
|
||||
num_beams,
|
||||
stopping_criteria,
|
||||
sharkStableLM,
|
||||
tok=None,
|
||||
input_ids=torch.randint(3, (1, 256)),
|
||||
attention_mask=torch.randint(3, (1, 256)),
|
||||
tokenizer=None,
|
||||
):
|
||||
if tok == None:
|
||||
tok = get_tokenizer()
|
||||
# Construct the input message string for the model by concatenating the current system message and conversation history
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer()
|
||||
# Construct the input message string for the model by
|
||||
# concatenating the current system message and conversation history
|
||||
# Tokenize the messages string
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
# sharkStableLM = compile_stableLM
|
||||
# (
|
||||
# None,
|
||||
# tuple([input_ids, attention_mask]),
|
||||
# "stableLM_linalg_f32_seqLen256",
|
||||
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
|
||||
# )
|
||||
words_list = []
|
||||
for i in range(max_new_tokens):
|
||||
numWords = len(new_text.split())
|
||||
# numWords = len(new_text.split())
|
||||
# if(numWords>220):
|
||||
# break
|
||||
model_inputs = tok(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
params = {
|
||||
"new_text": new_text,
|
||||
}
|
||||
generated_token_op = generate_new_token(
|
||||
sharkStableLM, tokenizer, params
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
output = sharkStableLM(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
if shouldStop(next_toks.indices):
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
if stop_generation:
|
||||
break
|
||||
# streamer.put(next_toks.indices[0][int(sum_attentionmask)-1])
|
||||
new_word = tok.decode(
|
||||
next_toks.indices[0][int(sum_attentionmask) - 1],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
print(new_word, end="", flush=True)
|
||||
words_list.append(new_word)
|
||||
if new_word == "":
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text = new_text + new_word
|
||||
new_text = new_text + detok
|
||||
return words_list
|
||||
|
||||
|
||||
def generate_new_token(shark_model, tokenizer, params):
|
||||
new_text = params["new_text"]
|
||||
model_inputs = tokenizer(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
output = shark_model(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
|
||||
detok = tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
@@ -1,665 +0,0 @@
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import transform_fx as transform_fx_
|
||||
import re
|
||||
from shark.shark_inference import SharkInference
|
||||
from tqdm import tqdm
|
||||
from torch_mlir import TensorPlaceholder
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class SecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class CompiledFirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CompiledSecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers0, layers1):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert len(layers0) == len(model.model.layers)
|
||||
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers0 = layers0
|
||||
self.layers1 = layers1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
is_first=True,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
if is_first:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers0
|
||||
)
|
||||
return self.model.forward(input_ids, attention_mask=attention_mask)
|
||||
else:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers1
|
||||
)
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
def write_in_dynamic_inputs0(module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
|
||||
def write_in_dynamic_inputs1(module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
if "dim_42 =" in line:
|
||||
continue
|
||||
if f"%c{dynamic_input_size}_i64 =" in line:
|
||||
new_lines.append(
|
||||
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
|
||||
)
|
||||
new_lines.append(
|
||||
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
|
||||
)
|
||||
continue
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim_42)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim_42\)",
|
||||
"tensor.empty(%dim_42, %dim_42)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
|
||||
def compile_vicuna_layer(
|
||||
vicuna_layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0=None,
|
||||
past_key_value1=None,
|
||||
):
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
attention_mask, dynamic_axes=[2, 3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
position_ids, dynamic_axes=[1]
|
||||
)
|
||||
|
||||
if past_key_value0 is None and past_key_value1 is None:
|
||||
fx_g = make_fx(
|
||||
vicuna_layer,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(hidden_states, attention_mask, position_ids)
|
||||
|
||||
else:
|
||||
fx_g = make_fx(
|
||||
vicuna_layer,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
)
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
print("FX_G recompile")
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
return ts_g
|
||||
|
||||
|
||||
def get_model_and_tokenizer(path="TheBloke/vicuna-7B-1.1-HF"):
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
|
||||
return vicuna_model, tokenizer
|
||||
|
||||
|
||||
def compile_to_vmfb(inputs, layers, is_first=True):
|
||||
mlirs, modules = [], []
|
||||
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
|
||||
if is_first:
|
||||
mlir_path = Path(f"{idx}_0.mlir")
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
else:
|
||||
mlir_path = Path(f"{idx}_1.mlir")
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if vmfb_path.exists():
|
||||
continue
|
||||
if mlir_path.exists():
|
||||
# print(f"Found layer {idx} mlir")
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
inputs[0], dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
inputs[1], dynamic_axes=[3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
inputs[2], dynamic_axes=[1]
|
||||
)
|
||||
if not is_first:
|
||||
pkv0_placeholder = TensorPlaceholder.like(
|
||||
inputs[3], dynamic_axes=[2]
|
||||
)
|
||||
pkv1_placeholder = TensorPlaceholder.like(
|
||||
inputs[4], dynamic_axes=[2]
|
||||
)
|
||||
print(f"Compiling layer {idx} mlir")
|
||||
if is_first:
|
||||
ts_g = compile_vicuna_layer(
|
||||
layer, inputs[0], inputs[1], inputs[2]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
hidden_states_placeholder,
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
else:
|
||||
ts_g = compile_vicuna_layer(
|
||||
layer,
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
inputs[3],
|
||||
inputs[4],
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
inputs[0],
|
||||
attention_mask_placeholder,
|
||||
inputs[2],
|
||||
pkv0_placeholder,
|
||||
pkv1_placeholder,
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# bytecode_stream = BytesIO()
|
||||
# module.operation.write_bytecode(bytecode_stream)
|
||||
# bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if is_first:
|
||||
module = write_in_dynamic_inputs0(str(module), 137)
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
else:
|
||||
module = write_in_dynamic_inputs1(str(module), 138)
|
||||
if idx in [0, 5, 6, 7]:
|
||||
module_str = module
|
||||
module_str = module_str.splitlines()
|
||||
new_lines = []
|
||||
for line in module_str:
|
||||
if len(line) < 1000:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line[:999])
|
||||
module_str = "\n".join(new_lines)
|
||||
f1_ = open(f"{idx}_1_test.mlir", "w+")
|
||||
f1_.write(module_str)
|
||||
f1_.close()
|
||||
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
mlirs.append(bytecode)
|
||||
|
||||
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
|
||||
if is_first:
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_0",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
else:
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_1",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
|
||||
return mlirs, modules
|
||||
|
||||
|
||||
def get_sharded_model():
|
||||
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
|
||||
# please don't change it
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
vicuna_model = get_model_and_tokenizer()[0]
|
||||
|
||||
placeholder_input0 = (
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
|
||||
)
|
||||
|
||||
placeholder_input1 = (
|
||||
torch.zeros([1, 1, 4096]),
|
||||
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
|
||||
torch.zeros([1, 1], dtype=torch.int64),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
|
||||
layers0 = [FirstVicunaLayer(layer) for layer in vicuna_model.model.layers]
|
||||
_, modules0 = compile_to_vmfb(placeholder_input0, layers0, is_first=True)
|
||||
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
|
||||
|
||||
layers1 = [SecondVicunaLayer(layer) for layer in vicuna_model.model.layers]
|
||||
_, modules1 = compile_to_vmfb(placeholder_input1, layers1, is_first=False)
|
||||
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
|
||||
|
||||
sharded_model = ShardedVicunaModel(
|
||||
vicuna_model, shark_layers0, shark_layers1
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
sharded_model = get_sharded_model()
|
||||
tokenizer = get_model_and_tokenizer()[1]
|
||||
past_key_values = None
|
||||
while True:
|
||||
print("\n\n")
|
||||
user_prompt = input("User: ")
|
||||
prompt_history = (
|
||||
prompt_history + "USER:\n" + user_prompt + prologue_prompt
|
||||
)
|
||||
prompt = prompt_history.strip()
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
tokens = input_ids
|
||||
prompt = print("Robot:", end=" ")
|
||||
new_sentence = []
|
||||
max_response_len = 1000
|
||||
for iteration in range(max_response_len):
|
||||
original_input_ids = input_ids
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
|
||||
if iteration == 0:
|
||||
output = sharded_model.forward(input_ids, is_first=True)
|
||||
else:
|
||||
output = sharded_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=False
|
||||
)
|
||||
logits = output["logits"]
|
||||
past_key_values = output["past_key_values"]
|
||||
new_token = int(torch.argmax(logits[:, -1, :], dim=1)[0])
|
||||
if new_token == 2:
|
||||
break
|
||||
new_sentence += [new_token]
|
||||
tokens.append(new_token)
|
||||
detok = tokenizer.decode(new_token)
|
||||
if detok == "<0x0A>":
|
||||
print("\n", end="", flush=True)
|
||||
else:
|
||||
print(f"{detok}", end=" ", flush=True)
|
||||
original_input_ids.append(new_token)
|
||||
input_ids = [new_token]
|
||||
|
||||
for i in range(len(tokens)):
|
||||
if type(tokens[i]) != int:
|
||||
tokens[i] = int(tokens[i][0])
|
||||
new_sentence_str = tokenizer.decode(new_sentence)
|
||||
print(
|
||||
"\n-----\nRobot: Here's the complete formatted reply:\n",
|
||||
new_sentence_str,
|
||||
)
|
||||
prompt_history += f"\n{new_sentence_str}\n"
|
||||
@@ -1,777 +0,0 @@
|
||||
import sys
|
||||
import warnings
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_compiler")
|
||||
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_runtime")
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import transform_fx as transform_fx_
|
||||
import re
|
||||
from shark.shark_inference import SharkInference
|
||||
from tqdm import tqdm
|
||||
from torch_mlir import TensorPlaceholder
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class SecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class CompiledFirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CompiledSecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers0, layers1):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert len(layers0) == len(model.model.layers)
|
||||
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers0 = layers0
|
||||
self.layers1 = layers1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
is_first=True,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
if is_first:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers0
|
||||
)
|
||||
return self.model.forward(input_ids, attention_mask=attention_mask)
|
||||
else:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers1
|
||||
)
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
def write_in_dynamic_inputs0(module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
|
||||
def write_in_dynamic_inputs1(module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
if "dim_42 =" in line:
|
||||
continue
|
||||
if f"%c{dynamic_input_size}_i64 =" in line:
|
||||
new_lines.append(
|
||||
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
|
||||
)
|
||||
new_lines.append(
|
||||
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
|
||||
)
|
||||
continue
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim_42)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim_42\)",
|
||||
"tensor.empty(%dim_42, %dim_42)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
|
||||
def compile_vicuna_layer(
|
||||
vicuna_layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0=None,
|
||||
past_key_value1=None,
|
||||
):
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
attention_mask, dynamic_axes=[2, 3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
position_ids, dynamic_axes=[1]
|
||||
)
|
||||
|
||||
if past_key_value0 is None and past_key_value1 is None:
|
||||
fx_g = make_fx(
|
||||
vicuna_layer,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(hidden_states, attention_mask, position_ids)
|
||||
|
||||
else:
|
||||
fx_g = make_fx(
|
||||
vicuna_layer,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
)
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
print("FX_G recompile")
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
return ts_g
|
||||
|
||||
|
||||
path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
|
||||
|
||||
|
||||
def compile_to_vmfb(inputs, layers, is_first=True):
|
||||
mlirs, modules = [], []
|
||||
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
|
||||
if is_first:
|
||||
mlir_path = Path(f"{idx}_0.mlir")
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
else:
|
||||
mlir_path = Path(f"{idx}_1.mlir")
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if vmfb_path.exists():
|
||||
continue
|
||||
if mlir_path.exists():
|
||||
# print(f"Found layer {idx} mlir")
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
inputs[0], dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
inputs[1], dynamic_axes=[3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
inputs[2], dynamic_axes=[1]
|
||||
)
|
||||
if not is_first:
|
||||
pkv0_placeholder = TensorPlaceholder.like(
|
||||
inputs[3], dynamic_axes=[2]
|
||||
)
|
||||
pkv1_placeholder = TensorPlaceholder.like(
|
||||
inputs[4], dynamic_axes=[2]
|
||||
)
|
||||
print(f"Compiling layer {idx} mlir")
|
||||
if is_first:
|
||||
ts_g = compile_vicuna_layer(
|
||||
layer, inputs[0], inputs[1], inputs[2]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
hidden_states_placeholder,
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
else:
|
||||
ts_g = compile_vicuna_layer(
|
||||
layer,
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
inputs[3],
|
||||
inputs[4],
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
inputs[0],
|
||||
attention_mask_placeholder,
|
||||
inputs[2],
|
||||
pkv0_placeholder,
|
||||
pkv1_placeholder,
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# bytecode_stream = BytesIO()
|
||||
# module.operation.write_bytecode(bytecode_stream)
|
||||
# bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if is_first:
|
||||
module = write_in_dynamic_inputs0(str(module), 137)
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
else:
|
||||
module = write_in_dynamic_inputs1(str(module), 138)
|
||||
if idx in [0, 5, 6, 7]:
|
||||
module_str = module
|
||||
module_str = module_str.splitlines()
|
||||
new_lines = []
|
||||
for line in module_str:
|
||||
if len(line) < 1000:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line[:999])
|
||||
module_str = "\n".join(new_lines)
|
||||
f1_ = open(f"{idx}_1_test.mlir", "w+")
|
||||
f1_.write(module_str)
|
||||
f1_.close()
|
||||
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
mlirs.append(bytecode)
|
||||
|
||||
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
|
||||
if is_first:
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_0",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
else:
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if idx < 25:
|
||||
device = "vulkan"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_1",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
|
||||
return mlirs, modules
|
||||
|
||||
|
||||
def get_sharded_model():
|
||||
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
|
||||
# please don't change it
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
global vicuna_model
|
||||
|
||||
placeholder_input0 = (
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
|
||||
)
|
||||
|
||||
placeholder_input1 = (
|
||||
torch.zeros([1, 1, 4096]),
|
||||
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
|
||||
torch.zeros([1, 1], dtype=torch.int64),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
|
||||
layers0 = [FirstVicunaLayer(layer) for layer in vicuna_model.model.layers]
|
||||
_, modules0 = compile_to_vmfb(placeholder_input0, layers0, is_first=True)
|
||||
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
|
||||
|
||||
layers1 = [SecondVicunaLayer(layer) for layer in vicuna_model.model.layers]
|
||||
_, modules1 = compile_to_vmfb(placeholder_input1, layers1, is_first=False)
|
||||
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
|
||||
|
||||
sharded_model = ShardedVicunaModel(
|
||||
vicuna_model, shark_layers0, shark_layers1
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
|
||||
sharded_model = get_sharded_model()
|
||||
|
||||
|
||||
def user(message, history):
|
||||
print("msg=", message)
|
||||
print("history=", history)
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
def chat(curr_system_message, history):
|
||||
global sharded_model
|
||||
past_key_values = None
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
print(messages)
|
||||
prompt = messages.strip()
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
tokens = input_ids
|
||||
new_sentence = []
|
||||
max_response_len = 1000
|
||||
partial_sentence = []
|
||||
partial_text = ""
|
||||
start_time = time.time()
|
||||
for iteration in range(max_response_len):
|
||||
original_input_ids = input_ids
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
|
||||
if iteration == 0:
|
||||
output = sharded_model.forward(input_ids, is_first=True)
|
||||
else:
|
||||
output = sharded_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=False
|
||||
)
|
||||
logits = output["logits"]
|
||||
past_key_values = output["past_key_values"]
|
||||
new_token = int(torch.argmax(logits[:, -1, :], dim=1)[0])
|
||||
if new_token == 2:
|
||||
break
|
||||
new_sentence += [new_token]
|
||||
partial_sentence += [new_token]
|
||||
if iteration > 0 and iteration % 2 == 0:
|
||||
new_text = tokenizer.decode(partial_sentence)
|
||||
partial_sentence = []
|
||||
print(new_text, " ")
|
||||
partial_text += new_text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
|
||||
tokens.append(new_token)
|
||||
original_input_ids.append(new_token)
|
||||
input_ids = [new_token]
|
||||
end_time = time.time()
|
||||
print(
|
||||
f"Total time taken to generated response is {end_time-start_time} seconds"
|
||||
)
|
||||
|
||||
for i in range(len(tokens)):
|
||||
if type(tokens[i]) != int:
|
||||
tokens[i] = int(tokens[i][0])
|
||||
new_sentence_str = tokenizer.decode(new_sentence)
|
||||
print(new_sentence_str)
|
||||
history[-1][1] = new_sentence_str
|
||||
return history
|
||||
|
||||
|
||||
system_msg = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
# history_eg = [["hi hello how are you", ""]]
|
||||
# print(chat(system_msg, history_eg))
|
||||
|
||||
with gr.Blocks(title="Chatbot") as vicuna_chat:
|
||||
with gr.Row():
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value="TheBloke/vicuna-7B-1.1-HF",
|
||||
choices=[
|
||||
"TheBloke/vicuna-7B-1.1-HF",
|
||||
],
|
||||
)
|
||||
device_value = None
|
||||
for d in available_devices:
|
||||
if "vulkan" in d:
|
||||
device_value = d
|
||||
break
|
||||
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=device_value if device_value else available_devices[0],
|
||||
interactive=False,
|
||||
choices=available_devices,
|
||||
)
|
||||
chatbot = gr.Chatbot().style(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
).style(container=False)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit")
|
||||
stop = gr.Button("Stop")
|
||||
clear = gr.Button("Clear")
|
||||
system_msg = gr.Textbox(
|
||||
system_msg, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[submit_event, submit_click_event],
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
p.add_argument(
|
||||
"--share",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for generating a public URL",
|
||||
)
|
||||
p.add_argument(
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=8080,
|
||||
help="flag for setting server port",
|
||||
)
|
||||
args, unknown = p.parse_known_args()
|
||||
|
||||
vicuna_chat.queue()
|
||||
vicuna_chat.launch(
|
||||
share=args.share,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.server_port,
|
||||
)
|
||||
15
apps/language_models/src/model_wrappers/stablelm_model.py
Normal file
15
apps/language_models/src/model_wrappers/stablelm_model.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
class StableLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
239
apps/language_models/src/model_wrappers/vicuna_model.py
Normal file
239
apps/language_models/src/model_wrappers/vicuna_model.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
class FirstVicuna(torch.nn.Module):
|
||||
def __init__(self, model_path):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
|
||||
def forward(self, input_ids):
|
||||
op = self.model(input_ids=input_ids, use_cache=True)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna(torch.nn.Module):
|
||||
def __init__(self, model_path):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
i0,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
):
|
||||
# input_ids = input_tuple[0]
|
||||
# input_tuple = torch.unbind(pkv, dim=0)
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
178
apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Normal file
178
apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import torch
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class SecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class CompiledFirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CompiledSecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers0, layers1):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert len(layers0) == len(model.model.layers)
|
||||
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers0 = layers0
|
||||
self.layers1 = layers1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
is_first=True,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
if is_first:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers0
|
||||
)
|
||||
return self.model.forward(input_ids, attention_mask=attention_mask)
|
||||
else:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers1
|
||||
)
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
41
apps/language_models/src/pipelines/SharkLLMBase.py
Normal file
41
apps/language_models/src/pipelines/SharkLLMBase.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class SharkLLMBase(ABC):
|
||||
def __init__(
|
||||
self, model_name, hf_model_path=None, max_num_tokens=512
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.hf_model_path = hf_model_path
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.shark_model = None
|
||||
self.device = "cpu"
|
||||
self.precision = "fp32"
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def compile(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate(self, prompt):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_new_token(self, params):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_tokenizer(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_src_model(self):
|
||||
pass
|
||||
|
||||
def load_init_from_config(self):
|
||||
pass
|
||||
185
apps/language_models/src/pipelines/stablelm_pipeline.py
Normal file
185
apps/language_models/src/pipelines/stablelm_pipeline.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.src.model_wrappers.stablelm_model import (
|
||||
StableLMModel,
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class SharkStableLM(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def shouldStop(self, tokens):
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_src_model(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, torch_dtype=torch.float32
|
||||
)
|
||||
return model
|
||||
|
||||
def get_model_inputs(self):
|
||||
input_ids = torch.randint(3, (1, self.max_sequence_len))
|
||||
attention_mask = torch.randint(3, (1, self.max_sequence_len))
|
||||
return input_ids, attention_mask
|
||||
|
||||
def compile(self):
|
||||
tmp_model_name = (
|
||||
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
|
||||
)
|
||||
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
model_vmfb_name = None
|
||||
vmfb_path = (
|
||||
Path(tmp_model_name + f"_{self.device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
shark_module = get_vmfb_from_path(
|
||||
vmfb_path, self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
if shark_module is not None:
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(tmp_model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
model = StableLMModel(self.get_src_model())
|
||||
model_inputs = self.get_model_inputs()
|
||||
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*model_inputs],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(tmp_model_name + ".mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir")
|
||||
f_.close()
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
|
||||
def get_tokenizer(self):
|
||||
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
def generate(self, prompt):
|
||||
words_list = []
|
||||
for i in range(self.max_num_tokens):
|
||||
params = {
|
||||
"new_text": prompt,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params)
|
||||
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
|
||||
if stop_generation:
|
||||
break
|
||||
|
||||
print(detok, end="", flush=True) # this is for CLI and DEBUG
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
prompt = prompt + detok
|
||||
return words_list
|
||||
|
||||
def generate_new_token(self, params):
|
||||
new_text = params["new_text"]
|
||||
model_inputs = self.tokenizer(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=self.max_sequence_len,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
output = self.shark_model(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if self.shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
|
||||
detok = self.tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
|
||||
# Initialize a StopOnTokens object
|
||||
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
647
apps/language_models/src/pipelines/vicuna_pipeline.py
Normal file
647
apps/language_models/src/pipelines/vicuna_pipeline.py
Normal file
@@ -0,0 +1,647 @@
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="vicuna runner",
|
||||
description="runs a vicuna model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
"--first_vicuna_vmfb_path", default=None, help="path to first vicuna vmfb"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--second_vicuna_vmfb_path",
|
||||
default=None,
|
||||
help="path to second vicuna vmfb",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first_vicuna_mlir_path",
|
||||
default=None,
|
||||
help="path to first vicuna mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--second_vicuna_mlir_path",
|
||||
default=None,
|
||||
help="path to second vicuna mlir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cli",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model in cli mode",
|
||||
)
|
||||
|
||||
|
||||
class Vicuna(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
first_vicuna_mlir_path=Path("first_vicuna.mlir"),
|
||||
second_vicuna_mlir_path=Path("second_vicuna.mlir"),
|
||||
first_vicuna_vmfb_path=Path("first_vicuna.vmfb"),
|
||||
second_vicuna_vmfb_path=Path("second_vicuna.vmfb"),
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
|
||||
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
|
||||
self.first_vicuna_mlir_path = first_vicuna_mlir_path
|
||||
self.second_vicuna_mlir_path = second_vicuna_mlir_path
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
return vicuna_model
|
||||
|
||||
def compile_first_vicuna(self):
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.first_vicuna_vmfb_path, self.device, "tm_tensor"
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
# Compilation path needs some more work before it is functional
|
||||
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with"
|
||||
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if self.first_vicuna_mlir_path.exists():
|
||||
with open(self.first_vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
if args.load_mlir_from_shark_tank:
|
||||
if self.precision == "fp32":
|
||||
# download MLIR from shark_tank for fp32
|
||||
download_public_file(
|
||||
"gs://shark_tank/vicuna/unsharded/mlir/second_vicuna.mlir",
|
||||
self.first_vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.first_vicuna_mlir_path.exists():
|
||||
with open(self.first_vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Only fp32 mlir added to tank, generating mlir on device."
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = self.tokenizer(
|
||||
compilation_prompt
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(
|
||||
compilation_input_ids
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(self.hf_model_path)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = get_torch_mlir_module_bytecode(
|
||||
model, firstVicunaCompileInput
|
||||
)
|
||||
del model
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
|
||||
firstVicunaCompileInput = list(firstVicunaCompileInput)
|
||||
firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like(
|
||||
firstVicunaCompileInput[0], dynamic_axes=[1]
|
||||
)
|
||||
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dim)", line
|
||||
)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
return line
|
||||
|
||||
module = str(module)
|
||||
new_lines = []
|
||||
|
||||
print(f"[DEBUG] rewriting torch_mlir file")
|
||||
for line in module.splitlines():
|
||||
line = remove_constant_dim(line)
|
||||
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
|
||||
new_lines.append(
|
||||
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
|
||||
)
|
||||
if (
|
||||
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
|
||||
in line
|
||||
):
|
||||
continue
|
||||
|
||||
new_lines.append(line)
|
||||
|
||||
module = "\n".join(new_lines)
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
del new_lines
|
||||
module = module.encode("UTF-8")
|
||||
module = BytesIO(module)
|
||||
bytecode = module.read()
|
||||
del module
|
||||
|
||||
print(f"[DEBUG] writing mlir to file")
|
||||
f_ = open(f"{self.model_name}.mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
vmfb_name = "first_" + self.model_name
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(),
|
||||
vmfb_name,
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
print("Saved first vic vmfb at vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile_second_vicuna(self):
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.second_vicuna_vmfb_path, self.device, "tm_tensor"
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
# Compilation path needs some more work before it is functional
|
||||
print(
|
||||
f"[DEBUG] mlir path {self.second_vicuna_mlir_path} {'exists' if self.second_vicuna_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if self.second_vicuna_mlir_path.exists():
|
||||
with open(self.second_vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
if args.load_mlir_from_shark_tank:
|
||||
if self.precision == "fp32":
|
||||
# download MLIR from shark_tank for fp32
|
||||
download_public_file(
|
||||
"gs://shark_tank/vicuna/unsharded/mlir/second_vicuna.mlir",
|
||||
self.second_vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.second_vicuna_mlir_path.exists():
|
||||
with open(self.second_vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Only fp32 mlir added to tank, generating mlir on device."
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
|
||||
pkv = tuple(
|
||||
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
|
||||
for _ in range(64)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
model = SecondVicuna(self.hf_model_path)
|
||||
ts_graph = get_torch_mlir_module_bytecode(
|
||||
model, secondVicunaCompileInput
|
||||
)
|
||||
secondVicunaCompileInput = list(secondVicunaCompileInput)
|
||||
for i in range(len(secondVicunaCompileInput)):
|
||||
if i != 0:
|
||||
secondVicunaCompileInput[
|
||||
i
|
||||
] = torch_mlir.TensorPlaceholder.like(
|
||||
secondVicunaCompileInput[i], dynamic_axes=[2]
|
||||
)
|
||||
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "c19_i64" in line:
|
||||
line = re.sub("c19_i64", "dim_i64", line)
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dim)", line
|
||||
)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
if "20x" in line:
|
||||
line = re.sub("20x", "?x", line)
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dimp1)", line
|
||||
)
|
||||
if " 20," in line:
|
||||
line = re.sub(" 20,", " %dimp1,", line)
|
||||
return line
|
||||
|
||||
module_str = str(module)
|
||||
new_lines = []
|
||||
|
||||
for line in module_str.splitlines():
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append(
|
||||
"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128xf32>"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
)
|
||||
continue
|
||||
if "%c2 = arith.constant 2 : index" in line:
|
||||
continue
|
||||
if "%c20_i64 = arith.constant 20 : i64" in line:
|
||||
new_lines.append("%c1_i64 = arith.constant 1 : i64")
|
||||
new_lines.append(
|
||||
"%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
|
||||
)
|
||||
continue
|
||||
line = remove_constant_dim(line)
|
||||
new_lines.append(line)
|
||||
|
||||
module_str = "\n".join(new_lines)
|
||||
bytecode = module_str.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
f_ = open(f"{self.model_name}.mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(),
|
||||
self.model_name,
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
shark_module.load_module(self.second_vicuna_vmfb_path)
|
||||
|
||||
# self.shark_module = shark_module
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile(self):
|
||||
# Cannot load both the models in the memory at once
|
||||
# due to memory constraints, hence on demand compilation
|
||||
# is being used until the space is enough for both models
|
||||
|
||||
# Testing : DO NOT Download Vmfbs if not found. Modify later
|
||||
# download vmfbs for A100
|
||||
if (
|
||||
not self.first_vicuna_vmfb_path.exists()
|
||||
and self.device == "cuda"
|
||||
and self.precision == "fp32"
|
||||
):
|
||||
download_public_file(
|
||||
"gs://shark_tank/vicuna/unsharded/first_vicuna.vmfb",
|
||||
self.first_vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
else:
|
||||
# get first vic
|
||||
# TODO: Remove after testing to avoid memory overload
|
||||
# fvic_shark_model = self.compile_first_vicuna()
|
||||
pass
|
||||
if (
|
||||
not self.second_vicuna_vmfb_path.exists()
|
||||
and self.device == "cuda"
|
||||
and self.precision == "fp32"
|
||||
):
|
||||
download_public_file(
|
||||
"gs://shark_tank/vicuna/unsharded/second_vicuna.vmfb",
|
||||
self.second_vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
else:
|
||||
# get second vic
|
||||
# TODO: Remove after testing to avoid memory overload
|
||||
# svic_shark_model = self.compile_second_vicuna()
|
||||
pass
|
||||
|
||||
# get first vic
|
||||
# fvic_shark_model = self.compile_first_vicuna()
|
||||
# get second vic
|
||||
# svic_shark_model = self.compile_second_vicuna()
|
||||
# return tuple of shark_modules
|
||||
# return fvic_shark_model, svic_shark_model
|
||||
return None
|
||||
# return tuple of shark_modules once mem is supported
|
||||
# return fvic_shark_model, svic_shark_model
|
||||
|
||||
def generate(self, prompt):
|
||||
# TODO: refactor for cleaner integration
|
||||
import gc
|
||||
|
||||
res = []
|
||||
res_tokens = []
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"is_first": True,
|
||||
"fv": self.compile_first_vicuna(),
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["pkv"]
|
||||
detok = generated_token_op["detok"]
|
||||
|
||||
res.append(detok)
|
||||
res_tokens.append(token)
|
||||
if args.cli:
|
||||
print(f"Assistant: {detok}", end=" ", flush=True)
|
||||
|
||||
# Clear First Vic from Memory (main and cuda)
|
||||
del params
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
sec_vic = self.compile_second_vicuna()
|
||||
for _ in range(self.max_num_tokens - 2):
|
||||
params = {
|
||||
"prompt": None,
|
||||
"is_first": False,
|
||||
"logits": logits,
|
||||
"pkv": pkv,
|
||||
"sv": sec_vic,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["pkv"]
|
||||
detok = generated_token_op["detok"]
|
||||
|
||||
if token == 2:
|
||||
break
|
||||
res_tokens.append(token)
|
||||
if detok == "<0x0A>":
|
||||
res.append("\n")
|
||||
if args.cli:
|
||||
print("\n", end="", flush=True)
|
||||
else:
|
||||
res.append(detok)
|
||||
if args.cli:
|
||||
print(f"{detok}", end=" ", flush=True)
|
||||
del sec_vic, pkv, logits
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
for i in range(len(res_tokens)):
|
||||
if type(res_tokens[i]) != int:
|
||||
res_tokens[i] = int(res_tokens[i][0])
|
||||
|
||||
res_str = self.tokenizer.decode(res_tokens)
|
||||
# print(f"[DEBUG] final output : \n{res_str}")
|
||||
return res_str
|
||||
|
||||
def generate_new_token(self, params):
|
||||
def forward_first(first_vic, prompt, cache_outputs=False):
|
||||
input_ids = self.tokenizer(prompt).input_ids
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
firstVicunaInput = (input_ids,)
|
||||
assert first_vic is not None
|
||||
output_first_vicuna = first_vic("forward", firstVicunaInput)
|
||||
output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:])
|
||||
logits_first_vicuna = torch.tensor(output_first_vicuna[0])
|
||||
if cache_outputs:
|
||||
torch.save(
|
||||
logits_first_vicuna, "logits_first_vicuna_tensor.pt"
|
||||
)
|
||||
torch.save(
|
||||
output_first_vicuna_tensor, "output_first_vicuna_tensor.pt"
|
||||
)
|
||||
token = torch.argmax(
|
||||
torch.tensor(logits_first_vicuna)[:, -1, :], dim=1
|
||||
)
|
||||
return token, logits_first_vicuna, output_first_vicuna_tensor
|
||||
|
||||
def forward_second(sec_vic, inputs=None, load_inputs=False):
|
||||
if inputs is not None:
|
||||
logits = inputs[0]
|
||||
pkv = inputs[1:]
|
||||
elif load_inputs:
|
||||
pkv = torch.load("output_first_vicuna_tensor.pt")
|
||||
pkv = tuple(torch.tensor(x) for x in pkv)
|
||||
logits = torch.load("logits_first_vicuna_tensor.pt")
|
||||
else:
|
||||
print(
|
||||
"Either inputs must be given, or load_inputs must be true"
|
||||
)
|
||||
return None
|
||||
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
|
||||
token = token.to(torch.int64).reshape([1, 1])
|
||||
secondVicunaInput = (token,) + tuple(pkv)
|
||||
|
||||
secondVicunaOutput = sec_vic("forward", secondVicunaInput)
|
||||
new_pkv = secondVicunaOutput[1:]
|
||||
new_logits = secondVicunaOutput[0]
|
||||
new_token = torch.argmax(torch.tensor(new_logits)[:, -1, :], dim=1)
|
||||
return new_token, new_logits, new_pkv
|
||||
|
||||
is_first = params["is_first"]
|
||||
|
||||
if is_first:
|
||||
prompt = params["prompt"]
|
||||
fv = params["fv"]
|
||||
token, logits, pkv = forward_first(
|
||||
fv, # self.shark_model[0],
|
||||
prompt=prompt,
|
||||
cache_outputs=False,
|
||||
)
|
||||
else:
|
||||
_logits = params["logits"]
|
||||
_pkv = params["pkv"]
|
||||
inputs = (_logits,) + tuple(_pkv)
|
||||
sv = params["sv"]
|
||||
token, logits, pkv = forward_second(
|
||||
sv, # self.shark_model[1],
|
||||
inputs=inputs,
|
||||
load_inputs=False,
|
||||
)
|
||||
|
||||
detok = self.tokenizer.decode(token)
|
||||
if not args.cli:
|
||||
print(
|
||||
f"[DEBUG] is_first: {is_first} |"
|
||||
f" token : {token} | detok : {detok}"
|
||||
)
|
||||
ret_dict = {
|
||||
"token": token,
|
||||
"logits": logits,
|
||||
"pkv": pkv,
|
||||
"detok": detok,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
first_vic_mlir_path = (
|
||||
Path("first_vicuna.mlir")
|
||||
if args.first_vicuna_mlir_path is None
|
||||
else Path(args.first_vicuna_mlir_path)
|
||||
)
|
||||
second_vic_mlir_path = (
|
||||
Path("second_vicuna.mlir")
|
||||
if args.second_vicuna_mlir_path is None
|
||||
else Path(args.second_vicuna_mlir_path)
|
||||
)
|
||||
first_vic_vmfb_path = (
|
||||
Path("first_vicuna.vmfb")
|
||||
if args.first_vicuna_vmfb_path is None
|
||||
else Path(args.first_vicuna_vmfb_path)
|
||||
)
|
||||
second_vic_vmfb_path = (
|
||||
Path("second_vicuna.vmfb")
|
||||
if args.second_vicuna_vmfb_path is None
|
||||
else Path(args.second_vicuna_vmfb_path)
|
||||
)
|
||||
|
||||
vic = Vicuna(
|
||||
"vicuna",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
first_vicuna_mlir_path=first_vic_mlir_path,
|
||||
second_vicuna_mlir_path=second_vic_mlir_path,
|
||||
first_vicuna_vmfb_path=first_vic_vmfb_path,
|
||||
second_vicuna_vmfb_path=second_vic_vmfb_path,
|
||||
)
|
||||
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
import gc
|
||||
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
prompt_history = (
|
||||
prompt_history + "USER:\n" + user_prompt + prologue_prompt
|
||||
)
|
||||
prompt = prompt_history.strip()
|
||||
res_str = vic.generate(prompt)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
print(
|
||||
"\n-----\nAssistant: Here's the complete formatted reply:\n",
|
||||
res_str,
|
||||
)
|
||||
prompt_history += f"\n{res_str}\n"
|
||||
416
apps/language_models/src/pipelines/vicuna_sharded_pipeline.py
Normal file
416
apps/language_models/src/pipelines/vicuna_sharded_pipeline.py
Normal file
@@ -0,0 +1,416 @@
|
||||
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
FirstVicunaLayer,
|
||||
SecondVicunaLayer,
|
||||
CompiledFirstVicunaLayer,
|
||||
CompiledSecondVicunaLayer,
|
||||
ShardedVicunaModel,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.utils import get_torch_mlir_module_bytecode
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from tqdm import tqdm
|
||||
from torch_mlir import TensorPlaceholder
|
||||
|
||||
|
||||
import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
|
||||
|
||||
class Vicuna(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
return vicuna_model
|
||||
|
||||
def write_in_dynamic_inputs0(self, module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
def write_in_dynamic_inputs1(self, module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
if "dim_42 =" in line:
|
||||
continue
|
||||
if f"%c{dynamic_input_size}_i64 =" in line:
|
||||
new_lines.append(
|
||||
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
|
||||
)
|
||||
new_lines.append(
|
||||
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
|
||||
)
|
||||
continue
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dim_42)", line
|
||||
)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim_42\)",
|
||||
"tensor.empty(%dim_42, %dim_42)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
def compile_vicuna_layer(
|
||||
self,
|
||||
vicuna_layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0=None,
|
||||
past_key_value1=None,
|
||||
):
|
||||
if past_key_value0 is None and past_key_value1 is None:
|
||||
model_inputs = (hidden_states, attention_mask, position_ids)
|
||||
else:
|
||||
model_inputs = (
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
)
|
||||
mlir_bytecode = get_torch_mlir_module_bytecode(
|
||||
vicuna_layer, model_inputs
|
||||
)
|
||||
return mlir_bytecode
|
||||
|
||||
def compile_to_vmfb(self, inputs, layers, is_first=True):
|
||||
mlirs, modules = [], []
|
||||
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
|
||||
if is_first:
|
||||
mlir_path = Path(f"{idx}_0.mlir")
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
else:
|
||||
mlir_path = Path(f"{idx}_1.mlir")
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if vmfb_path.exists():
|
||||
continue
|
||||
if mlir_path.exists():
|
||||
# print(f"Found layer {idx} mlir")
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
inputs[0], dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
inputs[1], dynamic_axes=[3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
inputs[2], dynamic_axes=[1]
|
||||
)
|
||||
if not is_first:
|
||||
pkv0_placeholder = TensorPlaceholder.like(
|
||||
inputs[3], dynamic_axes=[2]
|
||||
)
|
||||
pkv1_placeholder = TensorPlaceholder.like(
|
||||
inputs[4], dynamic_axes=[2]
|
||||
)
|
||||
print(f"Compiling layer {idx} mlir")
|
||||
if is_first:
|
||||
ts_g = self.compile_vicuna_layer(
|
||||
layer, inputs[0], inputs[1], inputs[2]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
hidden_states_placeholder,
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
else:
|
||||
ts_g = self.compile_vicuna_layer(
|
||||
layer,
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
inputs[3],
|
||||
inputs[4],
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
inputs[0],
|
||||
attention_mask_placeholder,
|
||||
inputs[2],
|
||||
pkv0_placeholder,
|
||||
pkv1_placeholder,
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# bytecode_stream = BytesIO()
|
||||
# module.operation.write_bytecode(bytecode_stream)
|
||||
# bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if is_first:
|
||||
module = self.write_in_dynamic_inputs0(str(module), 137)
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
else:
|
||||
module = self.write_in_dynamic_inputs1(str(module), 138)
|
||||
if idx in [0, 5, 6, 7]:
|
||||
module_str = module
|
||||
module_str = module_str.splitlines()
|
||||
new_lines = []
|
||||
for line in module_str:
|
||||
if len(line) < 1000:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line[:999])
|
||||
module_str = "\n".join(new_lines)
|
||||
f1_ = open(f"{idx}_1_test.mlir", "w+")
|
||||
f1_.write(module_str)
|
||||
f1_.close()
|
||||
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
mlirs.append(bytecode)
|
||||
|
||||
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
|
||||
if is_first:
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_0",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
else:
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_1",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
|
||||
return mlirs, modules
|
||||
|
||||
def get_sharded_model(self):
|
||||
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
|
||||
# please don't change it
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
vicuna_model = self.get_src_model()
|
||||
placeholder_input0 = (
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
|
||||
)
|
||||
|
||||
placeholder_input1 = (
|
||||
torch.zeros([1, 1, 4096]),
|
||||
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
|
||||
torch.zeros([1, 1], dtype=torch.int64),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
|
||||
layers0 = [
|
||||
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
_, modules0 = self.compile_to_vmfb(
|
||||
placeholder_input0, layers0, is_first=True
|
||||
)
|
||||
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
|
||||
|
||||
layers1 = [
|
||||
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
_, modules1 = self.compile_to_vmfb(
|
||||
placeholder_input1, layers1, is_first=False
|
||||
)
|
||||
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
|
||||
|
||||
sharded_model = ShardedVicunaModel(
|
||||
vicuna_model, shark_layers0, shark_layers1
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
def compile(self):
|
||||
return self.get_sharded_model()
|
||||
|
||||
def generate(self, prompt):
|
||||
# TODO: refactor for cleaner integration
|
||||
|
||||
tokens_generated = []
|
||||
_past_key_values = None
|
||||
_token = None
|
||||
detoks_generated = []
|
||||
for iteration in range(self.max_num_tokens):
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"is_first": iteration == 0,
|
||||
"token": _token,
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
|
||||
_token = generated_token_op["token"]
|
||||
_past_key_values = generated_token_op["past_key_values"]
|
||||
_detok = generated_token_op["detok"]
|
||||
|
||||
if _token == 2:
|
||||
break
|
||||
detoks_generated.append(_detok)
|
||||
tokens_generated.append(_token)
|
||||
|
||||
for i in range(len(tokens_generated)):
|
||||
if type(tokens_generated[i]) != int:
|
||||
tokens_generated[i] = int(tokens_generated[i][0])
|
||||
result_output = self.tokenizer.decode(tokens_generated)
|
||||
return result_output
|
||||
|
||||
def generate_new_token(self, params):
|
||||
is_first = params["is_first"]
|
||||
if is_first:
|
||||
prompt = params["prompt"]
|
||||
input_ids = self.tokenizer(prompt).input_ids
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
output = self.shark_model.forward(input_ids, is_first=is_first)
|
||||
else:
|
||||
token = params["token"]
|
||||
past_key_values = params["past_key_values"]
|
||||
input_ids = [token]
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
output = self.shark_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=is_first
|
||||
)
|
||||
|
||||
_logits = output["logits"]
|
||||
_past_key_values = output["past_key_values"]
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
_detok = self.tokenizer.decode(_token)
|
||||
|
||||
ret_dict = {
|
||||
"token": _token,
|
||||
"detok": _detok,
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
|
||||
print(f" token : {_token} | detok : {_detok}")
|
||||
|
||||
return ret_dict
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
vic = Vicuna("vicuna")
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
user_prompt = input("User: ")
|
||||
prompt_history = prompt_history + "USER:\n" + user_prompt + prologue_prompt
|
||||
prompt = prompt_history.strip()
|
||||
|
||||
res = vic.generate(prompt)
|
||||
print(prompt + res)
|
||||
140
apps/language_models/utils.py
Normal file
140
apps/language_models/utils.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_torch_mlir_module_bytecode(model, model_inputs):
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
# tracing_mode='symbolic',
|
||||
)(*model_inputs)
|
||||
print("Got FX_G")
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
print("FX_G recompile")
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
print("Got TS_G")
|
||||
return ts_g
|
||||
|
||||
|
||||
# expects a Path / str as arg
|
||||
# returns None if path not found or SharkInference module
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
if not isinstance(vmfb_path, Path):
|
||||
vmfb_path = Path(vmfb_path)
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
if not vmfb_path.exists():
|
||||
return None
|
||||
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect=mlir_dialect
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
return shark_module
|
||||
@@ -10,7 +10,7 @@ Vulkan AMD:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
|
||||
# use –iree-input-type=mhlo for tf models
|
||||
# use –iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
|
||||
|
||||
CUDA NVIDIA:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
@@ -20,6 +20,8 @@ from diffusers import (
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_IDLE,
|
||||
SD_STATE_CANCEL,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
@@ -84,6 +86,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.low_res_scheduler = low_res_scheduler
|
||||
self.status = SD_STATE_IDLE
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
accepts_eta = "eta" in set(
|
||||
@@ -164,6 +167,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.status = SD_STATE_IDLE
|
||||
self.load_unet()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
@@ -210,6 +214,9 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.status == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
|
||||
@@ -87,7 +87,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.text_encoder = get_clip()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.text_encoder = self.sd_model.clip()
|
||||
|
||||
@@ -104,7 +105,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.unet = get_unet()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.unet = self.sd_model.unet()
|
||||
|
||||
@@ -121,7 +123,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.vae = get_vae()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae = self.sd_model.vae()
|
||||
|
||||
|
||||
@@ -30,6 +30,8 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
sanitize_seed,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
get_generation_text_info,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"stablediffusion/untuned":"gs://shark_tank/nightly"
|
||||
},
|
||||
{
|
||||
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
|
||||
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
|
||||
@@ -125,6 +125,8 @@ def load_lower_configs(base_model_id=None):
|
||||
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
|
||||
|
||||
|
||||
@@ -521,6 +521,22 @@ p.add_argument(
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for enabling rest API",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for removing the output gallery tab, and avoid exposing images under --output_dir in the UI",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery_followlinks",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for whether the output gallery tab in the UI should follow symlinks when listing subdirectorys under --output_dir",
|
||||
)
|
||||
|
||||
|
||||
##############################################################################
|
||||
### SD model auto-annotation flags
|
||||
##############################################################################
|
||||
|
||||
@@ -83,7 +83,6 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
if "cuda" in args.device:
|
||||
@@ -338,13 +337,25 @@ def set_init_device_flags():
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
]
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
or "rdna" not in args.iree_vulkan_target_triple
|
||||
)
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif "rdna2" in args.iree_vulkan_target_triple and (
|
||||
base_model_id
|
||||
not in [
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
if args.use_tuned:
|
||||
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
|
||||
print(
|
||||
f"Using tuned models for {base_model_id}(fp16) on device {args.device}."
|
||||
)
|
||||
else:
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
@@ -694,11 +705,20 @@ def clear_all():
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
return Path(
|
||||
args.output_dir if args.output_dir else Path.cwd(), "generated_imgs"
|
||||
)
|
||||
|
||||
|
||||
def get_generated_imgs_todays_subdir() -> str:
|
||||
return dt.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed, extra_info={}):
|
||||
output_path = args.output_dir if args.output_dir else Path.cwd()
|
||||
generated_imgs_path = Path(
|
||||
output_path, "generated_imgs", dt.now().strftime("%Y%m%d")
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import transformers
|
||||
import transformers # ensures inclusion in pysintaller exe generation
|
||||
from apps.stable_diffusion.src import args, clear_all
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
@@ -80,6 +80,8 @@ if __name__ == "__main__":
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
txt2img_sendto_img2img,
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
@@ -89,6 +91,7 @@ if __name__ == "__main__":
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
img2img_sendto_inpaint,
|
||||
img2img_sendto_outpaint,
|
||||
img2img_sendto_upscaler,
|
||||
@@ -97,6 +100,7 @@ if __name__ == "__main__":
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
inpaint_sendto_img2img,
|
||||
inpaint_sendto_outpaint,
|
||||
inpaint_sendto_upscaler,
|
||||
@@ -105,6 +109,7 @@ if __name__ == "__main__":
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
outpaint_sendto_img2img,
|
||||
outpaint_sendto_inpaint,
|
||||
outpaint_sendto_upscaler,
|
||||
@@ -113,6 +118,7 @@ if __name__ == "__main__":
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
@@ -125,6 +131,15 @@ if __name__ == "__main__":
|
||||
modelmanager_sendto_outpaint,
|
||||
modelmanager_sendto_upscaler,
|
||||
stablelm_chat,
|
||||
outputgallery_web,
|
||||
outputgallery_tab_select,
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
outputgallery_sendto_upscaler,
|
||||
)
|
||||
|
||||
# init global sd pipeline and config
|
||||
@@ -151,6 +166,16 @@ if __name__ == "__main__":
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_outputgallery_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
|
||||
) as sd_web:
|
||||
@@ -171,7 +196,23 @@ if __name__ == "__main__":
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
|
||||
lora_train_web.render()
|
||||
if args.output_gallery:
|
||||
with gr.TabItem(label="Output Gallery", id=8) as og_tab:
|
||||
outputgallery_web.render()
|
||||
|
||||
# extra output gallery configuration
|
||||
outputgallery_tab_select(og_tab.select)
|
||||
outputgallery_watch(
|
||||
[
|
||||
txt2img_status,
|
||||
img2img_status,
|
||||
inpaint_status,
|
||||
outpaint_status,
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
|
||||
# send to buttons
|
||||
register_button_click(
|
||||
txt2img_sendto_img2img,
|
||||
1,
|
||||
@@ -268,6 +309,37 @@ if __name__ == "__main__":
|
||||
[upscaler_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
if args.output_gallery:
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_txt2img,
|
||||
0,
|
||||
[outputgallery_filename],
|
||||
[txt2img_png_info_img, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_img2img,
|
||||
1,
|
||||
[outputgallery_filename],
|
||||
[img2img_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_inpaint,
|
||||
2,
|
||||
[outputgallery_filename],
|
||||
[inpaint_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_outpaint,
|
||||
3,
|
||||
[outputgallery_filename],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_upscaler,
|
||||
4,
|
||||
[outputgallery_filename],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_txt2img,
|
||||
0,
|
||||
|
||||
@@ -5,6 +5,8 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
txt2img_sendto_img2img,
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
@@ -18,6 +20,7 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
img2img_sendto_inpaint,
|
||||
img2img_sendto_outpaint,
|
||||
img2img_sendto_upscaler,
|
||||
@@ -30,6 +33,7 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
inpaint_sendto_img2img,
|
||||
inpaint_sendto_outpaint,
|
||||
inpaint_sendto_upscaler,
|
||||
@@ -42,6 +46,7 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
outpaint_sendto_img2img,
|
||||
outpaint_sendto_inpaint,
|
||||
outpaint_sendto_upscaler,
|
||||
@@ -54,6 +59,7 @@ from apps.stable_diffusion.web.ui.upscaler_ui import (
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
@@ -69,3 +75,14 @@ from apps.stable_diffusion.web.ui.model_manager import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
|
||||
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat
|
||||
from apps.stable_diffusion.web.ui.outputgallery_ui import (
|
||||
outputgallery_web,
|
||||
outputgallery_tab_select,
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
outputgallery_sendto_upscaler,
|
||||
)
|
||||
|
||||
@@ -230,3 +230,44 @@ footer {
|
||||
#top_logo .download {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* output gallery tab */
|
||||
.output_parameters_dataframe tbody td {
|
||||
font-size: small;
|
||||
line-height: var(--line-xs)
|
||||
}
|
||||
|
||||
#output_refresh_button {
|
||||
max-width: 30px;
|
||||
align-self: end;
|
||||
padding-bottom: 8px;
|
||||
}
|
||||
|
||||
.outputgallery_sendto {
|
||||
min-width: 7em !important;
|
||||
}
|
||||
|
||||
/* output gallery should take up most of the viewport height regardless of image size/number */
|
||||
#outputgallery_gallery .fixed-height {
|
||||
min-height: 89vh !important;
|
||||
}
|
||||
|
||||
/* don't stretch non-square images to be square, breaking their aspect ratio */
|
||||
#outputgallery_gallery .thumbnail-item.thumbnail-lg > img {
|
||||
object-fit: contain !important;
|
||||
}
|
||||
|
||||
/* centered logo for when there are no images */
|
||||
#top_logo.logo_centered {
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
#top_logo.logo_centered img{
|
||||
object-fit: scale-down;
|
||||
position: absolute;
|
||||
width: 80%;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
import PIL
|
||||
from PIL import Image
|
||||
@@ -26,10 +24,13 @@ from apps.stable_diffusion.src import (
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -259,11 +260,17 @@ def img2img_inf(
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed, extra_info)
|
||||
save_output_img(
|
||||
out_imgs[0],
|
||||
img_seed,
|
||||
extra_info,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
# yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Image-to-Image", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
@@ -593,16 +600,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
img2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
img2img_sendto_outpaint = gr.Button(
|
||||
@@ -640,13 +644,21 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output],
|
||||
outputs=[img2img_gallery, std_output, img2img_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Image-to-Image", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=img2img_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
@@ -26,7 +25,11 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
@@ -213,7 +216,9 @@ def inpaint_inf(
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Inpaint", i + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
@@ -494,16 +499,14 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
inpaint_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
inpaint_sendto_outpaint = gr.Button(
|
||||
@@ -541,13 +544,20 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[inpaint_gallery, std_output],
|
||||
outputs=[inpaint_gallery, std_output, inpaint_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=inpaint_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
@@ -23,10 +21,13 @@ from apps.stable_diffusion.src import (
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
@@ -222,9 +223,11 @@ def outpaint_inf(
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Outpaint", i + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
@@ -524,16 +527,13 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
outpaint_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -572,13 +572,20 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[outpaint_gallery, std_output],
|
||||
outputs=[outpaint_gallery, std_output, outpaint_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Outpaint", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=outpaint_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
419
apps/stable_diffusion/web/ui/outputgallery_ui.py
Normal file
419
apps/stable_diffusion/web/ui/outputgallery_ui.py
Normal file
@@ -0,0 +1,419 @@
|
||||
import glob
|
||||
import gradio as gr
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
gradio_tmp_galleries_folder,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
|
||||
|
||||
# -- Functions for file, directory and image info querying
|
||||
|
||||
output_dir = get_generated_imgs_path()
|
||||
|
||||
|
||||
def outputgallery_filenames(subdir) -> list[str]:
|
||||
new_dir_path = os.path.join(output_dir, subdir)
|
||||
if os.path.exists(new_dir_path):
|
||||
filenames = [
|
||||
glob.glob(new_dir_path + "/" + ext)
|
||||
for ext in ("*.png", "*.jpg", "*.jpeg")
|
||||
]
|
||||
|
||||
return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True)
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def output_subdirs() -> list[str]:
|
||||
# Gets a list of subdirectories of output_dir and below, as relative paths.
|
||||
relative_paths = [
|
||||
os.path.relpath(entry[0], output_dir)
|
||||
for entry in os.walk(
|
||||
output_dir, followlinks=args.output_gallery_followlinks
|
||||
)
|
||||
]
|
||||
|
||||
# It is less confusing to always including the subdir that will take any images generated
|
||||
# today even if it doesn't exist yet
|
||||
if get_generated_imgs_todays_subdir() not in relative_paths:
|
||||
relative_paths.append(get_generated_imgs_todays_subdir())
|
||||
|
||||
# sort subdirectories so that that the date named ones we probably created in this or
|
||||
# previous sessions come first, sorted with the most recent first. Other subdirs are listed
|
||||
# after.
|
||||
generated_paths = sorted(
|
||||
[path for path in relative_paths if path.isnumeric()], reverse=True
|
||||
)
|
||||
result_paths = generated_paths + sorted(
|
||||
[
|
||||
path
|
||||
for path in relative_paths
|
||||
if (not path.isnumeric()) and path != "."
|
||||
]
|
||||
)
|
||||
|
||||
return result_paths
|
||||
|
||||
|
||||
# --- Define UI layout for Gradio
|
||||
|
||||
with gr.Blocks() as outputgallery_web:
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
|
||||
with gr.Row(elem_id="outputgallery_gallery"):
|
||||
# needed to workaround gradio issue: https://github.com/gradio-app/gradio/issues/2907
|
||||
dev_null = gr.Textbox("", visible=False)
|
||||
|
||||
gallery_files = gr.State(value=[])
|
||||
subdirectory_paths = gr.State(value=[])
|
||||
|
||||
with gr.Column(scale=6):
|
||||
logo = gr.Image(
|
||||
label="Getting subdirectories...",
|
||||
value=nod_logo,
|
||||
interactive=False,
|
||||
visible=True,
|
||||
show_label=True,
|
||||
elem_id="top_logo",
|
||||
elem_classes="logo_centered",
|
||||
)
|
||||
|
||||
gallery = gr.Gallery(
|
||||
label="",
|
||||
value=gallery_files.value,
|
||||
visible=False,
|
||||
show_label=True,
|
||||
).style(grid=4)
|
||||
gallery.DEFAULT_TEMP_DIR = gradio_tmp_galleries_folder
|
||||
|
||||
with gr.Column(scale=4):
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column(scale=16, min_width=160):
|
||||
subdirectories = gr.Dropdown(
|
||||
label=f"Subdirectories of {output_dir}",
|
||||
type="value",
|
||||
choices=subdirectory_paths.value,
|
||||
value="",
|
||||
interactive=True,
|
||||
).style(container=False)
|
||||
with gr.Column(
|
||||
scale=1, min_width=32, elem_id="output_refresh_button"
|
||||
):
|
||||
refresh = gr.Button(
|
||||
variant="secondary",
|
||||
value="\u21BB", # unicode clockwise arrow circle
|
||||
).style(size="sm")
|
||||
|
||||
image_columns = gr.Slider(
|
||||
label="Columns shown", value=4, minimum=1, maximum=16, step=1
|
||||
)
|
||||
outputgallery_filename = gr.Textbox(
|
||||
label="Filename", value="None", interactive=False
|
||||
).style(show_copy_button=True)
|
||||
|
||||
with gr.Accordion(
|
||||
label="Parameter Information", open=False
|
||||
) as parameters_accordian:
|
||||
image_parameters = gr.DataFrame(
|
||||
headers=["Parameter", "Value"],
|
||||
col_count=2,
|
||||
wrap=True,
|
||||
elem_classes="output_parameters_dataframe",
|
||||
value=[["Status", "No image selected"]],
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Send To", open=True):
|
||||
with gr.Row():
|
||||
outputgallery_sendto_txt2img = gr.Button(
|
||||
value="Txt2Img",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
outputgallery_sendto_img2img = gr.Button(
|
||||
value="Img2Img",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
outputgallery_sendto_inpaint = gr.Button(
|
||||
value="Inpaint",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
outputgallery_sendto_outpaint = gr.Button(
|
||||
value="Outpaint",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
outputgallery_sendto_upscaler = gr.Button(
|
||||
value="Upscaler",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
# --- Event handlers
|
||||
|
||||
def on_clear_gallery():
|
||||
return [
|
||||
gr.Gallery.update(
|
||||
value=[],
|
||||
visible=False,
|
||||
),
|
||||
gr.Image.update(
|
||||
visible=True,
|
||||
),
|
||||
]
|
||||
|
||||
def on_select_subdir(subdir) -> list:
|
||||
# evt.value is the subdirectory name
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = (
|
||||
f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
|
||||
)
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
|
||||
def on_refresh(current_subdir: str) -> list:
|
||||
# get an up to date subdirectory list
|
||||
refreshed_subdirs = output_subdirs()
|
||||
# get the images using either the current subdirectory or the most recent valid one
|
||||
new_subdir = (
|
||||
current_subdir
|
||||
if current_subdir in refreshed_subdirs
|
||||
else refreshed_subdirs[0]
|
||||
)
|
||||
new_images = outputgallery_filenames(new_subdir)
|
||||
new_label = f"{len(new_images)} images in {os.path.join(output_dir, new_subdir)}"
|
||||
|
||||
return [
|
||||
gr.Dropdown.update(
|
||||
choices=refreshed_subdirs,
|
||||
value=new_subdir,
|
||||
),
|
||||
refreshed_subdirs,
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
value=new_images, label=new_label, visible=len(new_images) > 0
|
||||
),
|
||||
gr.Image.update(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
|
||||
def on_new_image(subdir, subdir_paths, status) -> list:
|
||||
# prevent error triggered when an image generates before the tab has even been selected
|
||||
subdir_paths = (
|
||||
subdir_paths
|
||||
if len(subdir_paths) > 0
|
||||
else [get_generated_imgs_todays_subdir()]
|
||||
)
|
||||
|
||||
# only update if the current subdir is the most recent one as new images only go there
|
||||
if subdir_paths[0] == subdir:
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"
|
||||
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
else:
|
||||
# otherwise change nothing, (only untyped gradio gr.update() does this)
|
||||
return [gr.update(), gr.update(), gr.update()]
|
||||
|
||||
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
|
||||
# evt.index is an index into the full list of filenames for the current subdirectory
|
||||
filename = images[evt.index]
|
||||
params = displayable_metadata(filename)
|
||||
|
||||
if params:
|
||||
return [
|
||||
filename,
|
||||
list(map(list, params["parameters"].items())),
|
||||
]
|
||||
|
||||
return [
|
||||
filename,
|
||||
[["Status", "No parameters found"]],
|
||||
]
|
||||
|
||||
def on_outputgallery_filename_change(filename: str) -> list:
|
||||
exists = filename != "None" and os.path.exists(filename)
|
||||
return [
|
||||
# disable or enable each of the sendto button based on whether an image is selected
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
]
|
||||
|
||||
# The time first our tab is selected we need to do an initial refresh to populate
|
||||
# the subdirectory select box and the images from the most recent subdirectory.
|
||||
#
|
||||
# We do it at this point rather than setting this up in the controls' definitions
|
||||
# as when you refresh the browser you always get what was *initially* set, which
|
||||
# won't include any new subdirectories or images that might have created since
|
||||
# the application was started. Doing it this way means a browser refresh/reload
|
||||
# always gets the most up to date data.
|
||||
def on_select_tab(subdir_paths):
|
||||
if len(subdir_paths) == 0:
|
||||
return on_refresh("")
|
||||
else:
|
||||
return (
|
||||
# Change nothing, (only untyped gr.update() does this)
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
# Unfortunately as of gradio 3.22.0 gr.update against Galleries doesn't support
|
||||
# things set with .style, nor the elem_classes kwarg so we have to directly set
|
||||
# things up via JavaScript if we want the client to take notice of any of our
|
||||
# changes to the number of columns after it decides to put them back to the
|
||||
# original number when we change something
|
||||
def js_set_columns_in_browser(timeout_length):
|
||||
return f"""
|
||||
(new_cols) => {{
|
||||
setTimeout(() => {{
|
||||
required_style = "auto ".repeat(new_cols).trim();
|
||||
gallery = document.querySelector('#outputgallery_gallery .grid-container');
|
||||
if (gallery) {{
|
||||
gallery.style.gridTemplateColumns = required_style
|
||||
}}
|
||||
}}, {timeout_length});
|
||||
return []; // prevents console error from gradio
|
||||
}}
|
||||
"""
|
||||
|
||||
# --- Wire handlers up to the actions
|
||||
|
||||
# - Many actions reset the number of columns shown in the gallery on the browser end,
|
||||
# so we have to set them back to what we think they should be after the initial
|
||||
# action.
|
||||
# - None of the actions on this tab trigger inference, and we want the user to be able
|
||||
# to do them whilst other tabs have ongoing inference running. Waiting in the queue
|
||||
# behind inference jobs would mean the UI can't fully respond until the inference tasks
|
||||
# complete, hence queue=False on all of these.
|
||||
set_gallery_columns_immediate = dict(
|
||||
fn=None,
|
||||
inputs=[image_columns],
|
||||
# gradio blanks the UI on Chrome on Linux on gallery select if I don't put an output here
|
||||
outputs=[dev_null],
|
||||
_js=js_set_columns_in_browser(0),
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# setting columns after selecting a gallery item needs a real timeout length for the
|
||||
# number of columns to actually be applied. Not really sure why, maybe something has
|
||||
# to finish animating?
|
||||
set_gallery_columns_delayed = dict(
|
||||
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
|
||||
)
|
||||
|
||||
# clearing images when we need to completely change what's in the gallery avoids current
|
||||
# images being shown replacing piecemeal and prevents weirdness and errors if the user
|
||||
# selects an image during the replacement phase.
|
||||
clear_gallery = dict(
|
||||
fn=on_clear_gallery,
|
||||
inputs=None,
|
||||
outputs=[gallery, logo],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
image_columns.change(**set_gallery_columns_immediate)
|
||||
|
||||
subdirectories.select(**clear_gallery).then(
|
||||
on_select_subdir,
|
||||
[subdirectories],
|
||||
[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
refresh.click(**clear_gallery).then(
|
||||
on_refresh,
|
||||
[subdirectories],
|
||||
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
gallery.select(
|
||||
on_select_image,
|
||||
[gallery_files],
|
||||
[outputgallery_filename, image_parameters],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_delayed)
|
||||
|
||||
outputgallery_filename.change(
|
||||
on_outputgallery_filename_change,
|
||||
[outputgallery_filename],
|
||||
[
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
outputgallery_sendto_upscaler,
|
||||
],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# We should have been given the .select function for our tab, so set it up
|
||||
def outputgallery_tab_select(select):
|
||||
select(
|
||||
fn=on_select_tab,
|
||||
inputs=[subdirectory_paths],
|
||||
outputs=[
|
||||
subdirectories,
|
||||
subdirectory_paths,
|
||||
gallery_files,
|
||||
gallery,
|
||||
logo,
|
||||
],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
# We should have been passed a list of components on other tabs that update
|
||||
# when a new image has generated on that tab, so set things up so the user
|
||||
# will see that new image if they are looking at today's subdirectory
|
||||
def outputgallery_watch(components: gr.Textbox):
|
||||
for component in components:
|
||||
component.change(
|
||||
on_new_image,
|
||||
inputs=[subdirectories, subdirectory_paths, component],
|
||||
outputs=[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
@@ -4,7 +4,6 @@ import os
|
||||
from pathlib import Path
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
|
||||
@@ -23,6 +22,7 @@ def user(message, history):
|
||||
|
||||
sharkModel = 0
|
||||
sharded_model = 0
|
||||
vicuna_model = 0
|
||||
|
||||
|
||||
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
@@ -33,16 +33,21 @@ def chat(curr_system_message, history, model):
|
||||
print(f"In chat for {model}")
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
global vicuna_model
|
||||
if "vicuna" in model:
|
||||
from apps.language_models.scripts.sharded_vicuna_fp32 import (
|
||||
tokenizer,
|
||||
get_sharded_model,
|
||||
from apps.language_models.src.pipelines.vicuna_pipeline import (
|
||||
Vicuna,
|
||||
)
|
||||
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
curr_system_message = start_message_vicuna
|
||||
if sharded_model == 0:
|
||||
sharded_model = get_sharded_model()
|
||||
if vicuna_model == 0:
|
||||
first_vic_vmfb_path = Path("first_vicuna.vmfb")
|
||||
second_vic_vmfb_path = Path("second_vicuna.vmfb")
|
||||
vicuna_model = Vicuna(
|
||||
"vicuna",
|
||||
first_vicuna_vmfb_path=first_vic_vmfb_path,
|
||||
second_vicuna_vmfb_path=second_vic_vmfb_path,
|
||||
)
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
@@ -51,73 +56,30 @@ def chat(curr_system_message, history, model):
|
||||
)
|
||||
prompt = messages.strip()
|
||||
print("prompt = ", prompt)
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
new_sentence = ""
|
||||
for _ in range(200):
|
||||
original_input_ids = input_ids
|
||||
input_id_len = len(input_ids)
|
||||
pad_len = SAMPLE_INPUT_LEN - input_id_len
|
||||
attention_mask = torch.ones([1, input_id_len], dtype=torch.int64)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
attention_mask = torch.nn.functional.pad(
|
||||
torch.tensor(attention_mask),
|
||||
(0, pad_len),
|
||||
mode="constant",
|
||||
value=0,
|
||||
)
|
||||
sentence = vicuna_model.generate(prompt)
|
||||
|
||||
if _ == 0:
|
||||
output = sharded_model.forward(input_ids, is_first=True)
|
||||
else:
|
||||
output = sharded_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=False
|
||||
)
|
||||
logits = output["logits"]
|
||||
past_key_values = output["past_key_values"]
|
||||
new_word = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=1))
|
||||
if new_word == "</s>":
|
||||
break
|
||||
new_sentence += " " + new_word
|
||||
history[-1][1] = new_sentence
|
||||
partial_text = ""
|
||||
for new_text in sentence.split(" "):
|
||||
# print(new_text)
|
||||
partial_text += new_text + " "
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to cleanup the message textbox and the updated conversation history
|
||||
yield history
|
||||
next_token = torch.argmax(logits[:, input_id_len - 1, :], dim=1)
|
||||
original_input_ids.append(next_token)
|
||||
input_ids = [next_token]
|
||||
print(new_sentence)
|
||||
history[-1][1] = sentence
|
||||
return history
|
||||
|
||||
# else Model is StableLM
|
||||
global sharkModel
|
||||
from apps.language_models.scripts.stablelm import (
|
||||
compile_stableLM,
|
||||
StopOnTokens,
|
||||
generate,
|
||||
StableLMModel,
|
||||
from apps.language_models.src.pipelines.stablelm_pipeline import (
|
||||
SharkStableLM,
|
||||
)
|
||||
|
||||
if sharkModel == 0:
|
||||
# sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb")
|
||||
max_sequence_len = 256
|
||||
precision = "fp32"
|
||||
model_name_template = (
|
||||
f"stableLM_linalg_{precision}_seqLen{max_sequence_len}"
|
||||
)
|
||||
# max_new_tokens=512
|
||||
shark_slm = SharkStableLM(
|
||||
"StableLM"
|
||||
) # pass elements from UI as required
|
||||
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32
|
||||
)
|
||||
stableLMModel = StableLMModel(m)
|
||||
input_ids = torch.randint(3, (1, max_sequence_len))
|
||||
attention_mask = torch.randint(3, (1, max_sequence_len))
|
||||
sharkModel = compile_stableLM(
|
||||
stableLMModel,
|
||||
tuple([input_ids, attention_mask]),
|
||||
model_name_template,
|
||||
None, # provide a fully qualified path to vmfb file if already exists
|
||||
)
|
||||
# Initialize a StopOnTokens object
|
||||
stop = StopOnTokens()
|
||||
# Construct the input message string for the model by concatenating the current system message and conversation history
|
||||
if len(curr_system_message.split()) > 160:
|
||||
print("clearing context")
|
||||
@@ -129,18 +91,10 @@ def chat(curr_system_message, history, model):
|
||||
]
|
||||
)
|
||||
|
||||
generate_kwargs = dict(
|
||||
new_text=messages,
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
top_p=0.95,
|
||||
top_k=1000,
|
||||
temperature=1.0,
|
||||
num_beams=1,
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
sharkStableLM=sharkModel,
|
||||
)
|
||||
words_list = generate(**generate_kwargs)
|
||||
generate_kwargs = dict(prompt=messages)
|
||||
|
||||
words_list = shark_slm.generate(**generate_kwargs)
|
||||
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
# print(new_text)
|
||||
@@ -161,17 +115,17 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
"TheBloke/vicuna-7B-1.1-HF",
|
||||
],
|
||||
)
|
||||
device_value = None
|
||||
for d in available_devices:
|
||||
if "vulkan" in d:
|
||||
device_value = d
|
||||
break
|
||||
|
||||
supported_devices = [
|
||||
device for device in available_devices if "cuda" in device
|
||||
]
|
||||
enabled = len(supported_devices) > 0
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=device_value if device_value else available_devices[0],
|
||||
interactive=False,
|
||||
choices=available_devices,
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
)
|
||||
chatbot = gr.Chatbot().style(height=500)
|
||||
with gr.Row():
|
||||
@@ -180,12 +134,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
interactive=enabled,
|
||||
).style(container=False)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit")
|
||||
stop = gr.Button("Stop")
|
||||
clear = gr.Button("Clear")
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
system_msg = gr.Textbox(
|
||||
start_message, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
@@ -17,7 +16,8 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.png_metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImagePipeline,
|
||||
@@ -27,7 +27,10 @@ from apps.stable_diffusion.src import (
|
||||
save_output_img,
|
||||
prompt_examples,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
@@ -202,9 +205,11 @@ def txt2img_inf(
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Text-to-Image", i + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
@@ -308,7 +313,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
+ get_custom_model_files("vae"),
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
png_info_img = gr.Image(
|
||||
txt2img_png_info_img = gr.Image(
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
@@ -469,16 +474,13 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
txt2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -514,22 +516,30 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[txt2img_gallery, std_output],
|
||||
outputs=[txt2img_gallery, std_output, txt2img_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Text-to-Image", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=txt2img_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
png_info_img.change(
|
||||
txt2img_png_info_img.change(
|
||||
fn=import_png_metadata,
|
||||
inputs=[
|
||||
png_info_img,
|
||||
txt2img_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
@@ -542,7 +552,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
txt2img_hf_model_id,
|
||||
],
|
||||
outputs=[
|
||||
png_info_img,
|
||||
txt2img_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
@@ -17,16 +15,16 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_upscaler_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
UpscalerPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
from apps.stable_diffusion.src.utils import get_generated_imgs_path
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
@@ -66,6 +64,9 @@ def upscaler_inf(
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
@@ -202,13 +203,24 @@ def upscaler_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
|
||||
|
||||
save_output_img(high_res_img, img_seed, extra_info)
|
||||
generated_imgs.append(high_res_img)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(high_res_img, img_seed, extra_info)
|
||||
generated_imgs.append(high_res_img)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log, status_label(
|
||||
"Upscaler", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
@@ -220,7 +232,7 @@ def upscaler_inf(
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
@@ -495,16 +507,14 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
upscaler_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -539,13 +549,21 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[upscaler_gallery, std_output],
|
||||
outputs=[upscaler_gallery, std_output, upscaler_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Upscaler", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=upscaler_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
9
apps/stable_diffusion/web/utils/common_label_calc.py
Normal file
9
apps/stable_diffusion/web/utils/common_label_calc.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# functions for generating labels used in common by tabs across the UI
|
||||
|
||||
|
||||
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
|
||||
if batch_index < batch_count:
|
||||
bs = f"x{batch_size}" if batch_size > 1 else ""
|
||||
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
|
||||
else:
|
||||
return f"{tab_name} complete"
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import gradio
|
||||
from os import listdir
|
||||
|
||||
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
gradio_tmp_galleries_folder = os.path.join(gradio_tmp_imgs_folder, "galleries")
|
||||
|
||||
|
||||
# Clear all gradio tmp images
|
||||
@@ -15,6 +17,10 @@ def clear_gradio_tmp_imgs_folder():
|
||||
if fileName.startswith("tmp") and fileName.endswith(".png"):
|
||||
os.remove(gradio_tmp_imgs_folder + fileName)
|
||||
|
||||
# Clear all gradio tmp files created by galleries
|
||||
if os.path.exists(gradio_tmp_galleries_folder):
|
||||
shutil.rmtree(gradio_tmp_galleries_folder)
|
||||
|
||||
|
||||
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
|
||||
6
apps/stable_diffusion/web/utils/metadata/__init__.py
Normal file
6
apps/stable_diffusion/web/utils/metadata/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .png_metadata import (
|
||||
import_png_metadata,
|
||||
)
|
||||
from .display import (
|
||||
displayable_metadata,
|
||||
)
|
||||
31
apps/stable_diffusion/web/utils/metadata/csv_metadata.py
Normal file
31
apps/stable_diffusion/web/utils/metadata/csv_metadata.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import csv
|
||||
import os
|
||||
from .format import humanize, humanizable
|
||||
|
||||
|
||||
def csv_path(image_filename: str):
|
||||
return os.path.join(os.path.dirname(image_filename), "imgs_details.csv")
|
||||
|
||||
|
||||
def has_csv(image_filename: str) -> bool:
|
||||
return os.path.exists(csv_path(image_filename))
|
||||
|
||||
|
||||
def parse_csv(image_filename: str):
|
||||
# We use a reader instead of a DictReader here for images_details.csv files due to the lack of
|
||||
# headers, and then match up the return list for each row with our guess at which column format
|
||||
# the file is using.
|
||||
|
||||
# we assume the final column of the csv has the original filename with full path and match that
|
||||
# against the image_filename. We then exclude the filename from the output, hence the -1's.
|
||||
csv_filename = csv_path(image_filename)
|
||||
|
||||
matches = [
|
||||
humanize(row)
|
||||
for row in csv.reader(open(csv_filename, "r", newline=""))
|
||||
if row
|
||||
and humanizable(row)
|
||||
and os.path.basename(image_filename) in row[-1]
|
||||
]
|
||||
|
||||
return matches[0] if matches else {}
|
||||
50
apps/stable_diffusion/web/utils/metadata/display.py
Normal file
50
apps/stable_diffusion/web/utils/metadata/display.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import json
|
||||
import os
|
||||
from PIL import Image
|
||||
from .png_metadata import parse_generation_parameters
|
||||
from .exif_metadata import has_exif, parse_exif
|
||||
from .csv_metadata import has_csv, parse_csv
|
||||
from .format import compact, humanize
|
||||
|
||||
|
||||
def displayable_metadata(image_filename: str) -> dict:
|
||||
pil_image = Image.open(image_filename)
|
||||
|
||||
# we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads,
|
||||
# and we go via that for SendTo, and is directly tied to the image)
|
||||
if "parameters" in pil_image.info:
|
||||
return {
|
||||
"source": "png",
|
||||
"parameters": compact(
|
||||
parse_generation_parameters(pil_image.info["parameters"])
|
||||
),
|
||||
}
|
||||
|
||||
# we have a matching json file (next most likely to be accurate when it's there)
|
||||
json_path = os.path.splitext(image_filename)[0] + ".json"
|
||||
if os.path.isfile(json_path):
|
||||
with open(json_path) as params_file:
|
||||
return {
|
||||
"source": "json",
|
||||
"parameters": compact(
|
||||
humanize(json.load(params_file), includes_filename=False)
|
||||
),
|
||||
}
|
||||
|
||||
# we have a CSV file so try that (can be different shapes, and it usually has no
|
||||
# headers/param names so of the things we we *know* have parameters, it's the
|
||||
# last resort)
|
||||
if has_csv(image_filename):
|
||||
params = parse_csv(image_filename)
|
||||
if params: # we might not have found the filename in the csv
|
||||
return {
|
||||
"source": "csv",
|
||||
"parameters": compact(params), # already humanized
|
||||
}
|
||||
|
||||
# EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something*
|
||||
if has_exif(image_filename):
|
||||
return {"source": "exif", "parameters": parse_exif(pil_image)}
|
||||
|
||||
# we've got nothing
|
||||
return None
|
||||
52
apps/stable_diffusion/web/utils/metadata/exif_metadata.py
Normal file
52
apps/stable_diffusion/web/utils/metadata/exif_metadata.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from PIL import Image
|
||||
from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS
|
||||
|
||||
|
||||
def has_exif(image_filename: str) -> bool:
|
||||
return True if Image.open(image_filename).getexif() else False
|
||||
|
||||
|
||||
def parse_exif(pil_image: Image) -> dict:
|
||||
img_exif = pil_image.getexif()
|
||||
|
||||
# See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594
|
||||
# I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I
|
||||
# I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a
|
||||
# dependency
|
||||
exif_tags = {
|
||||
TAGS.get(key, key): str(val)
|
||||
for (key, val) in img_exif.items()
|
||||
if key in TAGS
|
||||
and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo)
|
||||
and val
|
||||
and (not isinstance(val, bytes))
|
||||
and (not str(val).isspace())
|
||||
}
|
||||
|
||||
def try_get_ifd(ifd_id):
|
||||
try:
|
||||
return img_exif.get_ifd(ifd_id).items()
|
||||
except KeyError:
|
||||
return {}
|
||||
|
||||
ifd_tags = {
|
||||
TAGS.get(key, key): str(val)
|
||||
for ifd_id in IFD
|
||||
for (key, val) in try_get_ifd(ifd_id)
|
||||
if ifd_id != IFD.GPSInfo
|
||||
and key in TAGS
|
||||
and val
|
||||
and (not isinstance(val, bytes))
|
||||
and (not str(val).isspace())
|
||||
}
|
||||
|
||||
gps_tags = {
|
||||
GPSTAGS.get(key, key): str(val)
|
||||
for (key, val) in try_get_ifd(IFD.GPSInfo)
|
||||
if key in GPSTAGS
|
||||
and val
|
||||
and (not isinstance(val, bytes))
|
||||
and (not str(val).isspace())
|
||||
}
|
||||
|
||||
return {**exif_tags, **ifd_tags, **gps_tags}
|
||||
115
apps/stable_diffusion/web/utils/metadata/format.py
Normal file
115
apps/stable_diffusion/web/utils/metadata/format.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# As SHARK has evolved more columns have been added to images_details.csv. However, since
|
||||
# no version of the CSV has any headers (yet) we don't actually have anything within the
|
||||
# file that tells us which parameter each column is for. So this is a list of known patterns
|
||||
# indexed by length which is what we're going to have to use to guess which columns are the
|
||||
# right ones for the file we're looking at.
|
||||
|
||||
# The same ordering is used for JSON, but these do have key names, however they are not very
|
||||
# human friendly, nor do they match up with the what is written to the .png headers
|
||||
|
||||
# So these are functions to try and get something consistent out the raw input from all
|
||||
# these sources
|
||||
|
||||
PARAMS_FORMATS = {
|
||||
9: {
|
||||
"VARIANT": "Model",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"OUTPUT": "Filename",
|
||||
},
|
||||
10: {
|
||||
"MODEL": "Model",
|
||||
"VARIANT": "Variant",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"OUTPUT": "Filename",
|
||||
},
|
||||
12: {
|
||||
"VARIANT": "Model",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"HEIGHT": "Height",
|
||||
"WIDTH": "Width",
|
||||
"MAX_LENGTH": "Max Length",
|
||||
"OUTPUT": "Filename",
|
||||
},
|
||||
}
|
||||
|
||||
PARAMS_FORMAT_LONGEST = PARAMS_FORMATS[max(PARAMS_FORMATS.keys())]
|
||||
|
||||
|
||||
def compact(metadata: dict) -> dict:
|
||||
# we don't want to alter the original dictionary
|
||||
result = dict(metadata)
|
||||
|
||||
# discard the filename because we should already have it
|
||||
if result.keys() & {"Filename"}:
|
||||
result.pop("Filename")
|
||||
|
||||
# make showing the sizes more compact by using only one line each
|
||||
if result.keys() & {"Size-1", "Size-2"}:
|
||||
result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}"
|
||||
elif result.keys() & {"Height", "Width"}:
|
||||
result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}"
|
||||
|
||||
if result.keys() & {"Hires resize-1", "Hires resize-1"}:
|
||||
hires_y = result.pop("Hires resize-1")
|
||||
hires_x = result.pop("Hires resize-2")
|
||||
|
||||
if hires_x == 0 and hires_y == 0:
|
||||
result["Hires resize"] = "None"
|
||||
else:
|
||||
result["Hires resize"] = f"{hires_y}x{hires_x}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def humanizable(metadata: dict | list[str], includes_filename=True) -> dict:
|
||||
lookup_key = len(metadata) + (0 if includes_filename else 1)
|
||||
return lookup_key in PARAMS_FORMATS.keys()
|
||||
|
||||
|
||||
def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
|
||||
lookup_key = len(metadata) + (0 if includes_filename else 1)
|
||||
|
||||
# For lists we can only work based on the length, we have no other information
|
||||
if isinstance(metadata, list):
|
||||
if humanizable(metadata, includes_filename):
|
||||
return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata))
|
||||
else:
|
||||
raise KeyError(
|
||||
f"Humanize could not find the format for a parameter list of length {len(metadata)}"
|
||||
)
|
||||
|
||||
# For dictionaries we try to use the matching length parameter format if
|
||||
# available, otherwise we use the longest. Then we swap keys in the
|
||||
# metadata that match keys in the format for the friendlier name that we
|
||||
# have set in the format value
|
||||
if isinstance(metadata, dict):
|
||||
if humanizable(metadata, includes_filename):
|
||||
format = PARAMS_FORMATS[lookup_key]
|
||||
else:
|
||||
format = PARAMS_FORMAT_LONGEST
|
||||
|
||||
return {
|
||||
format[key]: value
|
||||
for (key, value) in metadata.items()
|
||||
if key in format.keys()
|
||||
}
|
||||
|
||||
raise TypeError("Can only humanize parameter lists or dictionaries")
|
||||
@@ -40,7 +40,7 @@ cmake --build build/
|
||||
*Prepare the model*
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
```
|
||||
*Prepare the input*
|
||||
|
||||
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
|
||||
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
|
||||
```
|
||||
VAE and Autoencoder are also available
|
||||
```bash
|
||||
# VAE
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
|
||||
|
||||
# CLIP Autoencoder
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
|
||||
```
|
||||
|
||||
@@ -21,7 +21,7 @@ endif()
|
||||
# Compile mnist.mlir to mnist.vmfb.
|
||||
set(_COMPILE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)
|
||||
set(_COMPILE_ARGS)
|
||||
list(APPEND _COMPILE_ARGS "--iree-input-type=mhlo")
|
||||
list(APPEND _COMPILE_ARGS "--iree-input-type=auto")
|
||||
list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu")
|
||||
list(APPEND _COMPILE_ARGS "${IREE_SOURCE_DIR}/samples/models/mnist.mlir")
|
||||
list(APPEND _COMPILE_ARGS "-o")
|
||||
|
||||
@@ -70,11 +70,11 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"resnet50", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
|
||||
shark_module = SharkInference(mlir_model, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
path = shark_module.save_module()
|
||||
shark_module.load_module(path)
|
||||
result = shark_module.forward((img.detach().numpy(),))
|
||||
result = shark_module("forward", (img.detach().numpy(),))
|
||||
|
||||
print("The top 3 results obtained via shark_runner is:")
|
||||
print(top3_possibilities(torch.from_numpy(result)))
|
||||
|
||||
@@ -60,7 +60,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
def get_iree_frontend_args(frontend):
|
||||
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
|
||||
return ["--iree-llvmcpu-target-cpu-features=host"]
|
||||
elif frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
return [
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
@@ -265,8 +265,8 @@ def compile_module_to_flatbuffer(
|
||||
args += extra_args
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "mhlo"
|
||||
elif frontend in ["mhlo", "tosa"]:
|
||||
input_type = "auto"
|
||||
elif frontend in ["stablehlo", "tosa"]:
|
||||
input_type = frontend
|
||||
elif frontend in ["tflite", "tflite-tosa"]:
|
||||
input_type = "tosa"
|
||||
@@ -367,7 +367,7 @@ def export_iree_module_to_vmfb(
|
||||
def export_module_to_mlir_file(module, frontend, directory: str):
|
||||
# TODO: write proper documentation.
|
||||
mlir_str = module
|
||||
if frontend in ["tensorflow", "tf", "mhlo", "tflite"]:
|
||||
if frontend in ["tensorflow", "tf", "mhlo", "stablehlo", "tflite"]:
|
||||
mlir_str = module.decode("utf-8")
|
||||
elif frontend in ["pytorch", "torch"]:
|
||||
mlir_str = module.operation.get_asm()
|
||||
|
||||
@@ -117,7 +117,8 @@ def get_extensions(triple):
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("VK_NV_cooperative_matrix")
|
||||
|
||||
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
|
||||
ext.append("VK_KHR_shader_integer_dot_product")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
|
||||
@@ -228,6 +229,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
@@ -236,12 +238,12 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
|
||||
]
|
||||
|
||||
if product == "rx5700xt":
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["storagePushConstant8"] = False
|
||||
@@ -274,7 +276,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
@@ -305,6 +307,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = False
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
@@ -367,6 +370,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = False
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
@@ -408,11 +412,12 @@ def get_vulkan_target_capabilities(triple):
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat16"] = False
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
@@ -446,6 +451,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
|
||||
@@ -311,11 +311,18 @@ def add_attributes(op: ir.Operation, config: List[Dict]):
|
||||
split_k = config["split_k"]
|
||||
elif "SPIRV" in config["pipeline"]:
|
||||
pipeline = config["pipeline"]
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"],
|
||||
config["parallel_tile_sizes"],
|
||||
config["reduction_tile_sizes"],
|
||||
]
|
||||
if pipeline == "SPIRVMatmulPromoteVectorize":
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"]
|
||||
+ [config["reduction_tile_sizes"][-1]],
|
||||
]
|
||||
else:
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"],
|
||||
config["parallel_tile_sizes"],
|
||||
config["reduction_tile_sizes"],
|
||||
]
|
||||
|
||||
workgroup_size = config["work_group_sizes"]
|
||||
if "vector_tile_sizes" in config.keys():
|
||||
tile_sizes += [config["vector_tile_sizes"]]
|
||||
|
||||
@@ -61,6 +61,8 @@ def download_public_file(
|
||||
continue
|
||||
|
||||
destination_filename = os.path.join(destination_folder_name, blob_name)
|
||||
if os.path.isdir(destination_filename):
|
||||
continue
|
||||
with open(destination_filename, "wb") as f:
|
||||
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
|
||||
storage_client.download_blob_to_file(blob, file_obj)
|
||||
@@ -210,6 +212,9 @@ def download_model(
|
||||
+ "_BS"
|
||||
+ str(import_args["batch_size"])
|
||||
)
|
||||
elif any(model in model_name for model in ["clip", "unet", "vae"]):
|
||||
# TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation.
|
||||
model_dir_name = model_name
|
||||
else:
|
||||
model_dir_name = model_name + "_" + frontend
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
@@ -270,6 +275,9 @@ def download_model(
|
||||
tuned_str = "" if tuned is None else "_" + tuned
|
||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
print(
|
||||
f"Verifying that model artifacts were downloaded successfully to {filename}..."
|
||||
)
|
||||
if not os.path.exists(filename):
|
||||
from tank.generate_sharktank import gen_shark_files
|
||||
|
||||
|
||||
@@ -320,6 +320,9 @@ def transform_fx(fx_g):
|
||||
"device": torch.device(type="cpu"),
|
||||
"pin_memory": False,
|
||||
}
|
||||
kwargs_dict1 = {
|
||||
"dtype": torch.float16,
|
||||
}
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
@@ -327,7 +330,16 @@ def transform_fx(fx_g):
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.zeros,
|
||||
]:
|
||||
node.kwargs = kwargs_dict
|
||||
if node.kwargs.get("dtype") == torch.float32:
|
||||
node.kwargs = kwargs_dict
|
||||
|
||||
# Vicuna
|
||||
if node.target in [
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("dtype") == torch.float32:
|
||||
node.kwargs = kwargs_dict1
|
||||
|
||||
# Inputs and outputs of aten.var.mean should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.var_mean]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
@@ -337,6 +349,7 @@ def transform_fx(fx_g):
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node, node.args[1])
|
||||
|
||||
if node.name.startswith("getitem"):
|
||||
with fx_g.graph.inserting_before(node):
|
||||
if node.args[0].target in [torch.ops.aten.var_mean]:
|
||||
@@ -349,6 +362,19 @@ def transform_fx(fx_g):
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
# Change the default dtype of aten.full op. (Vicuna)
|
||||
if node.target in [torch.ops.aten.full]:
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node,),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
|
||||
@@ -25,7 +25,14 @@ import sys
|
||||
|
||||
|
||||
# supported dialects by the shark-runtime.
|
||||
supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite", "tm_tensor"}
|
||||
supported_dialects = {
|
||||
"linalg",
|
||||
"mhlo",
|
||||
"stablehlo",
|
||||
"tosa",
|
||||
"tf-lite",
|
||||
"tm_tensor",
|
||||
}
|
||||
|
||||
|
||||
class SharkRunner:
|
||||
|
||||
@@ -59,6 +59,7 @@ class SharkTrainer:
|
||||
"torch",
|
||||
"tensorflow",
|
||||
"tf",
|
||||
"stablehlo",
|
||||
"mhlo",
|
||||
"linalg",
|
||||
"tosa",
|
||||
@@ -84,7 +85,7 @@ class SharkTrainer:
|
||||
"tm_tensor",
|
||||
extra_args=extra_args,
|
||||
)
|
||||
elif self.frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
elif self.frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
self.shark_runner = SharkRunner(
|
||||
self.model,
|
||||
self.input,
|
||||
|
||||
@@ -8,7 +8,7 @@ distilbert-base-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/convnext-tiny-224,mhlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342","macos"
|
||||
funnel-transformer/small,mhlo,tf,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/201",""
|
||||
google/electra-small-discriminator,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"Fails during iree-compile",""
|
||||
google/mobilebert-uncased,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"Fails during iree-compile","macos"
|
||||
google/vit-base-patch16-224,mhlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,False,"",""
|
||||
microsoft/MiniLM-L12-H384-uncased,mhlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
|
||||
microsoft/layoutlm-base-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
@@ -24,9 +24,9 @@ facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nh
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344",""
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/388","macos"
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/343","macos"
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,True,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
|
||||
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
|
||||
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
@@ -35,13 +35,13 @@ squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","mac
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
|
||||
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243",""
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
stabilityai/stable-diffusion-2-1-base,linalg,torch,1e-3,1e-3,default,None,True,False,False,"",""
|
||||
stabilityai/stable-diffusion-2-1-base,linalg,torch,1e-3,1e-3,default,None,True,False,False,"","macos"
|
||||
|
||||
|
@@ -75,7 +75,7 @@ if __name__ == "__main__":
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
input_type="auto",
|
||||
)
|
||||
# flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ if __name__ == "__main__":
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
input_type="auto",
|
||||
)
|
||||
|
||||
# Save module as MLIR file in a directory
|
||||
|
||||
@@ -96,7 +96,7 @@ if __name__ == "__main__":
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
input_type="auto",
|
||||
)
|
||||
# flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ if __name__ == "__main__":
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
input_type="auto",
|
||||
)
|
||||
# flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
|
||||
|
||||
159
tank/examples/opt/opt_causallm.py
Normal file
159
tank/examples/opt/opt_causallm.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import unittest
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import torch_mlir
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark_hf_opt import OPTForCausalLM
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from tank.model_utils import compare_tensors
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
OPT_MODEL = "opt-350m"
|
||||
OPT_MODEL_66B = "facebook/opt-66b"
|
||||
MAX_SEQUENCE_LENGTH = 256
|
||||
MAX_NEW_TOKENS = 200
|
||||
|
||||
|
||||
def create_module(model_name, tokenizer, device):
|
||||
opt_model = OPTForCausalLM.from_pretrained(
|
||||
"facebook/" + model_name, return_dict=False
|
||||
)
|
||||
opt_model.eval()
|
||||
|
||||
encoded_inputs = tokenizer(
|
||||
"This is a sample input for generating the model.",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
mlir_path = f"./{OPT_MODEL}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
|
||||
if os.path.isfile(mlir_path):
|
||||
with open(mlir_path, "r") as f:
|
||||
model_mlir = f.read()
|
||||
print(f"Loaded .mlir from {mlir_path}")
|
||||
else:
|
||||
module = torch_mlir.compile(
|
||||
opt_model,
|
||||
inputs,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=True,
|
||||
)
|
||||
|
||||
model_mlir = module.operation.get_asm(
|
||||
large_elements_limit=None, enable_debug_info=True
|
||||
)
|
||||
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
func_name = "forward"
|
||||
act_out = opt_model(inputs[0], attention_mask=inputs[1], return_dict=False)
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=False,
|
||||
)
|
||||
vmfb_name = f"{OPT_MODEL}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
|
||||
shark_module.save_module(module_name=vmfb_name)
|
||||
shark_module.load_module(vmfb_name + ".vmfb")
|
||||
|
||||
results = shark_module("forward", inputs)
|
||||
print(
|
||||
"SHARK logits have shape: ",
|
||||
str(results[0].shape) + " : " + str(results[0]),
|
||||
)
|
||||
print(
|
||||
"PyTorch logits have shape: "
|
||||
+ str(act_out[0].shape)
|
||||
+ " : "
|
||||
+ str(act_out[0])
|
||||
)
|
||||
# exp_out = tokenizer.decode(act_out[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
# shark_out = tokenizer.decode(results[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
return shark_module
|
||||
|
||||
|
||||
def shouldStop(tokens):
|
||||
stop_ids = [50278, 50279, 50277, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def generate_new_token(shark_model, tokenizer, new_text):
|
||||
model_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
model_inputs["input_ids"],
|
||||
model_inputs["attention_mask"],
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
output = shark_model("forward", inputs)
|
||||
output = torch.FloatTensor(output[0])
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[int(sum_attentionmask) - 1]
|
||||
detok = tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/" + OPT_MODEL, use_fast=False
|
||||
)
|
||||
vmfb_path = f"./{OPT_MODEL}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu.vmfb"
|
||||
if os.path.isfile(vmfb_path):
|
||||
opt_shark_module = SharkInference(mlir_module=None, device="cpu")
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
opt_shark_module = create_module(OPT_MODEL, tokenizer, "cpu")
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence to complete:")
|
||||
words_list = []
|
||||
|
||||
for i in range(MAX_NEW_TOKENS):
|
||||
generated_token_op = generate_new_token(
|
||||
opt_shark_module, tokenizer, new_text
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
if stop_generation:
|
||||
break
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text = new_text + detok
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
break
|
||||
211
tank/examples/opt/opt_causallm_torch_test.py
Normal file
211
tank/examples/opt/opt_causallm_torch_test.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import unittest
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import torch_mlir
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark_hf_opt import OPTForCausalLM
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from tank.model_utils import compare_tensors
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
OPT_MODEL = "facebook/opt-1.3B"
|
||||
OPT_MODEL_66B = "facebook/opt-66b"
|
||||
|
||||
|
||||
class OPTModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
benchmark=False,
|
||||
):
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device, model_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
# config = OPTConfig()
|
||||
# opt_model = OPTModel(config)
|
||||
opt_model = OPTForCausalLM.from_pretrained(
|
||||
model_name, return_dict=False
|
||||
)
|
||||
opt_model.eval()
|
||||
|
||||
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
input_ids, attention_mask = (
|
||||
inputs.data["input_ids"],
|
||||
inputs.data["attention_mask"],
|
||||
)
|
||||
np.save("opt_inputs.npy", input_ids.detach())
|
||||
mlir_path = "./OPT1-3b_causallm_torch.mlir"
|
||||
if os.path.isfile(mlir_path):
|
||||
with open(mlir_path, "r") as f:
|
||||
model_mlir = f.read()
|
||||
print(f"Loaded .mlir from {mlir_path}")
|
||||
else:
|
||||
module = torch_mlir.compile(
|
||||
opt_model,
|
||||
input_ids,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=True,
|
||||
)
|
||||
|
||||
model_mlir = module.operation.get_asm(
|
||||
large_elements_limit=None, enable_debug_info=True
|
||||
)
|
||||
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
func_name = "forward"
|
||||
act_out = opt_model(input_ids, return_dict=False)
|
||||
|
||||
# mlir_importer = SharkImporter(
|
||||
# model,
|
||||
# (input,),
|
||||
# frontend="torch",
|
||||
# )
|
||||
# minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
# is_dynamic=dynamic, tracing_required=True
|
||||
# )
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=self.benchmark,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module("forward", (input_ids,))
|
||||
print(
|
||||
"SHARK logits have shape: ",
|
||||
str(results[0].shape) + " : " + str(results[0]),
|
||||
)
|
||||
print(
|
||||
"PyTorch logits have shape: "
|
||||
+ str(act_out[0].shape)
|
||||
+ " : "
|
||||
+ str(act_out[0])
|
||||
)
|
||||
# exp_out = tokenizer.decode(act_out[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
# shark_out = tokenizer.decode(results[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
assert compare_tensors(act_out[0].detach(), results[0])
|
||||
|
||||
if self.benchmark:
|
||||
shark_module.shark_runner.benchmark_all_csv(
|
||||
(input_ids, attention_mask),
|
||||
"opt",
|
||||
dynamic,
|
||||
device,
|
||||
"torch",
|
||||
)
|
||||
|
||||
|
||||
class OPTModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = OPTModuleTester(self)
|
||||
self.module_tester.save_mlir = False
|
||||
self.module_tester.save_vmfb = False
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
def test_1_3b_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
def test_1_3b_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
)
|
||||
def test_1_3b_static_cuda(self):
|
||||
dynamic = False
|
||||
device = "cuda"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
)
|
||||
def test_1_3b_dynamic_cuda(self):
|
||||
dynamic = True
|
||||
device = "cuda"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_1_3b_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_1_3b_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
# def test_66b_static_cpu(self):
|
||||
# dynamic = False
|
||||
# device = "cpu"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# def test_66b_dynamic_cpu(self):
|
||||
# dynamic = True
|
||||
# device = "cpu"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
# )
|
||||
# def test_66b_static_cuda(self):
|
||||
# dynamic = False
|
||||
# device = "cuda"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
# )
|
||||
# def test_66b_dynamic_cuda(self):
|
||||
# dynamic = True
|
||||
# device = "cuda"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
# )
|
||||
# def test_66b_static_vulkan(self):
|
||||
# dynamic = False
|
||||
# device = "vulkan"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
# )
|
||||
# def test_66b_dynamic_vulkan(self):
|
||||
# dynamic = True
|
||||
# device = "vulkan"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -2,11 +2,11 @@ import unittest
|
||||
|
||||
import pytest
|
||||
import torch_mlir
|
||||
from hacked_hf_opt import OPTModel
|
||||
from shark_hf_opt import OPTModel
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from tank.model_utils import compare_tensors
|
||||
from transformers import GPT2Tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
OPT_MODEL = "facebook/opt-350m"
|
||||
OPT_MODEL_66B = "facebook/opt-66b"
|
||||
@@ -20,7 +20,7 @@ class OPTModuleTester:
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device, model_name):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
# config = OPTConfig()
|
||||
# opt_model = OPTModel(config)
|
||||
opt_model = OPTModel.from_pretrained(model_name)
|
||||
@@ -56,13 +56,12 @@ class OPTModuleTester:
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
func_name,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=self.benchmark,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input_ids, attention_mask))
|
||||
results = shark_module("forward", (input_ids, attention_mask))
|
||||
assert compare_tensors(act_out, results)
|
||||
|
||||
if self.benchmark:
|
||||
|
||||
@@ -279,7 +279,6 @@ class OPTAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (
|
||||
bsz * self.num_heads,
|
||||
tgt_len,
|
||||
@@ -832,7 +831,10 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if isinstance(outputs[1:], tuple):
|
||||
output = (logits,) + outputs[1:]
|
||||
else:
|
||||
output = (logits, outputs[1:])
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
@@ -23,3 +23,5 @@ bert-base-uncased_fp16,True,fp16,False,linalg,False,109M,"nlp;bert-variant;trans
|
||||
bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
|
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"version": "2023-03-31_02d52bb"
|
||||
"version": "nightly"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user