BufferSpec and ProgramSpec [pr] (#7814)

* BufferSpec and ProgramSpec [pr]

* delete preallocate, it's unused

* Revert "delete preallocate, it's unused"

This reverts commit dcfcfaccde.
This commit is contained in:
George Hotz
2024-11-21 12:18:05 +08:00
committed by GitHub
parent 490a6130af
commit c5d458ce02
27 changed files with 139 additions and 135 deletions

View File

@@ -1,7 +1,7 @@
import ctypes, collections, time, itertools
from typing import List, Any, Dict, cast, Optional, Tuple
from tinygrad.helpers import init_c_var, round_up
from tinygrad.device import Buffer, BufferOptions
from tinygrad.device import Buffer, BufferSpec
from tinygrad.device import Compiled, Device
from tinygrad.ops import Variable
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
@@ -44,7 +44,7 @@ class HSAGraph(MultiGraphRunner):
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
for ji in self.jit_cache:
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferSpec()) for dev,sz in kernargs_size.items()}
# Fill initial arguments.
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}

View File

@@ -4,7 +4,7 @@ from typing import Tuple, TypeVar, List, Any, cast, Set
import tinygrad.runtime.autogen.hip as hip
from tinygrad.helpers import DEBUG, getenv, init_c_var
from tinygrad.helpers import from_mv, round_up, to_mv, colored, init_c_struct_t
from tinygrad.device import Compiled, LRUAllocator, BufferOptions, Runner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions
from tinygrad.device import Compiled, LRUAllocator, BufferSpec, Runner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.runtime.support.hip_comgr import compile_hip
from tinygrad.renderer.rdna import uops_to_rdna
@@ -93,7 +93,7 @@ class HIPAllocator(LRUAllocator):
def _alloc(self, size:int):
hip_set_device(self.device.device)
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
def _alloc_with_options(self, size:int, options:BufferOptions):
def _alloc_with_options(self, size:int, options:BufferSpec):
hip_set_device(self.device.device)
if options.uncached:
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipExtMallocWithFlags(ctypes.byref(x), size, 3))) # hipDeviceMallocUncached = 3
@@ -105,7 +105,7 @@ class HIPAllocator(LRUAllocator):
def copy_from_fd(self, dest, fd, offset, size):
hip_set_device(self.device.device)
if not hasattr(self, 'hb'):
self.hb = [self._alloc_with_options(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)]
self.hb = [self._alloc_with_options(CHUNK_SIZE, BufferSpec(host=True)) for _ in range(2)]
self.hb_events = [None, None]
self.hb_polarity = 0
fo = io.FileIO(fd, "a+b", closefd=False)
@@ -128,7 +128,7 @@ class HIPAllocator(LRUAllocator):
minor_offset = 0 # only on the first
def _copyin(self, dest:T, src: memoryview):
hip_set_device(self.device.device)
host_mem = self._alloc_with_options(len(src), BufferOptions(host=True))
host_mem = self._alloc_with_options(len(src), BufferSpec(host=True))
self.device.pending_copyin.append(host_mem)
ctypes.memmove(host_mem, from_mv(src), len(src))
check(hip.hipMemcpyAsync(dest, host_mem, len(src), hip.hipMemcpyHostToDevice, None))

View File

@@ -3,7 +3,7 @@ import ctypes, functools, subprocess, io, atexit, collections, json
from typing import Tuple, TypeVar, List, Dict, Any
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv, PROFILE
from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator
from tinygrad.device import Compiled, Compiler, CompileError, BufferSpec, LRUAllocator
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.runtime.support.hsa import check, scan_agents, find_memory_pool, AQLQueue
from tinygrad.runtime.support.hip_comgr import compile_hip
@@ -102,7 +102,7 @@ class HSAAllocator(LRUAllocator):
self.device = device
super().__init__()
def _alloc(self, size:int, options:BufferOptions):
def _alloc(self, size:int, options:BufferSpec):
if options.host:
check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, size, 0, ctypes.byref(mem := ctypes.c_void_p())))
check(hsa.hsa_amd_agents_allow_access(2, (hsa.hsa_agent_t*2)(HSADevice.cpu_agent, self.device.agent), None, mem))
@@ -112,14 +112,14 @@ class HSAAllocator(LRUAllocator):
check(hsa.hsa_amd_agents_allow_access(len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]), c_agents, None, buf))
return buf.value
def _free(self, opaque:T, options:BufferOptions):
def _free(self, opaque:T, options:BufferSpec):
HSADevice.synchronize_system()
check(hsa.hsa_amd_memory_pool_free(opaque))
def _copyin(self, dest:T, src: memoryview):
# Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets.
self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
mem = self._alloc(src.nbytes, BufferOptions(host=True))
mem = self._alloc(src.nbytes, BufferSpec(host=True))
ctypes.memmove(mem, from_mv(src), src.nbytes)
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes, 1, ctypes.byref(sync_signal),
copy_signal := self.device.alloc_signal(reusable=True), hsa.HSA_AMD_SDMA_ENGINE_0, True))
@@ -131,7 +131,7 @@ class HSAAllocator(LRUAllocator):
self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
if not hasattr(self, 'hb'):
self.hb = [self._alloc(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)]
self.hb = [self._alloc(CHUNK_SIZE, BufferSpec(host=True)) for _ in range(2)]
self.hb_signals = [self.device.alloc_signal(reusable=False) for _ in range(2)]
self.hb_polarity = 0
self.sdma = [hsa.HSA_AMD_SDMA_ENGINE_0, hsa.HSA_AMD_SDMA_ENGINE_1]
@@ -256,7 +256,7 @@ class HSADevice(Compiled):
def _new_kernargs_region(self, sz:int):
if hasattr(self, 'kernarg_start_addr'): self.delayed_free.append(self.kernarg_start_addr)
self.kernarg_start_addr: int = self.allocator._alloc(sz, BufferOptions())
self.kernarg_start_addr: int = self.allocator._alloc(sz, BufferSpec())
self.kernarg_next_addr = self.kernarg_start_addr
self.kernarg_pool_sz: int = sz

