diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index bd39568b50..e811221ba7 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -526,8 +526,8 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ self.assign(Tensor.randint(*self.shape, low=dtypes.min(self.dtype), high=dtypes.max(self.dtype), device=self.device, dtype=self.dtype))), "aten.random_.from": inplace_fn("self")(lambda self, from_, to: self.assign(Tensor.randint(*self.shape, low=from_, high=to, device=self.device, dtype=self.dtype))), - "aten.uniform_": inplace_fn("self")(lambda self, low=0, high=1: self.assign(Tensor.uniform(*self.shape, low=low, high=high))), - "aten.normal_": inplace_fn("self")(lambda self, mean=0, std=1: self.assign(Tensor.normal(*self.shape, mean=mean, std=std))), + "aten.uniform_": inplace_fn("self")(lambda self, low=0, high=1: self.assign(Tensor.uniform(*self.shape, low=low, high=high, dtype=self.dtype))), + "aten.normal_": inplace_fn("self")(lambda self, mean=0, std=1: self.assign(Tensor.normal(*self.shape, mean=mean, std=std, dtype=self.dtype))), # these don't work in out form, they have size 0 "aten.abs": Tensor.abs, "aten.logical_not": Tensor.logical_not, diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index 99ecb242b8..30117c3dd1 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -153,6 +153,16 @@ class TestTorchBackend(unittest.TestCase): res = torch.ops.aten.isin.Tensor_Tensor_out(a, b, invert=invert, assume_unique=assume_unique, out=out) np.testing.assert_equal(out.cpu().numpy(), expected.cpu().numpy()) + def test_uniform(self): + for torch_dtype in [torch.float32, torch.float16]: + a = torch.rand(10, 10, device=device, dtype=torch_dtype) + self.assertEqual(a.dtype, torch_dtype) + + def test_normal(self): + for torch_dtype in [torch.float32, torch.float16]: + a = torch.randn(10, 10, device=device, dtype=torch_dtype) + self.assertEqual(a.dtype, torch_dtype) + @unittest.skip("meh") def test_str(self): a = torch.ones(4, device=device)