mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[Llama2] Prefetch llama2 tokenizer configs (#1824)
-- This commit prefetches llama2 tokenizer configs from shark_tank. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -196,3 +196,6 @@ db_dir_UserData
|
||||
|
||||
# Embeded browser cache and other
|
||||
apps/stable_diffusion/web/EBWebView/
|
||||
|
||||
# Llama2 tokenizer configs
|
||||
llama2_tokenizer_configs/
|
||||
|
||||
@@ -1238,10 +1238,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
max_num_tokens,
|
||||
extra_args_cmd=extra_args_cmd,
|
||||
)
|
||||
if "llama2" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
"HF auth token required. Pass it using --hf_auth_token flag."
|
||||
)
|
||||
self.hf_auth_token = hf_auth_token
|
||||
if self.model_name == "llama2_7b":
|
||||
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
|
||||
@@ -1277,12 +1273,21 @@ class UnshardedVicuna(VicunaBase):
|
||||
)
|
||||
|
||||
def get_tokenizer(self):
|
||||
kwargs = {"use_auth_token": self.hf_auth_token}
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path,
|
||||
use_fast=False,
|
||||
**kwargs,
|
||||
)
|
||||
local_tokenizer_path = Path(Path.cwd(), "llama2_tokenizer_configs")
|
||||
local_tokenizer_path.mkdir(parents=True, exist_ok=True)
|
||||
tokenizer_files_to_download = [
|
||||
"config.json",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer.model",
|
||||
"tokenizer_config.json",
|
||||
]
|
||||
for tokenizer_file in tokenizer_files_to_download:
|
||||
download_public_file(
|
||||
f"gs://shark_tank/llama2_tokenizer/{tokenizer_file}",
|
||||
Path(local_tokenizer_path, tokenizer_file),
|
||||
single_file=True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(str(local_tokenizer_path))
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
|
||||
Reference in New Issue
Block a user