View File

@@ -1,6 +1,6 @@
from typing import Tuple, Dict, List
from tinygrad.dtype import DType
from tinygrad.renderer import Program
from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Device, Tensor
from tinygrad.engine.jit import TinyJit
from tinygrad.nn.state import get_state_dict
@@ -24,7 +24,7 @@ web_utils = {
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
for ji in run.jit_cache:
fxn: Program = ji.prg.p
fxn: ProgramSpec = ji.prg.p
functions[fxn.function_name] = fxn.src # NOTE: this assumes all with the same name are the same
cargs = []
for i,arg in enumerate(ji.bufs):

View File

@@ -4,7 +4,7 @@ import triton.language as tl
from triton.compiler import AttrsDescriptor, ASTSource, compile as triton_compile
import numpy as np
from tinygrad import Tensor, dtypes, Device
from tinygrad.engine.realize import CompiledRunner, ExecItem, Program
from tinygrad.engine.realize import CompiledRunner, ExecItem, ProgramSpec
from tinygrad.helpers import getenv
np.set_printoptions(suppress=True)
@@ -85,7 +85,7 @@ if __name__ == "__main__":
# remove debug sections
src = src.split("\t.file")[0]
assert '.extern .shared' not in src
prg = Program("matmul_kernel", src, device=Device.DEFAULT,
prg = ProgramSpec("matmul_kernel", src, device=Device.DEFAULT,
global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1],
mem_estimate=A.nbytes() + B.nbytes() + C.nbytes())
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)

View File

