hcq: move cpu to hcq (#11262)

* hcq: move cpu to hcq

* import time

* upd

* fix

* windows support

* hm

* cleaner

* fix timer

* fix timing

* std is ns

* skip profiler

* mypy

* cleaner

* cleanups

* after merge

* default is back
This commit is contained in:
nimlgen
2025-07-21 15:10:38 +03:00
committed by GitHub
parent 816c01c2d4
commit cc3c1e4c14
11 changed files with 131 additions and 108 deletions

View File

@@ -7,28 +7,30 @@
print("******** first, the runtime ***********")
from tinygrad.runtime.ops_cpu import ClangJITCompiler, MallocAllocator, CPUProgram
from tinygrad.runtime.ops_cpu import ClangJITCompiler, CPUDevice, CPUProgram
cpu = CPUDevice()
# allocate some buffers
out = MallocAllocator.alloc(4)
a = MallocAllocator.alloc(4)
b = MallocAllocator.alloc(4)
out = cpu.allocator.alloc(4)
a = cpu.allocator.alloc(4)
b = cpu.allocator.alloc(4)
# load in some values (little endian)
MallocAllocator._copyin(a, memoryview(bytearray([2,0,0,0])))
MallocAllocator._copyin(b, memoryview(bytearray([3,0,0,0])))
cpu.allocator._copyin(a, memoryview(bytearray([2,0,0,0])))
cpu.allocator._copyin(b, memoryview(bytearray([3,0,0,0])))
# compile a program to a binary
lib = ClangJITCompiler().compile("void add(int *out, int *a, int *b) { out[0] = a[0] + b[0]; }")
# create a runtime for the program
fxn = CPUProgram("add", lib)
fxn = cpu.runtime("add", lib)
# run the program
fxn(out, a, b)
# check the data out
print(val := MallocAllocator._as_buffer(out).cast("I").tolist()[0])
print(val := cpu.allocator._as_buffer(out).cast("I").tolist()[0])
assert val == 5
@@ -46,7 +48,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
out = Buffer(DEVICE, 1, dtypes.int32).allocate()
a = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 2))))
b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
# NOTE: a._buf is the same as the return from cpu.allocator.alloc
# describe the computation
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)

View File

@@ -10,6 +10,7 @@ import tensorflow as tf
import tf2onnx
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.tensor import Tensor
from tinygrad.helpers import to_mv
from extra.export_model import export_model_clang, compile_net, jit_model
def get_uncompiled_model2(dataset_size=32, output_size=4):
@@ -47,8 +48,8 @@ def compile_onnx_model(onnx_model):
cprog.append("void initialize(float *weights) {")
weights = bytes()
for name,cl in bufs_to_save.items():
cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {len(cl._buf)*4});")
weights += bytes(cl._buf)
cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {cl._buf.size});")
weights += bytes(to_mv(cl._buf.va_addr, cl._buf.size))
cprog.append("}")
# write the weights to disk

View File

@@ -4,7 +4,7 @@ from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Device, Tensor
from tinygrad.engine.jit import TinyJit
from tinygrad.nn.state import get_state_dict
from tinygrad.helpers import Context
from tinygrad.helpers import Context, to_mv
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import Ops
import json
@@ -68,7 +68,7 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
if not wasm:
for name,cl in bufs_to_save.items():
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
weight = ''.join(["\\x%02X"%x for x in bytes(to_mv(cl._buf.va_addr, cl._buf.size))])
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
cprog += [f"{dtype_map[dtype]} {name}[{len}];" if name not in bufs_to_save else f"{dtype_map[dtype]} *{name} = ({dtype_map[dtype]} *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in input_names+output_names]
cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]

View File

