From cc3c1e4c1456dddc2beb469154d4d1f3e34b5efa Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Mon, 21 Jul 2025 15:10:38 +0300 Subject: [PATCH] hcq: move cpu to hcq (#11262) * hcq: move cpu to hcq * import time * upd * fix * windows support * hm * cleaner * fix timer * fix timing * std is ns * skip profiler * mypy * cleaner * cleanups * after merge * default is back --- docs/abstractions2.py | 20 ++++---- examples/compile_tensorflow.py | 5 +- extra/export_model.py | 4 +- test/test_hcq.py | 6 ++- test/test_profiler.py | 5 +- tinygrad/device.py | 74 ++------------------------- tinygrad/helpers.py | 5 -- tinygrad/renderer/llvmir.py | 2 +- tinygrad/runtime/ops_cpu.py | 92 ++++++++++++++++++++++++++++++++-- tinygrad/runtime/ops_dsp.py | 14 +++--- tinygrad/runtime/ops_llvm.py | 12 +++-- 11 files changed, 131 insertions(+), 108 deletions(-) diff --git a/docs/abstractions2.py b/docs/abstractions2.py index c9f6a92111..f633474202 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -7,28 +7,30 @@ print("******** first, the runtime ***********") -from tinygrad.runtime.ops_cpu import ClangJITCompiler, MallocAllocator, CPUProgram +from tinygrad.runtime.ops_cpu import ClangJITCompiler, CPUDevice, CPUProgram + +cpu = CPUDevice() # allocate some buffers -out = MallocAllocator.alloc(4) -a = MallocAllocator.alloc(4) -b = MallocAllocator.alloc(4) +out = cpu.allocator.alloc(4) +a = cpu.allocator.alloc(4) +b = cpu.allocator.alloc(4) # load in some values (little endian) -MallocAllocator._copyin(a, memoryview(bytearray([2,0,0,0]))) -MallocAllocator._copyin(b, memoryview(bytearray([3,0,0,0]))) +cpu.allocator._copyin(a, memoryview(bytearray([2,0,0,0]))) +cpu.allocator._copyin(b, memoryview(bytearray([3,0,0,0]))) # compile a program to a binary lib = ClangJITCompiler().compile("void add(int *out, int *a, int *b) { out[0] = a[0] + b[0]; }") # create a runtime for the program -fxn = CPUProgram("add", lib) +fxn = cpu.runtime("add", lib) # run the program fxn(out, a, b) # check the data out -print(val := MallocAllocator._as_buffer(out).cast("I").tolist()[0]) +print(val := cpu.allocator._as_buffer(out).cast("I").tolist()[0]) assert val == 5 @@ -46,7 +48,7 @@ from tinygrad.shape.shapetracker import ShapeTracker out = Buffer(DEVICE, 1, dtypes.int32).allocate() a = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 2)))) b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 3)))) -# NOTE: a._buf is the same as the return from MallocAllocator.alloc +# NOTE: a._buf is the same as the return from cpu.allocator.alloc # describe the computation buf_1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1) diff --git a/examples/compile_tensorflow.py b/examples/compile_tensorflow.py index 1f308c58e3..33434c831c 100644 --- a/examples/compile_tensorflow.py +++ b/examples/compile_tensorflow.py @@ -10,6 +10,7 @@ import tensorflow as tf import tf2onnx from tinygrad.frontend.onnx import OnnxRunner from tinygrad.tensor import Tensor +from tinygrad.helpers import to_mv from extra.export_model import export_model_clang, compile_net, jit_model def get_uncompiled_model2(dataset_size=32, output_size=4): @@ -47,8 +48,8 @@ def compile_onnx_model(onnx_model): cprog.append("void initialize(float *weights) {") weights = bytes() for name,cl in bufs_to_save.items(): - cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {len(cl._buf)*4});") - weights += bytes(cl._buf) + cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {cl._buf.size});") + weights += bytes(to_mv(cl._buf.va_addr, cl._buf.size)) cprog.append("}") # write the weights to disk diff --git a/extra/export_model.py b/extra/export_model.py index 2b2aa1dc62..2d3d342652 100644 --- a/extra/export_model.py +++ b/extra/export_model.py @@ -4,7 +4,7 @@ from tinygrad.renderer import ProgramSpec from tinygrad.tensor import Device, Tensor from tinygrad.engine.jit import TinyJit from tinygrad.nn.state import get_state_dict -from tinygrad.helpers import Context +from tinygrad.helpers import Context, to_mv from tinygrad.dtype import dtypes from tinygrad.uop.ops import Ops import json @@ -68,7 +68,7 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in if not wasm: for name,cl in bufs_to_save.items(): - weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)]) + weight = ''.join(["\\x%02X"%x for x in bytes(to_mv(cl._buf.va_addr, cl._buf.size))]) cprog.append(f"unsigned char {name}_data[] = \"{weight}\";") cprog += [f"{dtype_map[dtype]} {name}[{len}];" if name not in bufs_to_save else f"{dtype_map[dtype]} *{name} = ({dtype_map[dtype]} *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in input_names+output_names] cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"] diff --git a/test/test_hcq.py b/test/test_hcq.py index 61996cd3e1..b657620fe7 100644 --- a/test/test_hcq.py +++ b/test/test_hcq.py @@ -75,7 +75,7 @@ class TestHCQ(unittest.TestCase): TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - @unittest.skipIf(MOCKGPU, "Can't handle async update on MOCKGPU for now") + @unittest.skipIf(MOCKGPU or Device.DEFAULT in {"CPU", "LLVM"}, "Can't handle async update on MOCKGPU for now") def test_wait_late_set(self): for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]: if queue_type is None: continue @@ -137,6 +137,7 @@ class TestHCQ(unittest.TestCase): val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0] assert val == 200.0, f"got val {val}" + @unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "No globals/locals on LLVM/CPU") def test_exec_update(self): sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.global_size[1:]) sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.local_size[1:]) @@ -154,6 +155,7 @@ class TestHCQ(unittest.TestCase): val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1] assert val == 0.0, f"got val {val}, should not be updated" + @unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "No globals/locals on LLVM/CPU") def test_exec_update_fuzz(self): virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32) virt_local = [Variable(f"local_{i}", 0, 0xffffffff, dtypes.uint32) for i in range(3)] @@ -336,7 +338,7 @@ class TestHCQ(unittest.TestCase): et = float(sig_en.timestamp - sig_st.timestamp) print(f"exec kernel time: {et:.2f} us") - assert 0.1 <= et <= (15000 if MOCKGPU else 100) + assert 0.1 <= et <= (15000 if MOCKGPU or Device.DEFAULT in {"CPU", "LLVM"} else 100) def test_speed_copy_bandwidth(self): if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue") diff --git a/test/test_profiler.py b/test/test_profiler.py index 6b4d1fc574..44deeace0c 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -30,7 +30,10 @@ def helper_profile_filter_device(profile, device:str): assert len(dev_events) == 1, "only one device registration event is expected" return [x for x in profile if getattr(x, "device", None) == device], dev_events[0] -@unittest.skipUnless(issubclass(type(Device[Device.DEFAULT]), HCQCompiled) or Device.DEFAULT in {"METAL"}, "HCQ device required to run") +# TODO: support in HCQCompiled +is_cpu_hcq = Device.DEFAULT in {"CPU", "LLVM"} + +@unittest.skipUnless((issubclass(type(Device[Device.DEFAULT]), HCQCompiled) and not is_cpu_hcq) or Device.DEFAULT in {"METAL"}, "Dev not supported") class TestProfiler(unittest.TestCase): @classmethod def setUpClass(self): diff --git a/tinygrad/device.py b/tinygrad/device.py index 06ef872221..6f9f3999a6 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -2,9 +2,9 @@ from __future__ import annotations from dataclasses import dataclass, replace, field from collections import defaultdict from typing import Any, Generic, TypeVar, Iterator -import importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time -from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \ - cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, cpu_events, ProfileEvent +import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal, time +from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, \ + colored, Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, cpu_events, ProfileEvent from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype from tinygrad.renderer import Renderer @@ -260,74 +260,6 @@ class LRUAllocator(Allocator, Generic[DeviceType]): if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque) else: super().free(opaque, size, options) -class _MallocAllocator(LRUAllocator['Compiled']): - def _alloc(self, size:int, options:BufferSpec): - # must be aligned to 0x20 for 256-bit ymm registers - # TODO: investigate if this is the cause of nondeterminism in speed - alignment = 0x1000 if size >= 0x1000 else 0x20 - return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, alignment) - def _alloc_aligned(self, size:int, alignment:int): - buffer = (ctypes.c_uint8 * (size + alignment))() - offset = round_up(ctypes.addressof(buffer), alignment) - ctypes.addressof(buffer) - return (ctypes.c_uint8 * size).from_buffer(buffer, offset) - def _as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src)) - def _as_dmaref(self, buf): return DMACPURef(ctypes.addressof(buf), ctypes.sizeof(buf)) - def _copyin(self, dest, src:memoryview): ctypes.memmove(dest, mv_address(src), len(src)) - def _copyout(self, dest:memoryview, src): ctypes.memmove(mv_address(dest), src, len(dest)) - def _offset(self, buf, size:int, offset:int): return from_mv(self._as_buffer(buf)[offset:offset+size]) - -MallocAllocator = _MallocAllocator(None) # type: ignore - -# NOTE: MAP_JIT is added to mmap module in python 3.13 -MAP_JIT = 0x0800 - -# CPUProgram is a jit/shellcode program that can be just mmapped and jumped to -class CPUProgram: - rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1') - - def __init__(self, name:str, lib:bytes): - if sys.platform == "win32": - PAGE_EXECUTE_READWRITE = 0x40 - MEM_COMMIT = 0x1000 - MEM_RESERVE = 0x2000 - ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p - self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE) - ctypes.memmove(self.mem, lib, len(lib)) - ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p - proc = ctypes.windll.kernel32.GetCurrentProcess() - ctypes.windll.kernel32.FlushInstructionCache(ctypes.c_void_p(proc), ctypes.c_void_p(self.mem), ctypes.c_size_t(len(lib))) - self.fxn = ctypes.CFUNCTYPE(None)(self.mem) - else: - from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE - # On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/ - # MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np) - self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC) - - if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False) - self.mem.write(lib) - if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True) - - # __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang. - # libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately - # it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux - # Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5 - CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib))) - - self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem)) - - def __call__(self, *bufs, vals=(), wait=False): - args = list(bufs) + list(vals) - # NOTE: replace this by --target={host's triple}-elf in clang args once we only support macos sequoia and later. - # Apple relaxes abi requirement for stack arguments to always be at least 8 byte aligned on arm64 - # https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms - # This hack is required because clang/llvm bug doesn't allow us to just use {host's triple}+'-elf' (relocation failures) - # The bug was fixed in https://github.com/llvm/llvm-project/commit/454cc36630296262cdb6360b60f90a64a97f7f1a but was only backported to xcode 16+ - if platform.machine() == "arm64" and OSX: args = args[:8] + [ctypes.c_int64(a) if isinstance(a, int) else a for a in args[8:]] - return cpu_time_execution(lambda: self.fxn(*args), enable=wait) - - def __del__(self): - if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE - # **************** for Compiled Devices **************** class CompileError(Exception): pass diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 7d47014922..b9f94cedd9 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -296,11 +296,6 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip # *** Exec helpers -def cpu_time_execution(cb, enable): - if enable: st = time.perf_counter() - cb() - if enable: return time.perf_counter()-st - def cpu_objdump(lib, objdump_tool='objdump'): with tempfile.NamedTemporaryFile(delete=True) as f: pathlib.Path(f.name).write_bytes(lib) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index cf203cee9f..4e05f3cedb 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -139,7 +139,7 @@ class LLVMRenderer(Renderer): def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops))) def _render_footer(self, uops: list[UOp]) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }' def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str: - # NOTE: MallocAllocator promises 0x20 alignment + # NOTE: CPUAllocator promises 0x20 alignment sargs = ", ".join([f"{ldt(dt)}{' noalias align 32' if isinstance(dt, PtrDType) else ''} {name}" for name,dt in args]) sprefix = "".join([f" {x}" for x in (prefix or []) + [self.abi] if x is not None]) return "\n".join([f"define{sprefix} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"]) diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index c5a15afb52..267a6d765b 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -1,8 +1,11 @@ -import platform, subprocess, sys -from tinygrad.helpers import capstone_flatdump, getenv -from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram +from __future__ import annotations +import platform, subprocess, sys, ctypes, functools, time +from tinygrad.helpers import capstone_flatdump, getenv, from_mv, to_mv, OSX, mv_address, round_up, wait_cond +from tinygrad.device import Compiler, BufferSpec, DMACPURef +from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface from tinygrad.runtime.support.elf import jit_loader from tinygrad.renderer.cstyle import ClangRenderer +from tinygrad.uop.ops import sint class ClangJITCompiler(Compiler): def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey) @@ -18,5 +21,84 @@ class ClangJITCompiler(Compiler): def disassemble(self, lib:bytes): return capstone_flatdump(lib) -class CPUDevice(Compiled): - def __init__(self, device:str): super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram) +class CPUComputeQueue(HWQueue): + def _exec(self, prg, bufs, *args): + prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, args[bufs:])) + def _signal(self, signal_addr, value): to_mv(signal_addr, 4).cast('I')[0] = value + def _wait(self, signal_addr, value): wait_cond(lambda: to_mv(signal_addr, 4).cast('I')[0] >= value, timeout_ms=60000) + def _timestamp(self, timestamp_addr): to_mv(timestamp_addr, 8).cast('Q')[0] = time.perf_counter_ns() + def cmd(self, cmd, *args): + self.q(cmd, len(args), *args) + return self + + def memory_barrier(self): return self + def exec(self, prg:CPUProgram, args_state:HCQArgsState, global_size, local_size): + return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals) + def wait(self, signal, value=0): return self.cmd(self._wait, signal.value_addr, value) + def timestamp(self, signal): return self.cmd(self._timestamp, signal.timestamp_addr) + def signal(self, signal, value:sint=0): return self.cmd(self._signal, signal.value_addr, value) + + def _submit(self, dev): + # Execute the commands in the queue: fn, argc, args... + off = 0 + while off < len(self._q): + self._q[off](*self._q[off + 2:off + 2 + self._q[off + 1]]) + off += self._q[off + 1] + 2 + +# NOTE: MAP_JIT is added to mmap module in python 3.13 +MAP_JIT = 0x0800 + +class CPUProgram(HCQProgram): + rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1') + + def __init__(self, dev, name:str, lib:bytes): + if sys.platform == "win32": + PAGE_EXECUTE_READWRITE, MEM_COMMIT, MEM_RESERVE = 0x40, 0x1000, 0x2000 + ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p + self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE) + ctypes.memmove(self.mem, lib, len(lib)) + ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p + proc = ctypes.windll.kernel32.GetCurrentProcess() + ctypes.windll.kernel32.FlushInstructionCache(ctypes.c_void_p(proc), ctypes.c_void_p(self.mem), ctypes.c_size_t(len(lib))) + self.fxn = ctypes.CFUNCTYPE(None)(self.mem) + else: + from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE + # On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/ + # MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np) + self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC) + + if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False) + self.mem.write(lib) + if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True) + + # __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang. + # libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately + # it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux + # Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5 + CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib))) + + self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem)) + + super().__init__(HCQArgsState, dev, name, kernargs_alloc_size=0) + + def __del__(self): + if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE + +class CPUAllocator(HCQAllocatorBase): + def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer: + if options.external_ptr: buf = (ctypes.c_uint8 * size).from_address(options.external_ptr) + else: + offset = round_up(ctypes.addressof(tmpbuf:=(ctypes.c_uint8 * (size + 0x1000))()), 0x1000) - ctypes.addressof(tmpbuf) + buf = (ctypes.c_uint8 * size).from_buffer(tmpbuf, offset) + return HCQBuffer(va:=ctypes.addressof(buf), sz:=ctypes.sizeof(buf), meta=buf, view=MMIOInterface(va, sz, fmt='B'), owner=self.dev) + def _as_buffer(self, src) -> memoryview: return to_mv(src.va_addr, src.size) + def _as_dmaref(self, buf): return DMACPURef(buf.va_addr, buf.size) + def _copyin(self, dest, src:memoryview): ctypes.memmove(dest.va_addr, from_mv(src), len(src)) + def _copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src.va_addr, len(dest)) + def _map(self, buf:HCQBuffer): + if buf.view is None or not isinstance(buf.view, MMIOInterface): raise RuntimeError("Cannot map buffer without view to cpu") + +class CPUDevice(HCQCompiled): + def __init__(self, device:str=""): + super().__init__(device, CPUAllocator(self), ClangRenderer(), ClangJITCompiler(), functools.partial(CPUProgram, self), HCQSignal, CPUComputeQueue, + supports_graph=False) diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index 2b921af795..640d992ecf 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -1,7 +1,8 @@ from __future__ import annotations import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess, struct assert sys.platform != 'win32' -from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler, MallocAllocator +from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler +from tinygrad.runtime.ops_cpu import CPUAllocator from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.uop.ops import Ops, UOp from tinygrad.helpers import getenv, round_up, mv_address, to_mv, cpu_objdump, DEBUG @@ -131,7 +132,8 @@ class DSPDevice(Compiled): def __init__(self, device:str=""): compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b"] if getenv("MOCKDSP"): - super().__init__(device, MallocAllocator, MockDSPRenderer(), ClangCompiler(None, ["-static"] + compiler_args, 'llvm-objdump'), MockDSPProgram) + super().__init__(device, CPUAllocator(self), MockDSPRenderer(), + ClangCompiler(None, ["-static"] + compiler_args, 'llvm-objdump'), MockDSPProgram) else: self.ion_fd = os.open('/dev/ion', os.O_RDONLY) # Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem. @@ -293,11 +295,11 @@ class MockDSPProgram: os.chmod(dsp_lib.name, 0o0777) # NOTE: this timing includes a docker launch proc = subprocess.run(["docker", "run", "--rm", "-i", "-v", f"{os.path.abspath(os.path.dirname(dsp_lib.name))}:/work", "-w", "/work", - "qemu-hexagon", "-c", f"qemu-hexagon {'-strace' if DEBUG >= 5 else ''} /work/"+os.path.basename(dsp_lib.name)], - input=b''.join([bytes(x) for x in bufs] + [struct.pack("I", x) for x in vals]), stdout=subprocess.PIPE, check=True) + "qemu-hexagon", "-c", f"qemu-hexagon {'-strace' if DEBUG >= 5 else ''} /work/"+os.path.basename(dsp_lib.name)], + input=b''.join([bytes(to_mv(x.va_addr, x.size)) for x in bufs] + [struct.pack("I", x) for x in vals]), stdout=subprocess.PIPE, check=True) offset = 4 for x in bufs: - x[:] = proc.stdout[offset:offset+len(x)] - offset += len(x) + x.cpu_view()[:] = proc.stdout[offset:offset+x.size] + offset += x.size assert offset == len(proc.stdout) return struct.unpack("I", proc.stdout[0:4])[0] / 1e9 # pretend it's 1 Ghz, but this is an inscount, not a time diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index c628ad138b..3bfa0dd351 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -1,5 +1,7 @@ -import ctypes, platform -from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram +import ctypes, platform, functools +from tinygrad.device import Compiler +from tinygrad.runtime.support.hcq import HCQCompiled, HCQSignal +from tinygrad.runtime.ops_cpu import CPUAllocator, CPUProgram, CPUComputeQueue from tinygrad.helpers import OSX, getenv, capstone_flatdump, DEBUG from tinygrad.renderer.llvmir import LLVMRenderer import tinygrad.runtime.autogen.llvm as llvm @@ -69,5 +71,7 @@ class HostLLVMCompiler(LLVMCompiler): cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures()) super().__init__(cpu.decode(), feats.decode()) -class LLVMDevice(Compiled): - def __init__(self, device:str): super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram) +class LLVMDevice(HCQCompiled): + def __init__(self, device:str=""): + super().__init__(device, CPUAllocator(self), LLVMRenderer(), HostLLVMCompiler(), functools.partial(CPUProgram, self), HCQSignal, CPUComputeQueue, + supports_graph=False)