mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix fuse gate_contiguous unique (#11504)
This commit is contained in:
@@ -86,6 +86,18 @@ class TestFuse(unittest.TestCase):
|
||||
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
|
||||
self._test_fuse(embedding, a, atol=1e-5)
|
||||
|
||||
def test_attention_kernel_count(self):
|
||||
wq = Tensor.empty(32, 32)
|
||||
wk = Tensor.empty(32, 32)
|
||||
wv = Tensor.empty(32, 32)
|
||||
x = Tensor.empty(2, 100, 32)
|
||||
q = (x @ wq).contiguous()
|
||||
k = (x @ wk).contiguous()
|
||||
v = (x @ wv).contiguous()
|
||||
attn = q.scaled_dot_product_attention(k, v).fuse()
|
||||
s = attn.schedule()
|
||||
self.assertEqual(len(s), 4) # 3 matmul and 1 attention
|
||||
|
||||
def test_flash_attention(self):
|
||||
BS = 4
|
||||
HEADS = 2
|
||||
|
||||
@@ -345,7 +345,7 @@ pm_fuse = PatternMatcher([
|
||||
def do_fusion(x:UOp):
|
||||
found_contiguous = {}
|
||||
def gate_contiguous(x):
|
||||
if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st),))
|
||||
if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st), UOp.unique()))
|
||||
return not is_contiguous
|
||||
x.toposort(gate=gate_contiguous)
|
||||
del gate_contiguous
|
||||
|
||||
Reference in New Issue
Block a user