diff --git a/test/test_dtype.py b/test/test_dtype.py index b1eb663fa5..2104b67112 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -131,7 +131,7 @@ class TestDType(unittest.TestCase): def test_finfo(self): if self.DTYPE not in [dtypes.float16, dtypes.float32, dtypes.float64]: return info = np.finfo(_to_np_dtype(self.DTYPE)) - self.assertEqual(info.bits, self.DTYPE.itemsize*8) + self.assertEqual(info.bits, self.DTYPE.bitsize) self.assertEqual((info.nexp, info.nmant), dtypes.finfo(self.DTYPE)) def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index e539c61738..9a79696c10 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -38,16 +38,18 @@ class AddrSpace(Enum): @dataclass(frozen=True, eq=False) class DType(metaclass=DTypeMetaClass): priority: int # this determines when things get upcasted - itemsize: int + bitsize: int name: str fmt: FmtStr|None count: int _scalar: DType|None + @property + def itemsize(self) -> int: return (self.bitsize + 7) // 8 @staticmethod - def new(priority:int, itemsize:int, name:str, fmt:FmtStr|None): return DType(priority, itemsize, name, fmt, 1, None) + def new(priority:int, bitsize:int, name:str, fmt:FmtStr|None): return DType(priority, bitsize, name, fmt, 1, None) def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self)) def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count != 1 else "") - def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count) + def __lt__(self, o:DType): return (self.priority, self.bitsize, self.name, self.fmt, self.count) < (o.priority, o.bitsize, o.name, o.fmt, o.count) @property def base(self): return self @property @@ -56,9 +58,9 @@ class DType(metaclass=DTypeMetaClass): def vec(self, sz:int) -> DType: assert self.count == 1, f"can't vectorize {self} with size {sz}" if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar - return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self) + return DType(self.priority, self.bitsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self) def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: - return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size) + return PtrDType(self.priority, self.bitsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size) def scalar(self) -> DType: return self._scalar if self._scalar is not None else self def nbytes(self) -> int: raise RuntimeError("only ptr types have nbytes") @property @@ -79,8 +81,8 @@ class PtrDType(DType): assert self.v == 1, f"can't vectorize ptr {self} with size {sz}" if sz == 1: return self # sz=1 is a scalar if isinstance(self, ImageDType): - return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape) - return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size) + return ImageDType(self.priority, self.bitsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape) + return type(self)(self.priority, self.bitsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size) def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: raise RuntimeError("can't make a pointer from a pointer") def nbytes(self) -> int: if self.size == -1: raise RuntimeError("can't get nbytes of a pointer with unlimited size") @@ -142,12 +144,12 @@ class dtypes: @staticmethod @functools.cache def min(dtype:DType): - if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().itemsize*8-1) + if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().bitsize-1) return -float("inf") if dtypes.is_float(dtype) else False @staticmethod @functools.cache def max(dtype:DType): - if dtypes.is_int(dtype): return 2**(dtype.scalar().itemsize*8)-1+dtypes.min(dtype) + if dtypes.is_int(dtype): return 2**(dtype.scalar().bitsize)-1+dtypes.min(dtype) return float("inf") if dtypes.is_float(dtype) else True @staticmethod def finfo(dtype:DType) -> tuple[int, int]: @@ -158,23 +160,23 @@ class dtypes: @staticmethod def fields() -> dict[str, DType]: return DTYPES_DICT void: Final[DType] = DType.new(-1, 0, "void", None) - index: Final[DType] = DType.new(-1,100, "index", None) + index: Final[DType] = DType.new(-1, 800, "index", None) bool: Final[DType] = DType.new(0, 1, "bool", '?') - int8: Final[DType] = DType.new(1, 1, "signed char", 'b') - uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B') - int16: Final[DType] = DType.new(3, 2, "short", 'h') - uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H') - int32: Final[DType] = DType.new(5, 4, "int", 'i') - uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I') - int64: Final[DType] = DType.new(7, 8, "long", 'q') - uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q') - fp8e4m3: Final[DType] = DType.new(9, 1, "float8_e4m3", None) - fp8e5m2: Final[DType] = DType.new(10, 1, "float8_e5m2", None) - float16: Final[DType] = DType.new(11, 2, "half", 'e') + int8: Final[DType] = DType.new(1, 8, "signed char", 'b') + uint8: Final[DType] = DType.new(2, 8, "unsigned char", 'B') + int16: Final[DType] = DType.new(3, 16, "short", 'h') + uint16: Final[DType] = DType.new(4, 16, "unsigned short", 'H') + int32: Final[DType] = DType.new(5, 32, "int", 'i') + uint32: Final[DType] = DType.new(6, 32, "unsigned int", 'I') + int64: Final[DType] = DType.new(7, 64, "long", 'q') + uint64: Final[DType] = DType.new(8, 64, "unsigned long", 'Q') + fp8e4m3: Final[DType] = DType.new(9, 8, "float8_e4m3", None) + fp8e5m2: Final[DType] = DType.new(10, 8, "float8_e5m2", None) + float16: Final[DType] = DType.new(11, 16, "half", 'e') # bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16 - bfloat16: Final[DType] = DType.new(12, 2, "__bf16", None) - float32: Final[DType] = DType.new(13, 4, "float", 'f') - float64: Final[DType] = DType.new(14, 8, "double", 'd') + bfloat16: Final[DType] = DType.new(12, 16, "__bf16", None) + float32: Final[DType] = DType.new(13, 32, "float", 'f') + float64: Final[DType] = DType.new(14, 64, "double", 'd') # dtype aliases half = float16; float = float32; double = float64 # noqa: E702 @@ -183,9 +185,9 @@ class dtypes: # NOTE: these are image dtypes @staticmethod - def imageh(shp, pitch=-1): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch) + def imageh(shp, pitch=-1): return ImageDType(100, 16, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch) @staticmethod - def imagef(shp, pitch=-1): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch) + def imagef(shp, pitch=-1): return ImageDType(100, 32, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch) default_float: ClassVar[DType] = float32 default_int: ClassVar[DType] = int32 diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 6106efbf15..e704fdacce 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -518,7 +518,7 @@ class AMDHIPRenderer(CStyleLanguage): prefix.append("typedef long unsigned int size_t;") ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]] ocml_ops = {Ops.EXP2: ("exp2", "pure"), Ops.LOG2: ("log2", "pure"), Ops.SQRT: ("sqrt", "const"), Ops.SIN: ("sin", ""), Ops.TRUNC: ("trunc", "")} - ocml = [(f"__ocml_{ocml_ops[op][0]}_f{dt.itemsize * 8}", dt.name, dt.name, ocml_ops[op][1]) + ocml = [(f"__ocml_{ocml_ops[op][0]}_f{dt.bitsize}", dt.name, dt.name, ocml_ops[op][1]) for op, dt in dedup((u.op, u.dtype.scalar()) for u in uops) if op in ocml_ops and dt in (dtypes.half, dtypes.float, dtypes.double)] if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;") if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#define half _Float16") diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index 3729a0ef6d..cba6a1dd25 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -12,7 +12,7 @@ def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else { **{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]}, - **{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.itemsize*8)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t] + **{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t] # alu ops, aop[][] u_aop = { Ops.ADD: "iadd", Ops.MUL: "imul", Ops.IDIV: "udiv", Ops.MOD: "umod", Ops.CMPLT: "ult", Ops.CMPNE: "ine", Ops.CMPEQ: "ieq", Ops.OR: "ior", @@ -26,7 +26,7 @@ def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ( def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def: if isinstance(it, PtrDType) and ot == dtypes.long: return src if ot == dtypes.bool: return nalu(b, c(it, False)+'ne'+('u' if c(it) == 'f' else ''), src, nimm(b, 0, it)) - return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.itemsize*8}", src) + return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src) def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable): nif = mesa.nir_push_if(b, cond) @@ -71,12 +71,12 @@ def nimm_set(imm:mesa.nir_def, x, dtype:DType): instr = ctypes.cast(imm.parent_instr, ctypes.POINTER(mesa.nir_load_const_instr)) struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, x) -@nir_instr(nc=1, bs=lambda dtype: 1 if dtype == dtypes.bool else dtype.itemsize * 8) +@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize) def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def: - nimm_set(getattr((instr:=mesa.nir_load_const_instr_create(b.shader, 1, 1 if dtype==dtypes.bool else dtype.itemsize * 8)).contents, "def"), x, dtype) + nimm_set(getattr((instr:=mesa.nir_load_const_instr_create(b.shader, 1, dtype.bitsize)).contents, "def"), x, dtype) return instr -@nir_instr(nc=1, bs=lambda dtype: 1 if dtype == dtypes.bool else dtype.itemsize * 8) -def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, 1 if dtype == dtypes.bool else dtype.itemsize * 8) +@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize) +def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, dtype.bitsize) deref_var = nir_instr(nc=1, bs=32, modes=lambda var:var.data.mode, type=lambda var:var.type, var=lambda var:ctypes.pointer(var))( # pylint: disable=W0108 lambda b, var: mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_var)) @@ -86,7 +86,7 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<