move more

This commit is contained in:
George Hotz
2026-01-08 04:18:46 -08:00
parent d84db5851f
commit 4a7456caef
3 changed files with 11 additions and 3 deletions

View File

@@ -195,6 +195,14 @@ def expr(s: str) -> UOp:
x, trunc = srcs[0], UOp(Ops.TRUNC, srcs[0].dtype, (srcs[0],))
floor = UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, trunc)), UOp(Ops.SUB, x.dtype, (trunc, _typed_const(x, 1))), trunc))
return UOp(Ops.SUB, x.dtype, (x, floor))
if name == 'isINF':
x = srcs[0]
return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (x, _typed_const(x, float('inf')))),
UOp(Ops.CMPEQ, dtypes.bool, (x, _typed_const(x, float('-inf'))))))
if name in ('min', 'max'):
a, b = srcs[0], srcs[1]
cmp = UOp(Ops.CMPLT, dtypes.bool, (a, b) if name == 'min' else (b, a))
return UOp(Ops.WHERE, a.dtype, (cmp, a, b))
output_dtype = _infer_fn_dtype(name, srcs)
return UOp(Ops.CUSTOM, output_dtype, srcs, arg=name)
# MEM[addr] -> CUSTOM('MEM', addr), MEM[addr].type -> BITCAST

View File

@@ -85,6 +85,9 @@ def _pr(n, d=0):
case UOp(Ops.XOR, _, (x,)) if len(n.src) == 1: return f"~{_pr(x)}"
case UOp(Ops.CMPEQ, _, (x,)) if len(n.src) == 1: return f"!{_pr(x)}"
case UOp(Ops.CMPNE, dtypes.bool, (a, b)) if a == b: return f"isNAN({_pr(a)})"
# isINF(x) -> OR(CMPEQ(x, +inf), CMPEQ(x, -inf))
case UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, _, (x1, UOp(Ops.CONST, _, _, c1))), UOp(Ops.CMPEQ, _, (x2, UOp(Ops.CONST, _, _, c2))))) if x1 == x2 and c1 == float('inf') and c2 == float('-inf'):
return f"isINF({_pr(x1)})"
# fract(x) -> SUB(x, floor(x)) where floor(x) = WHERE(CMPLT(x, TRUNC(x)), SUB(TRUNC(x), 1), TRUNC(x))
case UOp(Ops.SUB, _, (x1, UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x2, UOp(Ops.TRUNC, _, (x3,)))), UOp(Ops.SUB, _, (UOp(Ops.TRUNC, _, (x4,)), UOp(Ops.CONST, _, _, c))), UOp(Ops.TRUNC, _, (x5,)))))) if c in (1, 1.0) and x1 == x2 == x3 == x4 == x5:
return f"fract({_pr(x1)})"

View File

@@ -227,8 +227,6 @@ def _minmax(args: list[UOp], is_min: bool) -> UOp:
def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
if name == 'MEM': return a[0]
if name == 'isINF': return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('inf')))),
UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('-inf'))))))
if name in ('isQuietNAN', 'isSignalNAN'):
bits, exp_shift, exp_mask, mant_mask = _fp_bits(a[0])
# Use the dtype from bits (uint32/uint64/uint16) to determine which quiet bit to use
@@ -295,7 +293,6 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
assert a[0].op == Ops.CONST and a[0].arg == 2.0
return UOp(Ops.EXP2, a[0].dtype, (a[1] if a[1].dtype == a[0].dtype else UOp(Ops.CAST, a[0].dtype, (a[1],)),))
if name == 'ldexp': return UOp(Ops.MUL, a[0].dtype, (a[0], UOp(Ops.EXP2, a[0].dtype, (UOp(Ops.CAST, a[0].dtype, (a[1],)),))))
if name in ('min', 'max'): return _minmax(a, is_min=(name == 'min'))
if name in CVT_MAP:
dt, clamp = CVT_MAP[name]
v = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, 0.0))), UOp.const(a[0].dtype, 0.0), a[0])) if clamp else a[0]