mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add uop.swizzle(st) (#6476)
This commit is contained in:
@@ -79,8 +79,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
# reduce ops change ShapeTracker
|
||||
if buf.op in ReduceOps:
|
||||
rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, realizes, assign_targets, cache)
|
||||
ret = UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg))
|
||||
return cache.setdefault((buf, st), UOp(UOps.SWIZZLE, dtype, (ret,), st))
|
||||
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).swizzle(st))
|
||||
|
||||
# elementwise ops pass shapetracker
|
||||
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, cache) for x in buf.srcs)
|
||||
@@ -124,14 +123,13 @@ def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
|
||||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||
_, new_rshape = permute_reduce(new_input_st, reduceop.arg[1])
|
||||
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
||||
return UOp(UOps.SWIZZLE, reduceop.dtype, (UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda st:st+new_input_st, {}),),
|
||||
(reduceop.arg[0], new_axis)),), ShapeTracker.from_shape(swizzle.arg.shape))
|
||||
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda st:st+new_input_st, {}),),
|
||||
(reduceop.arg[0], new_axis)).swizzle(ShapeTracker.from_shape(swizzle.arg.shape))
|
||||
|
||||
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
|
||||
assert swizzle.arg.contiguous, "can't push a non contiguous SWIZZLE down to STORE"
|
||||
assert prod(swizzle.arg.shape) == prod(unwrap(swizzle.src[0].st).shape), "can't push expands down to STORE"
|
||||
return UOp(UOps.SWIZZLE, root.dtype, (UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, root.arg),),
|
||||
ShapeTracker.from_shape(unwrap(swizzle.st).reduce(root.arg[1])))
|
||||
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, root.arg).swizzle(ShapeTracker.from_shape(unwrap(swizzle.st).reduce(root.arg[1])))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
@@ -147,7 +145,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
fixup_cache: Dict[UOp, UOp] = {}
|
||||
new_srcs = [x.src[0] if x.op is UOps.SWIZZLE else st_fixup(x, lambda st:st.reshape(sw_input_shape), fixup_cache) for x in root.src]
|
||||
ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg)
|
||||
return ret if ret.op is UOps.STORE else UOp(UOps.SWIZZLE, None, (ret,), ShapeTracker.from_shape(sw_shape))
|
||||
return ret if ret.op is UOps.STORE else ret.swizzle(ShapeTracker.from_shape(sw_shape))
|
||||
|
||||
reduceop_fusor = PatternMatcher([
|
||||
# push a SWIZZLE up to LOAD, through a reduce (eg. expands)
|
||||
|
||||
@@ -376,6 +376,7 @@ class UOp(MathTrait):
|
||||
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
|
||||
return ret.arg
|
||||
def sink(self, *srcs): return UOp(UOps.SINK, None, (self,)+srcs)
|
||||
def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st)
|
||||
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i)
|
||||
|
||||
Reference in New Issue
Block a user