mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD] Fix SAMPLE_INPUT_LEN import issue
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
committed by
Phaneesh Barwaria
parent
fa833f8366
commit
7b74c86e42
@@ -17,7 +17,6 @@ from transformers import (
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from apps.language_models.scripts.sharded_vicuna_fp32 import (
|
||||
tokenizer,
|
||||
SAMPLE_INPUT_LEN,
|
||||
get_sharded_model,
|
||||
)
|
||||
|
||||
@@ -50,6 +49,7 @@ def chat(curr_system_message, history, model):
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
if "vicuna" in model:
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
curr_system_message = start_message_vicuna
|
||||
if sharded_model == 0:
|
||||
sharded_model = get_sharded_model()
|
||||
@@ -60,6 +60,7 @@ def chat(curr_system_message, history, model):
|
||||
]
|
||||
)
|
||||
prompt = messages.strip()
|
||||
print("prompt = ", prompt)
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
new_sentence = ""
|
||||
for _ in range(200):
|
||||
|
||||
Reference in New Issue
Block a user