mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 23:08:19 -05:00
Compare commits
6 Commits
diffusers-
...
fix-shardi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6cefb450bf | ||
|
|
0c0329699a | ||
|
|
65bec26d76 | ||
|
|
7c1981b201 | ||
|
|
8f9e837d50 | ||
|
|
e43876cff5 |
@@ -12,6 +12,8 @@ import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from os import environ
|
||||
from dataclasses import dataclass
|
||||
from os import environ
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
@@ -510,6 +512,8 @@ class ShardedVicuna(VicunaBase):
|
||||
n_devices=None,
|
||||
) -> None:
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.hidden_state_size_dict = {"vicuna": 4096, "llama2_7b": 4096, "llama2_13b" : 5120, "llama2_70b" : 8192}
|
||||
self.n_layers_dict = {"vicuna": 32, "llama2_7b": 32, "llama2_13b" : 40, "llama2_70b" : 80}
|
||||
super().__init__(
|
||||
model_name,
|
||||
hf_model_path,
|
||||
@@ -531,6 +535,11 @@ class ShardedVicuna(VicunaBase):
|
||||
self.dir_path.mkdir(parents=True, exist_ok=True)
|
||||
self.shark_model = self.compile(device=device)
|
||||
|
||||
def check_all_artifacts_present(self):
|
||||
file_list = [f"{i}_full" for i in range(self.n_layers_dict[self.model_name])] + ["norm", "embedding", "lmhead"]
|
||||
file_exists_list = [Path(f"{self.dir_name}/{x}.vmfb").exists() or Path(f"{self.dir_name}/{x}.mlir").exists() for x in file_list]
|
||||
return all(file_exists_list)
|
||||
|
||||
def get_tokenizer(self):
|
||||
kwargs = {}
|
||||
if "llama2" in self.model_name:
|
||||
@@ -711,6 +720,27 @@ class ShardedVicuna(VicunaBase):
|
||||
device_idx = max(idx_votes, key=idx_votes.get)
|
||||
return device_idx
|
||||
|
||||
|
||||
def write_dynamic_inputs_lmhead(self, ir, sample_input_length):
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
precision_str = "f16"
|
||||
else:
|
||||
precision_str = "f32"
|
||||
lines = ir.splitlines()
|
||||
new_lines = []
|
||||
for line in lines:
|
||||
if f"%cst_0 =" in line:
|
||||
new_lines.append(line)
|
||||
new_lines.append("%c1 = arith.constant 1 : index")
|
||||
new_lines.append(f"%dim = tensor.dim %arg0, %c1 : tensor<1x?x{self.hidden_state_size_dict[self.model_name]}x{precision_str}>")
|
||||
else:
|
||||
line = re.sub(f"{sample_input_length}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
new_lines.append(line)
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
def compile_lmhead(
|
||||
self,
|
||||
lmh,
|
||||
@@ -775,14 +805,21 @@ class ShardedVicuna(VicunaBase):
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
"""
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
"""
|
||||
module = str(module)
|
||||
if self.precision in ["int4", "fp16"]:
|
||||
module = self.write_dynamic_inputs_lmhead(module, 137)
|
||||
filepath = Path(f"{self.dir_name}/lmhead.mlir")
|
||||
f_ = open(mlir_path, "w+")
|
||||
f_.write(module)
|
||||
f_.close()
|
||||
# download_public_file(
|
||||
# "gs://shark_tank/elias/compressed_sv/lmhead.mlir",
|
||||
# filepath.absolute(),
|
||||
@@ -795,7 +832,7 @@ class ShardedVicuna(VicunaBase):
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
@@ -883,7 +920,7 @@ class ShardedVicuna(VicunaBase):
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
@@ -964,7 +1001,7 @@ class ShardedVicuna(VicunaBase):
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
@@ -1160,7 +1197,7 @@ class ShardedVicuna(VicunaBase):
|
||||
)
|
||||
if device_idx is None:
|
||||
if self.n_devices is not None:
|
||||
device_idx = idx % self.n_devices
|
||||
device_idx = (idx * self.n_devices) // self.n_layers_dict[self.model_name]
|
||||
else:
|
||||
device_idx = None
|
||||
module = SharkInference(
|
||||
@@ -1168,7 +1205,7 @@ class ShardedVicuna(VicunaBase):
|
||||
device=device,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
@@ -1178,15 +1215,15 @@ class ShardedVicuna(VicunaBase):
|
||||
)
|
||||
if device_idx is None:
|
||||
if self.n_devices is not None:
|
||||
device_idx = idx % self.n_devices
|
||||
device_idx = (idx * self.n_devices) // self.n_layers_dict[self.model_name]
|
||||
else:
|
||||
device_idx = 0
|
||||
device_idx = None
|
||||
module = SharkInference(
|
||||
mlirs[idx],
|
||||
device=device,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{self.dir_name}/{idx}_full",
|
||||
@@ -1238,7 +1275,7 @@ class ShardedVicuna(VicunaBase):
|
||||
if self.n_devices is not None:
|
||||
device_idx = idx % self.n_devices
|
||||
else:
|
||||
device_idx = 0
|
||||
device_idx = None
|
||||
module = SharkInference(
|
||||
None,
|
||||
device=device,
|
||||
@@ -1256,7 +1293,7 @@ class ShardedVicuna(VicunaBase):
|
||||
if self.n_devices is not None:
|
||||
device_idx = idx % self.n_devices
|
||||
else:
|
||||
device_idx = 0
|
||||
device_idx = None
|
||||
module = SharkInference(
|
||||
mlirs[idx],
|
||||
device=device,
|
||||
@@ -1289,72 +1326,79 @@ class ShardedVicuna(VicunaBase):
|
||||
)
|
||||
|
||||
if self.precision in ["int4", "int8"]:
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Applying weight quantization..")
|
||||
weight_bit_width = 4 if self.precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(vicuna_model).layers,
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=self.weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
if not self.check_all_artifacts_present():
|
||||
print("Applying weight quantization..")
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
weight_bit_width = 4 if self.precision == "int4" else 8
|
||||
|
||||
quantize_model(
|
||||
get_model_impl(vicuna_model).layers,
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=self.weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
|
||||
print("Weight quantization applied.")
|
||||
|
||||
else:
|
||||
print("Skipping quantization, as all required artifacts are present")
|
||||
|
||||
placeholder_pkv_segment = tuple(
|
||||
(
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
for _ in range(8)
|
||||
)
|
||||
placeholder_pkv_full = tuple(
|
||||
(
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
for _ in range(32)
|
||||
for _ in range(self.n_layers_dict[self.model_name])
|
||||
)
|
||||
|
||||
placeholder_input0 = (
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]),
|
||||
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
|
||||
)
|
||||
|
||||
placeholder_input1 = (
|
||||
torch.zeros([1, 1, 4096]),
|
||||
torch.zeros([1, 1, self.hidden_state_size_dict[self.model_name]]),
|
||||
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
|
||||
torch.zeros([1, 1], dtype=torch.int64),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
|
||||
norm = VicunaNorm(vicuna_model.model.norm)
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.model\.norm(?:\.|\s|$)"
|
||||
)
|
||||
print(device_idx)
|
||||
# HC device_idx for non-layer vmfbs
|
||||
device_idx = 0
|
||||
norm = self.compile_norm(
|
||||
norm,
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]),
|
||||
device=self.device,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
@@ -1363,7 +1407,8 @@ class ShardedVicuna(VicunaBase):
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)"
|
||||
)
|
||||
print(device_idx)
|
||||
# HC device_idx for non-layer vmfbs
|
||||
device_idx = 0
|
||||
embeddings = self.compile_embedding(
|
||||
embeddings,
|
||||
(torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)),
|
||||
@@ -1375,10 +1420,11 @@ class ShardedVicuna(VicunaBase):
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.lm_head(?:\.|\s|$)"
|
||||
)
|
||||
print(device_idx)
|
||||
# HC device_idx for non-layer vmfbs
|
||||
device_idx = 0
|
||||
lmhead = self.compile_lmhead(
|
||||
lmhead,
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]),
|
||||
device=self.device,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
@@ -1412,7 +1458,8 @@ class ShardedVicuna(VicunaBase):
|
||||
)
|
||||
|
||||
if not compressed:
|
||||
shark_layers = [CompiledVicunaLayer(m) for m in modules]
|
||||
breakpoints = [x for x in range(0,len(modules),(self.n_devices % 2) + (len(modules)//(self.n_devices)))][1:] + [len(modules)]
|
||||
shark_layers = [CompiledVicunaLayer(m, i, breakpoints) for (i, m) in enumerate(modules)]
|
||||
else:
|
||||
shark_layers = [CompiledEightLayerLayer(m) for m in modules]
|
||||
vicuna_model.model.compressedlayers = shark_layers
|
||||
@@ -1667,12 +1714,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
new_lines = []
|
||||
|
||||
# Using a while loop and the pop method to avoid creating a copy of module
|
||||
if "llama2_13b" in self.model_name:
|
||||
pkv_tensor_shape = "tensor<1x40x?x128x"
|
||||
elif "llama2_70b" in self.model_name:
|
||||
pkv_tensor_shape = "tensor<1x8x?x128x"
|
||||
else:
|
||||
pkv_tensor_shape = "tensor<1x32x?x128x"
|
||||
pkv_tensor_shape = f"tensor<1x{self.n_layers_dict[self.model_name]}x?x128x"
|
||||
if self.precision in ["fp16", "int4", "int8"]:
|
||||
pkv_tensor_shape += "f16>"
|
||||
else:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import time
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
@@ -110,9 +111,11 @@ class LMHeadCompiled(torch.nn.Module):
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach()
|
||||
hidden_states_sample = hidden_states.detach()
|
||||
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -136,8 +139,9 @@ class VicunaNormCompiled(torch.nn.Module):
|
||||
hidden_states.detach()
|
||||
except:
|
||||
pass
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = self.model("forward", (hidden_states,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -158,15 +162,18 @@ class VicunaEmbeddingCompiled(torch.nn.Module):
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids.detach()
|
||||
output = self.model("forward", (input_ids,))
|
||||
output = self.model("forward", (input_ids,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CompiledVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
def __init__(self, shark_module, idx, breakpoints):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
self.idx = idx
|
||||
self.breakpoints = breakpoints
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -177,10 +184,12 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
is_breakpoint = self.idx + 1 in self.breakpoints
|
||||
if past_key_value is None:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
# hidden_states = hidden_states.detach()
|
||||
# attention_mask = attention_mask.detach()
|
||||
# position_ids = position_ids.detach()
|
||||
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
(
|
||||
@@ -188,11 +197,17 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
send_to_host=is_breakpoint,
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
if is_breakpoint:
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
else:
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
@@ -202,11 +217,12 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
# hidden_states = hidden_states.detach()
|
||||
# attention_mask = attention_mask.detach()
|
||||
# position_ids = position_ids.detach()
|
||||
# pkv0 = past_key_value[0].detach()
|
||||
pkv0 = past_key_value[0]
|
||||
pkv1 = past_key_value[1]
|
||||
output = self.model(
|
||||
"second_vicuna_forward",
|
||||
(
|
||||
@@ -216,11 +232,17 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
send_to_host=is_breakpoint,
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
if is_breakpoint:
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
else:
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
|
||||
@@ -355,11 +355,15 @@ def get_iree_module(
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
haldriver = ireert.get_driver(device)
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
hal_device_id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = hal_device_id
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
@@ -398,15 +402,16 @@ def load_vmfb_using_mmap(
|
||||
haldriver = ireert.get_driver(device)
|
||||
dl.log(f"ireert.get_driver()")
|
||||
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
hal_device_id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
dl.log(f"ireert.create_device()")
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
config.id = hal_device_id
|
||||
dl.log(f"ireert.Config()")
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
|
||||
Reference in New Issue
Block a user