mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
refactor to axis_arg [run_process_replay] (#6806)
* refactor to axis_arg [run_process_replay] * remove more arg[1]s
This commit is contained in:
@@ -402,7 +402,7 @@ class Kernel:
|
||||
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
|
||||
check(self.first_reduce + self.group_for_reduces <= axis < self.first_upcast, "must be reduce axis to group")
|
||||
check(not self.tensor_core, "can't group with tensor cores")
|
||||
check(len(reduce_axes:=[i for r in self.reduceops for i in r.arg[1]]) == len(set(reduce_axes)), "can't group with parallel reduces")
|
||||
check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces")
|
||||
self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
|
||||
self.group_for_reduces += 1
|
||||
elif opt.op is OptOps.UNROLL: # purple
|
||||
@@ -749,7 +749,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
|
||||
return
|
||||
for x in uop.src: _assert_valid_uop(x, st, sts)
|
||||
# only reduceuop is allowed to change shape, limited to turning n to 1
|
||||
if uop.op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.arg[-1]))
|
||||
if uop.op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
|
||||
# movementops are pushed to SHAPETRACKER and SWIZZLE
|
||||
elif uop.op in {UOps.SHAPETRACKER, UOps.SWIZZLE}: st = uop.arg
|
||||
# everything else inherits shape
|
||||
|
||||
@@ -91,7 +91,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
|
||||
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
|
||||
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is UOps.RANGE)
|
||||
alu_op: BinaryOps = x.arg[0]
|
||||
ret = x.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
|
||||
@@ -69,7 +69,7 @@ def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTr
|
||||
def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
|
||||
if (swizzle_st:=unwrap(swizzle.st)).contiguous: return None
|
||||
rsrc = reduceop.src[0]
|
||||
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), reduceop.arg[1])
|
||||
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), reduceop.axis_arg)
|
||||
prshape = prod(rshape)
|
||||
strides = strides_for_shape(rshape)
|
||||
nv: List[View] = []
|
||||
@@ -78,7 +78,7 @@ def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
|
||||
# update input_st and axis
|
||||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||
_, new_rshape = permute_reduce(new_input_st, reduceop.arg[1])
|
||||
_, new_rshape = permute_reduce(new_input_st, reduceop.axis_arg)
|
||||
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
||||
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda st:st+new_input_st, {}),),
|
||||
(reduceop.arg[0], new_axis)).swizzle(ShapeTracker.from_shape(swizzle_st.shape))
|
||||
@@ -87,10 +87,9 @@ def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
|
||||
swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st)
|
||||
assert swizzle_st.contiguous, "can't push a non contiguous SWIZZLE down to STORE"
|
||||
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
|
||||
op, axis = root.arg
|
||||
output_shape = swizzle_st.reduce(axis)
|
||||
output_shape = swizzle_st.reduce(root.axis_arg)
|
||||
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
|
||||
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, (op, new_axis)).swizzle(ShapeTracker.from_shape(output_shape))
|
||||
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, (root.arg[0], new_axis)).swizzle(ShapeTracker.from_shape(output_shape))
|
||||
|
||||
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
swizzles = [x for x in root.src if x.op is UOps.SWIZZLE]
|
||||
@@ -106,8 +105,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
||||
new_axis: Tuple[int, ...] = root.arg[1]+first_reduce.arg[1]
|
||||
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], new_axis))
|
||||
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
||||
|
||||
reduceop_fusor = PatternMatcher([
|
||||
# push a SWIZZLE up to LOAD, through a reduce (eg. expands)
|
||||
|
||||
@@ -158,7 +158,7 @@ class UOp(MathTrait):
|
||||
src_sts = [x.st for x in self.src if x.st is not None]
|
||||
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.arg[1])) if self.op is UOps.REDUCE_AXIS else src_sts[0]
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is UOps.REDUCE_AXIS else src_sts[0]
|
||||
@functools.cached_property
|
||||
def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]:
|
||||
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
||||
@@ -179,6 +179,12 @@ class UOp(MathTrait):
|
||||
ret = self.src[0 if self.op is UOps.VALID else 1]
|
||||
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
|
||||
return ret.arg
|
||||
@property
|
||||
def axis_arg(self) -> Tuple[int, ...]:
|
||||
assert self.op in {UOps.REDUCE_AXIS, UOps.WMMA}, f"axis_arg called on {self.op}"
|
||||
ret = self.arg[1] if self.op is UOps.REDUCE_AXIS else self.arg[7]
|
||||
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
|
||||
return ret
|
||||
def sink(self, *srcs:UOp): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
|
||||
def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st)
|
||||
def const_like(self, b:ConstType|Variable|Tuple[ConstType, ...]): return UOp.const(self.dtype, b)
|
||||
|
||||
Reference in New Issue
Block a user