@@ -75,7 +75,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@unittest.skipIf(MOCKGPU, "Can't handle async update on MOCKGPU for now")
@unittest.skipIf(MOCKGPU or Device.DEFAULT in {"CPU", "LLVM"}, "Can't handle async update on MOCKGPU for now")
def test_wait_late_set(self):
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
if queue_type is None: continue
@@ -137,6 +137,7 @@ class TestHCQ(unittest.TestCase):
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
assert val == 200.0, f"got val {val}"
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "No globals/locals on LLVM/CPU")
def test_exec_update(self):
sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.global_size[1:])
sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.local_size[1:])
@@ -154,6 +155,7 @@ class TestHCQ(unittest.TestCase):
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
assert val == 0.0, f"got val {val}, should not be updated"
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "No globals/locals on LLVM/CPU")
def test_exec_update_fuzz(self):
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
virt_local = [Variable(f"local_{i}", 0, 0xffffffff, dtypes.uint32) for i in range(3)]
@@ -336,7 +338,7 @@ class TestHCQ(unittest.TestCase):
et = float(sig_en.timestamp - sig_st.timestamp)
print(f"exec kernel time: {et:.2f} us")
assert 0.1 <= et <= (15000 if MOCKGPU else 100)
assert 0.1 <= et <= (15000 if MOCKGPU or Device.DEFAULT in {"CPU", "LLVM"} else 100)
def test_speed_copy_bandwidth(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")

View File

@@ -30,7 +30,10 @@ def helper_profile_filter_device(profile, device:str):
assert len(dev_events) == 1, "only one device registration event is expected"
return [x for x in profile if getattr(x, "device", None) == device], dev_events[0]
@unittest.skipUnless(issubclass(type(Device[Device.DEFAULT]), HCQCompiled) or Device.DEFAULT in {"METAL"}, "HCQ device required to run")
# TODO: support in HCQCompiled
is_cpu_hcq = Device.DEFAULT in {"CPU", "LLVM"}
@unittest.skipUnless((issubclass(type(Device[Device.DEFAULT]), HCQCompiled) and not is_cpu_hcq) or Device.DEFAULT in {"METAL"}, "Dev not supported")
class TestProfiler(unittest.TestCase):
@classmethod
def setUpClass(self):

View File

@@ -2,9 +2,9 @@ from __future__ import annotations
from dataclasses import dataclass, replace, field
from collections import defaultdict
from typing import Any, Generic, TypeVar, Iterator
import importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, cpu_events, ProfileEvent
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal, time
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, \
colored, Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, cpu_events, ProfileEvent
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer
@@ -260,74 +260,6 @@ class LRUAllocator(Allocator, Generic[DeviceType]):
if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
else: super().free(opaque, size, options)
class _MallocAllocator(LRUAllocator['Compiled']):
def _alloc(self, size:int, options:BufferSpec):
# must be aligned to 0x20 for 256-bit ymm registers
# TODO: investigate if this is the cause of nondeterminism in speed
alignment = 0x1000 if size >= 0x1000 else 0x20
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, alignment)
def _alloc_aligned(self, size:int, alignment:int):
buffer = (ctypes.c_uint8 * (size + alignment))()
offset = round_up(ctypes.addressof(buffer), alignment) - ctypes.addressof(buffer)
return (ctypes.c_uint8 * size).from_buffer(buffer, offset)
def _as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
def _as_dmaref(self, buf): return DMACPURef(ctypes.addressof(buf), ctypes.sizeof(buf))
def _copyin(self, dest, src:memoryview): ctypes.memmove(dest, mv_address(src), len(src))
def _copyout(self, dest:memoryview, src): ctypes.memmove(mv_address(dest), src, len(dest))
def _offset(self, buf, size:int, offset:int): return from_mv(self._as_buffer(buf)[offset:offset+size])
MallocAllocator = _MallocAllocator(None) # type: ignore
# NOTE: MAP_JIT is added to mmap module in python 3.13
MAP_JIT = 0x0800
# CPUProgram is a jit/shellcode program that can be just mmapped and jumped to
class CPUProgram:
rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1')
def __init__(self, name:str, lib:bytes):
if sys.platform == "win32":
PAGE_EXECUTE_READWRITE = 0x40
MEM_COMMIT = 0x1000
MEM_RESERVE = 0x2000
ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
ctypes.memmove(self.mem, lib, len(lib))
ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p
proc = ctypes.windll.kernel32.GetCurrentProcess()
ctypes.windll.kernel32.FlushInstructionCache(ctypes.c_void_p(proc), ctypes.c_void_p(self.mem), ctypes.c_size_t(len(lib)))
self.fxn = ctypes.CFUNCTYPE(None)(self.mem)
else:
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
# On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
# MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False)
self.mem.write(lib)
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True)
# __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang.
# libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately
# it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux
# Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5
CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib)))
self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
def __call__(self, *bufs, vals=(), wait=False):
args = list(bufs) + list(vals)
# NOTE: replace this by --target={host's triple}-elf in clang args once we only support macos sequoia and later.
# Apple relaxes abi requirement for stack arguments to always be at least 8 byte aligned on arm64
# https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms
# This hack is required because clang/llvm bug doesn't allow us to just use {host's triple}+'-elf' (relocation failures)
# The bug was fixed in https://github.com/llvm/llvm-project/commit/454cc36630296262cdb6360b60f90a64a97f7f1a but was only backported to xcode 16+
if platform.machine() == "arm64" and OSX: args = args[:8] + [ctypes.c_int64(a) if isinstance(a, int) else a for a in args[8:]]
return cpu_time_execution(lambda: self.fxn(*args), enable=wait)
def __del__(self):
if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE
# **************** for Compiled Devices ****************
class CompileError(Exception): pass

