From cc49e47ea21aab7565c85292a21acbf5d67d6571 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 26 Jan 2026 11:30:18 +0800 Subject: [PATCH] tinygrad changes from ucode (#14336) * tinygrad changes from ucode * dtype --- test/mockgpu/amd/amdgpu.py | 11 ++++++++++- test/test_uop_graph.py | 2 +- test/unit/test_dtype_spec.py | 2 +- tinygrad/dtype.py | 9 ++++++--- tinygrad/runtime/support/compiler_cpu.py | 6 ++++-- 5 files changed, 22 insertions(+), 8 deletions(-) diff --git a/test/mockgpu/amd/amdgpu.py b/test/mockgpu/amd/amdgpu.py index 323ab1b9d9..d57791f916 100644 --- a/test/mockgpu/amd/amdgpu.py +++ b/test/mockgpu/amd/amdgpu.py @@ -9,6 +9,7 @@ SDMA_MAX_COPY_SIZE = 0x400000 regCOMPUTE_PGM_LO = 0x1bac + amd_gpu.GC_BASE__INST0_SEG0 regCOMPUTE_PGM_RSRC2 = 0x1bb3 + amd_gpu.GC_BASE__INST0_SEG0 +regCOMPUTE_TMPRING_SIZE = 0x1bb8 + amd_gpu.GC_BASE__INST0_SEG0 regCOMPUTE_USER_DATA_0 = 0x1be0 + amd_gpu.GC_BASE__INST0_SEG0 regCOMPUTE_NUM_THREAD_X = 0x1ba7 + amd_gpu.GC_BASE__INST0_SEG0 regGRBM_GFX_INDEX = 0x2200 + amd_gpu.GC_BASE__INST0_SEG1 @@ -185,10 +186,18 @@ class PM4Executor(AMDQueue): for st,sz in self.gpu.mapped_ranges: if st <= prg_addr < st+sz: prg_sz = sz - (prg_addr - st) + # Get scratch size from COMPUTE_TMPRING_SIZE register (WAVESIZE field, bits 12:25 for gfx11) + # WAVESIZE = ceildiv(64 * size_per_thread, 256), so size_per_thread ≈ WAVESIZE * 4 + try: tmpring_size = self.gpu.regs[regCOMPUTE_TMPRING_SIZE] + except KeyError: tmpring_size = 0 + wavesize = (tmpring_size >> 12) & 0x3FFF # bits 12:25 for gfx11 + scratch_size = wavesize * 4 # approximate private_segment_size per lane + assert prg_sz > 0, "Invalid prg ptr (not found in mapped ranges)" - # Pass valid memory ranges and rsrc2 to Python emulator for bounds checking and SGPR/VGPR layout + # Pass valid memory ranges, rsrc2, and scratch_size to Python emulator if hasattr(remu, 'valid_mem_ranges'): remu.valid_mem_ranges = self.gpu.mapped_ranges if hasattr(remu, 'rsrc2'): remu.rsrc2 = rsrc2 + if hasattr(remu, 'scratch_size'): remu.scratch_size = scratch_size err = remu.run_asm(prg_addr, prg_sz, *gl, *lc, args_addr) if err != 0: raise RuntimeError("remu does not support the new instruction introduced in this kernel") diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 84341f1b52..b97ea54417 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -530,7 +530,7 @@ class TestUOpGraph(unittest.TestCase): c = r + 1 self.assertIn(r, c.ranges) - e = UOp.const(dtypes.void, None).end(r) + e = UOp.const(dtypes.int, 1).end(r) self.assertNotIn(r, e.ranges) a = c.after(e) diff --git a/test/unit/test_dtype_spec.py b/test/unit/test_dtype_spec.py index 96c4e05089..7e3a904e98 100644 --- a/test/unit/test_dtype_spec.py +++ b/test/unit/test_dtype_spec.py @@ -380,7 +380,7 @@ class TestTypePromotion(unittest.TestCase): assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64 assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64 # similar to jax but we don't use weak type - assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.fp8e4m3 + assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.uint64 # is this correct? assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32 assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64 diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 78fefaf69a..c86b562cdd 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -141,7 +141,8 @@ class dtypes: if isinstance(val, InvalidType): return val # NOTE: float('nan') != float('nan'), so we canonicalize here if isinstance(val, float) and math.isnan(val): val = math.nan - return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val) + # int is the default + return float(val) if dtypes.is_float(dtype) else bool(val) if dtypes.is_bool(dtype) else int(val) @staticmethod @functools.cache def min(dtype:DType): @@ -169,6 +170,8 @@ class dtypes: uint32: Final[DType] = DType.new(6, 32, "unsigned int", 'I') int64: Final[DType] = DType.new(7, 64, "long", 'q') uint64: Final[DType] = DType.new(8, 64, "unsigned long", 'Q') + _uint128: Final[DType] = DType.new(8, 128, "uint128", None) + _uint256: Final[DType] = DType.new(8, 256, "uint256", None) fp8e4m3: Final[DType] = DType.new(9, 8, "float8_e4m3", None) fp8e5m2: Final[DType] = DType.new(10, 8, "float8_e5m2", None) float16: Final[DType] = DType.new(11, 16, "half", 'e') @@ -208,7 +211,7 @@ def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html # we don't support weak type and complex type promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], - dtypes.int64: [dtypes.fp8e4m3, dtypes.fp8e5m2], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], + dtypes.int64: [dtypes.uint64], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.fp8e4m3, dtypes.fp8e5m2], dtypes.fp8e5m2: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e4m3: [dtypes.float16, dtypes.bfloat16], dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], } @@ -222,7 +225,7 @@ def least_upper_dtype(*ds:DType) -> DType: if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0] def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float) -DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "index"))} +DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "index", "_"))} INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"} @functools.cache diff --git a/tinygrad/runtime/support/compiler_cpu.py b/tinygrad/runtime/support/compiler_cpu.py index 6c9b554023..5dd8450e46 100644 --- a/tinygrad/runtime/support/compiler_cpu.py +++ b/tinygrad/runtime/support/compiler_cpu.py @@ -67,7 +67,7 @@ class LLVMCompiler(Compiler): llvm.LLVMDisposePassBuilderOptions(self.pbo) llvm.LLVMContextDispose(self.context) - def compile(self, src:str) -> bytes: + def compile_to_obj(self, src:str) -> bytes: self.diag_msgs.clear() src_buf = llvm.LLVMCreateMemoryBufferWithMemoryRangeCopy(ctypes.create_string_buffer(src_bytes:=src.encode()), len(src_bytes), b'src') mod = expect(llvm.LLVMParseIRInContext(self.context, src_buf, ctypes.pointer(m:=llvm.LLVMModuleRef()), err:=cerr()), err, m) @@ -80,7 +80,9 @@ class LLVMCompiler(Compiler): obj = ctypes.string_at(llvm.LLVMGetBufferStart(obj_buf), llvm.LLVMGetBufferSize(obj_buf)) llvm.LLVMDisposeMemoryBuffer(obj_buf) if self.diag_msgs: raise RuntimeError("llvm diagnostic: " + "\n".join(self.diag_msgs)) - return jit_loader(obj) if self.jit else obj + return obj + + def compile(self, src:str) -> bytes: return jit_loader(self.compile_to_obj(src)) if self.jit else self.compile_to_obj(src) def disassemble(self, lib:bytes): capstone_flatdump(lib)