diff --git a/extra/assembly/amd/pcode_parse.py b/extra/assembly/amd/pcode_parse.py index 7a36b1de59..adb5440457 100644 --- a/extra/assembly/amd/pcode_parse.py +++ b/extra/assembly/amd/pcode_parse.py @@ -43,6 +43,10 @@ _UNOPS: dict[str, Ops] = {'-': Ops.NEG, '~': Ops.XOR, '!': Ops.CMPEQ} # Direct function -> UOp mappings (parsed directly, not as CUSTOM) _DIRECT_OPS: dict[str, Ops] = {'trunc': Ops.TRUNC, 'sqrt': Ops.SQRT, 'exp2': Ops.EXP2, 'log2': Ops.LOG2, 'sin': Ops.SIN, 'rcp': Ops.RECIPROCAL} +def _typed_const(src: UOp, val) -> UOp: + """Create a const with same dtype as src, or a deferred const if src.dtype is void.""" + return UOp.const(src.dtype, val) if src.dtype != dtypes.void else UOp(Ops.CONST, dtypes.void, (src,), val) + # Function return type inference for CUSTOM ops _BOOL_FNS = {'isNAN', 'isINF', 'isDENORM', 'isQuietNAN', 'isSignalNAN', 'isEven', 'LT_NEG_ZERO', 'GT_NEG_ZERO'} _PASSTHRU_FNS = {'abs', 'floor', 'fract', 'sqrt', 'sin', 'cos', 'trunc', 'fma', 'clamp', 'min', 'max', 'ldexp', @@ -173,6 +177,24 @@ def expr(s: str) -> UOp: if name in _DIRECT_OPS: return UOp(_DIRECT_OPS[name], srcs[0].dtype, srcs) if name == 'fma': return UOp(Ops.MULACC, srcs[2].dtype, (srcs[0], srcs[1], srcs[2])) if name == 'isNAN': return UOp(Ops.CMPNE, dtypes.bool, (srcs[0], srcs[0])) + if name == 'rsqrt': return UOp(Ops.RECIPROCAL, srcs[0].dtype, (UOp(Ops.SQRT, srcs[0].dtype, (srcs[0],)),)) + if name == 'clamp': + x, lo, hi = srcs[0], srcs[1], srcs[2] + c = UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, lo)), lo, x)) + return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (hi, c)), hi, c)) + if name == 'abs': + x = srcs[0] + return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _typed_const(x, 0))), UOp(Ops.NEG, x.dtype, (x,)), x)) + if name == 'cos': + x = srcs[0] + return UOp(Ops.SIN, x.dtype, (UOp(Ops.ADD, x.dtype, (x, _typed_const(x, 1.5707963267948966))),)) + if name == 'floor': + x, trunc = srcs[0], UOp(Ops.TRUNC, srcs[0].dtype, (srcs[0],)) + return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, trunc)), UOp(Ops.SUB, x.dtype, (trunc, _typed_const(x, 1))), trunc)) + if name == 'fract': + x, trunc = srcs[0], UOp(Ops.TRUNC, srcs[0].dtype, (srcs[0],)) + floor = UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, trunc)), UOp(Ops.SUB, x.dtype, (trunc, _typed_const(x, 1))), trunc)) + return UOp(Ops.SUB, x.dtype, (x, floor)) output_dtype = _infer_fn_dtype(name, srcs) return UOp(Ops.CUSTOM, output_dtype, srcs, arg=name) # MEM[addr] -> CUSTOM('MEM', addr), MEM[addr].type -> BITCAST diff --git a/extra/assembly/amd/test/test_pcode_parse.py b/extra/assembly/amd/test/test_pcode_parse.py index 7d9bc360f6..6a0053421f 100644 --- a/extra/assembly/amd/test/test_pcode_parse.py +++ b/extra/assembly/amd/test/test_pcode_parse.py @@ -85,6 +85,9 @@ def _pr(n, d=0): case UOp(Ops.XOR, _, (x,)) if len(n.src) == 1: return f"~{_pr(x)}" case UOp(Ops.CMPEQ, _, (x,)) if len(n.src) == 1: return f"!{_pr(x)}" case UOp(Ops.CMPNE, dtypes.bool, (a, b)) if a == b: return f"isNAN({_pr(a)})" + # fract(x) -> SUB(x, floor(x)) where floor(x) = WHERE(CMPLT(x, TRUNC(x)), SUB(TRUNC(x), 1), TRUNC(x)) + case UOp(Ops.SUB, _, (x1, UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x2, UOp(Ops.TRUNC, _, (x3,)))), UOp(Ops.SUB, _, (UOp(Ops.TRUNC, _, (x4,)), UOp(Ops.CONST, _, _, c))), UOp(Ops.TRUNC, _, (x5,)))))) if c in (1, 1.0) and x1 == x2 == x3 == x4 == x5: + return f"fract({_pr(x1)})" case UOp(_, _, (l, r), _) if n.op in _OP_SYMS: sym = _OP_SYMS[n.op] left, right = l, r @@ -92,12 +95,24 @@ def _pr(n, d=0): if n.tag == 'flipped' and n.op == Ops.CMPLE: sym, left, right = '>=', r, l if n.tag == '<>' and n.op == Ops.CMPNE: sym = '<>' return f"{_pr(left)} {sym} {_pr(right)}" + # clamp(x, lo, hi) -> WHERE(CMPLT(hi, WHERE(CMPLT(x, lo), lo, x)), hi, WHERE(CMPLT(x, lo), lo, x)) + case UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (hi, UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x1, lo1)), lo2, x2)))), hi2, UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x3, lo3)), lo4, x4)))) if hi == hi2 and x1 == x2 == x3 == x4 and lo1 == lo2 == lo3 == lo4: + return f"clamp({_pr(x1)}, {_pr(lo1)}, {_pr(hi)})" + # abs(x) -> WHERE(CMPLT(x, 0), NEG(x), x) + case UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x1, UOp(Ops.CONST, _, _, c))), UOp(Ops.NEG, _, (x2,)), x3)) if c in (0, 0.0) and x1 == x2 == x3: + return f"abs({_pr(x1)})" + # floor(x) -> WHERE(CMPLT(x, TRUNC(x)), SUB(TRUNC(x), 1), TRUNC(x)) + case UOp(Ops.WHERE, _, (UOp(Ops.CMPLT, _, (x1, UOp(Ops.TRUNC, _, (x2,)))), UOp(Ops.SUB, _, (UOp(Ops.TRUNC, _, (x3,)), UOp(Ops.CONST, _, _, c))), UOp(Ops.TRUNC, _, (x4,)))) if c in (1, 1.0) and x1 == x2 == x3 == x4: + return f"floor({_pr(x1)})" case UOp(Ops.WHERE, _, (c, t, f)): return f"{_pr(c)} ? {_pr(t)} : {_pr(f)}" case UOp(Ops.TRUNC, _, (x,)): return f"trunc({_pr(x)})" case UOp(Ops.SQRT, _, (x,)): return f"sqrt({_pr(x)})" case UOp(Ops.EXP2, _, (x,)): return f"exp2({_pr(x)})" case UOp(Ops.LOG2, _, (x,)): return f"log2({_pr(x)})" + # cos(x) -> SIN(ADD(x, π/2)) + case UOp(Ops.SIN, _, (UOp(Ops.ADD, _, (x, UOp(Ops.CONST, _, _, c))),)) if abs(c - 1.5707963267948966) < 1e-10: return f"cos({_pr(x)})" case UOp(Ops.SIN, _, (x,)): return f"sin({_pr(x)})" + case UOp(Ops.RECIPROCAL, _, (UOp(Ops.SQRT, _, (x,)),)): return f"rsqrt({_pr(x)})" case UOp(Ops.RECIPROCAL, _, (x,)): return f"rcp({_pr(x)})" case UOp(Ops.MULACC, _, (a, b, c)): return f"fma({_pr(a)}, {_pr(b)}, {_pr(c)})" case UOp(Ops.CUSTOM, _, args, 'MEM'): return f"MEM[{_pr(args[0])}]" diff --git a/extra/assembly/amd/ucode.py b/extra/assembly/amd/ucode.py index f5ce2dd86e..a54b18b6ec 100644 --- a/extra/assembly/amd/ucode.py +++ b/extra/assembly/amd/ucode.py @@ -76,6 +76,9 @@ def _resolve_special_var(name: str, ctx: Ctx, hint: DType = None) -> UOp | None: def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp: """Transform parsed UOp expression to resolved UOp.""" match node: + case UOp(Ops.CONST, dtypes.void, (type_src,), val): # Deferred const: infer type from type_src + resolved_src = _expr(type_src, ctx, hint) + return UOp.const(resolved_src.dtype, val) case UOp(Ops.CONST, dt, _, val): dt = dt if dt != dtypes.int32 or hint is None else hint return UOp.const(dtypes.float32 if isinstance(val, float) and dt not in FLOATS else dt, val) @@ -224,16 +227,6 @@ def _minmax(args: list[UOp], is_min: bool) -> UOp: def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp: if name == 'MEM': return a[0] - if name == 'abs': return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, 0))), UOp(Ops.NEG, a[0].dtype, (a[0],)), a[0])) - if name == 'cos': return UOp(Ops.SIN, a[0].dtype, (UOp(Ops.ADD, a[0].dtype, (a[0], UOp.const(a[0].dtype, 1.5707963267948966))),)) - if name == 'rsqrt': return UOp(Ops.RECIPROCAL, a[0].dtype, (UOp(Ops.SQRT, a[0].dtype, (a[0],)),)) - if name == 'floor': - trunc = UOp(Ops.TRUNC, a[0].dtype, (a[0],)) - return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], trunc)), UOp(Ops.SUB, a[0].dtype, (trunc, UOp.const(a[0].dtype, 1))), trunc)) - if name == 'fract': return UOp(Ops.SUB, a[0].dtype, (a[0], _transform_call('floor', a, hint))) - if name == 'clamp': - c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], a[1])), a[1], a[0])) - return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[2], c)), a[2], c)) if name == 'isINF': return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('inf')))), UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, float('-inf')))))) if name in ('isQuietNAN', 'isSignalNAN'):