mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
calls
This commit is contained in:
@@ -43,6 +43,10 @@ _UNOPS: dict[str, Ops] = {'-': Ops.NEG, '~': Ops.XOR, '!': Ops.CMPEQ}
|
||||
# Direct function -> UOp mappings (parsed directly, not as CUSTOM)
|
||||
_DIRECT_OPS: dict[str, Ops] = {'trunc': Ops.TRUNC, 'sqrt': Ops.SQRT, 'exp2': Ops.EXP2, 'log2': Ops.LOG2, 'sin': Ops.SIN, 'rcp': Ops.RECIPROCAL}
|
||||
|
||||
def _typed_const(src: UOp, val) -> UOp:
|
||||
"""Create a const with same dtype as src, or a deferred const if src.dtype is void."""
|
||||
return UOp.const(src.dtype, val) if src.dtype != dtypes.void else UOp(Ops.CONST, dtypes.void, (src,), val)
|
||||
|
||||
# Function return type inference for CUSTOM ops
|
||||
_BOOL_FNS = {'isNAN', 'isINF', 'isDENORM', 'isQuietNAN', 'isSignalNAN', 'isEven', 'LT_NEG_ZERO', 'GT_NEG_ZERO'}
|
||||
_PASSTHRU_FNS = {'abs', 'floor', 'fract', 'sqrt', 'sin', 'cos', 'trunc', 'fma', 'clamp', 'min', 'max', 'ldexp',
|
||||
@@ -173,6 +177,24 @@ def expr(s: str) -> UOp:
|
||||
if name in _DIRECT_OPS: return UOp(_DIRECT_OPS[name], srcs[0].dtype, srcs)
|
||||
if name == 'fma': return UOp(Ops.MULACC, srcs[2].dtype, (srcs[0], srcs[1], srcs[2]))
|
||||
if name == 'isNAN': return UOp(Ops.CMPNE, dtypes.bool, (srcs[0], srcs[0]))
|
||||
if name == 'rsqrt': return UOp(Ops.RECIPROCAL, srcs[0].dtype, (UOp(Ops.SQRT, srcs[0].dtype, (srcs[0],)),))
|
||||
if name == 'clamp':
|
||||
x, lo, hi = srcs[0], srcs[1], srcs[2]
|
||||
c = UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, lo)), lo, x))
|
||||
return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (hi, c)), hi, c))
|
||||
if name == 'abs':
|
||||
x = srcs[0]
|
||||
return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _typed_const(x, 0))), UOp(Ops.NEG, x.dtype, (x,)), x))
|
||||
if name == 'cos':
|
||||
x = srcs[0]
|
||||
return UOp(Ops.SIN, x.dtype, (UOp(Ops.ADD, x.dtype, (x, _typed_const(x, 1.5707963267948966))),))
|
||||
if name == 'floor':
|
||||
x, trunc = srcs[0], UOp(Ops.TRUNC, srcs[0].dtype, (srcs[0],))
|
||||
return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, trunc)), UOp(Ops.SUB, x.dtype, (trunc, _typed_const(x, 1))), trunc))
|
||||
if name == 'fract':
|
||||
x, trunc = srcs[0], UOp(Ops.TRUNC, srcs[0].dtype, (srcs[0],))
|
||||
floor = UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, trunc)), UOp(Ops.SUB, x.dtype, (trunc, _typed_const(x, 1))), trunc))
|
||||
return UOp(Ops.SUB, x.dtype, (x, floor))
|
||||
output_dtype = _infer_fn_dtype(name, srcs)
|
||||
return UOp(Ops.CUSTOM, output_dtype, srcs, arg=name)
|
||||
# MEM[addr] -> CUSTOM('MEM', addr), MEM[addr].type -> BITCAST
|
||||
|
||||
@@ -85,6 +85,9 @@ def _pr(n, d=0):
|
||||
case UOp(Ops.XOR, _, (x,)) if len(n.src) == 1: return f"~{_pr(x)}"
|
||||
case UOp(Ops.CMPEQ, _, (x,)) if len(n.src) == 1: return f"!{_pr(x)}"
|
||||
case UOp(Ops.CMPNE, dtypes.bool, (a, b)) if a == b: return f"isNAN({_pr(a)})"
|
||||
# fract(x) -> SUB(x, floor(x)) where floor(x) = WHERE(CMPLT(x, TRUNC(x)), SUB(TRUNC(x), 1), TRUNC(x))
|
||||
case UOp(Ops.SUB, _, (x1, UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x2, UOp(Ops.TRUNC, _, (x3,)))), UOp(Ops.SUB, _, (UOp(Ops.TRUNC, _, (x4,)), UOp(Ops.CONST, _, _, c))), UOp(Ops.TRUNC, _, (x5,)))))) if c in (1, 1.0) and x1 == x2 == x3 == x4 == x5:
|
||||
return f"fract({_pr(x1)})"
|
||||
case UOp(_, _, (l, r), _) if n.op in _OP_SYMS:
|
||||
sym = _OP_SYMS[n.op]
|
||||
left, right = l, r
|
||||
@@ -92,12 +95,24 @@ def _pr(n, d=0):
|
||||
if n.tag == 'flipped' and n.op == Ops.CMPLE: sym, left, right = '>=', r, l
|
||||
if n.tag == '<>' and n.op == Ops.CMPNE: sym = '<>'
|
||||
return f"{_pr(left)} {sym} {_pr(right)}"
|
||||
# clamp(x, lo, hi) -> WHERE(CMPLT(hi, WHERE(CMPLT(x, lo), lo, x)), hi, WHERE(CMPLT(x, lo), lo, x))
|
||||
case UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (hi, UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x1, lo1)), lo2, x2)))), hi2, UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x3, lo3)), lo4, x4)))) if hi == hi2 and x1 == x2 == x3 == x4 and lo1 == lo2 == lo3 == lo4:
|
||||
return f"clamp({_pr(x1)}, {_pr(lo1)}, {_pr(hi)})"
|
||||
# abs(x) -> WHERE(CMPLT(x, 0), NEG(x), x)
|
||||
case UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x1, UOp(Ops.CONST, _, _, c))), UOp(Ops.NEG, _, (x2,)), x3)) if c in (0, 0.0) and x1 == x2 == x3:
|
||||
return f"abs({_pr(x1)})"
|
||||
# floor(x) -> WHERE(CMPLT(x, TRUNC(x)), SUB(TRUNC(x), 1), TRUNC(x))
|
||||
case UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x1, UOp(Ops.TRUNC, _, (x2,)))), UOp(Ops.SUB, _, (UOp(Ops.TRUNC, _, (x3,)), UOp(Ops.CONST, _, _, c))), UOp(Ops.TRUNC, _, (x4,)))) if c in (1, 1.0) and x1 == x2 == x3 == x4:
|
||||
return f"floor({_pr(x1)})"
|
||||
case UOp(Ops.WHERE, _, (c, t, f)): return f"{_pr(c)} ? {_pr(t)} : {_pr(f)}"
|
||||
case UOp(Ops.TRUNC, _, (x,)): return f"trunc({_pr(x)})"
|
||||
case UOp(Ops.SQRT, _, (x,)): return f"sqrt({_pr(x)})"
|
||||
case UOp(Ops.EXP2, _, (x,)): return f"exp2({_pr(x)})"
|
||||
case UOp(Ops.LOG2, _, (x,)): return f"log2({_pr(x)})"
|
||||
# cos(x) -> SIN(ADD(x, π/2))
|
||||
case UOp(Ops.SIN, _, (UOp(Ops.ADD, _, (x, UOp(Ops.CONST, _, _, c))),)) if abs(c - 1.5707963267948966) < 1e-10: return f"cos({_pr(x)})"
|
||||
case UOp(Ops.SIN, _, (x,)): return f"sin({_pr(x)})"
|
||||
case UOp(Ops.RECIPROCAL, _, (UOp(Ops.SQRT, _, (x,)),)): return f"rsqrt({_pr(x)})"
|
||||
case UOp(Ops.RECIPROCAL, _, (x,)): return f"rcp({_pr(x)})"
|
||||
case UOp(Ops.MULACC, _, (a, b, c)): return f"fma({_pr(a)}, {_pr(b)}, {_pr(c)})"
|
||||
case UOp(Ops.CUSTOM, _, args, 'MEM'): return f"MEM[{_pr(args[0])}]"
|
||||
|
||||
@@ -76,6 +76,9 @@ def _resolve_special_var(name: str, ctx: Ctx, hint: DType = None) -> UOp | None:
|
||||
def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
|
||||
"""Transform parsed UOp expression to resolved UOp."""
|
||||
match node:
|
||||
case UOp(Ops.CONST, dtypes.void, (type_src,), val): # Deferred const: infer type from type_src
|
||||
resolved_src = _expr(type_src, ctx, hint)
|
||||
return UOp.const(resolved_src.dtype, val)
|
||||
case UOp(Ops.CONST, dt, _, val):
|
||||
dt = dt if dt != dtypes.int32 or hint is None else hint
|
||||
return UOp.const(dtypes.float32 if isinstance(val, float) and dt not in FLOATS else dt, val)
|
||||
@@ -224,16 +227,6 @@ def _minmax(args: list[UOp], is_min: bool) -> UOp:
|
||||
|
||||
def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
if name == 'MEM': return a[0]
|
||||
if name == 'abs': return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, 0))), UOp(Ops.NEG, a[0].dtype, (a[0],)), a[0]))
|
||||
if name == 'cos': return UOp(Ops.SIN, a[0].dtype, (UOp(Ops.ADD, a[0].dtype, (a[0], UOp.const(a[0].dtype, 1.5707963267948966))),))
|
||||
if name == 'rsqrt': return UOp(Ops.RECIPROCAL, a[0].dtype, (UOp(Ops.SQRT, a[0].dtype, (a[0],)),))
|
||||
if name == 'floor':
|
||||
trunc = UOp(Ops.TRUNC, a[0].dtype, (a[0],))
|
||||
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], trunc)), UOp(Ops.SUB, a[0].dtype, (trunc, UOp.const(a[0].dtype, 1))), trunc))
|
||||
if name == 'fract': return UOp(Ops.SUB, a[0].dtype, (a[0], _transform_call('floor', a, hint)))
|
||||
if name == 'clamp':
|
||||
c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], a[1])), a[1], a[0]))
|
||||
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[2], c)), a[2], c))
|
||||
if name == 'isINF': return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('inf')))),
|
||||
UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('-inf'))))))
|
||||
if name in ('isQuietNAN', 'isSignalNAN'):
|
||||
|
||||
Reference in New Issue
Block a user