From 9b464e34eaf287ce2c00fc03a8d7469f52c7ad8e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 17 May 2024 21:04:39 -0700 Subject: [PATCH] increase speed of uops (#4637) * increase speed of uops * not equal * minor speedup --- tinygrad/codegen/uops.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 11ca5495aa..898c1ed66a 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -33,14 +33,12 @@ class UOp: vin: Tuple[UOp, ...] = tuple() arg: Any = None def tuple(self): return (self.uop, self.dtype, self.vin, self.arg) + @functools.cached_property def cmp_tuple(self): # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX return (self.uop.value, (self.arg if self.uop is not UOps.DEFINE_VAR else self.arg.expr) if self.uop is not UOps.ALU else \ (type(self.uop), self.uop.value), self.dtype, self.vin) - def __lt__(self, x:UOp): - a, b = self.cmp_tuple(), x.cmp_tuple() - try: return a < b - except Exception: raise RuntimeError(f"compare failed between {self.uop} and {x.uop} -- {a} and {b}") + def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}" def cast(self, dtype): return UOp(UOps.CAST, dtype, (self,)) @@ -74,6 +72,16 @@ def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool: if k == "__name__": if v in store and store[v] != uop: return False store[v] = uop + elif k == "arg": + if uop.arg != v: return False + elif k == "dtype": + if isinstance(v, set): + if uop.dtype not in v: return False + elif uop.dtype != v: return False + elif k == "uop": + if isinstance(v, set): + if uop.uop not in v: return False + elif uop.uop != v: return False elif k == "vin": # only one if it's a tuple # try all permutations if it's a list @@ -85,11 +93,6 @@ def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool: for k,v in new_store.items(): store[k] = v return True return False - elif k in {"dtype", "uop"}: - if uop.__getattribute__(k) not in (v if isinstance(v, set) else set([v])): return False - elif k[:2] == "__": continue - else: - if uop.__getattribute__(k) != v: return False return True class PatternMatcher: @@ -136,10 +139,10 @@ constant_folder = PatternMatcher([ {"__name__": "compval", "uop": UOps.CONST})}, {"__name__": "multconst", "uop": UOps.CONST}, {"uop": UOps.CONST, "arg": 0})}, loop_collapse), # sum collapse to mul (with possible GEP) ({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.LOOP, "__name__": "loop"},)}, - {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "val1"}, {"__name__": "val2"}]})}, sum_collapse), + {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse), ({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.GEP, "vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.LOOP, "__name__": "loop"},)},)}, - {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "val1"}, {"__name__": "val2"}]})}, sum_collapse), + {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse), # deal with UNMUL ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"uop": UOps.CONST, "__name__": "c1"}, {"uop": UOps.UNMUL, "vin": [{"uop": UOps.CONST, "__name__": "c2"}, {"__name__": "v"}]}]}, @@ -168,8 +171,7 @@ constant_folder = PatternMatcher([ ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})}, lambda x, my: x-my.vin[0]), # -1*x -> -x - ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]}, - lambda x: UOp(UOps.ALU, x.dtype, (x,), UnaryOps.NEG)), + ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]}, lambda x: -x), # bool < False is always false, True < bool is always false ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({}, {"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": False})}, lambda x: x), ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": True}, {})}, @@ -223,7 +225,7 @@ constant_folder = PatternMatcher([ {"__name__": "c", "uop": UOps.CONST, "dtype": dtypes.int})}, lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)), # cast NOOP (NOTE: it's str to deal with PtrDType) - ({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None), + ({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if root.dtype is root.vin[0].dtype else None), ]) # *** uop graph *** @@ -268,8 +270,10 @@ class UOpGraph: up = rewritten recurse_cnt += 1 changed += recurse_cnt + # NOTE: this changes UOp, so we have to delete caches up.vin = tuple(rewrite(x) for x in up.vin) if hasattr(up, "parents"): del up.parents + if hasattr(up, "cmp_tuple"): del up.cmp_tuple # replace with cached nodes if found:=self.nodes.get(key:=up.tuple()): return found else: self.nodes[key] = up