add uop.swizzle(st) (#6476)

This commit is contained in:
qazal
2024-09-11 16:52:42 +08:00
committed by GitHub
parent 78148e16d8
commit 5cc142c8b8
2 changed files with 6 additions and 7 deletions

View File

@@ -79,8 +79,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
# reduce ops change ShapeTracker
if buf.op in ReduceOps:
rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, realizes, assign_targets, cache)
ret = UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg))
return cache.setdefault((buf, st), UOp(UOps.SWIZZLE, dtype, (ret,), st))
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).swizzle(st))
# elementwise ops pass shapetracker
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, cache) for x in buf.srcs)
@@ -124,14 +123,13 @@ def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
new_input_st = tmp + ShapeTracker(tuple(nv))
_, new_rshape = permute_reduce(new_input_st, reduceop.arg[1])
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
return UOp(UOps.SWIZZLE, reduceop.dtype, (UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda st:st+new_input_st, {}),),
(reduceop.arg[0], new_axis)),), ShapeTracker.from_shape(swizzle.arg.shape))
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda st:st+new_input_st, {}),),
(reduceop.arg[0], new_axis)).swizzle(ShapeTracker.from_shape(swizzle.arg.shape))
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
assert swizzle.arg.contiguous, "can't push a non contiguous SWIZZLE down to STORE"
assert prod(swizzle.arg.shape) == prod(unwrap(swizzle.src[0].st).shape), "can't push expands down to STORE"
return UOp(UOps.SWIZZLE, root.dtype, (UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, root.arg),),
ShapeTracker.from_shape(unwrap(swizzle.st).reduce(root.arg[1])))
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, root.arg).swizzle(ShapeTracker.from_shape(unwrap(swizzle.st).reduce(root.arg[1])))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
@@ -147,7 +145,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
fixup_cache: Dict[UOp, UOp] = {}
new_srcs = [x.src[0] if x.op is UOps.SWIZZLE else st_fixup(x, lambda st:st.reshape(sw_input_shape), fixup_cache) for x in root.src]
ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg)
return ret if ret.op is UOps.STORE else UOp(UOps.SWIZZLE, None, (ret,), ShapeTracker.from_shape(sw_shape))
return ret if ret.op is UOps.STORE else ret.swizzle(ShapeTracker.from_shape(sw_shape))
reduceop_fusor = PatternMatcher([
# push a SWIZZLE up to LOAD, through a reduce (eg. expands)

View File

@@ -376,6 +376,7 @@ class UOp(MathTrait):
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
return ret.arg
def sink(self, *srcs): return UOp(UOps.SINK, None, (self,)+srcs)
def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st)
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i)