prereqs for giving BUFFER UOps a ShapeTracker [pr] (#8809)

This commit is contained in:
qazal
2025-01-30 06:30:24 -05:00
committed by GitHub
parent 78c0455c7a
commit 9df8e34160

View File

@@ -142,8 +142,7 @@ def elementwise_view_right(root:UOp) -> UOp|None:
output_swizzle = swizzles[0]
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
# NOTE: swizzle resolves once we hit STORE
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(output_swizzle.shape))
return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
@@ -155,6 +154,8 @@ view_right = merge_views+PatternMatcher([
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
# STORE is the last child, so we just merge the ShapeTrackers and store the base
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
# REDUCE(src.view()) -> REDUCE(src).view()
@@ -303,7 +304,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
if p in ctx.realizes: continue
parents.extend([x.base.src[0] for x in p_uop.src if x.base.op is Ops.VIEW and len(x.base.src) != 0])
parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)])
if forced_realize or not group:
tr = r
if can_chase: