diff --git a/extra/assembly/amd/dsl2.py b/extra/assembly/amd/dsl2.py index 78891706b6..679ab84b9e 100644 --- a/extra/assembly/amd/dsl2.py +++ b/extra/assembly/amd/dsl2.py @@ -62,11 +62,8 @@ v = src[256:511] # VGPR0-255 # ══════════════════════════════════════════════════════════════ class BitField: - required_size = None def __init__(self, hi: int, lo: int): self.hi, self.lo = hi, lo - if self.required_size is not None and (hi - lo + 1) != self.required_size: - raise RuntimeError(f"wrong size field: expected {self.required_size}, got {hi - lo + 1}") def mask(self) -> int: return (1 << (self.hi - self.lo + 1)) - 1 def encode(self, val) -> int: return val def decode(self, val): return val @@ -95,26 +92,34 @@ class EnumBitField(BitField): # Typed fields # ══════════════════════════════════════════════════════════════ -class VGPRField(BitField): - required_size = 8 - def encode(self, val): - if not isinstance(val, Reg) or val.offset < 256 or val.offset >= 512: - raise RuntimeError(f"VGPRField requires VGPR, got {val}") - return val.offset - 256 - def decode(self, raw): return src[256 + raw] - class SrcField(BitField): - required_size = 9 + _valid_range = (0, 511) # inclusive _FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247} + + def __init__(self, hi: int, lo: int): + super().__init__(hi, lo) + expected_size = self._valid_range[1] - self._valid_range[0] + 1 + actual_size = 1 << (hi - lo + 1) + if actual_size != expected_size: + raise RuntimeError(f"{self.__class__.__name__}: field size {hi - lo + 1} bits ({actual_size}) doesn't match range {self._valid_range} ({expected_size})") + def encode(self, val): - if isinstance(val, Reg): return val.offset - if isinstance(val, float): - if val not in self._FLOAT_ENC: raise RuntimeError(f"SrcField: unsupported float constant {val}") - return self._FLOAT_ENC[val] - if isinstance(val, int) and 0 <= val <= 64: return 128 + val - if isinstance(val, int) and -16 <= val < 0: return 192 - val - raise RuntimeError(f"SrcField: invalid value {val}") - def decode(self, raw): return src[raw] + if isinstance(val, Reg): offset = val.offset + elif isinstance(val, float): + if val not in self._FLOAT_ENC: raise RuntimeError(f"unsupported float constant {val}") + offset = self._FLOAT_ENC[val] + elif isinstance(val, int) and 0 <= val <= 64: offset = 128 + val + elif isinstance(val, int) and -16 <= val < 0: offset = 192 - val + else: raise RuntimeError(f"invalid value {val}") + if not (self._valid_range[0] <= offset <= self._valid_range[1]): + raise RuntimeError(f"{self.__class__.__name__}: {val} (offset {offset}) out of range {self._valid_range}") + return offset - self._valid_range[0] + + def decode(self, raw): return src[raw + self._valid_range[0]] + +class VGPRField(SrcField): _valid_range = (256, 511) +class SGPRField(SrcField): _valid_range = (0, 127) +class SSrcField(SrcField): _valid_range = (0, 255) # ══════════════════════════════════════════════════════════════ # Inst base class