From 529832d223b0c63acf3f6d253a76180cdacce211 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 13 Aug 2024 02:15:11 +0800 Subject: [PATCH] refactor ast creation [compare_schedule] (#6050) * refactor scheduler lazyop creation [compare_schedule] * helpful prints * this will become the default --- tinygrad/engine/schedule.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f68a0e092d..40fd24b3e2 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -57,20 +57,18 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, if (buf, st) in cache: return cache[(buf, st)] arg = buf.arg - # consts are always fused and generated - if buf.op is MetaOps.CONST: - unbound_st, st_var_vals = st.simplify().unbind() - var_vals.update(st_var_vals) - if isinstance(arg, Variable): - arg, var_val = arg.unbind() - var_vals[arg] = var_val - else: assert isinstance(arg, get_args(ConstType)), f"cannot create ConstBuffer with value {arg}" - return LazyOp(BufferOps.CONST, (), ConstBuffer(arg, buf.dtype, unbound_st)) - - # if we aren't fusing it, it's a load and we add it to the inputs + # buffer ops define ShapeTracker if buf.realized is not None or (buf in realizes and buf not in outputs): unbound_st, st_var_vals = st.simplify().unbind() var_vals.update(st_var_vals) + # if it's a const, we generate it + if buf.op is MetaOps.CONST: + if isinstance(arg, Variable): + arg, var_val = arg.unbind() + var_vals[arg] = var_val + else: assert isinstance(arg, get_args(ConstType)), f"cannot create ConstBuffer with value {arg}" + return LazyOp(BufferOps.CONST, (), ConstBuffer(arg, buf.dtype, unbound_st)) + # otherwise, it's a load and we add it to the inputs if buf in assign_targets: # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine if unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and\ @@ -80,12 +78,7 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.setdefault(buf, len(inputs)), buf.dtype, unbound_st)) - # if a CONTIGUOUS or ASSIGN made it all the way here, just skip it - if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}: - assert buf in outputs - return _recursive_lazyop(buf.srcs[0], st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) - - # if it's a reduce, we have to change the shapetracker + # reduce ops change ShapeTracker if buf.op in ReduceOps: # if we are merging the reduce, skip it if (buf, st) not in reduce_info: @@ -93,9 +86,12 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, return _recursive_lazyop(buf.srcs[0], st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) st, arg = reduce_info[(buf, st)] - # otherwise we fuse it like normal - return cache.setdefault((buf, st), LazyOp(cast(Op,buf.op), tuple(_recursive_lazyop(x, st, outputs, var_vals, inputs, realizes, assign_targets, \ - reduce_info, cache) for x in buf.srcs), arg)) + # elementwise ops pass shapetracker + in_ops = tuple(_recursive_lazyop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs) + if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}: + assert buf in outputs, f"{buf.op} must be writable" + return in_ops[0] + return cache.setdefault((buf, st), LazyOp(cast(Op, buf.op), in_ops, arg)) def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]: permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis