mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
prereqs for giving BUFFER UOps a ShapeTracker [pr] (#8809)
This commit is contained in:
@@ -142,8 +142,7 @@ def elementwise_view_right(root:UOp) -> UOp|None:
|
||||
output_swizzle = swizzles[0]
|
||||
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
|
||||
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
|
||||
# NOTE: swizzle resolves once we hit STORE
|
||||
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(output_swizzle.shape))
|
||||
return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
@@ -155,6 +154,8 @@ view_right = merge_views+PatternMatcher([
|
||||
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
|
||||
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
|
||||
# STORE is the last child, so we just merge the ShapeTrackers and store the base
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
|
||||
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
|
||||
# REDUCE(src.view()) -> REDUCE(src).view()
|
||||
@@ -303,7 +304,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
|
||||
if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
|
||||
if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
|
||||
if p in ctx.realizes: continue
|
||||
parents.extend([x.base.src[0] for x in p_uop.src if x.base.op is Ops.VIEW and len(x.base.src) != 0])
|
||||
parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)])
|
||||
if forced_realize or not group:
|
||||
tr = r
|
||||
if can_chase:
|
||||
|
||||
Reference in New Issue
Block a user