more in pcode_transform

This commit is contained in:
George Hotz
2026-01-11 11:03:05 +09:00
parent 754dfb15b4
commit dea2a56dd0
2 changed files with 28 additions and 22 deletions

View File

@@ -19,13 +19,12 @@ _SPECIAL_CASTS = {'f32_to_u32', 'f64_to_u32', 'f16_to_u32', 'f32_to_u64', 'f64_t
_CAST_MAP = {f'{s}_to_{d}': _DT_SUFFIX[d] for s in _DT_SUFFIX for d in _DT_SUFFIX if s != d and f'{s}_to_{d}' not in _SPECIAL_CASTS}
_CAST_MAP.update({f'v_cvt_{d}_{s}': _DT_SUFFIX[d] for s in _DT_SUFFIX for d in _DT_SUFFIX if s != d and f'{s}_to_{d}' not in _SPECIAL_CASTS})
# CUSTOM op return types
# CUSTOM op return types (ops that stay as CUSTOM after transformation)
_CUSTOM_TYPES = {
'isDENORM': dtypes.bool, 'isQuietNAN': dtypes.bool, 'isSignalNAN': dtypes.bool, 'isEven': dtypes.bool,
'LT_NEG_ZERO': dtypes.bool, 'GT_NEG_ZERO': dtypes.bool,
'sign': dtypes.uint32, 'exponent': dtypes.uint32, 'ABSDIFF': dtypes.uint32, 'SAT8': dtypes.uint32,
'BYTE_PERMUTE': dtypes.uint32, 'count_ones': dtypes.uint32, 'countbits': dtypes.uint32, 'reverse_bits': dtypes.uint32,
'u8_to_u32': dtypes.uint32, 'u4_to_u32': dtypes.uint32, 's_ff1_i32_b32': dtypes.uint32, 's_ff1_i32_b64': dtypes.uint32,
'isDENORM': dtypes.bool, 'isQuietNAN': dtypes.bool, 'isSignalNAN': dtypes.bool,
'sign': dtypes.uint32, 'exponent': dtypes.uint32, 'BYTE_PERMUTE': dtypes.uint32,
'count_ones': dtypes.uint32, 'countbits': dtypes.uint32, 'reverse_bits': dtypes.uint32,
's_ff1_i32_b32': dtypes.uint32, 's_ff1_i32_b64': dtypes.uint32,
'v_sad_u8': dtypes.uint32, 'v_msad_u8': dtypes.uint32, 'ConvertFromFormat': dtypes.uint32, 'nop': dtypes.uint32,
'f32_to_u32': dtypes.uint32, 'f64_to_u32': dtypes.uint32, 'signext_from_bit': dtypes.int64,
'f16_to_snorm': dtypes.int16, 'f16_to_unorm': dtypes.uint16, 'f32_to_snorm': dtypes.int16, 'f32_to_unorm': dtypes.uint16,
@@ -62,8 +61,9 @@ def _vn(u: UOp) -> str|None: # var name
if u.op == Ops.DEFINE_VAR: return u.arg[0] if isinstance(u.arg, tuple) else u.arg
return _vn(u.src[0]) if u.op == Ops.CUSTOMI and u.src[0].op == Ops.DEFINE_VAR else None
def _typed_const(src: UOp, val) -> UOp: # const matching src type
return UOp.const(src.dtype, val) if src.dtype != dtypes.void else UOp(Ops.CONST, dtypes.void, (src,), val)
# ═══════════════════════════════════════════════════════════════════════════════
# PATTERN HANDLERS
@@ -160,8 +160,8 @@ pcode_pm = PatternMatcher([
(UPat(Ops.CUSTOM, arg='sin', src=(_fpat,)), lambda x: UOp(Ops.SIN, x.dtype, (x,))),
(UPat(Ops.CUSTOM, arg='rcp', src=(_fpat,)), lambda x: UOp(Ops.RECIPROCAL, x.dtype, (x,))),
(UPat(Ops.CUSTOM, arg='fma', src=(_fpat, UPat.var('b'), UPat.var('c'))), lambda x, b, c: UOp(Ops.MULACC, x.dtype, (x, b, c))),
(UPat(Ops.CUSTOM, arg='abs', src=(_fpat,)), lambda x: UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _typed_const(x, 0))), UOp(Ops.NEG, x.dtype, (x,)), x))),
(UPat(Ops.CUSTOM, arg='cos', src=(_fpat,)), lambda x: UOp(Ops.SIN, x.dtype, (UOp(Ops.ADD, x.dtype, (x, _typed_const(x, 1.5707963267948966))),))),
(UPat(Ops.CUSTOM, arg='abs', src=(_fpat,)), lambda x: UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _tc(x, 0))), UOp(Ops.NEG, x.dtype, (x,)), x))),
(UPat(Ops.CUSTOM, arg='cos', src=(_fpat,)), lambda x: UOp(Ops.SIN, x.dtype, (UOp(Ops.ADD, x.dtype, (x, _tc(x, 1.5707963267948966))),))),
(UPat(Ops.CUSTOM, arg='floor', src=(_fpat,)), lambda x: _floor(x)),
(UPat(Ops.CUSTOM, arg='fract', src=(_fpat,)), lambda x: UOp(Ops.SUB, x.dtype, (x, _floor(x)))),
(UPat(Ops.CUSTOM, arg='rsqrt', src=(_fpat,)), lambda x: UOp(Ops.RECIPROCAL, x.dtype, (UOp(Ops.SQRT, x.dtype, (x,)),))),
@@ -186,7 +186,24 @@ pcode_pm = PatternMatcher([
# Boolean functions
(UPat(Ops.CUSTOM, arg='isNAN', src=(UPat.var('x'),)), lambda x: UOp(Ops.CMPNE, dtypes.bool, (x, x))),
(UPat(Ops.CUSTOM, arg='isINF', src=(UPat.var('x'),)), lambda x: UOp(Ops.OR, dtypes.bool, (
UOp(Ops.CMPEQ, dtypes.bool, (x, _typed_const(x, float('inf')))), UOp(Ops.CMPEQ, dtypes.bool, (x, _typed_const(x, float('-inf'))))))),
UOp(Ops.CMPEQ, dtypes.bool, (x, _tc(x, float('inf')))), UOp(Ops.CMPEQ, dtypes.bool, (x, _tc(x, float('-inf'))))))),
# ABSDIFF: |a - b| for unsigned
(UPat(Ops.CUSTOM, arg='ABSDIFF', src=(UPat.var('a'), UPat.var('b'))),
lambda a, b: UOp(Ops.WHERE, a.dtype, (UOp(Ops.CMPLT, dtypes.bool, (b, a)),
UOp(Ops.SUB, a.dtype, (a, b)), UOp(Ops.SUB, a.dtype, (b, a))))),
# SAT8: clamp to signed 8-bit range [-128, 127]
(UPat(Ops.CUSTOM, arg='SAT8', src=(UPat.var('x'),)),
lambda x: UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _tc(x, -128))), _tc(x, -128),
UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (_tc(x, 127), x)), _tc(x, 127), x))))),
# cvtToQuietNAN: passthrough (just returns the arg)
(UPat(Ops.CUSTOM, arg='cvtToQuietNAN', src=(UPat.var('x'),)), lambda x: x),
# LT/GT_NEG_ZERO: signed comparison via bitcast to int (handles f16 fallback to f32 int type)
(UPat(Ops.CUSTOM, arg='LT_NEG_ZERO', src=(UPat.var('a', dtype=dtypes.floats), UPat.var('b', dtype=dtypes.floats))),
lambda a, b: (idt := {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a.dtype, dtypes.int32)) and
UOp(Ops.CMPLT, dtypes.bool, (UOp(Ops.BITCAST, idt, (a,)), UOp(Ops.BITCAST, idt, (b,))))),
(UPat(Ops.CUSTOM, arg='GT_NEG_ZERO', src=(UPat.var('a', dtype=dtypes.floats), UPat.var('b', dtype=dtypes.floats))),
lambda a, b: (idt := {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a.dtype, dtypes.int32)) and
UOp(Ops.CMPLT, dtypes.bool, (UOp(Ops.BITCAST, idt, (b,)), UOp(Ops.BITCAST, idt, (a,))))),
# min/max
(UPat(Ops.CUSTOM, arg='min', src=(UPat.var('a'), UPat.var('b'))), lambda a, b: _minmax(a, b, True)),
(UPat(Ops.CUSTOM, arg='max', src=(UPat.var('a'), UPat.var('b'))), lambda a, b: _minmax(a, b, False)),

