mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
3 Commits
20231004.9
...
20230911.9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c5947c998f | ||
|
|
1026d37f28 | ||
|
|
faf2e7bd83 |
@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
|
||||
<summary>Prerequisites - Drivers </summary>
|
||||
|
||||
#### Install your Windows hardware drivers
|
||||
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
|
||||
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
|
||||
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from tqdm import tqdm
|
||||
from typing import List, Tuple
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
@@ -383,7 +382,8 @@ class VicunaBase(SharkLLMBase):
|
||||
if sharded:
|
||||
output = self.shark_model.forward(input_ids, is_first=is_first)
|
||||
else:
|
||||
output = self.shark_model("first_vicuna_forward", (input_ids,), send_to_host=False)
|
||||
output = self.shark_model("first_vicuna_forward", (input_ids,))
|
||||
out_tensor = torch.tensor(output[1:])
|
||||
|
||||
else:
|
||||
token = params["token"]
|
||||
@@ -402,7 +402,7 @@ class VicunaBase(SharkLLMBase):
|
||||
token = token.to(torch.int64).reshape([1, 1])
|
||||
second_input = (token,) + tuple(past_key_values)
|
||||
output = self.shark_model(
|
||||
"second_vicuna_forward", second_input, send_to_host=False
|
||||
"second_vicuna_forward", second_input
|
||||
)
|
||||
|
||||
if sharded:
|
||||
@@ -410,8 +410,8 @@ class VicunaBase(SharkLLMBase):
|
||||
_past_key_values = output["past_key_values"]
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
else:
|
||||
_logits = torch.tensor(output[0].to_host())
|
||||
_past_key_values = output[1:]
|
||||
_logits = torch.tensor(output[0])
|
||||
_past_key_values = torch.tensor(output[1:])
|
||||
_token = torch.argmax(_logits[:, -1, :], dim=1)
|
||||
|
||||
_detok = self.tokenizer.decode(_token, skip_special_tokens=False)
|
||||
@@ -1221,7 +1221,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
hf_auth_token: str = None,
|
||||
max_num_tokens=512,
|
||||
device="cpu",
|
||||
vulkan_target_triple="",
|
||||
precision="int8",
|
||||
vicuna_mlir_path=None,
|
||||
vicuna_vmfb_path=None,
|
||||
@@ -1231,7 +1230,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
download_vmfb=False,
|
||||
cache_vicunas=False,
|
||||
extra_args_cmd=[],
|
||||
device_id=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@@ -1250,8 +1248,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
print(f"[DEBUG] hf model name: {self.hf_model_path}")
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.vulkan_target_triple = vulkan_target_triple
|
||||
self.device_id = device_id
|
||||
self.precision = precision
|
||||
self.download_vmfb = download_vmfb
|
||||
self.vicuna_vmfb_path = vicuna_vmfb_path
|
||||
@@ -1272,14 +1268,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
safe_device = self.device.split("-")[0]
|
||||
if suffix in ["mlirbc", "mlir"]:
|
||||
return Path(f"{self.model_name}_{self.precision}.{suffix}")
|
||||
|
||||
target_triple = ""
|
||||
if self.vulkan_target_triple != "":
|
||||
target_triple = "_"
|
||||
target_triple += "_".join(self.vulkan_target_triple.split("-")[:-1])
|
||||
|
||||
return Path(
|
||||
f"{self.model_name}_{self.precision}_{safe_device}{target_triple}.{suffix}"
|
||||
f"{self.model_name}_{self.precision}_{safe_device}.{suffix}"
|
||||
)
|
||||
|
||||
def get_tokenizer(self):
|
||||
@@ -1420,7 +1410,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
single_file=True,
|
||||
)
|
||||
self.shark_model = get_vmfb_from_path(
|
||||
self.vicuna_vmfb_path, self.device, "tm_tensor", self.device_id
|
||||
self.vicuna_vmfb_path, self.device, "tm_tensor"
|
||||
)
|
||||
if self.shark_model is not None:
|
||||
print(f"[DEBUG] vmfb found at {self.vicuna_vmfb_path.absolute()}")
|
||||
@@ -1668,7 +1658,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
mlir_module=combined_module,
|
||||
device=self.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=self.device_id
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.vicuna_vmfb_path.parent.absolute(),
|
||||
@@ -1702,17 +1691,15 @@ class UnshardedVicuna(VicunaBase):
|
||||
res_tokens = []
|
||||
params = {"prompt": prompt, "is_first": True, "fv": self.shark_model}
|
||||
|
||||
prefill_st_time = time.time()
|
||||
generated_token_op = self.generate_new_token(
|
||||
params=params, sharded=False, cli=cli
|
||||
)
|
||||
prefill_time = time.time() - prefill_st_time
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["past_key_values"]
|
||||
detok = generated_token_op["detok"]
|
||||
yield detok, None, prefill_time
|
||||
yield detok, ""
|
||||
|
||||
res_tokens.append(token)
|
||||
if cli:
|
||||
@@ -1727,11 +1714,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
"sv": self.shark_model,
|
||||
}
|
||||
|
||||
decode_st_time = time.time()
|
||||
generated_token_op = self.generate_new_token(
|
||||
params=params, sharded=False, cli=cli
|
||||
)
|
||||
decode_time_ms = (time.time() - decode_st_time)*1000
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
@@ -1747,10 +1732,10 @@ class UnshardedVicuna(VicunaBase):
|
||||
else:
|
||||
if cli:
|
||||
print(f"{detok}", end=" ", flush=True)
|
||||
yield detok, None, decode_time_ms
|
||||
yield detok, ""
|
||||
|
||||
res_str = self.decode_tokens(res_tokens)
|
||||
yield res_str, "formatted", None
|
||||
yield res_str, "formatted"
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
@@ -1797,26 +1782,14 @@ start_message = {
|
||||
def create_prompt(model_name, history):
|
||||
global start_message
|
||||
system_message = start_message[model_name]
|
||||
if "llama2" in model_name:
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
conversation = "".join(
|
||||
[
|
||||
f"{B_INST} {item[0].strip()} {E_INST} {item[1].strip()} "
|
||||
for item in history[1:]
|
||||
]
|
||||
)
|
||||
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
|
||||
else:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
@@ -1824,37 +1797,11 @@ if __name__ == "__main__":
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
_extra_args = []
|
||||
device_id = None
|
||||
# Process vulkan target triple.
|
||||
# TODO: This feature should just be in a common utils for other LLMs and in general
|
||||
# any model run via SHARK for Vulkan backend.
|
||||
vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
if vulkan_target_triple != "":
|
||||
# vulkan target triple
|
||||
if args.iree_vulkan_target_triple != "":
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Step 1. Fetch the device ID.
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
get_vulkan_target_triple
|
||||
)
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
target_triple = get_vulkan_target_triple(vulkaninfo_list[id])
|
||||
if target_triple == vulkan_target_triple:
|
||||
device_id = id
|
||||
break
|
||||
id += 1
|
||||
|
||||
assert device_id, f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
# Step 2. Add a few flags targetting specific hardwares.
|
||||
if "rdna" in vulkan_target_triple:
|
||||
flags_to_add = [
|
||||
"--iree-spirv-index-bits=64",
|
||||
]
|
||||
_extra_args = _extra_args + flags_to_add
|
||||
|
||||
|
||||
vic = None
|
||||
if not args.sharded:
|
||||
@@ -1880,7 +1827,6 @@ if __name__ == "__main__":
|
||||
download_vmfb=args.download_vmfb,
|
||||
cache_vicunas=args.cache_vicunas,
|
||||
extra_args_cmd=_extra_args,
|
||||
device_id=device_id
|
||||
)
|
||||
else:
|
||||
if args.config is not None:
|
||||
@@ -1922,24 +1868,7 @@ if __name__ == "__main__":
|
||||
user_prompt = input("User: ")
|
||||
history.append([user_prompt, ""])
|
||||
prompt = create_prompt(args.model_name, history)
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in vic.generate(prompt, cli=True):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
elif "formatted" in msg:
|
||||
for text, msg in vic.generate(prompt, cli=True):
|
||||
if "formatted" in msg:
|
||||
print("Response:", text)
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
print(f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec")
|
||||
print("\nResponse:", text)
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
|
||||
@@ -28,9 +28,7 @@ parser = argparse.ArgumentParser(
|
||||
description="runs a falcon model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
|
||||
)
|
||||
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
|
||||
)
|
||||
@@ -51,7 +49,7 @@ parser.add_argument(
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=True,
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
@@ -61,20 +59,13 @@ parser.add_argument(
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model in cli mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication token for falcon-180B model.",
|
||||
)
|
||||
|
||||
|
||||
class Falcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="tiiuae/falcon-7b-instruct",
|
||||
hf_auth_token: str = None,
|
||||
hf_model_path,
|
||||
max_num_tokens=150,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
@@ -83,15 +74,6 @@ class Falcon(SharkLLMBase):
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print("hf_model_path: ", self.hf_model_path)
|
||||
|
||||
if "180b" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
""" HF auth token required for falcon-180b. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
@@ -99,14 +81,12 @@ class Falcon(SharkLLMBase):
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile()
|
||||
self.src_model = self.get_src_model()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
token=self.hf_auth_token,
|
||||
self.hf_model_path, trust_remote_code=True
|
||||
)
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token_id = 11
|
||||
@@ -114,18 +94,13 @@ class Falcon(SharkLLMBase):
|
||||
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
"device_map": "cpu" if args.device == "cpu" else "cuda:0",
|
||||
}
|
||||
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
return falcon_model
|
||||
|
||||
def compile(self):
|
||||
def compile_falcon(self):
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
@@ -147,39 +122,37 @@ class Falcon(SharkLLMBase):
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
|
||||
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if args.load_mlir_from_shark_tank:
|
||||
# Downloading MLIR from shark_tank
|
||||
print(f"[DEBUG] Trying to download mlir from shark_tank")
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
if self.falcon_mlir_path.exists():
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(
|
||||
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] generating MLIR locally")
|
||||
compilation_input_ids = torch.randint(
|
||||
low=1, high=10000, size=(1, 100)
|
||||
)
|
||||
@@ -199,7 +172,6 @@ class Falcon(SharkLLMBase):
|
||||
is_f16=self.precision == "fp16",
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
is_gptq=self.precision == "int4",
|
||||
)
|
||||
del model
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
@@ -219,9 +191,10 @@ class Falcon(SharkLLMBase):
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
f_ = open(self.falcon_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
|
||||
print(f"[DEBUG] writing mlir to file")
|
||||
with open(f"{self.model_name}.mlir", "wb") as f_:
|
||||
with redirect_stdout(f_):
|
||||
print(module.operation.get_asm())
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
@@ -231,9 +204,11 @@ class Falcon(SharkLLMBase):
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
self.falcon_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-spirv-index-bits=64",
|
||||
],
|
||||
debug=self.debug,
|
||||
)
|
||||
@@ -242,6 +217,10 @@ class Falcon(SharkLLMBase):
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile(self):
|
||||
falcon_shark_model = self.compile_falcon()
|
||||
return falcon_shark_model
|
||||
|
||||
def generate(self, prompt):
|
||||
model_inputs = self.tokenizer(
|
||||
prompt,
|
||||
@@ -490,26 +469,11 @@ if __name__ == "__main__":
|
||||
else Path(args.falcon_vmfb_path)
|
||||
)
|
||||
|
||||
if args.precision == "int4":
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"TheBloke/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct-GPTQ"
|
||||
)
|
||||
else:
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "tiiuae/falcon-180B-chat"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
|
||||
)
|
||||
|
||||
falcon = Falcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
"falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path="tiiuae/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
@@ -536,11 +500,7 @@ if __name__ == "__main__":
|
||||
prompt = input("Please enter the prompt text: ")
|
||||
print("\nPrompt Text: ", prompt)
|
||||
|
||||
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
|
||||
User: {prompt}
|
||||
Assistant:"""
|
||||
|
||||
res_str = falcon.generate(prompt_template)
|
||||
res_str = falcon.generate(prompt)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
print(
|
||||
|
||||
@@ -8,7 +8,7 @@ from shark.shark_downloader import download_public_file
|
||||
|
||||
# expects a Path / str as arg
|
||||
# returns None if path not found or SharkInference module
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
if not isinstance(vmfb_path, Path):
|
||||
vmfb_path = Path(vmfb_path)
|
||||
|
||||
@@ -20,7 +20,7 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
print("Device from get_vmfb_from_path - ", device)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
|
||||
None, device=device, mlir_dialect=mlir_dialect
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
@@ -28,13 +28,7 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
|
||||
|
||||
|
||||
def get_vmfb_from_config(
|
||||
shark_container,
|
||||
model,
|
||||
precision,
|
||||
device,
|
||||
vmfb_path,
|
||||
padding=None,
|
||||
device_id=None,
|
||||
shark_container, model, precision, device, vmfb_path, padding=None
|
||||
):
|
||||
vmfb_url = (
|
||||
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
|
||||
@@ -43,6 +37,4 @@ def get_vmfb_from_config(
|
||||
vmfb_url = vmfb_url + f"_{padding}"
|
||||
vmfb_url = vmfb_url + ".vmfb"
|
||||
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
|
||||
return get_vmfb_from_path(
|
||||
vmfb_path, device, "tm_tensor", device_id=device_id
|
||||
)
|
||||
return get_vmfb_from_path(vmfb_path, device, "tm_tensor")
|
||||
|
||||
@@ -52,7 +52,6 @@ datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
datas += collect_data_files("cv2")
|
||||
datas += [
|
||||
("src/utils/resources/prompts.json", "resources"),
|
||||
("src/utils/resources/model_db.json", "resources"),
|
||||
@@ -74,9 +73,6 @@ datas += [
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("diffusers") if "tests" not in x
|
||||
]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
|
||||
@@ -273,7 +273,6 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
resample_type,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# TODO: 1. Change `num_images_per_prompt`.
|
||||
|
||||
@@ -458,14 +458,6 @@ p.add_argument(
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device_allocator_heap_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify heap key for device caching allocator."
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
@@ -184,18 +184,12 @@ def compile_through_fx(
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if args.device_allocator_heap_key:
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={args.device_allocator_heap_key}",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
@@ -476,18 +470,7 @@ def get_available_devices():
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
vulkan_devices = get_devices_by_name("vulkan")
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
@@ -594,7 +577,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
)
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path_or_dict=custom_weights,
|
||||
checkpoint_path=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
num_in_channels=num_in_channels,
|
||||
@@ -844,8 +827,6 @@ def clear_all():
|
||||
elif os.name == "unix":
|
||||
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
if args.local_tank_cache != "":
|
||||
shutil.rmtree(args.local_tank_cache)
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
|
||||
@@ -156,9 +156,9 @@ if __name__ == "__main__":
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
# lora_train_web,
|
||||
# model_web,
|
||||
# model_config_web,
|
||||
lora_train_web,
|
||||
model_web,
|
||||
model_config_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
@@ -250,16 +250,16 @@ if __name__ == "__main__":
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
# with gr.TabItem(label="Model Manager", id=6):
|
||||
# model_web.render()
|
||||
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
# lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=8):
|
||||
with gr.TabItem(label="Model Manager", id=6):
|
||||
model_web.render()
|
||||
with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot (Experimental)", id=8):
|
||||
stablelm_chat.render()
|
||||
# with gr.TabItem(
|
||||
# label="Generate Sharding Config (Experimental)", id=9
|
||||
# ):
|
||||
# model_config_web.render()
|
||||
with gr.TabItem(
|
||||
label="Generate Sharding Config (Experimental)", id=9
|
||||
):
|
||||
model_config_web.render()
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
|
||||
@@ -8,7 +8,7 @@ from transformers import (
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
def user(message, history):
|
||||
@@ -69,28 +69,25 @@ start_message = {
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if "llama2" in model_name:
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
conversation = "".join(
|
||||
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
|
||||
)
|
||||
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
elif model_name in ["vicuna"]:
|
||||
if model_name in [
|
||||
"vicuna",
|
||||
"llama2_7b",
|
||||
"llama2_13b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
else:
|
||||
conversation = "".join(
|
||||
["".join([item[0], item[1]]) for item in history]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
@@ -143,7 +140,6 @@ def chat(
|
||||
global model_vmfb_key
|
||||
global vicuna_model
|
||||
|
||||
device_id = None
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
@@ -152,7 +148,6 @@ def chat(
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
@@ -163,53 +158,18 @@ def chat(
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
|
||||
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{precision}"
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
if device == "vulkan":
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if vulkan_target_triple == "":
|
||||
# We already have the device_id extracted via WebUI, so we directly use
|
||||
# that to find the target triple.
|
||||
vulkan_target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[device_id]
|
||||
)
|
||||
if args.iree_vulkan_target_triple != "":
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={vulkan_target_triple}"
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
if "rdna" in vulkan_target_triple:
|
||||
flags_to_add = [
|
||||
"--iree-spirv-index-bits=64",
|
||||
]
|
||||
_extra_args = _extra_args + flags_to_add
|
||||
|
||||
if device_id is None:
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[id]
|
||||
)
|
||||
if target_triple == vulkan_target_triple:
|
||||
device_id = id
|
||||
break
|
||||
id += 1
|
||||
|
||||
assert (
|
||||
device_id
|
||||
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
|
||||
print(f"Will use target triple : {vulkan_target_triple}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
@@ -228,47 +188,32 @@ def chat(
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=download_vmfb,
|
||||
load_mlir_from_shark_tank=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
if vicuna_model is None:
|
||||
sys.exit("Unable to instantiate the model object, exiting.")
|
||||
|
||||
prompt = create_prompt(model_name, history)
|
||||
|
||||
partial_text = ""
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in progress.tqdm(
|
||||
count = 0
|
||||
start_time = time.time()
|
||||
for text, msg in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
count += 1
|
||||
if "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
end_time = time.time()
|
||||
tokens_per_sec = count / (end_time - start_time)
|
||||
yield history, str(format(tokens_per_sec, ".2f")) + " tokens/sec"
|
||||
else:
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, f"Prefill: {prefill_time:.2f}"
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
yield history, ""
|
||||
|
||||
return history, ""
|
||||
|
||||
@@ -306,7 +251,6 @@ def llm_chat_api(InputData: dict):
|
||||
UnshardedVicuna,
|
||||
)
|
||||
|
||||
device_id = None
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
@@ -315,7 +259,6 @@ def llm_chat_api(InputData: dict):
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
@@ -328,7 +271,6 @@ def llm_chat_api(InputData: dict):
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=True,
|
||||
load_mlir_from_shark_tank=True,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
# TODO: add role dict for different models
|
||||
@@ -410,20 +352,21 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="int4",
|
||||
value="int8",
|
||||
choices=[
|
||||
"int4",
|
||||
"int8",
|
||||
"fp16",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
visible=True,
|
||||
)
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
@@ -455,11 +398,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[
|
||||
@@ -472,15 +411,10 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[
|
||||
@@ -493,7 +427,6 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
|
||||
192
inference/CMakeLists.txt
Normal file
192
inference/CMakeLists.txt
Normal file
@@ -0,0 +1,192 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
|
||||
project(sharkbackend LANGUAGES C CXX)
|
||||
|
||||
#
|
||||
# Options
|
||||
#
|
||||
|
||||
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
|
||||
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
|
||||
|
||||
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
|
||||
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
|
||||
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Dependencies
|
||||
#
|
||||
# FetchContent requires us to include the transitive closure of all
|
||||
# repos that we depend on so that we can override the tags.
|
||||
#
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-common
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
|
||||
GIT_TAG ${TRITON_COMMON_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-core
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
|
||||
GIT_TAG ${TRITON_CORE_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-backend
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
|
||||
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
|
||||
|
||||
#
|
||||
# The backend must be built into a shared library. Use an ldscript to
|
||||
# hide all symbols except for the TRITONBACKEND API.
|
||||
#
|
||||
configure_file(src/libtriton_dshark.ldscript libtriton_dshark.ldscript COPYONLY)
|
||||
|
||||
add_library(
|
||||
triton-dshark-backend SHARED
|
||||
src/dshark.cc
|
||||
#src/dshark_driver_module.c
|
||||
)
|
||||
|
||||
add_library(
|
||||
SharkBackend::triton-dshark-backend ALIAS triton-dshark-backend
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
triton-dshark-backend
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
|
||||
|
||||
add_subdirectory(thirdparty/srt EXCLUDE_FROM_ALL)
|
||||
|
||||
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
|
||||
iree_hal_hal
|
||||
iree_hal_cuda_cuda
|
||||
iree_hal_cuda_registration_registration
|
||||
iree_hal_vmvx_registration_registration
|
||||
iree_hal_dylib_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_hal_local_loaders_system_library_loader
|
||||
iree_hal_local_loaders_vmvx_module_loader
|
||||
)
|
||||
|
||||
target_compile_features(triton-dshark-backend PRIVATE cxx_std_11)
|
||||
|
||||
|
||||
target_link_libraries(
|
||||
triton-dshark-backend
|
||||
PRIVATE
|
||||
triton-core-serverapi # from repo-core
|
||||
triton-core-backendapi # from repo-core
|
||||
triton-core-serverstub # from repo-core
|
||||
triton-backend-utils # from repo-backend
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
set_target_properties(
|
||||
triton-dshark-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_dshark
|
||||
)
|
||||
else()
|
||||
set_target_properties(
|
||||
triton-dshark-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_dshark
|
||||
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_dshark.ldscript
|
||||
LINK_FLAGS "-Wl,--version-script libtriton_dshark.ldscript"
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
#
|
||||
# Install
|
||||
#
|
||||
include(GNUInstallDirs)
|
||||
set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/SharkBackend)
|
||||
|
||||
install(
|
||||
TARGETS
|
||||
triton-dshark-backend
|
||||
EXPORT
|
||||
triton-dshark-backend-targets
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
|
||||
)
|
||||
|
||||
install(
|
||||
EXPORT
|
||||
triton-dshark-backend-targets
|
||||
FILE
|
||||
SharkBackendTargets.cmake
|
||||
NAMESPACE
|
||||
SharkBackend::
|
||||
DESTINATION
|
||||
${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
include(CMakePackageConfigHelpers)
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_LIST_DIR}/cmake/SharkBackendConfig.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
|
||||
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
install(
|
||||
FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
|
||||
DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
#
|
||||
# Export from build tree
|
||||
#
|
||||
export(
|
||||
EXPORT triton-dshark-backend-targets
|
||||
FILE ${CMAKE_CURRENT_BINARY_DIR}/SharkBackendTargets.cmake
|
||||
NAMESPACE SharkBackend::
|
||||
)
|
||||
|
||||
export(PACKAGE SharkBackend)
|
||||
|
||||
100
inference/README.md
Normal file
100
inference/README.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# SHARK Triton Backend
|
||||
|
||||
The triton backend for shark.
|
||||
|
||||
# Build
|
||||
|
||||
Install SHARK
|
||||
|
||||
```
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
# skip above step if dshark is already installed
|
||||
cd SHARK/inference
|
||||
```
|
||||
|
||||
install dependancies
|
||||
|
||||
```
|
||||
apt-get install patchelf rapidjson-dev python3-dev
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
update the submodules of iree
|
||||
|
||||
```
|
||||
cd thirdparty/srt
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
Next, make the backend and install it
|
||||
|
||||
```
|
||||
cd ../..
|
||||
mkdir build && cd build
|
||||
cmake -DTRITON_ENABLE_GPU=ON \
|
||||
-DIREE_HAL_DRIVER_CUDA=ON \
|
||||
-DIREE_TARGET_BACKEND_CUDA=ON \
|
||||
-DMLIR_ENABLE_CUDA_RUNNER=ON \
|
||||
-DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \
|
||||
-DTRITON_BACKEND_REPO_TAG=r22.02 \
|
||||
-DTRITON_CORE_REPO_TAG=r22.02 \
|
||||
-DTRITON_COMMON_REPO_TAG=r22.02 ..
|
||||
make install
|
||||
```
|
||||
|
||||
# Incorporating into Triton
|
||||
|
||||
There are much more in depth explenations for the following steps in triton's documentation:
|
||||
https://github.com/triton-inference-server/server/blob/main/docs/compose.md#triton-with-unsupported-and-custom-backends
|
||||
|
||||
There should be a file at /build/install/backends/dshark/libtriton_dshark.so. You will need to copy it into your triton server image.
|
||||
More documentation is in the link above, but to create the docker image, you need to run the compose.py command in the triton-backend server repo
|
||||
|
||||
|
||||
To first build your image, clone the tritonserver repo.
|
||||
|
||||
```
|
||||
git clone https://github.com/triton-inference-server/server.git
|
||||
```
|
||||
|
||||
then run `compose.py` to build a docker compose file
|
||||
```
|
||||
cd server
|
||||
python3 compose.py --repoagent checksum --dry-run
|
||||
```
|
||||
|
||||
Because dshark is a third party backend, you will need to manually modify the `Dockerfile.compose` to include the dshark backend. To do this, in the Dockerfile.compose file produced, copy this line.
|
||||
the dshark backend will be located in the build folder from earlier under `/build/install/backends`
|
||||
|
||||
```
|
||||
COPY /path/to/build/install/backends/dshark /opt/tritonserver/backends/dshark
|
||||
```
|
||||
|
||||
Next run
|
||||
```
|
||||
docker build -t tritonserver_custom -f Dockerfile.compose .
|
||||
docker run -it --gpus=1 --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
|
||||
```
|
||||
|
||||
where `path/to/model_repos` is where you are storing the models you want to run
|
||||
|
||||
if your not using gpus, omit `--gpus=1`
|
||||
|
||||
```
|
||||
docker run -it --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
|
||||
```
|
||||
|
||||
# Setting up a model
|
||||
|
||||
to include a model in your backend, add a directory with your model name to your model repository directory. examples of models can be seen here: https://github.com/triton-inference-server/backend/tree/main/examples/model_repos/minimal_models
|
||||
|
||||
make sure to adjust the input correctly in the config.pbtxt file, and save a vmfb file under 1/model.vmfb
|
||||
|
||||
# CUDA
|
||||
|
||||
if you're having issues with cuda, make sure your correct drivers are installed, and that `nvidia-smi` works, and also make sure that the nvcc compiler is on the path.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
39
inference/cmake/SharkBackendConfig.cmake.in
Normal file
39
inference/cmake/SharkBackendConfig.cmake.in
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include(CMakeFindDependencyMacro)
|
||||
|
||||
get_filename_component(
|
||||
SHARKBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${SHARKBACKEND_CMAKE_DIR})
|
||||
|
||||
if(NOT TARGET SharkBackend::triton-dshark-backend)
|
||||
include("${SHARKBACKEND_CMAKE_DIR}/SharkBackendTargets.cmake")
|
||||
endif()
|
||||
|
||||
set(SHARKBACKEND_LIBRARIES SharkBackend::triton-dshark-backend)
|
||||
1409
inference/src/dshark.cc
Normal file
1409
inference/src/dshark.cc
Normal file
File diff suppressed because it is too large
Load Diff
30
inference/src/libtriton_dshark.ldscript
Normal file
30
inference/src/libtriton_dshark.ldscript
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
{
|
||||
global:
|
||||
TRITONBACKEND_*;
|
||||
local: *;
|
||||
};
|
||||
1
inference/thirdparty/shark-runtime
vendored
Submodule
1
inference/thirdparty/shark-runtime
vendored
Submodule
Submodule inference/thirdparty/shark-runtime added at 7b82d90c72
@@ -16,7 +16,7 @@ iree-tools-tf
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tf-nightly
|
||||
keras-nightly
|
||||
keras
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
|
||||
@@ -25,7 +25,7 @@ diffusers
|
||||
accelerate
|
||||
scipy
|
||||
ftfy
|
||||
gradio==3.44.3
|
||||
gradio
|
||||
altair
|
||||
omegaconf
|
||||
# 0.3.2 doesn't have binaries for arm64
|
||||
|
||||
@@ -300,7 +300,6 @@ def compile_module_to_flatbuffer(
|
||||
args += get_iree_common_args(debug=debug)
|
||||
args += get_model_specific_args()
|
||||
args += extra_args
|
||||
args += shark_args.additional_compile_args
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "auto"
|
||||
@@ -404,11 +403,6 @@ def load_vmfb_using_mmap(
|
||||
dl.log(f"mmap {flatbuffer_blob_or_path}")
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
dl.log(f"ireert.SystemContext created")
|
||||
if "vulkan" in device:
|
||||
# Vulkan pipeline creation consumes significant amount of time.
|
||||
print(
|
||||
"\tCompiling Vulkan shaders. This may take a few minutes."
|
||||
)
|
||||
ctx.add_vm_module(mmaped_vmfb)
|
||||
dl.log(f"module initialized")
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
@@ -503,9 +497,9 @@ def export_iree_module_to_vmfb(
|
||||
)
|
||||
module_name = f"{mlir_dialect}_{device_name}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
with open(filename, "wb") as f:
|
||||
f.write(flatbuffer_blob)
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
return filename
|
||||
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ def get_extensions(triple):
|
||||
]
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("VK_KHR_cooperative_matrix")
|
||||
ext.append("VK_NV_cooperative_matrix")
|
||||
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
|
||||
ext.append("VK_KHR_shader_integer_dot_product")
|
||||
return make_ext_list(ext_list=ext)
|
||||
@@ -244,7 +244,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>"
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
|
||||
]
|
||||
|
||||
if product == "rx5700xt":
|
||||
@@ -465,9 +465,9 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
|
||||
]
|
||||
|
||||
elif arch == "adreno":
|
||||
@@ -528,7 +528,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cmc = ""
|
||||
for case in v:
|
||||
cmc += f"#vk.coop_matrix_props<{case}>, "
|
||||
res += f"cooperativeMatrixPropertiesKHR = [{cmc[:-2]}], "
|
||||
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
|
||||
else:
|
||||
res += f"{k} = {get_comma_sep_str(v)}, "
|
||||
else:
|
||||
|
||||
@@ -23,19 +23,11 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_all_vulkan_devices():
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver("vulkan")
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return [d["name"] for d in device_list_src]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
|
||||
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
@@ -186,7 +178,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
@functools.cache
|
||||
def get_iree_vulkan_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={shark_args.vulkan_large_heap_block_size}",
|
||||
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
|
||||
f"--vulkan_vma_allocator={'true' if shark_args.vulkan_vma_allocator else 'false'}",
|
||||
]
|
||||
return vulkan_runtime_flags
|
||||
|
||||
|
||||
@@ -14,21 +14,8 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
|
||||
class SplitStrToListAction(argparse.Action):
|
||||
def __init__(self, option_strings, dest, *args, **kwargs):
|
||||
super(SplitStrToListAction, self).__init__(
|
||||
option_strings=option_strings, dest=dest, *args, **kwargs
|
||||
)
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
del parser, option_string
|
||||
setattr(namespace, self.dest, shlex.split(values[0]))
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="SHARK runner.")
|
||||
|
||||
parser.add_argument(
|
||||
@@ -37,13 +24,6 @@ parser.add_argument(
|
||||
default="cpu",
|
||||
help="Device on which shark_runner runs. options are cpu, cuda, and vulkan",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional_compile_args",
|
||||
default=list(),
|
||||
nargs=1,
|
||||
action=SplitStrToListAction,
|
||||
help="Additional arguments to pass to the compiler. These are appended as the last arguments.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_tf32",
|
||||
type=bool,
|
||||
@@ -153,6 +133,13 @@ parser.add_argument(
|
||||
help="Profiles vulkan device and collects the .rdc info.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="2073741824",
|
||||
help="Flag for setting VMA preferredLargeHeapBlockSize for "
|
||||
"vulkan device, default is 4G.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
@@ -160,4 +147,11 @@ parser.add_argument(
|
||||
help="Flag for disabling vulkan validation layers when benchmarking.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_vma_allocator",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling / disabling Vulkan VMA Allocator.",
|
||||
)
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
|
||||
@@ -451,65 +451,6 @@ def transform_fx(fx_g, quantized=False):
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
def gptq_transforms(fx_g):
|
||||
import torch
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.ones,
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("device") == torch.device(device="cuda:0"):
|
||||
updated_kwargs = node.kwargs.copy()
|
||||
updated_kwargs["device"] = torch.device(device="cpu")
|
||||
node.kwargs = updated_kwargs
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("dtype") == torch.bfloat16:
|
||||
updated_kwargs = node.kwargs.copy()
|
||||
updated_kwargs["dtype"] = torch.float16
|
||||
node.kwargs = updated_kwargs
|
||||
|
||||
# Inputs of aten.native_layer_norm should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.native_layer_norm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (
|
||||
new_node_arg0,
|
||||
node.args[1],
|
||||
node.args[2],
|
||||
node.args[3],
|
||||
node.args[4],
|
||||
)
|
||||
|
||||
# Downcasting the result of native_layer_norm back to fp16.
|
||||
if node.name.startswith("getitem"):
|
||||
with fx_g.graph.inserting_before(node):
|
||||
if node.args[0].target in [
|
||||
torch.ops.aten.native_layer_norm
|
||||
]:
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node,),
|
||||
kwargs={"dtype": torch.float32},
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
new_node.kwargs = {"dtype": torch.float32}
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
# Doesn't replace the None type.
|
||||
def change_fx_graph_return_to_tuple(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
@@ -563,7 +504,6 @@ def import_with_fx(
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
precision="fp32",
|
||||
is_gptq=False,
|
||||
):
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@@ -644,7 +584,7 @@ def import_with_fx(
|
||||
torch.ops.aten.index_add,
|
||||
torch.ops.aten.index_add_,
|
||||
]
|
||||
if precision in ["int4", "int8"] and not is_gptq:
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
block_quant_layer_level_manager,
|
||||
)
|
||||
@@ -713,10 +653,6 @@ def import_with_fx(
|
||||
add_upcast(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if is_gptq:
|
||||
gptq_transforms(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if mlir_type == "fx":
|
||||
return fx_g
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import collections
|
||||
import json
|
||||
import os
|
||||
import psutil
|
||||
import resource
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
@@ -167,7 +168,7 @@ def save_json(data, filename):
|
||||
|
||||
|
||||
def collect_huggingface_logits(
|
||||
model_name: str, max_seq_len: int, to_save_json: bool
|
||||
model_name: str, max_seq_len: int, save_json: bool
|
||||
) -> Tuple[float, float]:
|
||||
# Load
|
||||
t0 = time.time()
|
||||
@@ -193,11 +194,11 @@ def collect_huggingface_logits(
|
||||
for idx, tokens in enumerate(tokenized_prompts):
|
||||
print("prompt: {}".format(PROMPTS[idx]))
|
||||
logits = run_huggingface_model(model_wrapper, tokens)
|
||||
if to_save_json:
|
||||
if save_json:
|
||||
results.append([PROMPTS[idx], logits[0].tolist()])
|
||||
run_time = time.time() - t0
|
||||
print("--- Took {} seconds to run Huggingface.".format(run_time))
|
||||
if to_save_json:
|
||||
if save_json:
|
||||
save_json(results, "/tmp/huggingface.json")
|
||||
run_memory_info = get_memory_info()
|
||||
return {
|
||||
@@ -214,10 +215,7 @@ def collect_huggingface_logits(
|
||||
|
||||
|
||||
def collect_shark_logits(
|
||||
model_name: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
to_save_json: bool,
|
||||
model_name: str, max_seq_len: int, recompile_shark: bool, save_json: bool
|
||||
) -> Tuple[float, float]:
|
||||
# Load
|
||||
t0 = time.time()
|
||||
@@ -248,11 +246,11 @@ def collect_shark_logits(
|
||||
print("prompt: {}".format(PROMPTS[idx]))
|
||||
logits = run_shark_model(model_wrapper, tokens)
|
||||
lst = [e.tolist() for e in logits]
|
||||
if to_save_json:
|
||||
if save_json:
|
||||
results.append([PROMPTS[idx], lst])
|
||||
run_time = time.time() - t0
|
||||
print("--- Took {} seconds to run Shark.".format(run_time))
|
||||
if to_save_json:
|
||||
if save_json:
|
||||
save_json(results, "/tmp/shark.json")
|
||||
platform_postfix = "-compile" if recompile_shark else "-precompiled"
|
||||
run_memory_info = get_memory_info()
|
||||
|
||||
Reference in New Issue
Block a user