add codegen support in vic pipeline

This commit is contained in:
PhaneeshB
2023-07-11 19:09:30 +05:30
committed by Phaneesh Barwaria
parent be417f0bf4
commit 1c7eecc981

View File

@@ -150,10 +150,14 @@ class ShardedVicuna(SharkLLMBase):
self.shark_model = self.compile(device=device)
def get_tokenizer(self):
# Retrieve the tokenizer from Huggingface
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
if self.model_name == "codegen":
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, trust_remote_code=True,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
@@ -786,7 +790,7 @@ class ShardedVicuna(SharkLLMBase):
def compile(self, device="cpu"):
return self.get_sharded_model(device=device)
def generate(self, prompt, cli=False):
def generate(self, prompt, cli=True):
# TODO: refactor for cleaner integration
tokens_generated = []
@@ -909,9 +913,14 @@ class UnshardedVicuna(SharkLLMBase):
)
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
if self.model_name == "codegen":
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, trust_remote_code=True,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
@@ -961,9 +970,16 @@ class UnshardedVicuna(SharkLLMBase):
)
if not mlir_generated:
compilation_prompt = "".join(["0" for _ in range(17)])
# Select a compilation prompt such that the resulting input_ids
# from the model's tokenizer has shape [1, 19]
if self.model_name == "codegen":
compilation_prompt = "def hello_world():\n print('Hello World')\n print('Hello World')"
else:
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = self.tokenizer(
compilation_prompt
compilation_prompt,
return_tensor="pt",
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
@@ -1334,10 +1350,11 @@ class UnshardedVicuna(SharkLLMBase):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = self.tokenizer.decode(res_tokens)
skip_sp_tok = True if self.model_name == "codegen" else False
res_str = self.tokenizer.decode(res_tokens, skip_special_tokens=skip_sp_tok)
return res_str
def generate(self, prompt, cli=False):
def generate(self, prompt, cli=True):
# TODO: refactor for cleaner integration
import gc
@@ -1481,7 +1498,8 @@ class UnshardedVicuna(SharkLLMBase):
load_inputs=False,
)
detok = self.tokenizer.decode(token)
skip_sp_tok = True if self.model_name == "codegen" else False
detok = self.tokenizer.decode(token, skip_special_tokens=skip_sp_tok)
if debug:
print(
f"[DEBUG] is_first: {is_first} |"