diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 0f0da97735..ccf9a50e7a 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -1,8 +1,8 @@ from __future__ import annotations -import functools, math, operator +import functools, math, operator, itertools from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast from collections import defaultdict -from tinygrad.helpers import DEBUG, flatten, all_same +from tinygrad.helpers import DEBUG, flatten from tinygrad.dtype import dtypes, DType from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.shape.symbolic import sint, Variable, Node, NumNode, MulNode, DivNode, SumNode @@ -20,10 +20,13 @@ class UOps(Enum): class UOp: uop: UOps dtype: Optional[DType] - vin: Tuple[UOp, ...] - arg: Any + vin: Tuple[UOp, ...] = tuple() + arg: Any = None def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}" + @staticmethod + def const(dtype, val): + return UOp(UOps.CONST, dtype, arg=float(val) if dtypes.is_float(dtype) else (int(val) if dtypes.is_int(dtype) else bool(val))) def hook_overflow(dv, fxn): def wfxn(*args): @@ -58,6 +61,59 @@ def uop_alu_resolve(u:UOp) -> sint: def phi_resolve_acc(u:UOp) -> UOp: return u if u.uop is UOps.DEFINE_ACC else phi_resolve_acc(u.vin[0]) +def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool: + for k,v in pattern.items(): + if k == "__name__": + if v in store and store[v] != uop: return False + store[v] = uop + elif k == "vin": + # only one if it's a tuple + # try all permutations if it's a list + # repeat if it's a dict + for vp in itertools.permutations(v) if isinstance(v, list) else ([v] if isinstance(v, tuple) else [(v,)*len(uop.vin)]): + if len(uop.vin) != len(vp): return False + new_store = store.copy() + if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)): + for k,v in new_store.items(): store[k] = v + return True + return False + else: + if uop.__getattribute__(k) != v: return False + return True + +def rewrite(uop:UOp, patterns:List[Tuple[Dict[str, Any], Any]]) -> Optional[UOp]: + for p,fxn in patterns: + store: Dict[str, UOp] = {} + if _match(uop, p, store): + return fxn(**store) + return None + +constant_fold_patterns = [ + # const rules + ({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)), + ({"__name__": "root", "uop": UOps.CAST, "vin": {"__name__": "c", "uop": UOps.CONST}}, lambda root,c: UOp.const(root.dtype, c.arg)), + # a phi without loops (len(vin)==2) is a noop + ({"uop": UOps.PHI, "vin": ({}, {"__name__": "x"})}, lambda x: x), + # x+-y -> x-y + ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})}, + lambda x, my: UOp(UOps.ALU, x.dtype, (x, my.vin[0]), BinaryOps.SUB)), + # a conditional with the same results either way is a noop, also fold const conditionals + ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({}, {"__name__": "val"}, {"__name__": "val"})}, lambda val: val), + ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"__name__": "gate", "uop": UOps.CONST}, {"__name__": "c0"}, {"__name__": "c1"})}, + lambda gate, c0, c1: c0 if gate.arg else c1), + # ** constant folding ** + ({"__name__": "root", "uop": UOps.ALU, "vin": {"uop": UOps.CONST}}, + lambda root: UOp(UOps.CONST, root.dtype, arg=exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))), + # ** self folding ** + ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 0}]}, lambda x: x), # x+0 -> x or 0+x -> x + ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 1}]}, lambda x: x), # x*1 -> x or 1*x -> x + ({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 0})}, lambda x: x), # x-0 -> x + ({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 1})}, lambda x: x), # x/1 -> x + # ** zero folding ** + ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0 + ({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0 +] + class UOpGraph: def __init__(self, start_uops:Optional[List[UOp]]=None): # list of uops @@ -81,33 +137,14 @@ class UOpGraph: def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: - if simplify: - if uop is UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop - if uop is UOps.GEP and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=vin[0].arg, insert_before=insert_before) - if uop is UOps.CAST and all(x.uop is UOps.CONST for x in vin) and all_same([x.arg for x in vin]): - return self.add(UOps.CONST, dtype, arg=vin[0].arg, insert_before=insert_before) - if uop is UOps.ALU: - # rewrites. NOTE: the rewritten NEG op is still around... - if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG: - return self.add(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before) - # constant folding - if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop - if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2] - if all(x.uop is UOps.CONST for x in vin): - return self.add(UOps.CONST, dtype, arg=exec_alu(arg, dtype, [x.arg for x in vin]), insert_before=insert_before) - # zero folding - for x in [0,1]: - if arg is BinaryOps.ADD and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[1-x] - if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 1.0: return vin[1-x] - if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[x] - if arg is BinaryOps.SUB and vin[1].uop is UOps.CONST and vin[1].arg == 0.0: return vin[0] - if arg is BinaryOps.DIV and vin[1].uop is UOps.CONST and vin[1].arg == 1.0: return vin[0] - - key = (uop, dtype, vin, arg) + ret = UOp(uop, dtype, vin, arg) + if simplify and (rewritten:=rewrite(ret, constant_fold_patterns)) is not None: + if rewritten in self.uops: return rewritten # ignore cachable + ret = rewritten + key = (ret.uop, ret.dtype, ret.vin, ret.arg) if insert_before is None: insert_before = len(self.uops) # check if the cached expr is valid with the given insert place. if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr - ret = UOp(uop, dtype, vin, arg) self.uops.insert(insert_before, ret) if cachable: self.saved_exprs[key] = ret return ret