mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add UOps.SWIZZLE (#6271)
* add UOps.SWIZZLE * flip swizzle init * generic st_fixup
This commit is contained in:
@@ -83,6 +83,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
|
||||
# reduce ops change ShapeTracker
|
||||
if buf.op in ReduceOps:
|
||||
swizzle = (UOp(UOps.SWIZZLE, src=(st.to_uop(),)),) if not st.contiguous and AST_REWRITE else ()
|
||||
rinfo: Optional[Tuple[ShapeTracker, Tuple[int, ...]]] = (ShapeTracker.from_shape(buf.srcs[0].shape), buf.arg) \
|
||||
if AST_REWRITE else reduce_info.get((buf, st))
|
||||
rsrc = _recursive_uop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
|
||||
@@ -91,7 +92,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
if rinfo is None:
|
||||
assert rsrc.op is UOps.REDUCE_AXIS and rsrc.arg[0] is alu_op, f"can't merge reduceop {buf.op} with {rsrc}\n{st}"
|
||||
return rsrc
|
||||
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (alu_op, rinfo[1])))
|
||||
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,)+swizzle, (alu_op, rinfo[1])))
|
||||
|
||||
# elementwise ops pass shapetracker
|
||||
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
|
||||
@@ -172,6 +173,11 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int
|
||||
|
||||
# ***** reduceop fusor *****
|
||||
|
||||
def apply_swizzle(root:UOp, rsrc:UOp, swizzle:UOp) -> UOp:
|
||||
uop_sts: Dict[UOp, ShapeTracker] = {}
|
||||
new_input_st, new_axis = swizzle_reduceop(get_output_st(rsrc, uop_sts), swizzle.arg, root.arg[1])
|
||||
return replace(root, src=(st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), arg=(root.arg[0], new_axis))
|
||||
|
||||
def push_reduceop_shape(root:UOp) -> Optional[UOp]:
|
||||
reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS]
|
||||
if len(reduceops) == 0: return None
|
||||
@@ -181,6 +187,7 @@ def push_reduceop_shape(root:UOp) -> Optional[UOp]:
|
||||
return st_fixup(root, lambda st:st.reshape(rshape), uop_sts, {})
|
||||
|
||||
reduceop_fusor = PatternMatcher([
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(name="rsrc"), UPat(UOps.SWIZZLE, src=(UPat(name="swizzle"),))), name="root"), apply_swizzle),
|
||||
(UPat(UOps.STORE, name="root"), push_reduceop_shape),
|
||||
])
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.
|
||||
# the order of these UOps controls the order of the toposort
|
||||
class UOps(Enum):
|
||||
# ops that aren't rendered
|
||||
SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); SHAPETRACKER = auto() # noqa: E702
|
||||
SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); SHAPETRACKER = auto(); SWIZZLE = auto() # noqa: E702
|
||||
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
||||
CONST = auto(); SPECIAL = auto() # noqa: E702
|
||||
NOOP = auto(); GEP = auto() # noqa: E702
|
||||
|
||||
Reference in New Issue
Block a user