fixed bug where device_idx was hardcoded (#1693)

Co-authored-by: Elias Joseph <elias@nod-labs.com>
This commit is contained in:
Eliasj42
2023-07-25 17:00:13 -07:00
committed by GitHub
parent 927b662aa7
commit 9d399eb988

View File

@@ -785,7 +785,7 @@ class ShardedVicuna(VicunaBase):
module = SharkInference(
None,
device=device,
device_idx=idx % 4,
device_idx=device_idx,
mlir_dialect="tm_tensor",
mmap=False,
)
@@ -798,7 +798,7 @@ class ShardedVicuna(VicunaBase):
module = SharkInference(
mlirs[idx],
device=device,
device_idx=idx % 4,
device_idx=device_idx,
mlir_dialect="tm_tensor",
mmap=False,
)