diff --git a/extra/assembly/amd/pcode_transform.py b/extra/assembly/amd/pcode_transform.py index 17c2cfca48..7548a0f02c 100644 --- a/extra/assembly/amd/pcode_transform.py +++ b/extra/assembly/amd/pcode_transform.py @@ -90,11 +90,31 @@ def _prop_assign(ctx, lhs, rhs): ctx[name] = rhs.dtype return UOp(Ops.ASSIGN, rhs.dtype, (UOp(Ops.DEFINE_VAR, rhs.dtype, arg=lhs.arg), rhs)) -# Dtype propagation for void-typed ops +# Dtype propagation for void-typed ops (forward propagation) def _prop_binop(l, r, __OP__, **kw): - dt = l.dtype if l.dtype != dtypes.void else r.dtype + # For SHL/SHR, result type comes from left operand + if __OP__.op in {Ops.SHL, Ops.SHR}: + dt = l.dtype if l.dtype != dtypes.void else r.dtype + # Use larger dtype if both are typed, otherwise first non-void + elif l.dtype != dtypes.void and r.dtype != dtypes.void: + dt = l.dtype if l.dtype.itemsize >= r.dtype.itemsize else r.dtype + else: + dt = l.dtype if l.dtype != dtypes.void else r.dtype return UOp(__OP__.op, dt, (l, r), kw.get('arg')) if dt != dtypes.void else None +# Back-propagate type to void DEFINE_VAR source +def _backprop_binop(ctx, op, void_var, typed_src): + # void_var is void DEFINE_VAR, typed_src is typed - propagate type to void_var + dt = typed_src.dtype + name = void_var.arg[0] if isinstance(void_var.arg, tuple) else void_var.arg + if ctx is not None: + if name in ctx: assert ctx[name] == dt, f"variable '{name}' has conflicting types: {ctx[name]} vs {dt}" + else: ctx[name] = dt + new_var = UOp(Ops.DEFINE_VAR, dt, arg=void_var.arg) + # maintain original order + new_srcs = (new_var, typed_src) if op.src[0] is void_var else (typed_src, new_var) + return UOp(op.op, op.dtype, new_srcs, op.arg) + def _prop_unop(x, __OP__, **kw): return UOp(__OP__.op, x.dtype, (x,), kw.get('arg')) if x.dtype != dtypes.void else None @@ -111,20 +131,28 @@ def _prop_cat(x): return UOp(Ops.CAT, dt, x.src, x.arg) if dt != dtypes.void else None def _prop_customi(base, hi, lo, **kw): - if hi is lo: # array element access - if base.dtype == dtypes.void: return None - dt = base.dtype.scalar() if base.dtype.count > 1 else base.dtype - else: # slice - infer from bounds - dt = dtypes.uint64 if hi.op == Ops.CONST and lo.op == Ops.CONST and abs(int(hi.arg) - int(lo.arg)) + 1 > 32 else dtypes.uint32 + if hi is lo: # array element access - use base type (register files like SGPR/VGPR are uint32) + dt = base.dtype if base.dtype != dtypes.void else dtypes.uint32 + elif hi.op == Ops.CONST and lo.op == Ops.CONST: # slice with const bounds + dt = dtypes.uint64 if abs(int(hi.arg) - int(lo.arg)) + 1 > 32 else dtypes.uint32 + else: # slice with variable bounds - assume uint32 + dt = dtypes.uint32 return UOp(Ops.CUSTOMI, dt, (base, hi, lo), kw.get('arg')) +_PASSTHROUGH_FNS = {'abs', 'cvtToQuietNAN'} # these preserve input type + def _prop_custom(x): if x.arg in _BOOL_FNS: dt = dtypes.bool elif x.arg in _U32_FNS: dt = dtypes.uint32 elif x.arg in _CVT_FNS: dt = _CVT_FNS[x.arg] elif x.arg == 'trig_preop_result': dt = dtypes.float64 + elif x.arg == 'ConvertFromFormat': dt = dtypes.uint32 # format conversion returns uint32 + elif x.arg == 'nop': dt = dtypes.uint32 # nop is a no-op + elif x.arg == 'MEM': return None # MEM gets type from BITCAST + elif x.arg in _PASSTHROUGH_FNS: return None # these get type from source, handled by CAST wrapper else: dt = _first_nonvoid(*x.src) if x.src else dtypes.void - return UOp(Ops.CUSTOM, dt, x.src, x.arg) if dt != dtypes.void else None + assert dt != dtypes.void, f"cannot infer type for CUSTOM op '{x.arg}'" + return UOp(Ops.CUSTOM, dt, x.src, x.arg) # ═══════════════════════════════════════════════════════════════════════════════ # PATTERN MATCHER @@ -165,6 +193,16 @@ pcode_pm = PatternMatcher([ (UPat(Ops.CUSTOM, arg='i32_to_i16', src=(UPat.var('x', dtype=dtypes.int32),)), lambda x: UOp(Ops.CAST, dtypes.int16, (UOp(Ops.AND, dtypes.uint32, (UOp(Ops.CAST, dtypes.uint32, (x,)), UOp.const(dtypes.uint32, 0xffff))),))), ]) + PatternMatcher([ + # Math constants + (UPat(Ops.DEFINE_VAR, arg=('PI', None, None)), lambda: UOp.const(dtypes.float64, 3.141592653589793)), + (UPat(Ops.DEFINE_VAR, arg=('INF', None, None)), lambda: UOp.const(dtypes.float64, float('inf'))), + # Float special values + (UPat(Ops.DEFINE_VAR, arg=('MAX_FLOAT_F32', None, None)), lambda: UOp.const(dtypes.float32, 3.4028235e+38)), + (UPat(Ops.DEFINE_VAR, arg=('MAX_FLOAT_F64', None, None)), lambda: UOp.const(dtypes.float64, 1.7976931348623157e+308)), + (UPat(Ops.DEFINE_VAR, arg=('OVERFLOW_F32', None, None)), lambda: UOp.const(dtypes.float32, float('inf'))), + (UPat(Ops.DEFINE_VAR, arg=('OVERFLOW_F64', None, None)), lambda: UOp.const(dtypes.float64, float('inf'))), + (UPat(Ops.DEFINE_VAR, arg=('UNDERFLOW_F32', None, None)), lambda: UOp.const(dtypes.float32, 0.0)), + (UPat(Ops.DEFINE_VAR, arg=('UNDERFLOW_F64', None, None)), lambda: UOp.const(dtypes.float64, 0.0)), # Variable type tracking and propagation (UPat(Ops.DEFINE_VAR, name='u'), _track_var), (UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='u'), _prop_var), @@ -194,69 +232,42 @@ pcode_pm = PatternMatcher([ # Fix WHERE with non-bool condition: cast int condition to bool (test != 0) (UPat(Ops.WHERE, src=(UPat.var('c', dtype=dtypes.ints), UPat.var('t'), UPat.var('f'))), lambda c, t, f: UOp(Ops.WHERE, t.dtype if t.dtype != dtypes.void else f.dtype, (UOp(Ops.CMPNE, dtypes.bool, (c, UOp.const(c.dtype, 0))), t, f))), + # Fix logical AND/OR with bool and int: convert int to bool (!= 0) + (UPat((Ops.AND, Ops.OR), src=(UPat.var('x', dtype=dtypes.bool), UPat.var('y', dtype=dtypes.ints))), + lambda x, y: UOp(Ops.AND, dtypes.bool, (x, UOp(Ops.CMPNE, dtypes.bool, (y, UOp.const(y.dtype, 0)))))), + (UPat((Ops.AND, Ops.OR), src=(UPat.var('x', dtype=dtypes.ints), UPat.var('y', dtype=dtypes.bool))), + lambda x, y: UOp(Ops.AND, dtypes.bool, (UOp(Ops.CMPNE, dtypes.bool, (x, UOp.const(x.dtype, 0))), y))), # Fix binary op type mismatches: cast smaller to larger (excluding POW which allows int exponent) - (UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV), src=(UPat.var('x'), UPat.var('y')), name='op'), + (UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR), src=(UPat.var('x'), UPat.var('y')), name='op'), lambda op, x, y: UOp(op.op, op.dtype, (x, UOp(Ops.CAST, x.dtype, (y,)))) if x.dtype != dtypes.void and y.dtype != dtypes.void and x.dtype != y.dtype and x.dtype.itemsize >= y.dtype.itemsize else None), - (UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV), src=(UPat.var('x'), UPat.var('y')), name='op'), + (UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR), src=(UPat.var('x'), UPat.var('y')), name='op'), lambda op, x, y: UOp(op.op, op.dtype, (UOp(Ops.CAST, y.dtype, (x,)), y)) if x.dtype != dtypes.void and y.dtype != dtypes.void and x.dtype != y.dtype and y.dtype.itemsize > x.dtype.itemsize else None), + # Back-propagate types to void DEFINE_VAR sources + (UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR), + src=(UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='v'), UPat.var('t')), name='op'), + lambda op, v, t: _backprop_binop(None, op, v, t) if t.dtype != dtypes.void else None), + (UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR), + src=(UPat.var('t'), UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='v')), name='op'), + lambda op, t, v: _backprop_binop(None, op, v, t) if t.dtype != dtypes.void else None), ]) # ═══════════════════════════════════════════════════════════════════════════════ # PCODE SPEC (extends shared_spec with pcode-specific patterns) # ═══════════════════════════════════════════════════════════════════════════════ -def _is_numeric(dt: DType) -> bool: - """Check if dtype is numeric (bool, int, float, or custom bit-width type like u1, i65)""" - if dt == dtypes.bool or dtypes.is_int(dt) or dtypes.is_float(dt): return True - return dt.name[0] in ('u', 'i') and dt.name[1:].isdigit() # custom bit-width types - -def _check_binop(x): - """Binary ALU: result and sources should be compatible""" - for s in x.src: - if s.dtype == dtypes.void: continue - if x.dtype.base == s.dtype.base: continue - # Both numeric types (int/float/custom bit-width) are compatible - if _is_numeric(x.dtype) and _is_numeric(s.dtype): continue - # POW allows int exponent with float base - if x.op == Ops.POW and dtypes.is_float(x.dtype) and dtypes.is_int(s.dtype): continue - return False - return True - pcode_spec = PatternMatcher([ - # DEFINE_VAR for register/variable references + # DEFINE_VAR: pcode uses string names, not (name, min, max) tuples with ints (UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg, (str, tuple))), - # ASSIGN: dtype matches rhs, rhs must be typed (unless both sides are void) + # ASSIGN: dtype matches rhs (unless both void) (UPat(Ops.ASSIGN, src=(UPat.var("lhs"), UPat.var("rhs")), name="a"), lambda a, lhs, rhs: a.dtype == rhs.dtype and (rhs.dtype != dtypes.void or lhs.dtype == dtypes.void)), - # BITCAST for type views on registers + # Pcode-specific ops (void sources allowed - type comes from context) (UPat(Ops.BITCAST, src=(UPat(),)), lambda: True), - # CUSTOMI for slices and array access - (UPat(Ops.CUSTOMI, src=(UPat(), UPat(), UPat())), lambda: True), - # CUSTOM ops that haven't been transformed yet - (UPat(Ops.CUSTOM), lambda: True), - # CAT for bit concatenation - (UPat(Ops.CAT), lambda: True), - # MULACC: all types match (or void) - (UPat(Ops.MULACC, src=(UPat(), UPat(), UPat()), name="x"), - lambda x: all(s.dtype == x.dtype or s.dtype == dtypes.void for s in x.src)), - # Unary ops: result matches source (or void source) - (UPat((Ops.NEG, Ops.TRUNC, Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.RECIPROCAL), src=(UPat.var("x"),), name="u"), - lambda u, x: u.dtype == x.dtype or x.dtype == dtypes.void), - - # SHL/SHR: shift amount is uint - (UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), - lambda a, x, y: (a.dtype == x.dtype or x.dtype == dtypes.void) and y.dtype == dtypes.uint), - # Comparisons: result is bool, sources match (or void) - (UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ, Ops.CMPLE), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), - lambda x, y: x.dtype == dtypes.void or y.dtype == dtypes.void or x.dtype == y.dtype), - # Unary comparison (sign check) + (UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.CAT)), lambda: True), + # Unary comparison (sign check, e.g. !sign(x) parses as CMPEQ(sign(x))) (UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ, Ops.CMPLE), dtype=dtypes.bool, src=(UPat(),)), lambda: True), - # WHERE: condition is bool, t/f match result (or void) - (UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("t"), UPat.var("f"))), - lambda w, t, f: (w.dtype == t.dtype or t.dtype == dtypes.void) and (w.dtype == f.dtype or f.dtype == dtypes.void)), - # Binary ALU ops - (UPat(GroupOp.ALU-{Ops.SHL, Ops.SHR, Ops.WHERE, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ, Ops.CMPLE, Ops.NEG, Ops.TRUNC, Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.RECIPROCAL}, - src=(UPat(), UPat()), name="x"), _check_binop), + # POW allows int exponent with float base + (UPat(Ops.POW, dtype=dtypes.floats, src=(UPat(dtype=dtypes.floats), UPat(dtype=dtypes.ints))), lambda: True), ]) + shared_spec # ═══════════════════════════════════════════════════════════════════════════════ @@ -278,5 +289,8 @@ def _transform_stmt(stmt, ctx: dict): case _: return stmt def parse_transform(pcode: str) -> tuple: - ctx: dict[str, DType] = {'SCC': dtypes.bool, 'VCC': dtypes.uint64, 'EXEC': dtypes.uint64} + ctx: dict[str, DType] = {'SCC': dtypes.bool, 'VCC': dtypes.uint64, 'EXEC': dtypes.uint64, + 'VDATA': dtypes.uint64, 'SDATA': dtypes.uint64, 'ADDR': dtypes.uint64, 'VDST': dtypes.uint32, + 'ROUND_MODE': dtypes.uint32, 'ROUND_TOWARD_ZERO': dtypes.uint32, 'HW_REGISTERS': dtypes.uint32, + 'SGPR': dtypes.uint32, 'VGPR': dtypes.uint32} # register files are uint32 arrays return tuple(_transform_stmt(s, ctx) for s in parse(pcode))