mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
remove UPCASTMID [pr] (#8173)
This commit is contained in:
@@ -19,7 +19,7 @@ from tinygrad.codegen.uopgraph import full_graph_rewrite
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
|
||||
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@@ -358,7 +358,7 @@ class Kernel:
|
||||
return False
|
||||
|
||||
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
||||
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
|
||||
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
|
||||
|
||||
if opt.op is OptOps.TC:
|
||||
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
||||
@@ -414,14 +414,6 @@ class Kernel:
|
||||
check(amt <= 16, "don't upcast more than 16")
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op is OptOps.UPCASTMID: # white
|
||||
check(self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
|
||||
check(axes[0] == axis, "wrong axis")
|
||||
check(amt == 4, "don't upcast mid anything but 4")
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
|
||||
self.group_for_reduces += 1
|
||||
elif opt.op is OptOps.NOLOCALS:
|
||||
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
||||
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
|
||||
@@ -490,13 +482,6 @@ class Kernel:
|
||||
break
|
||||
except KernelOptError: pass
|
||||
|
||||
# are we upcasting in mid reduce? (only for images)
|
||||
if self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
if self.sts[0].shape[axes[0]]%4 == 0:
|
||||
self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
|
||||
|
||||
# upcast float4 images
|
||||
for buf_index,buf in enumerate(self.bufs):
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
||||
|
||||
@@ -18,7 +18,7 @@ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,
|
||||
actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
|
||||
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=6, amt=2)]
|
||||
actions += [Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
|
||||
actions += [Opt(op=OptOps.TC, axis=0, amt=0)]
|
||||
actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
|
||||
actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)]
|
||||
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
||||
|
||||
@@ -161,7 +161,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
|
||||
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
|
||||
CACHELEVEL = getenv("CACHELEVEL", 2)
|
||||
|
||||
VERSION = 16
|
||||
VERSION = 17
|
||||
_db_connection = None
|
||||
def db_connection():
|
||||
global _db_connection
|
||||
|
||||
Reference in New Issue
Block a user