mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
move more
This commit is contained in:
@@ -195,6 +195,14 @@ def expr(s: str) -> UOp:
|
||||
x, trunc = srcs[0], UOp(Ops.TRUNC, srcs[0].dtype, (srcs[0],))
|
||||
floor = UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, trunc)), UOp(Ops.SUB, x.dtype, (trunc, _typed_const(x, 1))), trunc))
|
||||
return UOp(Ops.SUB, x.dtype, (x, floor))
|
||||
if name == 'isINF':
|
||||
x = srcs[0]
|
||||
return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (x, _typed_const(x, float('inf')))),
|
||||
UOp(Ops.CMPEQ, dtypes.bool, (x, _typed_const(x, float('-inf'))))))
|
||||
if name in ('min', 'max'):
|
||||
a, b = srcs[0], srcs[1]
|
||||
cmp = UOp(Ops.CMPLT, dtypes.bool, (a, b) if name == 'min' else (b, a))
|
||||
return UOp(Ops.WHERE, a.dtype, (cmp, a, b))
|
||||
output_dtype = _infer_fn_dtype(name, srcs)
|
||||
return UOp(Ops.CUSTOM, output_dtype, srcs, arg=name)
|
||||
# MEM[addr] -> CUSTOM('MEM', addr), MEM[addr].type -> BITCAST
|
||||
|
||||
@@ -85,6 +85,9 @@ def _pr(n, d=0):
|
||||
case UOp(Ops.XOR, _, (x,)) if len(n.src) == 1: return f"~{_pr(x)}"
|
||||
case UOp(Ops.CMPEQ, _, (x,)) if len(n.src) == 1: return f"!{_pr(x)}"
|
||||
case UOp(Ops.CMPNE, dtypes.bool, (a, b)) if a == b: return f"isNAN({_pr(a)})"
|
||||
# isINF(x) -> OR(CMPEQ(x, +inf), CMPEQ(x, -inf))
|
||||
case UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, _, (x1, UOp(Ops.CONST, _, _, c1))), UOp(Ops.CMPEQ, _, (x2, UOp(Ops.CONST, _, _, c2))))) if x1 == x2 and c1 == float('inf') and c2 == float('-inf'):
|
||||
return f"isINF({_pr(x1)})"
|
||||
# fract(x) -> SUB(x, floor(x)) where floor(x) = WHERE(CMPLT(x, TRUNC(x)), SUB(TRUNC(x), 1), TRUNC(x))
|
||||
case UOp(Ops.SUB, _, (x1, UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x2, UOp(Ops.TRUNC, _, (x3,)))), UOp(Ops.SUB, _, (UOp(Ops.TRUNC, _, (x4,)), UOp(Ops.CONST, _, _, c))), UOp(Ops.TRUNC, _, (x5,)))))) if c in (1, 1.0) and x1 == x2 == x3 == x4 == x5:
|
||||
return f"fract({_pr(x1)})"
|
||||
|
||||
@@ -227,8 +227,6 @@ def _minmax(args: list[UOp], is_min: bool) -> UOp:
|
||||
|
||||
def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
if name == 'MEM': return a[0]
|
||||
if name == 'isINF': return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('inf')))),
|
||||
UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('-inf'))))))
|
||||
if name in ('isQuietNAN', 'isSignalNAN'):
|
||||
bits, exp_shift, exp_mask, mant_mask = _fp_bits(a[0])
|
||||
# Use the dtype from bits (uint32/uint64/uint16) to determine which quiet bit to use
|
||||
@@ -295,7 +293,6 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
assert a[0].op == Ops.CONST and a[0].arg == 2.0
|
||||
return UOp(Ops.EXP2, a[0].dtype, (a[1] if a[1].dtype == a[0].dtype else UOp(Ops.CAST, a[0].dtype, (a[1],)),))
|
||||
if name == 'ldexp': return UOp(Ops.MUL, a[0].dtype, (a[0], UOp(Ops.EXP2, a[0].dtype, (UOp(Ops.CAST, a[0].dtype, (a[1],)),))))
|
||||
if name in ('min', 'max'): return _minmax(a, is_min=(name == 'min'))
|
||||
if name in CVT_MAP:
|
||||
dt, clamp = CVT_MAP[name]
|
||||
v = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, 0.0))), UOp.const(a[0].dtype, 0.0), a[0])) if clamp else a[0]
|
||||
|
||||
Reference in New Issue
Block a user