@@ -1,7 +1,7 @@
import unittest, ctypes, struct, time, array
from tinygrad import Device, Tensor, dtypes
from tinygrad.helpers import to_mv, CI
from tinygrad.device import Buffer, BufferOptions
from tinygrad.device import Buffer, BufferSpec
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_runner
@@ -255,8 +255,8 @@ class TestHCQ(unittest.TestCase):
def test_copy_bandwidth(self):
# THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least.
SZ = 2_000_000_000
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
b = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
b = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
q = TestHCQ.copy_queue()
q.copy(a._buf.va_addr, b._buf.va_addr, SZ)
et = _time_queue(q, TestHCQ.d0)
@@ -266,8 +266,8 @@ class TestHCQ(unittest.TestCase):
def test_cross_device_copy_bandwidth(self):
SZ = 2_000_000_000
b = Buffer(f"{Device.DEFAULT}:1", SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
b = Buffer(f"{Device.DEFAULT}:1", SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
TestHCQ.d0._gpu_map(b._buf)
q = TestHCQ.copy_queue()
q.copy(a._buf.va_addr, b._buf.va_addr, SZ)

View File

@@ -1,7 +1,7 @@
import unittest, ctypes, struct
from tinygrad import Device, Tensor, dtypes
from tinygrad.helpers import CI, getenv
from tinygrad.device import Buffer, BufferOptions
from tinygrad.device import Buffer, BufferSpec
from tinygrad.runtime.support.hcq import HCQCompiled
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_runner, CompiledRunner
@@ -149,8 +149,8 @@ class TestHCQ(unittest.TestCase):
runner = CompiledRunner(k.to_program())
zb = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferOptions(cpu_access=True, nolru=True)).ensure_allocated()
zt = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferOptions(cpu_access=True, nolru=True)).ensure_allocated()
zb = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
zt = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
ctypes.memset(zb._buf.va_addr, 0, zb.nbytes)
kernargs = runner.clprg.fill_kernargs([zt._buf, zb._buf])
@@ -190,8 +190,8 @@ class TestHCQ(unittest.TestCase):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
sz = 64 << 20
buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferOptions(host=True, nolru=True)).ensure_allocated()
buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
ctypes.memset(buf2._buf.va_addr, 1, sz)
TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
@@ -224,8 +224,8 @@ class TestHCQ(unittest.TestCase):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
sz = 64 << 20
buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferOptions(host=True, nolru=True)).ensure_allocated()
buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
ctypes.memset(buf2._buf.va_addr, 1, sz)
q = TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
@@ -301,8 +301,8 @@ class TestHCQ(unittest.TestCase):
# THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least.
SZ = 200_000_000
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
b = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
b = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t()
TestHCQ.d0.hw_copy_queue_t().timestamp(sig_st) \
@@ -329,8 +329,8 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0._prof_setup()
SZ = 200_000_000
b = Buffer(f"{Device.DEFAULT}:1", SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
b = Buffer(f"{Device.DEFAULT}:1", SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
TestHCQ.d0._gpu_map(b._buf)
sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t()
@@ -367,8 +367,8 @@ class TestHCQ(unittest.TestCase):
def test_small_copies_from_host_buf(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(host=True, nolru=True)).ensure_allocated()
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
for i in range(256):
ctypes.memset(buf2._buf.va_addr, i, 1)
@@ -384,9 +384,9 @@ class TestHCQ(unittest.TestCase):
def test_small_copies_from_host_buf_intercopy(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
buf3 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(host=True, nolru=True)).ensure_allocated()
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf3 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
for i in range(256):
ctypes.memset(buf3._buf.va_addr, i, 1)
@@ -406,9 +406,9 @@ class TestHCQ(unittest.TestCase):
try: _ = Device[f"{Device.DEFAULT}:1"]
except Exception: self.skipTest("no multidevice, test skipped")
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
buf2 = Buffer(f"{Device.DEFAULT}:1", 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
buf3 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(host=True, nolru=True)).ensure_allocated()
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf2 = Buffer(f"{Device.DEFAULT}:1", 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf3 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
TestHCQ.d0.allocator.map(buf2._buf)
for i in range(256):
@@ -428,8 +428,8 @@ class TestHCQ(unittest.TestCase):
b = a + 1
runner = get_runner(TestHCQ.d0.device, create_schedule([b.lazydata])[-1].ast)
buf1 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferOptions(cpu_access=True, nolru=True)).ensure_allocated()
buf1 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
kernargs_ptr = runner.clprg.fill_kernargs([buf1._buf, buf2._buf])
@@ -449,9 +449,9 @@ class TestHCQ(unittest.TestCase):
def test_memory_barrier_before_copy(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
buf3 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(cpu_access=True, nolru=True)).ensure_allocated()
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf3 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
for i in range(256):
ctypes.memset(buf3._buf.va_addr, i, 1)

View File

@@ -1,7 +1,7 @@
import unittest, struct, contextlib, tempfile, pathlib, json, time, atexit, random
from tinygrad import Device, Tensor, dtypes, TinyJit
from tinygrad.helpers import CI, getenv, Context
from tinygrad.device import Buffer, BufferOptions
from tinygrad.device import Buffer, BufferSpec
from tinygrad.runtime.support.hcq import ProfileLogger, HCQCompiled
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_runner
@@ -94,7 +94,7 @@ class TestProfiler(unittest.TestCase):
helper_validate_node(kernel_node, profile=profile, pid_name=Device.DEFAULT, tid_name="COMPUTE")
def test_profile_copyin(self):
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferOptions(nolru=True)).ensure_allocated()
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
with helper_collect_profile(TestProfiler.d0) as profile:
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
@@ -104,7 +104,7 @@ class TestProfiler(unittest.TestCase):
def test_profile_multiops(self):
runner_name = TestProfiler.runner.clprg.name
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferOptions(nolru=True)).ensure_allocated()
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
with helper_collect_profile(TestProfiler.d0) as profile:
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
@@ -125,8 +125,8 @@ class TestProfiler(unittest.TestCase):
def test_profile_multidev_copyin(self):
d1 = Device[f"{Device.DEFAULT}:1"]
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferOptions(nolru=True)).ensure_allocated()
buf2 = Buffer(f"{Device.DEFAULT}:1", 2, dtypes.float, options=BufferOptions(nolru=True)).ensure_allocated()
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
buf2 = Buffer(f"{Device.DEFAULT}:1", 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
with helper_collect_profile(TestProfiler.d0, d1) as profile:
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
@@ -197,7 +197,7 @@ class TestProfiler(unittest.TestCase):
expected_diff = 100000 # sleep in us
devs = [Device[f"{Device.DEFAULT}:{i}"] for i in range(6)]
bufs = [Buffer(f"{Device.DEFAULT}:{i}", 2, dtypes.float, options=BufferOptions(nolru=True)).ensure_allocated() for i in range(6)]
bufs = [Buffer(f"{Device.DEFAULT}:{i}", 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated() for i in range(6)]
# enqueue ops on different queues to check the timer sync
cpu_time = []

View File

@@ -9,7 +9,7 @@ from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup, flatten, prod
from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.ops import UOp, Ops
from tinygrad.renderer import Program
from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.lazy import LazyBuffer
@@ -23,7 +23,7 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
inbufs = [cast(LazyBuffer,x.lazydata).base.buffer for x in inputs]
src = Device[Device.DEFAULT].renderer.render("test", uops)
ei = CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, local_size=local_size))
ei = CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, uops=uops, local_size=local_size))
ei.exec(outbufs+inbufs)
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]

View File

