fix SLM for SharkStudio

This commit is contained in:
PhaneeshB
2023-05-17 22:18:15 +05:30
committed by Phaneesh Barwaria
parent 6602a2f5ba
commit 512235892e

View File

@@ -106,6 +106,8 @@ def chat(curr_system_message, history, model):
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32
)
stableLMModel = StableLMModel(m)
input_ids = torch.randint(3, (1, 256))
attention_mask = torch.randint(3, (1, 256))
sharkModel = compile_stableLM(
stableLMModel,
tuple([input_ids, attention_mask]),