rewrite recip to div (#3690)

* rewrite recip to div

* fix bug in uops add
This commit is contained in:
George Hotz
2024-03-11 15:52:24 -07:00
committed by GitHub
parent aec4c4f01b
commit 2b089bfd18
2 changed files with 15 additions and 3 deletions

View File

@@ -107,7 +107,11 @@ class UOpGraph:
key = (uop, dtype, vin, arg)
if insert_before is None: insert_before = len(self.uops)
# check if the cached expr is valid with the given insert place.
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
try:
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
except ValueError:
# this happens if self.uops.index because the UOp was deleted
pass
ret = UOp(uop, dtype, vin, arg)
self.uops.insert(insert_before, ret)
if cachable: self.saved_exprs[key] = ret

View File

@@ -30,7 +30,7 @@ class CStyleLanguage(NamedTuple):
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
UnaryOps.RECIP: lambda x,dtype: f"(1.0/{x})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
# returns a str expression of the casted xs with the given type
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
@@ -95,6 +95,14 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
depth = 1
def kk(s): kernel.append(" "*depth+s)
# here we do a pretransform on UOps to fix some shortcomings of cstyle
# RECIP is annoying with the type of the const, so we transform it into DIV
for u in uops:
if u.uop is UOps.ALU and u.arg is UnaryOps.RECIP:
const_1 = uops.add(UOps.CONST, u.dtype, arg=1.0, insert_before=uops.uops.index(u))
u.arg = BinaryOps.DIV
u.vin = (const_1, u.vin[0])
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
def ssa(u, prefix="t"):
@@ -133,7 +141,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
val = lang.code_for_op[args](*operands, dtype)
assert child_count[u] != 0, f"childless ALU op found {u}"
# TODO: fix index rendering issue. fix clang nested max macro issue
if child_count[u] <= 1 and args not in {UnaryOps.RECIP, BinaryOps.MAX} and not getenv("EXPAND_SSA"): r[u] = val
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
else: kk(f"{dtype.name} {ssa(u,'alu')} = {val};")
elif uop is UOps.SPECIAL:
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")