@@ -7,7 +7,7 @@ from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.dtype import dtypes, DType
from tinygrad.device import Buffer, Device
from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.renderer import Program
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.schedule import create_schedule, to_si
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.linearize import linearize_uop
@@ -20,7 +20,7 @@ def _uops_to_prg(uops_list):
uops = linearize_uop(full_graph_rewrite(UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer))
src = Device[Device.DEFAULT].renderer.render("test", uops)
has_local = Device[Device.DEFAULT].renderer.has_local
return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops,
return CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, uops=uops,
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
def uop(uops:List[UOp], uop:Ops, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp:

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python
import unittest
from tinygrad.device import Device, BufferOptions
from tinygrad.device import Device, BufferSpec
from tinygrad.dtype import dtypes
@unittest.skipUnless(Device.DEFAULT == "QCOM", "QCOM device required to run")
@@ -9,7 +9,7 @@ class TestQcom(unittest.TestCase):
dev = Device["QCOM"]
def __validate(imgdt, expected_pitch):
img = dev.allocator.alloc(imgdt.shape[0] * imgdt.shape[1] * 16, options:=BufferOptions(image=imgdt))
img = dev.allocator.alloc(imgdt.shape[0] * imgdt.shape[1] * 16, options:=BufferSpec(image=imgdt))
pitch = (img.descriptor[2] & 0x1fffff80) >> 7
assert pitch == expected_pitch, f"Failed pitch for image: {imgdt}. Got 0x{pitch:X}, expected 0x{expected_pitch:X}"
dev.allocator.free(img, imgdt.shape[0] * imgdt.shape[1] * 16, options)

View File

@@ -8,7 +8,7 @@ from enum import Enum, auto
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
graph_rewrite, track_rewrites, UPat
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
from tinygrad.dtype import ImageDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
@@ -702,7 +702,7 @@ class Kernel:
if DEBUG >= 5: print_uops(self.uops)
return self
def to_program(self, name_override:Optional[str]=None) -> Program:
def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
self.linearize()
src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops)
@@ -715,7 +715,7 @@ class Kernel:
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
key=lambda x: (x.op, x.src[0].arg)))
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
return ProgramSpec(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
# the living definition of intermediate UOps

View File

@@ -44,7 +44,8 @@ Device = _Device()
# **************** Buffer + Allocators ****************
@dataclass(frozen=True, eq=True)
class BufferOptions:
class BufferSpec:
# TODO: move device, size, dtype here?
image: Optional[ImageDType] = None
uncached: bool = False
cpu_access: bool = False
@@ -53,9 +54,9 @@ class BufferOptions:
external_ptr: Optional[int] = None
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferSpec]=None,
initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
if isinstance(dtype, ImageDType): options = BufferSpec(image=dtype) # TODO: image hack shouldn't be here. where should it be?
else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
if base is None:
@@ -82,7 +83,7 @@ class Buffer:
assert not self.is_allocated(), "can't allocate already allocated buffer"
self.allocator = Device[self.device].allocator
if external_ptr is not None:
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferOptions(external_ptr=external_ptr)
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr)
if self._base is not None:
self._base.ensure_allocated()
assert hasattr(self.allocator, "_offset"), "offset function required for view"
@@ -135,12 +136,15 @@ class Buffer:
# TODO: size, dest, src are the same type. can we enforce this?
class Allocator:
def alloc(self, size:int, options:Optional[BufferOptions]=None):
# overriden in LRUAllocator
def alloc(self, size:int, options:Optional[BufferSpec]=None):
assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
return self._alloc(size, options if options is not None else BufferOptions())
def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
def free(self, opaque, size:int, options:Optional[BufferOptions]=None): self._free(opaque, options if options is not None else BufferOptions())
def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
return self._alloc(size, options if options is not None else BufferSpec())
def free(self, opaque, size:int, options:Optional[BufferSpec]=None): self._free(opaque, options if options is not None else BufferSpec())
# implemented by the runtime
def _alloc(self, size:int, options:BufferSpec): raise NotImplementedError("need alloc")
def _free(self, opaque, options:BufferSpec): pass # if opaque is a Python object, you don't need a free
def _copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
def _copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
# def _as_buffer(self, src) -> memoryview:
@@ -152,8 +156,8 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
The LRU Allocator is responsible for caching buffers.
It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
"""
def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
def alloc(self, size:int, options:Optional[BufferOptions]=None):
def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferSpec]], Any] = defaultdict(list)
def alloc(self, size:int, options:Optional[BufferSpec]=None):
if len(c := self.cache[(size, options)]): return c.pop()
try: return super().alloc(size, options)
except (RuntimeError, MemoryError):
@@ -163,12 +167,12 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
for (sz,options),opaques in self.cache.items():
for opaque in opaques: super().free(opaque, sz, options)
opaques.clear()
def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
def free(self, opaque:Any, size:int, options:Optional[BufferSpec]=None):
if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
else: super().free(opaque, size, options)
class _MallocAllocator(LRUAllocator):
def _alloc(self, size:int, options:BufferOptions):
def _alloc(self, size:int, options:BufferSpec):
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else (ctypes.c_uint8 * size)()
def _as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
def _copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))

View File

@@ -5,7 +5,7 @@ from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BE
from tinygrad.ops import Ops, UOp, Variable, sym_infer, sint
from tinygrad.dtype import dtypes
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, Program
from tinygrad.renderer import Renderer, ProgramSpec
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.schedule import ScheduleItem
@@ -75,9 +75,9 @@ class Runner:
raise NotImplementedError("override this")
class CompiledRunner(Runner):
def __init__(self, p:Program, precompiled:Optional[bytes]=None):
def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None):
if DEBUG >= 4: print(p.src)
self.p:Program = p
self.p:ProgramSpec = p
self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src)
if DEBUG >= 6: Device[p.device].compiler.disassemble(self.lib)
self.clprg = Device[p.device].runtime(p.function_name, self.lib)
@@ -148,7 +148,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
else:
prg: Program = get_kernel(Device[device].renderer, ast).to_program()
prg: ProgramSpec = get_kernel(Device[device].renderer, ast).to_program()
if getenv("FUZZ_UOPS"):
from test.external.fuzz_uops import UOpsFuzzerRunner
return UOpsFuzzerRunner(replace(prg, device=device))

View File

@@ -9,7 +9,7 @@ from tinygrad.dtype import ImageDType, PtrDType
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
from tinygrad.tensor import Tensor
from tinygrad.engine.realize import CompiledRunner
from tinygrad.renderer import Program
from tinygrad.renderer import ProgramSpec
actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)]
@@ -33,7 +33,7 @@ def _get_test_global_size(global_size, max_global_size, var_vals):
break
return test_global_size, factor
def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:List[Buffer], early_stop:Optional[float]=None,
def _time_program(p:ProgramSpec, lib:bytes, var_vals:Dict[Variable, int], rawbufs:List[Buffer], early_stop:Optional[float]=None,
max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> List[float]:
factor = 1
if p.global_size is not None and max_global_size is not None:
@@ -55,7 +55,7 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li
class TimeoutException(Exception): pass
def timeout_handler(signum, frame): raise TimeoutException()
def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[ProgramSpec, bytes, float]]]:
if hasattr(signal, "alarm"):
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
# set timeout

View File

@@ -23,7 +23,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
@dataclass
class Program:
class ProgramSpec:
name:str
src:str
device:str

View File

@@ -2,7 +2,7 @@ import collections, time
from typing import List, Any, Dict, cast, Optional, Tuple, Set
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState
from tinygrad.device import Buffer, BufferOptions, Compiled, Device
from tinygrad.device import Buffer, BufferSpec, Compiled, Device
from tinygrad.ops import Variable
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
from tinygrad.engine.jit import MultiGraphRunner
@@ -17,7 +17,7 @@ class HCQGraph(MultiGraphRunner):
for ji in self.jit_cache:
if not isinstance(ji.prg, CompiledRunner): continue
kernargs_size[ji.prg.dev] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferSpec(cpu_access=True)) for dev,sz in kernargs_size.items()}
# Fill initial arguments.
self.ji_args: Dict[int, HCQArgsState] = {}
@@ -197,4 +197,4 @@ class HCQGraph(MultiGraphRunner):
if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferOptions(cpu_access=True))
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))

View File

@@ -4,7 +4,7 @@ import os, ctypes, ctypes.util, functools, pathlib, mmap, errno, time, array, co
assert sys.platform != 'win32'
from dataclasses import dataclass
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram
from tinygrad.device import BufferOptions
from tinygrad.device import BufferSpec
from tinygrad.helpers import getenv, to_mv, round_up, data64_le, mv_address
from tinygrad.renderer.cstyle import AMDRenderer
from tinygrad.runtime.autogen import kfd, hsa, amd_gpu, libc
@@ -61,7 +61,7 @@ class AMDComputeQueue(HWQueue): # pylint: disable=abstract-method
def __del__(self):
if self.binded_device is not None:
self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferOptions(cpu_access=True, nolru=True, uncached=True))
self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True, uncached=True))
def _acquire_mem(self, addr=0x0, sz=(1 << 64)-1, gli=1, glm=1, glk=1, glv=1, gl1=1, gl2=1):
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_ACQUIRE_MEM, 6), 0, *data64_le(sz), *data64_le(addr), 0,
@@ -166,7 +166,7 @@ class AMDComputeQueue(HWQueue): # pylint: disable=abstract-method
def bind(self, dev:AMDDevice):
self.binded_device = dev
self.hw_page = dev.allocator.alloc(len(self.q) * 4, BufferOptions(cpu_access=True, nolru=True, uncached=True))
self.hw_page = dev.allocator.alloc(len(self.q) * 4, BufferSpec(cpu_access=True, nolru=True, uncached=True))
hw_view = to_mv(self.hw_page.va_addr, self.hw_page.size).cast("I")
for i, value in enumerate(self.q): hw_view[i] = value
@@ -274,7 +274,7 @@ class AMDProgram(HCQProgram):
self.dev: AMDDevice = dev
self.name, self.lib = name, lib
image, sections, _ = elf_loader(self.lib)
self.lib_gpu = self.dev.allocator.alloc(round_up(image.nbytes, 0x1000), BufferOptions(cpu_access=True, nolru=True))
self.lib_gpu = self.dev.allocator.alloc(round_up(image.nbytes, 0x1000), BufferSpec(cpu_access=True, nolru=True))
ctypes.memmove(self.lib_gpu.va_addr, mv_address(image), image.nbytes)
entry_point = min(sh.header.sh_addr for sh in sections if sh.header.sh_type == libc.SHT_PROGBITS and sh.header.sh_flags & libc.SHF_ALLOC)
@@ -303,17 +303,17 @@ class AMDProgram(HCQProgram):
super().__init__(AMDArgsState, self.dev, self.name, kernargs_alloc_size=self.kernargs_segment_size+additional_alloc_sz)
def __del__(self):
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, BufferOptions(cpu_access=True, nolru=True))
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, BufferSpec(cpu_access=True, nolru=True))
class AMDAllocator(HCQAllocator['AMDDevice']):
def __init__(self, dev:AMDDevice): super().__init__(dev, batch_size=SDMA_MAX_COPY_SIZE)
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
if options.host: return self.dev._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR, public=True)
if options.cpu_access and options.uncached: return self.dev._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
return self.dev._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM, public=options.cpu_access)
def _free(self, opaque, options:BufferOptions):
def _free(self, opaque, options:BufferSpec):
self.dev.synchronize()
self.dev._gpu_free(opaque)

View File

@@ -13,14 +13,14 @@ from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad.renderer import Renderer
from tinygrad.dtype import dtypes
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferSpec
# ***** API *****
class CloudRequest: pass
@dataclass(frozen=True)
class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferOptions # noqa: E702
class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
@dataclass(frozen=True)
class BufferFree(CloudRequest): buffer_num: int # noqa: E702
@@ -43,7 +43,7 @@ class ProgramExec(CloudRequest):
global_size: Optional[Tuple[int, ...]]; local_size: Optional[Tuple[int, ...]]; wait: bool # noqa: E702
# for safe deserialization
whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferOptions]}
whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferSpec]}
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]}
@@ -76,7 +76,7 @@ class BatchRequest:
class CloudSession:
programs: Dict[Tuple[str, str], Any] = field(default_factory=dict)
# TODO: the buffer should track this internally
buffers: Dict[int, Tuple[Any, int, Optional[BufferOptions]]] = field(default_factory=dict)
buffers: Dict[int, Tuple[Any, int, Optional[BufferSpec]]] = field(default_factory=dict)
class CloudHandler(BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
@@ -143,7 +143,7 @@ class CloudAllocator(Allocator):
self.device = dev
super().__init__()
# TODO: ideally we shouldn't have to deal with images here
def _alloc(self, size:int, options:BufferOptions) -> int:
def _alloc(self, size:int, options:BufferSpec) -> int:
self.device.buffer_num += 1
self.device.req.q(BufferAlloc(self.device.buffer_num, size, options))
return self.device.buffer_num

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import ctypes, ctypes.util, functools
from typing import Tuple, Optional, List
from tinygrad.helpers import DEBUG, getenv, from_mv, init_c_var, init_c_struct_t
from tinygrad.device import Compiled, BufferOptions, LRUAllocator
from tinygrad.device import Compiled, BufferSpec, LRUAllocator
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.runtime.autogen import cuda
@@ -63,17 +63,17 @@ class CUDAAllocator(LRUAllocator):
def __init__(self, device:CUDADevice):
self.device = device
super().__init__()
def _alloc(self, size, options:BufferOptions):
def _alloc(self, size, options:BufferSpec):
check(cuda.cuCtxSetCurrent(self.device.context))
if options.host: return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0x01)))
return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size)))
def _free(self, opaque, options:BufferOptions):
def _free(self, opaque, options:BufferSpec):
if options.host: check(cuda.cuMemFreeHost(opaque))
else: check(cuda.cuMemFree_v2(opaque))
def _copyin(self, dest, src:memoryview):
check(cuda.cuCtxSetCurrent(self.device.context))
host_mem = self.alloc(len(src), BufferOptions(host=True))
self.device.pending_copyin.append((host_mem, len(src), BufferOptions(host=True)))
host_mem = self.alloc(len(src), BufferSpec(host=True))
self.device.pending_copyin.append((host_mem, len(src), BufferSpec(host=True)))
ctypes.memmove(host_mem, from_mv(src), len(src))
check(cuda.cuMemcpyHtoDAsync_v2(dest, host_mem, len(src), None))
def _copyout(self, dest:memoryview, src):
@@ -110,7 +110,7 @@ class CUDADevice(Compiled):
CUDADevice.peer_access = True
self.arch = f"sm_{major.value}{minor.value}"
self.pending_copyin: List[Tuple[int, int, Optional[BufferOptions]]] = []
self.pending_copyin: List[Tuple[int, int, Optional[BufferSpec]]] = []
CUDADevice.devices.append(self)
from tinygrad.runtime.graph.cuda import CUDAGraph

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Tuple, Any
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys
assert sys.platform != 'win32'
from tinygrad.device import BufferOptions, Compiled, Allocator
from tinygrad.device import BufferSpec, Compiled, Allocator
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv
from tinygrad.runtime.ops_clang import ClangCompiler
from tinygrad.renderer.cstyle import DSPRenderer
@@ -43,13 +43,13 @@ class DSPAllocator(Allocator):
self.dev = dev
super().__init__()
def _alloc(self, size:int, options:BufferOptions):
def _alloc(self, size:int, options:BufferSpec):
b = qcom_dsp.ION_IOC_ALLOC(self.dev.ion_fd, len=size, align=0x200, heap_id_mask=1<<qcom_dsp.ION_SYSTEM_HEAP_ID, flags=qcom_dsp.ION_FLAG_CACHED)
share_info = qcom_dsp.ION_IOC_SHARE(self.dev.ion_fd, handle=b.handle)
va_addr = libc.mmap(0, size, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED, share_info.fd, 0)
return DSPBuffer(va_addr, size, share_info, offset=0)
def _free(self, opaque:DSPBuffer, options:BufferOptions):
def _free(self, opaque:DSPBuffer, options:BufferSpec):
libc.munmap(opaque.va_addr, opaque.size)
os.close(opaque.share_info.fd)
qcom_dsp.ION_IOC_FREE(self.dev.ion_fd, handle=opaque.share_info.handle)
@@ -75,7 +75,7 @@ class DSPDevice(Compiled):
ClangCompiler("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump'), functools.partial(DSPProgram, self))
fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes()))
self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferOptions(nolru=True))
self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True))
ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes)
self.init_dsp()

View File

@@ -4,7 +4,7 @@ import ctypes, functools, hashlib
from tinygrad.runtime.autogen import opencl as cl
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, getenv, mv_address
from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer
from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler, CompileError
from tinygrad.device import BufferSpec, LRUAllocator, Compiled, Compiler, CompileError
# see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
OSX_TIMING_RATIO = (125/3) if OSX else 1.0
@@ -44,7 +44,7 @@ class CLProgram:
if hasattr(self, 'kernel'): check(cl.clReleaseKernel(self.kernel))
if hasattr(self, 'program'): check(cl.clReleaseProgram(self.program))
def __call__(self, *bufs:Tuple[ctypes._CData, BufferOptions], global_size:Tuple[int,int,int]=(1,1,1), local_size:Optional[Tuple[int,int,int]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]: # noqa: E501
def __call__(self, *bufs:Tuple[ctypes._CData, BufferSpec], global_size:Tuple[int,int,int]=(1,1,1), local_size:Optional[Tuple[int,int,int]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]: # noqa: E501
for i,(b,_) in enumerate(bufs): cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))
for i,v in enumerate(vals,start=len(bufs)): cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v)))
if local_size is not None: global_size = cast(Tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))
@@ -62,14 +62,14 @@ class CLAllocator(LRUAllocator):
def __init__(self, dev:CLDevice):
self.dev = dev
super().__init__()
def _alloc(self, size:int, options:BufferOptions) -> Tuple[ctypes._CData, BufferOptions]:
def _alloc(self, size:int, options:BufferSpec) -> Tuple[ctypes._CData, BufferSpec]:
if options.image is not None:
return (checked(cl.clCreateImage2D(self.dev.context, cl.CL_MEM_READ_WRITE,
cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]),
options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status), options)
return (checked(cl.clCreateBuffer(self.dev.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status), options)
def _free(self, opaque:Tuple[ctypes._CData, BufferOptions], options:BufferOptions): check(cl.clReleaseMemObject(opaque[0]))
def _copyin(self, dest:Tuple[ctypes._CData, BufferOptions], src:memoryview):
def _free(self, opaque:Tuple[ctypes._CData, BufferSpec], options:BufferSpec): check(cl.clReleaseMemObject(opaque[0]))
def _copyin(self, dest:Tuple[ctypes._CData, BufferSpec], src:memoryview):
if dest[1].image is not None:
check(cl.clEnqueueWriteImage(self.dev.queue, dest[0], False, (ctypes.c_size_t * 3)(0,0,0),
(ctypes.c_size_t * 3)(dest[1].image.shape[1],dest[1].image.shape[0],1), 0, 0, from_mv(src), 0, None, None))
@@ -77,7 +77,7 @@ class CLAllocator(LRUAllocator):
if mv_address(src) % 16: src = memoryview(bytearray(src))
check(cl.clEnqueueWriteBuffer(self.dev.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
self.dev.pending_copyin.append(src) # NOTE: these can't be freed until the GPU actually executes this command
def _copyout(self, dest:memoryview, src:Tuple[ctypes._CData, BufferOptions]):
def _copyout(self, dest:memoryview, src:Tuple[ctypes._CData, BufferSpec]):
if src[1].image is not None:
check(cl.clEnqueueReadImage(self.dev.queue, src[0], False, (ctypes.c_size_t * 3)(0,0,0),
(ctypes.c_size_t * 3)(src[1].image.shape[1],src[1].image.shape[0],1), 0, 0, from_mv(dest), 0, None, None))

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import ctypes, functools
from typing import Tuple
from tinygrad.helpers import init_c_var, from_mv, init_c_struct_t, getenv
from tinygrad.device import Compiled, LRUAllocator, BufferOptions
from tinygrad.device import Compiled, LRUAllocator, BufferSpec
from tinygrad.runtime.autogen import hip
from tinygrad.runtime.support.compiler_hip import AMDCompiler
from tinygrad.renderer.cstyle import HIPRenderer
@@ -46,10 +46,10 @@ class HIPAllocator(LRUAllocator):
def __init__(self, dev:HIPDevice):
self.dev = dev
super().__init__()
def _alloc(self, size:int, options:BufferOptions):
def _alloc(self, size:int, options:BufferSpec):
check(hip.hipSetDevice(self.dev.device_id))
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
def _free(self, opaque, options:BufferOptions): check(hip.hipFree(opaque))
def _free(self, opaque, options:BufferSpec): check(hip.hipFree(opaque))
def _copyin(self, dest, src: memoryview):
check(hip.hipSetDevice(self.dev.device_id))
check(hip.hipMemcpy(dest, from_mv(src), len(src), hip.hipMemcpyHostToDevice))

View File

@@ -5,7 +5,7 @@ from typing import Tuple, List, Any, cast, Union, Dict, Type, Optional
from dataclasses import dataclass
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, hcq_command
from tinygrad.runtime.support.hcq import HCQArgsState, HCQProgram, HCQSignal
from tinygrad.device import BufferOptions
from tinygrad.device import BufferSpec
from tinygrad.helpers import getenv, mv_address, init_c_struct_t, to_mv, round_up, data64, data64_le, DEBUG, prod
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.cstyle import NVRenderer
@@ -85,7 +85,7 @@ class NVSignal(HCQSignal):
class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']): # pylint: disable=abstract-method
def __del__(self):
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferOptions(cpu_access=True, nolru=True))
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True))
@hcq_command
def setup(self, compute_class=None, copy_class=None, local_mem_window=None, shared_mem_window=None, local_mem=None, local_mem_tpc_bytes=None):
@@ -108,7 +108,7 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
def bind(self, dev:NVDevice):
self.binded_device = dev
self.hw_page = dev.allocator.alloc(len(self.q) * 4, BufferOptions(cpu_access=True, nolru=True))
self.hw_page = dev.allocator.alloc(len(self.q) * 4, BufferSpec(cpu_access=True, nolru=True))
hw_view = to_mv(self.hw_page.va_addr, self.hw_page.size).cast("I")
for i, value in enumerate(self.q): hw_view[i] = value
@@ -228,7 +228,7 @@ class NVProgram(HCQProgram):
else: image, sections, relocs = elf_loader(self.lib, force_section_align=128)
# NOTE: Ensure at least 4KB of space after the program to mitigate prefetch memory faults.
self.lib_gpu = self.dev.allocator.alloc(round_up(image.nbytes, 0x1000) + 0x1000, BufferOptions(cpu_access=True))
self.lib_gpu = self.dev.allocator.alloc(round_up(image.nbytes, 0x1000) + 0x1000, BufferSpec(cpu_access=True))
self.prog_addr, self.prog_sz, self.regs_usage, self.shmem_usage, self.lcmem_usage = self.lib_gpu.va_addr, image.nbytes, 0, 0x400, 0
self.constbufs: Dict[int, Tuple[int, int]] = {0: (0, 0x160)} # Dict[constbuf index, Tuple[va_addr, size]]
@@ -281,7 +281,7 @@ class NVProgram(HCQProgram):
super().__init__(NVArgsState, self.dev, self.name, kernargs_alloc_size=round_up(self.constbufs[0][1], 1 << 8) + (8 << 8))
def __del__(self):
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, BufferOptions(cpu_access=True))
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, BufferSpec(cpu_access=True))
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
if prod(local_size) > 1024 or self.max_threads < prod(local_size) or self.lcmem_usage > cast(NVDevice, self.dev).slm_per_thread:
@@ -291,11 +291,11 @@ class NVProgram(HCQProgram):
return super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
class NVAllocator(HCQAllocator['NVDevice']):
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
if options.host: return self.dev._gpu_host_alloc(size, tag="user host memory")
return self.dev._gpu_alloc(size, map_to_cpu=options.cpu_access, huge_page=(size > (16 << 20)), tag=f"user memory ({options})")
def _free(self, opaque, options:BufferOptions):
def _free(self, opaque, options:BufferSpec):
self.dev.synchronize()
self.dev._gpu_free(opaque)

View File

@@ -3,7 +3,7 @@ import os, ctypes, functools, mmap, struct, array, decimal, math, sys
assert sys.platform != 'win32'
from types import SimpleNamespace
from typing import Tuple, List, Any, cast, Optional
from tinygrad.device import BufferOptions
from tinygrad.device import BufferSpec
from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQSignal, HCQAllocator, HCQArgsState
from tinygrad.runtime.autogen import kgsl, adreno, libc
from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice
@@ -50,7 +50,7 @@ class QCOMComputeQueue(HWQueue): # pylint: disable=abstract-method
super().__init__()
def __del__(self):
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferOptions(cpu_access=True, nolru=True))
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True))
def cmd(self, opcode: int, *vals: int): self.q += [pkt7_hdr(opcode, len(vals)), *vals]
@@ -98,7 +98,7 @@ class QCOMComputeQueue(HWQueue): # pylint: disable=abstract-method
def bind(self, dev:QCOMDevice):
self.binded_device = dev
self.hw_page = dev.allocator.alloc(len(self.q) * 4, BufferOptions(cpu_access=True, nolru=True))
self.hw_page = dev.allocator.alloc(len(self.q) * 4, BufferSpec(cpu_access=True, nolru=True))
self.submit_req, self.obj = self._build_gpu_command(self.binded_device, self.hw_page.va_addr)
# From now on, the queue is on the device for faster submission.
self.q = to_mv(self.obj.gpuaddr, len(self.q) * 4).cast("I") # type: ignore
@@ -213,7 +213,7 @@ class QCOMProgram(HCQProgram):
self.name, self.lib = name, lib
self._parse_lib()
self.lib_gpu: HCQBuffer = self.dev.allocator.alloc(self.image_size, options=BufferOptions(cpu_access=True, nolru=True))
self.lib_gpu: HCQBuffer = self.dev.allocator.alloc(self.image_size, options=BufferSpec(cpu_access=True, nolru=True))
to_mv(self.lib_gpu.va_addr, self.image_size)[:] = self.image
self.pvtmem_size_per_item: int = round_up(self.pvtmem, 512) >> 9
@@ -283,7 +283,7 @@ class QCOMProgram(HCQProgram):
self.fregs, self.hregs = _read_lib(reg_desc_off + 0x14), _read_lib(reg_desc_off + 0x18)
def __del__(self):
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferOptions(cpu_access=True, nolru=True))
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferSpec(cpu_access=True, nolru=True))
class QCOMBuffer(HCQBuffer):
def __init__(self, va_addr:int, size:int, info=None, mapped=False, desc=None, ibo=None, pitch=None, real_stride=None, **kwargs):
@@ -293,7 +293,7 @@ class QCOMBuffer(HCQBuffer):
self.desc, self.ibo, self.pitch, self.real_stride = [0] * 16, [0] * 16, pitch, real_stride
class QCOMAllocator(HCQAllocator):
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
if options.image is not None:
imgw, imgh, itemsize_log = options.image.shape[1], options.image.shape[0], int(math.log2(options.image.itemsize))
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
@@ -337,7 +337,7 @@ class QCOMAllocator(HCQAllocator):
self.dev.synchronize()
return to_mv(src.va_addr, src.size)
def _free(self, opaque, options:BufferOptions):
def _free(self, opaque, options:BufferSpec):
self.dev.synchronize()
self.dev._gpu_free(opaque)

View File

@@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, Typ
import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes, functools
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv
from tinygrad.renderer import Renderer
from tinygrad.device import BufferOptions, Compiler, Compiled, LRUAllocator
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator
# **************** for HCQ Compatible Devices ****************
@@ -372,7 +372,7 @@ class HCQCompiled(Compiled, Generic[SignalType]):
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferOptions(cpu_access=True))
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferSpec(cpu_access=True))
self.kernargs_ptr:int = self.kernargs_page.va_addr
self.devices.append(self)
@@ -483,11 +483,11 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]): # pylint: disable=abstrac
def __init__(self, dev:DeviceType, batch_size:int=(2 << 20), batch_cnt:int=32):
self.dev:DeviceType = dev
self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
self.b = [self._alloc(batch_size, BufferSpec(host=True)) for _ in range(batch_cnt)]
self.b_timeline, self.b_next = [0] * len(self.b), 0
super().__init__()
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
def _copyin(self, dest:HCQBuffer, src:memoryview):
assert self.dev.hw_copy_queue_t is not None

View File

@@ -10,7 +10,7 @@ from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_u
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait
from tinygrad.device import Device, Buffer, BufferOptions
from tinygrad.device import Device, Buffer, BufferSpec
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.memory import memory_planner
@@ -259,7 +259,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
# NOTE: this realizes on the object from as_buffer being a Python object
cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
if self.device != "CLANG": buf.options = BufferSpec(nolru=True)
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
def data(self) -> memoryview: