mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
rewrite recip to div (#3690)
* rewrite recip to div * fix bug in uops add
This commit is contained in:
@@ -30,7 +30,7 @@ class CStyleLanguage(NamedTuple):
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
|
||||
BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
||||
UnaryOps.RECIP: lambda x,dtype: f"(1.0/{x})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
|
||||
@@ -95,6 +95,14 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
||||
depth = 1
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
|
||||
# here we do a pretransform on UOps to fix some shortcomings of cstyle
|
||||
# RECIP is annoying with the type of the const, so we transform it into DIV
|
||||
for u in uops:
|
||||
if u.uop is UOps.ALU and u.arg is UnaryOps.RECIP:
|
||||
const_1 = uops.add(UOps.CONST, u.dtype, arg=1.0, insert_before=uops.uops.index(u))
|
||||
u.arg = BinaryOps.DIV
|
||||
u.vin = (const_1, u.vin[0])
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, str] = {}
|
||||
def ssa(u, prefix="t"):
|
||||
@@ -133,7 +141,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
||||
val = lang.code_for_op[args](*operands, dtype)
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
if child_count[u] <= 1 and args not in {UnaryOps.RECIP, BinaryOps.MAX} and not getenv("EXPAND_SSA"): r[u] = val
|
||||
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
||||
else: kk(f"{dtype.name} {ssa(u,'alu')} = {val};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
|
||||
Reference in New Issue
Block a user