mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
build Tensor._tri with movements only (#5110)
* build Tensor._tri with movements only doesn't need arange, saved a kernel in attention mask * simpler, more tests
This commit is contained in:
@@ -301,6 +301,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,3)], lambda x: x.tril(2))
|
||||
helper_test_op([(3,3)], lambda x: x.tril(-1))
|
||||
helper_test_op([(3,3)], lambda x: x.tril(-2))
|
||||
helper_test_op([(4,5)], lambda x: x.tril(4))
|
||||
helper_test_op([(4,5)], lambda x: x.tril(5))
|
||||
helper_test_op([(4,5)], lambda x: x.tril(6))
|
||||
helper_test_op([(4,5)], lambda x: x.tril(-4))
|
||||
helper_test_op([(4,5)], lambda x: x.tril(-5))
|
||||
helper_test_op([(4,5)], lambda x: x.tril(-6))
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril())
|
||||
helper_test_op([(5,0,3)], lambda x: x.tril())
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril(1))
|
||||
@@ -311,6 +317,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,3)], lambda x: x.triu(2))
|
||||
helper_test_op([(3,3)], lambda x: x.triu(-1))
|
||||
helper_test_op([(3,3)], lambda x: x.triu(-2))
|
||||
helper_test_op([(4,5)], lambda x: x.triu(4))
|
||||
helper_test_op([(4,5)], lambda x: x.triu(5))
|
||||
helper_test_op([(4,5)], lambda x: x.triu(6))
|
||||
helper_test_op([(4,5)], lambda x: x.triu(-4))
|
||||
helper_test_op([(4,5)], lambda x: x.triu(-5))
|
||||
helper_test_op([(4,5)], lambda x: x.triu(-6))
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu())
|
||||
helper_test_op([(5,0,3)], lambda x: x.triu())
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu(1))
|
||||
|
||||
@@ -31,7 +31,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
||||
for i, s in enumerate(sched):
|
||||
print("kernel", i+1)
|
||||
for op in s.ast: print_tree(op)
|
||||
if len(sched) != allowed: raise KernelCountException()
|
||||
if len(sched) != allowed: raise KernelCountException(f"{len(sched)=} != {allowed}")
|
||||
# test the (non loadops) ops linearize
|
||||
for s in sched:
|
||||
if s.ast[0].op in LoadOps: continue
|
||||
@@ -797,7 +797,7 @@ class TestSchedule(unittest.TestCase):
|
||||
def test_scaled_dot_product_attention_causal_fusion(self):
|
||||
x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4))
|
||||
out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m, is_causal=True)
|
||||
check_schedule(out, 7)
|
||||
check_schedule(out, 6)
|
||||
|
||||
def test_adam_step_fusion(self):
|
||||
with Tensor.train():
|
||||
|
||||
@@ -1796,9 +1796,13 @@ class Tensor:
|
||||
|
||||
@staticmethod
|
||||
def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
|
||||
assert all_int((r,c)), "does not support symbolic"
|
||||
if r == 0: return Tensor.zeros((r, c), **kwargs)
|
||||
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-diagonal, c-diagonal, **kwargs).unsqueeze(0).expand(r,c)
|
||||
assert isinstance(r, int) and isinstance(c, int), f"does not support symbolic, getting {r=}, {c=}"
|
||||
if r == 0 or c == 0 or diagonal >= c: return Tensor.zeros(r,c,**kwargs)
|
||||
if r+diagonal <= 0: return Tensor.ones(r,c,**kwargs)
|
||||
s = r+c-1
|
||||
# build a (s, s) upper triangle
|
||||
t = Tensor.ones(s,s,**kwargs).pad((None,(0,s))).flatten().shrink(((0,s*(2*s-1)),)).reshape(s,-1).shrink((None,(0,s)))
|
||||
return t[:r,-diagonal:c-diagonal] if diagonal <= 0 else t[diagonal:r+diagonal,:c]
|
||||
|
||||
def triu(self, diagonal:int=0) -> Tensor:
|
||||
"""
|
||||
@@ -1821,7 +1825,7 @@ class Tensor:
|
||||
print(t.triu(diagonal=-1).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device).where(self, 0).cast(self.dtype)
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, 0).cast(self.dtype)
|
||||
|
||||
def tril(self, diagonal:int=0) -> Tensor:
|
||||
"""
|
||||
@@ -1844,7 +1848,7 @@ class Tensor:
|
||||
print(t.tril(diagonal=-1).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device).where(0, self).cast(self.dtype)
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
|
||||
|
||||
# ***** unary ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user