From dea2a56dd00bb2e3cd908c50c4d558a3db205225 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 11 Jan 2026 11:03:05 +0900 Subject: [PATCH] more in pcode_transform --- extra/assembly/amd/pcode_transform.py | 39 +++++++++++++++++++-------- extra/assembly/amd/ucode.py | 11 -------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/extra/assembly/amd/pcode_transform.py b/extra/assembly/amd/pcode_transform.py index dd85dae045..fe12af3338 100644 --- a/extra/assembly/amd/pcode_transform.py +++ b/extra/assembly/amd/pcode_transform.py @@ -19,13 +19,12 @@ _SPECIAL_CASTS = {'f32_to_u32', 'f64_to_u32', 'f16_to_u32', 'f32_to_u64', 'f64_t _CAST_MAP = {f'{s}_to_{d}': _DT_SUFFIX[d] for s in _DT_SUFFIX for d in _DT_SUFFIX if s != d and f'{s}_to_{d}' not in _SPECIAL_CASTS} _CAST_MAP.update({f'v_cvt_{d}_{s}': _DT_SUFFIX[d] for s in _DT_SUFFIX for d in _DT_SUFFIX if s != d and f'{s}_to_{d}' not in _SPECIAL_CASTS}) -# CUSTOM op return types +# CUSTOM op return types (ops that stay as CUSTOM after transformation) _CUSTOM_TYPES = { - 'isDENORM': dtypes.bool, 'isQuietNAN': dtypes.bool, 'isSignalNAN': dtypes.bool, 'isEven': dtypes.bool, - 'LT_NEG_ZERO': dtypes.bool, 'GT_NEG_ZERO': dtypes.bool, - 'sign': dtypes.uint32, 'exponent': dtypes.uint32, 'ABSDIFF': dtypes.uint32, 'SAT8': dtypes.uint32, - 'BYTE_PERMUTE': dtypes.uint32, 'count_ones': dtypes.uint32, 'countbits': dtypes.uint32, 'reverse_bits': dtypes.uint32, - 'u8_to_u32': dtypes.uint32, 'u4_to_u32': dtypes.uint32, 's_ff1_i32_b32': dtypes.uint32, 's_ff1_i32_b64': dtypes.uint32, + 'isDENORM': dtypes.bool, 'isQuietNAN': dtypes.bool, 'isSignalNAN': dtypes.bool, + 'sign': dtypes.uint32, 'exponent': dtypes.uint32, 'BYTE_PERMUTE': dtypes.uint32, + 'count_ones': dtypes.uint32, 'countbits': dtypes.uint32, 'reverse_bits': dtypes.uint32, + 's_ff1_i32_b32': dtypes.uint32, 's_ff1_i32_b64': dtypes.uint32, 'v_sad_u8': dtypes.uint32, 'v_msad_u8': dtypes.uint32, 'ConvertFromFormat': dtypes.uint32, 'nop': dtypes.uint32, 'f32_to_u32': dtypes.uint32, 'f64_to_u32': dtypes.uint32, 'signext_from_bit': dtypes.int64, 'f16_to_snorm': dtypes.int16, 'f16_to_unorm': dtypes.uint16, 'f32_to_snorm': dtypes.int16, 'f32_to_unorm': dtypes.uint16, @@ -62,8 +61,9 @@ def _vn(u: UOp) -> str|None: # var name if u.op == Ops.DEFINE_VAR: return u.arg[0] if isinstance(u.arg, tuple) else u.arg return _vn(u.src[0]) if u.op == Ops.CUSTOMI and u.src[0].op == Ops.DEFINE_VAR else None -def _typed_const(src: UOp, val) -> UOp: # const matching src type - return UOp.const(src.dtype, val) if src.dtype != dtypes.void else UOp(Ops.CONST, dtypes.void, (src,), val) + + + # ═══════════════════════════════════════════════════════════════════════════════ # PATTERN HANDLERS @@ -160,8 +160,8 @@ pcode_pm = PatternMatcher([ (UPat(Ops.CUSTOM, arg='sin', src=(_fpat,)), lambda x: UOp(Ops.SIN, x.dtype, (x,))), (UPat(Ops.CUSTOM, arg='rcp', src=(_fpat,)), lambda x: UOp(Ops.RECIPROCAL, x.dtype, (x,))), (UPat(Ops.CUSTOM, arg='fma', src=(_fpat, UPat.var('b'), UPat.var('c'))), lambda x, b, c: UOp(Ops.MULACC, x.dtype, (x, b, c))), - (UPat(Ops.CUSTOM, arg='abs', src=(_fpat,)), lambda x: UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _typed_const(x, 0))), UOp(Ops.NEG, x.dtype, (x,)), x))), - (UPat(Ops.CUSTOM, arg='cos', src=(_fpat,)), lambda x: UOp(Ops.SIN, x.dtype, (UOp(Ops.ADD, x.dtype, (x, _typed_const(x, 1.5707963267948966))),))), + (UPat(Ops.CUSTOM, arg='abs', src=(_fpat,)), lambda x: UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _tc(x, 0))), UOp(Ops.NEG, x.dtype, (x,)), x))), + (UPat(Ops.CUSTOM, arg='cos', src=(_fpat,)), lambda x: UOp(Ops.SIN, x.dtype, (UOp(Ops.ADD, x.dtype, (x, _tc(x, 1.5707963267948966))),))), (UPat(Ops.CUSTOM, arg='floor', src=(_fpat,)), lambda x: _floor(x)), (UPat(Ops.CUSTOM, arg='fract', src=(_fpat,)), lambda x: UOp(Ops.SUB, x.dtype, (x, _floor(x)))), (UPat(Ops.CUSTOM, arg='rsqrt', src=(_fpat,)), lambda x: UOp(Ops.RECIPROCAL, x.dtype, (UOp(Ops.SQRT, x.dtype, (x,)),))), @@ -186,7 +186,24 @@ pcode_pm = PatternMatcher([ # Boolean functions (UPat(Ops.CUSTOM, arg='isNAN', src=(UPat.var('x'),)), lambda x: UOp(Ops.CMPNE, dtypes.bool, (x, x))), (UPat(Ops.CUSTOM, arg='isINF', src=(UPat.var('x'),)), lambda x: UOp(Ops.OR, dtypes.bool, ( - UOp(Ops.CMPEQ, dtypes.bool, (x, _typed_const(x, float('inf')))), UOp(Ops.CMPEQ, dtypes.bool, (x, _typed_const(x, float('-inf'))))))), + UOp(Ops.CMPEQ, dtypes.bool, (x, _tc(x, float('inf')))), UOp(Ops.CMPEQ, dtypes.bool, (x, _tc(x, float('-inf'))))))), + # ABSDIFF: |a - b| for unsigned + (UPat(Ops.CUSTOM, arg='ABSDIFF', src=(UPat.var('a'), UPat.var('b'))), + lambda a, b: UOp(Ops.WHERE, a.dtype, (UOp(Ops.CMPLT, dtypes.bool, (b, a)), + UOp(Ops.SUB, a.dtype, (a, b)), UOp(Ops.SUB, a.dtype, (b, a))))), + # SAT8: clamp to signed 8-bit range [-128, 127] + (UPat(Ops.CUSTOM, arg='SAT8', src=(UPat.var('x'),)), + lambda x: UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _tc(x, -128))), _tc(x, -128), + UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (_tc(x, 127), x)), _tc(x, 127), x))))), + # cvtToQuietNAN: passthrough (just returns the arg) + (UPat(Ops.CUSTOM, arg='cvtToQuietNAN', src=(UPat.var('x'),)), lambda x: x), + # LT/GT_NEG_ZERO: signed comparison via bitcast to int (handles f16 fallback to f32 int type) + (UPat(Ops.CUSTOM, arg='LT_NEG_ZERO', src=(UPat.var('a', dtype=dtypes.floats), UPat.var('b', dtype=dtypes.floats))), + lambda a, b: (idt := {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a.dtype, dtypes.int32)) and + UOp(Ops.CMPLT, dtypes.bool, (UOp(Ops.BITCAST, idt, (a,)), UOp(Ops.BITCAST, idt, (b,))))), + (UPat(Ops.CUSTOM, arg='GT_NEG_ZERO', src=(UPat.var('a', dtype=dtypes.floats), UPat.var('b', dtype=dtypes.floats))), + lambda a, b: (idt := {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a.dtype, dtypes.int32)) and + UOp(Ops.CMPLT, dtypes.bool, (UOp(Ops.BITCAST, idt, (b,)), UOp(Ops.BITCAST, idt, (a,))))), # min/max (UPat(Ops.CUSTOM, arg='min', src=(UPat.var('a'), UPat.var('b'))), lambda a, b: _minmax(a, b, True)), (UPat(Ops.CUSTOM, arg='max', src=(UPat.var('a'), UPat.var('b'))), lambda a, b: _minmax(a, b, False)), diff --git a/extra/assembly/amd/ucode.py b/extra/assembly/amd/ucode.py index 37d64827bb..9bf7905b4c 100644 --- a/extra/assembly/amd/ucode.py +++ b/extra/assembly/amd/ucode.py @@ -232,18 +232,10 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp: result = UOp(Ops.BITCAST, a[0].dtype, (UOp(Ops.OR, uint_dt, (UOp(Ops.AND, uint_dt, (bits, UOp.const(uint_dt, (1 << sign_shift) | mant_mask))), UOp.const(uint_dt, (bias - 1) << exp_shift))),)) return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, 0.0))), a[0], result)) - if name == 'cvtToQuietNAN': return a[0] if name == 'signext_from_bit': sign = UOp(Ops.SHL, a[0].dtype, (UOp.const(a[0].dtype, 1), UOp(Ops.SUB, a[0].dtype, (_cast(a[1], a[0].dtype), UOp.const(a[0].dtype, 1))))) result = UOp(Ops.SUB, a[0].dtype, (UOp(Ops.XOR, a[0].dtype, (a[0], sign)), sign)) return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPEQ, dtypes.bool, (a[1], UOp.const(a[1].dtype, 0))), UOp.const(a[0].dtype, 0), result)) - if name == 'ABSDIFF': - gt = UOp(Ops.CMPLT, dtypes.bool, (a[1], a[0])) - return UOp(Ops.SUB, dtypes.uint32, (UOp(Ops.WHERE, dtypes.uint32, (gt, _cast(a[0], dtypes.uint32), _cast(a[1], dtypes.uint32))), - UOp(Ops.WHERE, dtypes.uint32, (gt, _cast(a[1], dtypes.uint32), _cast(a[0], dtypes.uint32))))) - if name == 'SAT8': - c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, -128))), UOp.const(a[0].dtype, -128), a[0])) - return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(a[0].dtype, 127), c)), UOp.const(a[0].dtype, 127), c)) if name == 'BYTE_PERMUTE': src64, sel = _cast(a[0], dtypes.uint64), UOp(Ops.AND, dtypes.uint32, (_cast(a[1], dtypes.uint32), UOp.const(dtypes.uint32, 0xff))) sel_idx, sel_nib = UOp(Ops.AND, dtypes.uint32, (sel, UOp.const(dtypes.uint32, 7))), UOp(Ops.AND, dtypes.uint32, (sel, UOp.const(dtypes.uint32, 0xf))) @@ -268,9 +260,6 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp: c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, lo))), UOp.const(a[0].dtype, lo), a[0])) c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(a[0].dtype, 1.0), c)), UOp.const(a[0].dtype, 1.0), c)) return UOp(Ops.CAST, out, (UOp(Ops.MUL, a[0].dtype, (c, UOp.const(a[0].dtype, scale))),)) - if name in ('LT_NEG_ZERO', 'GT_NEG_ZERO'): - int_dt = {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a[0].dtype, dtypes.int32) - return UOp(Ops.CMPLT, dtypes.bool, ((UOp(Ops.BITCAST, int_dt, (a[0],)), UOp(Ops.BITCAST, int_dt, (a[1],))) if 'LT' in name else (UOp(Ops.BITCAST, int_dt, (a[1],)), UOp(Ops.BITCAST, int_dt, (a[0],))))) if name in ('v_sad_u8', 'v_msad_u8'): result = a[2] if len(a) > 2 else UOp.const(dtypes.uint32, 0) for i in range(4):