mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
rewrite recip to div (#3690)
* rewrite recip to div * fix bug in uops add
This commit is contained in:
@@ -107,7 +107,11 @@ class UOpGraph:
|
||||
key = (uop, dtype, vin, arg)
|
||||
if insert_before is None: insert_before = len(self.uops)
|
||||
# check if the cached expr is valid with the given insert place.
|
||||
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
|
||||
try:
|
||||
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
|
||||
except ValueError:
|
||||
# this happens if self.uops.index because the UOp was deleted
|
||||
pass
|
||||
ret = UOp(uop, dtype, vin, arg)
|
||||
self.uops.insert(insert_before, ret)
|
||||
if cachable: self.saved_exprs[key] = ret
|
||||
|
||||
@@ -30,7 +30,7 @@ class CStyleLanguage(NamedTuple):
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
|
||||
BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
||||
UnaryOps.RECIP: lambda x,dtype: f"(1.0/{x})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
|
||||
@@ -95,6 +95,14 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
||||
depth = 1
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
|
||||
# here we do a pretransform on UOps to fix some shortcomings of cstyle
|
||||
# RECIP is annoying with the type of the const, so we transform it into DIV
|
||||
for u in uops:
|
||||
if u.uop is UOps.ALU and u.arg is UnaryOps.RECIP:
|
||||
const_1 = uops.add(UOps.CONST, u.dtype, arg=1.0, insert_before=uops.uops.index(u))
|
||||
u.arg = BinaryOps.DIV
|
||||
u.vin = (const_1, u.vin[0])
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, str] = {}
|
||||
def ssa(u, prefix="t"):
|
||||
@@ -133,7 +141,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
||||
val = lang.code_for_op[args](*operands, dtype)
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
if child_count[u] <= 1 and args not in {UnaryOps.RECIP, BinaryOps.MAX} and not getenv("EXPAND_SSA"): r[u] = val
|
||||
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
||||
else: kk(f"{dtype.name} {ssa(u,'alu')} = {val};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
|
||||
Reference in New Issue
Block a user