all fields

This commit is contained in:
George Hotz
2026-01-12 11:33:03 +09:00
parent 72a203585e
commit 4cafe9e56f

View File

@@ -62,11 +62,8 @@ v = src[256:511] # VGPR0-255
# ══════════════════════════════════════════════════════════════
class BitField:
required_size = None
def __init__(self, hi: int, lo: int):
self.hi, self.lo = hi, lo
if self.required_size is not None and (hi - lo + 1) != self.required_size:
raise RuntimeError(f"wrong size field: expected {self.required_size}, got {hi - lo + 1}")
def mask(self) -> int: return (1 << (self.hi - self.lo + 1)) - 1
def encode(self, val) -> int: return val
def decode(self, val): return val
@@ -95,26 +92,34 @@ class EnumBitField(BitField):
# Typed fields
# ══════════════════════════════════════════════════════════════
class VGPRField(BitField):
required_size = 8
def encode(self, val):
if not isinstance(val, Reg) or val.offset < 256 or val.offset >= 512:
raise RuntimeError(f"VGPRField requires VGPR, got {val}")
return val.offset - 256
def decode(self, raw): return src[256 + raw]
class SrcField(BitField):
required_size = 9
_valid_range = (0, 511) # inclusive
_FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247}
def __init__(self, hi: int, lo: int):
super().__init__(hi, lo)
expected_size = self._valid_range[1] - self._valid_range[0] + 1
actual_size = 1 << (hi - lo + 1)
if actual_size != expected_size:
raise RuntimeError(f"{self.__class__.__name__}: field size {hi - lo + 1} bits ({actual_size}) doesn't match range {self._valid_range} ({expected_size})")
def encode(self, val):
if isinstance(val, Reg): return val.offset
if isinstance(val, float):
if val not in self._FLOAT_ENC: raise RuntimeError(f"SrcField: unsupported float constant {val}")
return self._FLOAT_ENC[val]
if isinstance(val, int) and 0 <= val <= 64: return 128 + val
if isinstance(val, int) and -16 <= val < 0: return 192 - val
raise RuntimeError(f"SrcField: invalid value {val}")
def decode(self, raw): return src[raw]
if isinstance(val, Reg): offset = val.offset
elif isinstance(val, float):
if val not in self._FLOAT_ENC: raise RuntimeError(f"unsupported float constant {val}")
offset = self._FLOAT_ENC[val]
elif isinstance(val, int) and 0 <= val <= 64: offset = 128 + val
elif isinstance(val, int) and -16 <= val < 0: offset = 192 - val
else: raise RuntimeError(f"invalid value {val}")
if not (self._valid_range[0] <= offset <= self._valid_range[1]):
raise RuntimeError(f"{self.__class__.__name__}: {val} (offset {offset}) out of range {self._valid_range}")
return offset - self._valid_range[0]
def decode(self, raw): return src[raw + self._valid_range[0]]
class VGPRField(SrcField): _valid_range = (256, 511)
class SGPRField(SrcField): _valid_range = (0, 127)
class SSrcField(SrcField): _valid_range = (0, 255)
# ══════════════════════════════════════════════════════════════
# Inst base class