simpler spec

This commit is contained in:
George Hotz
2026-01-11 09:41:57 +09:00
parent 688f57468b
commit c028fdd83c

View File

@@ -90,11 +90,31 @@ def _prop_assign(ctx, lhs, rhs):
ctx[name] = rhs.dtype
return UOp(Ops.ASSIGN, rhs.dtype, (UOp(Ops.DEFINE_VAR, rhs.dtype, arg=lhs.arg), rhs))
# Dtype propagation for void-typed ops
# Dtype propagation for void-typed ops (forward propagation)
def _prop_binop(l, r, __OP__, **kw):
dt = l.dtype if l.dtype != dtypes.void else r.dtype
# For SHL/SHR, result type comes from left operand
if __OP__.op in {Ops.SHL, Ops.SHR}:
dt = l.dtype if l.dtype != dtypes.void else r.dtype
# Use larger dtype if both are typed, otherwise first non-void
elif l.dtype != dtypes.void and r.dtype != dtypes.void:
dt = l.dtype if l.dtype.itemsize >= r.dtype.itemsize else r.dtype
else:
dt = l.dtype if l.dtype != dtypes.void else r.dtype
return UOp(__OP__.op, dt, (l, r), kw.get('arg')) if dt != dtypes.void else None
# Back-propagate type to void DEFINE_VAR source
def _backprop_binop(ctx, op, void_var, typed_src):
# void_var is void DEFINE_VAR, typed_src is typed - propagate type to void_var
dt = typed_src.dtype
name = void_var.arg[0] if isinstance(void_var.arg, tuple) else void_var.arg
if ctx is not None:
if name in ctx: assert ctx[name] == dt, f"variable '{name}' has conflicting types: {ctx[name]} vs {dt}"
else: ctx[name] = dt
new_var = UOp(Ops.DEFINE_VAR, dt, arg=void_var.arg)
# maintain original order
new_srcs = (new_var, typed_src) if op.src[0] is void_var else (typed_src, new_var)
return UOp(op.op, op.dtype, new_srcs, op.arg)
def _prop_unop(x, __OP__, **kw):
return UOp(__OP__.op, x.dtype, (x,), kw.get('arg')) if x.dtype != dtypes.void else None
@@ -111,20 +131,28 @@ def _prop_cat(x):
return UOp(Ops.CAT, dt, x.src, x.arg) if dt != dtypes.void else None
def _prop_customi(base, hi, lo, **kw):
if hi is lo: # array element access
if base.dtype == dtypes.void: return None
dt = base.dtype.scalar() if base.dtype.count > 1 else base.dtype
else: # slice - infer from bounds
dt = dtypes.uint64 if hi.op == Ops.CONST and lo.op == Ops.CONST and abs(int(hi.arg) - int(lo.arg)) + 1 > 32 else dtypes.uint32
if hi is lo: # array element access - use base type (register files like SGPR/VGPR are uint32)
dt = base.dtype if base.dtype != dtypes.void else dtypes.uint32
elif hi.op == Ops.CONST and lo.op == Ops.CONST: # slice with const bounds
dt = dtypes.uint64 if abs(int(hi.arg) - int(lo.arg)) + 1 > 32 else dtypes.uint32
else: # slice with variable bounds - assume uint32
dt = dtypes.uint32
return UOp(Ops.CUSTOMI, dt, (base, hi, lo), kw.get('arg'))
_PASSTHROUGH_FNS = {'abs', 'cvtToQuietNAN'} # these preserve input type
def _prop_custom(x):
if x.arg in _BOOL_FNS: dt = dtypes.bool
elif x.arg in _U32_FNS: dt = dtypes.uint32
elif x.arg in _CVT_FNS: dt = _CVT_FNS[x.arg]
elif x.arg == 'trig_preop_result': dt = dtypes.float64
elif x.arg == 'ConvertFromFormat': dt = dtypes.uint32 # format conversion returns uint32
elif x.arg == 'nop': dt = dtypes.uint32 # nop is a no-op
elif x.arg == 'MEM': return None # MEM gets type from BITCAST
elif x.arg in _PASSTHROUGH_FNS: return None # these get type from source, handled by CAST wrapper
else: dt = _first_nonvoid(*x.src) if x.src else dtypes.void
return UOp(Ops.CUSTOM, dt, x.src, x.arg) if dt != dtypes.void else None
assert dt != dtypes.void, f"cannot infer type for CUSTOM op '{x.arg}'"
return UOp(Ops.CUSTOM, dt, x.src, x.arg)
# ═══════════════════════════════════════════════════════════════════════════════
# PATTERN MATCHER
@@ -165,6 +193,16 @@ pcode_pm = PatternMatcher([
(UPat(Ops.CUSTOM, arg='i32_to_i16', src=(UPat.var('x', dtype=dtypes.int32),)),
lambda x: UOp(Ops.CAST, dtypes.int16, (UOp(Ops.AND, dtypes.uint32, (UOp(Ops.CAST, dtypes.uint32, (x,)), UOp.const(dtypes.uint32, 0xffff))),))),
]) + PatternMatcher([
# Math constants
(UPat(Ops.DEFINE_VAR, arg=('PI', None, None)), lambda: UOp.const(dtypes.float64, 3.141592653589793)),
(UPat(Ops.DEFINE_VAR, arg=('INF', None, None)), lambda: UOp.const(dtypes.float64, float('inf'))),
# Float special values
(UPat(Ops.DEFINE_VAR, arg=('MAX_FLOAT_F32', None, None)), lambda: UOp.const(dtypes.float32, 3.4028235e+38)),
(UPat(Ops.DEFINE_VAR, arg=('MAX_FLOAT_F64', None, None)), lambda: UOp.const(dtypes.float64, 1.7976931348623157e+308)),
(UPat(Ops.DEFINE_VAR, arg=('OVERFLOW_F32', None, None)), lambda: UOp.const(dtypes.float32, float('inf'))),
(UPat(Ops.DEFINE_VAR, arg=('OVERFLOW_F64', None, None)), lambda: UOp.const(dtypes.float64, float('inf'))),
(UPat(Ops.DEFINE_VAR, arg=('UNDERFLOW_F32', None, None)), lambda: UOp.const(dtypes.float32, 0.0)),
(UPat(Ops.DEFINE_VAR, arg=('UNDERFLOW_F64', None, None)), lambda: UOp.const(dtypes.float64, 0.0)),
# Variable type tracking and propagation
(UPat(Ops.DEFINE_VAR, name='u'), _track_var),
(UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='u'), _prop_var),
@@ -194,69 +232,42 @@ pcode_pm = PatternMatcher([
# Fix WHERE with non-bool condition: cast int condition to bool (test != 0)
(UPat(Ops.WHERE, src=(UPat.var('c', dtype=dtypes.ints), UPat.var('t'), UPat.var('f'))),
lambda c, t, f: UOp(Ops.WHERE, t.dtype if t.dtype != dtypes.void else f.dtype, (UOp(Ops.CMPNE, dtypes.bool, (c, UOp.const(c.dtype, 0))), t, f))),
# Fix logical AND/OR with bool and int: convert int to bool (!= 0)
(UPat((Ops.AND, Ops.OR), src=(UPat.var('x', dtype=dtypes.bool), UPat.var('y', dtype=dtypes.ints))),
lambda x, y: UOp(Ops.AND, dtypes.bool, (x, UOp(Ops.CMPNE, dtypes.bool, (y, UOp.const(y.dtype, 0)))))),
(UPat((Ops.AND, Ops.OR), src=(UPat.var('x', dtype=dtypes.ints), UPat.var('y', dtype=dtypes.bool))),
lambda x, y: UOp(Ops.AND, dtypes.bool, (UOp(Ops.CMPNE, dtypes.bool, (x, UOp.const(x.dtype, 0))), y))),
# Fix binary op type mismatches: cast smaller to larger (excluding POW which allows int exponent)
(UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV), src=(UPat.var('x'), UPat.var('y')), name='op'),
(UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR), src=(UPat.var('x'), UPat.var('y')), name='op'),
lambda op, x, y: UOp(op.op, op.dtype, (x, UOp(Ops.CAST, x.dtype, (y,)))) if x.dtype != dtypes.void and y.dtype != dtypes.void and x.dtype != y.dtype and x.dtype.itemsize >= y.dtype.itemsize else None),
(UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV), src=(UPat.var('x'), UPat.var('y')), name='op'),
(UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR), src=(UPat.var('x'), UPat.var('y')), name='op'),
lambda op, x, y: UOp(op.op, op.dtype, (UOp(Ops.CAST, y.dtype, (x,)), y)) if x.dtype != dtypes.void and y.dtype != dtypes.void and x.dtype != y.dtype and y.dtype.itemsize > x.dtype.itemsize else None),
# Back-propagate types to void DEFINE_VAR sources
(UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR),
src=(UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='v'), UPat.var('t')), name='op'),
lambda op, v, t: _backprop_binop(None, op, v, t) if t.dtype != dtypes.void else None),
(UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR),
src=(UPat.var('t'), UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='v')), name='op'),
lambda op, t, v: _backprop_binop(None, op, v, t) if t.dtype != dtypes.void else None),
])
# ═══════════════════════════════════════════════════════════════════════════════
# PCODE SPEC (extends shared_spec with pcode-specific patterns)
# ═══════════════════════════════════════════════════════════════════════════════
def _is_numeric(dt: DType) -> bool:
"""Check if dtype is numeric (bool, int, float, or custom bit-width type like u1, i65)"""
if dt == dtypes.bool or dtypes.is_int(dt) or dtypes.is_float(dt): return True
return dt.name[0] in ('u', 'i') and dt.name[1:].isdigit() # custom bit-width types
def _check_binop(x):
"""Binary ALU: result and sources should be compatible"""
for s in x.src:
if s.dtype == dtypes.void: continue
if x.dtype.base == s.dtype.base: continue
# Both numeric types (int/float/custom bit-width) are compatible
if _is_numeric(x.dtype) and _is_numeric(s.dtype): continue
# POW allows int exponent with float base
if x.op == Ops.POW and dtypes.is_float(x.dtype) and dtypes.is_int(s.dtype): continue
return False
return True
pcode_spec = PatternMatcher([
# DEFINE_VAR for register/variable references
# DEFINE_VAR: pcode uses string names, not (name, min, max) tuples with ints
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg, (str, tuple))),
# ASSIGN: dtype matches rhs, rhs must be typed (unless both sides are void)
# ASSIGN: dtype matches rhs (unless both void)
(UPat(Ops.ASSIGN, src=(UPat.var("lhs"), UPat.var("rhs")), name="a"),
lambda a, lhs, rhs: a.dtype == rhs.dtype and (rhs.dtype != dtypes.void or lhs.dtype == dtypes.void)),
# BITCAST for type views on registers
# Pcode-specific ops (void sources allowed - type comes from context)
(UPat(Ops.BITCAST, src=(UPat(),)), lambda: True),
# CUSTOMI for slices and array access
(UPat(Ops.CUSTOMI, src=(UPat(), UPat(), UPat())), lambda: True),
# CUSTOM ops that haven't been transformed yet
(UPat(Ops.CUSTOM), lambda: True),
# CAT for bit concatenation
(UPat(Ops.CAT), lambda: True),
# MULACC: all types match (or void)
(UPat(Ops.MULACC, src=(UPat(), UPat(), UPat()), name="x"),
lambda x: all(s.dtype == x.dtype or s.dtype == dtypes.void for s in x.src)),
# Unary ops: result matches source (or void source)
(UPat((Ops.NEG, Ops.TRUNC, Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.RECIPROCAL), src=(UPat.var("x"),), name="u"),
lambda u, x: u.dtype == x.dtype or x.dtype == dtypes.void),
# SHL/SHR: shift amount is uint
(UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"),
lambda a, x, y: (a.dtype == x.dtype or x.dtype == dtypes.void) and y.dtype == dtypes.uint),
# Comparisons: result is bool, sources match (or void)
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ, Ops.CMPLE), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))),
lambda x, y: x.dtype == dtypes.void or y.dtype == dtypes.void or x.dtype == y.dtype),
# Unary comparison (sign check)
(UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.CAT)), lambda: True),
# Unary comparison (sign check, e.g. !sign(x) parses as CMPEQ(sign(x)))
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ, Ops.CMPLE), dtype=dtypes.bool, src=(UPat(),)), lambda: True),
# WHERE: condition is bool, t/f match result (or void)
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("t"), UPat.var("f"))),
lambda w, t, f: (w.dtype == t.dtype or t.dtype == dtypes.void) and (w.dtype == f.dtype or f.dtype == dtypes.void)),
# Binary ALU ops
(UPat(GroupOp.ALU-{Ops.SHL, Ops.SHR, Ops.WHERE, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ, Ops.CMPLE, Ops.NEG, Ops.TRUNC, Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.RECIPROCAL},
src=(UPat(), UPat()), name="x"), _check_binop),
# POW allows int exponent with float base
(UPat(Ops.POW, dtype=dtypes.floats, src=(UPat(dtype=dtypes.floats), UPat(dtype=dtypes.ints))), lambda: True),
]) + shared_spec
# ═══════════════════════════════════════════════════════════════════════════════
@@ -278,5 +289,8 @@ def _transform_stmt(stmt, ctx: dict):
case _: return stmt
def parse_transform(pcode: str) -> tuple:
ctx: dict[str, DType] = {'SCC': dtypes.bool, 'VCC': dtypes.uint64, 'EXEC': dtypes.uint64}
ctx: dict[str, DType] = {'SCC': dtypes.bool, 'VCC': dtypes.uint64, 'EXEC': dtypes.uint64,
'VDATA': dtypes.uint64, 'SDATA': dtypes.uint64, 'ADDR': dtypes.uint64, 'VDST': dtypes.uint32,
'ROUND_MODE': dtypes.uint32, 'ROUND_TOWARD_ZERO': dtypes.uint32, 'HW_REGISTERS': dtypes.uint32,
'SGPR': dtypes.uint32, 'VGPR': dtypes.uint32} # register files are uint32 arrays
return tuple(_transform_stmt(s, ctx) for s in parse(pcode))