[vicuna] fix shard config generator script (#1747)

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-08-10 23:56:03 +05:30
committed by GitHub
parent f5e4fa6ffe
commit 3c577f7168
2 changed files with 6 additions and 12 deletions

View File

@@ -161,14 +161,8 @@ def chat(
"llama2_7b",
"llama2_70b",
]:
if model_name == "vicuna4":
from apps.language_models.scripts.vicuna import (
ShardedVicuna as Vicuna,
)
else:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna as Vicuna,
)
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
if vicuna_model == 0:
@@ -186,7 +180,7 @@ def chat(
max_toks = 128 if model_name == "codegen" else 512
if model_name == "vicuna4":
vicuna_model = Vicuna(
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
@@ -196,7 +190,7 @@ def chat(
)
else:
if len(devices) == 1 and config_file is None:
vicuna_model = Vicuna(
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
@@ -211,7 +205,7 @@ def chat(
config_file.close()
else:
config_json = get_default_config()
vicuna_model = Vicuna(
vicuna_model = ShardedVicuna(
model_name,
device=device,
precision=precision,