use bitsize on dtype (#14011)

* use bitsize on dtype [pr]

* bitsize

* bitsize in js export, but might be wrong

* reverts

* revert that
This commit is contained in:
George Hotz
2026-01-04 12:16:21 -08:00
committed by GitHub
parent cfb8bf5814
commit 7abf4591ba
5 changed files with 38 additions and 36 deletions

View File

@@ -131,7 +131,7 @@ class TestDType(unittest.TestCase):
def test_finfo(self):
if self.DTYPE not in [dtypes.float16, dtypes.float32, dtypes.float64]: return
info = np.finfo(_to_np_dtype(self.DTYPE))
self.assertEqual(info.bits, self.DTYPE.itemsize*8)
self.assertEqual(info.bits, self.DTYPE.bitsize)
self.assertEqual((info.nexp, info.nmant), dtypes.finfo(self.DTYPE))
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):

View File

@@ -38,16 +38,18 @@ class AddrSpace(Enum):
@dataclass(frozen=True, eq=False)
class DType(metaclass=DTypeMetaClass):
priority: int # this determines when things get upcasted
itemsize: int
bitsize: int
name: str
fmt: FmtStr|None
count: int
_scalar: DType|None
@property
def itemsize(self) -> int: return (self.bitsize + 7) // 8
@staticmethod
def new(priority:int, itemsize:int, name:str, fmt:FmtStr|None): return DType(priority, itemsize, name, fmt, 1, None)
def new(priority:int, bitsize:int, name:str, fmt:FmtStr|None): return DType(priority, bitsize, name, fmt, 1, None)
def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count != 1 else "")
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
def __lt__(self, o:DType): return (self.priority, self.bitsize, self.name, self.fmt, self.count) < (o.priority, o.bitsize, o.name, o.fmt, o.count)
@property
def base(self): return self
@property
@@ -56,9 +58,9 @@ class DType(metaclass=DTypeMetaClass):
def vec(self, sz:int) -> DType:
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
return DType(self.priority, self.bitsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
return PtrDType(self.priority, self.bitsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
def nbytes(self) -> int: raise RuntimeError("only ptr types have nbytes")
@property
@@ -79,8 +81,8 @@ class PtrDType(DType):
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
if sz == 1: return self # sz=1 is a scalar
if isinstance(self, ImageDType):
return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape)
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size)
return ImageDType(self.priority, self.bitsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape)
return type(self)(self.priority, self.bitsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size)
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: raise RuntimeError("can't make a pointer from a pointer")
def nbytes(self) -> int:
if self.size == -1: raise RuntimeError("can't get nbytes of a pointer with unlimited size")
@@ -142,12 +144,12 @@ class dtypes:
@staticmethod
@functools.cache
def min(dtype:DType):
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().itemsize*8-1)
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().bitsize-1)
return -float("inf") if dtypes.is_float(dtype) else False
@staticmethod
@functools.cache
def max(dtype:DType):
if dtypes.is_int(dtype): return 2**(dtype.scalar().itemsize*8)-1+dtypes.min(dtype)
if dtypes.is_int(dtype): return 2**(dtype.scalar().bitsize)-1+dtypes.min(dtype)
return float("inf") if dtypes.is_float(dtype) else True
@staticmethod
def finfo(dtype:DType) -> tuple[int, int]:
@@ -158,23 +160,23 @@ class dtypes:
@staticmethod
def fields() -> dict[str, DType]: return DTYPES_DICT
void: Final[DType] = DType.new(-1, 0, "void", None)
index: Final[DType] = DType.new(-1,100, "index", None)
index: Final[DType] = DType.new(-1, 800, "index", None)
bool: Final[DType] = DType.new(0, 1, "bool", '?')
int8: Final[DType] = DType.new(1, 1, "signed char", 'b')
uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B')
int16: Final[DType] = DType.new(3, 2, "short", 'h')
uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H')
int32: Final[DType] = DType.new(5, 4, "int", 'i')
uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I')
int64: Final[DType] = DType.new(7, 8, "long", 'q')
uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q')
fp8e4m3: Final[DType] = DType.new(9, 1, "float8_e4m3", None)
fp8e5m2: Final[DType] = DType.new(10, 1, "float8_e5m2", None)
float16: Final[DType] = DType.new(11, 2, "half", 'e')
int8: Final[DType] = DType.new(1, 8, "signed char", 'b')
uint8: Final[DType] = DType.new(2, 8, "unsigned char", 'B')
int16: Final[DType] = DType.new(3, 16, "short", 'h')
uint16: Final[DType] = DType.new(4, 16, "unsigned short", 'H')
int32: Final[DType] = DType.new(5, 32, "int", 'i')
uint32: Final[DType] = DType.new(6, 32, "unsigned int", 'I')
int64: Final[DType] = DType.new(7, 64, "long", 'q')
uint64: Final[DType] = DType.new(8, 64, "unsigned long", 'Q')
fp8e4m3: Final[DType] = DType.new(9, 8, "float8_e4m3", None)
fp8e5m2: Final[DType] = DType.new(10, 8, "float8_e5m2", None)
float16: Final[DType] = DType.new(11, 16, "half", 'e')
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
bfloat16: Final[DType] = DType.new(12, 2, "__bf16", None)
float32: Final[DType] = DType.new(13, 4, "float", 'f')
float64: Final[DType] = DType.new(14, 8, "double", 'd')
bfloat16: Final[DType] = DType.new(12, 16, "__bf16", None)
float32: Final[DType] = DType.new(13, 32, "float", 'f')
float64: Final[DType] = DType.new(14, 64, "double", 'd')
# dtype aliases
half = float16; float = float32; double = float64 # noqa: E702
@@ -183,9 +185,9 @@ class dtypes:
# NOTE: these are image dtypes
@staticmethod
def imageh(shp, pitch=-1): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
def imageh(shp, pitch=-1): return ImageDType(100, 16, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
@staticmethod
def imagef(shp, pitch=-1): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
def imagef(shp, pitch=-1): return ImageDType(100, 32, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32

View File

@@ -518,7 +518,7 @@ class AMDHIPRenderer(CStyleLanguage):
prefix.append("typedef long unsigned int size_t;")
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
ocml_ops = {Ops.EXP2: ("exp2", "pure"), Ops.LOG2: ("log2", "pure"), Ops.SQRT: ("sqrt", "const"), Ops.SIN: ("sin", ""), Ops.TRUNC: ("trunc", "")}
ocml = [(f"__ocml_{ocml_ops[op][0]}_f{dt.itemsize * 8}", dt.name, dt.name, ocml_ops[op][1])
ocml = [(f"__ocml_{ocml_ops[op][0]}_f{dt.bitsize}", dt.name, dt.name, ocml_ops[op][1])
for op, dt in dedup((u.op, u.dtype.scalar()) for u in uops) if op in ocml_ops and dt in (dtypes.half, dtypes.float, dtypes.double)]
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#define half _Float16")

View File

@@ -12,7 +12,7 @@ def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer
def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else {
**{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]},
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.itemsize*8)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
# alu ops, aop[<dtype>][<op>]
u_aop = { Ops.ADD: "iadd", Ops.MUL: "imul", Ops.IDIV: "udiv", Ops.MOD: "umod", Ops.CMPLT: "ult", Ops.CMPNE: "ine", Ops.CMPEQ: "ieq", Ops.OR: "ior",
@@ -26,7 +26,7 @@ def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else (
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
if isinstance(it, PtrDType) and ot == dtypes.long: return src
if ot == dtypes.bool: return nalu(b, c(it, False)+'ne'+('u' if c(it) == 'f' else ''), src, nimm(b, 0, it))
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.itemsize*8}", src)
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
nif = mesa.nir_push_if(b, cond)
@@ -71,12 +71,12 @@ def nimm_set(imm:mesa.nir_def, x, dtype:DType):
instr = ctypes.cast(imm.parent_instr, ctypes.POINTER(mesa.nir_load_const_instr))
struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, x)
@nir_instr(nc=1, bs=lambda dtype: 1 if dtype == dtypes.bool else dtype.itemsize * 8)
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def:
nimm_set(getattr((instr:=mesa.nir_load_const_instr_create(b.shader, 1, 1 if dtype==dtypes.bool else dtype.itemsize * 8)).contents, "def"), x, dtype)
nimm_set(getattr((instr:=mesa.nir_load_const_instr_create(b.shader, 1, dtype.bitsize)).contents, "def"), x, dtype)
return instr
@nir_instr(nc=1, bs=lambda dtype: 1 if dtype == dtypes.bool else dtype.itemsize * 8)
def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, 1 if dtype == dtypes.bool else dtype.itemsize * 8)
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, dtype.bitsize)
deref_var = nir_instr(nc=1, bs=32, modes=lambda var:var.data.mode, type=lambda var:var.type, var=lambda var:ctypes.pointer(var))( # pylint: disable=W0108
lambda b, var: mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_var))
@@ -86,7 +86,7 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.itemsize*8//dtype.count, num_components=lambda dtype:dtype.count,
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))

View File

@@ -604,7 +604,7 @@ class Tensor(OpMixin):
bits = bits.bitcast(uint_dtype)
# only randomize the mantissa bits and set the exponent to 1
one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
bits = bits.rshift(dtype.bitsize - nmant).bitwise_or(one)
# bitcast back to the original dtype and reshape
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape).requires_grad_(kwargs.get("requires_grad"))
return out.contiguous() if contiguous else out