Compare commits

...

1 Commits

Author SHA1 Message Date
Elias Joseph
429057f5e0 finilized fixes for sharded llama2 2023-12-06 01:45:44 -08:00
3 changed files with 73 additions and 62 deletions

View File

@@ -512,8 +512,8 @@ class ShardedVicuna(VicunaBase):
n_devices=None,
) -> None:
self.hf_auth_token = hf_auth_token
self.hidden_state_size_dict = {"vicuna": 4096, "llama2_7b": 4096, "llama2_13b" : 5120}
self.n_layers_dict = {"vicuna": 32, "llama2_7b": 32, "llama2_13b" : 40}
self.hidden_state_size_dict = {"vicuna": 4096, "llama2_7b": 4096, "llama2_13b" : 5120, "llama2_70b" : 8192}
self.n_layers_dict = {"vicuna": 32, "llama2_7b": 32, "llama2_13b" : 40, "llama2_70b" : 80}
super().__init__(
model_name,
hf_model_path,
@@ -535,6 +535,11 @@ class ShardedVicuna(VicunaBase):
self.dir_path.mkdir(parents=True, exist_ok=True)
self.shark_model = self.compile(device=device)
def check_all_artifacts_present(self):
file_list = [f"{i}_full" for i in range(self.n_layers_dict[self.model_name])] + ["norm", "embedding", "lmhead"]
file_exists_list = [Path(f"{self.dir_name}/{x}.vmfb").exists() or Path(f"{self.dir_name}/{x}.mlir").exists() for x in file_list]
return all(file_exists_list)
def get_tokenizer(self):
kwargs = {}
if "llama2" in self.model_name:
@@ -1192,10 +1197,9 @@ class ShardedVicuna(VicunaBase):
)
if device_idx is None:
if self.n_devices is not None:
device_idx = idx % self.n_devices
device_idx = (idx * self.n_devices) // self.n_layers_dict[self.model_name]
else:
device_idx = None
print(device_idx, self.n_devices)
module = SharkInference(
None,
device=device,
@@ -1211,7 +1215,7 @@ class ShardedVicuna(VicunaBase):
)
if device_idx is None:
if self.n_devices is not None:
device_idx = idx % self.n_devices
device_idx = (idx * self.n_devices) // self.n_layers_dict[self.model_name]
else:
device_idx = None
module = SharkInference(
@@ -1322,34 +1326,40 @@ class ShardedVicuna(VicunaBase):
)
if self.precision in ["int4", "int8"]:
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Applying weight quantization..")
weight_bit_width = 4 if self.precision == "int4" else 8
quantize_model(
get_model_impl(vicuna_model).layers,
dtype=torch.float32,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float_scale",
weight_quant_granularity="per_group",
weight_group_size=self.weight_group_size,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048,
)
print("Weight quantization applied.")
if not self.check_all_artifacts_present():
print("Applying weight quantization..")
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
weight_bit_width = 4 if self.precision == "int4" else 8
quantize_model(
get_model_impl(vicuna_model).layers,
dtype=torch.float32,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float_scale",
weight_quant_granularity="per_group",
weight_group_size=self.weight_group_size,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048,
)
print("Weight quantization applied.")
else:
print("Skipping quantization, as all required artifacts are present")
placeholder_pkv_segment = tuple(
(
@@ -1448,7 +1458,11 @@ class ShardedVicuna(VicunaBase):
)
if not compressed:
shark_layers = [CompiledVicunaLayer(m) for m in modules]
if self.n_devices is None:
breakpoints = None
else:
breakpoints = [x for x in range(0,len(modules),(self.n_devices % 2) + (len(modules)//(self.n_devices)))][1:] + [len(modules)]
shark_layers = [CompiledVicunaLayer(m, i, breakpoints) for (i, m) in enumerate(modules)]
else:
shark_layers = [CompiledEightLayerLayer(m) for m in modules]
vicuna_model.model.compressedlayers = shark_layers

View File

@@ -67,7 +67,6 @@ class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
# assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers = layers
@@ -169,9 +168,11 @@ class VicunaEmbeddingCompiled(torch.nn.Module):
class CompiledVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
def __init__(self, shark_module, idx, breakpoints):
super().__init__()
self.model = shark_module
self.idx = idx
self.breakpoints = breakpoints
def forward(
self,
@@ -182,11 +183,11 @@ class CompiledVicunaLayer(torch.nn.Module):
output_attentions=False,
use_cache=True,
):
if self.breakpoints is None:
is_breakpoint = False
else:
is_breakpoint = self.idx + 1 in self.breakpoints
if past_key_value is None:
# hidden_states = hidden_states.detach()
# attention_mask = attention_mask.detach()
# position_ids = position_ids.detach()
output = self.model(
"first_vicuna_forward",
(
@@ -194,17 +195,17 @@ class CompiledVicunaLayer(torch.nn.Module):
attention_mask,
position_ids,
),
send_to_host=True,
send_to_host=is_breakpoint,
)
### send_to_host=True
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
### send_to_host=False
# output0 = output[0]
# output1 = output[1]
# output2 = output[2]
if is_breakpoint:
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
else:
output0 = output[0]
output1 = output[1]
output2 = output[2]
return (
output0,
@@ -214,10 +215,6 @@ class CompiledVicunaLayer(torch.nn.Module):
),
)
else:
# hidden_states = hidden_states.detach()
# attention_mask = attention_mask.detach()
# position_ids = position_ids.detach()
# pkv0 = past_key_value[0].detach()
pkv0 = past_key_value[0]
pkv1 = past_key_value[1]
output = self.model(
@@ -229,17 +226,17 @@ class CompiledVicunaLayer(torch.nn.Module):
pkv0,
pkv1,
),
send_to_host=True,
send_to_host=is_breakpoint,
)
### send_to_host=True
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
### send_to_host=False
# output0 = output[0]
# output1 = output[1]
# output2 = output[2]
if is_breakpoint:
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
else:
output0 = output[0]
output1 = output[1]
output2 = output[2]
return (
output0,

View File

@@ -264,7 +264,7 @@ def chat(
for text, msg, exec_time in vicuna_model.generate(prompt, cli=cli):
if msg is None:
if is_first:
prefill_time = exec_time
prefill_time = exec_time / 1000
is_first = False
else:
total_time_ms += exec_time