From 9d399eb98853aeda70f7a821faba4f46c91f4c40 Mon Sep 17 00:00:00 2001 From: Eliasj42 <46754803+Eliasj42@users.noreply.github.com> Date: Tue, 25 Jul 2023 17:00:13 -0700 Subject: [PATCH] fixed bug where device_idx was hardcoded (#1693) Co-authored-by: Elias Joseph --- apps/language_models/scripts/vicuna.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 86b64958..7f73c320 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -785,7 +785,7 @@ class ShardedVicuna(VicunaBase): module = SharkInference( None, device=device, - device_idx=idx % 4, + device_idx=device_idx, mlir_dialect="tm_tensor", mmap=False, ) @@ -798,7 +798,7 @@ class ShardedVicuna(VicunaBase): module = SharkInference( mlirs[idx], device=device, - device_idx=idx % 4, + device_idx=device_idx, mlir_dialect="tm_tensor", mmap=False, )