From 9df8e34160af88e26bea0e87f371afc873abe14f Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 30 Jan 2025 06:30:24 -0500 Subject: [PATCH] prereqs for giving BUFFER UOps a ShapeTracker [pr] (#8809) --- tinygrad/engine/schedule.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index c51b3606e3..9b3dbc5654 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -142,8 +142,7 @@ def elementwise_view_right(root:UOp) -> UOp|None: output_swizzle = swizzles[0] new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape) ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src)) - # NOTE: swizzle resolves once we hit STORE - return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(output_swizzle.shape)) + return ret.view(ShapeTracker.from_shape(output_swizzle.shape)) def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" @@ -155,6 +154,8 @@ view_right = merge_views+PatternMatcher([ # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val)) (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))), lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))), + # STORE is the last child, so we just merge the ShapeTrackers and store the base + (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)), # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view() (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)), # REDUCE(src.view()) -> REDUCE(src).view() @@ -303,7 +304,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]: if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False if p in ctx.realizes: continue - parents.extend([x.base.src[0] for x in p_uop.src if x.base.op is Ops.VIEW and len(x.base.src) != 0]) + parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)]) if forced_realize or not group: tr = r if can_chase: