mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
explore global uop cache [pr] (#6863)
* explore global uop cache * wvd uops * remove useless lru caches * key is is * simpler rewriter
This commit is contained in:
@@ -99,7 +99,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
|
||||
c2 = UOp(UOps.ALU, dtypes.float, (y1, y2), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
self.assertEqual(matcher.rewrite(c2), c1)
|
||||
|
||||
def test_dtype(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
|
||||
|
||||
@@ -186,7 +186,7 @@ def is_increasing(f:UOp):
|
||||
|
||||
def replace_uop(uop:UOp, old:UOp, new:UOp):
|
||||
# replace all `old` in `uop` to `new`
|
||||
return new if uop.key == old.key else uop.replace(src=tuple(replace_uop(s, old, new) for s in uop.src))
|
||||
return new if uop is old else uop.replace(src=tuple(replace_uop(s, old, new) for s in uop.src))
|
||||
|
||||
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
||||
# if it's X <= c, returns X, True, c
|
||||
@@ -229,8 +229,8 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
|
||||
newidxs[1].append(newidx.src[1])
|
||||
|
||||
# if every branch in candidate gives the same simplified output, we can rewrite the idx
|
||||
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
|
||||
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
|
||||
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same(newidxs[0])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
|
||||
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same(newidxs[1])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
|
||||
return idx
|
||||
|
||||
def simplify_valid_image_load(load:UOp, buf:UOp):
|
||||
@@ -262,7 +262,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
|
||||
drop_stmt.append(stmt)
|
||||
break
|
||||
|
||||
if not drop_stmt and idx.key == start_idx.key: return None
|
||||
if not drop_stmt and idx == start_idx: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
|
||||
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx)))
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from types import FrameType
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, pretty_print, prod, getenv, all_same
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
@@ -155,7 +156,14 @@ def resolve(x, default:bool=True):
|
||||
except ValueError: return default
|
||||
def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.max)
|
||||
|
||||
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
|
||||
class UOp(MathTrait):
|
||||
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
|
||||
def __new__(cls, op:UOps, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
if (ret:=ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
|
||||
ucache[key] = ret = super().__new__(cls)
|
||||
return ret
|
||||
|
||||
__slots__ = ["op", "dtype", "src", "arg"]
|
||||
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
# TODO: instant check rules here make debugging easier
|
||||
@@ -238,7 +246,6 @@ class UOp(MathTrait):
|
||||
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return UOp(UOps.ALU, out_dtype, (self,)+src, arg)
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): return UOp._const(dtype, b)
|
||||
@staticmethod
|
||||
def _const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
|
||||
@@ -247,7 +254,6 @@ class UOp(MathTrait):
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
|
||||
@@ -597,16 +603,14 @@ class RewriteContext:
|
||||
def __init__(self, pm, ctx):
|
||||
self.pm: PatternMatcher = pm
|
||||
self.ctx = ctx
|
||||
self.nodes: Dict[Tuple, UOp] = {}
|
||||
self.replace: Dict[UOp, UOp] = {}
|
||||
def rewrite(self, n:UOp) -> UOp:
|
||||
if (rn := self.replace.get(n)) is not None: return rn
|
||||
replace_source = (n.op, n.dtype, new_src:=tuple(map(self.rewrite, n.src)), n.arg)
|
||||
if (found := self.nodes.get(replace_source)) is not None: self.replace[n] = found
|
||||
else:
|
||||
x = UOp(*replace_source) if new_src != n.src else n
|
||||
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) is not None else x
|
||||
return found
|
||||
new_src = tuple(map(self.rewrite, n.src))
|
||||
x = UOp(n.op, n.dtype, new_src, n.arg) if new_src != n.src else n
|
||||
self.replace[n] = ret = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) is not None else x
|
||||
return ret
|
||||
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
Reference in New Issue
Block a user