View File

@@ -296,11 +296,6 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip
# *** Exec helpers
def cpu_time_execution(cb, enable):
if enable: st = time.perf_counter()
cb()
if enable: return time.perf_counter()-st
def cpu_objdump(lib, objdump_tool='objdump'):
with tempfile.NamedTemporaryFile(delete=True) as f:
pathlib.Path(f.name).write_bytes(lib)

View File

@@ -139,7 +139,7 @@ class LLVMRenderer(Renderer):
def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops)))
def _render_footer(self, uops: list[UOp]) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }'
def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str:
# NOTE: MallocAllocator promises 0x20 alignment
# NOTE: CPUAllocator promises 0x20 alignment
sargs = ", ".join([f"{ldt(dt)}{' noalias align 32' if isinstance(dt, PtrDType) else ''} {name}" for name,dt in args])
sprefix = "".join([f" {x}" for x in (prefix or []) + [self.abi] if x is not None])
return "\n".join([f"define{sprefix} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])

View File

@@ -1,8 +1,11 @@
import platform, subprocess, sys
from tinygrad.helpers import capstone_flatdump, getenv
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
from __future__ import annotations
import platform, subprocess, sys, ctypes, functools, time
from tinygrad.helpers import capstone_flatdump, getenv, from_mv, to_mv, OSX, mv_address, round_up, wait_cond
from tinygrad.device import Compiler, BufferSpec, DMACPURef
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
from tinygrad.runtime.support.elf import jit_loader
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.uop.ops import sint
class ClangJITCompiler(Compiler):
def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey)
@@ -18,5 +21,84 @@ class ClangJITCompiler(Compiler):
def disassemble(self, lib:bytes): return capstone_flatdump(lib)
class CPUDevice(Compiled):
def __init__(self, device:str): super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram)
class CPUComputeQueue(HWQueue):
def _exec(self, prg, bufs, *args):
prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, args[bufs:]))
def _signal(self, signal_addr, value): to_mv(signal_addr, 4).cast('I')[0] = value
def _wait(self, signal_addr, value): wait_cond(lambda: to_mv(signal_addr, 4).cast('I')[0] >= value, timeout_ms=60000)
def _timestamp(self, timestamp_addr): to_mv(timestamp_addr, 8).cast('Q')[0] = time.perf_counter_ns()
def cmd(self, cmd, *args):
self.q(cmd, len(args), *args)
return self
def memory_barrier(self): return self
def exec(self, prg:CPUProgram, args_state:HCQArgsState, global_size, local_size):
return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals)
def wait(self, signal, value=0): return self.cmd(self._wait, signal.value_addr, value)
def timestamp(self, signal): return self.cmd(self._timestamp, signal.timestamp_addr)
def signal(self, signal, value:sint=0): return self.cmd(self._signal, signal.value_addr, value)
def _submit(self, dev):
# Execute the commands in the queue: fn, argc, args...
off = 0
while off < len(self._q):
self._q[off](*self._q[off + 2:off + 2 + self._q[off + 1]])
off += self._q[off + 1] + 2
# NOTE: MAP_JIT is added to mmap module in python 3.13
MAP_JIT = 0x0800
class CPUProgram(HCQProgram):
rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1')
def __init__(self, dev, name:str, lib:bytes):
if sys.platform == "win32":
PAGE_EXECUTE_READWRITE, MEM_COMMIT, MEM_RESERVE = 0x40, 0x1000, 0x2000
ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
ctypes.memmove(self.mem, lib, len(lib))
ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p
proc = ctypes.windll.kernel32.GetCurrentProcess()
ctypes.windll.kernel32.FlushInstructionCache(ctypes.c_void_p(proc), ctypes.c_void_p(self.mem), ctypes.c_size_t(len(lib)))
self.fxn = ctypes.CFUNCTYPE(None)(self.mem)
else:
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
# On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
# MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False)
self.mem.write(lib)
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True)
# __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang.
# libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately
# it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux
# Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5
CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib)))
self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
super().__init__(HCQArgsState, dev, name, kernargs_alloc_size=0)
def __del__(self):
if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE
class CPUAllocator(HCQAllocatorBase):
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
if options.external_ptr: buf = (ctypes.c_uint8 * size).from_address(options.external_ptr)
else:
offset = round_up(ctypes.addressof(tmpbuf:=(ctypes.c_uint8 * (size + 0x1000))()), 0x1000) - ctypes.addressof(tmpbuf)
buf = (ctypes.c_uint8 * size).from_buffer(tmpbuf, offset)
return HCQBuffer(va:=ctypes.addressof(buf), sz:=ctypes.sizeof(buf), meta=buf, view=MMIOInterface(va, sz, fmt='B'), owner=self.dev)
def _as_buffer(self, src) -> memoryview: return to_mv(src.va_addr, src.size)
def _as_dmaref(self, buf): return DMACPURef(buf.va_addr, buf.size)
def _copyin(self, dest, src:memoryview): ctypes.memmove(dest.va_addr, from_mv(src), len(src))
def _copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src.va_addr, len(dest))
def _map(self, buf:HCQBuffer):
if buf.view is None or not isinstance(buf.view, MMIOInterface): raise RuntimeError("Cannot map buffer without view to cpu")
class CPUDevice(HCQCompiled):
def __init__(self, device:str=""):
super().__init__(device, CPUAllocator(self), ClangRenderer(), ClangJITCompiler(), functools.partial(CPUProgram, self), HCQSignal, CPUComputeQueue,
supports_graph=False)

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess, struct
assert sys.platform != 'win32'
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler, MallocAllocator
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler
from tinygrad.runtime.ops_cpu import CPUAllocator
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.uop.ops import Ops, UOp
from tinygrad.helpers import getenv, round_up, mv_address, to_mv, cpu_objdump, DEBUG
@@ -131,7 +132,8 @@ class DSPDevice(Compiled):
def __init__(self, device:str=""):
compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b"]
if getenv("MOCKDSP"):
super().__init__(device, MallocAllocator, MockDSPRenderer(), ClangCompiler(None, ["-static"] + compiler_args, 'llvm-objdump'), MockDSPProgram)
super().__init__(device, CPUAllocator(self), MockDSPRenderer(),
ClangCompiler(None, ["-static"] + compiler_args, 'llvm-objdump'), MockDSPProgram)
else:
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
# Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
@@ -293,11 +295,11 @@ class MockDSPProgram:
os.chmod(dsp_lib.name, 0o0777)
# NOTE: this timing includes a docker launch
proc = subprocess.run(["docker", "run", "--rm", "-i", "-v", f"{os.path.abspath(os.path.dirname(dsp_lib.name))}:/work", "-w", "/work",
"qemu-hexagon", "-c", f"qemu-hexagon {'-strace' if DEBUG >= 5 else ''} /work/"+os.path.basename(dsp_lib.name)],
input=b''.join([bytes(x) for x in bufs] + [struct.pack("I", x) for x in vals]), stdout=subprocess.PIPE, check=True)
"qemu-hexagon", "-c", f"qemu-hexagon {'-strace' if DEBUG >= 5 else ''} /work/"+os.path.basename(dsp_lib.name)],
input=b''.join([bytes(to_mv(x.va_addr, x.size)) for x in bufs] + [struct.pack("I", x) for x in vals]), stdout=subprocess.PIPE, check=True)
offset = 4
for x in bufs:
x[:] = proc.stdout[offset:offset+len(x)]
offset += len(x)
x.cpu_view()[:] = proc.stdout[offset:offset+x.size]
offset += x.size
assert offset == len(proc.stdout)
return struct.unpack("I", proc.stdout[0:4])[0] / 1e9 # pretend it's 1 Ghz, but this is an inscount, not a time

View File

@@ -1,5 +1,7 @@
import ctypes, platform
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
import ctypes, platform, functools
from tinygrad.device import Compiler
from tinygrad.runtime.support.hcq import HCQCompiled, HCQSignal
from tinygrad.runtime.ops_cpu import CPUAllocator, CPUProgram, CPUComputeQueue
from tinygrad.helpers import OSX, getenv, capstone_flatdump, DEBUG
from tinygrad.renderer.llvmir import LLVMRenderer
import tinygrad.runtime.autogen.llvm as llvm
@@ -69,5 +71,7 @@ class HostLLVMCompiler(LLVMCompiler):
cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures())
super().__init__(cpu.decode(), feats.decode())
class LLVMDevice(Compiled):
def __init__(self, device:str): super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram)
class LLVMDevice(HCQCompiled):
def __init__(self, device:str=""):
super().__init__(device, CPUAllocator(self), LLVMRenderer(), HostLLVMCompiler(), functools.partial(CPUProgram, self), HCQSignal, CPUComputeQueue,
supports_graph=False)