remove UPCASTMID [pr] (#8173)

This commit is contained in:
George Hotz
2024-12-11 17:29:01 -08:00
committed by GitHub
parent f86e0014b7
commit 151ac5f5a2
3 changed files with 4 additions and 19 deletions

View File

@@ -19,7 +19,7 @@ from tinygrad.codegen.uopgraph import full_graph_rewrite
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
class OptOps(Enum):
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
def __lt__(self, x:OptOps): return self.value < x.value
@@ -358,7 +358,7 @@ class Kernel:
return False
def apply_opt(self, opt:Opt, append_opt:bool=True):
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
if opt.op is OptOps.TC:
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
@@ -414,14 +414,6 @@ class Kernel:
check(amt <= 16, "don't upcast more than 16")
self.shift_to(axis, amt, insert_before=None)
self.upcast()
elif opt.op is OptOps.UPCASTMID: # white
check(self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
axes = self.sts[0].unit_stride_axes()
check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
check(axes[0] == axis, "wrong axis")
check(amt == 4, "don't upcast mid anything but 4")
self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
self.group_for_reduces += 1
elif opt.op is OptOps.NOLOCALS:
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
@@ -490,13 +482,6 @@ class Kernel:
break
except KernelOptError: pass
# are we upcasting in mid reduce? (only for images)
if self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
axes = self.sts[0].unit_stride_axes()
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
if self.sts[0].shape[axes[0]]%4 == 0:
self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
# upcast float4 images
for buf_index,buf in enumerate(self.bufs):
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]

View File

@@ -18,7 +18,7 @@ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,
actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=6, amt=2)]
actions += [Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
actions += [Opt(op=OptOps.TC, axis=0, amt=0)]
actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)]
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]

View File

@@ -161,7 +161,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
CACHELEVEL = getenv("CACHELEVEL", 2)
VERSION = 16
VERSION = 17
_db_connection = None
def db_connection():
global _db_connection