mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
remove Kernel.first_reduce [pr] (#11269)
This commit is contained in:
@@ -16,11 +16,12 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
st0, st1 = k.sts[k.bufs.index(mulop.src[0])], k.sts[k.bufs.index(mulop.src[1])]
|
||||
strides0, strides1 = st0.real_strides(), st1.real_strides()
|
||||
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
|
||||
if strides0[k.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
|
||||
if strides0[first_reduce:=(k.axes_of(AxisType.REDUCE)[0])] == 1 and \
|
||||
not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
|
||||
for global_idx in k.axes_of(AxisType.GLOBAL):
|
||||
if k.full_shape[k.first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if k.full_shape[first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3:
|
||||
print(f"MATVEC: {k.full_shape=} {k.first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
|
||||
print(f"MATVEC: {k.full_shape=} {first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
|
||||
if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||||
if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||||
if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
||||
@@ -41,7 +42,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims:
|
||||
k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
|
||||
elif axis in k.unrollable_dims:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, axis-k.first_reduce, 4))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims.index(axis), 4))
|
||||
|
||||
# no more opt if we are grouping
|
||||
if k.group_for_reduces: return k.applied_opts
|
||||
@@ -82,14 +83,14 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
upcast_size = prod(k.full_shape[a] for a in k.axes_of(AxisType.UPCAST, AxisType.UNROLL))
|
||||
if k.unrollable_dims and (upcast_size <= 4 or not k.axes_of(AxisType.UNROLL)) and (upcast_size < 64):
|
||||
if (s:=k.full_shape[k.unrollable_dims[-1]]) <= 32:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims[-1]-k.first_reduce, 0))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
if k.unrollable_dims and s <= 3 and k.full_shape[k.unrollable_dims[-1]] <= 3:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims[-1]-k.first_reduce, 0))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
|
||||
else:
|
||||
for splits in [4]:
|
||||
if k.full_shape[axis:=k.unrollable_dims[-1]]%splits == 0:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, axis-k.first_reduce, splits))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, splits))
|
||||
break
|
||||
|
||||
# if nothing at all is upcasted and it's easy to, do an upcast
|
||||
|
||||
@@ -114,9 +114,6 @@ class Kernel:
|
||||
|
||||
return ret
|
||||
|
||||
@property
|
||||
def first_reduce(self) -> int: return self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)[0]
|
||||
|
||||
@property
|
||||
def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user