branch to run vicuna with 4 shards

This commit is contained in:
Elias Joseph
2023-07-21 18:28:24 -04:00
parent 55a12cc0c4
commit bce652767c
2 changed files with 398 additions and 30 deletions

View File

@@ -4,7 +4,9 @@ import re
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from typing import List, Tuple
from typing import List, Optional, Tuple, Union
import numpy as np
import iree.runtime
import torch
import torch_mlir
@@ -127,6 +129,328 @@ brevitas_matmul_rhs_group_quant_library = [
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics]
class EightLayerLayerSV(torch.nn.Module):
def __init__(self, layers):
super().__init__()
assert(len(layers) == 8)
self.layers = layers
def forward(self, hidden_states, attention_mask, position_ids, pkv00, pkv01, pkv10, pkv11, pkv20, pkv21, pkv30, pkv31, pkv40, pkv41, pkv50, pkv51, pkv60, pkv61, pkv70, pkv71):
pkvs = [(pkv00, pkv01), (pkv10, pkv11), (pkv20, pkv21), (pkv30, pkv31), (pkv40, pkv41), (pkv50, pkv51), (pkv60, pkv61), (pkv70, pkv71)]
new_pkvs = []
for layer, pkv in zip(self.layers, pkvs):
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
pkv[0],
pkv[1],
),
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append((outputs[-1][0], outputs[-1][1], ))
((new_pkv00, new_pkv01), (new_pkv10, new_pkv11), (new_pkv20, new_pkv21), (new_pkv30, new_pkv31), (new_pkv40, new_pkv41), (new_pkv50, new_pkv51), (new_pkv60, new_pkv61), (new_pkv70, new_pkv71)) = new_pkvs
return hidden_states, new_pkv00, new_pkv01, new_pkv10, new_pkv11, new_pkv20, new_pkv21, new_pkv30, new_pkv31, new_pkv40, new_pkv41, new_pkv50, new_pkv51, new_pkv60, new_pkv61, new_pkv70, new_pkv71
class EightLayerLayerFV(torch.nn.Module):
def __init__(self, layers):
super().__init__()
assert(len(layers) == 8)
self.layers = layers
def forward(self, hidden_states, attention_mask, position_ids):
new_pkvs = []
for layer in self.layers:
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=None,
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append((outputs[-1][0], outputs[-1][1], ))
((new_pkv00, new_pkv01), (new_pkv10, new_pkv11), (new_pkv20, new_pkv21), (new_pkv30, new_pkv31), (new_pkv40, new_pkv41), (new_pkv50, new_pkv51), (new_pkv60, new_pkv61), (new_pkv70, new_pkv71)) = new_pkvs
return hidden_states, new_pkv00, new_pkv01, new_pkv10, new_pkv11, new_pkv20, new_pkv21, new_pkv30, new_pkv31, new_pkv40, new_pkv41, new_pkv50, new_pkv51, new_pkv60, new_pkv61, new_pkv70, new_pkv71
class CompiledEightLayerLayerSV(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
((pkv00, pkv01), (pkv10, pkv11), (pkv20, pkv21), (pkv30, pkv31), (pkv40, pkv41), (pkv50, pkv51), (pkv60, pkv61), (pkv70, pkv71)) = past_key_value
pkv00 = pkv00.detatch()
pkv01 = pkv01.detatch()
pkv10 = pkv10.detatch()
pkv11 = pkv11.detatch()
pkv20 = pkv20.detatch()
pkv21 = pkv21.detatch()
pkv30 = pkv30.detatch()
pkv31 = pkv31.detatch()
pkv40 = pkv40.detatch()
pkv41 = pkv41.detatch()
pkv50 = pkv50.detatch()
pkv51 = pkv51.detatch()
pkv60 = pkv60.detatch()
pkv61 = pkv61.detatch()
pkv70 = pkv70.detatch()
pkv71 = pkv71.detatch()
output = self.model("forward", (hidden_states, attention_mask, position_ids, pkv00, pkv01, pkv10, pkv11, pkv20, pkv21, pkv30, pkv31, pkv40, pkv41, pkv50, pkv51, pkv60, pkv61, pkv70, pkv71))
return (output[0], (output[1][0], output[1][1]), (output[2][0], output[2][1]),(output[3][0], output[3][1]),(output[4][0], output[4][1]),(output[5][0], output[5][1]),(output[6][0], output[6][1]),(output[7][0], output[7][1]),(output[8][0], output[8][1]),)
def forward_compressed(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.compressedlayers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[8 * idx:8 * (idx + 1)] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
from time import time
class CompiledEightLayerLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value = None,
output_attentions=False,
use_cache=True,
):
t2 = time()
if past_key_value is None:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
t1 = time()
output = self.model("first_vicuna_forward", (hidden_states, attention_mask, position_ids))
#output2 = (output[0], (output[1], output[2],), (output[3], output[4],),(output[5], output[6],),(output[7], output[8],),(output[9],output[10],),(output[11], output[12],),(output[13], output[14],),(output[15], output[16],),)
#return output2
return (torch.tensor(output[0]), (torch.tensor(output[1]), torch.tensor(output[2]),), (torch.tensor(output[3]), torch.tensor(output[4]),),(torch.tensor(output[5]), torch.tensor(output[6]),),(torch.tensor(output[7]), torch.tensor(output[8]),),(torch.tensor(output[9]), torch.tensor(output[10]),),(torch.tensor(output[11]), torch.tensor(output[12]),),(torch.tensor(output[13]), torch.tensor(output[14]),),(torch.tensor(output[15]), torch.tensor(output[16]),),)
else:
((pkv00, pkv01), (pkv10, pkv11), (pkv20, pkv21), (pkv30, pkv31), (pkv40, pkv41), (pkv50, pkv51), (pkv60, pkv61), (pkv70, pkv71)) = past_key_value
try:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv00 = pkv00.detach()
pkv01 = pkv01.detach()
pkv10 = pkv10.detach()
pkv11 = pkv11.detach()
pkv20 = pkv20.detach()
pkv21 = pkv21.detach()
pkv30 = pkv30.detach()
pkv31 = pkv31.detach()
pkv40 = pkv40.detach()
pkv41 = pkv41.detach()
pkv50 = pkv50.detach()
pkv51 = pkv51.detach()
pkv60 = pkv60.detach()
pkv61 = pkv61.detach()
pkv70 = pkv70.detach()
pkv71 = pkv71.detach()
except:
x = 10
t1 = time()
if type(hidden_states) == iree.runtime.array_interop.DeviceArray:
hidden_states = np.array(hidden_states, hidden_states.dtype)
hidden_states = torch.tensor(hidden_states)
hidden_states = hidden_states.detach()
output = self.model("second_vicuna_forward", (hidden_states, attention_mask, position_ids, pkv00, pkv01, pkv10, pkv11, pkv20, pkv21, pkv30, pkv31, pkv40, pkv41, pkv50, pkv51, pkv60, pkv61, pkv70, pkv71), send_to_host=False)
del pkv00
del pkv01
del pkv10
del pkv11
del pkv20
del pkv21
del pkv30
del pkv31
del pkv40
del pkv41
del pkv50
del pkv51
del pkv60
del pkv61
del pkv70
del pkv71
#print(f"sv0 pass completed in {time() - t2} seconds")
"""
try:
pkv00 = np.asarray(pkv00, pkv00.dtype)
pkv01 = np.asarray(pkv01, pkv01.dtype)
pkv10 = np.asarray(pkv10, pkv10.dtype)
pkv11 = np.asarray(pkv11, pkv11.dtype)
pkv20 = np.asarray(pkv20, pkv20.dtype)
pkv21 = np.asarray(pkv21, pkv21.dtype)
pkv30 = np.asarray(pkv30, pkv30.dtype)
pkv31 = np.asarray(pkv31, pkv31.dtype)
pkv40 = np.asarray(pkv40, pkv40.dtype)
pkv41 = np.asarray(pkv41, pkv41.dtype)
pkv50 = np.asarray(pkv50, pkv50.dtype)
pkv51 = np.asarray(pkv51, pkv51.dtype)
pkv60 = np.asarray(pkv60, pkv60.dtype)
pkv61 = np.asarray(pkv61, pkv61.dtype)
pkv70 = np.asarray(pkv70, pkv70.dtype)
pkv71 = np.asarray(pkv71, pkv71.dtype)
print("iree arrays converted")
except:
x = 10
"""
output2 = (output[0], (output[1], output[2],), (output[3], output[4],),(output[5], output[6],),(output[7], output[8],),(output[9],output[10],),(output[11], output[12],),(output[13], output[14],),(output[15], output[16],),)
#output2 = (torch.tensor(output[0]), (torch.tensor(output[1]), torch.tensor(output[2]),), (torch.tensor(output[3]), torch.tensor(output[4]),),(torch.tensor(output[5]), torch.tensor(output[6]),),(torch.tensor(output[7]), torch.tensor(output[8]),),(torch.tensor(output[9]), torch.tensor(output[10]),),(torch.tensor(output[11]), torch.tensor(output[12]),),(torch.tensor(output[13]), torch.tensor(output[14]),),(torch.tensor(output[15]), torch.tensor(output[16]),),)
#print(output2[1][0])
return output2
class ShardedVicuna(SharkLLMBase):
# Class representing Sharded Vicuna Model
@@ -281,7 +605,10 @@ class ShardedVicuna(SharkLLMBase):
vname = vname.strip()
vbody = re.sub("arith.constant", "", vbody)
vbody = vbody.strip()
vdtype = vbody.split(":")[1].strip()
if ":" in vbody:
vdtype = vbody.split(":")[1].strip()
else:
vdtype = vbody.split(" ")[-1].strip()
fixed_vdtype = vdtype
vdtypes.append(vdtype)
vdtype = re.sub("\d{1,}x", "?x", vdtype)
@@ -339,19 +666,20 @@ class ShardedVicuna(SharkLLMBase):
hidden_states,
attention_mask,
position_ids,
past_key_value0=None,
past_key_value1=None,
past_key_values = None,
):
# Compile a hidden decoder layer of vicuna
if past_key_value0 is None and past_key_value1 is None:
if past_key_values is None:
model_inputs = (hidden_states, attention_mask, position_ids)
else:
((pkv00, pkv01), (pkv10, pkv11), (pkv20, pkv21), (pkv30, pkv31), (pkv40, pkv41), (pkv50, pkv51), (pkv60, pkv61), (pkv70, pkv71)) = past_key_values
model_inputs = (
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
pkv00, pkv01, pkv10, pkv11, pkv20, pkv21, pkv30, pkv31, pkv40, pkv41, pkv50, pkv51, pkv60, pkv61, pkv70, pkv71
)
mlir_bytecode = import_with_fx(
vicuna_layer,
@@ -414,7 +742,7 @@ class ShardedVicuna(SharkLLMBase):
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
mmap=True,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
@@ -457,7 +785,7 @@ class ShardedVicuna(SharkLLMBase):
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
mmap=True,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
@@ -499,7 +827,7 @@ class ShardedVicuna(SharkLLMBase):
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
mmap=True,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
@@ -545,12 +873,24 @@ class ShardedVicuna(SharkLLMBase):
position_ids_placeholder1 = TensorPlaceholder.like(
inputs1[2], dynamic_axes=[1]
)
pkv0_placeholder = TensorPlaceholder.like(
inputs1[3], dynamic_axes=[2]
)
pkv1_placeholder = TensorPlaceholder.like(
inputs1[4], dynamic_axes=[2]
)
pkv00_placeholder = TensorPlaceholder.like(inputs1[3][0][0], dynamic_axes=[2])
pkv01_placeholder = TensorPlaceholder.like(inputs1[3][0][1], dynamic_axes=[2])
pkv10_placeholder = TensorPlaceholder.like(inputs1[3][1][0], dynamic_axes=[2])
pkv11_placeholder = TensorPlaceholder.like(inputs1[3][1][1], dynamic_axes=[2])
pkv20_placeholder = TensorPlaceholder.like(inputs1[3][2][0], dynamic_axes=[2])
pkv21_placeholder = TensorPlaceholder.like(inputs1[3][2][1], dynamic_axes=[2])
pkv30_placeholder = TensorPlaceholder.like(inputs1[3][3][0], dynamic_axes=[2])
pkv31_placeholder = TensorPlaceholder.like(inputs1[3][3][1], dynamic_axes=[2])
pkv40_placeholder = TensorPlaceholder.like(inputs1[3][4][0], dynamic_axes=[2])
pkv41_placeholder = TensorPlaceholder.like(inputs1[3][4][1], dynamic_axes=[2])
pkv50_placeholder = TensorPlaceholder.like(inputs1[3][5][0], dynamic_axes=[2])
pkv51_placeholder = TensorPlaceholder.like(inputs1[3][5][1], dynamic_axes=[2])
pkv60_placeholder = TensorPlaceholder.like(inputs1[3][6][0], dynamic_axes=[2])
pkv61_placeholder = TensorPlaceholder.like(inputs1[3][6][1], dynamic_axes=[2])
pkv70_placeholder = TensorPlaceholder.like(inputs1[3][7][0], dynamic_axes=[2])
pkv71_placeholder = TensorPlaceholder.like(inputs1[3][7][1], dynamic_axes=[2])
print(f"Compiling layer {idx} mlir")
ts_g = self.compile_vicuna_layer(
@@ -596,8 +936,7 @@ class ShardedVicuna(SharkLLMBase):
inputs1[0],
inputs1[1],
inputs1[2],
inputs1[3],
inputs1[4],
inputs1[3]
)
if self.precision in ["int4", "int8"]:
module1 = torch_mlir.compile(
@@ -606,8 +945,8 @@ class ShardedVicuna(SharkLLMBase):
inputs1[0],
attention_mask_placeholder1,
inputs1[2],
pkv0_placeholder,
pkv1_placeholder,
pkv00_placeholder, pkv01_placeholder, pkv10_placeholder, pkv11_placeholder, pkv20_placeholder, pkv21_placeholder,pkv30_placeholder, pkv31_placeholder,pkv40_placeholder, pkv41_placeholder, pkv50_placeholder, pkv51_placeholder, pkv60_placeholder, pkv61_placeholder, pkv70_placeholder, pkv71_placeholder
),
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
@@ -629,8 +968,7 @@ class ShardedVicuna(SharkLLMBase):
inputs1[0],
attention_mask_placeholder1,
inputs1[2],
pkv0_placeholder,
pkv1_placeholder,
pkv00_placeholder, pkv01_placeholder, pkv10_placeholder, pkv11_placeholder, pkv20_placeholder, pkv21_placeholder,pkv30_placeholder, pkv31_placeholder,pkv40_placeholder, pkv41_placeholder, pkv50_placeholder, pkv51_placeholder, pkv60_placeholder, pkv61_placeholder, pkv70_placeholder, pkv71_placeholder
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
@@ -653,7 +991,7 @@ class ShardedVicuna(SharkLLMBase):
device=device,
device_idx=idx % 4,
mlir_dialect="tm_tensor",
mmap=False,
mmap=True,
)
module.load_module(vmfb_path)
else:
@@ -666,7 +1004,7 @@ class ShardedVicuna(SharkLLMBase):
device=device,
device_idx=idx % 4,
mlir_dialect="tm_tensor",
mmap=False,
mmap=True,
)
module.save_module(
module_name=f"{idx}_full",
@@ -695,7 +1033,7 @@ class ShardedVicuna(SharkLLMBase):
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_type="float",
weight_scale_precision="float",
weight_quant_granularity="per_group",
weight_group_size=self.weight_group_size,
quantize_weight_zero_point=False,
@@ -709,6 +1047,9 @@ class ShardedVicuna(SharkLLMBase):
)
print("Weight quantization applied.")
placeholder_pkv_segment = tuple((torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),) for _ in range(8))
placeholder_pkv_full = tuple((torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),) for _ in range(32))
placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
@@ -723,6 +1064,13 @@ class ShardedVicuna(SharkLLMBase):
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
placeholder_input2 = (
torch.zeros([1, 1, 4096]),
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
torch.zeros([1, 1], dtype=torch.int64),
placeholder_pkv_segment
)
norm = VicunaNorm(vicuna_model.model.norm)
device_idx = self.get_device_index(
r"vicuna\.model\.model\.norm(?:\.|\s|$)"
@@ -765,14 +1113,28 @@ class ShardedVicuna(SharkLLMBase):
layers1 = [
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
]
layers00 = EightLayerLayerFV(vicuna_model.model.layers[0:8])
layers01 = EightLayerLayerFV(vicuna_model.model.layers[8:16])
layers02 = EightLayerLayerFV(vicuna_model.model.layers[16:24])
layers03 = EightLayerLayerFV(vicuna_model.model.layers[24:32])
layers10 = EightLayerLayerSV(vicuna_model.model.layers[0:8])
layers11 = EightLayerLayerSV(vicuna_model.model.layers[8:16])
layers12 = EightLayerLayerSV(vicuna_model.model.layers[16:24])
layers13 = EightLayerLayerSV(vicuna_model.model.layers[24:32])
layers0 = [layers00, layers01, layers02, layers03]
layers1 = [layers10, layers11, layers12, layers13]
#vicuna_model.model.forward = forward_compressed
_, modules = self.compile_to_vmfb_one_model(
placeholder_input0,
layers0,
placeholder_input1,
placeholder_input2,
layers1,
device=device,
)
shark_layers = [CompiledVicunaLayer(m) for m in modules]
shark_layers = [CompiledEightLayerLayer(m) for m in modules]
vicuna_model.model.compressedlayers = shark_layers
sharded_model = ShardedVicunaModel(
vicuna_model,
@@ -823,6 +1185,9 @@ class ShardedVicuna(SharkLLMBase):
if is_first:
prompt = params["prompt"]
input_ids = self.tokenizer(prompt).input_ids
#crop input_ids
input_ids = input_ids[len(input_ids) - 20:]
############
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
@@ -1548,7 +1913,7 @@ if __name__ == "__main__":
config_json=config_json,
weight_group_size=args.weight_group_size,
)
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives unhelpful, detailed, and rude answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"
while True:

View File

@@ -66,7 +66,7 @@ class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers) == len(model.model.layers)
#assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers = layers
@@ -132,7 +132,10 @@ class VicunaNormCompiled(torch.nn.Module):
self.model = shark_module
def forward(self, hidden_states):
hidden_states.detach()
try:
hidden_states.detach()
except:
x = 10
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output