create randperm and support pytorch backend (#10019)

This commit is contained in:
Park Jun
2025-04-24 20:29:02 +09:00
committed by GitHub
parent b545338e59
commit c3ad7b2a84
4 changed files with 25 additions and 1 deletions

View File

@@ -112,7 +112,8 @@ def index_put(self, indices, values, accumulate=False):
def isin_tensor_tensor_out(x, y, *, assume_unique=False, invert=False, out=None): return out.copy_(aten.isin(x.cpu(), y.cpu(), assume_unique=assume_unique, invert=invert).tiny())
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator, device="cpu").tiny())
def randperm_generator(n, generator=None, out=None):
return out.copy_(wrap(Tensor.randperm(n, generator=generator, device=unwrap(out).device)))
@torch.library.impl("aten::cummax", "privateuseone")
def cummax(self, dim):

View File

@@ -11,6 +11,16 @@ else:
device = "tiny"
class TestTorchBackend(unittest.TestCase):
def test_randperm_generator_out(self):
n = 10
out = torch.empty(n, dtype=torch.long, device=device)
res = torch.randperm(n, out=out).cpu().numpy()
np.testing.assert_equal(set(res), set(range(n)))
np.testing.assert_equal(out.cpu().numpy(), res)
res2 = torch.randperm(n).cpu().numpy()
np.testing.assert_equal(set(res2), set(range(n)))
def test_numpy_ones(self):
a = torch.ones(4, device=device)
np.testing.assert_equal(a.cpu().numpy(), [1,1,1,1])

View File

@@ -237,6 +237,13 @@ class TestTinygrad(unittest.TestCase):
b = random_fn(10,10).realize()
np.testing.assert_allclose(a.numpy(), b.numpy())
def test_randperm(self):
Tensor.manual_seed(0)
a = Tensor.randperm(10).realize()
np.testing.assert_equal(a.numpy(), [5, 2, 8, 1, 3, 7, 9, 6, 0, 4])
b = Tensor.randperm(1000).realize()
np.testing.assert_equal(set(b.numpy()), set(range(1000)))
def test_randn_isnt_inf_on_zero(self):
# simulate failure case of rand handing a zero to randn
original_rand, Tensor.rand = Tensor.rand, Tensor.zeros

View File

@@ -866,6 +866,12 @@ class Tensor(SimpleMathTrait):
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
@staticmethod
def randperm(n: int, *, device=None, dtype=dtypes.int32, **kwargs) -> Tensor:
r = Tensor.rand(n, device=device, **kwargs)
_, indices = r.sort()
return indices.cast(dtype)
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"