mypy emulator pre-commit passing (#15379)

* fix dict stuff

* add type: ignores

* fix pcode to put uops not ints
This commit is contained in:
qazal
2026-03-20 07:44:09 +02:00
committed by GitHub
parent 87c4ec1724
commit cf6a429aaa
2 changed files with 29 additions and 24 deletions

View File

@@ -337,7 +337,7 @@ def get_pcode(op) -> str:
pcode = pcode.replace('VCC = 0x0LL', 'VCC.u64[laneId] = 0').replace('VCC = 0x1LL', 'VCC.u64[laneId] = 1')
return pcode
def parse_pcode(pcode: str, srcs: dict[str, UOp] | None = None) -> tuple[dict, list[tuple[str, UOp]]]:
def parse_pcode(pcode: str, srcs: dict[str, UOp | int] | None = None) -> tuple[dict, list[tuple[str, UOp]]]:
env: dict = srcs.copy() if srcs else {}
assigns: list[tuple[str, UOp]] = []
raw_lines = [l.strip().rstrip(';') for l in pcode.split('\n') if l.strip() and not l.strip().startswith('//')]
@@ -620,7 +620,7 @@ class _Ctx:
elif dest.startswith('VCC'): stores.extend(self.wmask(_c(VCC_LO.offset), val))
return stores
def compile_sop_pcode(self, op, srcs: dict[str, UOp], sdst_reg: UOp, sdst_size: int) -> UOp:
def compile_sop_pcode(self, op, srcs: dict[str, UOp | int], sdst_reg: UOp, sdst_size: int) -> UOp:
"""Compile a scalar instruction with dynamic destination register."""
pcode = get_pcode(op)
srcs.update({'VCC': self.rmask(_c(VCC_LO.offset)), 'EXEC': self.rexec(), 'SCC': self.rsgpr_dyn(_c(SCC.offset)),
@@ -653,7 +653,7 @@ class _Ctx:
elif dest.startswith('VGPR['): stores.append(self.vgpr.index(val[0].cast(dtypes.int)).store(val[1].cast(dtypes.uint32)))
return UOp.sink(*stores, *self.inc_pc())
def compile_vop_pcode(self, op, srcs: dict[str, UOp], lane: UOp, vdst_reg: UOp, exec_mask: UOp,
def compile_vop_pcode(self, op, srcs: dict[str, UOp | int], lane: UOp, vdst_reg: UOp, exec_mask: UOp,
opsel_dst_hi: bool | UOp = False, sdst_reg: int | None = None, clmp: int = 0,
src0_off: UOp | None = None) -> UOp:
"""Compile VOP instruction. Returns sink with stores and inc_pc."""
@@ -688,6 +688,7 @@ class _Ctx:
if clmp and int_saturate is None and any(p in op.name for p in ('_SUB_U32', '_ADD_U32', '_SUB_U16', '_ADD_U16')):
s0, s1 = srcs.get('S0'), srcs.get('S1')
if s0 is not None and s1 is not None:
assert isinstance(s0, UOp) and isinstance(s1, UOp)
a, b = (s1.cast(dtypes.uint32), s0.cast(dtypes.uint32)) if 'SUBREV' in op.name else (s0.cast(dtypes.uint32), s1.cast(dtypes.uint32))
if 'SUB' in op.name:
int_saturate = (a < b).where(_c(0), a - b) # underflow -> 0
@@ -700,14 +701,13 @@ class _Ctx:
for dest, val in assigns:
# VGPR bit-slice assignment: VGPR[lane][reg][hi:lo] = (vgpr_idx, rhs_val, hi, lo[, cond]) -> read-modify-write
if dest.startswith('VGPR[') and re.search(r'\[\d+:\d+\]', dest):
vgpr_idx, rhs_val, hi_bit, lo_bit = val[:4]
branch_cond = val[4] if len(val) > 4 else None # optional condition from if/else branch
# VGPR bit-slice: (vgpr_idx, rhs_val, hi_bit, lo_bit) - hi/lo are UOp constants
hi_bit, lo_bit = int(val[2].arg), int(val[3].arg)
width = hi_bit - lo_bit + 1
old = self.vgpr.index(vgpr_idx.cast(dtypes.int), ptr=True).load()
new_val = _set_bits(old, _val_to_bits(rhs_val), width, lo_bit).cast(dtypes.uint32)
old = self.vgpr.index(val[0].cast(dtypes.int), ptr=True).load()
new_val = _set_bits(old, _val_to_bits(val[1]), width, lo_bit).cast(dtypes.uint32)
active = _lane_active(exec_mask, lane)
if branch_cond is not None: active = active & _to_u32(branch_cond).ne(_c(0))
raw_stores.append(('vgpr_direct', self.vgpr.index(vgpr_idx.cast(dtypes.int), active).store(new_val)))
raw_stores.append(('vgpr_direct', self.vgpr.index(val[0].cast(dtypes.int), active).store(new_val)))
continue
if 'D0' in dest and '[laneId]' in dest:
old_vcc = self.rmask(_c(VCC_LO.offset))
@@ -715,13 +715,13 @@ class _Ctx:
raw_stores.extend([('vcc', s) for s in self.wmask(_c(VCC_LO.offset), new_vcc)])
elif dest.startswith('D0'):
if (slice_match := re.match(r'D0\[(\d+)\s*:\s*(\d+)\]', dest)):
hi_bit, lo_bit = int(slice_match.group(1)), int(slice_match.group(2))
if hi_bit != 31 or lo_bit != 0:
width, slice_mask = hi_bit - lo_bit + 1, (1 << (hi_bit - lo_bit + 1)) - 1
d0_hi_bit, d0_lo_bit = int(slice_match.group(1)), int(slice_match.group(2))
if d0_hi_bit != 31 or d0_lo_bit != 0:
d0_width, slice_mask = d0_hi_bit - d0_lo_bit + 1, (1 << (d0_hi_bit - d0_lo_bit + 1)) - 1
val_bits = val.bitcast(dtypes.uint16).cast(dtypes.uint32) if val.dtype == dtypes.half else \
val.cast(dtypes.uint32) if val.dtype in (dtypes.uint16, dtypes.int16) else \
val.cast(dtypes.uint32) & UOp.const(dtypes.uint32, slice_mask)
raw_stores.append(('vgpr_slice', (lo_bit, width, val_bits)))
raw_stores.append(('vgpr_slice', (d0_lo_bit, d0_width, val_bits)))
continue
# For integer ops with clamp, use pre-computed saturated value; for floats, clamp to [0,1]
if int_saturate is not None: val = int_saturate
@@ -917,8 +917,10 @@ def _compile_sdwa(inst: irc.VOP1_SDWA | irc.VOP2_SDWA | irc.VOP2_SDWA_SDST | irc
is_vopc = isinstance(inst, irc.VOPC_SDWA_SDST)
exec_mask = ctx.rexec()
# sd=1 means use sdst register, sd=0 means use VCC (for VOPC_SDWA_SDST and VOP2_SDWA_SDST)
has_sdst = isinstance(inst, (irc.VOP2_SDWA_SDST, irc.VOPC_SDWA_SDST))
sdst_off = _c(inst.sdst.offset) if has_sdst and getattr(inst, 'sd', 0) else _c(VCC_LO.offset)
if isinstance(inst, (irc.VOP2_SDWA_SDST, irc.VOPC_SDWA_SDST)):
sdst_off = _c(inst.sdst.offset) if getattr(inst, 'sd', False) else _c(VCC_LO.offset)
else:
sdst_off = _c(VCC_LO.offset)
# Read SDWA fields (these are dynamic but shared across lanes)
src0_sel = ctx.inst_field(type(inst).src0_sel)
src0_sext = ctx.inst_field(type(inst).src0_sext)
@@ -937,7 +939,7 @@ def _compile_sdwa(inst: irc.VOP1_SDWA | irc.VOP2_SDWA | irc.VOP2_SDWA_SDST | irc
s0 = _sdwa_select(s0_raw, src0_sel, src0_sext)
s1_raw = ctx.rsgpr_dyn(vsrc1_reg) if inst.s1 else ctx.rvgpr_dyn(vsrc1_reg, lc)
s1 = _sdwa_select(s1_raw, src1_sel, src1_sext)
srcs: dict[str, UOp] = {'S0': s0, 'S1': s1, 'laneId': lc}
srcs = {'S0': s0, 'S1': s1, 'laneId': lc}
for dest, val in parse_pcode(pcode, srcs)[1]:
if '[laneId]' in dest and ('D0' in dest or 'EXEC' in dest): return val.cast(dtypes.uint32)
return _c(0)
@@ -947,27 +949,27 @@ def _compile_sdwa(inst: irc.VOP1_SDWA | irc.VOP2_SDWA | irc.VOP2_SDWA_SDST | irc
# Non-VOPC path: VOP1_SDWA, VOP2_SDWA, VOP2_SDWA_SDST — uses lane loop
lane = ctx.range()
vdst_reg = ctx.inst_field(type(inst).vdst)
vdst_reg = ctx.inst_field(type(inst).vdst) # type: ignore[union-attr]
s0_raw = ctx.rsgpr_dyn(vsrc0_reg) if inst.s0 else ctx.rvgpr_dyn(vsrc0_reg, lane)
s0 = _sdwa_select(s0_raw, src0_sel, src0_sext)
if isinstance(inst, (irc.VOP2_SDWA, irc.VOP2_SDWA_SDST)):
s1_raw = ctx.rsgpr_dyn(vsrc1_reg) if inst.s1 else ctx.rvgpr_dyn(vsrc1_reg, lane)
s1 = _sdwa_select(s1_raw, src1_sel, src1_sext)
srcs: dict[str, UOp] = {'S0': s0, 'S1': s1, 'D0': ctx.rvgpr_dyn(vdst_reg, lane)}
srcs:dict[str, UOp | int] = {'S0': s0, 'S1': s1, 'D0': ctx.rvgpr_dyn(vdst_reg, lane)}
else:
srcs = {'S0': s0}
# dst_sel and dst_unused
has_dst_sel = hasattr(type(inst), 'dst_sel')
if has_dst_sel:
dst_sel = ctx.inst_field(type(inst).dst_sel)
dst_unused = ctx.inst_field(type(inst).dst_unused)
dst_sel = ctx.inst_field(type(inst).dst_sel) # type: ignore[union-attr]
dst_unused = ctx.inst_field(type(inst).dst_unused) # type: ignore[union-attr]
srcs.update({'VCC': ctx.rmask(_c(VCC_LO.offset)), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)),
'laneId': lane, 'VDST': vdst_reg, 'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0),
'ROUND_NEAREST_EVEN': _c(0), '_vgpr': ctx.vgpr, '_wave_size': ctx.wave_size,
'SDWA_SRC0_SEL': _c(0), 'BYTE0': _c(0), 'BYTE1': _c(1), 'BYTE2': _c(2), 'BYTE3': _c(3),
'WORD0': _c(0), 'WORD1': _c(1)})
_, assigns = parse_pcode(pcode, srcs)
stores: list[UOp] = []
stores = []
vcc_val = None
for dest, val in assigns:
if 'D0' in dest and '[laneId]' in dest:
@@ -1020,7 +1022,7 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane))
srcs = {'S0': s0, 'D0': d0}
srcs:dict[str, UOp | int] = {'S0': s0, 'D0': d0}
else:
vsrc1_reg = ctx.inst_field(type(inst).vsrc1)
vsrc1_hi = bits['s0'] == 16 and (vsrc1_reg >= _c(128))
@@ -1256,6 +1258,7 @@ def _compile_mfma(inst: irc.VOP3P, ctx: _Ctx) -> UOp:
src1_is_vgpr = src1_off >= _c(256)
m = _re.search(r'(\d+)X(\d+)X(\d+)', op_name)
if m is None: raise ValueError(f"could not parse MFMA dimensions from {op_name}")
M, N, K = int(m.group(1)), int(m.group(2)), int(m.group(3))
is_bf16 = 'BF16' in op_name
@@ -1545,7 +1548,7 @@ def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp:
is_pk_f32 = 'PK' in op_name and 'F32' in op_name and 'MOV' not in op_name # CDNA packed F32 ops
is_pk_mov_b32 = 'PK_MOV_B32' in op_name # CDNA packed MOV needs special handling
do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name and not is_pk_f32
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16, literal=literal, do_cast=do_cast)
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16, literal=literal, do_cast=do_cast)
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16, literal=literal, do_cast=do_cast)
@@ -1571,6 +1574,7 @@ def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp:
stores = [ctx.wvgpr_dyn(vdst_reg, lane, lo_out, exec_mask), ctx.wvgpr_dyn(vdst_reg + _c(1), lane, hi_out, exec_mask)]
return UOp.sink(UOp.group(*stores).end(lane), *ctx.inc_pc())
srcs: dict[str, UOp | int] = {}
if is_pk_f32:
# CDNA packed F32: read 32-bit sources, build 64-bit packed values using opsel.
# For VGPRs: opsel selects between v[reg] (0) and v[reg+1] (1) for each half.
@@ -1648,6 +1652,7 @@ def _compile_vopd(inst: ir3.VOPD | ir4.VOPD, ctx: _Ctx) -> UOp:
lane = ctx.range()
srcy0, srcy1 = ctx.rsrc_dyn(srcy0_off, lane, literal=literal), ctx.rvgpr_dyn(vsrcy1_reg, lane)
all_stores = []
srcs:dict[str, UOp | int] = {}
for op, src0_off, vsrc1_reg, vdst_reg, label in [(inst.opx, srcx0_off, vsrcx1_reg, vdstx_reg, 'X'),
(inst.opy, srcy0_off, vsrcy1_reg, vdsty_reg, 'Y')]:
vop = VOPD_TO_VOP2.get(op)

View File

@@ -1073,7 +1073,7 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
ws = env.get('_wave_size', 32)
vgpr_idx = _to_u32(rg) * _u32(ws) + _to_u32(ln)
if assigns is not None:
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}][{hi_val}:{lo_val}]', (vgpr_idx, val, hi_val, lo_val)))
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}][{hi_val}:{lo_val}]', (vgpr_idx, val, _u32(hi_val), _u32(lo_val))))
i += 1
continue
if j < len(toks) and toks[j].type == 'DOT': j += 2 # skip .type suffix