add UOps.SWIZZLE (#6271)

* add UOps.SWIZZLE

* flip swizzle init

* generic st_fixup
This commit is contained in:
qazal
2024-08-26 16:08:51 +08:00
committed by GitHub
parent 002f60b4c3
commit 1c0456af89
2 changed files with 9 additions and 2 deletions

View File

@@ -83,6 +83,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
# reduce ops change ShapeTracker
if buf.op in ReduceOps:
swizzle = (UOp(UOps.SWIZZLE, src=(st.to_uop(),)),) if not st.contiguous and AST_REWRITE else ()
rinfo: Optional[Tuple[ShapeTracker, Tuple[int, ...]]] = (ShapeTracker.from_shape(buf.srcs[0].shape), buf.arg) \
if AST_REWRITE else reduce_info.get((buf, st))
rsrc = _recursive_uop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
@@ -91,7 +92,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
if rinfo is None:
assert rsrc.op is UOps.REDUCE_AXIS and rsrc.arg[0] is alu_op, f"can't merge reduceop {buf.op} with {rsrc}\n{st}"
return rsrc
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (alu_op, rinfo[1])))
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,)+swizzle, (alu_op, rinfo[1])))
# elementwise ops pass shapetracker
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
@@ -172,6 +173,11 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int
# ***** reduceop fusor *****
def apply_swizzle(root:UOp, rsrc:UOp, swizzle:UOp) -> UOp:
uop_sts: Dict[UOp, ShapeTracker] = {}
new_input_st, new_axis = swizzle_reduceop(get_output_st(rsrc, uop_sts), swizzle.arg, root.arg[1])
return replace(root, src=(st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), arg=(root.arg[0], new_axis))
def push_reduceop_shape(root:UOp) -> Optional[UOp]:
reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS]
if len(reduceops) == 0: return None
@@ -181,6 +187,7 @@ def push_reduceop_shape(root:UOp) -> Optional[UOp]:
return st_fixup(root, lambda st:st.reshape(rshape), uop_sts, {})
reduceop_fusor = PatternMatcher([
(UPat(UOps.REDUCE_AXIS, src=(UPat(name="rsrc"), UPat(UOps.SWIZZLE, src=(UPat(name="swizzle"),))), name="root"), apply_swizzle),
(UPat(UOps.STORE, name="root"), push_reduceop_shape),
])

View File

@@ -41,7 +41,7 @@ def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.
# the order of these UOps controls the order of the toposort
class UOps(Enum):
# ops that aren't rendered
SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); SHAPETRACKER = auto() # noqa: E702
SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); SHAPETRACKER = auto(); SWIZZLE = auto() # noqa: E702
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
CONST = auto(); SPECIAL = auto() # noqa: E702
NOOP = auto(); GEP = auto() # noqa: E702