mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
[vicuna] fix shard config generator script (#1747)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -161,14 +161,8 @@ def chat(
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
if model_name == "vicuna4":
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
ShardedVicuna as Vicuna,
|
||||
)
|
||||
else:
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna as Vicuna,
|
||||
)
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if vicuna_model == 0:
|
||||
@@ -186,7 +180,7 @@ def chat(
|
||||
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = Vicuna(
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
@@ -196,7 +190,7 @@ def chat(
|
||||
)
|
||||
else:
|
||||
if len(devices) == 1 and config_file is None:
|
||||
vicuna_model = Vicuna(
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
@@ -211,7 +205,7 @@ def chat(
|
||||
config_file.close()
|
||||
else:
|
||||
config_json = get_default_config()
|
||||
vicuna_model = Vicuna(
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
device=device,
|
||||
precision=precision,
|
||||
|
||||
Reference in New Issue
Block a user