mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Enhance tensor random functions with dtype support (#10214)
* Enhance tensor random functions with dtype support - Updated `aten.uniform_` and `aten.normal_` to include dtype parameter in backend.py - Added unit tests for uniform and normal tensor generation with specific dtypes in test.py * Refactor test name for clarity - Renamed `test_normal_dtype` to `test_normal` in `extra/torch_backend/test.py` - Aims to improve readability and better reflect the test's purpose
This commit is contained in:
@@ -526,8 +526,8 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
self.assign(Tensor.randint(*self.shape, low=dtypes.min(self.dtype), high=dtypes.max(self.dtype), device=self.device, dtype=self.dtype))),
|
||||
"aten.random_.from": inplace_fn("self")(lambda self, from_, to:
|
||||
self.assign(Tensor.randint(*self.shape, low=from_, high=to, device=self.device, dtype=self.dtype))),
|
||||
"aten.uniform_": inplace_fn("self")(lambda self, low=0, high=1: self.assign(Tensor.uniform(*self.shape, low=low, high=high))),
|
||||
"aten.normal_": inplace_fn("self")(lambda self, mean=0, std=1: self.assign(Tensor.normal(*self.shape, mean=mean, std=std))),
|
||||
"aten.uniform_": inplace_fn("self")(lambda self, low=0, high=1: self.assign(Tensor.uniform(*self.shape, low=low, high=high, dtype=self.dtype))),
|
||||
"aten.normal_": inplace_fn("self")(lambda self, mean=0, std=1: self.assign(Tensor.normal(*self.shape, mean=mean, std=std, dtype=self.dtype))),
|
||||
# these don't work in out form, they have size 0
|
||||
"aten.abs": Tensor.abs,
|
||||
"aten.logical_not": Tensor.logical_not,
|
||||
|
||||
@@ -153,6 +153,16 @@ class TestTorchBackend(unittest.TestCase):
|
||||
res = torch.ops.aten.isin.Tensor_Tensor_out(a, b, invert=invert, assume_unique=assume_unique, out=out)
|
||||
np.testing.assert_equal(out.cpu().numpy(), expected.cpu().numpy())
|
||||
|
||||
def test_uniform(self):
|
||||
for torch_dtype in [torch.float32, torch.float16]:
|
||||
a = torch.rand(10, 10, device=device, dtype=torch_dtype)
|
||||
self.assertEqual(a.dtype, torch_dtype)
|
||||
|
||||
def test_normal(self):
|
||||
for torch_dtype in [torch.float32, torch.float16]:
|
||||
a = torch.randn(10, 10, device=device, dtype=torch_dtype)
|
||||
self.assertEqual(a.dtype, torch_dtype)
|
||||
|
||||
@unittest.skip("meh")
|
||||
def test_str(self):
|
||||
a = torch.ones(4, device=device)
|
||||
|
||||
Reference in New Issue
Block a user