mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
cleanup some scheduler rewrites [run_process_replay] (#6474)
This commit is contained in:
@@ -108,27 +108,22 @@ def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTr
|
||||
tmp = input_st.permute(permute_axis)
|
||||
return tmp, tmp.shape[-len(axis):]
|
||||
|
||||
def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[int, ...]]:
|
||||
# push the movementop to the buffer uop
|
||||
tmp, rshape = permute_reduce(input_st, axis)
|
||||
prshape = prod(rshape)
|
||||
strides = strides_for_shape(rshape)
|
||||
nv: List[View] = []
|
||||
for v in swizzle.views:
|
||||
nv.append(View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
|
||||
# update input_st and axis
|
||||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||
_, new_rshape = permute_reduce(new_input_st, axis)
|
||||
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
||||
return new_input_st, new_axis
|
||||
|
||||
# ***** reduceop fusor *****
|
||||
|
||||
def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
|
||||
if swizzle.arg.contiguous: return None
|
||||
rsrc = reduceop.src[0]
|
||||
new_input_st, new_axis = swizzle_reduceop(ShapeTracker.from_shape(unwrap(rsrc.st).shape), swizzle.arg, reduceop.arg[1])
|
||||
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), reduceop.arg[1])
|
||||
prshape = prod(rshape)
|
||||
strides = strides_for_shape(rshape)
|
||||
nv: List[View] = []
|
||||
for v in swizzle.arg.views:
|
||||
nv.append(View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
|
||||
# update input_st and axis
|
||||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||
_, new_rshape = permute_reduce(new_input_st, reduceop.arg[1])
|
||||
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
||||
return UOp(UOps.SWIZZLE, reduceop.dtype, (UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda st:st+new_input_st, {}),),
|
||||
(reduceop.arg[0], new_axis)),), ShapeTracker.from_shape(swizzle.arg.shape))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user