diff --git a/test/unit/test_pattern_matcher.py b/test/unit/test_pattern_matcher.py index 7563bc6095..e9767adb1c 100644 --- a/test/unit/test_pattern_matcher.py +++ b/test/unit/test_pattern_matcher.py @@ -99,7 +99,7 @@ class TestPatternMatcher(unittest.TestCase): c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD) c2 = UOp(UOps.ALU, dtypes.float, (y1, y2), BinaryOps.ADD) self.assertEqual(matcher.rewrite(c1), c1) - self.assertEqual(matcher.rewrite(c2), None) + self.assertEqual(matcher.rewrite(c2), c1) def test_dtype(self): matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float32), lambda x: x)]) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index d9adb6d3f0..71544958e0 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -186,7 +186,7 @@ def is_increasing(f:UOp): def replace_uop(uop:UOp, old:UOp, new:UOp): # replace all `old` in `uop` to `new` - return new if uop.key == old.key else uop.replace(src=tuple(replace_uop(s, old, new) for s in uop.src)) + return new if uop is old else uop.replace(src=tuple(replace_uop(s, old, new) for s in uop.src)) def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]: # if it's X <= c, returns X, True, c @@ -229,8 +229,8 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]: newidxs[1].append(newidx.src[1]) # if every branch in candidate gives the same simplified output, we can rewrite the idx - if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1])) - if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0])) + if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same(newidxs[0])): idx = idx.replace(src=(newidxs[0][0], idx.src[1])) + if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same(newidxs[1])): idx = idx.replace(src=(idx.src[0], newidxs[1][0])) return idx def simplify_valid_image_load(load:UOp, buf:UOp): @@ -262,7 +262,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp): drop_stmt.append(stmt) break - if not drop_stmt and idx.key == start_idx.key: return None + if not drop_stmt and idx == start_idx: return None new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx))) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 4831729ffc..a3712bcef8 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -4,6 +4,7 @@ from types import FrameType import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle from enum import auto, IntEnum, Enum from dataclasses import dataclass, field +from weakref import WeakValueDictionary from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate from tinygrad.helpers import ContextVar, pretty_print, prod, getenv, all_same from tinygrad.shape.symbolic import Variable, sint @@ -155,7 +156,14 @@ def resolve(x, default:bool=True): except ValueError: return default def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.max) +ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary() class UOp(MathTrait): + def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg) + def __new__(cls, op:UOps, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None): + if (ret:=ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret + ucache[key] = ret = super().__new__(cls) + return ret + __slots__ = ["op", "dtype", "src", "arg"] def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None): # TODO: instant check rules here make debugging easier @@ -238,7 +246,6 @@ class UOp(MathTrait): out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool return UOp(UOps.ALU, out_dtype, (self,)+src, arg) @staticmethod - @functools.lru_cache(None) def const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): return UOp._const(dtype, b) @staticmethod def _const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): @@ -247,7 +254,6 @@ class UOp(MathTrait): if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore @staticmethod - @functools.lru_cache(None) def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @staticmethod def range(dtype:DType, start:ConstType, end:ConstType, idx:int): @@ -597,16 +603,14 @@ class RewriteContext: def __init__(self, pm, ctx): self.pm: PatternMatcher = pm self.ctx = ctx - self.nodes: Dict[Tuple, UOp] = {} self.replace: Dict[UOp, UOp] = {} def rewrite(self, n:UOp) -> UOp: if (rn := self.replace.get(n)) is not None: return rn - replace_source = (n.op, n.dtype, new_src:=tuple(map(self.rewrite, n.src)), n.arg) - if (found := self.nodes.get(replace_source)) is not None: self.replace[n] = found - else: - x = UOp(*replace_source) if new_src != n.src else n - self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) is not None else x - return found + new_src = tuple(map(self.rewrite, n.src)) + x = UOp(n.op, n.dtype, new_src, n.arg) if new_src != n.src else n + self.replace[n] = ret = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) is not None else x + return ret + def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp: if TRACK_MATCH_STATS >= 2: from tinygrad.codegen.kernel import Kernel