This commit is contained in:
George Hotz
2026-01-08 00:24:10 -08:00
parent 5e923ccb5e
commit 37b4751958
3 changed files with 2 additions and 1 deletions

View File

@@ -172,6 +172,7 @@ def expr(s: str) -> UOp:
# Direct UOp mappings for functions
if name in _DIRECT_OPS: return UOp(_DIRECT_OPS[name], srcs[0].dtype, srcs)
if name == 'fma': return UOp(Ops.MULACC, srcs[2].dtype, (srcs[0], srcs[1], srcs[2]))
if name == 'isNAN': return UOp(Ops.CMPNE, dtypes.bool, (srcs[0], srcs[0]))
output_dtype = _infer_fn_dtype(name, srcs)
return UOp(Ops.CUSTOM, output_dtype, srcs, arg=name)
# MEM[addr] -> CUSTOM('MEM', addr), MEM[addr].type -> BITCAST

View File

@@ -84,6 +84,7 @@ def _pr(n, d=0):
case UOp(Ops.NEG, _, (x,)): return f"-{_pr(x)}"
case UOp(Ops.XOR, _, (x,)) if len(n.src) == 1: return f"~{_pr(x)}"
case UOp(Ops.CMPEQ, _, (x,)) if len(n.src) == 1: return f"!{_pr(x)}"
case UOp(Ops.CMPNE, dtypes.bool, (a, b)) if a == b: return f"isNAN({_pr(a)})"
case UOp(_, _, (l, r), _) if n.op in _OP_SYMS:
sym = _OP_SYMS[n.op]
left, right = l, r

View File

@@ -234,7 +234,6 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
if name == 'clamp':
c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], a[1])), a[1], a[0]))
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[2], c)), a[2], c))
if name == 'isNAN': return UOp(Ops.CMPNE, dtypes.bool, (a[0], a[0]))
if name == 'isINF': return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('inf')))),
UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('-inf'))))))
if name in ('isQuietNAN', 'isSignalNAN'):