[SD] Fix SAMPLE_INPUT_LEN import issue

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-05-11 15:38:59 +05:30
committed by Phaneesh Barwaria
parent fa833f8366
commit 7b74c86e42

View File

@@ -17,7 +17,6 @@ from transformers import (
from apps.stable_diffusion.web.ui.utils import available_devices
from apps.language_models.scripts.sharded_vicuna_fp32 import (
tokenizer,
SAMPLE_INPUT_LEN,
get_sharded_model,
)
@@ -50,6 +49,7 @@ def chat(curr_system_message, history, model):
global sharded_model
global past_key_values
if "vicuna" in model:
SAMPLE_INPUT_LEN = 137
curr_system_message = start_message_vicuna
if sharded_model == 0:
sharded_model = get_sharded_model()
@@ -60,6 +60,7 @@ def chat(curr_system_message, history, model):
]
)
prompt = messages.strip()
print("prompt = ", prompt)
input_ids = tokenizer(prompt).input_ids
new_sentence = ""
for _ in range(200):