diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 08af5d6cca..78c47dfacf 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -19,7 +19,7 @@ from tinygrad.codegen.uopgraph import full_graph_rewrite from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction class OptOps(Enum): - TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702 + TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702 GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702 def __lt__(self, x:OptOps): return self.value < x.value @@ -358,7 +358,7 @@ class Kernel: return False def apply_opt(self, opt:Opt, append_opt:bool=True): - if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals") + if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals") if opt.op is OptOps.TC: check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine @@ -414,14 +414,6 @@ class Kernel: check(amt <= 16, "don't upcast more than 16") self.shift_to(axis, amt, insert_before=None) self.upcast() - elif opt.op is OptOps.UPCASTMID: # white - check(self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501 - axes = self.sts[0].unit_stride_axes() - check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}") - check(axes[0] == axis, "wrong axis") - check(amt == 4, "don't upcast mid anything but 4") - self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces) - self.group_for_reduces += 1 elif opt.op is OptOps.NOLOCALS: check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals") check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals") @@ -490,13 +482,6 @@ class Kernel: break except KernelOptError: pass - # are we upcasting in mid reduce? (only for images) - if self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501 - axes = self.sts[0].unit_stride_axes() - assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" - if self.sts[0].shape[axes[0]]%4 == 0: - self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4)) - # upcast float4 images for buf_index,buf in enumerate(self.bufs): unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0] diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index e5e6e7241d..42c488d813 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -18,7 +18,7 @@ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29, actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)] if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)] actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=6, amt=2)] -actions += [Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)] +actions += [Opt(op=OptOps.TC, axis=0, amt=0)] actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce) actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)] if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 12639799c5..a766f610bf 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -161,7 +161,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db"))) CACHELEVEL = getenv("CACHELEVEL", 2) -VERSION = 16 +VERSION = 17 _db_connection = None def db_connection(): global _db_connection