mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
hcq: move cpu to hcq (#11262)
* hcq: move cpu to hcq * import time * upd * fix * windows support * hm * cleaner * fix timer * fix timing * std is ns * skip profiler * mypy * cleaner * cleanups * after merge * default is back
This commit is contained in:
@@ -7,28 +7,30 @@
|
||||
|
||||
print("******** first, the runtime ***********")
|
||||
|
||||
from tinygrad.runtime.ops_cpu import ClangJITCompiler, MallocAllocator, CPUProgram
|
||||
from tinygrad.runtime.ops_cpu import ClangJITCompiler, CPUDevice, CPUProgram
|
||||
|
||||
cpu = CPUDevice()
|
||||
|
||||
# allocate some buffers
|
||||
out = MallocAllocator.alloc(4)
|
||||
a = MallocAllocator.alloc(4)
|
||||
b = MallocAllocator.alloc(4)
|
||||
out = cpu.allocator.alloc(4)
|
||||
a = cpu.allocator.alloc(4)
|
||||
b = cpu.allocator.alloc(4)
|
||||
|
||||
# load in some values (little endian)
|
||||
MallocAllocator._copyin(a, memoryview(bytearray([2,0,0,0])))
|
||||
MallocAllocator._copyin(b, memoryview(bytearray([3,0,0,0])))
|
||||
cpu.allocator._copyin(a, memoryview(bytearray([2,0,0,0])))
|
||||
cpu.allocator._copyin(b, memoryview(bytearray([3,0,0,0])))
|
||||
|
||||
# compile a program to a binary
|
||||
lib = ClangJITCompiler().compile("void add(int *out, int *a, int *b) { out[0] = a[0] + b[0]; }")
|
||||
|
||||
# create a runtime for the program
|
||||
fxn = CPUProgram("add", lib)
|
||||
fxn = cpu.runtime("add", lib)
|
||||
|
||||
# run the program
|
||||
fxn(out, a, b)
|
||||
|
||||
# check the data out
|
||||
print(val := MallocAllocator._as_buffer(out).cast("I").tolist()[0])
|
||||
print(val := cpu.allocator._as_buffer(out).cast("I").tolist()[0])
|
||||
assert val == 5
|
||||
|
||||
|
||||
@@ -46,7 +48,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
out = Buffer(DEVICE, 1, dtypes.int32).allocate()
|
||||
a = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 2))))
|
||||
b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
|
||||
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
|
||||
# NOTE: a._buf is the same as the return from cpu.allocator.alloc
|
||||
|
||||
# describe the computation
|
||||
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)
|
||||
|
||||
@@ -10,6 +10,7 @@ import tensorflow as tf
|
||||
import tf2onnx
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import to_mv
|
||||
from extra.export_model import export_model_clang, compile_net, jit_model
|
||||
|
||||
def get_uncompiled_model2(dataset_size=32, output_size=4):
|
||||
@@ -47,8 +48,8 @@ def compile_onnx_model(onnx_model):
|
||||
cprog.append("void initialize(float *weights) {")
|
||||
weights = bytes()
|
||||
for name,cl in bufs_to_save.items():
|
||||
cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {len(cl._buf)*4});")
|
||||
weights += bytes(cl._buf)
|
||||
cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {cl._buf.size});")
|
||||
weights += bytes(to_mv(cl._buf.va_addr, cl._buf.size))
|
||||
cprog.append("}")
|
||||
|
||||
# write the weights to disk
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.tensor import Device, Tensor
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.helpers import Context, to_mv
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import Ops
|
||||
import json
|
||||
@@ -68,7 +68,7 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
|
||||
|
||||
if not wasm:
|
||||
for name,cl in bufs_to_save.items():
|
||||
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
|
||||
weight = ''.join(["\\x%02X"%x for x in bytes(to_mv(cl._buf.va_addr, cl._buf.size))])
|
||||
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
|
||||
cprog += [f"{dtype_map[dtype]} {name}[{len}];" if name not in bufs_to_save else f"{dtype_map[dtype]} *{name} = ({dtype_map[dtype]} *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in input_names+output_names]
|
||||
cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestHCQ(unittest.TestCase):
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@unittest.skipIf(MOCKGPU, "Can't handle async update on MOCKGPU for now")
|
||||
@unittest.skipIf(MOCKGPU or Device.DEFAULT in {"CPU", "LLVM"}, "Can't handle async update on MOCKGPU for now")
|
||||
def test_wait_late_set(self):
|
||||
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
|
||||
if queue_type is None: continue
|
||||
@@ -137,6 +137,7 @@ class TestHCQ(unittest.TestCase):
|
||||
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 200.0, f"got val {val}"
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "No globals/locals on LLVM/CPU")
|
||||
def test_exec_update(self):
|
||||
sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.global_size[1:])
|
||||
sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.local_size[1:])
|
||||
@@ -154,6 +155,7 @@ class TestHCQ(unittest.TestCase):
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
|
||||
assert val == 0.0, f"got val {val}, should not be updated"
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "No globals/locals on LLVM/CPU")
|
||||
def test_exec_update_fuzz(self):
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
virt_local = [Variable(f"local_{i}", 0, 0xffffffff, dtypes.uint32) for i in range(3)]
|
||||
@@ -336,7 +338,7 @@ class TestHCQ(unittest.TestCase):
|
||||
et = float(sig_en.timestamp - sig_st.timestamp)
|
||||
|
||||
print(f"exec kernel time: {et:.2f} us")
|
||||
assert 0.1 <= et <= (15000 if MOCKGPU else 100)
|
||||
assert 0.1 <= et <= (15000 if MOCKGPU or Device.DEFAULT in {"CPU", "LLVM"} else 100)
|
||||
|
||||
def test_speed_copy_bandwidth(self):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
@@ -30,7 +30,10 @@ def helper_profile_filter_device(profile, device:str):
|
||||
assert len(dev_events) == 1, "only one device registration event is expected"
|
||||
return [x for x in profile if getattr(x, "device", None) == device], dev_events[0]
|
||||
|
||||
@unittest.skipUnless(issubclass(type(Device[Device.DEFAULT]), HCQCompiled) or Device.DEFAULT in {"METAL"}, "HCQ device required to run")
|
||||
# TODO: support in HCQCompiled
|
||||
is_cpu_hcq = Device.DEFAULT in {"CPU", "LLVM"}
|
||||
|
||||
@unittest.skipUnless((issubclass(type(Device[Device.DEFAULT]), HCQCompiled) and not is_cpu_hcq) or Device.DEFAULT in {"METAL"}, "Dev not supported")
|
||||
class TestProfiler(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
|
||||
@@ -2,9 +2,9 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, replace, field
|
||||
from collections import defaultdict
|
||||
from typing import Any, Generic, TypeVar, Iterator
|
||||
import importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
|
||||
cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, cpu_events, ProfileEvent
|
||||
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal, time
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, \
|
||||
colored, Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, cpu_events, ProfileEvent
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -260,74 +260,6 @@ class LRUAllocator(Allocator, Generic[DeviceType]):
|
||||
if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
|
||||
else: super().free(opaque, size, options)
|
||||
|
||||
class _MallocAllocator(LRUAllocator['Compiled']):
|
||||
def _alloc(self, size:int, options:BufferSpec):
|
||||
# must be aligned to 0x20 for 256-bit ymm registers
|
||||
# TODO: investigate if this is the cause of nondeterminism in speed
|
||||
alignment = 0x1000 if size >= 0x1000 else 0x20
|
||||
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, alignment)
|
||||
def _alloc_aligned(self, size:int, alignment:int):
|
||||
buffer = (ctypes.c_uint8 * (size + alignment))()
|
||||
offset = round_up(ctypes.addressof(buffer), alignment) - ctypes.addressof(buffer)
|
||||
return (ctypes.c_uint8 * size).from_buffer(buffer, offset)
|
||||
def _as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
||||
def _as_dmaref(self, buf): return DMACPURef(ctypes.addressof(buf), ctypes.sizeof(buf))
|
||||
def _copyin(self, dest, src:memoryview): ctypes.memmove(dest, mv_address(src), len(src))
|
||||
def _copyout(self, dest:memoryview, src): ctypes.memmove(mv_address(dest), src, len(dest))
|
||||
def _offset(self, buf, size:int, offset:int): return from_mv(self._as_buffer(buf)[offset:offset+size])
|
||||
|
||||
MallocAllocator = _MallocAllocator(None) # type: ignore
|
||||
|
||||
# NOTE: MAP_JIT is added to mmap module in python 3.13
|
||||
MAP_JIT = 0x0800
|
||||
|
||||
# CPUProgram is a jit/shellcode program that can be just mmapped and jumped to
|
||||
class CPUProgram:
|
||||
rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1')
|
||||
|
||||
def __init__(self, name:str, lib:bytes):
|
||||
if sys.platform == "win32":
|
||||
PAGE_EXECUTE_READWRITE = 0x40
|
||||
MEM_COMMIT = 0x1000
|
||||
MEM_RESERVE = 0x2000
|
||||
ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
|
||||
self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
|
||||
ctypes.memmove(self.mem, lib, len(lib))
|
||||
ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p
|
||||
proc = ctypes.windll.kernel32.GetCurrentProcess()
|
||||
ctypes.windll.kernel32.FlushInstructionCache(ctypes.c_void_p(proc), ctypes.c_void_p(self.mem), ctypes.c_size_t(len(lib)))
|
||||
self.fxn = ctypes.CFUNCTYPE(None)(self.mem)
|
||||
else:
|
||||
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
|
||||
# On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
|
||||
# MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
|
||||
self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
|
||||
|
||||
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False)
|
||||
self.mem.write(lib)
|
||||
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True)
|
||||
|
||||
# __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang.
|
||||
# libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately
|
||||
# it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux
|
||||
# Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5
|
||||
CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib)))
|
||||
|
||||
self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
|
||||
|
||||
def __call__(self, *bufs, vals=(), wait=False):
|
||||
args = list(bufs) + list(vals)
|
||||
# NOTE: replace this by --target={host's triple}-elf in clang args once we only support macos sequoia and later.
|
||||
# Apple relaxes abi requirement for stack arguments to always be at least 8 byte aligned on arm64
|
||||
# https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms
|
||||
# This hack is required because clang/llvm bug doesn't allow us to just use {host's triple}+'-elf' (relocation failures)
|
||||
# The bug was fixed in https://github.com/llvm/llvm-project/commit/454cc36630296262cdb6360b60f90a64a97f7f1a but was only backported to xcode 16+
|
||||
if platform.machine() == "arm64" and OSX: args = args[:8] + [ctypes.c_int64(a) if isinstance(a, int) else a for a in args[8:]]
|
||||
return cpu_time_execution(lambda: self.fxn(*args), enable=wait)
|
||||
|
||||
def __del__(self):
|
||||
if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE
|
||||
|
||||
# **************** for Compiled Devices ****************
|
||||
|
||||
class CompileError(Exception): pass
|
||||
|
||||
@@ -296,11 +296,6 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip
|
||||
|
||||
# *** Exec helpers
|
||||
|
||||
def cpu_time_execution(cb, enable):
|
||||
if enable: st = time.perf_counter()
|
||||
cb()
|
||||
if enable: return time.perf_counter()-st
|
||||
|
||||
def cpu_objdump(lib, objdump_tool='objdump'):
|
||||
with tempfile.NamedTemporaryFile(delete=True) as f:
|
||||
pathlib.Path(f.name).write_bytes(lib)
|
||||
|
||||
@@ -139,7 +139,7 @@ class LLVMRenderer(Renderer):
|
||||
def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops)))
|
||||
def _render_footer(self, uops: list[UOp]) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }'
|
||||
def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str:
|
||||
# NOTE: MallocAllocator promises 0x20 alignment
|
||||
# NOTE: CPUAllocator promises 0x20 alignment
|
||||
sargs = ", ".join([f"{ldt(dt)}{' noalias align 32' if isinstance(dt, PtrDType) else ''} {name}" for name,dt in args])
|
||||
sprefix = "".join([f" {x}" for x in (prefix or []) + [self.abi] if x is not None])
|
||||
return "\n".join([f"define{sprefix} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import platform, subprocess, sys
|
||||
from tinygrad.helpers import capstone_flatdump, getenv
|
||||
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
|
||||
from __future__ import annotations
|
||||
import platform, subprocess, sys, ctypes, functools, time
|
||||
from tinygrad.helpers import capstone_flatdump, getenv, from_mv, to_mv, OSX, mv_address, round_up, wait_cond
|
||||
from tinygrad.device import Compiler, BufferSpec, DMACPURef
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
|
||||
from tinygrad.runtime.support.elf import jit_loader
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.uop.ops import sint
|
||||
|
||||
class ClangJITCompiler(Compiler):
|
||||
def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey)
|
||||
@@ -18,5 +21,84 @@ class ClangJITCompiler(Compiler):
|
||||
|
||||
def disassemble(self, lib:bytes): return capstone_flatdump(lib)
|
||||
|
||||
class CPUDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram)
|
||||
class CPUComputeQueue(HWQueue):
|
||||
def _exec(self, prg, bufs, *args):
|
||||
prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, args[bufs:]))
|
||||
def _signal(self, signal_addr, value): to_mv(signal_addr, 4).cast('I')[0] = value
|
||||
def _wait(self, signal_addr, value): wait_cond(lambda: to_mv(signal_addr, 4).cast('I')[0] >= value, timeout_ms=60000)
|
||||
def _timestamp(self, timestamp_addr): to_mv(timestamp_addr, 8).cast('Q')[0] = time.perf_counter_ns()
|
||||
def cmd(self, cmd, *args):
|
||||
self.q(cmd, len(args), *args)
|
||||
return self
|
||||
|
||||
def memory_barrier(self): return self
|
||||
def exec(self, prg:CPUProgram, args_state:HCQArgsState, global_size, local_size):
|
||||
return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals)
|
||||
def wait(self, signal, value=0): return self.cmd(self._wait, signal.value_addr, value)
|
||||
def timestamp(self, signal): return self.cmd(self._timestamp, signal.timestamp_addr)
|
||||
def signal(self, signal, value:sint=0): return self.cmd(self._signal, signal.value_addr, value)
|
||||
|
||||
def _submit(self, dev):
|
||||
# Execute the commands in the queue: fn, argc, args...
|
||||
off = 0
|
||||
while off < len(self._q):
|
||||
self._q[off](*self._q[off + 2:off + 2 + self._q[off + 1]])
|
||||
off += self._q[off + 1] + 2
|
||||
|
||||
# NOTE: MAP_JIT is added to mmap module in python 3.13
|
||||
MAP_JIT = 0x0800
|
||||
|
||||
class CPUProgram(HCQProgram):
|
||||
rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1')
|
||||
|
||||
def __init__(self, dev, name:str, lib:bytes):
|
||||
if sys.platform == "win32":
|
||||
PAGE_EXECUTE_READWRITE, MEM_COMMIT, MEM_RESERVE = 0x40, 0x1000, 0x2000
|
||||
ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
|
||||
self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
|
||||
ctypes.memmove(self.mem, lib, len(lib))
|
||||
ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p
|
||||
proc = ctypes.windll.kernel32.GetCurrentProcess()
|
||||
ctypes.windll.kernel32.FlushInstructionCache(ctypes.c_void_p(proc), ctypes.c_void_p(self.mem), ctypes.c_size_t(len(lib)))
|
||||
self.fxn = ctypes.CFUNCTYPE(None)(self.mem)
|
||||
else:
|
||||
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
|
||||
# On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
|
||||
# MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
|
||||
self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
|
||||
|
||||
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False)
|
||||
self.mem.write(lib)
|
||||
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True)
|
||||
|
||||
# __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang.
|
||||
# libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately
|
||||
# it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux
|
||||
# Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5
|
||||
CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib)))
|
||||
|
||||
self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
|
||||
|
||||
super().__init__(HCQArgsState, dev, name, kernargs_alloc_size=0)
|
||||
|
||||
def __del__(self):
|
||||
if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE
|
||||
|
||||
class CPUAllocator(HCQAllocatorBase):
|
||||
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
|
||||
if options.external_ptr: buf = (ctypes.c_uint8 * size).from_address(options.external_ptr)
|
||||
else:
|
||||
offset = round_up(ctypes.addressof(tmpbuf:=(ctypes.c_uint8 * (size + 0x1000))()), 0x1000) - ctypes.addressof(tmpbuf)
|
||||
buf = (ctypes.c_uint8 * size).from_buffer(tmpbuf, offset)
|
||||
return HCQBuffer(va:=ctypes.addressof(buf), sz:=ctypes.sizeof(buf), meta=buf, view=MMIOInterface(va, sz, fmt='B'), owner=self.dev)
|
||||
def _as_buffer(self, src) -> memoryview: return to_mv(src.va_addr, src.size)
|
||||
def _as_dmaref(self, buf): return DMACPURef(buf.va_addr, buf.size)
|
||||
def _copyin(self, dest, src:memoryview): ctypes.memmove(dest.va_addr, from_mv(src), len(src))
|
||||
def _copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src.va_addr, len(dest))
|
||||
def _map(self, buf:HCQBuffer):
|
||||
if buf.view is None or not isinstance(buf.view, MMIOInterface): raise RuntimeError("Cannot map buffer without view to cpu")
|
||||
|
||||
class CPUDevice(HCQCompiled):
|
||||
def __init__(self, device:str=""):
|
||||
super().__init__(device, CPUAllocator(self), ClangRenderer(), ClangJITCompiler(), functools.partial(CPUProgram, self), HCQSignal, CPUComputeQueue,
|
||||
supports_graph=False)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess, struct
|
||||
assert sys.platform != 'win32'
|
||||
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler, MallocAllocator
|
||||
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler
|
||||
from tinygrad.runtime.ops_cpu import CPUAllocator
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.helpers import getenv, round_up, mv_address, to_mv, cpu_objdump, DEBUG
|
||||
@@ -131,7 +132,8 @@ class DSPDevice(Compiled):
|
||||
def __init__(self, device:str=""):
|
||||
compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b"]
|
||||
if getenv("MOCKDSP"):
|
||||
super().__init__(device, MallocAllocator, MockDSPRenderer(), ClangCompiler(None, ["-static"] + compiler_args, 'llvm-objdump'), MockDSPProgram)
|
||||
super().__init__(device, CPUAllocator(self), MockDSPRenderer(),
|
||||
ClangCompiler(None, ["-static"] + compiler_args, 'llvm-objdump'), MockDSPProgram)
|
||||
else:
|
||||
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
|
||||
# Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
|
||||
@@ -293,11 +295,11 @@ class MockDSPProgram:
|
||||
os.chmod(dsp_lib.name, 0o0777)
|
||||
# NOTE: this timing includes a docker launch
|
||||
proc = subprocess.run(["docker", "run", "--rm", "-i", "-v", f"{os.path.abspath(os.path.dirname(dsp_lib.name))}:/work", "-w", "/work",
|
||||
"qemu-hexagon", "-c", f"qemu-hexagon {'-strace' if DEBUG >= 5 else ''} /work/"+os.path.basename(dsp_lib.name)],
|
||||
input=b''.join([bytes(x) for x in bufs] + [struct.pack("I", x) for x in vals]), stdout=subprocess.PIPE, check=True)
|
||||
"qemu-hexagon", "-c", f"qemu-hexagon {'-strace' if DEBUG >= 5 else ''} /work/"+os.path.basename(dsp_lib.name)],
|
||||
input=b''.join([bytes(to_mv(x.va_addr, x.size)) for x in bufs] + [struct.pack("I", x) for x in vals]), stdout=subprocess.PIPE, check=True)
|
||||
offset = 4
|
||||
for x in bufs:
|
||||
x[:] = proc.stdout[offset:offset+len(x)]
|
||||
offset += len(x)
|
||||
x.cpu_view()[:] = proc.stdout[offset:offset+x.size]
|
||||
offset += x.size
|
||||
assert offset == len(proc.stdout)
|
||||
return struct.unpack("I", proc.stdout[0:4])[0] / 1e9 # pretend it's 1 Ghz, but this is an inscount, not a time
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import ctypes, platform
|
||||
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
|
||||
import ctypes, platform, functools
|
||||
from tinygrad.device import Compiler
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQSignal
|
||||
from tinygrad.runtime.ops_cpu import CPUAllocator, CPUProgram, CPUComputeQueue
|
||||
from tinygrad.helpers import OSX, getenv, capstone_flatdump, DEBUG
|
||||
from tinygrad.renderer.llvmir import LLVMRenderer
|
||||
import tinygrad.runtime.autogen.llvm as llvm
|
||||
@@ -69,5 +71,7 @@ class HostLLVMCompiler(LLVMCompiler):
|
||||
cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures())
|
||||
super().__init__(cpu.decode(), feats.decode())
|
||||
|
||||
class LLVMDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram)
|
||||
class LLVMDevice(HCQCompiled):
|
||||
def __init__(self, device:str=""):
|
||||
super().__init__(device, CPUAllocator(self), LLVMRenderer(), HostLLVMCompiler(), functools.partial(CPUProgram, self), HCQSignal, CPUComputeQueue,
|
||||
supports_graph=False)
|
||||
|
||||
Reference in New Issue
Block a user