This commit is contained in:
George Hotz
2026-01-08 00:55:57 -08:00
parent 37b4751958
commit d84db5851f
3 changed files with 40 additions and 10 deletions

View File

@@ -43,6 +43,10 @@ _UNOPS: dict[str, Ops] = {'-': Ops.NEG, '~': Ops.XOR, '!': Ops.CMPEQ}
# Direct function -> UOp mappings (parsed directly, not as CUSTOM)
_DIRECT_OPS: dict[str, Ops] = {'trunc': Ops.TRUNC, 'sqrt': Ops.SQRT, 'exp2': Ops.EXP2, 'log2': Ops.LOG2, 'sin': Ops.SIN, 'rcp': Ops.RECIPROCAL}
def _typed_const(src: UOp, val) -> UOp:
"""Create a const with same dtype as src, or a deferred const if src.dtype is void."""
return UOp.const(src.dtype, val) if src.dtype != dtypes.void else UOp(Ops.CONST, dtypes.void, (src,), val)
# Function return type inference for CUSTOM ops
_BOOL_FNS = {'isNAN', 'isINF', 'isDENORM', 'isQuietNAN', 'isSignalNAN', 'isEven', 'LT_NEG_ZERO', 'GT_NEG_ZERO'}
_PASSTHRU_FNS = {'abs', 'floor', 'fract', 'sqrt', 'sin', 'cos', 'trunc', 'fma', 'clamp', 'min', 'max', 'ldexp',
@@ -173,6 +177,24 @@ def expr(s: str) -> UOp:
if name in _DIRECT_OPS: return UOp(_DIRECT_OPS[name], srcs[0].dtype, srcs)
if name == 'fma': return UOp(Ops.MULACC, srcs[2].dtype, (srcs[0], srcs[1], srcs[2]))
if name == 'isNAN': return UOp(Ops.CMPNE, dtypes.bool, (srcs[0], srcs[0]))
if name == 'rsqrt': return UOp(Ops.RECIPROCAL, srcs[0].dtype, (UOp(Ops.SQRT, srcs[0].dtype, (srcs[0],)),))
if name == 'clamp':
x, lo, hi = srcs[0], srcs[1], srcs[2]
c = UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, lo)), lo, x))
return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (hi, c)), hi, c))
if name == 'abs':
x = srcs[0]
return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _typed_const(x, 0))), UOp(Ops.NEG, x.dtype, (x,)), x))
if name == 'cos':
x = srcs[0]
return UOp(Ops.SIN, x.dtype, (UOp(Ops.ADD, x.dtype, (x, _typed_const(x, 1.5707963267948966))),))
if name == 'floor':
x, trunc = srcs[0], UOp(Ops.TRUNC, srcs[0].dtype, (srcs[0],))
return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, trunc)), UOp(Ops.SUB, x.dtype, (trunc, _typed_const(x, 1))), trunc))
if name == 'fract':
x, trunc = srcs[0], UOp(Ops.TRUNC, srcs[0].dtype, (srcs[0],))
floor = UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, trunc)), UOp(Ops.SUB, x.dtype, (trunc, _typed_const(x, 1))), trunc))
return UOp(Ops.SUB, x.dtype, (x, floor))
output_dtype = _infer_fn_dtype(name, srcs)
return UOp(Ops.CUSTOM, output_dtype, srcs, arg=name)
# MEM[addr] -> CUSTOM('MEM', addr), MEM[addr].type -> BITCAST

View File

