Fix IREE eager backend device string (#237)

This commit is contained in:
Quinn Dawkins
2022-08-03 15:09:52 -04:00
committed by GitHub
parent 38664a4c68
commit 934f15ebb7

View File

@@ -48,8 +48,8 @@ class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend):
def __init__(self, device: str):
self.torch_device_str = device
self.iree_device_str = IREE_DEVICE_MAP[device]
self.config = ireert.Config(self.iree_device_str)
self.config = ireert.Config(IREE_DEVICE_MAP[device])
self.raw_device_str = device
def get_torch_metadata(
self, tensor: DeviceArray, kwargs: Dict[str, Any]
@@ -71,7 +71,7 @@ class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend):
"EagerMode",
)
callable, _ = get_iree_compiled_module(
imported_module, self.iree_device_str, func_name=fn_name
imported_module, self.raw_device_str, func_name=fn_name
)
return callable