mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
branch to run vicuna with 4 shards
This commit is contained in:
@@ -4,7 +4,9 @@ import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import iree.runtime
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
@@ -127,6 +129,328 @@ brevitas_matmul_rhs_group_quant_library = [
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
|
||||
class EightLayerLayerSV(torch.nn.Module):
|
||||
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
assert(len(layers) == 8)
|
||||
self.layers = layers
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids, pkv00, pkv01, pkv10, pkv11, pkv20, pkv21, pkv30, pkv31, pkv40, pkv41, pkv50, pkv51, pkv60, pkv61, pkv70, pkv71):
|
||||
pkvs = [(pkv00, pkv01), (pkv10, pkv11), (pkv20, pkv21), (pkv30, pkv31), (pkv40, pkv41), (pkv50, pkv51), (pkv60, pkv61), (pkv70, pkv71)]
|
||||
new_pkvs = []
|
||||
for layer, pkv in zip(self.layers, pkvs):
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
pkv[0],
|
||||
pkv[1],
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append((outputs[-1][0], outputs[-1][1], ))
|
||||
((new_pkv00, new_pkv01), (new_pkv10, new_pkv11), (new_pkv20, new_pkv21), (new_pkv30, new_pkv31), (new_pkv40, new_pkv41), (new_pkv50, new_pkv51), (new_pkv60, new_pkv61), (new_pkv70, new_pkv71)) = new_pkvs
|
||||
return hidden_states, new_pkv00, new_pkv01, new_pkv10, new_pkv11, new_pkv20, new_pkv21, new_pkv30, new_pkv31, new_pkv40, new_pkv41, new_pkv50, new_pkv51, new_pkv60, new_pkv61, new_pkv70, new_pkv71
|
||||
|
||||
class EightLayerLayerFV(torch.nn.Module):
|
||||
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
assert(len(layers) == 8)
|
||||
self.layers = layers
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
new_pkvs = []
|
||||
for layer in self.layers:
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=None,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append((outputs[-1][0], outputs[-1][1], ))
|
||||
((new_pkv00, new_pkv01), (new_pkv10, new_pkv11), (new_pkv20, new_pkv21), (new_pkv30, new_pkv31), (new_pkv40, new_pkv41), (new_pkv50, new_pkv51), (new_pkv60, new_pkv61), (new_pkv70, new_pkv71)) = new_pkvs
|
||||
return hidden_states, new_pkv00, new_pkv01, new_pkv10, new_pkv11, new_pkv20, new_pkv21, new_pkv30, new_pkv31, new_pkv40, new_pkv41, new_pkv50, new_pkv51, new_pkv60, new_pkv61, new_pkv70, new_pkv71
|
||||
|
||||
|
||||
class CompiledEightLayerLayerSV(torch.nn.Module):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
((pkv00, pkv01), (pkv10, pkv11), (pkv20, pkv21), (pkv30, pkv31), (pkv40, pkv41), (pkv50, pkv51), (pkv60, pkv61), (pkv70, pkv71)) = past_key_value
|
||||
pkv00 = pkv00.detatch()
|
||||
pkv01 = pkv01.detatch()
|
||||
pkv10 = pkv10.detatch()
|
||||
pkv11 = pkv11.detatch()
|
||||
pkv20 = pkv20.detatch()
|
||||
pkv21 = pkv21.detatch()
|
||||
pkv30 = pkv30.detatch()
|
||||
pkv31 = pkv31.detatch()
|
||||
pkv40 = pkv40.detatch()
|
||||
pkv41 = pkv41.detatch()
|
||||
pkv50 = pkv50.detatch()
|
||||
pkv51 = pkv51.detatch()
|
||||
pkv60 = pkv60.detatch()
|
||||
pkv61 = pkv61.detatch()
|
||||
pkv70 = pkv70.detatch()
|
||||
pkv71 = pkv71.detatch()
|
||||
|
||||
output = self.model("forward", (hidden_states, attention_mask, position_ids, pkv00, pkv01, pkv10, pkv11, pkv20, pkv21, pkv30, pkv31, pkv40, pkv41, pkv50, pkv51, pkv60, pkv61, pkv70, pkv71))
|
||||
return (output[0], (output[1][0], output[1][1]), (output[2][0], output[2][1]),(output[3][0], output[3][1]),(output[4][0], output[4][1]),(output[5][0], output[5][1]),(output[6][0], output[6][1]),(output[7][0], output[7][1]),(output[8][0], output[8][1]),)
|
||||
|
||||
def forward_compressed(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.compressedlayers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[8 * idx:8 * (idx + 1)] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
from time import time
|
||||
class CompiledEightLayerLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value = None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
t2 = time()
|
||||
if past_key_value is None:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
t1 = time()
|
||||
|
||||
output = self.model("first_vicuna_forward", (hidden_states, attention_mask, position_ids))
|
||||
#output2 = (output[0], (output[1], output[2],), (output[3], output[4],),(output[5], output[6],),(output[7], output[8],),(output[9],output[10],),(output[11], output[12],),(output[13], output[14],),(output[15], output[16],),)
|
||||
#return output2
|
||||
return (torch.tensor(output[0]), (torch.tensor(output[1]), torch.tensor(output[2]),), (torch.tensor(output[3]), torch.tensor(output[4]),),(torch.tensor(output[5]), torch.tensor(output[6]),),(torch.tensor(output[7]), torch.tensor(output[8]),),(torch.tensor(output[9]), torch.tensor(output[10]),),(torch.tensor(output[11]), torch.tensor(output[12]),),(torch.tensor(output[13]), torch.tensor(output[14]),),(torch.tensor(output[15]), torch.tensor(output[16]),),)
|
||||
else:
|
||||
((pkv00, pkv01), (pkv10, pkv11), (pkv20, pkv21), (pkv30, pkv31), (pkv40, pkv41), (pkv50, pkv51), (pkv60, pkv61), (pkv70, pkv71)) = past_key_value
|
||||
|
||||
try:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv00 = pkv00.detach()
|
||||
pkv01 = pkv01.detach()
|
||||
pkv10 = pkv10.detach()
|
||||
pkv11 = pkv11.detach()
|
||||
pkv20 = pkv20.detach()
|
||||
pkv21 = pkv21.detach()
|
||||
pkv30 = pkv30.detach()
|
||||
pkv31 = pkv31.detach()
|
||||
pkv40 = pkv40.detach()
|
||||
pkv41 = pkv41.detach()
|
||||
pkv50 = pkv50.detach()
|
||||
pkv51 = pkv51.detach()
|
||||
pkv60 = pkv60.detach()
|
||||
pkv61 = pkv61.detach()
|
||||
pkv70 = pkv70.detach()
|
||||
pkv71 = pkv71.detach()
|
||||
except:
|
||||
x = 10
|
||||
|
||||
t1 = time()
|
||||
if type(hidden_states) == iree.runtime.array_interop.DeviceArray:
|
||||
hidden_states = np.array(hidden_states, hidden_states.dtype)
|
||||
hidden_states = torch.tensor(hidden_states)
|
||||
hidden_states = hidden_states.detach()
|
||||
|
||||
|
||||
output = self.model("second_vicuna_forward", (hidden_states, attention_mask, position_ids, pkv00, pkv01, pkv10, pkv11, pkv20, pkv21, pkv30, pkv31, pkv40, pkv41, pkv50, pkv51, pkv60, pkv61, pkv70, pkv71), send_to_host=False)
|
||||
del pkv00
|
||||
del pkv01
|
||||
del pkv10
|
||||
del pkv11
|
||||
del pkv20
|
||||
del pkv21
|
||||
del pkv30
|
||||
del pkv31
|
||||
del pkv40
|
||||
del pkv41
|
||||
del pkv50
|
||||
del pkv51
|
||||
del pkv60
|
||||
del pkv61
|
||||
del pkv70
|
||||
del pkv71
|
||||
#print(f"sv0 pass completed in {time() - t2} seconds")
|
||||
"""
|
||||
try:
|
||||
pkv00 = np.asarray(pkv00, pkv00.dtype)
|
||||
pkv01 = np.asarray(pkv01, pkv01.dtype)
|
||||
pkv10 = np.asarray(pkv10, pkv10.dtype)
|
||||
pkv11 = np.asarray(pkv11, pkv11.dtype)
|
||||
pkv20 = np.asarray(pkv20, pkv20.dtype)
|
||||
pkv21 = np.asarray(pkv21, pkv21.dtype)
|
||||
pkv30 = np.asarray(pkv30, pkv30.dtype)
|
||||
pkv31 = np.asarray(pkv31, pkv31.dtype)
|
||||
pkv40 = np.asarray(pkv40, pkv40.dtype)
|
||||
pkv41 = np.asarray(pkv41, pkv41.dtype)
|
||||
pkv50 = np.asarray(pkv50, pkv50.dtype)
|
||||
pkv51 = np.asarray(pkv51, pkv51.dtype)
|
||||
pkv60 = np.asarray(pkv60, pkv60.dtype)
|
||||
pkv61 = np.asarray(pkv61, pkv61.dtype)
|
||||
pkv70 = np.asarray(pkv70, pkv70.dtype)
|
||||
pkv71 = np.asarray(pkv71, pkv71.dtype)
|
||||
print("iree arrays converted")
|
||||
except:
|
||||
x = 10
|
||||
"""
|
||||
output2 = (output[0], (output[1], output[2],), (output[3], output[4],),(output[5], output[6],),(output[7], output[8],),(output[9],output[10],),(output[11], output[12],),(output[13], output[14],),(output[15], output[16],),)
|
||||
#output2 = (torch.tensor(output[0]), (torch.tensor(output[1]), torch.tensor(output[2]),), (torch.tensor(output[3]), torch.tensor(output[4]),),(torch.tensor(output[5]), torch.tensor(output[6]),),(torch.tensor(output[7]), torch.tensor(output[8]),),(torch.tensor(output[9]), torch.tensor(output[10]),),(torch.tensor(output[11]), torch.tensor(output[12]),),(torch.tensor(output[13]), torch.tensor(output[14]),),(torch.tensor(output[15]), torch.tensor(output[16]),),)
|
||||
#print(output2[1][0])
|
||||
return output2
|
||||
|
||||
|
||||
|
||||
class ShardedVicuna(SharkLLMBase):
|
||||
# Class representing Sharded Vicuna Model
|
||||
@@ -281,7 +605,10 @@ class ShardedVicuna(SharkLLMBase):
|
||||
vname = vname.strip()
|
||||
vbody = re.sub("arith.constant", "", vbody)
|
||||
vbody = vbody.strip()
|
||||
vdtype = vbody.split(":")[1].strip()
|
||||
if ":" in vbody:
|
||||
vdtype = vbody.split(":")[1].strip()
|
||||
else:
|
||||
vdtype = vbody.split(" ")[-1].strip()
|
||||
fixed_vdtype = vdtype
|
||||
vdtypes.append(vdtype)
|
||||
vdtype = re.sub("\d{1,}x", "?x", vdtype)
|
||||
@@ -339,19 +666,20 @@ class ShardedVicuna(SharkLLMBase):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0=None,
|
||||
past_key_value1=None,
|
||||
past_key_values = None,
|
||||
):
|
||||
# Compile a hidden decoder layer of vicuna
|
||||
if past_key_value0 is None and past_key_value1 is None:
|
||||
if past_key_values is None:
|
||||
model_inputs = (hidden_states, attention_mask, position_ids)
|
||||
else:
|
||||
((pkv00, pkv01), (pkv10, pkv11), (pkv20, pkv21), (pkv30, pkv31), (pkv40, pkv41), (pkv50, pkv51), (pkv60, pkv61), (pkv70, pkv71)) = past_key_values
|
||||
|
||||
|
||||
model_inputs = (
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
pkv00, pkv01, pkv10, pkv11, pkv20, pkv21, pkv30, pkv31, pkv40, pkv41, pkv50, pkv51, pkv60, pkv61, pkv70, pkv71
|
||||
)
|
||||
mlir_bytecode = import_with_fx(
|
||||
vicuna_layer,
|
||||
@@ -414,7 +742,7 @@ class ShardedVicuna(SharkLLMBase):
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
@@ -457,7 +785,7 @@ class ShardedVicuna(SharkLLMBase):
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
@@ -499,7 +827,7 @@ class ShardedVicuna(SharkLLMBase):
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
@@ -545,12 +873,24 @@ class ShardedVicuna(SharkLLMBase):
|
||||
position_ids_placeholder1 = TensorPlaceholder.like(
|
||||
inputs1[2], dynamic_axes=[1]
|
||||
)
|
||||
pkv0_placeholder = TensorPlaceholder.like(
|
||||
inputs1[3], dynamic_axes=[2]
|
||||
)
|
||||
pkv1_placeholder = TensorPlaceholder.like(
|
||||
inputs1[4], dynamic_axes=[2]
|
||||
)
|
||||
|
||||
|
||||
pkv00_placeholder = TensorPlaceholder.like(inputs1[3][0][0], dynamic_axes=[2])
|
||||
pkv01_placeholder = TensorPlaceholder.like(inputs1[3][0][1], dynamic_axes=[2])
|
||||
pkv10_placeholder = TensorPlaceholder.like(inputs1[3][1][0], dynamic_axes=[2])
|
||||
pkv11_placeholder = TensorPlaceholder.like(inputs1[3][1][1], dynamic_axes=[2])
|
||||
pkv20_placeholder = TensorPlaceholder.like(inputs1[3][2][0], dynamic_axes=[2])
|
||||
pkv21_placeholder = TensorPlaceholder.like(inputs1[3][2][1], dynamic_axes=[2])
|
||||
pkv30_placeholder = TensorPlaceholder.like(inputs1[3][3][0], dynamic_axes=[2])
|
||||
pkv31_placeholder = TensorPlaceholder.like(inputs1[3][3][1], dynamic_axes=[2])
|
||||
pkv40_placeholder = TensorPlaceholder.like(inputs1[3][4][0], dynamic_axes=[2])
|
||||
pkv41_placeholder = TensorPlaceholder.like(inputs1[3][4][1], dynamic_axes=[2])
|
||||
pkv50_placeholder = TensorPlaceholder.like(inputs1[3][5][0], dynamic_axes=[2])
|
||||
pkv51_placeholder = TensorPlaceholder.like(inputs1[3][5][1], dynamic_axes=[2])
|
||||
pkv60_placeholder = TensorPlaceholder.like(inputs1[3][6][0], dynamic_axes=[2])
|
||||
pkv61_placeholder = TensorPlaceholder.like(inputs1[3][6][1], dynamic_axes=[2])
|
||||
pkv70_placeholder = TensorPlaceholder.like(inputs1[3][7][0], dynamic_axes=[2])
|
||||
pkv71_placeholder = TensorPlaceholder.like(inputs1[3][7][1], dynamic_axes=[2])
|
||||
|
||||
print(f"Compiling layer {idx} mlir")
|
||||
ts_g = self.compile_vicuna_layer(
|
||||
@@ -596,8 +936,7 @@ class ShardedVicuna(SharkLLMBase):
|
||||
inputs1[0],
|
||||
inputs1[1],
|
||||
inputs1[2],
|
||||
inputs1[3],
|
||||
inputs1[4],
|
||||
inputs1[3]
|
||||
)
|
||||
if self.precision in ["int4", "int8"]:
|
||||
module1 = torch_mlir.compile(
|
||||
@@ -606,8 +945,8 @@ class ShardedVicuna(SharkLLMBase):
|
||||
inputs1[0],
|
||||
attention_mask_placeholder1,
|
||||
inputs1[2],
|
||||
pkv0_placeholder,
|
||||
pkv1_placeholder,
|
||||
pkv00_placeholder, pkv01_placeholder, pkv10_placeholder, pkv11_placeholder, pkv20_placeholder, pkv21_placeholder,pkv30_placeholder, pkv31_placeholder,pkv40_placeholder, pkv41_placeholder, pkv50_placeholder, pkv51_placeholder, pkv60_placeholder, pkv61_placeholder, pkv70_placeholder, pkv71_placeholder
|
||||
|
||||
),
|
||||
output_type="torch",
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
@@ -629,8 +968,7 @@ class ShardedVicuna(SharkLLMBase):
|
||||
inputs1[0],
|
||||
attention_mask_placeholder1,
|
||||
inputs1[2],
|
||||
pkv0_placeholder,
|
||||
pkv1_placeholder,
|
||||
pkv00_placeholder, pkv01_placeholder, pkv10_placeholder, pkv11_placeholder, pkv20_placeholder, pkv21_placeholder,pkv30_placeholder, pkv31_placeholder,pkv40_placeholder, pkv41_placeholder, pkv50_placeholder, pkv51_placeholder, pkv60_placeholder, pkv61_placeholder, pkv70_placeholder, pkv71_placeholder
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
@@ -653,7 +991,7 @@ class ShardedVicuna(SharkLLMBase):
|
||||
device=device,
|
||||
device_idx=idx % 4,
|
||||
mlir_dialect="tm_tensor",
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
@@ -666,7 +1004,7 @@ class ShardedVicuna(SharkLLMBase):
|
||||
device=device,
|
||||
device_idx=idx % 4,
|
||||
mlir_dialect="tm_tensor",
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_full",
|
||||
@@ -695,7 +1033,7 @@ class ShardedVicuna(SharkLLMBase):
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_type="float",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=self.weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
@@ -709,6 +1047,9 @@ class ShardedVicuna(SharkLLMBase):
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
placeholder_pkv_segment = tuple((torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),) for _ in range(8))
|
||||
placeholder_pkv_full = tuple((torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),) for _ in range(32))
|
||||
|
||||
placeholder_input0 = (
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
|
||||
@@ -723,6 +1064,13 @@ class ShardedVicuna(SharkLLMBase):
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
|
||||
placeholder_input2 = (
|
||||
torch.zeros([1, 1, 4096]),
|
||||
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
|
||||
torch.zeros([1, 1], dtype=torch.int64),
|
||||
placeholder_pkv_segment
|
||||
)
|
||||
|
||||
norm = VicunaNorm(vicuna_model.model.norm)
|
||||
device_idx = self.get_device_index(
|
||||
r"vicuna\.model\.model\.norm(?:\.|\s|$)"
|
||||
@@ -765,14 +1113,28 @@ class ShardedVicuna(SharkLLMBase):
|
||||
layers1 = [
|
||||
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
|
||||
layers00 = EightLayerLayerFV(vicuna_model.model.layers[0:8])
|
||||
layers01 = EightLayerLayerFV(vicuna_model.model.layers[8:16])
|
||||
layers02 = EightLayerLayerFV(vicuna_model.model.layers[16:24])
|
||||
layers03 = EightLayerLayerFV(vicuna_model.model.layers[24:32])
|
||||
layers10 = EightLayerLayerSV(vicuna_model.model.layers[0:8])
|
||||
layers11 = EightLayerLayerSV(vicuna_model.model.layers[8:16])
|
||||
layers12 = EightLayerLayerSV(vicuna_model.model.layers[16:24])
|
||||
layers13 = EightLayerLayerSV(vicuna_model.model.layers[24:32])
|
||||
layers0 = [layers00, layers01, layers02, layers03]
|
||||
layers1 = [layers10, layers11, layers12, layers13]
|
||||
#vicuna_model.model.forward = forward_compressed
|
||||
|
||||
_, modules = self.compile_to_vmfb_one_model(
|
||||
placeholder_input0,
|
||||
layers0,
|
||||
placeholder_input1,
|
||||
placeholder_input2,
|
||||
layers1,
|
||||
device=device,
|
||||
)
|
||||
shark_layers = [CompiledVicunaLayer(m) for m in modules]
|
||||
shark_layers = [CompiledEightLayerLayer(m) for m in modules]
|
||||
vicuna_model.model.compressedlayers = shark_layers
|
||||
|
||||
sharded_model = ShardedVicunaModel(
|
||||
vicuna_model,
|
||||
@@ -823,6 +1185,9 @@ class ShardedVicuna(SharkLLMBase):
|
||||
if is_first:
|
||||
prompt = params["prompt"]
|
||||
input_ids = self.tokenizer(prompt).input_ids
|
||||
#crop input_ids
|
||||
input_ids = input_ids[len(input_ids) - 20:]
|
||||
############
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
@@ -1548,7 +1913,7 @@ if __name__ == "__main__":
|
||||
config_json=config_json,
|
||||
weight_group_size=args.weight_group_size,
|
||||
)
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives unhelpful, detailed, and rude answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
while True:
|
||||
|
||||
@@ -66,7 +66,7 @@ class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers, lmhead, embedding, norm):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert len(layers) == len(model.model.layers)
|
||||
#assert len(layers) == len(model.model.layers)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers = layers
|
||||
@@ -132,7 +132,10 @@ class VicunaNormCompiled(torch.nn.Module):
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states.detach()
|
||||
try:
|
||||
hidden_states.detach()
|
||||
except:
|
||||
x = 10
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user