mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
remove final range in heuristic [pr] (#11251)
all dims are based on AxisType now
This commit is contained in:
@@ -17,7 +17,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
strides0, strides1 = st0.real_strides(), st1.real_strides()
|
||||
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
|
||||
if strides0[k.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
|
||||
for global_idx in range(k.global_dims):
|
||||
for global_idx in k.axes_of(AxisType.GLOBAL):
|
||||
if k.full_shape[k.first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3:
|
||||
print(f"MATVEC: {k.full_shape=} {k.first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
|
||||
@@ -63,7 +63,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
upcasted_axis: set[int] = set()
|
||||
while resolve(prod(k.sts[0].shape[i] for i in k.upcastable_dims) >= 1024):
|
||||
xb_choices = []
|
||||
# consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP)
|
||||
# consider all upcastable axes with 3 or 4 upcast (128 on the DSP)
|
||||
for axis, upcast_amount in itertools.product(k.upcastable_dims, ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
|
||||
# if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue
|
||||
@@ -73,7 +73,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
|
||||
if xb_choices:
|
||||
xb_choices = sorted(xb_choices)
|
||||
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
|
||||
if DEBUG >= 4: print(f"more upcast axis : {xb_choices}")
|
||||
k.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
|
||||
upcasted_axis.add(xb_choices[0][2])
|
||||
else: break
|
||||
|
||||
Reference in New Issue
Block a user