Tensor.eye with arange (#13287)

with rangify we can write these with arange
This commit is contained in:
chenyu
2025-11-15 09:32:27 -08:00
committed by GitHub
parent 5b823af696
commit e8844853ed
2 changed files with 16 additions and 12 deletions

View File

@@ -7,24 +7,28 @@ from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.uop.ops import Ops
class TestArange(unittest.TestCase):
def _get_flops(self, N):
def _get_flops(self, tensor, desired):
GlobalCounters.reset()
tt = Tensor.arange(N)
sched = tt.schedule()
sched = tensor.schedule()
self.assertEqual(len(sched), 1)
p = get_program(sched[-1].ast)
ExecItem(CompiledRunner(p), [tt.uop.buffer]).run()
np.testing.assert_equal(tt.numpy(), np.arange(N))
ExecItem(CompiledRunner(p), [tensor.uop.buffer]).run()
np.testing.assert_equal(tensor.numpy(), desired)
return p.estimates.ops
def test_complexity(self):
self.assertEqual(self._get_flops(256), 0)
self.assertEqual(self._get_flops(2560), 0)
def test_arange_complexity(self):
self.assertEqual(self._get_flops(Tensor.arange(256), np.arange(256)), 0)
self.assertEqual(self._get_flops(Tensor.arange(2560), np.arange(2560)), 0)
def test_arange_cat(self):
t = Tensor.arange(2, dtype=dtypes.int)+Tensor([3])
self.assertEqual(t.cat(t).tolist(), [3, 4, 3, 4])
def test_eye_complexity(self):
with Context(NOOPT=1):
# NOTE: not every backend supports CMPEQ
self.assertLessEqual(self._get_flops(Tensor.eye(2560).contiguous(), np.eye(2560)), 2*2560*2560)
DSET, DDIM = 2048, 32
class TestIndexing(unittest.TestCase):

View File

@@ -719,7 +719,7 @@ class Tensor(OpMixin):
return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
@staticmethod
def eye(n:int, m:int|None=None, **kwargs) -> Tensor:
def eye(n:int, m:int|None=None, dtype=None, device=None, requires_grad:bool|None=None) -> Tensor:
"""
Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere.
@@ -734,9 +734,9 @@ class Tensor(OpMixin):
print(Tensor.eye(2, 4).numpy())
```
"""
if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
x = Tensor.ones(n, **kwargs).diag()
return x if m is None else x.pad((None, (0, m-n))) if m > n else x.shrink((None, (0, m)))
if n < 0 or ((m := n if m is None else m) < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
t = (Tensor.arange(n, device=device).unsqueeze(-1) == Tensor.arange(m, device=device))
return t.cast(dtype or dtypes.default_float).requires_grad_(requires_grad)
def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
"""