fix fuse gate_contiguous unique (#11504)

This commit is contained in:
chenyu
2025-08-04 20:43:31 -07:00
committed by GitHub
parent 7f6acfb0d5
commit f02720ca2d
2 changed files with 13 additions and 1 deletions

View File

@@ -86,6 +86,18 @@ class TestFuse(unittest.TestCase):
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
self._test_fuse(embedding, a, atol=1e-5)
def test_attention_kernel_count(self):
wq = Tensor.empty(32, 32)
wk = Tensor.empty(32, 32)
wv = Tensor.empty(32, 32)
x = Tensor.empty(2, 100, 32)
q = (x @ wq).contiguous()
k = (x @ wk).contiguous()
v = (x @ wv).contiguous()
attn = q.scaled_dot_product_attention(k, v).fuse()
s = attn.schedule()
self.assertEqual(len(s), 4) # 3 matmul and 1 attention
def test_flash_attention(self):
BS = 4
HEADS = 2

View File

@@ -345,7 +345,7 @@ pm_fuse = PatternMatcher([
def do_fusion(x:UOp):
found_contiguous = {}
def gate_contiguous(x):
if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st),))
if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st), UOp.unique()))
return not is_contiguous
x.toposort(gate=gate_contiguous)
del gate_contiguous