diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 86b64958..7f73c320 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -785,7 +785,7 @@ class ShardedVicuna(VicunaBase): module = SharkInference( None, device=device, - device_idx=idx % 4, + device_idx=device_idx, mlir_dialect="tm_tensor", mmap=False, ) @@ -798,7 +798,7 @@ class ShardedVicuna(VicunaBase): module = SharkInference( mlirs[idx], device=device, - device_idx=idx % 4, + device_idx=device_idx, mlir_dialect="tm_tensor", mmap=False, )