From 2b089bfd18912dcde5942b33ffe541575dbd9ced Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 11 Mar 2024 15:52:24 -0700 Subject: [PATCH] rewrite recip to div (#3690) * rewrite recip to div * fix bug in uops add --- tinygrad/codegen/uops.py | 6 +++++- tinygrad/renderer/cstyle.py | 12 ++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index c8ab02717a..0a42bf8728 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -107,7 +107,11 @@ class UOpGraph: key = (uop, dtype, vin, arg) if insert_before is None: insert_before = len(self.uops) # check if the cached expr is valid with the given insert place. - if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr + try: + if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr + except ValueError: + # this happens if self.uops.index because the UOp was deleted + pass ret = UOp(uop, dtype, vin, arg) self.uops.insert(insert_before, ret) if cachable: self.saved_exprs[key] = ret diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index e3f3838cb7..7c8a107867 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -30,7 +30,7 @@ class CStyleLanguage(NamedTuple): BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})", - UnaryOps.RECIP: lambda x,dtype: f"(1.0/{x})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"} + TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"} # returns a str expression of the casted xs with the given type def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str: @@ -95,6 +95,14 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str depth = 1 def kk(s): kernel.append(" "*depth+s) + # here we do a pretransform on UOps to fix some shortcomings of cstyle + # RECIP is annoying with the type of the const, so we transform it into DIV + for u in uops: + if u.uop is UOps.ALU and u.arg is UnaryOps.RECIP: + const_1 = uops.add(UOps.CONST, u.dtype, arg=1.0, insert_before=uops.uops.index(u)) + u.arg = BinaryOps.DIV + u.vin = (const_1, u.vin[0]) + c: DefaultDict[str, int] = defaultdict(int) r: Dict[UOp, str] = {} def ssa(u, prefix="t"): @@ -133,7 +141,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str val = lang.code_for_op[args](*operands, dtype) assert child_count[u] != 0, f"childless ALU op found {u}" # TODO: fix index rendering issue. fix clang nested max macro issue - if child_count[u] <= 1 and args not in {UnaryOps.RECIP, BinaryOps.MAX} and not getenv("EXPAND_SSA"): r[u] = val + if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val else: kk(f"{dtype.name} {ssa(u,'alu')} = {val};") elif uop is UOps.SPECIAL: kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")