mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 14:58:11 -05:00
Compare commits
30 Commits
monorimet-
...
fp16cpu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
489a858af1 | ||
|
|
b83d32fafe | ||
|
|
0a618e1863 | ||
|
|
a731eb6ed4 | ||
|
|
2004d16945 | ||
|
|
6e409bfb77 | ||
|
|
77727d149c | ||
|
|
66f6e79d68 | ||
|
|
3b825579a7 | ||
|
|
9f0a421764 | ||
|
|
c28682110c | ||
|
|
caf6cc5d8f | ||
|
|
8614a18474 | ||
|
|
86c1c0c215 | ||
|
|
8bb364bcb8 | ||
|
|
7abddd01ec | ||
|
|
2a451fa0c7 | ||
|
|
9c4610b9da | ||
|
|
a38cc9d216 | ||
|
|
1c382449ec | ||
|
|
7cc9b3f8e8 | ||
|
|
e54517e967 | ||
|
|
326327a799 | ||
|
|
785b65c7b0 | ||
|
|
0d16c81687 | ||
|
|
8dd7850c69 | ||
|
|
e930ba85b4 | ||
|
|
cd732e7a38 | ||
|
|
8e0f8b3227 | ||
|
|
b8210ef796 |
@@ -20,7 +20,7 @@ import gc
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
# Brevitas
|
||||
@@ -256,6 +256,11 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
bytecode = save_mlir(
|
||||
bytecode,
|
||||
model_name=f"h2ogpt_{precision}",
|
||||
frontend="torch",
|
||||
)
|
||||
return bytecode
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
|
||||
442
apps/language_models/scripts/llama_ir_conversion_utils.py
Normal file
442
apps/language_models/scripts/llama_ir_conversion_utils.py
Normal file
@@ -0,0 +1,442 @@
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
from argparse import RawTextHelpFormatter
|
||||
import re, gc
|
||||
|
||||
"""
|
||||
This script can be used as a standalone utility to convert IRs to dynamic + combine them.
|
||||
Following are the various ways this script can be used :-
|
||||
a. To convert a single Linalg IR to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO FIRST IR>
|
||||
b. To convert two Linalg IRs to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>
|
||||
c. To combine two Linalg IRs into one:
|
||||
--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
d. To convert both IRs into dynamic as well as combine the IRs:
|
||||
--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
|
||||
NOTE: For dynamic you'll also need to provide the following set of flags:-
|
||||
i. For First Llama : --dynamic_input_size (DEFAULT: 19)
|
||||
ii. For Second Llama: --model_name (DEFAULT: llama2_7b)
|
||||
--precision (DEFAULT: 'int4')
|
||||
You may use --save_dynamic to also save the dynamic IR in option d above.
|
||||
Else for option a. and b. the dynamic IR(s) will get saved by default.
|
||||
"""
|
||||
|
||||
|
||||
def combine_mlir_scripts(
|
||||
first_vicuna_mlir,
|
||||
second_vicuna_mlir,
|
||||
output_name,
|
||||
return_ir=True,
|
||||
):
|
||||
print(f"[DEBUG] combining first and second mlir")
|
||||
print(f"[DEBUG] output_name = {output_name}")
|
||||
maps1 = []
|
||||
maps2 = []
|
||||
constants = set()
|
||||
f1 = []
|
||||
f2 = []
|
||||
|
||||
print(f"[DEBUG] processing first vicuna mlir")
|
||||
first_vicuna_mlir = first_vicuna_mlir.splitlines()
|
||||
while first_vicuna_mlir:
|
||||
line = first_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps1.append(line)
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "first_vicuna_forward", line)
|
||||
f1.append(line)
|
||||
f1 = f1[:-1]
|
||||
del first_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps1):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
|
||||
maps1[i] = map_line
|
||||
f1 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
|
||||
for func_line in f1
|
||||
]
|
||||
|
||||
print(f"[DEBUG] processing second vicuna mlir")
|
||||
second_vicuna_mlir = second_vicuna_mlir.splitlines()
|
||||
while second_vicuna_mlir:
|
||||
line = second_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps2.append(line)
|
||||
elif "global_seed" in line:
|
||||
continue
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "second_vicuna_forward", line)
|
||||
f2.append(line)
|
||||
f2 = f2[:-1]
|
||||
del second_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps2):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
|
||||
maps2[i] = map_line
|
||||
f2 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
|
||||
for func_line in f2
|
||||
]
|
||||
|
||||
module_start = 'module attributes {torch.debug_module_name = "_lambda"} {'
|
||||
module_end = "}"
|
||||
|
||||
global_vars = []
|
||||
vnames = []
|
||||
global_var_loading1 = []
|
||||
global_var_loading2 = []
|
||||
|
||||
print(f"[DEBUG] processing constants")
|
||||
counter = 0
|
||||
constants = list(constants)
|
||||
while constants:
|
||||
constant = constants.pop(0)
|
||||
vname, vbody = constant.split("=")
|
||||
vname = re.sub("%", "", vname)
|
||||
vname = vname.strip()
|
||||
vbody = re.sub("arith.constant", "", vbody)
|
||||
vbody = vbody.strip()
|
||||
if len(vbody.split(":")) < 2:
|
||||
print(constant)
|
||||
vdtype = vbody.split(":")[-1].strip()
|
||||
fixed_vdtype = vdtype
|
||||
if "c1_i64" in vname:
|
||||
print(constant)
|
||||
counter += 1
|
||||
if counter == 2:
|
||||
counter = 0
|
||||
print("detected duplicate")
|
||||
continue
|
||||
vnames.append(vname)
|
||||
if "true" not in vname:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
else:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : i1"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
|
||||
new_f1, new_f2 = [], []
|
||||
|
||||
print(f"[DEBUG] processing f1")
|
||||
for line in f1:
|
||||
if "func.func" in line:
|
||||
new_f1.append(line)
|
||||
for global_var in global_var_loading1:
|
||||
new_f1.append(global_var)
|
||||
else:
|
||||
new_f1.append(line)
|
||||
|
||||
print(f"[DEBUG] processing f2")
|
||||
for line in f2:
|
||||
if "func.func" in line:
|
||||
new_f2.append(line)
|
||||
for global_var in global_var_loading2:
|
||||
if (
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
in global_var
|
||||
):
|
||||
print(global_var)
|
||||
new_f2.append(global_var)
|
||||
else:
|
||||
new_f2.append(line)
|
||||
|
||||
f1 = new_f1
|
||||
f2 = new_f2
|
||||
|
||||
del new_f1
|
||||
del new_f2
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
[
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
|
||||
for x in [maps1, maps2, global_vars, f1, f2]
|
||||
]
|
||||
)
|
||||
|
||||
# doing it this way rather than assembling the whole string
|
||||
# to prevent OOM with 64GiB RAM when encoding the file.
|
||||
|
||||
print(f"[DEBUG] Saving mlir to {output_name}")
|
||||
with open(output_name, "w+") as f_:
|
||||
f_.writelines(line + "\n" for line in maps1)
|
||||
f_.writelines(line + "\n" for line in maps2)
|
||||
f_.writelines(line + "\n" for line in [module_start])
|
||||
f_.writelines(line + "\n" for line in global_vars)
|
||||
f_.writelines(line + "\n" for line in f1)
|
||||
f_.writelines(line + "\n" for line in f2)
|
||||
f_.writelines(line + "\n" for line in [module_end])
|
||||
|
||||
del maps1
|
||||
del maps2
|
||||
del module_start
|
||||
del global_vars
|
||||
del f1
|
||||
del f2
|
||||
del module_end
|
||||
gc.collect()
|
||||
|
||||
if return_ir:
|
||||
print(f"[DEBUG] Reading combined mlir back in")
|
||||
with open(output_name, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def write_in_dynamic_inputs0(module, dynamic_input_size):
|
||||
print("[DEBUG] writing dynamic inputs to first vicuna")
|
||||
# Current solution for ensuring mlir files support dynamic inputs
|
||||
# TODO: find a more elegant way to implement this
|
||||
new_lines = []
|
||||
module = module.splitlines()
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
|
||||
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
|
||||
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
|
||||
continue
|
||||
|
||||
new_lines.append(line)
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def write_in_dynamic_inputs1(module, model_name, precision):
|
||||
print("[DEBUG] writing dynamic inputs to second vicuna")
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "c19_i64" in line:
|
||||
line = re.sub("c19_i64", "dim_i64", line)
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
if "x20x" in line or "<20x" in line:
|
||||
line = re.sub("20x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
|
||||
if " 20," in line:
|
||||
line = re.sub(" 20,", " %dimp1,", line)
|
||||
return line
|
||||
|
||||
module = module.splitlines()
|
||||
new_lines = []
|
||||
|
||||
# Using a while loop and the pop method to avoid creating a copy of module
|
||||
if "llama2_13b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x40x?x128x"
|
||||
elif "llama2_70b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x8x?x128x"
|
||||
else:
|
||||
pkv_tensor_shape = "tensor<1x32x?x128x"
|
||||
if precision in ["fp16", "int4", "int8"]:
|
||||
pkv_tensor_shape += "f16>"
|
||||
else:
|
||||
pkv_tensor_shape += "f32>"
|
||||
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append(
|
||||
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
)
|
||||
continue
|
||||
if "%c2 = arith.constant 2 : index" in line:
|
||||
continue
|
||||
if "%c20_i64 = arith.constant 20 : i64" in line:
|
||||
new_lines.append("%c1_i64 = arith.constant 1 : i64")
|
||||
new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
|
||||
new_lines.append(
|
||||
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
|
||||
)
|
||||
continue
|
||||
line = remove_constant_dim(line)
|
||||
new_lines.append(line)
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def save_dynamic_ir(ir_to_save, output_file):
|
||||
if not ir_to_save:
|
||||
return
|
||||
# We only get string output from the dynamic conversion utility.
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
print(ir_to_save)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="llama ir utility",
|
||||
description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n"
|
||||
+ "\tFollowing are the various ways this script can be used :-\n"
|
||||
+ "\t\ta. To convert a single Linalg IR to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO FIRST IR>\n"
|
||||
+ "\t\tb. To convert two Linalg IRs to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\tc. To combine two Linalg IRs into one:\n"
|
||||
+ "\t\t\t--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n"
|
||||
+ "\t\t\t--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n\n"
|
||||
+ "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n"
|
||||
+ "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n"
|
||||
+ "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n"
|
||||
+ "\t\t\t--precision (DEFAULT: 'int4')\n"
|
||||
+ "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n"
|
||||
+ "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
"-p",
|
||||
default="int4",
|
||||
choices=["fp32", "fp16", "int8", "int4"],
|
||||
help="Precision of the concerned IR",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="llama2_7b",
|
||||
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
|
||||
help="Specify which model to run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first_ir_path",
|
||||
default=None,
|
||||
help="path to first llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--second_ir_path",
|
||||
default=None,
|
||||
help="path to second llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic_input_size",
|
||||
type=int,
|
||||
default=19,
|
||||
help="Specify the static input size to replace with dynamic dim.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Save the individual IR(s) after converting to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--combine",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
dynamic = args.dynamic
|
||||
combine = args.combine
|
||||
assert (
|
||||
dynamic or combine
|
||||
), "neither `dynamic` nor `combine` flag is turned on"
|
||||
first_ir_path = args.first_ir_path
|
||||
second_ir_path = args.second_ir_path
|
||||
assert first_ir_path or second_ir_path, "no input ir has been provided"
|
||||
if combine:
|
||||
assert (
|
||||
first_ir_path and second_ir_path
|
||||
), "you will need to provide both IRs to combine"
|
||||
precision = args.precision
|
||||
model_name = args.model_name
|
||||
dynamic_input_size = args.dynamic_input_size
|
||||
save_dynamic = args.save_dynamic
|
||||
|
||||
print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}")
|
||||
print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}")
|
||||
|
||||
if dynamic and not combine:
|
||||
save_dynamic = True
|
||||
|
||||
first_ir = None
|
||||
first_dynamic_ir_name = None
|
||||
second_ir = None
|
||||
second_dynamic_ir_name = None
|
||||
if first_ir_path:
|
||||
first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic"
|
||||
with open(first_ir_path, "r") as f:
|
||||
first_ir = f.read()
|
||||
if second_ir_path:
|
||||
second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic"
|
||||
with open(second_ir_path, "r") as f:
|
||||
second_ir = f.read()
|
||||
if dynamic:
|
||||
first_ir = (
|
||||
write_in_dynamic_inputs0(first_ir, dynamic_input_size)
|
||||
if first_ir
|
||||
else None
|
||||
)
|
||||
second_ir = (
|
||||
write_in_dynamic_inputs1(second_ir, model_name, precision)
|
||||
if second_ir
|
||||
else None
|
||||
)
|
||||
if save_dynamic:
|
||||
save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir")
|
||||
save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir")
|
||||
|
||||
if combine:
|
||||
combine_mlir_scripts(
|
||||
first_ir,
|
||||
second_ir,
|
||||
f"{model_name}_{precision}.mlir",
|
||||
return_ir=False,
|
||||
)
|
||||
@@ -4,10 +4,12 @@ import re
|
||||
import gc
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from statistics import mean, stdev
|
||||
from tqdm import tqdm
|
||||
from typing import List, Tuple
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
@@ -42,12 +44,18 @@ from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
SecondVicuna13B,
|
||||
SecondVicuna70B,
|
||||
)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model_gpu import (
|
||||
FirstVicunaGPU,
|
||||
SecondVicuna7BGPU,
|
||||
SecondVicuna13BGPU,
|
||||
SecondVicuna70BGPU,
|
||||
)
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import get_f16_inputs
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
|
||||
@@ -102,7 +110,7 @@ parser.add_argument(
|
||||
"--download_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download vmfb from sharktank, system dependent, YMMV",
|
||||
help="Download vmfb from sharktank, system dependent, YMMV",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
@@ -130,6 +138,38 @@ parser.add_argument(
|
||||
help="Specify target triple for vulkan.",
|
||||
)
|
||||
|
||||
# Microbenchmarking options.
|
||||
parser.add_argument(
|
||||
"--enable_microbenchmark",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Enables the microbenchmarking mode (non-interactive). Uses the system and the user prompt from args.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microbenchmark_iterations",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of microbenchmark iterations. Default: 5.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microbenchmark_num_tokens",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Generate an exact number of output tokens. Default: 512.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system_prompt",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify the system prompt. This is only used with `--enable_microbenchmark`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user_prompt",
|
||||
type=str,
|
||||
default="Hi",
|
||||
help="Specify the user prompt. This is only used with `--enable_microbenchmark`",
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
@@ -398,7 +438,7 @@ class VicunaBase(SharkLLMBase):
|
||||
is_first=is_first,
|
||||
)
|
||||
else:
|
||||
token = token.to(torch.int64).reshape([1, 1])
|
||||
token = torch.tensor(token).reshape([1, 1])
|
||||
second_input = (token,) + tuple(past_key_values)
|
||||
output = self.shark_model(
|
||||
"second_vicuna_forward", second_input, send_to_host=False
|
||||
@@ -408,6 +448,9 @@ class VicunaBase(SharkLLMBase):
|
||||
_logits = output["logits"]
|
||||
_past_key_values = output["past_key_values"]
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
elif "cpu" in self.device:
|
||||
_past_key_values = output[1:]
|
||||
_token = int(output[0].to_host())
|
||||
else:
|
||||
_logits = torch.tensor(output[0].to_host())
|
||||
_past_key_values = output[1:]
|
||||
@@ -417,9 +460,10 @@ class VicunaBase(SharkLLMBase):
|
||||
ret_dict = {
|
||||
"token": _token,
|
||||
"detok": _detok,
|
||||
"logits": _logits,
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
if "cpu" not in self.device:
|
||||
ret_dict["logits"] = _logits
|
||||
|
||||
if cli:
|
||||
print(f" token : {_token} | detok : {_detok}")
|
||||
@@ -640,9 +684,7 @@ class ShardedVicuna(VicunaBase):
|
||||
mlir_path = Path(f"lmhead.mlir")
|
||||
vmfb_path = Path(f"lmhead.vmfb")
|
||||
if mlir_path.exists():
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
print(f"Found bytecode module at {mlir_path}.")
|
||||
else:
|
||||
hidden_states = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
@@ -667,12 +709,10 @@ class ShardedVicuna(VicunaBase):
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
f_ = open(f"lmhead.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
mlir_path = filepath
|
||||
|
||||
shark_module = SharkInference(
|
||||
bytecode,
|
||||
mlir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
@@ -692,9 +732,7 @@ class ShardedVicuna(VicunaBase):
|
||||
mlir_path = Path(f"norm.mlir")
|
||||
vmfb_path = Path(f"norm.vmfb")
|
||||
if mlir_path.exists():
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
print(f"Found bytecode module at {mlir_path}.")
|
||||
else:
|
||||
hidden_states = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
@@ -713,12 +751,10 @@ class ShardedVicuna(VicunaBase):
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
f_ = open(f"norm.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
mlir_path = filepath
|
||||
|
||||
shark_module = SharkInference(
|
||||
bytecode,
|
||||
mlir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
@@ -738,9 +774,7 @@ class ShardedVicuna(VicunaBase):
|
||||
mlir_path = Path(f"embedding.mlir")
|
||||
vmfb_path = Path(f"embedding.vmfb")
|
||||
if mlir_path.exists():
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
print(f"Found bytecode module at {mlir_path}.")
|
||||
else:
|
||||
input_ids = torch_mlir.TensorPlaceholder.like(
|
||||
input_ids, dynamic_axes=[1]
|
||||
@@ -764,12 +798,10 @@ class ShardedVicuna(VicunaBase):
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
f_ = open(f"embedding.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
mlir_path = filepath
|
||||
|
||||
shark_module = SharkInference(
|
||||
bytecode,
|
||||
mlir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
@@ -1219,6 +1251,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
hf_auth_token: str = None,
|
||||
max_num_tokens=512,
|
||||
min_num_tokens=0,
|
||||
device="cpu",
|
||||
vulkan_target_triple="",
|
||||
precision="int8",
|
||||
@@ -1248,8 +1281,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
|
||||
print(f"[DEBUG] hf model name: {self.hf_model_path}")
|
||||
self.max_sequence_length = 256
|
||||
self.min_num_tokens = min_num_tokens
|
||||
self.device = device
|
||||
self.vulkan_target_triple = vulkan_target_triple.replace("-","_")
|
||||
self.vulkan_target_triple = vulkan_target_triple
|
||||
self.device_id = device_id
|
||||
self.precision = precision
|
||||
self.download_vmfb = download_vmfb
|
||||
@@ -1271,7 +1305,12 @@ class UnshardedVicuna(VicunaBase):
|
||||
safe_device = self.device.split("-")[0]
|
||||
if suffix in ["mlirbc", "mlir"]:
|
||||
return Path(f"{self.model_name}_{self.precision}.{suffix}")
|
||||
target_triple = "" if self.vulkan_target_triple=="" else f"_{self.vulkan_target_triple}"
|
||||
|
||||
target_triple = ""
|
||||
if self.vulkan_target_triple != "":
|
||||
target_triple = "_"
|
||||
target_triple += "_".join(self.vulkan_target_triple.split("-")[:-1])
|
||||
|
||||
return Path(
|
||||
f"{self.model_name}_{self.precision}_{safe_device}{target_triple}.{suffix}"
|
||||
)
|
||||
@@ -1368,8 +1407,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
elif "llama2_70b" in self.model_name:
|
||||
pkv_tensor_shape = "tensor<1x8x?x128x"
|
||||
else:
|
||||
pkv_tensor_shape = "tensor<1x32x?x128x"
|
||||
if self.precision in ["fp16", "int4", "int8"]:
|
||||
pkv_tensor_shape = "tensor<1x?x32x128x"
|
||||
if self.device!="cpu:" : #precision in ["fp16", "int4", "int8"]:
|
||||
pkv_tensor_shape += "f16>"
|
||||
else:
|
||||
pkv_tensor_shape += "f32>"
|
||||
@@ -1377,9 +1416,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append("%c2 = arith.constant 1 : index")
|
||||
new_lines.append(
|
||||
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
|
||||
f"%dim_4_int = tensor.dim %arg1, %c1 : {pkv_tensor_shape}"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
@@ -1420,10 +1459,12 @@ class UnshardedVicuna(VicunaBase):
|
||||
print(f"[DEBUG] vmfb found at {self.vicuna_vmfb_path.absolute()}")
|
||||
return
|
||||
|
||||
print(f"[DEBUG] vmfb not found")
|
||||
print(f"[DEBUG] vmfb not found (search path: {self.vicuna_vmfb_path})")
|
||||
mlir_generated = False
|
||||
for suffix in ["mlirbc", "mlir"]:
|
||||
self.vicuna_mlir_path = self.get_model_path(suffix)
|
||||
if "cpu" in self.device and "llama2_7b" in self.vicuna_mlir_path.name:
|
||||
self.vicuna_mlir_path = Path("llama2_7b_int4_f32.mlir")
|
||||
if not self.vicuna_mlir_path.exists() and self.load_mlir_from_shark_tank:
|
||||
print(
|
||||
f"Looking into gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}"
|
||||
@@ -1435,18 +1476,14 @@ class UnshardedVicuna(VicunaBase):
|
||||
)
|
||||
if self.vicuna_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
|
||||
with open(self.vicuna_mlir_path, "rb") as f:
|
||||
combined_module = f.read()
|
||||
combined_module = self.vicuna_mlir_path.absolute()
|
||||
mlir_generated = True
|
||||
break
|
||||
|
||||
print(self.device)
|
||||
print(self.device=="cpu")
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] mlir not found")
|
||||
# Disabling this path of IR generation for now as it is broken.
|
||||
print("Please check if the mlir file is present at the shark tank. Exiting.")
|
||||
self.shark_model = None
|
||||
sys.exit()
|
||||
return
|
||||
|
||||
print("[DEBUG] generating mlir on device")
|
||||
# Select a compilation prompt such that the resulting input_ids
|
||||
@@ -1468,15 +1505,26 @@ class UnshardedVicuna(VicunaBase):
|
||||
compilation_input_ids
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
if "cpu" in self.device:
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = FirstVicunaGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
is_f16 = self.device!="cpu"
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
firstVicunaCompileInput,
|
||||
@@ -1506,6 +1554,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(first_model_path[:-5]+"_torch.mlir", "w+") as f:
|
||||
f.write(str(first_module))
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
first_module,
|
||||
@@ -1556,36 +1607,68 @@ class UnshardedVicuna(VicunaBase):
|
||||
dim1 = 32
|
||||
total_tuple = 64
|
||||
pkv = tuple(
|
||||
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
|
||||
(torch.zeros([1, 19, dim1, 128], dtype=torch.float32))
|
||||
for _ in range(total_tuple)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
elif self.model_name == "llama2_70b":
|
||||
model = SecondVicuna70B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
if "cpu" in self.device:
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
elif self.model_name == "llama2_70b":
|
||||
model = SecondVicuna70B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = SecondVicuna7B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = SecondVicuna7B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13BGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
elif self.model_name == "llama2_70b":
|
||||
model = SecondVicuna70BGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = SecondVicuna7BGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
is_f16 = self.device!="cpu"
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
secondVicunaCompileInput,
|
||||
@@ -1595,7 +1678,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del model
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
if self.device != "cpu":
|
||||
secondVicunaCompileInput = get_f16_inputs(
|
||||
secondVicunaCompileInput,
|
||||
True,
|
||||
@@ -1605,7 +1688,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
for i in range(len(secondVicunaCompileInput)):
|
||||
if i != 0:
|
||||
secondVicunaCompileInput[i] = torch_mlir.TensorPlaceholder.like(
|
||||
secondVicunaCompileInput[i], dynamic_axes=[2]
|
||||
secondVicunaCompileInput[i], dynamic_axes=[1]
|
||||
)
|
||||
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
@@ -1620,6 +1703,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
if self.cache_vicunas:
|
||||
with open(second_model_path[:-5]+"_torch.mlir", "w+") as f:
|
||||
f.write(str(second_module))
|
||||
run_pipeline_with_repro_report(
|
||||
second_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
@@ -1653,6 +1739,12 @@ class UnshardedVicuna(VicunaBase):
|
||||
second_module,
|
||||
self.vicuna_mlir_path,
|
||||
)
|
||||
combined_module = save_mlir(
|
||||
combined_module,
|
||||
model_name="self.vicuna_mlir_path",
|
||||
mlir_dialect="tm_tensor",
|
||||
dir=str(os.getcwd()),
|
||||
)
|
||||
del first_module, second_module
|
||||
|
||||
print(self.device)
|
||||
@@ -1696,39 +1788,46 @@ class UnshardedVicuna(VicunaBase):
|
||||
res_tokens = []
|
||||
params = {"prompt": prompt, "is_first": True, "fv": self.shark_model}
|
||||
|
||||
prefill_st_time = time.time()
|
||||
generated_token_op = self.generate_new_token(
|
||||
params=params, sharded=False, cli=cli
|
||||
)
|
||||
prefill_time = time.time() - prefill_st_time
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
if "cpu" not in self.device:
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["past_key_values"]
|
||||
detok = generated_token_op["detok"]
|
||||
yield detok, ""
|
||||
yield detok, None, prefill_time
|
||||
|
||||
res_tokens.append(token)
|
||||
if cli:
|
||||
print(f"Assistant: {detok}", end=" ", flush=True)
|
||||
|
||||
for _ in range(self.max_num_tokens - 2):
|
||||
for idx in range(self.max_num_tokens):
|
||||
params = {
|
||||
"token": token,
|
||||
"is_first": False,
|
||||
"logits": logits,
|
||||
"past_key_values": pkv,
|
||||
"sv": self.shark_model,
|
||||
}
|
||||
if "cpu" not in self.device:
|
||||
params["logits"] = logits
|
||||
|
||||
decode_st_time = time.time()
|
||||
generated_token_op = self.generate_new_token(
|
||||
params=params, sharded=False, cli=cli
|
||||
)
|
||||
decode_time_ms = (time.time() - decode_st_time)*1000
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
if "cpu" not in self.device:
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["past_key_values"]
|
||||
detok = generated_token_op["detok"]
|
||||
|
||||
if token == 2:
|
||||
if token == 2 and idx >= self.min_num_tokens:
|
||||
break
|
||||
res_tokens.append(token)
|
||||
if detok == "<0x0A>":
|
||||
@@ -1737,10 +1836,10 @@ class UnshardedVicuna(VicunaBase):
|
||||
else:
|
||||
if cli:
|
||||
print(f"{detok}", end=" ", flush=True)
|
||||
yield detok, ""
|
||||
yield detok, None, decode_time_ms
|
||||
|
||||
res_str = self.decode_tokens(res_tokens)
|
||||
yield res_str, "formatted"
|
||||
yield res_str, "formatted", None
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
@@ -1836,7 +1935,7 @@ if __name__ == "__main__":
|
||||
device_id = id
|
||||
break
|
||||
id += 1
|
||||
|
||||
|
||||
assert device_id, f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
# Step 2. Add a few flags targetting specific hardwares.
|
||||
if "rdna" in vulkan_target_triple:
|
||||
@@ -1844,7 +1943,7 @@ if __name__ == "__main__":
|
||||
"--iree-spirv-index-bits=64",
|
||||
]
|
||||
_extra_args = _extra_args + flags_to_add
|
||||
|
||||
|
||||
|
||||
vic = None
|
||||
if not args.sharded:
|
||||
@@ -1858,9 +1957,16 @@ if __name__ == "__main__":
|
||||
if args.vicuna_vmfb_path is None
|
||||
else Path(args.vicuna_vmfb_path)
|
||||
)
|
||||
min_tokens = 0
|
||||
max_tokens = 512
|
||||
if args.enable_microbenchmark:
|
||||
min_tokens = max_tokens = args.microbenchmark_num_tokens
|
||||
|
||||
vic = UnshardedVicuna(
|
||||
model_name=args.model_name,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
max_num_tokens=max_tokens,
|
||||
min_num_tokens=min_tokens,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
vicuna_mlir_path=vic_mlir_path,
|
||||
@@ -1887,17 +1993,6 @@ if __name__ == "__main__":
|
||||
weight_group_size=args.weight_group_size,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
if args.model_name == "vicuna":
|
||||
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
else:
|
||||
system_message = """System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
as helpfully as possible, while being safe. Your answers should not
|
||||
include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal
|
||||
content. Please ensure that your responses are socially unbiased and positive
|
||||
in nature. If a question does not make any sense, or is not factually coherent,
|
||||
explain why instead of answering something not correct. If you don't know the
|
||||
answer to a question, please don't share false information."""
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
history = []
|
||||
|
||||
@@ -1907,12 +2002,55 @@ if __name__ == "__main__":
|
||||
"llama2_13b": "llama2_13b=>meta-llama/Llama-2-13b-chat-hf",
|
||||
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
|
||||
}
|
||||
|
||||
iteration = 0
|
||||
|
||||
prefill_times = []
|
||||
avg_decode_speed = []
|
||||
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
history.append([user_prompt, ""])
|
||||
prompt = create_prompt(args.model_name, history)
|
||||
for text, msg in vic.generate(prompt, cli=True):
|
||||
if "formatted" in msg:
|
||||
print("Response:", text)
|
||||
iteration += 1
|
||||
if not args.enable_microbenchmark:
|
||||
user_prompt = input("User prompt: ")
|
||||
history.append([user_prompt, ""])
|
||||
prompt = create_prompt(args.model_name, history)
|
||||
else:
|
||||
if iteration > args.microbenchmark_iterations:
|
||||
break
|
||||
user_prompt = args.user_prompt
|
||||
prompt = args.system_prompt + user_prompt
|
||||
history = [[user_prompt, ""]]
|
||||
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in vic.generate(prompt, cli=True):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
prefill_times.append(prefill_time)
|
||||
avg_decode_speed.append(tokens_per_sec)
|
||||
|
||||
print("\nResponse:", text.strip())
|
||||
print(f"\nNum tokens: {token_count}")
|
||||
print(f"Prefill: {prefill_time:.2f} seconds")
|
||||
print(f"Decode: {tokens_per_sec:.2f} tokens/s")
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
|
||||
if args.enable_microbenchmark:
|
||||
print("\n### Final Statistics ###")
|
||||
print("Number of iterations:", iteration - 1)
|
||||
print(f"Prefill: avg. {mean(prefill_times):.2f} s, stdev {stdev(prefill_times):.2f}")
|
||||
print(f"Decode: avg. {mean(avg_decode_speed):.2f} tokens/s, stdev {stdev(avg_decode_speed):.2f}")
|
||||
|
||||
@@ -54,7 +54,6 @@ from apps.language_models.utils import (
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import get_f16_inputs
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
@@ -7,6 +7,7 @@ class FirstVicuna(torch.nn.Module):
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -15,6 +16,9 @@ class FirstVicuna(torch.nn.Module):
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
@@ -29,7 +33,7 @@ class FirstVicuna(torch.nn.Module):
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float16 if precision == "int4" else torch.float32,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
@@ -43,11 +47,13 @@ class FirstVicuna(torch.nn.Module):
|
||||
def forward(self, input_ids):
|
||||
op = self.model(input_ids=input_ids, use_cache=True)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
token = torch.argmax(op.logits[:, -1, :], dim=1)
|
||||
return_vals.append(token)
|
||||
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return_vals.append(item[0].transpose(1,2))
|
||||
return_vals.append(item[1].transpose(1,2))
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
@@ -56,6 +62,7 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -67,6 +74,9 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
@@ -78,7 +88,7 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float16 if precision == "int4" else torch.float32,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
@@ -285,15 +295,19 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
i64,
|
||||
),
|
||||
)
|
||||
|
||||
past_key_values = [(x[0].transpose(1,2), x[0].transpose(1,2)) for x in past_key_values]
|
||||
past_key_values = tuple(past_key_values)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
token = torch.argmax(op.logits[:, -1, :], dim=1)
|
||||
return_vals.append(token)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return_vals.append(item[0].transpose(1,2))
|
||||
return_vals.append(item[1].transpose(1,2))
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
@@ -302,6 +316,7 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
self,
|
||||
model_path,
|
||||
precision="int8",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -313,6 +328,9 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
@@ -323,7 +341,7 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float16 if precision == "int4" else torch.float32,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
@@ -595,6 +613,7 @@ class SecondVicuna70B(torch.nn.Module):
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -606,6 +625,9 @@ class SecondVicuna70B(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
@@ -617,7 +639,7 @@ class SecondVicuna70B(torch.nn.Module):
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float16,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
|
||||
1165
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
1165
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,9 +7,9 @@ from io import BytesIO
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stdout
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
|
||||
from transformers.generation import (
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
@@ -32,7 +32,7 @@ parser.add_argument(
|
||||
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
|
||||
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
@@ -119,9 +119,16 @@ class Falcon(SharkLLMBase):
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
if self.precision == "int4":
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
if self.precision == "int4":
|
||||
falcon_model = falcon_model.to(torch.float32)
|
||||
return falcon_model
|
||||
|
||||
def compile(self):
|
||||
@@ -173,8 +180,6 @@ class Falcon(SharkLLMBase):
|
||||
print(
|
||||
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
|
||||
if not mlir_generated:
|
||||
@@ -195,9 +200,10 @@ class Falcon(SharkLLMBase):
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
falconCompileInput,
|
||||
is_f16=self.precision == "fp16",
|
||||
is_f16=self.precision in ["fp16", "int4"],
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
is_gptq=self.precision == "int4",
|
||||
)
|
||||
del model
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
@@ -221,9 +227,12 @@ class Falcon(SharkLLMBase):
|
||||
f_.write(bytecode)
|
||||
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
|
||||
f_.close()
|
||||
del bytecode
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module=self.falcon_mlir_path,
|
||||
device=self.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
@@ -232,7 +241,12 @@ class Falcon(SharkLLMBase):
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
]
|
||||
+ [
|
||||
"--iree-llvmcpu-use-fast-min-max-ops",
|
||||
]
|
||||
if self.precision == "int4"
|
||||
else [],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
@@ -409,7 +423,7 @@ class Falcon(SharkLLMBase):
|
||||
(model_inputs["input_ids"], model_inputs["attention_mask"]),
|
||||
)
|
||||
)
|
||||
if self.precision == "fp16":
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
@@ -488,12 +502,22 @@ if __name__ == "__main__":
|
||||
else Path(args.falcon_vmfb_path)
|
||||
)
|
||||
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "tiiuae/falcon-180B-chat"
|
||||
if args.precision == "int4":
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"TheBloke/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct-GPTQ"
|
||||
)
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
|
||||
)
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "tiiuae/falcon-180B-chat"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
|
||||
)
|
||||
|
||||
falcon = Falcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
@@ -524,7 +548,11 @@ if __name__ == "__main__":
|
||||
prompt = input("Please enter the prompt text: ")
|
||||
print("\nPrompt Text: ", prompt)
|
||||
|
||||
res_str = falcon.generate(prompt)
|
||||
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
|
||||
User: {prompt}
|
||||
Assistant:"""
|
||||
|
||||
res_str = falcon.generate(prompt_template)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
print(
|
||||
|
||||
@@ -126,7 +126,7 @@ def is_url(input_url):
|
||||
import os
|
||||
import tempfile
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
@@ -235,6 +235,12 @@ def compile_int_precision(
|
||||
mlir_module = BytesIO(mlir_module)
|
||||
bytecode = mlir_module.read()
|
||||
print(f"Elided IR written for {extended_model_name}")
|
||||
bytecode = save_mlir(
|
||||
bytecode,
|
||||
model_name=extended_model_name,
|
||||
frontend="torch",
|
||||
dir=os.getcwd(),
|
||||
)
|
||||
return bytecode
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
|
||||
@@ -74,6 +74,9 @@ datas += [
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("diffusers") if "tests" not in x
|
||||
]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
|
||||
@@ -710,8 +710,11 @@ class SharkifyStableDiffusionModel:
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
|
||||
save_dir = ""
|
||||
if self.debug:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["clip"]
|
||||
)
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
|
||||
@@ -84,9 +84,6 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
|
||||
@@ -458,6 +458,14 @@ p.add_argument(
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device_allocator_heap_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify heap key for device caching allocator."
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
@@ -18,7 +18,7 @@ import tempfile
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
@@ -154,8 +154,8 @@ def compile_through_fx(
|
||||
f16_input_mask=f16_input_mask,
|
||||
debug=debug,
|
||||
model_name=extended_model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
|
||||
if use_tuned:
|
||||
if "vae" in extended_model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
@@ -168,6 +168,14 @@ def compile_through_fx(
|
||||
mlir_module, extended_model_name, base_model_id
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_dir):
|
||||
save_dir = ""
|
||||
|
||||
mlir_module = save_mlir(
|
||||
mlir_module,
|
||||
model_name=extended_model_name,
|
||||
dir=save_dir,
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device if device is None else device,
|
||||
@@ -179,17 +187,22 @@ def compile_through_fx(
|
||||
mlir_module,
|
||||
)
|
||||
|
||||
del mlir_module
|
||||
gc.collect()
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if args.device_allocator_heap_key:
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={args.device_allocator_heap_key}",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
@@ -522,10 +535,6 @@ def get_opt_flags(model, precision="fp16"):
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
)
|
||||
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"default_compilation_flags"
|
||||
|
||||
@@ -156,9 +156,9 @@ if __name__ == "__main__":
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
lora_train_web,
|
||||
model_web,
|
||||
model_config_web,
|
||||
# lora_train_web,
|
||||
# model_web,
|
||||
# model_config_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
@@ -250,16 +250,16 @@ if __name__ == "__main__":
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
with gr.TabItem(label="Model Manager", id=6):
|
||||
model_web.render()
|
||||
with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot (Experimental)", id=8):
|
||||
# with gr.TabItem(label="Model Manager", id=6):
|
||||
# model_web.render()
|
||||
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
# lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=8):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(
|
||||
label="Generate Sharding Config (Experimental)", id=9
|
||||
):
|
||||
model_config_web.render()
|
||||
# with gr.TabItem(
|
||||
# label="Generate Sharding Config (Experimental)", id=9
|
||||
# ):
|
||||
# model_config_web.render()
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
|
||||
@@ -212,6 +212,7 @@ with gr.Blocks(title="DocuChat") as h2ogpt_web:
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
|
||||
@@ -396,6 +396,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
img2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
@@ -421,6 +422,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -452,6 +454,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Stencil model",
|
||||
value="None",
|
||||
choices=["None", "canny", "openpose", "scribble"],
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
def show_canvas(choice):
|
||||
@@ -512,6 +515,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
).replace("\\", "\n\\")
|
||||
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=i2i_lora_info,
|
||||
elem_id="lora_weights",
|
||||
@@ -535,6 +539,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -590,6 +595,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
"Cubic",
|
||||
],
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
@@ -648,6 +654,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
|
||||
@@ -344,6 +344,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
inpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
@@ -369,6 +370,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -406,6 +408,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -424,6 +427,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -527,6 +531,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
|
||||
@@ -50,6 +50,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
@@ -73,6 +74,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -105,6 +107,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
@@ -177,6 +180,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
|
||||
@@ -143,6 +143,7 @@ with gr.Blocks() as minigpt4_web:
|
||||
# else "Only CUDA Supported for now",
|
||||
choices=["cuda"],
|
||||
interactive=False,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
|
||||
@@ -98,6 +98,7 @@ with gr.Blocks() as model_web:
|
||||
choices=None,
|
||||
value=None,
|
||||
visible=False,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
# TODO: select and SendTo
|
||||
civit_models = gr.Gallery(
|
||||
|
||||
@@ -351,6 +351,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
outpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
@@ -376,6 +377,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -411,6 +413,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -429,6 +432,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -555,6 +559,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
|
||||
@@ -109,6 +109,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
value="",
|
||||
interactive=True,
|
||||
elem_classes="dropdown_no_container",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column(
|
||||
scale=1,
|
||||
|
||||
@@ -8,7 +8,7 @@ from transformers import (
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import time
|
||||
import sys
|
||||
|
||||
|
||||
def user(message, history):
|
||||
@@ -66,8 +66,10 @@ start_message = {
|
||||
}
|
||||
|
||||
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
def create_prompt(model_name, history, prompt_prefix):
|
||||
system_message = ""
|
||||
if prompt_prefix:
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if "llama2" in model_name:
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
@@ -129,7 +131,7 @@ model_vmfb_key = ""
|
||||
|
||||
# TODO: Make chat reusable for UI and API
|
||||
def chat(
|
||||
curr_system_message,
|
||||
prompt_prefix,
|
||||
history,
|
||||
model,
|
||||
device,
|
||||
@@ -163,8 +165,8 @@ def chat(
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{precision}"
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
|
||||
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
@@ -237,25 +239,38 @@ def chat(
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
prompt = create_prompt(model_name, history)
|
||||
if vicuna_model is None:
|
||||
sys.exit("Unable to instantiate the model object, exiting.")
|
||||
|
||||
prompt = create_prompt(model_name, history, prompt_prefix)
|
||||
|
||||
partial_text = ""
|
||||
count = 0
|
||||
start_time = time.time()
|
||||
for text, msg in progress.tqdm(
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
count += 1
|
||||
if "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
end_time = time.time()
|
||||
tokens_per_sec = count / (end_time - start_time)
|
||||
yield history, str(format(tokens_per_sec, ".2f")) + " tokens/sec"
|
||||
else:
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, ""
|
||||
yield history, f"Prefill: {prefill_time:.2f}"
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
|
||||
return history, ""
|
||||
|
||||
@@ -380,6 +395,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
label="Select Model",
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
supported_devices = available_devices
|
||||
enabled = len(supported_devices) > 0
|
||||
@@ -393,25 +409,31 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
allow_custom_value=True,
|
||||
# multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="int8",
|
||||
value="int4",
|
||||
choices=[
|
||||
"int4",
|
||||
"int8",
|
||||
"fp16",
|
||||
],
|
||||
visible=True,
|
||||
visible=False,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
prompt_prefix = gr.Checkbox(
|
||||
label="Add System Prompt",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
@@ -438,16 +460,17 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
system_msg = gr.Textbox(
|
||||
start_message, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[
|
||||
system_msg,
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
@@ -456,14 +479,19 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[
|
||||
system_msg,
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
@@ -472,6 +500,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
|
||||
@@ -406,6 +406,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
txt2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
@@ -430,6 +431,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
txt2img_png_info_img = gr.Image(
|
||||
@@ -466,6 +468,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -484,6 +487,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -568,6 +572,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
"Cubic",
|
||||
],
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
hiresfix_height = gr.Slider(
|
||||
384,
|
||||
@@ -624,6 +629,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
|
||||
@@ -365,6 +365,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
custom_checkpoint_type="upscaler"
|
||||
)
|
||||
+ predefined_upscaler_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
upscaler_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
@@ -390,6 +391,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -425,6 +427,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -443,6 +446,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Scheduler",
|
||||
value="DDIM",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -547,6 +551,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
|
||||
@@ -8,19 +8,8 @@ torchvision
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
#these dont work ok osx
|
||||
#iree-tools-tflite
|
||||
#iree-tools-xla
|
||||
#iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow-macos
|
||||
tensorflow-metal
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
tensorflow-probability
|
||||
#jax[cpu]
|
||||
|
||||
# tflitehub dependencies.
|
||||
|
||||
@@ -9,23 +9,13 @@ tabulate
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
iree-tools-tflite
|
||||
iree-tools-xla
|
||||
iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
# Modelling and JAX.
|
||||
gin-config
|
||||
tf-nightly
|
||||
keras-nightly
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
diffusers
|
||||
#tensorflow-probability
|
||||
#jax[cpu]
|
||||
|
||||
|
||||
# tflitehub dependencies.
|
||||
Pillow
|
||||
|
||||
# Testing and support.
|
||||
|
||||
@@ -47,4 +47,4 @@ pefile
|
||||
pyinstaller
|
||||
|
||||
# vicuna quantization
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@dev
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
|
||||
|
||||
@@ -86,6 +86,7 @@ $PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
|
||||
if [ "$torch_mlir_bin" = true ]; then
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
|
||||
$PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
@@ -128,7 +129,13 @@ if [[ ! -z "${IMPORTER}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/cpu/
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
PYTORCH_URL=https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
PYTORCH_URL=https://download.pytorch.org/whl/nightly/cpu/
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL}
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
|
||||
T_VER=$($PYTHON -m pip show torch | grep Version)
|
||||
|
||||
@@ -292,9 +292,10 @@ def compile_module_to_flatbuffer(
|
||||
extra_args,
|
||||
model_name="None",
|
||||
debug=False,
|
||||
compile_str=False,
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
input_type = "auto"
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args(debug=debug)
|
||||
@@ -311,10 +312,7 @@ def compile_module_to_flatbuffer(
|
||||
elif frontend in ["tm_tensor"]:
|
||||
input_type = ireec.InputType.TM_TENSOR
|
||||
|
||||
# TODO: make it simpler.
|
||||
# Compile according to the input type, else just try compiling.
|
||||
if input_type != "":
|
||||
# Currently for MHLO/TOSA.
|
||||
if compile_str:
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[iree_target_map(device)],
|
||||
@@ -322,9 +320,10 @@ def compile_module_to_flatbuffer(
|
||||
input_type=input_type,
|
||||
)
|
||||
else:
|
||||
# Currently for Torch.
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
assert os.path.isfile(module)
|
||||
flatbuffer_blob = ireec.compile_file(
|
||||
module,
|
||||
input_type=input_type,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=args,
|
||||
)
|
||||
@@ -404,6 +403,11 @@ def load_vmfb_using_mmap(
|
||||
dl.log(f"mmap {flatbuffer_blob_or_path}")
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
dl.log(f"ireert.SystemContext created")
|
||||
if "vulkan" in device:
|
||||
# Vulkan pipeline creation consumes significant amount of time.
|
||||
print(
|
||||
"\tCompiling Vulkan shaders. This may take a few minutes."
|
||||
)
|
||||
ctx.add_vm_module(mmaped_vmfb)
|
||||
dl.log(f"module initialized")
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
@@ -427,10 +431,17 @@ def get_iree_compiled_module(
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
debug: bool = False,
|
||||
compile_str: bool = False,
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, model_config_path, extra_args, debug
|
||||
module,
|
||||
device,
|
||||
frontend,
|
||||
model_config_path,
|
||||
extra_args,
|
||||
debug,
|
||||
compile_str,
|
||||
)
|
||||
temp_file_to_unlink = None
|
||||
# TODO: Currently mmap=True control flow path has been switched off for mmap.
|
||||
@@ -487,10 +498,17 @@ def export_iree_module_to_vmfb(
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
debug: bool = False,
|
||||
compile_str: bool = False,
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, mlir_dialect, model_config_path, extra_args, debug
|
||||
module,
|
||||
device,
|
||||
mlir_dialect,
|
||||
model_config_path,
|
||||
extra_args,
|
||||
debug,
|
||||
compile_str,
|
||||
)
|
||||
if module_name is None:
|
||||
device_name = (
|
||||
|
||||
@@ -89,24 +89,10 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
|
||||
|
||||
def get_iree_metal_args(device_num=0, extra_args=[]):
|
||||
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
# Add any metal spefic compilation flags here
|
||||
res_metal_flag = []
|
||||
metal_triple_flag = None
|
||||
for arg in extra_args:
|
||||
if "-iree-metal-target-platform=" in arg:
|
||||
print(f"Using target triple {arg} from command line args")
|
||||
metal_triple_flag = arg
|
||||
break
|
||||
|
||||
if metal_triple_flag is None:
|
||||
metal_triple_flag = get_metal_triple_flag(extra_args=extra_args)
|
||||
|
||||
if metal_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(
|
||||
"-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
)
|
||||
res_metal_flag.append(vulkan_target_env)
|
||||
if len(extra_args) > 0:
|
||||
res_metal_flag.extend(extra_args)
|
||||
return res_metal_flag
|
||||
|
||||
|
||||
|
||||
@@ -119,6 +119,8 @@ def get_vulkan_target_triple(device_name):
|
||||
# Windows: AMD Radeon RX 7900 XTX
|
||||
elif all(x in device_name for x in ("RX", "7900")):
|
||||
triple = f"rdna3-7900-{system_os}"
|
||||
elif all(x in device_name for x in ("Radeon", "780M")):
|
||||
triple = f"rdna3-780m-{system_os}"
|
||||
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
|
||||
triple = f"rdna3-w7900-{system_os}"
|
||||
elif any(x in device_name for x in ("AMD", "Radeon")):
|
||||
|
||||
@@ -84,6 +84,13 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
self.extra_args = extra_args
|
||||
self.import_args = {}
|
||||
self.temp_file_to_unlink = None
|
||||
if not os.path.isfile(mlir_module):
|
||||
print(
|
||||
"Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
|
||||
)
|
||||
self.compile_str = True
|
||||
else:
|
||||
self.compile_str = False
|
||||
SharkRunner.__init__(
|
||||
self,
|
||||
mlir_module,
|
||||
@@ -98,6 +105,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
".",
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
params = load_flatbuffer(
|
||||
self.vmfb_file,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import tempfile
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
@@ -130,10 +130,17 @@ def compile_int_precision(
|
||||
mlir_module = mlir_module.encode("UTF-8")
|
||||
mlir_module = BytesIO(mlir_module)
|
||||
bytecode = mlir_module.read()
|
||||
bytecode_path = os.path.join(
|
||||
os.getcwd(), f"{extended_model_name}_linalg.mlirbc"
|
||||
)
|
||||
with open(bytecode_path, "wb") as f:
|
||||
f.write(bytecode)
|
||||
del bytecode
|
||||
del mlir_module
|
||||
print(f"Elided IR written for {extended_model_name}")
|
||||
return bytecode
|
||||
return bytecode_path
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
mlir_module=bytecode_path, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
@@ -148,7 +155,7 @@ def compile_int_precision(
|
||||
generate_vmfb=generate_vmfb,
|
||||
extra_args=extra_args,
|
||||
),
|
||||
bytecode,
|
||||
bytecode_path,
|
||||
)
|
||||
|
||||
|
||||
@@ -201,7 +208,7 @@ def shark_compile_through_fx(
|
||||
]
|
||||
else:
|
||||
(
|
||||
mlir_module,
|
||||
bytecode,
|
||||
_,
|
||||
) = import_with_fx(
|
||||
model=model,
|
||||
@@ -212,6 +219,11 @@ def shark_compile_through_fx(
|
||||
model_name=extended_model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
mlir_module = save_mlir(
|
||||
mlir_module=bytecode,
|
||||
model_name=extended_model_name,
|
||||
mlir_dialect=mlir_dialect,
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
|
||||
@@ -275,11 +275,11 @@ def download_model(
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
tuned_str = "" if tuned is None else "_" + tuned
|
||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
mlir_filename = os.path.join(model_dir, model_name + suffix)
|
||||
print(
|
||||
f"Verifying that model artifacts were downloaded successfully to {filename}..."
|
||||
f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..."
|
||||
)
|
||||
if not os.path.exists(filename):
|
||||
if not os.path.exists(mlir_filename):
|
||||
from tank.generate_sharktank import gen_shark_files
|
||||
|
||||
print(
|
||||
@@ -287,13 +287,11 @@ def download_model(
|
||||
)
|
||||
gen_shark_files(model_name, frontend, WORKDIR, import_args)
|
||||
|
||||
assert os.path.exists(filename), f"MLIR not found at {filename}"
|
||||
with open(filename, mode="rb") as f:
|
||||
mlir_file = f.read()
|
||||
assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}"
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
return mlir_filename, function_name, inputs_tuple, golden_out_tuple
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
import torchvision.models as models
|
||||
import copy
|
||||
import io
|
||||
@@ -20,10 +20,16 @@ def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
bytecode_path = save_mlir(
|
||||
bytecode,
|
||||
model_name="shark_eager_module",
|
||||
frontend="torch",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
mlir_module=bytecode_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
|
||||
@@ -3,8 +3,8 @@ import json
|
||||
import numpy as np
|
||||
|
||||
import torch_mlir
|
||||
from iree.compiler import compile_str
|
||||
from shark.shark_importer import import_with_fx, get_f16_inputs
|
||||
from iree.compiler import compile_file
|
||||
from shark.shark_importer import import_with_fx, get_f16_inputs, save_mlir
|
||||
|
||||
|
||||
class GenerateConfigFile:
|
||||
@@ -54,9 +54,15 @@ class GenerateConfigFile:
|
||||
verbose=False,
|
||||
)
|
||||
module = module.operation.get_asm(large_elements_limit=4)
|
||||
module_file = save_mlir(
|
||||
module,
|
||||
model_name="module_pre_split",
|
||||
frontend="torch",
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
compiled_module_str = str(
|
||||
compile_str(
|
||||
str(module),
|
||||
compile_file(
|
||||
module_file,
|
||||
target_backends=[backend],
|
||||
extra_args=[
|
||||
"--compile-to=flow",
|
||||
|
||||
@@ -451,6 +451,108 @@ def transform_fx(fx_g, quantized=False):
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
def gptq_transforms(fx_g):
|
||||
import torch
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.ones,
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("device") == torch.device(device="cuda:0"):
|
||||
updated_kwargs = node.kwargs.copy()
|
||||
updated_kwargs["device"] = torch.device(device="cpu")
|
||||
node.kwargs = updated_kwargs
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("dtype") == torch.bfloat16:
|
||||
updated_kwargs = node.kwargs.copy()
|
||||
updated_kwargs["dtype"] = torch.float16
|
||||
node.kwargs = updated_kwargs
|
||||
|
||||
# Inputs of aten.native_layer_norm should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.native_layer_norm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (
|
||||
new_node_arg0,
|
||||
node.args[1],
|
||||
node.args[2],
|
||||
node.args[3],
|
||||
node.args[4],
|
||||
)
|
||||
|
||||
# Inputs of aten.mm should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.mm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
new_node_arg1 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[1], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node_arg0, new_node_arg1)
|
||||
|
||||
# Outputs of aten.mm should be downcasted to fp16.
|
||||
if type(node.args[0]) == torch.fx.node.Node and node.args[
|
||||
0
|
||||
].target in [torch.ops.aten.mm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
tmp = node.args[0]
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node.args[0],),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.args[0].append(new_node)
|
||||
node.args[0].replace_all_uses_with(new_node)
|
||||
new_node.args = (tmp,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
# Inputs of aten._softmax should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten._softmax]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node_arg0, node.args[1], node.args[2])
|
||||
|
||||
# Outputs of aten._softmax should be downcasted to fp16.
|
||||
if (
|
||||
type(node.args[0]) == torch.fx.node.Node
|
||||
and node.args[0].target in [torch.ops.aten._softmax]
|
||||
and node.target in [torch.ops.aten.expand]
|
||||
):
|
||||
with fx_g.graph.inserting_before(node):
|
||||
tmp = node.args[0]
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node.args[0],),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.args[0].append(new_node)
|
||||
node.args[0].replace_all_uses_with(new_node)
|
||||
new_node.args = (tmp,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
# Doesn't replace the None type.
|
||||
def change_fx_graph_return_to_tuple(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
@@ -504,6 +606,7 @@ def import_with_fx(
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
precision="fp32",
|
||||
is_gptq=False,
|
||||
):
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@@ -584,7 +687,7 @@ def import_with_fx(
|
||||
torch.ops.aten.index_add,
|
||||
torch.ops.aten.index_add_,
|
||||
]
|
||||
if precision in ["int4", "int8"]:
|
||||
if precision in ["int4", "int8"] and not is_gptq:
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
block_quant_layer_level_manager,
|
||||
)
|
||||
@@ -653,6 +756,10 @@ def import_with_fx(
|
||||
add_upcast(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if is_gptq:
|
||||
gptq_transforms(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if mlir_type == "fx":
|
||||
return fx_g
|
||||
|
||||
@@ -685,3 +792,25 @@ def import_with_fx(
|
||||
|
||||
mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type)
|
||||
return mlir_module, func_name
|
||||
|
||||
|
||||
# Saves a .mlir module python object to the directory 'dir' with 'model_name' and returns a path to the saved file.
|
||||
def save_mlir(
|
||||
mlir_module,
|
||||
model_name,
|
||||
mlir_dialect="linalg",
|
||||
frontend="torch",
|
||||
dir=tempfile.gettempdir(),
|
||||
):
|
||||
model_name_mlir = (
|
||||
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
|
||||
)
|
||||
if dir == "":
|
||||
dir = tempfile.gettempdir()
|
||||
mlir_path = os.path.join(dir, model_name_mlir)
|
||||
print(f"saving {model_name_mlir} to {dir}")
|
||||
if frontend == "torch":
|
||||
with open(mlir_path, "wb") as mlir_file:
|
||||
mlir_file.write(mlir_module)
|
||||
|
||||
return mlir_path
|
||||
|
||||
@@ -39,7 +39,7 @@ class SharkInference:
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
|
||||
mlir_module or path represented in string; modules from torch-mlir are serialized in bytecode format.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
@@ -65,7 +65,7 @@ class SharkInference:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: bytes,
|
||||
mlir_module,
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
is_benchmark: bool = False,
|
||||
@@ -75,6 +75,14 @@ class SharkInference:
|
||||
mmap: bool = True,
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
if mlir_module is not None:
|
||||
if mlir_module and not os.path.isfile(mlir_module):
|
||||
print(
|
||||
"Warning: Initializing SharkInference with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
|
||||
)
|
||||
self.compile_str = True
|
||||
else:
|
||||
self.compile_str = False
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.is_benchmark = is_benchmark
|
||||
@@ -203,6 +211,7 @@ class SharkInference:
|
||||
module_name=module_name,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
|
||||
# load and return the module.
|
||||
|
||||
@@ -45,7 +45,7 @@ class SharkRunner:
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string.
|
||||
mlir_module path, string, or bytecode.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
@@ -74,6 +74,14 @@ class SharkRunner:
|
||||
device_idx: int = None,
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
if self.mlir_module is not None:
|
||||
if not os.path.isfile(mlir_module):
|
||||
print(
|
||||
"Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
|
||||
)
|
||||
self.compile_str = True
|
||||
else:
|
||||
self.compile_str = False
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
@@ -91,6 +99,7 @@ class SharkRunner:
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
device_idx=self.device_idx,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
self.iree_compilation_module = params["vmfb"]
|
||||
self.iree_config = params["config"]
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
from shark.parser import shark_args
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.backward_makefx import MakeFxModule
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
@@ -84,6 +84,12 @@ class SharkTrainer:
|
||||
training=True,
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
mlir_module = save_mlir(
|
||||
mlir_module,
|
||||
model_name="shark_model",
|
||||
frontend="torch",
|
||||
mlir_dialect=mlir_type,
|
||||
)
|
||||
self.shark_runner = SharkRunner(
|
||||
mlir_module,
|
||||
self.device,
|
||||
|
||||
@@ -1,24 +1,6 @@
|
||||
resnet50,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
albert-base-v2,stablehlo,tf,1e-2,1e-2,default,None,False,False,False,"",""
|
||||
roberta-base,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
bert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","enabled_windows"
|
||||
camembert-base,stablehlo,tf,1e-2,1e-3,default,None,True,True,True,"",""
|
||||
dbmdz/convbert-base-turkish-cased,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/iree-org/iree/issues/9971",""
|
||||
distilbert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/convnext-tiny-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342","macos"
|
||||
funnel-transformer/small,stablehlo,tf,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/201",""
|
||||
google/electra-small-discriminator,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"Fails during iree-compile","macos"
|
||||
google/vit-base-patch16-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,False,"",""
|
||||
microsoft/MiniLM-L12-H384-uncased,stablehlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
|
||||
microsoft/layoutlm-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
microsoft/mpnet-base,stablehlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
|
||||
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-large-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
|
||||
@@ -32,14 +14,8 @@ resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
|
||||
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
|
||||
efficientnet-v2-s,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b0,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
gpt2,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
|
||||
t5-base,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
|
||||
t5-large,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
|
||||
|
@@ -36,9 +36,7 @@ def create_module(model_name, tokenizer, device):
|
||||
|
||||
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
|
||||
if os.path.isfile(mlir_path):
|
||||
with open(mlir_path, "r") as f:
|
||||
model_mlir = f.read()
|
||||
print(f"Loaded .mlir from {mlir_path}")
|
||||
print(f"Found .mlir from {mlir_path}")
|
||||
else:
|
||||
(model_mlir, func_name) = import_with_fx(
|
||||
model=opt_model,
|
||||
@@ -50,9 +48,10 @@ def create_module(model_name, tokenizer, device):
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
del model_mlir
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
mlir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=False,
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
OPT_MODEL = "facebook/opt-1.3b"
|
||||
@@ -57,9 +57,10 @@ class OPTModuleTester:
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(mlir_module)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
del mlir_module
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
mlir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=self.benchmark,
|
||||
|
||||
@@ -18,7 +18,6 @@ import collections
|
||||
import json
|
||||
import os
|
||||
import psutil
|
||||
import resource
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import torch
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
|
||||
model_name = "facebook/opt-1.3b"
|
||||
@@ -25,11 +25,13 @@ inputs = (
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
debug=True,
|
||||
model_name=model_name.split("/")[1],
|
||||
save_dir=".",
|
||||
)
|
||||
|
||||
mlir_module = save_mlir(
|
||||
mlir_module,
|
||||
model_name=model_name.split("/")[1],
|
||||
frontend="torch",
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device="cpu-sync",
|
||||
|
||||
@@ -36,7 +36,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
get_hf_img_cls_model,
|
||||
get_fp16_model,
|
||||
)
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
|
||||
with open(torch_model_list) as csvfile:
|
||||
torch_reader = csv.reader(csvfile, delimiter=",")
|
||||
@@ -130,133 +130,6 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
)
|
||||
|
||||
|
||||
def save_tf_model(tf_model_list, local_tank_cache, import_args):
|
||||
from tank.model_utils_tf import (
|
||||
get_causal_image_model,
|
||||
get_masked_lm_model,
|
||||
get_causal_lm_model,
|
||||
get_keras_model,
|
||||
get_TFhf_model,
|
||||
get_tfhf_seq2seq_model,
|
||||
)
|
||||
import os
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
import tensorflow as tf
|
||||
|
||||
visible_default = tf.config.list_physical_devices("GPU")
|
||||
try:
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
visible_devices = tf.config.get_visible_devices()
|
||||
for device in visible_devices:
|
||||
assert device.device_type != "GPU"
|
||||
except:
|
||||
# Invalid device or cannot modify virtual devices once initialized.
|
||||
pass
|
||||
|
||||
with open(tf_model_list) as csvfile:
|
||||
tf_reader = csv.reader(csvfile, delimiter=",")
|
||||
fields = next(tf_reader)
|
||||
for row in tf_reader:
|
||||
tf_model_name = row[0]
|
||||
model_type = row[1]
|
||||
|
||||
model = None
|
||||
input = None
|
||||
print(f"Generating artifacts for model {tf_model_name}")
|
||||
if model_type == "hf":
|
||||
model, input, _ = get_masked_lm_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "img":
|
||||
model, input, _ = get_causal_image_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "keras":
|
||||
model, input, _ = get_keras_model(tf_model_name, import_args)
|
||||
elif model_type == "TFhf":
|
||||
model, input, _ = get_TFhf_model(tf_model_name, import_args)
|
||||
elif model_type == "tfhf_seq2seq":
|
||||
model, input, _ = get_tfhf_seq2seq_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf_causallm":
|
||||
model, input, _ = get_causal_lm_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
if import_args["batch_size"] != 1:
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache,
|
||||
str(tf_model_name)
|
||||
+ "_tf"
|
||||
+ f"_BS{str(import_args['batch_size'])}",
|
||||
)
|
||||
else:
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache, str(tf_model_name) + "_tf"
|
||||
)
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
inputs=input,
|
||||
frontend="tf",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
dir=tf_model_dir,
|
||||
model_name=tf_model_name,
|
||||
)
|
||||
|
||||
|
||||
def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
|
||||
from shark.tflite_utils import TFLitePreprocessor
|
||||
|
||||
with open(tflite_model_list) as csvfile:
|
||||
tflite_reader = csv.reader(csvfile, delimiter=",")
|
||||
for row in tflite_reader:
|
||||
print("\n")
|
||||
tflite_model_name = row[0]
|
||||
tflite_model_link = row[1]
|
||||
print("tflite_model_name", tflite_model_name)
|
||||
print("tflite_model_link", tflite_model_link)
|
||||
tflite_model_name_dir = os.path.join(
|
||||
local_tank_cache, str(tflite_model_name) + "_tflite"
|
||||
)
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
|
||||
# Preprocess to get SharkImporter input import_args
|
||||
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
|
||||
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
|
||||
inputs = tflite_preprocessor.get_inputs()
|
||||
tflite_interpreter = tflite_preprocessor.get_interpreter()
|
||||
|
||||
# Use SharkImporter to get SharkInference input import_args
|
||||
my_shark_importer = SharkImporter(
|
||||
module=tflite_interpreter,
|
||||
inputs=inputs,
|
||||
frontend="tflite",
|
||||
raw_model_file=raw_model_file_path,
|
||||
)
|
||||
my_shark_importer.import_debug(
|
||||
dir=tflite_model_name_dir,
|
||||
model_name=tflite_model_name,
|
||||
func_name="main",
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(
|
||||
tflite_model_name_dir,
|
||||
tflite_model_name + "_tflite" + ".mlir",
|
||||
)
|
||||
)
|
||||
np.save(
|
||||
os.path.join(tflite_model_name_dir, "hash"),
|
||||
np.array(mlir_hash),
|
||||
)
|
||||
|
||||
|
||||
def check_requirements(frontend):
|
||||
import importlib
|
||||
|
||||
@@ -265,10 +138,6 @@ def check_requirements(frontend):
|
||||
tv_spec = importlib.util.find_spec("torchvision")
|
||||
has_pkgs = tv_spec is not None
|
||||
|
||||
elif frontend in ["tensorflow", "tf"]:
|
||||
tf_spec = importlib.util.find_spec("tensorflow")
|
||||
has_pkgs = tf_spec is not None
|
||||
|
||||
return has_pkgs
|
||||
|
||||
|
||||
@@ -287,27 +156,11 @@ def gen_shark_files(modelname, frontend, tank_dir, importer_args):
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tf_model_list.csv"
|
||||
)
|
||||
custom_model_csv = tempfile.NamedTemporaryFile(
|
||||
dir=os.path.dirname(__file__),
|
||||
delete=True,
|
||||
)
|
||||
# Create a temporary .csv with only the desired entry.
|
||||
if frontend == "tf":
|
||||
with open(tf_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_tf_model(custom_model_csv.name, tank_dir, import_args)
|
||||
|
||||
elif frontend == "torch":
|
||||
if frontend == "torch":
|
||||
with open(torch_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
@@ -341,18 +194,6 @@ if __name__ == "__main__":
|
||||
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--tf_model_csv",
|
||||
# type=lambda x: is_valid_file(x),
|
||||
# default="./tank/tf_model_list.csv",
|
||||
# help="Contains the file with tf model name and args.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--tflite_model_csv",
|
||||
# type=lambda x: is_valid_file(x),
|
||||
# default="./tank/tflite/tflite_model_list.csv",
|
||||
# help="Contains the file with tf model name and args.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--ci_tank_dir",
|
||||
# type=bool,
|
||||
# default=False,
|
||||
@@ -369,11 +210,5 @@ if __name__ == "__main__":
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
|
||||
tflite_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
|
||||
)
|
||||
|
||||
save_torch_model(torch_model_csv, WORKDIR, import_args)
|
||||
# save_tf_model(tf_model_csv, WORKDIR, import_args)
|
||||
# save_tflite_model(tflite_model_csv, WORKDIR, import_args)
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
model_name, model_type
|
||||
albert-base-v2,hf
|
||||
bert-base-uncased,hf
|
||||
camembert-base,hf
|
||||
dbmdz/convbert-base-turkish-cased,hf
|
||||
distilbert-base-uncased,hf
|
||||
google/electra-small-discriminator,hf
|
||||
funnel-transformer/small,hf
|
||||
microsoft/layoutlm-base-uncased,hf
|
||||
google/mobilebert-uncased,hf
|
||||
microsoft/mpnet-base,hf
|
||||
roberta-base,hf
|
||||
resnet50,keras
|
||||
xlm-roberta-base,hf
|
||||
microsoft/MiniLM-L12-H384-uncased,TFhf
|
||||
funnel-transformer/small,hf
|
||||
microsoft/mpnet-base,hf
|
||||
facebook/convnext-tiny-224,img
|
||||
google/vit-base-patch16-224,img
|
||||
efficientnet-v2-s,keras
|
||||
bert-large-uncased,hf
|
||||
t5-base,tfhf_seq2seq
|
||||
t5-large,tfhf_seq2seq
|
||||
efficientnet_b0,keras
|
||||
efficientnet_b7,keras
|
||||
gpt2,hf_causallm
|
||||
t5-base,tfhf_seq2seq
|
||||
t5-large,tfhf_seq2seq
|
||||
|
Reference in New Issue
Block a user