From c9d763d33150629cc6d2b65fe35068c6bc3f72e8 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 30 Sep 2024 09:37:31 +0800 Subject: [PATCH] refactor to axis_arg [run_process_replay] (#6806) * refactor to axis_arg [run_process_replay] * remove more arg[1]s --- tinygrad/codegen/kernel.py | 4 ++-- tinygrad/codegen/lowerer.py | 2 +- tinygrad/engine/schedule.py | 12 +++++------- tinygrad/ops.py | 8 +++++++- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 7925a96077..952a425d6f 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -402,7 +402,7 @@ class Kernel: check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem") check(self.first_reduce + self.group_for_reduces <= axis < self.first_upcast, "must be reduce axis to group") check(not self.tensor_core, "can't group with tensor cores") - check(len(reduce_axes:=[i for r in self.reduceops for i in r.arg[1]]) == len(set(reduce_axes)), "can't group with parallel reduces") + check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces") self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces) self.group_for_reduces += 1 elif opt.op is OptOps.UNROLL: # purple @@ -749,7 +749,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> return for x in uop.src: _assert_valid_uop(x, st, sts) # only reduceuop is allowed to change shape, limited to turning n to 1 - if uop.op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.arg[-1])) + if uop.op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg)) # movementops are pushed to SHAPETRACKER and SWIZZLE elif uop.op in {UOps.SHAPETRACKER, UOps.SWIZZLE}: st = uop.arg # everything else inherits shape diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 5d3c52ee9c..174cd20b0f 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -91,7 +91,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext: def lower_reduce_axis(ctx: IndexContext, x: UOp): # NOTE: always using ridxs is fine here - reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE) + reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is UOps.RANGE) alu_op: BinaryOps = x.arg[0] ret = x.src[0] if len(contract_axis:=flatten(x.arg for x in reduce_expand)): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 3677a56e0c..7da288fb81 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -69,7 +69,7 @@ def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTr def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]: if (swizzle_st:=unwrap(swizzle.st)).contiguous: return None rsrc = reduceop.src[0] - tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), reduceop.arg[1]) + tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), reduceop.axis_arg) prshape = prod(rshape) strides = strides_for_shape(rshape) nv: List[View] = [] @@ -78,7 +78,7 @@ def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]: v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None)) # update input_st and axis new_input_st = tmp + ShapeTracker(tuple(nv)) - _, new_rshape = permute_reduce(new_input_st, reduceop.arg[1]) + _, new_rshape = permute_reduce(new_input_st, reduceop.axis_arg) new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape))) return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda st:st+new_input_st, {}),), (reduceop.arg[0], new_axis)).swizzle(ShapeTracker.from_shape(swizzle_st.shape)) @@ -87,10 +87,9 @@ def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp: swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st) assert swizzle_st.contiguous, "can't push a non contiguous SWIZZLE down to STORE" assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE" - op, axis = root.arg - output_shape = swizzle_st.reduce(axis) + output_shape = swizzle_st.reduce(root.axis_arg) new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u) - return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, (op, new_axis)).swizzle(ShapeTracker.from_shape(output_shape)) + return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, (root.arg[0], new_axis)).swizzle(ShapeTracker.from_shape(output_shape)) def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]: swizzles = [x for x in root.src if x.op is UOps.SWIZZLE] @@ -106,8 +105,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]: def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time" - new_axis: Tuple[int, ...] = root.arg[1]+first_reduce.arg[1] - return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], new_axis)) + return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)) reduceop_fusor = PatternMatcher([ # push a SWIZZLE up to LOAD, through a reduce (eg. expands) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 6cf968e839..7fae98f09c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -158,7 +158,7 @@ class UOp(MathTrait): src_sts = [x.st for x in self.src if x.st is not None] assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}" from tinygrad.shape.shapetracker import ShapeTracker - return ShapeTracker.from_shape(src_sts[0].reduce(self.arg[1])) if self.op is UOps.REDUCE_AXIS else src_sts[0] + return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is UOps.REDUCE_AXIS else src_sts[0] @functools.cached_property def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]: # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX @@ -179,6 +179,12 @@ class UOp(MathTrait): ret = self.src[0 if self.op is UOps.VALID else 1] assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}" return ret.arg + @property + def axis_arg(self) -> Tuple[int, ...]: + assert self.op in {UOps.REDUCE_AXIS, UOps.WMMA}, f"axis_arg called on {self.op}" + ret = self.arg[1] if self.op is UOps.REDUCE_AXIS else self.arg[7] + assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}" + return ret def sink(self, *srcs:UOp): return UOp(UOps.SINK, dtypes.void, (self,)+srcs) def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st) def const_like(self, b:ConstType|Variable|Tuple[ConstType, ...]): return UOp.const(self.dtype, b)