mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
move permute_reduces to uop movementops [run_process_replay] (#6272)
This commit is contained in:
@@ -102,11 +102,6 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
|
||||
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))
|
||||
|
||||
def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
||||
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis
|
||||
tmp = input_st.permute(permute_axis)
|
||||
return tmp, tmp.shape[-len(axis):]
|
||||
|
||||
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],
|
||||
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
|
||||
cache:Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]]) -> \
|
||||
@@ -153,9 +148,14 @@ def reshape_uop(u:UOp, new_shape:Tuple[sint, ...], uop_sts:Dict[UOp, ShapeTracke
|
||||
cache[u] = reshaped = u if new_srcs == u.src else replace(u, src=new_srcs)
|
||||
return reshaped
|
||||
|
||||
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
||||
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
|
||||
tmp = input_st.permute(permute_axis)
|
||||
return tmp, tmp.shape[-len(axis):]
|
||||
|
||||
def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[int, ...]]:
|
||||
# push the movementop to the buffer uop
|
||||
tmp, rshape = _permute_reduce(input_st, axis)
|
||||
tmp, rshape = permute_reduce(input_st, axis)
|
||||
prshape = prod(rshape)
|
||||
strides = strides_for_shape(rshape)
|
||||
nv: List[View] = []
|
||||
@@ -164,7 +164,7 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
|
||||
# update input_st and axis
|
||||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||
_, new_rshape = _permute_reduce(new_input_st, axis)
|
||||
_, new_rshape = permute_reduce(new_input_st, axis)
|
||||
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
||||
return new_input_st, new_axis
|
||||
|
||||
|
||||
Reference in New Issue
Block a user