mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
@@ -1,5 +1,5 @@
|
||||
import itertools
|
||||
from tinygrad.opt.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
from tinygrad.opt.kernel import Kernel, Opt, OptOps, KernelOptError, AxisType
|
||||
from tinygrad.helpers import getenv, DEBUG, all_int, prod, NOLOCALS
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.uop.ops import Ops, resolve
|
||||
@@ -72,8 +72,8 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
for axis, upcast_amount in itertools.product(range(k.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
|
||||
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and isinstance(k.full_shape[axis], int) and k.full_shape[axis]%upcast_amount == 0 and \
|
||||
any(st.views[-1].strides[axis] == 0 and not any(x == 0 for x in k.sts[buf_index].real_strides()[k.first_upcast:]) \
|
||||
for buf_index, st in enumerate(k.sts)):
|
||||
any(st.views[-1].strides[axis] == 0 and not any(x == 0 for x in st.real_strides()[k.first_upcast:]) \
|
||||
for st in k.sts):
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts),
|
||||
sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
|
||||
if xb_choices:
|
||||
@@ -83,10 +83,9 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
upcasted_axis.add(xb_choices[0][2])
|
||||
else: break
|
||||
|
||||
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
|
||||
# if last dim is small(ish) and it's a reduce dim, loop unroll the reduce
|
||||
if k.first_reduce < k.first_upcast and \
|
||||
(prod(k.full_shape[k.first_upcast:]) <= 4 or (k.sts[0].shape[k.first_upcast:] == k.full_shape[k.first_upcast:])) and \
|
||||
(k.upcasted == 0 or prod(k.full_shape[-k.upcasted:]) < 64):
|
||||
(prod(k.full_shape[k.first_upcast:]) <= 4 or (AxisType.UNROLL not in k.axis_types)) and (prod(k.full_shape[k.first_upcast:]) < 64):
|
||||
if isinstance(s:=k.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis
|
||||
k.apply_opt(Opt(OptOps.UNROLL, len(k.full_unupcasted_shape)-1-k.first_reduce, 0))
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
|
||||
Reference in New Issue
Block a user