mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Tensor.eye with arange (#13287)
with rangify we can write these with arange
This commit is contained in:
@@ -7,24 +7,28 @@ from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
from tinygrad.uop.ops import Ops
|
||||
|
||||
class TestArange(unittest.TestCase):
|
||||
def _get_flops(self, N):
|
||||
def _get_flops(self, tensor, desired):
|
||||
GlobalCounters.reset()
|
||||
tt = Tensor.arange(N)
|
||||
sched = tt.schedule()
|
||||
sched = tensor.schedule()
|
||||
self.assertEqual(len(sched), 1)
|
||||
p = get_program(sched[-1].ast)
|
||||
ExecItem(CompiledRunner(p), [tt.uop.buffer]).run()
|
||||
np.testing.assert_equal(tt.numpy(), np.arange(N))
|
||||
ExecItem(CompiledRunner(p), [tensor.uop.buffer]).run()
|
||||
np.testing.assert_equal(tensor.numpy(), desired)
|
||||
return p.estimates.ops
|
||||
|
||||
def test_complexity(self):
|
||||
self.assertEqual(self._get_flops(256), 0)
|
||||
self.assertEqual(self._get_flops(2560), 0)
|
||||
def test_arange_complexity(self):
|
||||
self.assertEqual(self._get_flops(Tensor.arange(256), np.arange(256)), 0)
|
||||
self.assertEqual(self._get_flops(Tensor.arange(2560), np.arange(2560)), 0)
|
||||
|
||||
def test_arange_cat(self):
|
||||
t = Tensor.arange(2, dtype=dtypes.int)+Tensor([3])
|
||||
self.assertEqual(t.cat(t).tolist(), [3, 4, 3, 4])
|
||||
|
||||
def test_eye_complexity(self):
|
||||
with Context(NOOPT=1):
|
||||
# NOTE: not every backend supports CMPEQ
|
||||
self.assertLessEqual(self._get_flops(Tensor.eye(2560).contiguous(), np.eye(2560)), 2*2560*2560)
|
||||
|
||||
DSET, DDIM = 2048, 32
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
|
||||
@@ -719,7 +719,7 @@ class Tensor(OpMixin):
|
||||
return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
|
||||
|
||||
@staticmethod
|
||||
def eye(n:int, m:int|None=None, **kwargs) -> Tensor:
|
||||
def eye(n:int, m:int|None=None, dtype=None, device=None, requires_grad:bool|None=None) -> Tensor:
|
||||
"""
|
||||
Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere.
|
||||
|
||||
@@ -734,9 +734,9 @@ class Tensor(OpMixin):
|
||||
print(Tensor.eye(2, 4).numpy())
|
||||
```
|
||||
"""
|
||||
if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
|
||||
x = Tensor.ones(n, **kwargs).diag()
|
||||
return x if m is None else x.pad((None, (0, m-n))) if m > n else x.shrink((None, (0, m)))
|
||||
if n < 0 or ((m := n if m is None else m) < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
|
||||
t = (Tensor.arange(n, device=device).unsqueeze(-1) == Tensor.arange(m, device=device))
|
||||
return t.cast(dtype or dtypes.default_float).requires_grad_(requires_grad)
|
||||
|
||||
def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user