View File

@@ -232,18 +232,10 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
result = UOp(Ops.BITCAST, a[0].dtype, (UOp(Ops.OR, uint_dt, (UOp(Ops.AND, uint_dt, (bits, UOp.const(uint_dt, (1 << sign_shift) | mant_mask))),
UOp.const(uint_dt, (bias - 1) << exp_shift))),))
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, 0.0))), a[0], result))
if name == 'cvtToQuietNAN': return a[0]
if name == 'signext_from_bit':
sign = UOp(Ops.SHL, a[0].dtype, (UOp.const(a[0].dtype, 1), UOp(Ops.SUB, a[0].dtype, (_cast(a[1], a[0].dtype), UOp.const(a[0].dtype, 1)))))
result = UOp(Ops.SUB, a[0].dtype, (UOp(Ops.XOR, a[0].dtype, (a[0], sign)), sign))
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPEQ, dtypes.bool, (a[1], UOp.const(a[1].dtype, 0))), UOp.const(a[0].dtype, 0), result))
if name == 'ABSDIFF':
gt = UOp(Ops.CMPLT, dtypes.bool, (a[1], a[0]))
return UOp(Ops.SUB, dtypes.uint32, (UOp(Ops.WHERE, dtypes.uint32, (gt, _cast(a[0], dtypes.uint32), _cast(a[1], dtypes.uint32))),
UOp(Ops.WHERE, dtypes.uint32, (gt, _cast(a[1], dtypes.uint32), _cast(a[0], dtypes.uint32)))))
if name == 'SAT8':
c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, -128))), UOp.const(a[0].dtype, -128), a[0]))
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(a[0].dtype, 127), c)), UOp.const(a[0].dtype, 127), c))
if name == 'BYTE_PERMUTE':
src64, sel = _cast(a[0], dtypes.uint64), UOp(Ops.AND, dtypes.uint32, (_cast(a[1], dtypes.uint32), UOp.const(dtypes.uint32, 0xff)))
sel_idx, sel_nib = UOp(Ops.AND, dtypes.uint32, (sel, UOp.const(dtypes.uint32, 7))), UOp(Ops.AND, dtypes.uint32, (sel, UOp.const(dtypes.uint32, 0xf)))
@@ -268,9 +260,6 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, lo))), UOp.const(a[0].dtype, lo), a[0]))
c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(a[0].dtype, 1.0), c)), UOp.const(a[0].dtype, 1.0), c))
return UOp(Ops.CAST, out, (UOp(Ops.MUL, a[0].dtype, (c, UOp.const(a[0].dtype, scale))),))
if name in ('LT_NEG_ZERO', 'GT_NEG_ZERO'):
int_dt = {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a[0].dtype, dtypes.int32)
return UOp(Ops.CMPLT, dtypes.bool, ((UOp(Ops.BITCAST, int_dt, (a[0],)), UOp(Ops.BITCAST, int_dt, (a[1],))) if 'LT' in name else (UOp(Ops.BITCAST, int_dt, (a[1],)), UOp(Ops.BITCAST, int_dt, (a[0],)))))
if name in ('v_sad_u8', 'v_msad_u8'):
result = a[2] if len(a) > 2 else UOp.const(dtypes.uint32, 0)
for i in range(4):