fix after end (#13542)

This commit is contained in:
wozeparrot
2025-12-02 18:42:58 -08:00
committed by GitHub
parent 8902781dc1
commit 0d55aec605
3 changed files with 41 additions and 25 deletions

View File

@@ -660,6 +660,18 @@ class TestUOpGraph(unittest.TestCase):
bad_gate = UOp.const(dtypes.int, 1)
with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
def test_after_end(self):
r = UOp.range(10, 0)
c = r + 1
self.assertIn(r, c.ranges)
e = UOp.const(dtypes.void, None).end(r)
self.assertNotIn(r, e.ranges)
a = c.after(e)
self.assertNotIn(r, a.ranges)
@track_rewrites()
def expander_rewrite(sink): return graph_rewrite(sink, sym + expander)

View File

@@ -422,28 +422,29 @@ class TestTK(unittest.TestCase):
norm_vec = warp.zero(norm_vec)
for tile_col in ker.range(N // BLOCK_SIZE):
a_smem = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2)
a_reg = warp.load(a_reg, a_smem)
a_smem_ = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2)
a_reg_ = warp.load(a_reg, a_smem_)
a_reg *= 1.0 / math.log(2)
a_reg_ *= 1.0 / math.log(2)
max_vec_last = warp.copy(max_vec_last.after(tile_col), max_vec)
max_vec = warp.row_reduce(max_vec.after(max_vec_last), a_reg, lambda a, b: a.maximum(b), init_value=-math.inf)
a_reg = (a_reg - max_vec).exp2()
max_vec = warp.row_reduce(max_vec.after(max_vec_last), a_reg_, lambda a, b: a.maximum(b), init_value=-math.inf)
a_reg_ = (a_reg_ - max_vec).exp2()
max_vec_last = (max_vec_last - max_vec).exp2()
norm_vec *= max_vec_last
norm_vec = warp.row_reduce(norm_vec, a_reg, lambda a, b: a + b)
norm_vec = warp.row_reduce(norm_vec, a_reg_, lambda a, b: a + b)
norm_vec = ker.endrange()
max_vec = max_vec.after(norm_vec)
for tile_col in ker.range(N // BLOCK_SIZE):
a_smem = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2)
a_reg = warp.load(a_reg.after(norm_vec), a_smem)
a_smem_ = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2)
a_reg_ = warp.load(a_reg, a_smem_)
a_reg *= 1.0 / math.log(2)
a_reg = (a_reg - max_vec).exp2()
a_reg /= norm_vec
a_reg_ *= 1.0 / math.log(2)
a_reg_ = (a_reg_ - max_vec).exp2()
a_reg_ /= norm_vec
b = warp.store(b, a_reg, (0, 0, 0, tile_col), (), axis=2)
b = warp.store(b, a_reg_, (0, 0, 0, tile_col), (), axis=2)
sink = ker.finish()
@@ -481,28 +482,29 @@ class TestTK(unittest.TestCase):
norm_vec = warp.zero(norm_vec)
for tile_row in ker.range(N // BLOCK_SIZE):
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2)
a_reg = warp.load(a_reg, a_smem)
a_smem_ = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2)
a_reg_ = warp.load(a_reg, a_smem_)
a_reg *= 1.0 / math.log(2)
a_reg_ *= 1.0 / math.log(2)
max_vec_last = warp.copy(max_vec_last.after(tile_row), max_vec)
max_vec = warp.col_reduce(max_vec.after(max_vec_last), a_reg, lambda a, b: a.maximum(b), init_value=-math.inf)
a_reg = (a_reg - max_vec).exp2()
max_vec = warp.col_reduce(max_vec.after(max_vec_last), a_reg_, lambda a, b: a.maximum(b), init_value=-math.inf)
a_reg_ = (a_reg_ - max_vec).exp2()
max_vec_last = (max_vec_last - max_vec).exp2()
norm_vec *= max_vec_last
norm_vec = warp.col_reduce(norm_vec, a_reg, lambda a, b: a + b)
norm_vec = warp.col_reduce(norm_vec, a_reg_, lambda a, b: a + b)
norm_vec = ker.endrange()
max_vec = max_vec.after(norm_vec)
for tile_row in ker.range(N // BLOCK_SIZE):
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2)
a_reg = warp.load(a_reg.after(norm_vec), a_smem)
a_smem_ = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2)
a_reg_ = warp.load(a_reg.after(norm_vec), a_smem_)
a_reg *= 1.0 / math.log(2)
a_reg = (a_reg - max_vec).exp2()
a_reg /= norm_vec
a_reg_ *= 1.0 / math.log(2)
a_reg_ = (a_reg_ - max_vec).exp2()
a_reg_ /= norm_vec
b = warp.store(b, a_reg, (0, 0, tile_row, 0), (), axis=2)
b = warp.store(b, a_reg_, (0, 0, tile_row, 0), (), axis=2)
sink = ker.finish()
@@ -605,6 +607,7 @@ class TestTK(unittest.TestCase):
att_block_mma = warp.copy(att_block_mma.after(kv_idx, norm_vec), att_block)
o_reg = warp.mma_AtB(o_reg, v_reg, att_block_mma)
o_reg = ker.endrange()
norm_vec = norm_vec.after(o_reg)
o_reg /= norm_vec
@@ -630,7 +633,7 @@ class TestTK(unittest.TestCase):
ref = q_permuted.scaled_dot_product_attention(k_permuted, v_permuted, is_causal=True, enable_gqa=True).float()
ref = ref.permute(0, 2, 1, 3)
np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=1e-2, rtol=1e-5)
np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2)
if __name__ == "__main__":
unittest.main()

View File

@@ -314,6 +314,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@functools.cached_property
def ended_ranges(self):
if self.op in range_start: return self.src[range_start[self.op]:]
if self.op is Ops.AFTER: return tuple(flatten([x.ended_ranges for x in self.src[1:]]))
return ()
# determine what ranges this is in