Files
tinygrad/extra/assembly/amd/pcode_transform.py
George Hotz 322eb1fbc8 bitcast
2026-01-11 09:50:19 +09:00

302 lines
22 KiB
Python

# Transform parsed pcode CUSTOM ops to UOps using PatternMatcher
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp
from tinygrad.uop.spec import shared_spec, type_verify
from tinygrad.dtype import dtypes, DType
from extra.assembly.amd.pcode_parse import parse, If, For, Lambda, Break, Return
# ═══════════════════════════════════════════════════════════════════════════════
# TYPE MAPPINGS
# ═══════════════════════════════════════════════════════════════════════════════
_DT_SUFFIX = {
'f16': dtypes.float16, 'f32': dtypes.float32, 'f64': dtypes.float64, 'bf16': dtypes.bfloat16,
'i8': dtypes.int8, 'i16': dtypes.int16, 'i32': dtypes.int32, 'i64': dtypes.int64,
'u8': dtypes.uint8, 'u16': dtypes.uint16, 'u32': dtypes.uint32, 'u64': dtypes.uint64,
}
# Special conversions needing custom handling - excluded from auto-generated casts
_SPECIAL_CASTS = {'f32_to_u32', 'f64_to_u32', 'f16_to_u32', 'f32_to_u64', 'f64_to_u64', 'f16_to_u64', # clamping
'bf16_to_f32', 'u32_to_u16', 'i32_to_i16'} # bit manipulation
# Auto-generate all {src}_to_{dst} and v_cvt_{dst}_{src} cast mappings
_CAST_MAP = {f'{s}_to_{d}': _DT_SUFFIX[d] for s in _DT_SUFFIX for d in _DT_SUFFIX if s != d and f'{s}_to_{d}' not in _SPECIAL_CASTS}
_CAST_MAP.update({f'v_cvt_{d}_{s}': _DT_SUFFIX[d] for s in _DT_SUFFIX for d in _DT_SUFFIX if s != d and f'{s}_to_{d}' not in _SPECIAL_CASTS})
# Remaining CUSTOM ops that need dtype inference (not transformed to real UOps)
_BOOL_FNS = {'isDENORM', 'isQuietNAN', 'isSignalNAN', 'isEven', 'LT_NEG_ZERO', 'GT_NEG_ZERO'}
_U32_FNS = {'sign', 'exponent', 'ABSDIFF', 'SAT8', 'BYTE_PERMUTE', 'count_ones', 'countbits', 'reverse_bits',
'u8_to_u32', 'u4_to_u32', 's_ff1_i32_b32', 's_ff1_i32_b64', 'v_sad_u8', 'v_msad_u8'}
_CVT_FNS = {'f32_to_u32': dtypes.uint32, 'f64_to_u32': dtypes.uint32, 'signext_from_bit': dtypes.int64,
'f16_to_snorm': dtypes.int16, 'f16_to_unorm': dtypes.uint16, 'f32_to_snorm': dtypes.int16, 'f32_to_unorm': dtypes.uint16}
# ═══════════════════════════════════════════════════════════════════════════════
# HELPERS
# ═══════════════════════════════════════════════════════════════════════════════
def _typed_const(src: UOp, val) -> UOp:
return UOp.const(src.dtype, val) if src.dtype != dtypes.void else UOp(Ops.CONST, dtypes.void, (src,), val)
def _floor(x: UOp, dt: DType) -> UOp:
trunc = UOp(Ops.TRUNC, dt, (x,))
return UOp(Ops.WHERE, dt, (UOp(Ops.CMPLT, dtypes.bool, (x, trunc)), UOp(Ops.SUB, dt, (trunc, _typed_const(x, 1))), trunc))
def _minmax(a: UOp, b: UOp, is_min: bool, dt: DType|None = None) -> UOp:
dt = dt or (a.dtype if a.dtype != dtypes.void else b.dtype)
return UOp(Ops.WHERE, dt, (UOp(Ops.CMPLT, dtypes.bool, (a, b) if is_min else (b, a)), a, b))
def _minmax3(a: UOp, b: UOp, c: UOp, is_min: bool, dt: DType|None = None) -> UOp:
dt = dt or (a.dtype if a.dtype != dtypes.void else b.dtype if b.dtype != dtypes.void else c.dtype)
return _minmax(_minmax(a, b, is_min, dt), c, is_min, dt)
def _first_nonvoid(*srcs: UOp) -> DType:
return next((s.dtype for s in srcs if s.dtype != dtypes.void), dtypes.void)
def _var_name(u: UOp) -> str|None:
if u.op == Ops.DEFINE_VAR: return u.arg[0] if isinstance(u.arg, tuple) else u.arg
if u.op == Ops.CUSTOMI and u.src[0].op == Ops.DEFINE_VAR: return _var_name(u.src[0])
return None
# ═══════════════════════════════════════════════════════════════════════════════
# PATTERN HANDLERS
# ═══════════════════════════════════════════════════════════════════════════════
def _typed_minmax2(a, b, op):
if not isinstance(op.arg, str) or not (op.arg.startswith('v_min_') or op.arg.startswith('v_max_')): return None
if (suffix := op.arg.split('_')[-1]) not in _DT_SUFFIX: return None
return _minmax(a, b, op.arg.startswith('v_min_'), _DT_SUFFIX[suffix])
def _typed_minmax3(a, b, c, op):
if not isinstance(op.arg, str) or not (op.arg.startswith('v_min3_') or op.arg.startswith('v_max3_')): return None
if (suffix := op.arg.split('_')[-1]) not in _DT_SUFFIX: return None
return _minmax3(a, b, c, op.arg.startswith('v_min3_'), _DT_SUFFIX[suffix])
def _typed_cast(x, op):
return UOp(Ops.CAST, _CAST_MAP[op.arg], (x,)) if op.arg in _CAST_MAP else None
# Variable type tracking/propagation
def _track_var(ctx, u):
if ctx is None or u.dtype == dtypes.void: return None
name = u.arg[0] if isinstance(u.arg, tuple) else u.arg
if name in ctx: assert ctx[name] == u.dtype, f"variable '{name}' declared with conflicting types: {ctx[name]} vs {u.dtype}"
else: ctx[name] = u.dtype
return None
def _prop_var(ctx, u):
if ctx is None: return None
name = u.arg[0] if isinstance(u.arg, tuple) else u.arg
return UOp(Ops.DEFINE_VAR, ctx[name], arg=u.arg) if name in ctx else None
def _prop_assign(ctx, lhs, rhs):
if ctx is None or rhs.dtype == dtypes.void or lhs.op != Ops.DEFINE_VAR: return None
if (name := _var_name(lhs)) is None or name in ctx: return None
ctx[name] = rhs.dtype
return UOp(Ops.ASSIGN, rhs.dtype, (UOp(Ops.DEFINE_VAR, rhs.dtype, arg=lhs.arg), rhs))
# Dtype propagation for void-typed ops (forward propagation)
def _prop_binop(l, r, __OP__, **kw):
# For SHL/SHR, result type comes from left operand
if __OP__.op in {Ops.SHL, Ops.SHR}:
dt = l.dtype if l.dtype != dtypes.void else r.dtype
# Use larger dtype if both are typed, otherwise first non-void
elif l.dtype != dtypes.void and r.dtype != dtypes.void:
dt = l.dtype if l.dtype.itemsize >= r.dtype.itemsize else r.dtype
else:
dt = l.dtype if l.dtype != dtypes.void else r.dtype
return UOp(__OP__.op, dt, (l, r), kw.get('arg')) if dt != dtypes.void else None
# Back-propagate type to void DEFINE_VAR source
def _backprop_binop(ctx, op, void_var, typed_src):
# void_var is void DEFINE_VAR, typed_src is typed - propagate type to void_var
dt = typed_src.dtype
name = void_var.arg[0] if isinstance(void_var.arg, tuple) else void_var.arg
if ctx is not None:
if name in ctx: assert ctx[name] == dt, f"variable '{name}' has conflicting types: {ctx[name]} vs {dt}"
else: ctx[name] = dt
new_var = UOp(Ops.DEFINE_VAR, dt, arg=void_var.arg)
# maintain original order
new_srcs = (new_var, typed_src) if op.src[0] is void_var else (typed_src, new_var)
return UOp(op.op, op.dtype, new_srcs, op.arg)
def _prop_unop(x, __OP__, **kw):
return UOp(__OP__.op, x.dtype, (x,), kw.get('arg')) if x.dtype != dtypes.void else None
def _prop_mulacc(a, b, c, **kw):
return UOp(Ops.MULACC, c.dtype, (a, b, c), kw.get('arg')) if c.dtype != dtypes.void else None
def _prop_where(cond, t, f, **kw):
dt = _first_nonvoid(t, f)
return UOp(Ops.WHERE, dt, (cond, t, f), kw.get('arg')) if dt != dtypes.void else None
def _prop_cat(x):
total_bits = sum(p.dtype.itemsize * 8 for p in x.src if p.dtype != dtypes.void)
dt = dtypes.uint64 if total_bits > 32 else dtypes.uint32 if total_bits > 0 else dtypes.void
return UOp(Ops.CAT, dt, x.src, x.arg) if dt != dtypes.void else None
def _prop_customi(base, hi, lo, **kw):
if hi is lo: # array element access - use base type (register files like SGPR/VGPR are uint32)
dt = base.dtype if base.dtype != dtypes.void else dtypes.uint32
elif hi.op == Ops.CONST and lo.op == Ops.CONST: # slice with const bounds
dt = dtypes.uint64 if abs(int(hi.arg) - int(lo.arg)) + 1 > 32 else dtypes.uint32
else: # slice with variable bounds - assume uint32
dt = dtypes.uint32
return UOp(Ops.CUSTOMI, dt, (base, hi, lo), kw.get('arg'))
_PASSTHROUGH_FNS = {'abs', 'cvtToQuietNAN'} # these preserve input type
def _prop_custom(x):
if x.arg in _BOOL_FNS: dt = dtypes.bool
elif x.arg in _U32_FNS: dt = dtypes.uint32
elif x.arg in _CVT_FNS: dt = _CVT_FNS[x.arg]
elif x.arg == 'trig_preop_result': dt = dtypes.float64
elif x.arg == 'ConvertFromFormat': dt = dtypes.uint32 # format conversion returns uint32
elif x.arg == 'nop': dt = dtypes.uint32 # nop is a no-op
elif x.arg == 'MEM': return None # MEM gets type from BITCAST
elif x.arg in _PASSTHROUGH_FNS: return None # these get type from source, handled by CAST wrapper
else: dt = _first_nonvoid(*x.src) if x.src else dtypes.void
assert dt != dtypes.void, f"cannot infer type for CUSTOM op '{x.arg}'"
return UOp(Ops.CUSTOM, dt, x.src, x.arg)
# ═══════════════════════════════════════════════════════════════════════════════
# PATTERN MATCHER
# ═══════════════════════════════════════════════════════════════════════════════
_fpat = UPat.var('x', dtype=dtypes.floats)
pcode_pm = PatternMatcher([
# Float ops (preserve input type)
(UPat(Ops.CUSTOM, arg='trunc', src=(_fpat,)), lambda x: UOp(Ops.TRUNC, x.dtype, (x,))),
(UPat(Ops.CUSTOM, arg='sqrt', src=(_fpat,)), lambda x: UOp(Ops.SQRT, x.dtype, (x,))),
(UPat(Ops.CUSTOM, arg='exp2', src=(_fpat,)), lambda x: UOp(Ops.EXP2, x.dtype, (x,))),
(UPat(Ops.CUSTOM, arg='log2', src=(_fpat,)), lambda x: UOp(Ops.LOG2, x.dtype, (x,))),
(UPat(Ops.CUSTOM, arg='sin', src=(_fpat,)), lambda x: UOp(Ops.SIN, x.dtype, (x,))),
(UPat(Ops.CUSTOM, arg='rcp', src=(_fpat,)), lambda x: UOp(Ops.RECIPROCAL, x.dtype, (x,))),
(UPat(Ops.CUSTOM, arg='fma', src=(_fpat, UPat.var('b'), UPat.var('c'))), lambda x, b, c: UOp(Ops.MULACC, x.dtype, (x, b, c))),
(UPat(Ops.CUSTOM, arg='abs', src=(_fpat,)), lambda x: UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _typed_const(x, 0))), UOp(Ops.NEG, x.dtype, (x,)), x))),
(UPat(Ops.CUSTOM, arg='cos', src=(_fpat,)), lambda x: UOp(Ops.SIN, x.dtype, (UOp(Ops.ADD, x.dtype, (x, _typed_const(x, 1.5707963267948966))),))),
(UPat(Ops.CUSTOM, arg='floor', src=(_fpat,)), lambda x: _floor(x, x.dtype)),
(UPat(Ops.CUSTOM, arg='fract', src=(_fpat,)), lambda x: UOp(Ops.SUB, x.dtype, (x, _floor(x, x.dtype)))),
(UPat(Ops.CUSTOM, arg='rsqrt', src=(_fpat,)), lambda x: UOp(Ops.RECIPROCAL, x.dtype, (UOp(Ops.SQRT, x.dtype, (x,)),))),
# Boolean functions
(UPat(Ops.CUSTOM, arg='isNAN', src=(UPat.var('x'),)), lambda x: UOp(Ops.CMPNE, dtypes.bool, (x, x))),
(UPat(Ops.CUSTOM, arg='isINF', src=(UPat.var('x'),)), lambda x: UOp(Ops.OR, dtypes.bool, (
UOp(Ops.CMPEQ, dtypes.bool, (x, _typed_const(x, float('inf')))), UOp(Ops.CMPEQ, dtypes.bool, (x, _typed_const(x, float('-inf'))))))),
# min/max
(UPat(Ops.CUSTOM, arg='min', src=(UPat.var('a'), UPat.var('b'))), lambda a, b: _minmax(a, b, True)),
(UPat(Ops.CUSTOM, arg='max', src=(UPat.var('a'), UPat.var('b'))), lambda a, b: _minmax(a, b, False)),
(UPat(Ops.CUSTOM, arg='clamp', src=(UPat.var('x'), UPat.var('lo'), UPat.var('hi'))), lambda x, lo, hi: _minmax(_minmax(x, lo, False), hi, True)),
(UPat(Ops.CUSTOM, src=(UPat.var('a'), UPat.var('b')), name='op'), _typed_minmax2),
(UPat(Ops.CUSTOM, src=(UPat.var('a'), UPat.var('b'), UPat.var('c')), name='op'), _typed_minmax3),
# Type conversions
(UPat(Ops.CUSTOM, src=(UPat.var('x'),), name='op'), _typed_cast),
(UPat(Ops.CUSTOM, arg='signext', src=(UPat.var('x', dtype=dtypes.ints),)), lambda x: UOp(Ops.CAST, dtypes.int64, (x,))),
(UPat(Ops.CUSTOM, arg='bf16_to_f32', src=(UPat.var('x', dtype=dtypes.bfloat16),)),
lambda x: UOp(Ops.BITCAST, dtypes.float32, (UOp(Ops.SHL, dtypes.uint32, (UOp(Ops.CAST, dtypes.uint32, (x,)), UOp.const(dtypes.uint32, 16))),))),
(UPat(Ops.CUSTOM, arg='u32_to_u16', src=(UPat.var('x', dtype=dtypes.uint32),)), lambda x: UOp(Ops.AND, dtypes.uint32, (x, UOp.const(dtypes.uint32, 0xffff)))),
(UPat(Ops.CUSTOM, arg='i32_to_i16', src=(UPat.var('x', dtype=dtypes.int32),)),
lambda x: UOp(Ops.CAST, dtypes.int16, (UOp(Ops.AND, dtypes.uint32, (UOp(Ops.CAST, dtypes.uint32, (x,)), UOp.const(dtypes.uint32, 0xffff))),))),
]) + PatternMatcher([
# Math constants
(UPat(Ops.DEFINE_VAR, arg=('PI', None, None)), lambda: UOp.const(dtypes.float64, 3.141592653589793)),
(UPat(Ops.DEFINE_VAR, arg=('INF', None, None)), lambda: UOp.const(dtypes.float64, float('inf'))),
# Float special values
(UPat(Ops.DEFINE_VAR, arg=('MAX_FLOAT_F32', None, None)), lambda: UOp.const(dtypes.float32, 3.4028235e+38)),
(UPat(Ops.DEFINE_VAR, arg=('MAX_FLOAT_F64', None, None)), lambda: UOp.const(dtypes.float64, 1.7976931348623157e+308)),
(UPat(Ops.DEFINE_VAR, arg=('OVERFLOW_F32', None, None)), lambda: UOp.const(dtypes.float32, float('inf'))),
(UPat(Ops.DEFINE_VAR, arg=('OVERFLOW_F64', None, None)), lambda: UOp.const(dtypes.float64, float('inf'))),
(UPat(Ops.DEFINE_VAR, arg=('UNDERFLOW_F32', None, None)), lambda: UOp.const(dtypes.float32, 0.0)),
(UPat(Ops.DEFINE_VAR, arg=('UNDERFLOW_F64', None, None)), lambda: UOp.const(dtypes.float64, 0.0)),
# Variable type tracking and propagation
(UPat(Ops.DEFINE_VAR, name='u'), _track_var),
(UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='u'), _prop_var),
(UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='lhs'), UPat.var('rhs'))), _prop_assign),
# Propagate dtype for ASSIGN from rhs, or infer rhs dtype from lhs if rhs is void
(UPat(Ops.ASSIGN, dtype=dtypes.void, src=(UPat.var('lhs'), UPat.var('rhs'))),
lambda lhs, rhs: UOp(Ops.ASSIGN, rhs.dtype, (lhs, rhs)) if rhs.dtype != dtypes.void else
UOp(Ops.ASSIGN, lhs.dtype, (lhs, rhs.replace(dtype=lhs.dtype))) if lhs.dtype != dtypes.void else None),
# Dtype propagation for void-typed ops
(UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR, Ops.SHL, Ops.SHR, Ops.MOD, Ops.POW),
dtype=dtypes.void, src=(UPat.var('l'), UPat.var('r')), name='__OP__'), _prop_binop),
(UPat((Ops.NEG, Ops.TRUNC, Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.RECIPROCAL),
dtype=dtypes.void, src=(UPat.var('x'),), name='__OP__'), _prop_unop),
# Unary XOR (NOT) -> binary XOR with all ones
(UPat(Ops.XOR, src=(UPat.var('x'),)),
lambda x: UOp(Ops.XOR, x.dtype, (x, UOp.const(x.dtype, -1))) if x.dtype != dtypes.void else None),
# Unary CMPEQ (logical NOT) -> CMPEQ(x, 0) with matching type (default to uint32 for void)
(UPat(Ops.CMPEQ, dtype=dtypes.bool, src=(UPat.var('x'),)),
lambda x: UOp(Ops.CMPEQ, dtypes.bool, (x, UOp.const(x.dtype if x.dtype != dtypes.void else dtypes.uint32, 0)))),
(UPat(Ops.MULACC, dtype=dtypes.void, src=(UPat.var('a'), UPat.var('b'), UPat.var('c'))), _prop_mulacc),
(UPat(Ops.WHERE, dtype=dtypes.void, src=(UPat.var('cond'), UPat.var('t'), UPat.var('f'))), _prop_where),
(UPat(Ops.CAT, dtype=dtypes.void, name='x'), _prop_cat),
(UPat(Ops.CUSTOMI, dtype=dtypes.void, src=(UPat.var('base'), UPat.var('hi'), UPat.var('lo'))), _prop_customi),
(UPat(Ops.CUSTOM, dtype=dtypes.void, name='x'), _prop_custom),
# Fix comparison type mismatches: cast to larger type
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ, Ops.CMPLE), src=(UPat.var('x'), UPat.var('y')), name='cmp'),
lambda cmp, x, y: UOp(cmp.op, dtypes.bool, (x, UOp(Ops.CAST, x.dtype, (y,)))) if x.dtype != dtypes.void and y.dtype != dtypes.void and x.dtype != y.dtype and x.dtype.itemsize >= y.dtype.itemsize else None),
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ, Ops.CMPLE), src=(UPat.var('x'), UPat.var('y')), name='cmp'),
lambda cmp, x, y: UOp(cmp.op, dtypes.bool, (UOp(Ops.CAST, y.dtype, (x,)), y)) if x.dtype != dtypes.void and y.dtype != dtypes.void and x.dtype != y.dtype and y.dtype.itemsize > x.dtype.itemsize else None),
# Fix WHERE with non-bool condition: cast int condition to bool (test != 0)
(UPat(Ops.WHERE, src=(UPat.var('c', dtype=dtypes.ints), UPat.var('t'), UPat.var('f'))),
lambda c, t, f: UOp(Ops.WHERE, t.dtype if t.dtype != dtypes.void else f.dtype, (UOp(Ops.CMPNE, dtypes.bool, (c, UOp.const(c.dtype, 0))), t, f))),
# Fix logical AND/OR with bool and int: convert int to bool (!= 0)
(UPat((Ops.AND, Ops.OR), src=(UPat.var('x', dtype=dtypes.bool), UPat.var('y', dtype=dtypes.ints))),
lambda x, y: UOp(Ops.AND, dtypes.bool, (x, UOp(Ops.CMPNE, dtypes.bool, (y, UOp.const(y.dtype, 0)))))),
(UPat((Ops.AND, Ops.OR), src=(UPat.var('x', dtype=dtypes.ints), UPat.var('y', dtype=dtypes.bool))),
lambda x, y: UOp(Ops.AND, dtypes.bool, (UOp(Ops.CMPNE, dtypes.bool, (x, UOp.const(x.dtype, 0))), y))),
# Fix binary op type mismatches: cast smaller to larger (excluding POW which allows int exponent)
(UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR), src=(UPat.var('x'), UPat.var('y')), name='op'),
lambda op, x, y: UOp(op.op, op.dtype, (x, UOp(Ops.CAST, x.dtype, (y,)))) if x.dtype != dtypes.void and y.dtype != dtypes.void and x.dtype != y.dtype and x.dtype.itemsize >= y.dtype.itemsize else None),
(UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR), src=(UPat.var('x'), UPat.var('y')), name='op'),
lambda op, x, y: UOp(op.op, op.dtype, (UOp(Ops.CAST, y.dtype, (x,)), y)) if x.dtype != dtypes.void and y.dtype != dtypes.void and x.dtype != y.dtype and y.dtype.itemsize > x.dtype.itemsize else None),
# Back-propagate types to void DEFINE_VAR sources
(UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR),
src=(UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='v'), UPat.var('t')), name='op'),
lambda op, v, t: _backprop_binop(None, op, v, t) if t.dtype != dtypes.void else None),
(UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR),
src=(UPat.var('t'), UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='v')), name='op'),
lambda op, t, v: _backprop_binop(None, op, v, t) if t.dtype != dtypes.void else None),
])
# ═══════════════════════════════════════════════════════════════════════════════
# PCODE SPEC (extends shared_spec with pcode-specific patterns)
# ═══════════════════════════════════════════════════════════════════════════════
pcode_spec = PatternMatcher([
# DEFINE_VAR: pcode uses string names, not (name, min, max) tuples with ints
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg, (str, tuple))),
# ASSIGN: dtype matches rhs (unless both void)
(UPat(Ops.ASSIGN, src=(UPat.var("lhs"), UPat.var("rhs")), name="a"),
lambda a, lhs, rhs: a.dtype == rhs.dtype and (rhs.dtype != dtypes.void or lhs.dtype == dtypes.void)),
# BITCAST: void source allowed (type view on untyped register)
(UPat(Ops.BITCAST, src=(UPat(),)), lambda: True),
# CUSTOMI/CAT: must be typed (slice bounds or bit concat determine type)
(UPat(Ops.CUSTOMI, name="x"), lambda x: x.dtype != dtypes.void),
(UPat(Ops.CAT, name="x"), lambda x: x.dtype != dtypes.void),
# CUSTOM: MEM and passthrough ops (abs, cvtToQuietNAN) can be void (wrapped by BITCAST/CAST)
(UPat(Ops.CUSTOM, name="x"), lambda x: x.dtype != dtypes.void or x.arg in {'MEM', 'abs', 'cvtToQuietNAN'}),
# POW allows int exponent with float base
(UPat(Ops.POW, dtype=dtypes.floats, src=(UPat(dtype=dtypes.floats), UPat(dtype=dtypes.ints))), lambda: True),
]) + shared_spec
# ═══════════════════════════════════════════════════════════════════════════════
# TRANSFORM
# ═══════════════════════════════════════════════════════════════════════════════
def _transform_uop(u: UOp, ctx: dict) -> UOp:
result = graph_rewrite(u, pcode_pm, ctx=ctx)
type_verify(result, pcode_spec)
return result
def _transform_stmt(stmt, ctx: dict):
match stmt:
case If(branches): return If(tuple((_transform_uop(c, ctx) if c is not None else None, tuple(_transform_stmt(s, ctx) for s in b)) for c, b in branches))
case For(var, start, end, body): return For(var, _transform_uop(start, ctx), _transform_uop(end, ctx), tuple(_transform_stmt(s, ctx) for s in body))
case Lambda(name, params, body): return Lambda(name, params, _transform_uop(body, ctx) if isinstance(body, UOp) else tuple(_transform_stmt(s, ctx) for s in body))
case Return(v): return Return(_transform_uop(v, ctx))
case UOp(): return _transform_uop(stmt, ctx)
case _: return stmt
def parse_transform(pcode: str) -> tuple:
ctx: dict[str, DType] = {'SCC': dtypes.bool, 'VCC': dtypes.uint64, 'EXEC': dtypes.uint64,
'VDATA': dtypes.uint64, 'SDATA': dtypes.uint64, 'ADDR': dtypes.uint64, 'VDST': dtypes.uint32,
'ROUND_MODE': dtypes.uint32, 'ROUND_TOWARD_ZERO': dtypes.uint32, 'HW_REGISTERS': dtypes.uint32,
'SGPR': dtypes.uint32, 'VGPR': dtypes.uint32} # register files are uint32 arrays
return tuple(_transform_stmt(s, ctx) for s in parse(pcode))