mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
add codegen support in vic pipeline
This commit is contained in:
committed by
Phaneesh Barwaria
parent
be417f0bf4
commit
1c7eecc981
@@ -150,10 +150,14 @@ class ShardedVicuna(SharkLLMBase):
|
||||
self.shark_model = self.compile(device=device)
|
||||
|
||||
def get_tokenizer(self):
|
||||
# Retrieve the tokenizer from Huggingface
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
)
|
||||
if self.model_name == "codegen":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
@@ -786,7 +790,7 @@ class ShardedVicuna(SharkLLMBase):
|
||||
def compile(self, device="cpu"):
|
||||
return self.get_sharded_model(device=device)
|
||||
|
||||
def generate(self, prompt, cli=False):
|
||||
def generate(self, prompt, cli=True):
|
||||
# TODO: refactor for cleaner integration
|
||||
|
||||
tokens_generated = []
|
||||
@@ -909,9 +913,14 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
)
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
)
|
||||
if self.model_name == "codegen":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
@@ -961,9 +970,16 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
|
||||
# Select a compilation prompt such that the resulting input_ids
|
||||
# from the model's tokenizer has shape [1, 19]
|
||||
if self.model_name == "codegen":
|
||||
compilation_prompt = "def hello_world():\n print('Hello World')\n print('Hello World')"
|
||||
else:
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = self.tokenizer(
|
||||
compilation_prompt
|
||||
compilation_prompt,
|
||||
return_tensor="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(
|
||||
compilation_input_ids
|
||||
@@ -1334,10 +1350,11 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
if type(res_tokens[i]) != int:
|
||||
res_tokens[i] = int(res_tokens[i][0])
|
||||
|
||||
res_str = self.tokenizer.decode(res_tokens)
|
||||
skip_sp_tok = True if self.model_name == "codegen" else False
|
||||
res_str = self.tokenizer.decode(res_tokens, skip_special_tokens=skip_sp_tok)
|
||||
return res_str
|
||||
|
||||
def generate(self, prompt, cli=False):
|
||||
def generate(self, prompt, cli=True):
|
||||
# TODO: refactor for cleaner integration
|
||||
import gc
|
||||
|
||||
@@ -1481,7 +1498,8 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
load_inputs=False,
|
||||
)
|
||||
|
||||
detok = self.tokenizer.decode(token)
|
||||
skip_sp_tok = True if self.model_name == "codegen" else False
|
||||
detok = self.tokenizer.decode(token, skip_special_tokens=skip_sp_tok)
|
||||
if debug:
|
||||
print(
|
||||
f"[DEBUG] is_first: {is_first} |"
|
||||
|
||||
Reference in New Issue
Block a user