mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix after end (#13542)
This commit is contained in:
@@ -660,6 +660,18 @@ class TestUOpGraph(unittest.TestCase):
|
||||
bad_gate = UOp.const(dtypes.int, 1)
|
||||
with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
|
||||
|
||||
def test_after_end(self):
|
||||
r = UOp.range(10, 0)
|
||||
|
||||
c = r + 1
|
||||
self.assertIn(r, c.ranges)
|
||||
|
||||
e = UOp.const(dtypes.void, None).end(r)
|
||||
self.assertNotIn(r, e.ranges)
|
||||
|
||||
a = c.after(e)
|
||||
self.assertNotIn(r, a.ranges)
|
||||
|
||||
@track_rewrites()
|
||||
def expander_rewrite(sink): return graph_rewrite(sink, sym + expander)
|
||||
|
||||
|
||||
@@ -422,28 +422,29 @@ class TestTK(unittest.TestCase):
|
||||
norm_vec = warp.zero(norm_vec)
|
||||
|
||||
for tile_col in ker.range(N // BLOCK_SIZE):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2)
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
a_smem_ = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2)
|
||||
a_reg_ = warp.load(a_reg, a_smem_)
|
||||
|
||||
a_reg *= 1.0 / math.log(2)
|
||||
a_reg_ *= 1.0 / math.log(2)
|
||||
|
||||
max_vec_last = warp.copy(max_vec_last.after(tile_col), max_vec)
|
||||
max_vec = warp.row_reduce(max_vec.after(max_vec_last), a_reg, lambda a, b: a.maximum(b), init_value=-math.inf)
|
||||
a_reg = (a_reg - max_vec).exp2()
|
||||
max_vec = warp.row_reduce(max_vec.after(max_vec_last), a_reg_, lambda a, b: a.maximum(b), init_value=-math.inf)
|
||||
a_reg_ = (a_reg_ - max_vec).exp2()
|
||||
max_vec_last = (max_vec_last - max_vec).exp2()
|
||||
norm_vec *= max_vec_last
|
||||
norm_vec = warp.row_reduce(norm_vec, a_reg, lambda a, b: a + b)
|
||||
norm_vec = warp.row_reduce(norm_vec, a_reg_, lambda a, b: a + b)
|
||||
norm_vec = ker.endrange()
|
||||
max_vec = max_vec.after(norm_vec)
|
||||
|
||||
for tile_col in ker.range(N // BLOCK_SIZE):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2)
|
||||
a_reg = warp.load(a_reg.after(norm_vec), a_smem)
|
||||
a_smem_ = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2)
|
||||
a_reg_ = warp.load(a_reg, a_smem_)
|
||||
|
||||
a_reg *= 1.0 / math.log(2)
|
||||
a_reg = (a_reg - max_vec).exp2()
|
||||
a_reg /= norm_vec
|
||||
a_reg_ *= 1.0 / math.log(2)
|
||||
a_reg_ = (a_reg_ - max_vec).exp2()
|
||||
a_reg_ /= norm_vec
|
||||
|
||||
b = warp.store(b, a_reg, (0, 0, 0, tile_col), (), axis=2)
|
||||
b = warp.store(b, a_reg_, (0, 0, 0, tile_col), (), axis=2)
|
||||
|
||||
sink = ker.finish()
|
||||
|
||||
@@ -481,28 +482,29 @@ class TestTK(unittest.TestCase):
|
||||
norm_vec = warp.zero(norm_vec)
|
||||
|
||||
for tile_row in ker.range(N // BLOCK_SIZE):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2)
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
a_smem_ = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2)
|
||||
a_reg_ = warp.load(a_reg, a_smem_)
|
||||
|
||||
a_reg *= 1.0 / math.log(2)
|
||||
a_reg_ *= 1.0 / math.log(2)
|
||||
|
||||
max_vec_last = warp.copy(max_vec_last.after(tile_row), max_vec)
|
||||
max_vec = warp.col_reduce(max_vec.after(max_vec_last), a_reg, lambda a, b: a.maximum(b), init_value=-math.inf)
|
||||
a_reg = (a_reg - max_vec).exp2()
|
||||
max_vec = warp.col_reduce(max_vec.after(max_vec_last), a_reg_, lambda a, b: a.maximum(b), init_value=-math.inf)
|
||||
a_reg_ = (a_reg_ - max_vec).exp2()
|
||||
max_vec_last = (max_vec_last - max_vec).exp2()
|
||||
norm_vec *= max_vec_last
|
||||
norm_vec = warp.col_reduce(norm_vec, a_reg, lambda a, b: a + b)
|
||||
norm_vec = warp.col_reduce(norm_vec, a_reg_, lambda a, b: a + b)
|
||||
norm_vec = ker.endrange()
|
||||
max_vec = max_vec.after(norm_vec)
|
||||
|
||||
for tile_row in ker.range(N // BLOCK_SIZE):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2)
|
||||
a_reg = warp.load(a_reg.after(norm_vec), a_smem)
|
||||
a_smem_ = warp.load(a_smem, a, (), (0, 0, tile_row, 0), axis=2)
|
||||
a_reg_ = warp.load(a_reg.after(norm_vec), a_smem_)
|
||||
|
||||
a_reg *= 1.0 / math.log(2)
|
||||
a_reg = (a_reg - max_vec).exp2()
|
||||
a_reg /= norm_vec
|
||||
a_reg_ *= 1.0 / math.log(2)
|
||||
a_reg_ = (a_reg_ - max_vec).exp2()
|
||||
a_reg_ /= norm_vec
|
||||
|
||||
b = warp.store(b, a_reg, (0, 0, tile_row, 0), (), axis=2)
|
||||
b = warp.store(b, a_reg_, (0, 0, tile_row, 0), (), axis=2)
|
||||
|
||||
sink = ker.finish()
|
||||
|
||||
@@ -605,6 +607,7 @@ class TestTK(unittest.TestCase):
|
||||
att_block_mma = warp.copy(att_block_mma.after(kv_idx, norm_vec), att_block)
|
||||
o_reg = warp.mma_AtB(o_reg, v_reg, att_block_mma)
|
||||
o_reg = ker.endrange()
|
||||
norm_vec = norm_vec.after(o_reg)
|
||||
|
||||
o_reg /= norm_vec
|
||||
|
||||
@@ -630,7 +633,7 @@ class TestTK(unittest.TestCase):
|
||||
ref = q_permuted.scaled_dot_product_attention(k_permuted, v_permuted, is_causal=True, enable_gqa=True).float()
|
||||
ref = ref.permute(0, 2, 1, 3)
|
||||
|
||||
np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=1e-2, rtol=1e-5)
|
||||
np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -314,6 +314,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
@functools.cached_property
|
||||
def ended_ranges(self):
|
||||
if self.op in range_start: return self.src[range_start[self.op]:]
|
||||
if self.op is Ops.AFTER: return tuple(flatten([x.ended_ranges for x in self.src[1:]]))
|
||||
return ()
|
||||
|
||||
# determine what ranges this is in
|
||||
|
||||
Reference in New Issue
Block a user