mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
tinygrad changes from ucode (#14336)
* tinygrad changes from ucode * dtype
This commit is contained in:
@@ -9,6 +9,7 @@ SDMA_MAX_COPY_SIZE = 0x400000
|
||||
|
||||
regCOMPUTE_PGM_LO = 0x1bac + amd_gpu.GC_BASE__INST0_SEG0
|
||||
regCOMPUTE_PGM_RSRC2 = 0x1bb3 + amd_gpu.GC_BASE__INST0_SEG0
|
||||
regCOMPUTE_TMPRING_SIZE = 0x1bb8 + amd_gpu.GC_BASE__INST0_SEG0
|
||||
regCOMPUTE_USER_DATA_0 = 0x1be0 + amd_gpu.GC_BASE__INST0_SEG0
|
||||
regCOMPUTE_NUM_THREAD_X = 0x1ba7 + amd_gpu.GC_BASE__INST0_SEG0
|
||||
regGRBM_GFX_INDEX = 0x2200 + amd_gpu.GC_BASE__INST0_SEG1
|
||||
@@ -185,10 +186,18 @@ class PM4Executor(AMDQueue):
|
||||
for st,sz in self.gpu.mapped_ranges:
|
||||
if st <= prg_addr < st+sz: prg_sz = sz - (prg_addr - st)
|
||||
|
||||
# Get scratch size from COMPUTE_TMPRING_SIZE register (WAVESIZE field, bits 12:25 for gfx11)
|
||||
# WAVESIZE = ceildiv(64 * size_per_thread, 256), so size_per_thread ≈ WAVESIZE * 4
|
||||
try: tmpring_size = self.gpu.regs[regCOMPUTE_TMPRING_SIZE]
|
||||
except KeyError: tmpring_size = 0
|
||||
wavesize = (tmpring_size >> 12) & 0x3FFF # bits 12:25 for gfx11
|
||||
scratch_size = wavesize * 4 # approximate private_segment_size per lane
|
||||
|
||||
assert prg_sz > 0, "Invalid prg ptr (not found in mapped ranges)"
|
||||
# Pass valid memory ranges and rsrc2 to Python emulator for bounds checking and SGPR/VGPR layout
|
||||
# Pass valid memory ranges, rsrc2, and scratch_size to Python emulator
|
||||
if hasattr(remu, 'valid_mem_ranges'): remu.valid_mem_ranges = self.gpu.mapped_ranges
|
||||
if hasattr(remu, 'rsrc2'): remu.rsrc2 = rsrc2
|
||||
if hasattr(remu, 'scratch_size'): remu.scratch_size = scratch_size
|
||||
err = remu.run_asm(prg_addr, prg_sz, *gl, *lc, args_addr)
|
||||
if err != 0: raise RuntimeError("remu does not support the new instruction introduced in this kernel")
|
||||
|
||||
|
||||
@@ -530,7 +530,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
c = r + 1
|
||||
self.assertIn(r, c.ranges)
|
||||
|
||||
e = UOp.const(dtypes.void, None).end(r)
|
||||
e = UOp.const(dtypes.int, 1).end(r)
|
||||
self.assertNotIn(r, e.ranges)
|
||||
|
||||
a = c.after(e)
|
||||
|
||||
@@ -380,7 +380,7 @@ class TestTypePromotion(unittest.TestCase):
|
||||
assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64
|
||||
assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64
|
||||
# similar to jax but we don't use weak type
|
||||
assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.fp8e4m3
|
||||
assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.uint64 # is this correct?
|
||||
assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32
|
||||
assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64
|
||||
|
||||
|
||||
@@ -141,7 +141,8 @@ class dtypes:
|
||||
if isinstance(val, InvalidType): return val
|
||||
# NOTE: float('nan') != float('nan'), so we canonicalize here
|
||||
if isinstance(val, float) and math.isnan(val): val = math.nan
|
||||
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
||||
# int is the default
|
||||
return float(val) if dtypes.is_float(dtype) else bool(val) if dtypes.is_bool(dtype) else int(val)
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def min(dtype:DType):
|
||||
@@ -169,6 +170,8 @@ class dtypes:
|
||||
uint32: Final[DType] = DType.new(6, 32, "unsigned int", 'I')
|
||||
int64: Final[DType] = DType.new(7, 64, "long", 'q')
|
||||
uint64: Final[DType] = DType.new(8, 64, "unsigned long", 'Q')
|
||||
_uint128: Final[DType] = DType.new(8, 128, "uint128", None)
|
||||
_uint256: Final[DType] = DType.new(8, 256, "uint256", None)
|
||||
fp8e4m3: Final[DType] = DType.new(9, 8, "float8_e4m3", None)
|
||||
fp8e5m2: Final[DType] = DType.new(10, 8, "float8_e5m2", None)
|
||||
float16: Final[DType] = DType.new(11, 16, "half", 'e')
|
||||
@@ -208,7 +211,7 @@ def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType)
|
||||
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
||||
# we don't support weak type and complex type
|
||||
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
|
||||
dtypes.int64: [dtypes.fp8e4m3, dtypes.fp8e5m2], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
|
||||
dtypes.int64: [dtypes.uint64], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
|
||||
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.fp8e4m3, dtypes.fp8e5m2],
|
||||
dtypes.fp8e5m2: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e4m3: [dtypes.float16, dtypes.bfloat16],
|
||||
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
|
||||
@@ -222,7 +225,7 @@ def least_upper_dtype(*ds:DType) -> DType:
|
||||
if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
||||
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
|
||||
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "index"))}
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "index", "_"))}
|
||||
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"}
|
||||
|
||||
@functools.cache
|
||||
|
||||
@@ -67,7 +67,7 @@ class LLVMCompiler(Compiler):
|
||||
llvm.LLVMDisposePassBuilderOptions(self.pbo)
|
||||
llvm.LLVMContextDispose(self.context)
|
||||
|
||||
def compile(self, src:str) -> bytes:
|
||||
def compile_to_obj(self, src:str) -> bytes:
|
||||
self.diag_msgs.clear()
|
||||
src_buf = llvm.LLVMCreateMemoryBufferWithMemoryRangeCopy(ctypes.create_string_buffer(src_bytes:=src.encode()), len(src_bytes), b'src')
|
||||
mod = expect(llvm.LLVMParseIRInContext(self.context, src_buf, ctypes.pointer(m:=llvm.LLVMModuleRef()), err:=cerr()), err, m)
|
||||
@@ -80,7 +80,9 @@ class LLVMCompiler(Compiler):
|
||||
obj = ctypes.string_at(llvm.LLVMGetBufferStart(obj_buf), llvm.LLVMGetBufferSize(obj_buf))
|
||||
llvm.LLVMDisposeMemoryBuffer(obj_buf)
|
||||
if self.diag_msgs: raise RuntimeError("llvm diagnostic: " + "\n".join(self.diag_msgs))
|
||||
return jit_loader(obj) if self.jit else obj
|
||||
return obj
|
||||
|
||||
def compile(self, src:str) -> bytes: return jit_loader(self.compile_to_obj(src)) if self.jit else self.compile_to_obj(src)
|
||||
|
||||
def disassemble(self, lib:bytes): capstone_flatdump(lib)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user