[Llama2] Prefetch llama2 tokenizer configs (#1824)

-- This commit prefetches llama2 tokenizer configs from shark_tank.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek Varma
2023-09-08 23:59:54 +05:30
committed by GitHub
parent c5dcfc1f13
commit c854208d49
2 changed files with 18 additions and 10 deletions

3
.gitignore vendored
View File

@@ -196,3 +196,6 @@ db_dir_UserData
# Embeded browser cache and other
apps/stable_diffusion/web/EBWebView/
# Llama2 tokenizer configs
llama2_tokenizer_configs/

View File

@@ -1238,10 +1238,6 @@ class UnshardedVicuna(VicunaBase):
max_num_tokens,
extra_args_cmd=extra_args_cmd,
)
if "llama2" in self.model_name and hf_auth_token == None:
raise ValueError(
"HF auth token required. Pass it using --hf_auth_token flag."
)
self.hf_auth_token = hf_auth_token
if self.model_name == "llama2_7b":
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
@@ -1277,12 +1273,21 @@ class UnshardedVicuna(VicunaBase):
)
def get_tokenizer(self):
kwargs = {"use_auth_token": self.hf_auth_token}
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
use_fast=False,
**kwargs,
)
local_tokenizer_path = Path(Path.cwd(), "llama2_tokenizer_configs")
local_tokenizer_path.mkdir(parents=True, exist_ok=True)
tokenizer_files_to_download = [
"config.json",
"special_tokens_map.json",
"tokenizer.model",
"tokenizer_config.json",
]
for tokenizer_file in tokenizer_files_to_download:
download_public_file(
f"gs://shark_tank/llama2_tokenizer/{tokenizer_file}",
Path(local_tokenizer_path, tokenizer_file),
single_file=True,
)
tokenizer = AutoTokenizer.from_pretrained(str(local_tokenizer_path))
return tokenizer
def get_src_model(self):