diff --git a/test/test_schedule.py b/test/test_schedule.py index 734e32b813..abd9a71031 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -16,7 +16,7 @@ from tinygrad.shape.view import View from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.spec import type_verify, shape_spec -from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp +from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, all_same, temp from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, sym from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis @@ -178,6 +178,13 @@ class TestSchedule(unittest.TestCase): c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b with self.assertRaises(KernelCountException): check_schedule(c, 1) + def test_allow_push_permutes(self): + a = Tensor.randn(10,10,10).realize() + b = Tensor.randn(10,10,1).realize() + c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b + with Context(DONT_GROUP_REDUCES=1): run_schedule(check_schedule(c, 1)) + np.testing.assert_allclose(c.numpy(), np.sum(a.numpy(), axis=0, keepdims=True).transpose(2,1,0)+b.numpy()) + def test_binop_early_reshape_reduce_fusion(self): a = Tensor.empty(100) b = Tensor.empty(100) @@ -247,20 +254,14 @@ class TestSchedule(unittest.TestCase): def test_div_collapse_buffer(self): a = Tensor.full((4,), 4.0).contiguous().realize() b = Tensor.full((4,), 2.0).contiguous().realize() - GlobalCounters.reset() expr = (a*b)/b - expr.realize() - self.assertEqual(GlobalCounters.kernel_count, 0) # the scheduler can fold divs now! - self.assertEqual(GlobalCounters.global_ops, 0) + check_schedule(expr, 0) np.testing.assert_allclose(expr.numpy(), np.full((4,), 4.0)) def test_div_collapse_const(self): a = Tensor.full((4,), 4.0).contiguous().realize() - GlobalCounters.reset() expr = a/a - expr.realize() - self.assertEqual(GlobalCounters.kernel_count, 0) - self.assertEqual(GlobalCounters.global_ops, 0) + check_schedule(expr, 0) np.testing.assert_allclose(expr.numpy(), np.full((4,), 1.0)) def test_div_collapse(self): @@ -316,7 +317,7 @@ class TestSchedule(unittest.TestCase): def test_fold_double_unary(self): y = Tensor.empty(2) - out = y.sum(keepdim=True).sqrt().__neg__() + out = y.sum(keepdim=True).sqrt().neg() check_schedule(out, 1) #@unittest.skip("may want to reconsider this") @@ -1871,7 +1872,7 @@ class TestIndexing(unittest.TestCase): ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1))) r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) - r = r + 2 + r = r + r.const_like(2).replace(src=(unwrap(r.st).to_uop(),)) sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) rsink = graph_rewrite(sink, view_right) # this AST first needs to swizzle, but it doesn't have implicit movementops @@ -1889,126 +1890,73 @@ class TestIndexing(unittest.TestCase): @track_rewrites(named=True) def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right) - def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0]) -# these pattern matchers should move to engine/schedule.py - -ops_folding = symbolic_simple+PatternMatcher([ - (UPat(Ops.DETACH, name="x"), lambda x:x.src[0]), -]) - -def _load_buffer(ctx:list[UOp], buf:UOp): - glbl = UOp(Ops.DEFINE_GLOBAL, buf.dtype.ptr(size=buf.size), (), len(ctx)) - ctx.append(buf) - return UOp(Ops.LOAD, buf.dtype, (glbl, ShapeTracker.from_shape((buf.size,)).to_uop())) -load_buffers = PatternMatcher([ - (UPat(Ops.BUFFER, name="buf"), _load_buffer), -]) - -# put the entire schedule of the tensor in a single ScheduleItem -@track_rewrites(named=True) -def run_tensor_ast(r:Tensor): - output = UOp.new_buffer(r.device, r.lazydata.size, r.dtype) - glbl = UOp(Ops.DEFINE_GLOBAL, output.dtype.ptr(size=output.size), (), 0) - sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink() - sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output]) - sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right) - si = ScheduleItem(sink, tuple(x.buffer for x in bufs), ()) - run_schedule([si]) - return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist() - class TestSwizzle(unittest.TestCase): def test_swizzle_simple(self): + Tensor.manual_seed(0) with Context(DEBUG=0, TRACK_MATCH_STATS=0): a = Tensor.randint(32, 32).realize() - # double reduce collapses to a single reduce r = (a+a).sum(1).sum(0) - self.assertEqual(run_tensor_ast(r), (a.numpy()+a.numpy()).sum(1).sum(0)) + # double reduce collapses to a single reduce + with Context(DONT_GROUP_REDUCES=1): + run_schedule(check_schedule(r, 1)) + self.assertEqual(r.numpy(), (a.numpy()+a.numpy()).sum(1).sum(0)) def test_single_swizzle(self): + Tensor.manual_seed(0) with Context(DEBUG=0, TRACK_MATCH_STATS=0): a = Tensor.randint(4, 1).realize() b = Tensor.ones((1, 1), dtype=a.dtype).contiguous().realize() # ADD(REDUCE(RESHAPE(LOAD)), LOAD) to ADD(REDUCE(RESHAPE(LOAD))), RESHAPE(LOAD) r = a.sum(0)+b - self.assertEqual(run_tensor_ast(r), a.numpy().sum(0)+1) + run_schedule(check_schedule(r, 1)) + self.assertEqual(r.numpy(), a.numpy().sum(0)+1) def test_double_swizzle_possible(self): + Tensor.manual_seed(0) with Context(DEBUG=0, TRACK_MATCH_STATS=0): - Tensor.manual_seed(0) a = Tensor.randint(4,).realize() b = Tensor.randint(4,).realize() # parallel reduce! add = a.sum(0)+b.sum(0) - self.assertEqual(run_tensor_ast(add), a.numpy().sum(0)+b.numpy().sum(0)) + with Context(DONT_GROUP_REDUCES=1): + run_schedule(check_schedule(add, 1)) + self.assertEqual(add.numpy(), a.numpy().sum(0)+b.numpy().sum(0)) - # TODO: this is failing because it cannot resolve the final shape of two swizzled sources - @unittest.expectedFailure - def test_softmax(self): + @unittest.skip("TODO: how do we express the norm") + def test_softmax_one_kernel(self): + Tensor.manual_seed(0) with Context(DEBUG=0, TRACK_MATCH_STATS=0): - Tensor.manual_seed(0) a = Tensor.randn(32, 32).realize() t = a.softmax() - run_tensor_ast(t) + with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): + check_schedule(t, 1) - def test_swizzle_rewrite_alt(self): - swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501 - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501 - # there's an UNROLL pushing through the REDUCE_AXIS - self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape)) - ret = swizzle_rewrite(swizzle) - # UNROLL is rewritten - self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape)) - # and pushed to the LOAD - new_load_st = unwrap([x for x in ret.toposort if x.op is Ops.VIEW][0].st) - self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape)) - self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27)) + def test_argmax_one_kernel(self): + Tensor.manual_seed(0) + with Context(DEBUG=0, TRACK_MATCH_STATS=0): + a = Tensor.randn(10, 20).realize() + t = a.argmax(0) + with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): t.realize() + + def test_swizzle_reduceop(self): + Tensor.manual_seed(0) + x = Tensor.randn(4,4).realize() + y = Tensor.randn(4,4,4).realize() + out = x.reshape(4,4,1).expand(4,4,4).sum(axis=(1,))+y + with Context(DONT_REALIZE_EXPAND=1, DONT_GROUP_REDUCES=1): + run_schedule(check_schedule(out, 1)) + np.testing.assert_allclose(out.numpy(), np.tile(x.numpy().reshape(4,4,1), (1,1,4)).sum(axis=1)+y.numpy()) def test_permute_rewrite(self): x = Tensor.randn(4, 4, 16).realize() y = Tensor.randn(4, 1, 16).realize() z = Tensor.randn(4, 4, 1).realize() t = (x*y).sum(axis=(0, 2)).reshape(1, 4, 1).permute(0, 2, 1)+z + with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1)) t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy() - np.testing.assert_allclose(run_tensor_ast(t), t_np, atol=1e-6, rtol=1e-3) - - @unittest.expectedFailure - def test_fuse_conv2_relu_bw(self): - # fuse (relu bw, conv2d, conv2d bw, relu) - sink = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(10, ('METAL', 128, dtypes.float)), src=()), - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( - x6:=UOp(Ops.WHERE, dtypes.float, arg=None, src=( - UOp(Ops.VALID, dtypes.bool, arg=None, src=( - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - x9:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()), - x9,)), - UOp(Ops.MAX, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(9, ('METAL', 96, dtypes.float)), src=()), - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 16, 2, 2, 3, 3, 3), strides=(48, 0, 0, 4, 1, 16, 4, 1), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.PRELOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(16, ('METAL', 432, dtypes.float)), src=()), - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 16, 2, 2, 3, 3, 3), strides=(0, 0, 27, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), - x6,)),)),)), - UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(18, ('METAL', 128, dtypes.float)), src=()), - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 3, 2, 3), strides=(64, 4, 2, 0, 1, 0), offset=0, mask=((0, 2), (0, 16), (0, 2), (0, 1), (0, 2), (0, 1)), contiguous=False), View(shape=(1, 2, 1, 16, 3, 2, 3, 2), strides=(0, 576, 0, 36, 12, 6, 2, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)) - ret = swizzle_rewrite(sink) - self.assertEqual(swizzle_cnt(ret), 0) + np.testing.assert_allclose(t.numpy(), t_np, atol=1e-6, rtol=1e-3) @unittest.skip("this swizzle can't be decided after the ADD") def test_swizzle_failure_permute(self): @@ -2052,39 +2000,6 @@ class TestSwizzle(unittest.TestCase): ret = swizzle_rewrite(sink) self.assertEqual(swizzle_cnt(ret), 0) - def test_non_contiguous_view_simplify(self): - st = ShapeTracker(views=(View(shape=(2048, 2048), strides=(1, 2048), offset=0, mask=None, contiguous=False),)) - a = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, 4194304, dtypes.char), st.to_uop())) - ret = swizzle_rewrite(a.view(st)) - self.assertEqual(ret.st_arg, st+st) - - def test_contiguous_view_simplify(self): - base = ShapeTracker.from_shape((32, 32)) - a = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop())) - swizzle = a.reshape((64, 16)) - swizzle = graph_rewrite(swizzle, remove_movement_ops) - self.assertEqual(swizzle_cnt(swizzle), 1) - ret = swizzle_rewrite(swizzle) - self.assertEqual(ret.st_arg, base.reshape((64, 16))) # late rewrite - reswizzle = a.reshape((64, 16)).reshape((32, 32)) - self.assertEqual(swizzle_cnt(reswizzle), 0) # instant rule - ret = swizzle_rewrite(reswizzle) - self.assertEqual(ret.st, reswizzle.st) - - def test_late_fusion_post_permute_simpler(self): - base = ShapeTracker.from_shape((32, 16, 1)) - start = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop())) - r = start.expand((32, 16, 16)).r(Ops.ADD, (2,)) - add = r.reshape((16, 32, 1)) + UOp.const(r.dtype, 0) - self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1))) - to_store = add.permute((1, 0, 2)).contiguous() - to_store = graph_rewrite(to_store, remove_movement_ops) - self.assertEqual(to_store.st, ShapeTracker.from_shape((32, 16, 1))) - self.assertEqual(to_store.src[0].st, add.st.permute((1, 0, 2))) - self.assertIs(to_store.src[0].op, Ops.VIEW) - ret = graph_rewrite(to_store, view_left) - self.assertEqual(swizzle_cnt(ret), 1) - def store_val(si:ScheduleItem): return si.ast.src[0].src[2] zero_pm = UPat(Ops.CONST, arg=0) class TestView(unittest.TestCase): @@ -2159,7 +2074,7 @@ class TestView(unittest.TestCase): self.assertEqual(other_child.tolist(), [2, 3, 4]) def tensor_rewrite(t) -> UOp: return graph_rewrite(t.lazydata.base, remove_movement_ops+symbolic_simple) -class TestBigGraph(unittest.TestCase): +class TestSimplifier(unittest.TestCase): def test_sink_childless_const(self): x = Tensor(0) check_schedule(x, 0) @@ -2242,12 +2157,11 @@ class TestConst(unittest.TestCase): a = Tensor.ones((4,)).pad((1, 1)).contiguous() sched = a.schedule() print(sched[0].ast) - const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat.cvar("x"), UPat(Ops.CONST, arg=0)))),)) + const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat.where(UPat(Ops.VALID), UPat.cvar("x"), UPat(Ops.CONST, arg=0))),)) self.assertEqual(len(const_ast_pattern.match(sched[0].ast, {})), 1) run_schedule(sched) self.assertListEqual(a.tolist(), [0, 1, 1, 1, 1, 0]) - # TOOD: currently even unmasked constants are VALID until codegen def test_unmasked_const_ast(self): a = Tensor.ones((4,)).contiguous() sched = a.schedule() @@ -2641,5 +2555,12 @@ class TestUOpBecome(unittest.TestCase): assert b.lazydata is c.lazydata assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.lazydata, {}) + def test_setitem_becomes_view_of_base(self): + a = Tensor.full((4,), 2.).contiguous().realize() + b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0)) + b.realize() + assert b.lazydata.is_realized + assert b.lazydata.base.buffer._base is None + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index a8d87a3453..8a02529e02 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -5,7 +5,7 @@ from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, grap from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv, pluralize -from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND, SPLIT_REDUCEOP +from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP from tinygrad.dtype import ImageDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View, strides_for_shape @@ -165,6 +165,7 @@ create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), def group_realizes(sink:UOp) -> dict[UOp, None]: # start by adding uops that always realize sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict))) + if DONT_GROUP_REDUCES: return ctx.realizes # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) reduce_for_op: dict[UOp, UOp] = {} double_reduces: list[UOp] = [] @@ -230,6 +231,11 @@ class KernelContext: realizes: dict[UOp, None] ops_metadata: dict[UOp, Metadata] +def create_kernel(x:UOp, b:UOp): + kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x)) + buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset)) + return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape) + DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER} def append_to_kernel(ctx:KernelContext, x:UOp): new_srcs: list[UOp] = [] @@ -244,11 +250,14 @@ def append_to_kernel(ctx:KernelContext, x:UOp): return x.replace(arg=Kernel(x.arg.ast, new_metadata)) if (new_metadata:=tuple(metadata)) != x.arg.metadata else None create_kernels = merge_views+PatternMatcher([ - # always give assign a kernel - (UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), lambda x,b: b.assign(UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x)))), - # otherwise check if need to assign this UOp to a new buffer - (UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), lambda ctx,x: UOp(Ops.ASSIGN, x.dtype, (b:=UOp.new_buffer(x.device, x.size, x.dtype).view(x.st),\ - UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x)))) if x in ctx.realizes else None), + # always give assign/contiguous a kernel + (UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel), + (UPat(Ops.CONTIGUOUS, name="x"), lambda x: create_kernel(x, UOp.new_buffer(x.device, x.size, x.dtype))), + # create a buffer for COPY on the new device + (UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="d"), UPat()), name="x"), lambda d,x: create_kernel(x, UOp.new_buffer(d.arg, x.size, x.dtype))), + # otherwise check the context if we're realizing this UOp + (UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), + lambda ctx,x: create_kernel(x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None), # walk back the local graph until we reach a buffer/assign parent (UPat(Ops.KERNEL, name="x"), append_to_kernel), # remove CONST/BIND from SINK @@ -260,13 +269,9 @@ create_kernels = merge_views+PatternMatcher([ # ** create buffer ops + enumerate buffers -def load_buf(ctx:list[UOp], x:UOp): - if x not in ctx: ctx.append(x) - return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop())) - add_buffer_ops = PatternMatcher([ # LOAD - (UPat(Ops.BUFFER, name="x"), load_buf), + (UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))), # STORE (except for COPY/BUFFER_VIEW) (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x), (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), @@ -278,8 +283,9 @@ add_buffer_ops = PatternMatcher([ def apply_swizzle(u:UOp) -> UOp: with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left) -def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp: - input_st = ShapeTracker.from_shape(unwrap(src.st).shape) +def swizzle_reduceop(r:UOp, src:UOp, view:UOp): + if (st:=unwrap(view.st)).contiguous: return None + input_st = ShapeTracker.from_shape(src.shape) tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg) prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):]) strides = strides_for_shape(rshape) @@ -290,20 +296,18 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp: new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg))) return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) -def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp: - if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}") - output_shape = swizzle_st.reduce(r.axis_arg) - return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape)) +def reduceop_view_right(src:UOp, v:UOp, r:UOp): + assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}" + return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape)) def elementwise_view_right(root:UOp) -> UOp|None: - if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None - assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}" + if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW]): return None assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}" - # push the swizzle from src to root - output_swizzle = swizzles[0] - new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape) - ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src)) - return ret.view(ShapeTracker.from_shape(output_swizzle.shape)) + # place view after applying the elementwise op + new_shape = swizzles[0].base.shape + ret = root.replace(src=tuple(x.base if x.base.shape == new_shape else apply_swizzle(x.view(ShapeTracker.from_shape(new_shape))) for x in root.src)) + # reshape to match downstream shapes + return ret.reshape(root.shape) def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" @@ -317,12 +321,12 @@ view_right = merge_views+PatternMatcher([ lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))), # STORE is the last child, so we just merge the ShapeTrackers and store the base (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)), - # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view() - (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)), - # REDUCE(src.view()) -> REDUCE(src).view() - (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right), - # ALU(src.view()) -> ALU(src).view() - (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right), + # push a non contiguous ShapeTracker through reduceop + (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop), + # apply view after reduceops + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat.var("src"),), name="v"),), name="r"), reduceop_view_right), + # apply view after elementwise ops + (UPat(GroupOp.All-GroupOp.Buffer, name="root"), elementwise_view_right), # double reduce op collapses to a single reduce op (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) @@ -369,10 +373,10 @@ fix_kernel_ops = PatternMatcher([ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp: assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}" - # substitute kernel sources for the target buffer - ast = k.arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in k.src if s.op is Ops.ASSIGN}).sink() + # substitute kernel sources for the target buffer + apply reshapes + ast = k.arg.ast.substitute({(ast:=s.src[1].arg.ast):s.src[0].view(unwrap(ast.st)) for s in k.src if s.op is Ops.ASSIGN}).sink() # add buffer ops - ast = graph_rewrite(ast, add_buffer_ops, bufs:=[s.buf_uop for s in k.src], bottom_up=True) + ast = graph_rewrite(ast, add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True) if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}") # unbind_vars + push views to edges ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right) @@ -417,9 +421,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # map tensors to buffer/const, optionally apply a VIEW on top becomes_map: dict[UOp, UOp] = {} for k,v in tensor_map.items(): - # if we created a KERNEL for this tensor, map it to the assigned buffer - if (a:=kernel_map.get(v.base)) is not None and a.op is Ops.ASSIGN: - becomes_map[k] = a.src[0] if v is v.base else a.src[0].view(unwrap(v.st)) + # ASSIGN always becomes the target buffer + if v.op is Ops.ASSIGN: becomes_map[k] = v.src[0] + # if we created a new buffer for this tensor, map it to the assigned buffer + elif (a:=kernel_map.get(v.base)) is not None and (a:=a.base).op is Ops.ASSIGN: + becomes_map[k] = a.src[0] if a.src[0].st == v.st else a.src[0].view(unwrap(v.st)) # tensors can also simplify to an existing buffer/const else: if k is v: continue @@ -463,8 +469,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # TODO: move this to create_kernels k = fix_kernel_ast(u.src[1], var_vals) schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata)) - # increment the refcount of the target buf (this is required by the JIT and memory planner) - u.buf_uop.buffer.ref(1) + # increment the refcount of the target buf (this is required by the JIT and memory planner) TODO: this does not belong here + k.src[0].buffer.ref(1) for x in children.get(u, []): in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 0972f2488b..2b51316083 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -112,7 +112,7 @@ FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) -DONT_REALIZE_EXPAND = ContextVar("DONT_REALIZE_EXPAND", 0) +DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0) @dataclass(frozen=True) class Metadata: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index f2576f2598..2df87e2281 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -292,7 +292,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return ShapeTracker.from_shape( tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))) if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return ShapeTracker.from_shape((self.size,)) - if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape) + if self.op is Ops.KERNEL: return ShapeTracker.from_shape((self.arg.ast.size,)) # these ops define a ShapeTracker from the arg if self.op is Ops.VIEW: return self.arg if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 3e8e8f1e5d..8bf0ba08ed 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -584,7 +584,8 @@ class AMDDevice(HCQCompiled): sgrp_size_per_cu, lds_size_per_cu, hwreg_size_per_cu = 0x4000, 0x10000, 0x1000 vgpr_size_per_cu = 0x60000 if self.target in {110000, 110001, 120000, 120001} else 0x40000 wg_data_size = round_up((vgpr_size_per_cu + sgrp_size_per_cu + lds_size_per_cu + hwreg_size_per_cu) * (self.max_cu_id + 1), mmap.PAGESIZE) - ctl_stack_size = round_up(12 * (self.max_cu_id + 1) * (self.max_wave_id + 1) + 8 + 40, mmap.PAGESIZE) if self.target//10000 != 10 else 0x7000 + ctl_stack_size = round_up(12 * (self.max_cu_id + 1) * (self.max_wave_id + 1) + 8 + 40, mmap.PAGESIZE) + if self.target//10000 == 10: ctl_stack_size = min(ctl_stack_size, 0x7000) debug_memory_size = round_up((self.max_cu_id + 1) * (self.max_wave_id + 1) * 32, 64) self.compute_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_COMPUTE, 0x800000, ctx_save_restore_size=wg_data_size + ctl_stack_size, diff --git a/tinygrad/spec.py b/tinygrad/spec.py index d790728c9f..c94677abcc 100644 --- a/tinygrad/spec.py +++ b/tinygrad/spec.py @@ -1,7 +1,7 @@ from typing import cast from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType -from tinygrad.helpers import all_same, all_int, dedup, prod +from tinygrad.helpers import all_same, dedup, prod buffer_spec = PatternMatcher([ (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True), @@ -9,7 +9,7 @@ buffer_spec = PatternMatcher([ (UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"), lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))), (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"), - lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all_int(buf_view.arg)), + lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)), ]) # *** this is the spec of a Tensor in UOp *** @@ -126,10 +126,10 @@ spec = PatternMatcher([ # *** this is the spec of a Kernel in UOp *** kernel_spec = buffer_spec+PatternMatcher([ - (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True), - # assign has a buffer view and kernel source, it can optionally depend on other assigns - (UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True), - (UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}), lambda: False), + (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN))), lambda: True), + # assign has a buffer and kernel source, it can optionally depend on other assigns + (UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True), + (UPat(GroupOp.All-{Ops.SINK}), lambda: False), ]) # *** this is the UOp shape spec ***