Enable downloading vmfb/mlir for webui (#1807)

This commit is contained in:
jinchen62
2023-08-31 11:05:47 -07:00
committed by GitHub
parent 3601dc7c3b
commit 4c3d8a0a7f
3 changed files with 260 additions and 252 deletions

View File

@@ -1272,7 +1272,7 @@ class UnshardedVicuna(VicunaBase):
self.vicuna_vmfb_path = self.get_model_path(suffix="vmfb")
self.tokenizer = self.get_tokenizer()
self.cache_vicunas = cache_vicunas
self.compile(download_vmfb)
self.compile()
def get_model_path(self, suffix="mlir"):
safe_device = self.device.split("-")[0]
@@ -1404,13 +1404,13 @@ class UnshardedVicuna(VicunaBase):
return "\n".join(new_lines)
def compile(self, download_vmfb=False):
def compile(self):
# Testing : DO NOT Download Vmfbs if not found. Modify later
# download vmfbs for A100
print(
f"Looking into gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}"
)
if not self.vicuna_vmfb_path.exists() and download_vmfb:
if not self.vicuna_vmfb_path.exists() and self.download_vmfb:
print(
f"Looking into gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}"
)
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}",
self.vicuna_vmfb_path.absolute(),
@@ -1423,245 +1423,237 @@ class UnshardedVicuna(VicunaBase):
print(f"[DEBUG] vmfb found at {self.vicuna_vmfb_path.absolute()}")
return
print(f"[DEBUG] vmfb not found at {self.vicuna_vmfb_path.absolute()}")
if self.vicuna_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
with open(self.vicuna_mlir_path, "rb") as f:
combined_module = f.read()
else:
print(
f"[DEBUG] mlir not found at {self.vicuna_mlir_path.absolute()}"
)
mlir_generated = False
if self.load_mlir_from_shark_tank:
# download MLIR from shark tank
for suffix in ["mlirbc", "mlir"]:
self.vicuna_mlir_path = self.get_model_path(suffix)
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}",
self.vicuna_mlir_path.absolute(),
single_file=True,
)
if self.vicuna_mlir_path.exists():
with open(self.vicuna_mlir_path, "rb") as f:
combined_module = f.read()
mlir_generated = True
break
self.vicuna_mlir_path = self.get_model_path("mlir")
if not mlir_generated:
print(
f"[DEBUG] failed to download {self.vicuna_mlir_path.name} from shark tank"
)
print(f"[DEBUG] vmfb not found")
mlir_generated = False
for suffix in ["mlirbc", "mlir"]:
self.vicuna_mlir_path = self.get_model_path(suffix)
if not self.vicuna_mlir_path.exists() and self.load_mlir_from_shark_tank:
print(
f"Looking into gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}"
)
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}",
self.vicuna_mlir_path.absolute(),
single_file=True,
)
if self.vicuna_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
with open(self.vicuna_mlir_path, "rb") as f:
combined_module = f.read()
mlir_generated = True
break
if not mlir_generated:
print("[DEBUG] generating mlir on device")
# Select a compilation prompt such that the resulting input_ids
# from the model's tokenizer has shape [1, 19]
if self.model_name == "codegen":
compilation_prompt = "def hello_world():\n print('Hello World')\n print('Hello World')"
else:
compilation_prompt = "".join(["0" for _ in range(17)])
if not mlir_generated:
print(f"[DEBUG] mlir not found")
print("[DEBUG] generating mlir on device")
# Select a compilation prompt such that the resulting input_ids
# from the model's tokenizer has shape [1, 19]
if self.model_name == "codegen":
compilation_prompt = "def hello_world():\n print('Hello World')\n print('Hello World')"
else:
compilation_prompt = "".join(["0" for _ in range(17)])
first_model_path = f"first_{self.model_name}_{self.precision}.mlir"
if Path(first_model_path).exists():
print(f"loading {first_model_path}")
with open(Path(first_model_path), "r") as f:
first_module = f.read()
first_model_path = f"first_{self.model_name}_{self.precision}.mlir"
if Path(first_model_path).exists():
print(f"loading {first_model_path}")
with open(Path(first_model_path), "r") as f:
first_module = f.read()
else:
# generate first vicuna
compilation_input_ids = self.tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=is_f16,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[
0
] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
first_module = None
print(f"[DEBUG] generating torch mlir")
if self.precision in ["int4", "int8"]:
first_module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
first_module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
# generate first vicuna
compilation_input_ids = self.tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(
first_module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
del firstVicunaCompileInput
gc.collect()
print(
"[DEBUG] successfully generated first vicuna linalg mlir"
)
first_module = self.write_in_dynamic_inputs0(
str(first_module), dynamic_input_size=19
)
if self.cache_vicunas:
with open(first_model_path, "w+") as f:
f.write(first_module)
print("Finished writing IR after dynamic")
print(f"[DEBUG] Starting generation of second llama")
second_model_path = f"second_{self.model_name}_{self.precision}.mlir"
if Path(second_model_path).exists():
print(f"loading {second_model_path}")
with open(Path(second_model_path), "r") as f:
second_module = f.read()
else:
# generate second vicuna
compilation_input_ids = torch.zeros(
[1, 1], dtype=torch.int64
)
if self.model_name == "llama2_13b":
dim1 = 40
total_tuple = 80
elif self.model_name == "llama2_70b":
dim1 = 8
total_tuple = 160
else:
dim1 = 32
total_tuple = 64
pkv = tuple(
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
for _ in range(total_tuple)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
if self.model_name == "llama2_13b":
model = SecondVicuna13B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=is_f16,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
elif self.model_name == "llama2_70b":
model = SecondVicuna70B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
del model
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[
0
] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
first_module = None
print(f"[DEBUG] generating torch mlir")
if self.precision in ["int4", "int8"]:
first_module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
first_module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
first_module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
del firstVicunaCompileInput
gc.collect()
print(
"[DEBUG] successfully generated first vicuna linalg mlir"
)
first_module = self.write_in_dynamic_inputs0(
str(first_module), dynamic_input_size=19
)
if self.cache_vicunas:
with open(first_model_path, "w+") as f:
f.write(first_module)
print("Finished writing IR after dynamic")
print(f"[DEBUG] Starting generation of second llama")
second_model_path = f"second_{self.model_name}_{self.precision}.mlir"
if Path(second_model_path).exists():
print(f"loading {second_model_path}")
with open(Path(second_model_path), "r") as f:
second_module = f.read()
else:
# generate second vicuna
compilation_input_ids = torch.zeros(
[1, 1], dtype=torch.int64
model = SecondVicuna7B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
if self.model_name == "llama2_13b":
dim1 = 40
total_tuple = 80
elif self.model_name == "llama2_70b":
dim1 = 8
total_tuple = 160
else:
dim1 = 32
total_tuple = 64
pkv = tuple(
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
for _ in range(total_tuple)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
if self.model_name == "llama2_13b":
model = SecondVicuna13B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
elif self.model_name == "llama2_70b":
model = SecondVicuna70B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = SecondVicuna7B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=is_f16,
precision=self.precision,
f16_input_mask=[False] + [True] * total_tuple,
mlir_type="torchscript",
)
del model
if self.precision in ["fp16", "int4"]:
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
f16_input_mask=[False] + [True] * total_tuple,
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[i] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
print(f"[DEBUG] generating torch mlir")
if self.precision in ["int4", "int8"]:
second_module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
second_module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
second_module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
del secondVicunaCompileInput
gc.collect()
print(
"[DEBUG] successfully generated second vicuna linalg mlir"
)
second_module = self.write_in_dynamic_inputs1(
str(second_module)
)
if self.cache_vicunas:
with open(second_model_path, "w+") as f:
f.write(second_module)
print("Finished writing IR after dynamic")
combined_module = self.combine_mlir_scripts(
first_module,
second_module,
self.vicuna_mlir_path,
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=is_f16,
precision=self.precision,
f16_input_mask=[False] + [True] * total_tuple,
mlir_type="torchscript",
)
del first_module, second_module
del model
if self.precision in ["fp16", "int4"]:
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
f16_input_mask=[False] + [True] * total_tuple,
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[i] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
print(f"[DEBUG] generating torch mlir")
if self.precision in ["int4", "int8"]:
second_module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
second_module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
second_module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
del secondVicunaCompileInput
gc.collect()
print(
"[DEBUG] successfully generated second vicuna linalg mlir"
)
second_module = self.write_in_dynamic_inputs1(
str(second_module)
)
if self.cache_vicunas:
with open(second_model_path, "w+") as f:
f.write(second_module)
print("Finished writing IR after dynamic")
combined_module = self.combine_mlir_scripts(
first_module,
second_module,
self.vicuna_mlir_path,
)
del first_module, second_module
print(self.device)
if "rocm" in self.device:

View File

@@ -109,7 +109,7 @@ with gr.Blocks() as minigpt4_web:
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=0.5):
with gr.Column():
image = gr.Image(type="pil")
upload_button = gr.Button(
value="Upload & Start Chat",

View File

@@ -160,14 +160,15 @@ def chat(
model,
device,
precision,
download_vmfb,
config_file,
cli=False,
progress=gr.Progress(),
):
global past_key_values
global model_vmfb_key
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
@@ -177,6 +178,8 @@ def chat(
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")
@@ -194,18 +197,6 @@ def chat(
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
elif "rocm" in device:
device = "rocm"
if new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
@@ -237,6 +228,8 @@ def chat(
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
)
# else:
@@ -360,6 +353,8 @@ def llm_chat_api(InputData: dict):
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
)
# TODO: add role dict for different models
@@ -430,15 +425,14 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
# print(supported_devices)
devices = gr.Dropdown(
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
if enabled
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
# multiselect=True,
# multiselect=True,
)
precision = gr.Radio(
label="Precision",
@@ -450,7 +444,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
],
visible=True,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
interactive=True,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Row(visible=False):
with gr.Group():
@@ -485,7 +485,15 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, devices, precision, config_file],
inputs=[
system_msg,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
queue=True,
)
@@ -493,7 +501,15 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, devices, precision, config_file],
inputs=[
system_msg,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
queue=True,
)