explore global uop cache [pr] (#6863)

* explore global uop cache

* wvd uops

* remove useless lru caches

* key is is

* simpler rewriter
This commit is contained in:
George Hotz
2024-10-03 13:08:13 +08:00
committed by GitHub
parent a26c6a0ad0
commit e10245909a
3 changed files with 18 additions and 14 deletions

View File

@@ -99,7 +99,7 @@ class TestPatternMatcher(unittest.TestCase):
c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
c2 = UOp(UOps.ALU, dtypes.float, (y1, y2), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
self.assertEqual(matcher.rewrite(c2), c1)
def test_dtype(self):
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float32), lambda x: x)])

View File

@@ -186,7 +186,7 @@ def is_increasing(f:UOp):
def replace_uop(uop:UOp, old:UOp, new:UOp):
# replace all `old` in `uop` to `new`
return new if uop.key == old.key else uop.replace(src=tuple(replace_uop(s, old, new) for s in uop.src))
return new if uop is old else uop.replace(src=tuple(replace_uop(s, old, new) for s in uop.src))
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
# if it's X <= c, returns X, True, c
@@ -229,8 +229,8 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
newidxs[1].append(newidx.src[1])
# if every branch in candidate gives the same simplified output, we can rewrite the idx
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same(newidxs[0])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same(newidxs[1])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
return idx
def simplify_valid_image_load(load:UOp, buf:UOp):
@@ -262,7 +262,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
drop_stmt.append(stmt)
break
if not drop_stmt and idx.key == start_idx.key: return None
if not drop_stmt and idx == start_idx: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx)))

View File

@@ -4,6 +4,7 @@ from types import FrameType
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
from weakref import WeakValueDictionary
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, pretty_print, prod, getenv, all_same
from tinygrad.shape.symbolic import Variable, sint
@@ -155,7 +156,14 @@ def resolve(x, default:bool=True):
except ValueError: return default
def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.max)
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
class UOp(MathTrait):
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
def __new__(cls, op:UOps, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
if (ret:=ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
ucache[key] = ret = super().__new__(cls)
return ret
__slots__ = ["op", "dtype", "src", "arg"]
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
# TODO: instant check rules here make debugging easier
@@ -238,7 +246,6 @@ class UOp(MathTrait):
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(UOps.ALU, out_dtype, (self,)+src, arg)
@staticmethod
@functools.lru_cache(None)
def const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): return UOp._const(dtype, b)
@staticmethod
def _const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
@@ -247,7 +254,6 @@ class UOp(MathTrait):
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
@staticmethod
@functools.lru_cache(None)
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@staticmethod
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
@@ -597,16 +603,14 @@ class RewriteContext:
def __init__(self, pm, ctx):
self.pm: PatternMatcher = pm
self.ctx = ctx
self.nodes: Dict[Tuple, UOp] = {}
self.replace: Dict[UOp, UOp] = {}
def rewrite(self, n:UOp) -> UOp:
if (rn := self.replace.get(n)) is not None: return rn
replace_source = (n.op, n.dtype, new_src:=tuple(map(self.rewrite, n.src)), n.arg)
if (found := self.nodes.get(replace_source)) is not None: self.replace[n] = found
else:
x = UOp(*replace_source) if new_src != n.src else n
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) is not None else x
return found
new_src = tuple(map(self.rewrite, n.src))
x = UOp(n.op, n.dtype, new_src, n.arg) if new_src != n.src else n
self.replace[n] = ret = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) is not None else x
return ret
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
if TRACK_MATCH_STATS >= 2:
from tinygrad.codegen.kernel import Kernel