mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
isNAN
This commit is contained in:
@@ -172,6 +172,7 @@ def expr(s: str) -> UOp:
|
||||
# Direct UOp mappings for functions
|
||||
if name in _DIRECT_OPS: return UOp(_DIRECT_OPS[name], srcs[0].dtype, srcs)
|
||||
if name == 'fma': return UOp(Ops.MULACC, srcs[2].dtype, (srcs[0], srcs[1], srcs[2]))
|
||||
if name == 'isNAN': return UOp(Ops.CMPNE, dtypes.bool, (srcs[0], srcs[0]))
|
||||
output_dtype = _infer_fn_dtype(name, srcs)
|
||||
return UOp(Ops.CUSTOM, output_dtype, srcs, arg=name)
|
||||
# MEM[addr] -> CUSTOM('MEM', addr), MEM[addr].type -> BITCAST
|
||||
|
||||
@@ -84,6 +84,7 @@ def _pr(n, d=0):
|
||||
case UOp(Ops.NEG, _, (x,)): return f"-{_pr(x)}"
|
||||
case UOp(Ops.XOR, _, (x,)) if len(n.src) == 1: return f"~{_pr(x)}"
|
||||
case UOp(Ops.CMPEQ, _, (x,)) if len(n.src) == 1: return f"!{_pr(x)}"
|
||||
case UOp(Ops.CMPNE, dtypes.bool, (a, b)) if a == b: return f"isNAN({_pr(a)})"
|
||||
case UOp(_, _, (l, r), _) if n.op in _OP_SYMS:
|
||||
sym = _OP_SYMS[n.op]
|
||||
left, right = l, r
|
||||
|
||||
@@ -234,7 +234,6 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
if name == 'clamp':
|
||||
c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], a[1])), a[1], a[0]))
|
||||
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[2], c)), a[2], c))
|
||||
if name == 'isNAN': return UOp(Ops.CMPNE, dtypes.bool, (a[0], a[0]))
|
||||
if name == 'isINF': return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('inf')))),
|
||||
UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('-inf'))))))
|
||||
if name in ('isQuietNAN', 'isSignalNAN'):
|
||||
|
||||
Reference in New Issue
Block a user