mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
increase speed of uops (#4637)
* increase speed of uops * not equal * minor speedup
This commit is contained in:
@@ -33,14 +33,12 @@ class UOp:
|
||||
vin: Tuple[UOp, ...] = tuple()
|
||||
arg: Any = None
|
||||
def tuple(self): return (self.uop, self.dtype, self.vin, self.arg)
|
||||
@functools.cached_property
|
||||
def cmp_tuple(self):
|
||||
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
||||
return (self.uop.value, (self.arg if self.uop is not UOps.DEFINE_VAR else self.arg.expr) if self.uop is not UOps.ALU else \
|
||||
(type(self.uop), self.uop.value), self.dtype, self.vin)
|
||||
def __lt__(self, x:UOp):
|
||||
a, b = self.cmp_tuple(), x.cmp_tuple()
|
||||
try: return a < b
|
||||
except Exception: raise RuntimeError(f"compare failed between {self.uop} and {x.uop} -- {a} and {b}")
|
||||
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
||||
def __repr__(self):
|
||||
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
|
||||
def cast(self, dtype): return UOp(UOps.CAST, dtype, (self,))
|
||||
@@ -74,6 +72,16 @@ def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
|
||||
if k == "__name__":
|
||||
if v in store and store[v] != uop: return False
|
||||
store[v] = uop
|
||||
elif k == "arg":
|
||||
if uop.arg != v: return False
|
||||
elif k == "dtype":
|
||||
if isinstance(v, set):
|
||||
if uop.dtype not in v: return False
|
||||
elif uop.dtype != v: return False
|
||||
elif k == "uop":
|
||||
if isinstance(v, set):
|
||||
if uop.uop not in v: return False
|
||||
elif uop.uop != v: return False
|
||||
elif k == "vin":
|
||||
# only one if it's a tuple
|
||||
# try all permutations if it's a list
|
||||
@@ -85,11 +93,6 @@ def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
|
||||
for k,v in new_store.items(): store[k] = v
|
||||
return True
|
||||
return False
|
||||
elif k in {"dtype", "uop"}:
|
||||
if uop.__getattribute__(k) not in (v if isinstance(v, set) else set([v])): return False
|
||||
elif k[:2] == "__": continue
|
||||
else:
|
||||
if uop.__getattribute__(k) != v: return False
|
||||
return True
|
||||
|
||||
class PatternMatcher:
|
||||
@@ -136,10 +139,10 @@ constant_folder = PatternMatcher([
|
||||
{"__name__": "compval", "uop": UOps.CONST})}, {"__name__": "multconst", "uop": UOps.CONST}, {"uop": UOps.CONST, "arg": 0})}, loop_collapse),
|
||||
# sum collapse to mul (with possible GEP)
|
||||
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.LOOP, "__name__": "loop"},)},
|
||||
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "val1"}, {"__name__": "val2"}]})}, sum_collapse),
|
||||
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
|
||||
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.GEP,
|
||||
"vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.LOOP, "__name__": "loop"},)},)},
|
||||
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "val1"}, {"__name__": "val2"}]})}, sum_collapse),
|
||||
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
|
||||
# deal with UNMUL
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"uop": UOps.CONST, "__name__": "c1"},
|
||||
{"uop": UOps.UNMUL, "vin": [{"uop": UOps.CONST, "__name__": "c2"}, {"__name__": "v"}]}]},
|
||||
@@ -168,8 +171,7 @@ constant_folder = PatternMatcher([
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
|
||||
lambda x, my: x-my.vin[0]),
|
||||
# -1*x -> -x
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]},
|
||||
lambda x: UOp(UOps.ALU, x.dtype, (x,), UnaryOps.NEG)),
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]}, lambda x: -x),
|
||||
# bool < False is always false, True < bool is always false
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({}, {"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": False})}, lambda x: x),
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": True}, {})},
|
||||
@@ -223,7 +225,7 @@ constant_folder = PatternMatcher([
|
||||
{"__name__": "c", "uop": UOps.CONST, "dtype": dtypes.int})},
|
||||
lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)),
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
|
||||
({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if root.dtype is root.vin[0].dtype else None),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
@@ -268,8 +270,10 @@ class UOpGraph:
|
||||
up = rewritten
|
||||
recurse_cnt += 1
|
||||
changed += recurse_cnt
|
||||
# NOTE: this changes UOp, so we have to delete caches
|
||||
up.vin = tuple(rewrite(x) for x in up.vin)
|
||||
if hasattr(up, "parents"): del up.parents
|
||||
if hasattr(up, "cmp_tuple"): del up.cmp_tuple
|
||||
# replace with cached nodes
|
||||
if found:=self.nodes.get(key:=up.tuple()): return found
|
||||
else: self.nodes[key] = up
|
||||
|
||||
Reference in New Issue
Block a user