mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 23:08:19 -05:00
Compare commits
1 Commits
diffusers-
...
sharding-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
429057f5e0 |
@@ -512,8 +512,8 @@ class ShardedVicuna(VicunaBase):
|
||||
n_devices=None,
|
||||
) -> None:
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.hidden_state_size_dict = {"vicuna": 4096, "llama2_7b": 4096, "llama2_13b" : 5120}
|
||||
self.n_layers_dict = {"vicuna": 32, "llama2_7b": 32, "llama2_13b" : 40}
|
||||
self.hidden_state_size_dict = {"vicuna": 4096, "llama2_7b": 4096, "llama2_13b" : 5120, "llama2_70b" : 8192}
|
||||
self.n_layers_dict = {"vicuna": 32, "llama2_7b": 32, "llama2_13b" : 40, "llama2_70b" : 80}
|
||||
super().__init__(
|
||||
model_name,
|
||||
hf_model_path,
|
||||
@@ -535,6 +535,11 @@ class ShardedVicuna(VicunaBase):
|
||||
self.dir_path.mkdir(parents=True, exist_ok=True)
|
||||
self.shark_model = self.compile(device=device)
|
||||
|
||||
def check_all_artifacts_present(self):
|
||||
file_list = [f"{i}_full" for i in range(self.n_layers_dict[self.model_name])] + ["norm", "embedding", "lmhead"]
|
||||
file_exists_list = [Path(f"{self.dir_name}/{x}.vmfb").exists() or Path(f"{self.dir_name}/{x}.mlir").exists() for x in file_list]
|
||||
return all(file_exists_list)
|
||||
|
||||
def get_tokenizer(self):
|
||||
kwargs = {}
|
||||
if "llama2" in self.model_name:
|
||||
@@ -1192,10 +1197,9 @@ class ShardedVicuna(VicunaBase):
|
||||
)
|
||||
if device_idx is None:
|
||||
if self.n_devices is not None:
|
||||
device_idx = idx % self.n_devices
|
||||
device_idx = (idx * self.n_devices) // self.n_layers_dict[self.model_name]
|
||||
else:
|
||||
device_idx = None
|
||||
print(device_idx, self.n_devices)
|
||||
module = SharkInference(
|
||||
None,
|
||||
device=device,
|
||||
@@ -1211,7 +1215,7 @@ class ShardedVicuna(VicunaBase):
|
||||
)
|
||||
if device_idx is None:
|
||||
if self.n_devices is not None:
|
||||
device_idx = idx % self.n_devices
|
||||
device_idx = (idx * self.n_devices) // self.n_layers_dict[self.model_name]
|
||||
else:
|
||||
device_idx = None
|
||||
module = SharkInference(
|
||||
@@ -1322,34 +1326,40 @@ class ShardedVicuna(VicunaBase):
|
||||
)
|
||||
|
||||
if self.precision in ["int4", "int8"]:
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Applying weight quantization..")
|
||||
weight_bit_width = 4 if self.precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(vicuna_model).layers,
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=self.weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
if not self.check_all_artifacts_present():
|
||||
print("Applying weight quantization..")
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
weight_bit_width = 4 if self.precision == "int4" else 8
|
||||
|
||||
quantize_model(
|
||||
get_model_impl(vicuna_model).layers,
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=self.weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
|
||||
print("Weight quantization applied.")
|
||||
|
||||
else:
|
||||
print("Skipping quantization, as all required artifacts are present")
|
||||
|
||||
placeholder_pkv_segment = tuple(
|
||||
(
|
||||
@@ -1448,7 +1458,11 @@ class ShardedVicuna(VicunaBase):
|
||||
)
|
||||
|
||||
if not compressed:
|
||||
shark_layers = [CompiledVicunaLayer(m) for m in modules]
|
||||
if self.n_devices is None:
|
||||
breakpoints = None
|
||||
else:
|
||||
breakpoints = [x for x in range(0,len(modules),(self.n_devices % 2) + (len(modules)//(self.n_devices)))][1:] + [len(modules)]
|
||||
shark_layers = [CompiledVicunaLayer(m, i, breakpoints) for (i, m) in enumerate(modules)]
|
||||
else:
|
||||
shark_layers = [CompiledEightLayerLayer(m) for m in modules]
|
||||
vicuna_model.model.compressedlayers = shark_layers
|
||||
|
||||
@@ -67,7 +67,6 @@ class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers, lmhead, embedding, norm):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
# assert len(layers) == len(model.model.layers)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers = layers
|
||||
@@ -169,9 +168,11 @@ class VicunaEmbeddingCompiled(torch.nn.Module):
|
||||
|
||||
|
||||
class CompiledVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
def __init__(self, shark_module, idx, breakpoints):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
self.idx = idx
|
||||
self.breakpoints = breakpoints
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -182,11 +183,11 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
if self.breakpoints is None:
|
||||
is_breakpoint = False
|
||||
else:
|
||||
is_breakpoint = self.idx + 1 in self.breakpoints
|
||||
if past_key_value is None:
|
||||
# hidden_states = hidden_states.detach()
|
||||
# attention_mask = attention_mask.detach()
|
||||
# position_ids = position_ids.detach()
|
||||
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
(
|
||||
@@ -194,17 +195,17 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
send_to_host=True,
|
||||
send_to_host=is_breakpoint,
|
||||
)
|
||||
|
||||
### send_to_host=True
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
### send_to_host=False
|
||||
# output0 = output[0]
|
||||
# output1 = output[1]
|
||||
# output2 = output[2]
|
||||
if is_breakpoint:
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
else:
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
@@ -214,10 +215,6 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
# hidden_states = hidden_states.detach()
|
||||
# attention_mask = attention_mask.detach()
|
||||
# position_ids = position_ids.detach()
|
||||
# pkv0 = past_key_value[0].detach()
|
||||
pkv0 = past_key_value[0]
|
||||
pkv1 = past_key_value[1]
|
||||
output = self.model(
|
||||
@@ -229,17 +226,17 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
send_to_host=True,
|
||||
send_to_host=is_breakpoint,
|
||||
)
|
||||
|
||||
### send_to_host=True
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
### send_to_host=False
|
||||
# output0 = output[0]
|
||||
# output1 = output[1]
|
||||
# output2 = output[2]
|
||||
if is_breakpoint:
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
else:
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
|
||||
@@ -264,7 +264,7 @@ def chat(
|
||||
for text, msg, exec_time in vicuna_model.generate(prompt, cli=cli):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
prefill_time = exec_time / 1000
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
|
||||
Reference in New Issue
Block a user