@@ -85,6 +85,9 @@ def _pr(n, d=0):
case UOp(Ops.XOR, _, (x,)) if len(n.src) == 1: return f"~{_pr(x)}"
case UOp(Ops.CMPEQ, _, (x,)) if len(n.src) == 1: return f"!{_pr(x)}"
case UOp(Ops.CMPNE, dtypes.bool, (a, b)) if a == b: return f"isNAN({_pr(a)})"
# fract(x) -> SUB(x, floor(x)) where floor(x) = WHERE(CMPLT(x, TRUNC(x)), SUB(TRUNC(x), 1), TRUNC(x))
case UOp(Ops.SUB, _, (x1, UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x2, UOp(Ops.TRUNC, _, (x3,)))), UOp(Ops.SUB, _, (UOp(Ops.TRUNC, _, (x4,)), UOp(Ops.CONST, _, _, c))), UOp(Ops.TRUNC, _, (x5,)))))) if c in (1, 1.0) and x1 == x2 == x3 == x4 == x5:
return f"fract({_pr(x1)})"
case UOp(_, _, (l, r), _) if n.op in _OP_SYMS:
sym = _OP_SYMS[n.op]
left, right = l, r
@@ -92,12 +95,24 @@ def _pr(n, d=0):
if n.tag == 'flipped' and n.op == Ops.CMPLE: sym, left, right = '>=', r, l
if n.tag == '<>' and n.op == Ops.CMPNE: sym = '<>'
return f"{_pr(left)} {sym} {_pr(right)}"
# clamp(x, lo, hi) -> WHERE(CMPLT(hi, WHERE(CMPLT(x, lo), lo, x)), hi, WHERE(CMPLT(x, lo), lo, x))
case UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (hi, UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x1, lo1)), lo2, x2)))), hi2, UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x3, lo3)), lo4, x4)))) if hi == hi2 and x1 == x2 == x3 == x4 and lo1 == lo2 == lo3 == lo4:
return f"clamp({_pr(x1)}, {_pr(lo1)}, {_pr(hi)})"
# abs(x) -> WHERE(CMPLT(x, 0), NEG(x), x)
case UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x1, UOp(Ops.CONST, _, _, c))), UOp(Ops.NEG, _, (x2,)), x3)) if c in (0, 0.0) and x1 == x2 == x3:
return f"abs({_pr(x1)})"
# floor(x) -> WHERE(CMPLT(x, TRUNC(x)), SUB(TRUNC(x), 1), TRUNC(x))
case UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x1, UOp(Ops.TRUNC, _, (x2,)))), UOp(Ops.SUB, _, (UOp(Ops.TRUNC, _, (x3,)), UOp(Ops.CONST, _, _, c))), UOp(Ops.TRUNC, _, (x4,)))) if c in (1, 1.0) and x1 == x2 == x3 == x4:
return f"floor({_pr(x1)})"
case UOp(Ops.WHERE, _, (c, t, f)): return f"{_pr(c)} ? {_pr(t)} : {_pr(f)}"
case UOp(Ops.TRUNC, _, (x,)): return f"trunc({_pr(x)})"
case UOp(Ops.SQRT, _, (x,)): return f"sqrt({_pr(x)})"
case UOp(Ops.EXP2, _, (x,)): return f"exp2({_pr(x)})"
case UOp(Ops.LOG2, _, (x,)): return f"log2({_pr(x)})"
# cos(x) -> SIN(ADD(x, π/2))
case UOp(Ops.SIN, _, (UOp(Ops.ADD, _, (x, UOp(Ops.CONST, _, _, c))),)) if abs(c - 1.5707963267948966) < 1e-10: return f"cos({_pr(x)})"
case UOp(Ops.SIN, _, (x,)): return f"sin({_pr(x)})"
case UOp(Ops.RECIPROCAL, _, (UOp(Ops.SQRT, _, (x,)),)): return f"rsqrt({_pr(x)})"
case UOp(Ops.RECIPROCAL, _, (x,)): return f"rcp({_pr(x)})"
case UOp(Ops.MULACC, _, (a, b, c)): return f"fma({_pr(a)}, {_pr(b)}, {_pr(c)})"
case UOp(Ops.CUSTOM, _, args, 'MEM'): return f"MEM[{_pr(args[0])}]"

View File

@@ -76,6 +76,9 @@ def _resolve_special_var(name: str, ctx: Ctx, hint: DType = None) -> UOp | None:
def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
"""Transform parsed UOp expression to resolved UOp."""
match node:
case UOp(Ops.CONST, dtypes.void, (type_src,), val): # Deferred const: infer type from type_src
resolved_src = _expr(type_src, ctx, hint)
return UOp.const(resolved_src.dtype, val)
case UOp(Ops.CONST, dt, _, val):
dt = dt if dt != dtypes.int32 or hint is None else hint
return UOp.const(dtypes.float32 if isinstance(val, float) and dt not in FLOATS else dt, val)
@@ -224,16 +227,6 @@ def _minmax(args: list[UOp], is_min: bool) -> UOp:
def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
if name == 'MEM': return a[0]
if name == 'abs': return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, 0))), UOp(Ops.NEG, a[0].dtype, (a[0],)), a[0]))
if name == 'cos': return UOp(Ops.SIN, a[0].dtype, (UOp(Ops.ADD, a[0].dtype, (a[0], UOp.const(a[0].dtype, 1.5707963267948966))),))
if name == 'rsqrt': return UOp(Ops.RECIPROCAL, a[0].dtype, (UOp(Ops.SQRT, a[0].dtype, (a[0],)),))
if name == 'floor':
trunc = UOp(Ops.TRUNC, a[0].dtype, (a[0],))
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], trunc)), UOp(Ops.SUB, a[0].dtype, (trunc, UOp.const(a[0].dtype, 1))), trunc))
if name == 'fract': return UOp(Ops.SUB, a[0].dtype, (a[0], _transform_call('floor', a, hint)))
if name == 'clamp':
c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], a[1])), a[1], a[0]))
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[2], c)), a[2], c))
if name == 'isINF': return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('inf')))),
UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('-inf'))))))
if name in ('isQuietNAN', 'isSignalNAN'):