From 0d55aec605faca5ecac007fd058199ec5bd0aa37 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Tue, 2 Dec 2025 18:42:58 -0800 Subject: [PATCH] fix after end (#13542) --- test/test_uop_graph.py | 12 +++++++++ test/testextra/test_tk.py | 53 +++++++++++++++++++++------------------ tinygrad/uop/ops.py | 1 + 3 files changed, 41 insertions(+), 25 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 16aba5c44a..5b8bcd094e 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -660,6 +660,18 @@ class TestUOpGraph(unittest.TestCase): bad_gate = UOp.const(dtypes.int, 1) with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))]) + def test_after_end(self): + r = UOp.range(10, 0) + + c = r + 1 + self.assertIn(r, c.ranges) + + e = UOp.const(dtypes.void, None).end(r) + self.assertNotIn(r, e.ranges) + + a = c.after(e) + self.assertNotIn(r, a.ranges) + @track_rewrites() def expander_rewrite(sink): return graph_rewrite(sink, sym + expander) diff --git a/test/testextra/test_tk.py b/test/testextra/test_tk.py index f34875295f..43c82d6859 100644 --- a/test/testextra/test_tk.py +++ b/test/testextra/test_tk.py @@ -422,28 +422,29 @@ class TestTK(unittest.TestCase): norm_vec = warp.zero(norm_vec) for tile_col in ker.range(N // BLOCK_SIZE): - a_smem = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2) - a_reg = warp.load(a_reg, a_smem) + a_smem_ = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2) + a_reg_ = warp.load(a_reg, a_smem_) - a_reg *= 1.0 / math.log(2) + a_reg_ *= 1.0 / math.log(2) max_vec_last = warp.copy(max_vec_last.after(tile_col), max_vec) - max_vec = warp.row_reduce(max_vec.after(max_vec_last), a_reg, lambda a, b: a.maximum(b), init_value=-math.inf) - a_reg = (a_reg - max_vec).exp2() + max_vec = warp.row_reduce(max_vec.after(max_vec_last), a_reg_, lambda a, b: a.maximum(b), init_value=-math.inf) + a_reg_ = (a_reg_ - max_vec).exp2() max_vec_last = (max_vec_last - max_vec).exp2() norm_vec *= max_vec_last - norm_vec = warp.row_reduce(norm_vec, a_reg, lambda a, b: a + b) + norm_vec = warp.row_reduce(norm_vec, a_reg_, lambda a, b: a + b) norm_vec = ker.endrange() + max_vec = max_vec.after(norm_vec) for tile_col in ker.range(N // BLOCK_SIZE): - a_smem = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2) - a_reg = warp.load(a_reg.after(norm_vec), a_smem) + a_smem_ = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2) + a_reg_ = warp.load(a_reg, a_smem_) - a_reg *= 1.0 / math.log(2) - a_reg = (a_reg - max_vec).exp2() - a_reg /= norm_vec + a_reg_ *= 1.0 / math.log(2) + a_reg_ = (a_reg_ - max_vec).exp2() + a_reg_ /= norm_vec - b = warp.store(b, a_reg, (0, 0, 0, tile_col), (), axis=2) + b = warp.store(b, a_reg_, (0, 0, 0, tile_col), (), axis=2) sink = ker.finish() @@ -481,28 +482,29 @@ class TestTK(unittest.TestCase): norm_vec = warp.zero(norm_vec) for tile_row in ker.range(N // BLOCK_SIZE): - a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2) - a_reg = warp.load(a_reg, a_smem) + a_smem_ = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2) + a_reg_ = warp.load(a_reg, a_smem_) - a_reg *= 1.0 / math.log(2) + a_reg_ *= 1.0 / math.log(2) max_vec_last = warp.copy(max_vec_last.after(tile_row), max_vec) - max_vec = warp.col_reduce(max_vec.after(max_vec_last), a_reg, lambda a, b: a.maximum(b), init_value=-math.inf) - a_reg = (a_reg - max_vec).exp2() + max_vec = warp.col_reduce(max_vec.after(max_vec_last), a_reg_, lambda a, b: a.maximum(b), init_value=-math.inf) + a_reg_ = (a_reg_ - max_vec).exp2() max_vec_last = (max_vec_last - max_vec).exp2() norm_vec *= max_vec_last - norm_vec = warp.col_reduce(norm_vec, a_reg, lambda a, b: a + b) + norm_vec = warp.col_reduce(norm_vec, a_reg_, lambda a, b: a + b) norm_vec = ker.endrange() + max_vec = max_vec.after(norm_vec) for tile_row in ker.range(N // BLOCK_SIZE): - a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2) - a_reg = warp.load(a_reg.after(norm_vec), a_smem) + a_smem_ = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2) + a_reg_ = warp.load(a_reg.after(norm_vec), a_smem_) - a_reg *= 1.0 / math.log(2) - a_reg = (a_reg - max_vec).exp2() - a_reg /= norm_vec + a_reg_ *= 1.0 / math.log(2) + a_reg_ = (a_reg_ - max_vec).exp2() + a_reg_ /= norm_vec - b = warp.store(b, a_reg, (0, 0, tile_row, 0), (), axis=2) + b = warp.store(b, a_reg_, (0, 0, tile_row, 0), (), axis=2) sink = ker.finish() @@ -605,6 +607,7 @@ class TestTK(unittest.TestCase): att_block_mma = warp.copy(att_block_mma.after(kv_idx, norm_vec), att_block) o_reg = warp.mma_AtB(o_reg, v_reg, att_block_mma) o_reg = ker.endrange() + norm_vec = norm_vec.after(o_reg) o_reg /= norm_vec @@ -630,7 +633,7 @@ class TestTK(unittest.TestCase): ref = q_permuted.scaled_dot_product_attention(k_permuted, v_permuted, is_causal=True, enable_gqa=True).float() ref = ref.permute(0, 2, 1, 3) - np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=1e-2, rtol=1e-5) + np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2) if __name__ == "__main__": unittest.main() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 0b49d24a0a..c0bc735824 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -314,6 +314,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): @functools.cached_property def ended_ranges(self): if self.op in range_start: return self.src[range_start[self.op]:] + if self.op is Ops.AFTER: return tuple(flatten([x.ended_ranges for x in self.src[1:]])) return () # determine what ranges this is in