Enhance tensor random functions with dtype support (#10214)

* Enhance tensor random functions with dtype support
- Updated `aten.uniform_` and `aten.normal_` to include dtype parameter in backend.py
- Added unit tests for uniform and normal tensor generation with specific dtypes in test.py

* Refactor test name for clarity
- Renamed `test_normal_dtype` to `test_normal` in `extra/torch_backend/test.py`
- Aims to improve readability and better reflect the test's purpose
This commit is contained in:
Xingyu
2025-05-09 08:48:07 +08:00
committed by GitHub
parent b6904bbf83
commit a21369d039
2 changed files with 12 additions and 2 deletions

View File

@@ -526,8 +526,8 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
self.assign(Tensor.randint(*self.shape, low=dtypes.min(self.dtype), high=dtypes.max(self.dtype), device=self.device, dtype=self.dtype))),
"aten.random_.from": inplace_fn("self")(lambda self, from_, to:
self.assign(Tensor.randint(*self.shape, low=from_, high=to, device=self.device, dtype=self.dtype))),
"aten.uniform_": inplace_fn("self")(lambda self, low=0, high=1: self.assign(Tensor.uniform(*self.shape, low=low, high=high))),
"aten.normal_": inplace_fn("self")(lambda self, mean=0, std=1: self.assign(Tensor.normal(*self.shape, mean=mean, std=std))),
"aten.uniform_": inplace_fn("self")(lambda self, low=0, high=1: self.assign(Tensor.uniform(*self.shape, low=low, high=high, dtype=self.dtype))),
"aten.normal_": inplace_fn("self")(lambda self, mean=0, std=1: self.assign(Tensor.normal(*self.shape, mean=mean, std=std, dtype=self.dtype))),
# these don't work in out form, they have size 0
"aten.abs": Tensor.abs,
"aten.logical_not": Tensor.logical_not,

View File

@@ -153,6 +153,16 @@ class TestTorchBackend(unittest.TestCase):
res = torch.ops.aten.isin.Tensor_Tensor_out(a, b, invert=invert, assume_unique=assume_unique, out=out)
np.testing.assert_equal(out.cpu().numpy(), expected.cpu().numpy())
def test_uniform(self):
for torch_dtype in [torch.float32, torch.float16]:
a = torch.rand(10, 10, device=device, dtype=torch_dtype)
self.assertEqual(a.dtype, torch_dtype)
def test_normal(self):
for torch_dtype in [torch.float32, torch.float16]:
a = torch.randn(10, 10, device=device, dtype=torch_dtype)
self.assertEqual(a.dtype, torch_dtype)
@unittest.skip("meh")
def test_str(self):
a = torch.ones(4, device=device)