From cf6a429aaa8126165de0cdfc143bd9204a7515e8 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 20 Mar 2026 07:44:09 +0200 Subject: [PATCH] mypy emulator pre-commit passing (#15379) * fix dict stuff * add type: ignores * fix pcode to put uops not ints --- test/mockgpu/amd/emu.py | 51 +++++++++++++++++++++------------------ test/mockgpu/amd/pcode.py | 2 +- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/test/mockgpu/amd/emu.py b/test/mockgpu/amd/emu.py index 548a4525ad..ff99313ce9 100644 --- a/test/mockgpu/amd/emu.py +++ b/test/mockgpu/amd/emu.py @@ -337,7 +337,7 @@ def get_pcode(op) -> str: pcode = pcode.replace('VCC = 0x0LL', 'VCC.u64[laneId] = 0').replace('VCC = 0x1LL', 'VCC.u64[laneId] = 1') return pcode -def parse_pcode(pcode: str, srcs: dict[str, UOp] | None = None) -> tuple[dict, list[tuple[str, UOp]]]: +def parse_pcode(pcode: str, srcs: dict[str, UOp | int] | None = None) -> tuple[dict, list[tuple[str, UOp]]]: env: dict = srcs.copy() if srcs else {} assigns: list[tuple[str, UOp]] = [] raw_lines = [l.strip().rstrip(';') for l in pcode.split('\n') if l.strip() and not l.strip().startswith('//')] @@ -620,7 +620,7 @@ class _Ctx: elif dest.startswith('VCC'): stores.extend(self.wmask(_c(VCC_LO.offset), val)) return stores - def compile_sop_pcode(self, op, srcs: dict[str, UOp], sdst_reg: UOp, sdst_size: int) -> UOp: + def compile_sop_pcode(self, op, srcs: dict[str, UOp | int], sdst_reg: UOp, sdst_size: int) -> UOp: """Compile a scalar instruction with dynamic destination register.""" pcode = get_pcode(op) srcs.update({'VCC': self.rmask(_c(VCC_LO.offset)), 'EXEC': self.rexec(), 'SCC': self.rsgpr_dyn(_c(SCC.offset)), @@ -653,7 +653,7 @@ class _Ctx: elif dest.startswith('VGPR['): stores.append(self.vgpr.index(val[0].cast(dtypes.int)).store(val[1].cast(dtypes.uint32))) return UOp.sink(*stores, *self.inc_pc()) - def compile_vop_pcode(self, op, srcs: dict[str, UOp], lane: UOp, vdst_reg: UOp, exec_mask: UOp, + def compile_vop_pcode(self, op, srcs: dict[str, UOp | int], lane: UOp, vdst_reg: UOp, exec_mask: UOp, opsel_dst_hi: bool | UOp = False, sdst_reg: int | None = None, clmp: int = 0, src0_off: UOp | None = None) -> UOp: """Compile VOP instruction. Returns sink with stores and inc_pc.""" @@ -688,6 +688,7 @@ class _Ctx: if clmp and int_saturate is None and any(p in op.name for p in ('_SUB_U32', '_ADD_U32', '_SUB_U16', '_ADD_U16')): s0, s1 = srcs.get('S0'), srcs.get('S1') if s0 is not None and s1 is not None: + assert isinstance(s0, UOp) and isinstance(s1, UOp) a, b = (s1.cast(dtypes.uint32), s0.cast(dtypes.uint32)) if 'SUBREV' in op.name else (s0.cast(dtypes.uint32), s1.cast(dtypes.uint32)) if 'SUB' in op.name: int_saturate = (a < b).where(_c(0), a - b) # underflow -> 0 @@ -700,14 +701,13 @@ class _Ctx: for dest, val in assigns: # VGPR bit-slice assignment: VGPR[lane][reg][hi:lo] = (vgpr_idx, rhs_val, hi, lo[, cond]) -> read-modify-write if dest.startswith('VGPR[') and re.search(r'\[\d+:\d+\]', dest): - vgpr_idx, rhs_val, hi_bit, lo_bit = val[:4] - branch_cond = val[4] if len(val) > 4 else None # optional condition from if/else branch + # VGPR bit-slice: (vgpr_idx, rhs_val, hi_bit, lo_bit) - hi/lo are UOp constants + hi_bit, lo_bit = int(val[2].arg), int(val[3].arg) width = hi_bit - lo_bit + 1 - old = self.vgpr.index(vgpr_idx.cast(dtypes.int), ptr=True).load() - new_val = _set_bits(old, _val_to_bits(rhs_val), width, lo_bit).cast(dtypes.uint32) + old = self.vgpr.index(val[0].cast(dtypes.int), ptr=True).load() + new_val = _set_bits(old, _val_to_bits(val[1]), width, lo_bit).cast(dtypes.uint32) active = _lane_active(exec_mask, lane) - if branch_cond is not None: active = active & _to_u32(branch_cond).ne(_c(0)) - raw_stores.append(('vgpr_direct', self.vgpr.index(vgpr_idx.cast(dtypes.int), active).store(new_val))) + raw_stores.append(('vgpr_direct', self.vgpr.index(val[0].cast(dtypes.int), active).store(new_val))) continue if 'D0' in dest and '[laneId]' in dest: old_vcc = self.rmask(_c(VCC_LO.offset)) @@ -715,13 +715,13 @@ class _Ctx: raw_stores.extend([('vcc', s) for s in self.wmask(_c(VCC_LO.offset), new_vcc)]) elif dest.startswith('D0'): if (slice_match := re.match(r'D0\[(\d+)\s*:\s*(\d+)\]', dest)): - hi_bit, lo_bit = int(slice_match.group(1)), int(slice_match.group(2)) - if hi_bit != 31 or lo_bit != 0: - width, slice_mask = hi_bit - lo_bit + 1, (1 << (hi_bit - lo_bit + 1)) - 1 + d0_hi_bit, d0_lo_bit = int(slice_match.group(1)), int(slice_match.group(2)) + if d0_hi_bit != 31 or d0_lo_bit != 0: + d0_width, slice_mask = d0_hi_bit - d0_lo_bit + 1, (1 << (d0_hi_bit - d0_lo_bit + 1)) - 1 val_bits = val.bitcast(dtypes.uint16).cast(dtypes.uint32) if val.dtype == dtypes.half else \ val.cast(dtypes.uint32) if val.dtype in (dtypes.uint16, dtypes.int16) else \ val.cast(dtypes.uint32) & UOp.const(dtypes.uint32, slice_mask) - raw_stores.append(('vgpr_slice', (lo_bit, width, val_bits))) + raw_stores.append(('vgpr_slice', (d0_lo_bit, d0_width, val_bits))) continue # For integer ops with clamp, use pre-computed saturated value; for floats, clamp to [0,1] if int_saturate is not None: val = int_saturate @@ -917,8 +917,10 @@ def _compile_sdwa(inst: irc.VOP1_SDWA | irc.VOP2_SDWA | irc.VOP2_SDWA_SDST | irc is_vopc = isinstance(inst, irc.VOPC_SDWA_SDST) exec_mask = ctx.rexec() # sd=1 means use sdst register, sd=0 means use VCC (for VOPC_SDWA_SDST and VOP2_SDWA_SDST) - has_sdst = isinstance(inst, (irc.VOP2_SDWA_SDST, irc.VOPC_SDWA_SDST)) - sdst_off = _c(inst.sdst.offset) if has_sdst and getattr(inst, 'sd', 0) else _c(VCC_LO.offset) + if isinstance(inst, (irc.VOP2_SDWA_SDST, irc.VOPC_SDWA_SDST)): + sdst_off = _c(inst.sdst.offset) if getattr(inst, 'sd', False) else _c(VCC_LO.offset) + else: + sdst_off = _c(VCC_LO.offset) # Read SDWA fields (these are dynamic but shared across lanes) src0_sel = ctx.inst_field(type(inst).src0_sel) src0_sext = ctx.inst_field(type(inst).src0_sext) @@ -937,7 +939,7 @@ def _compile_sdwa(inst: irc.VOP1_SDWA | irc.VOP2_SDWA | irc.VOP2_SDWA_SDST | irc s0 = _sdwa_select(s0_raw, src0_sel, src0_sext) s1_raw = ctx.rsgpr_dyn(vsrc1_reg) if inst.s1 else ctx.rvgpr_dyn(vsrc1_reg, lc) s1 = _sdwa_select(s1_raw, src1_sel, src1_sext) - srcs: dict[str, UOp] = {'S0': s0, 'S1': s1, 'laneId': lc} + srcs = {'S0': s0, 'S1': s1, 'laneId': lc} for dest, val in parse_pcode(pcode, srcs)[1]: if '[laneId]' in dest and ('D0' in dest or 'EXEC' in dest): return val.cast(dtypes.uint32) return _c(0) @@ -947,27 +949,27 @@ def _compile_sdwa(inst: irc.VOP1_SDWA | irc.VOP2_SDWA | irc.VOP2_SDWA_SDST | irc # Non-VOPC path: VOP1_SDWA, VOP2_SDWA, VOP2_SDWA_SDST — uses lane loop lane = ctx.range() - vdst_reg = ctx.inst_field(type(inst).vdst) + vdst_reg = ctx.inst_field(type(inst).vdst) # type: ignore[union-attr] s0_raw = ctx.rsgpr_dyn(vsrc0_reg) if inst.s0 else ctx.rvgpr_dyn(vsrc0_reg, lane) s0 = _sdwa_select(s0_raw, src0_sel, src0_sext) if isinstance(inst, (irc.VOP2_SDWA, irc.VOP2_SDWA_SDST)): s1_raw = ctx.rsgpr_dyn(vsrc1_reg) if inst.s1 else ctx.rvgpr_dyn(vsrc1_reg, lane) s1 = _sdwa_select(s1_raw, src1_sel, src1_sext) - srcs: dict[str, UOp] = {'S0': s0, 'S1': s1, 'D0': ctx.rvgpr_dyn(vdst_reg, lane)} + srcs:dict[str, UOp | int] = {'S0': s0, 'S1': s1, 'D0': ctx.rvgpr_dyn(vdst_reg, lane)} else: srcs = {'S0': s0} # dst_sel and dst_unused has_dst_sel = hasattr(type(inst), 'dst_sel') if has_dst_sel: - dst_sel = ctx.inst_field(type(inst).dst_sel) - dst_unused = ctx.inst_field(type(inst).dst_unused) + dst_sel = ctx.inst_field(type(inst).dst_sel) # type: ignore[union-attr] + dst_unused = ctx.inst_field(type(inst).dst_unused) # type: ignore[union-attr] srcs.update({'VCC': ctx.rmask(_c(VCC_LO.offset)), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane, 'VDST': vdst_reg, 'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0), 'ROUND_NEAREST_EVEN': _c(0), '_vgpr': ctx.vgpr, '_wave_size': ctx.wave_size, 'SDWA_SRC0_SEL': _c(0), 'BYTE0': _c(0), 'BYTE1': _c(1), 'BYTE2': _c(2), 'BYTE3': _c(3), 'WORD0': _c(0), 'WORD1': _c(1)}) _, assigns = parse_pcode(pcode, srcs) - stores: list[UOp] = [] + stores = [] vcc_val = None for dest, val in assigns: if 'D0' in dest and '[laneId]' in dest: @@ -1020,7 +1022,7 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO src0_reg = src0_hi.where(src0_off - _c(384), _c(0)) s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0) d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane)) - srcs = {'S0': s0, 'D0': d0} + srcs:dict[str, UOp | int] = {'S0': s0, 'D0': d0} else: vsrc1_reg = ctx.inst_field(type(inst).vsrc1) vsrc1_hi = bits['s0'] == 16 and (vsrc1_reg >= _c(128)) @@ -1256,6 +1258,7 @@ def _compile_mfma(inst: irc.VOP3P, ctx: _Ctx) -> UOp: src1_is_vgpr = src1_off >= _c(256) m = _re.search(r'(\d+)X(\d+)X(\d+)', op_name) + if m is None: raise ValueError(f"could not parse MFMA dimensions from {op_name}") M, N, K = int(m.group(1)), int(m.group(2)), int(m.group(3)) is_bf16 = 'BF16' in op_name @@ -1545,7 +1548,7 @@ def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp: is_pk_f32 = 'PK' in op_name and 'F32' in op_name and 'MOV' not in op_name # CDNA packed F32 ops is_pk_mov_b32 = 'PK_MOV_B32' in op_name # CDNA packed MOV needs special handling do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name and not is_pk_f32 - literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None + literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr] src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16, literal=literal, do_cast=do_cast) src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16, literal=literal, do_cast=do_cast) src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16, literal=literal, do_cast=do_cast) @@ -1571,6 +1574,7 @@ def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp: stores = [ctx.wvgpr_dyn(vdst_reg, lane, lo_out, exec_mask), ctx.wvgpr_dyn(vdst_reg + _c(1), lane, hi_out, exec_mask)] return UOp.sink(UOp.group(*stores).end(lane), *ctx.inc_pc()) + srcs: dict[str, UOp | int] = {} if is_pk_f32: # CDNA packed F32: read 32-bit sources, build 64-bit packed values using opsel. # For VGPRs: opsel selects between v[reg] (0) and v[reg+1] (1) for each half. @@ -1648,6 +1652,7 @@ def _compile_vopd(inst: ir3.VOPD | ir4.VOPD, ctx: _Ctx) -> UOp: lane = ctx.range() srcy0, srcy1 = ctx.rsrc_dyn(srcy0_off, lane, literal=literal), ctx.rvgpr_dyn(vsrcy1_reg, lane) all_stores = [] + srcs:dict[str, UOp | int] = {} for op, src0_off, vsrc1_reg, vdst_reg, label in [(inst.opx, srcx0_off, vsrcx1_reg, vdstx_reg, 'X'), (inst.opy, srcy0_off, vsrcy1_reg, vdsty_reg, 'Y')]: vop = VOPD_TO_VOP2.get(op) diff --git a/test/mockgpu/amd/pcode.py b/test/mockgpu/amd/pcode.py index b2e2f98308..033ba5e283 100644 --- a/test/mockgpu/amd/pcode.py +++ b/test/mockgpu/amd/pcode.py @@ -1073,7 +1073,7 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic ws = env.get('_wave_size', 32) vgpr_idx = _to_u32(rg) * _u32(ws) + _to_u32(ln) if assigns is not None: - assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}][{hi_val}:{lo_val}]', (vgpr_idx, val, hi_val, lo_val))) + assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}][{hi_val}:{lo_val}]', (vgpr_idx, val, _u32(hi_val), _u32(lo_val)))) i += 1 continue if j < len(toks) and toks[j].type == 'DOT': j += 2 # skip .type suffix