From bce652767c8a7209fbe1b6f9bec0e73fe995d2c9 Mon Sep 17 00:00:00 2001 From: Elias Joseph Date: Fri, 21 Jul 2023 18:28:24 -0400 Subject: [PATCH] branch to run vicuna with 4 shards --- apps/language_models/scripts/vicuna.py | 421 ++++++++++++++++-- .../model_wrappers/vicuna_sharded_model.py | 7 +- 2 files changed, 398 insertions(+), 30 deletions(-) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 73afe840..3f3ea281 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -4,7 +4,9 @@ import re from io import BytesIO from pathlib import Path from tqdm import tqdm -from typing import List, Tuple +from typing import List, Optional, Tuple, Union +import numpy as np +import iree.runtime import torch import torch_mlir @@ -127,6 +129,328 @@ brevitas_matmul_rhs_group_quant_library = [ brevitas〇matmul_rhs_group_quant〡dtype, brevitas〇matmul_rhs_group_quant〡has_value_semantics] +class EightLayerLayerSV(torch.nn.Module): + + def __init__(self, layers): + super().__init__() + assert(len(layers) == 8) + self.layers = layers + + def forward(self, hidden_states, attention_mask, position_ids, pkv00, pkv01, pkv10, pkv11, pkv20, pkv21, pkv30, pkv31, pkv40, pkv41, pkv50, pkv51, pkv60, pkv61, pkv70, pkv71): + pkvs = [(pkv00, pkv01), (pkv10, pkv11), (pkv20, pkv21), (pkv30, pkv31), (pkv40, pkv41), (pkv50, pkv51), (pkv60, pkv61), (pkv70, pkv71)] + new_pkvs = [] + for layer, pkv in zip(self.layers, pkvs): + outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=( + pkv[0], + pkv[1], + ), + use_cache=True, + ) + + hidden_states = outputs[0] + new_pkvs.append((outputs[-1][0], outputs[-1][1], )) + ((new_pkv00, new_pkv01), (new_pkv10, new_pkv11), (new_pkv20, new_pkv21), (new_pkv30, new_pkv31), (new_pkv40, new_pkv41), (new_pkv50, new_pkv51), (new_pkv60, new_pkv61), (new_pkv70, new_pkv71)) = new_pkvs + return hidden_states, new_pkv00, new_pkv01, new_pkv10, new_pkv11, new_pkv20, new_pkv21, new_pkv30, new_pkv31, new_pkv40, new_pkv41, new_pkv50, new_pkv51, new_pkv60, new_pkv61, new_pkv70, new_pkv71 + +class EightLayerLayerFV(torch.nn.Module): + + def __init__(self, layers): + super().__init__() + assert(len(layers) == 8) + self.layers = layers + + def forward(self, hidden_states, attention_mask, position_ids): + new_pkvs = [] + for layer in self.layers: + outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None, + use_cache=True, + ) + + hidden_states = outputs[0] + new_pkvs.append((outputs[-1][0], outputs[-1][1], )) + ((new_pkv00, new_pkv01), (new_pkv10, new_pkv11), (new_pkv20, new_pkv21), (new_pkv30, new_pkv31), (new_pkv40, new_pkv41), (new_pkv50, new_pkv51), (new_pkv60, new_pkv61), (new_pkv70, new_pkv71)) = new_pkvs + return hidden_states, new_pkv00, new_pkv01, new_pkv10, new_pkv11, new_pkv20, new_pkv21, new_pkv30, new_pkv31, new_pkv40, new_pkv41, new_pkv50, new_pkv51, new_pkv60, new_pkv61, new_pkv70, new_pkv71 + + +class CompiledEightLayerLayerSV(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions=False, + use_cache=True, + ): + hidden_states = hidden_states.detach() + attention_mask = attention_mask.detach() + position_ids = position_ids.detach() + ((pkv00, pkv01), (pkv10, pkv11), (pkv20, pkv21), (pkv30, pkv31), (pkv40, pkv41), (pkv50, pkv51), (pkv60, pkv61), (pkv70, pkv71)) = past_key_value + pkv00 = pkv00.detatch() + pkv01 = pkv01.detatch() + pkv10 = pkv10.detatch() + pkv11 = pkv11.detatch() + pkv20 = pkv20.detatch() + pkv21 = pkv21.detatch() + pkv30 = pkv30.detatch() + pkv31 = pkv31.detatch() + pkv40 = pkv40.detatch() + pkv41 = pkv41.detatch() + pkv50 = pkv50.detatch() + pkv51 = pkv51.detatch() + pkv60 = pkv60.detatch() + pkv61 = pkv61.detatch() + pkv70 = pkv70.detatch() + pkv71 = pkv71.detatch() + + output = self.model("forward", (hidden_states, attention_mask, position_ids, pkv00, pkv01, pkv10, pkv11, pkv20, pkv21, pkv30, pkv31, pkv40, pkv41, pkv50, pkv51, pkv60, pkv61, pkv70, pkv71)) + return (output[0], (output[1][0], output[1][1]), (output[2][0], output[2][1]),(output[3][0], output[3][1]),(output[4][0], output[4][1]),(output[5][0], output[5][1]),(output[6][0], output[6][1]),(output[7][0], output[7][1]),(output[8][0], output[8][1]),) + +def forward_compressed( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.compressedlayers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[8 * idx:8 * (idx + 1)] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) +from time import time +class CompiledEightLayerLayer(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value = None, + output_attentions=False, + use_cache=True, + ): + t2 = time() + if past_key_value is None: + hidden_states = hidden_states.detach() + attention_mask = attention_mask.detach() + position_ids = position_ids.detach() + t1 = time() + + output = self.model("first_vicuna_forward", (hidden_states, attention_mask, position_ids)) + #output2 = (output[0], (output[1], output[2],), (output[3], output[4],),(output[5], output[6],),(output[7], output[8],),(output[9],output[10],),(output[11], output[12],),(output[13], output[14],),(output[15], output[16],),) + #return output2 + return (torch.tensor(output[0]), (torch.tensor(output[1]), torch.tensor(output[2]),), (torch.tensor(output[3]), torch.tensor(output[4]),),(torch.tensor(output[5]), torch.tensor(output[6]),),(torch.tensor(output[7]), torch.tensor(output[8]),),(torch.tensor(output[9]), torch.tensor(output[10]),),(torch.tensor(output[11]), torch.tensor(output[12]),),(torch.tensor(output[13]), torch.tensor(output[14]),),(torch.tensor(output[15]), torch.tensor(output[16]),),) + else: + ((pkv00, pkv01), (pkv10, pkv11), (pkv20, pkv21), (pkv30, pkv31), (pkv40, pkv41), (pkv50, pkv51), (pkv60, pkv61), (pkv70, pkv71)) = past_key_value + + try: + hidden_states = hidden_states.detach() + attention_mask = attention_mask.detach() + position_ids = position_ids.detach() + pkv00 = pkv00.detach() + pkv01 = pkv01.detach() + pkv10 = pkv10.detach() + pkv11 = pkv11.detach() + pkv20 = pkv20.detach() + pkv21 = pkv21.detach() + pkv30 = pkv30.detach() + pkv31 = pkv31.detach() + pkv40 = pkv40.detach() + pkv41 = pkv41.detach() + pkv50 = pkv50.detach() + pkv51 = pkv51.detach() + pkv60 = pkv60.detach() + pkv61 = pkv61.detach() + pkv70 = pkv70.detach() + pkv71 = pkv71.detach() + except: + x = 10 + + t1 = time() + if type(hidden_states) == iree.runtime.array_interop.DeviceArray: + hidden_states = np.array(hidden_states, hidden_states.dtype) + hidden_states = torch.tensor(hidden_states) + hidden_states = hidden_states.detach() + + + output = self.model("second_vicuna_forward", (hidden_states, attention_mask, position_ids, pkv00, pkv01, pkv10, pkv11, pkv20, pkv21, pkv30, pkv31, pkv40, pkv41, pkv50, pkv51, pkv60, pkv61, pkv70, pkv71), send_to_host=False) + del pkv00 + del pkv01 + del pkv10 + del pkv11 + del pkv20 + del pkv21 + del pkv30 + del pkv31 + del pkv40 + del pkv41 + del pkv50 + del pkv51 + del pkv60 + del pkv61 + del pkv70 + del pkv71 + #print(f"sv0 pass completed in {time() - t2} seconds") + """ + try: + pkv00 = np.asarray(pkv00, pkv00.dtype) + pkv01 = np.asarray(pkv01, pkv01.dtype) + pkv10 = np.asarray(pkv10, pkv10.dtype) + pkv11 = np.asarray(pkv11, pkv11.dtype) + pkv20 = np.asarray(pkv20, pkv20.dtype) + pkv21 = np.asarray(pkv21, pkv21.dtype) + pkv30 = np.asarray(pkv30, pkv30.dtype) + pkv31 = np.asarray(pkv31, pkv31.dtype) + pkv40 = np.asarray(pkv40, pkv40.dtype) + pkv41 = np.asarray(pkv41, pkv41.dtype) + pkv50 = np.asarray(pkv50, pkv50.dtype) + pkv51 = np.asarray(pkv51, pkv51.dtype) + pkv60 = np.asarray(pkv60, pkv60.dtype) + pkv61 = np.asarray(pkv61, pkv61.dtype) + pkv70 = np.asarray(pkv70, pkv70.dtype) + pkv71 = np.asarray(pkv71, pkv71.dtype) + print("iree arrays converted") + except: + x = 10 + """ + output2 = (output[0], (output[1], output[2],), (output[3], output[4],),(output[5], output[6],),(output[7], output[8],),(output[9],output[10],),(output[11], output[12],),(output[13], output[14],),(output[15], output[16],),) + #output2 = (torch.tensor(output[0]), (torch.tensor(output[1]), torch.tensor(output[2]),), (torch.tensor(output[3]), torch.tensor(output[4]),),(torch.tensor(output[5]), torch.tensor(output[6]),),(torch.tensor(output[7]), torch.tensor(output[8]),),(torch.tensor(output[9]), torch.tensor(output[10]),),(torch.tensor(output[11]), torch.tensor(output[12]),),(torch.tensor(output[13]), torch.tensor(output[14]),),(torch.tensor(output[15]), torch.tensor(output[16]),),) + #print(output2[1][0]) + return output2 + + class ShardedVicuna(SharkLLMBase): # Class representing Sharded Vicuna Model @@ -281,7 +605,10 @@ class ShardedVicuna(SharkLLMBase): vname = vname.strip() vbody = re.sub("arith.constant", "", vbody) vbody = vbody.strip() - vdtype = vbody.split(":")[1].strip() + if ":" in vbody: + vdtype = vbody.split(":")[1].strip() + else: + vdtype = vbody.split(" ")[-1].strip() fixed_vdtype = vdtype vdtypes.append(vdtype) vdtype = re.sub("\d{1,}x", "?x", vdtype) @@ -339,19 +666,20 @@ class ShardedVicuna(SharkLLMBase): hidden_states, attention_mask, position_ids, - past_key_value0=None, - past_key_value1=None, + past_key_values = None, ): # Compile a hidden decoder layer of vicuna - if past_key_value0 is None and past_key_value1 is None: + if past_key_values is None: model_inputs = (hidden_states, attention_mask, position_ids) else: + ((pkv00, pkv01), (pkv10, pkv11), (pkv20, pkv21), (pkv30, pkv31), (pkv40, pkv41), (pkv50, pkv51), (pkv60, pkv61), (pkv70, pkv71)) = past_key_values + + model_inputs = ( hidden_states, attention_mask, position_ids, - past_key_value0, - past_key_value1, + pkv00, pkv01, pkv10, pkv11, pkv20, pkv21, pkv30, pkv31, pkv40, pkv41, pkv50, pkv51, pkv60, pkv61, pkv70, pkv71 ) mlir_bytecode = import_with_fx( vicuna_layer, @@ -414,7 +742,7 @@ class ShardedVicuna(SharkLLMBase): device=device, mlir_dialect="tm_tensor", device_idx=device_idx, - mmap=False, + mmap=True, ) if vmfb_path.exists(): shark_module.load_module(vmfb_path) @@ -457,7 +785,7 @@ class ShardedVicuna(SharkLLMBase): device=device, mlir_dialect="tm_tensor", device_idx=device_idx, - mmap=False, + mmap=True, ) if vmfb_path.exists(): shark_module.load_module(vmfb_path) @@ -499,7 +827,7 @@ class ShardedVicuna(SharkLLMBase): device=device, mlir_dialect="tm_tensor", device_idx=device_idx, - mmap=False, + mmap=True, ) if vmfb_path.exists(): shark_module.load_module(vmfb_path) @@ -545,12 +873,24 @@ class ShardedVicuna(SharkLLMBase): position_ids_placeholder1 = TensorPlaceholder.like( inputs1[2], dynamic_axes=[1] ) - pkv0_placeholder = TensorPlaceholder.like( - inputs1[3], dynamic_axes=[2] - ) - pkv1_placeholder = TensorPlaceholder.like( - inputs1[4], dynamic_axes=[2] - ) + + + pkv00_placeholder = TensorPlaceholder.like(inputs1[3][0][0], dynamic_axes=[2]) + pkv01_placeholder = TensorPlaceholder.like(inputs1[3][0][1], dynamic_axes=[2]) + pkv10_placeholder = TensorPlaceholder.like(inputs1[3][1][0], dynamic_axes=[2]) + pkv11_placeholder = TensorPlaceholder.like(inputs1[3][1][1], dynamic_axes=[2]) + pkv20_placeholder = TensorPlaceholder.like(inputs1[3][2][0], dynamic_axes=[2]) + pkv21_placeholder = TensorPlaceholder.like(inputs1[3][2][1], dynamic_axes=[2]) + pkv30_placeholder = TensorPlaceholder.like(inputs1[3][3][0], dynamic_axes=[2]) + pkv31_placeholder = TensorPlaceholder.like(inputs1[3][3][1], dynamic_axes=[2]) + pkv40_placeholder = TensorPlaceholder.like(inputs1[3][4][0], dynamic_axes=[2]) + pkv41_placeholder = TensorPlaceholder.like(inputs1[3][4][1], dynamic_axes=[2]) + pkv50_placeholder = TensorPlaceholder.like(inputs1[3][5][0], dynamic_axes=[2]) + pkv51_placeholder = TensorPlaceholder.like(inputs1[3][5][1], dynamic_axes=[2]) + pkv60_placeholder = TensorPlaceholder.like(inputs1[3][6][0], dynamic_axes=[2]) + pkv61_placeholder = TensorPlaceholder.like(inputs1[3][6][1], dynamic_axes=[2]) + pkv70_placeholder = TensorPlaceholder.like(inputs1[3][7][0], dynamic_axes=[2]) + pkv71_placeholder = TensorPlaceholder.like(inputs1[3][7][1], dynamic_axes=[2]) print(f"Compiling layer {idx} mlir") ts_g = self.compile_vicuna_layer( @@ -596,8 +936,7 @@ class ShardedVicuna(SharkLLMBase): inputs1[0], inputs1[1], inputs1[2], - inputs1[3], - inputs1[4], + inputs1[3] ) if self.precision in ["int4", "int8"]: module1 = torch_mlir.compile( @@ -606,8 +945,8 @@ class ShardedVicuna(SharkLLMBase): inputs1[0], attention_mask_placeholder1, inputs1[2], - pkv0_placeholder, - pkv1_placeholder, + pkv00_placeholder, pkv01_placeholder, pkv10_placeholder, pkv11_placeholder, pkv20_placeholder, pkv21_placeholder,pkv30_placeholder, pkv31_placeholder,pkv40_placeholder, pkv41_placeholder, pkv50_placeholder, pkv51_placeholder, pkv60_placeholder, pkv61_placeholder, pkv70_placeholder, pkv71_placeholder + ), output_type="torch", backend_legal_ops=["brevitas.matmul_rhs_group_quant"], @@ -629,8 +968,7 @@ class ShardedVicuna(SharkLLMBase): inputs1[0], attention_mask_placeholder1, inputs1[2], - pkv0_placeholder, - pkv1_placeholder, + pkv00_placeholder, pkv01_placeholder, pkv10_placeholder, pkv11_placeholder, pkv20_placeholder, pkv21_placeholder,pkv30_placeholder, pkv31_placeholder,pkv40_placeholder, pkv41_placeholder, pkv50_placeholder, pkv51_placeholder, pkv60_placeholder, pkv61_placeholder, pkv70_placeholder, pkv71_placeholder ), torch_mlir.OutputType.LINALG_ON_TENSORS, use_tracing=False, @@ -653,7 +991,7 @@ class ShardedVicuna(SharkLLMBase): device=device, device_idx=idx % 4, mlir_dialect="tm_tensor", - mmap=False, + mmap=True, ) module.load_module(vmfb_path) else: @@ -666,7 +1004,7 @@ class ShardedVicuna(SharkLLMBase): device=device, device_idx=idx % 4, mlir_dialect="tm_tensor", - mmap=False, + mmap=True, ) module.save_module( module_name=f"{idx}_full", @@ -695,7 +1033,7 @@ class ShardedVicuna(SharkLLMBase): weight_quant_type="asym", weight_bit_width=weight_bit_width, weight_param_method="stats", - weight_scale_type="float", + weight_scale_precision="float", weight_quant_granularity="per_group", weight_group_size=self.weight_group_size, quantize_weight_zero_point=False, @@ -709,6 +1047,9 @@ class ShardedVicuna(SharkLLMBase): ) print("Weight quantization applied.") + placeholder_pkv_segment = tuple((torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),) for _ in range(8)) + placeholder_pkv_full = tuple((torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),) for _ in range(32)) + placeholder_input0 = ( torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]), @@ -723,6 +1064,13 @@ class ShardedVicuna(SharkLLMBase): torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), ) + placeholder_input2 = ( + torch.zeros([1, 1, 4096]), + torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]), + torch.zeros([1, 1], dtype=torch.int64), + placeholder_pkv_segment + ) + norm = VicunaNorm(vicuna_model.model.norm) device_idx = self.get_device_index( r"vicuna\.model\.model\.norm(?:\.|\s|$)" @@ -765,14 +1113,28 @@ class ShardedVicuna(SharkLLMBase): layers1 = [ SecondVicunaLayer(layer) for layer in vicuna_model.model.layers ] + + layers00 = EightLayerLayerFV(vicuna_model.model.layers[0:8]) + layers01 = EightLayerLayerFV(vicuna_model.model.layers[8:16]) + layers02 = EightLayerLayerFV(vicuna_model.model.layers[16:24]) + layers03 = EightLayerLayerFV(vicuna_model.model.layers[24:32]) + layers10 = EightLayerLayerSV(vicuna_model.model.layers[0:8]) + layers11 = EightLayerLayerSV(vicuna_model.model.layers[8:16]) + layers12 = EightLayerLayerSV(vicuna_model.model.layers[16:24]) + layers13 = EightLayerLayerSV(vicuna_model.model.layers[24:32]) + layers0 = [layers00, layers01, layers02, layers03] + layers1 = [layers10, layers11, layers12, layers13] + #vicuna_model.model.forward = forward_compressed + _, modules = self.compile_to_vmfb_one_model( placeholder_input0, layers0, - placeholder_input1, + placeholder_input2, layers1, device=device, ) - shark_layers = [CompiledVicunaLayer(m) for m in modules] + shark_layers = [CompiledEightLayerLayer(m) for m in modules] + vicuna_model.model.compressedlayers = shark_layers sharded_model = ShardedVicunaModel( vicuna_model, @@ -823,6 +1185,9 @@ class ShardedVicuna(SharkLLMBase): if is_first: prompt = params["prompt"] input_ids = self.tokenizer(prompt).input_ids + #crop input_ids + input_ids = input_ids[len(input_ids) - 20:] + ############ input_id_len = len(input_ids) input_ids = torch.tensor(input_ids) input_ids = input_ids.reshape([1, input_id_len]) @@ -1548,7 +1913,7 @@ if __name__ == "__main__": config_json=config_json, weight_group_size=args.weight_group_size, ) - prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" + prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives unhelpful, detailed, and rude answers to the user's questions.\n" prologue_prompt = "ASSISTANT:\n" while True: diff --git a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py index 1a46c3b5..5d132969 100644 --- a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py +++ b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py @@ -66,7 +66,7 @@ class ShardedVicunaModel(torch.nn.Module): def __init__(self, model, layers, lmhead, embedding, norm): super().__init__() self.model = model - assert len(layers) == len(model.model.layers) + #assert len(layers) == len(model.model.layers) self.model.model.config.use_cache = True self.model.model.config.output_attentions = False self.layers = layers @@ -132,7 +132,10 @@ class VicunaNormCompiled(torch.nn.Module): self.model = shark_module def forward(self, hidden_states): - hidden_states.detach() + try: + hidden_states.detach() + except: + x = 10 output = self.model("forward", (hidden_states,)) output = torch.tensor(output) return output