mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
@@ -28,8 +28,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
|
||||
if k.opts.has_local and k.opts.has_shared and all_int(k.sts[0].shape[:k.first_reduce]):
|
||||
# are we grouping? (requires local shape support)
|
||||
if not [x for x in k.sts[0].unit_stride_axes() if x >= k.first_upcast and k.sts[0].shape[x]%4 == 0] and \
|
||||
k.first_reduce <= 2 and k.first_reduce < k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
|
||||
if k.first_reduce <= 2 and k.first_reduce < k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
for sz in ([256, 16] if prod(k.sts[0].shape[:k.first_reduce]) <= 32 else [16]):
|
||||
try: # may fail due to excessive smem usage
|
||||
@@ -40,13 +39,11 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
# upcast float4 images
|
||||
for buf_index,buf in enumerate(k.bufs):
|
||||
unit_stride_axes_mul_4 = [i for i in k.sts[buf_index].unit_stride_axes(ignore_valid=True) if k.sts[buf_index].shape[i]%4 == 0]
|
||||
if buf.src[0].dtype.__class__ is ImageDType:
|
||||
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {k.bufs[buf_index]}"
|
||||
if len(unit_stride_axes_mul_4) and all(x < k.first_upcast for x in unit_stride_axes_mul_4):
|
||||
if unit_stride_axes_mul_4[0] < k.first_reduce:
|
||||
k.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
||||
else:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-k.first_reduce, 4))
|
||||
if buf.src[0].dtype.__class__ is ImageDType and len(unit_stride_axes_mul_4) and all(x < k.first_upcast for x in unit_stride_axes_mul_4):
|
||||
if unit_stride_axes_mul_4[0] < k.first_reduce:
|
||||
k.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
||||
else:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-k.first_reduce, 4))
|
||||
|
||||
# no more opt if we are grouping
|
||||
if k.group_for_reduces: return k.applied_opts
|
||||
|
||||
Reference in New Issue
Block a user