diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index 98d1fcdcea..6e525400d5 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType, AddrSpace from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import strides_for_shape, get_contraction -from tinygrad.opt.swizzler import view_left +from tinygrad.opt.swizzler import view_left, view_left_through_load class OptOps(Enum): TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702 @@ -503,4 +503,4 @@ class Kernel: self.finalized = True fixed_ast = fixup_ast(self.ast) del fixup_ast - return graph_rewrite(fixed_ast, view_left, name="fixup optimized AST") + return graph_rewrite(fixed_ast, view_left+view_left_through_load, name="fixup optimized AST") diff --git a/tinygrad/opt/swizzler.py b/tinygrad/opt/swizzler.py index a200bf9d60..44e50549f4 100644 --- a/tinygrad/opt/swizzler.py +++ b/tinygrad/opt/swizzler.py @@ -44,12 +44,18 @@ def reduce_push_add_ones(src:UOp, r:UOp, view:UOp): view_left = merge_views+PatternMatcher([ # view before elementwise and buffer ops - (UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.LOAD, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"), + (UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"), lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))), # if there's ones added after reduce, put this before the reduce (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones), ]) +view_left_through_load = PatternMatcher([ + # view before load + (UPat(Ops.VIEW, src=(UPat(Ops.LOAD, name="e"),), name="view"), + lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))), +]) + def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left") # change reduceop axes and input ShapeTrackers, view gets replaced with a reshape. @@ -96,7 +102,7 @@ view_right = merge_views+PatternMatcher([ # apply view after reduceops (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right), # apply view after elementwise ops - (UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS, Ops.LOAD, Ops.STORE}, name="root"), elementwise_view_right), + (UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS, Ops.STORE}, name="root"), elementwise_view_right), # merge axes for double reduce (invert of SPLIT_REDUCEOP=1) (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"), lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None), @@ -112,9 +118,7 @@ def check_load_st(glbl:UOp, view:UOp): raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) -fix_kernel_ops = PatternMatcher([ - # add the LOAD - (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: x.replace(tag=None).view(x.st).load() if x.tag is not None else None), +fix_kernel_ops = view_left_through_load+PatternMatcher([ # STORE (except for meta ops) (UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink: UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])), diff --git a/tinygrad/schedule/grouper.py b/tinygrad/schedule/grouper.py index 66063c0246..ea1c733e1c 100644 --- a/tinygrad/schedule/grouper.py +++ b/tinygrad/schedule/grouper.py @@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND, from tinygrad.shape.shapetracker import ShapeTracker ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, - Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL} + Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL, Ops.LOAD} # **** Grouper decides which of the UOps realize diff --git a/tinygrad/schedule/kernelize.py b/tinygrad/schedule/kernelize.py index 8c8e10f0df..8ad6db77dc 100644 --- a/tinygrad/schedule/kernelize.py +++ b/tinygrad/schedule/kernelize.py @@ -151,7 +151,7 @@ create_kernels = PatternMatcher([ early_buffer_ops = PatternMatcher([ # LOAD - (UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x), tag=1)), + (UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st).load()), # no SINK for meta ops (UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x), ]) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index da63e24ea0..add5ca012f 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -154,8 +154,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): sz = cast(PtrDType, self.dtype).size return ShapeTracker.from_shape((sz,)) if sz > 0 else None - # hack for PTX, CASTing the ptr loses the shape. even worse hack with tag - if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL and self.src[0].tag is None: return None + # hack for PTX, CASTing the ptr loses the shape + if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None # otherwise we get the shape from sources if not (src_sts := [x.st for x in self.src if x.st is not None]): return None