Fix input tensors with non-floating point dtype in the lockstep tracer (#328)

This commit is contained in:
Quinn Dawkins
2022-09-13 21:14:38 -04:00
committed by GitHub
parent e304041574
commit 8d21292d34

View File

@@ -65,14 +65,18 @@ class TorchMLIRLockstepTensor(TorchMLIRTensor):
nt = elem.detach().data.numpy()
if not nt.flags["C_CONTIGUOUS"]:
nt = np.ascontiguousarray(nt, dtype=nt.dtype)
r.elem = backend.transfer_from_torch_to_device(torch.Tensor(nt))
r.elem = backend.transfer_from_torch_to_device(
torch.from_numpy(nt)
)
elif isinstance(elem, torch.Tensor):
r = make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs)
# Ditto TODO: Find a better way to handle this
nt = elem.numpy()
if not nt.flags["C_CONTIGUOUS"]:
nt = np.ascontiguousarray(nt, dtype=nt.dtype)
r.elem = backend.transfer_from_torch_to_device(torch.Tensor(nt))
r.elem = backend.transfer_from_torch_to_device(
torch.from_numpy(nt)
)
# This branch handles the case when a python scalar is passed to some op
# or is returned from some aten op, such as _local_scalar_dense.
elif isinstance(elem, (int, float, bool)):