increase speed of uops (#4637)

* increase speed of uops

* not equal

* minor speedup
This commit is contained in:
George Hotz
2024-05-17 21:04:39 -07:00
committed by GitHub
parent b74cc1d01a
commit 9b464e34ea

View File

@@ -33,14 +33,12 @@ class UOp:
vin: Tuple[UOp, ...] = tuple()
arg: Any = None
def tuple(self): return (self.uop, self.dtype, self.vin, self.arg)
@functools.cached_property
def cmp_tuple(self):
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
return (self.uop.value, (self.arg if self.uop is not UOps.DEFINE_VAR else self.arg.expr) if self.uop is not UOps.ALU else \
(type(self.uop), self.uop.value), self.dtype, self.vin)
def __lt__(self, x:UOp):
a, b = self.cmp_tuple(), x.cmp_tuple()
try: return a < b
except Exception: raise RuntimeError(f"compare failed between {self.uop} and {x.uop} -- {a} and {b}")
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
def __repr__(self):
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
def cast(self, dtype): return UOp(UOps.CAST, dtype, (self,))
@@ -74,6 +72,16 @@ def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
if k == "__name__":
if v in store and store[v] != uop: return False
store[v] = uop
elif k == "arg":
if uop.arg != v: return False
elif k == "dtype":
if isinstance(v, set):
if uop.dtype not in v: return False
elif uop.dtype != v: return False
elif k == "uop":
if isinstance(v, set):
if uop.uop not in v: return False
elif uop.uop != v: return False
elif k == "vin":
# only one if it's a tuple
# try all permutations if it's a list
@@ -85,11 +93,6 @@ def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
for k,v in new_store.items(): store[k] = v
return True
return False
elif k in {"dtype", "uop"}:
if uop.__getattribute__(k) not in (v if isinstance(v, set) else set([v])): return False
elif k[:2] == "__": continue
else:
if uop.__getattribute__(k) != v: return False
return True
class PatternMatcher:
@@ -136,10 +139,10 @@ constant_folder = PatternMatcher([
{"__name__": "compval", "uop": UOps.CONST})}, {"__name__": "multconst", "uop": UOps.CONST}, {"uop": UOps.CONST, "arg": 0})}, loop_collapse),
# sum collapse to mul (with possible GEP)
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.LOOP, "__name__": "loop"},)},
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "val1"}, {"__name__": "val2"}]})}, sum_collapse),
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.GEP,
"vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.LOOP, "__name__": "loop"},)},)},
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "val1"}, {"__name__": "val2"}]})}, sum_collapse),
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
# deal with UNMUL
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"uop": UOps.CONST, "__name__": "c1"},
{"uop": UOps.UNMUL, "vin": [{"uop": UOps.CONST, "__name__": "c2"}, {"__name__": "v"}]}]},
@@ -168,8 +171,7 @@ constant_folder = PatternMatcher([
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
lambda x, my: x-my.vin[0]),
# -1*x -> -x
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]},
lambda x: UOp(UOps.ALU, x.dtype, (x,), UnaryOps.NEG)),
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]}, lambda x: -x),
# bool < False is always false, True < bool is always false
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({}, {"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": False})}, lambda x: x),
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": True}, {})},
@@ -223,7 +225,7 @@ constant_folder = PatternMatcher([
{"__name__": "c", "uop": UOps.CONST, "dtype": dtypes.int})},
lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)),
# cast NOOP (NOTE: it's str to deal with PtrDType)
({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if root.dtype is root.vin[0].dtype else None),
])
# *** uop graph ***
@@ -268,8 +270,10 @@ class UOpGraph:
up = rewritten
recurse_cnt += 1
changed += recurse_cnt
# NOTE: this changes UOp, so we have to delete caches
up.vin = tuple(rewrite(x) for x in up.vin)
if hasattr(up, "parents"): del up.parents
if hasattr(up, "cmp_tuple"): del up.cmp_tuple
# replace with cached nodes
if found:=self.nodes.get(key:=up.tuple()): return found
else: self.nodes[key] = up