mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
BufferSpec and ProgramSpec [pr] (#7814)
* BufferSpec and ProgramSpec [pr]
* delete preallocate, it's unused
* Revert "delete preallocate, it's unused"
This reverts commit dcfcfaccde.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import ctypes, collections, time, itertools
|
||||
from typing import List, Any, Dict, cast, Optional, Tuple
|
||||
from tinygrad.helpers import init_c_var, round_up
|
||||
from tinygrad.device import Buffer, BufferOptions
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.device import Compiled, Device
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
||||
@@ -44,7 +44,7 @@ class HSAGraph(MultiGraphRunner):
|
||||
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
||||
for ji in self.jit_cache:
|
||||
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
||||
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
|
||||
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferSpec()) for dev,sz in kernargs_size.items()}
|
||||
|
||||
# Fill initial arguments.
|
||||
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Tuple, TypeVar, List, Any, cast, Set
|
||||
import tinygrad.runtime.autogen.hip as hip
|
||||
from tinygrad.helpers import DEBUG, getenv, init_c_var
|
||||
from tinygrad.helpers import from_mv, round_up, to_mv, colored, init_c_struct_t
|
||||
from tinygrad.device import Compiled, LRUAllocator, BufferOptions, Runner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions
|
||||
from tinygrad.device import Compiled, LRUAllocator, BufferSpec, Runner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.runtime.support.hip_comgr import compile_hip
|
||||
from tinygrad.renderer.rdna import uops_to_rdna
|
||||
@@ -93,7 +93,7 @@ class HIPAllocator(LRUAllocator):
|
||||
def _alloc(self, size:int):
|
||||
hip_set_device(self.device.device)
|
||||
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
|
||||
def _alloc_with_options(self, size:int, options:BufferOptions):
|
||||
def _alloc_with_options(self, size:int, options:BufferSpec):
|
||||
hip_set_device(self.device.device)
|
||||
if options.uncached:
|
||||
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipExtMallocWithFlags(ctypes.byref(x), size, 3))) # hipDeviceMallocUncached = 3
|
||||
@@ -105,7 +105,7 @@ class HIPAllocator(LRUAllocator):
|
||||
def copy_from_fd(self, dest, fd, offset, size):
|
||||
hip_set_device(self.device.device)
|
||||
if not hasattr(self, 'hb'):
|
||||
self.hb = [self._alloc_with_options(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)]
|
||||
self.hb = [self._alloc_with_options(CHUNK_SIZE, BufferSpec(host=True)) for _ in range(2)]
|
||||
self.hb_events = [None, None]
|
||||
self.hb_polarity = 0
|
||||
fo = io.FileIO(fd, "a+b", closefd=False)
|
||||
@@ -128,7 +128,7 @@ class HIPAllocator(LRUAllocator):
|
||||
minor_offset = 0 # only on the first
|
||||
def _copyin(self, dest:T, src: memoryview):
|
||||
hip_set_device(self.device.device)
|
||||
host_mem = self._alloc_with_options(len(src), BufferOptions(host=True))
|
||||
host_mem = self._alloc_with_options(len(src), BufferSpec(host=True))
|
||||
self.device.pending_copyin.append(host_mem)
|
||||
ctypes.memmove(host_mem, from_mv(src), len(src))
|
||||
check(hip.hipMemcpyAsync(dest, host_mem, len(src), hip.hipMemcpyHostToDevice, None))
|
||||
|
||||
@@ -3,7 +3,7 @@ import ctypes, functools, subprocess, io, atexit, collections, json
|
||||
from typing import Tuple, TypeVar, List, Dict, Any
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv, PROFILE
|
||||
from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator
|
||||
from tinygrad.device import Compiled, Compiler, CompileError, BufferSpec, LRUAllocator
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.runtime.support.hsa import check, scan_agents, find_memory_pool, AQLQueue
|
||||
from tinygrad.runtime.support.hip_comgr import compile_hip
|
||||
@@ -102,7 +102,7 @@ class HSAAllocator(LRUAllocator):
|
||||
self.device = device
|
||||
super().__init__()
|
||||
|
||||
def _alloc(self, size:int, options:BufferOptions):
|
||||
def _alloc(self, size:int, options:BufferSpec):
|
||||
if options.host:
|
||||
check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, size, 0, ctypes.byref(mem := ctypes.c_void_p())))
|
||||
check(hsa.hsa_amd_agents_allow_access(2, (hsa.hsa_agent_t*2)(HSADevice.cpu_agent, self.device.agent), None, mem))
|
||||
@@ -112,14 +112,14 @@ class HSAAllocator(LRUAllocator):
|
||||
check(hsa.hsa_amd_agents_allow_access(len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]), c_agents, None, buf))
|
||||
return buf.value
|
||||
|
||||
def _free(self, opaque:T, options:BufferOptions):
|
||||
def _free(self, opaque:T, options:BufferSpec):
|
||||
HSADevice.synchronize_system()
|
||||
check(hsa.hsa_amd_memory_pool_free(opaque))
|
||||
|
||||
def _copyin(self, dest:T, src: memoryview):
|
||||
# Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets.
|
||||
self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
|
||||
mem = self._alloc(src.nbytes, BufferOptions(host=True))
|
||||
mem = self._alloc(src.nbytes, BufferSpec(host=True))
|
||||
ctypes.memmove(mem, from_mv(src), src.nbytes)
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes, 1, ctypes.byref(sync_signal),
|
||||
copy_signal := self.device.alloc_signal(reusable=True), hsa.HSA_AMD_SDMA_ENGINE_0, True))
|
||||
@@ -131,7 +131,7 @@ class HSAAllocator(LRUAllocator):
|
||||
self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
|
||||
|
||||
if not hasattr(self, 'hb'):
|
||||
self.hb = [self._alloc(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)]
|
||||
self.hb = [self._alloc(CHUNK_SIZE, BufferSpec(host=True)) for _ in range(2)]
|
||||
self.hb_signals = [self.device.alloc_signal(reusable=False) for _ in range(2)]
|
||||
self.hb_polarity = 0
|
||||
self.sdma = [hsa.HSA_AMD_SDMA_ENGINE_0, hsa.HSA_AMD_SDMA_ENGINE_1]
|
||||
@@ -256,7 +256,7 @@ class HSADevice(Compiled):
|
||||
|
||||
def _new_kernargs_region(self, sz:int):
|
||||
if hasattr(self, 'kernarg_start_addr'): self.delayed_free.append(self.kernarg_start_addr)
|
||||
self.kernarg_start_addr: int = self.allocator._alloc(sz, BufferOptions())
|
||||
self.kernarg_start_addr: int = self.allocator._alloc(sz, BufferSpec())
|
||||
self.kernarg_next_addr = self.kernarg_start_addr
|
||||
self.kernarg_pool_sz: int = sz
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Tuple, Dict, List
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.tensor import Device, Tensor
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
@@ -24,7 +24,7 @@ web_utils = {
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
for ji in run.jit_cache:
|
||||
fxn: Program = ji.prg.p
|
||||
fxn: ProgramSpec = ji.prg.p
|
||||
functions[fxn.function_name] = fxn.src # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
for i,arg in enumerate(ji.bufs):
|
||||
|
||||
@@ -4,7 +4,7 @@ import triton.language as tl
|
||||
from triton.compiler import AttrsDescriptor, ASTSource, compile as triton_compile
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, Program
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, ProgramSpec
|
||||
from tinygrad.helpers import getenv
|
||||
np.set_printoptions(suppress=True)
|
||||
|
||||
@@ -85,7 +85,7 @@ if __name__ == "__main__":
|
||||
# remove debug sections
|
||||
src = src.split("\t.file")[0]
|
||||
assert '.extern .shared' not in src
|
||||
prg = Program("matmul_kernel", src, device=Device.DEFAULT,
|
||||
prg = ProgramSpec("matmul_kernel", src, device=Device.DEFAULT,
|
||||
global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1],
|
||||
mem_estimate=A.nbytes() + B.nbytes() + C.nbytes())
|
||||
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
||||
|
||||
10
test/external/external_test_hcq.py
vendored
10
test/external/external_test_hcq.py
vendored
@@ -1,7 +1,7 @@
|
||||
import unittest, ctypes, struct, time, array
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.helpers import to_mv, CI
|
||||
from tinygrad.device import Buffer, BufferOptions
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_runner
|
||||
|
||||
@@ -255,8 +255,8 @@ class TestHCQ(unittest.TestCase):
|
||||
def test_copy_bandwidth(self):
|
||||
# THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least.
|
||||
SZ = 2_000_000_000
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
|
||||
b = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
b = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
q = TestHCQ.copy_queue()
|
||||
q.copy(a._buf.va_addr, b._buf.va_addr, SZ)
|
||||
et = _time_queue(q, TestHCQ.d0)
|
||||
@@ -266,8 +266,8 @@ class TestHCQ(unittest.TestCase):
|
||||
|
||||
def test_cross_device_copy_bandwidth(self):
|
||||
SZ = 2_000_000_000
|
||||
b = Buffer(f"{Device.DEFAULT}:1", SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
|
||||
b = Buffer(f"{Device.DEFAULT}:1", SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
TestHCQ.d0._gpu_map(b._buf)
|
||||
q = TestHCQ.copy_queue()
|
||||
q.copy(a._buf.va_addr, b._buf.va_addr, SZ)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, ctypes, struct
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.device import Buffer, BufferOptions
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_runner, CompiledRunner
|
||||
@@ -149,8 +149,8 @@ class TestHCQ(unittest.TestCase):
|
||||
|
||||
runner = CompiledRunner(k.to_program())
|
||||
|
||||
zb = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferOptions(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
zt = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferOptions(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
zb = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
zt = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
ctypes.memset(zb._buf.va_addr, 0, zb.nbytes)
|
||||
kernargs = runner.clprg.fill_kernargs([zt._buf, zb._buf])
|
||||
|
||||
@@ -190,8 +190,8 @@ class TestHCQ(unittest.TestCase):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
sz = 64 << 20
|
||||
buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferOptions(host=True, nolru=True)).ensure_allocated()
|
||||
buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
|
||||
ctypes.memset(buf2._buf.va_addr, 1, sz)
|
||||
|
||||
TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
|
||||
@@ -224,8 +224,8 @@ class TestHCQ(unittest.TestCase):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
sz = 64 << 20
|
||||
buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferOptions(host=True, nolru=True)).ensure_allocated()
|
||||
buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
|
||||
ctypes.memset(buf2._buf.va_addr, 1, sz)
|
||||
|
||||
q = TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
|
||||
@@ -301,8 +301,8 @@ class TestHCQ(unittest.TestCase):
|
||||
|
||||
# THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least.
|
||||
SZ = 200_000_000
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
|
||||
b = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
b = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
|
||||
sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t()
|
||||
TestHCQ.d0.hw_copy_queue_t().timestamp(sig_st) \
|
||||
@@ -329,8 +329,8 @@ class TestHCQ(unittest.TestCase):
|
||||
TestHCQ.d0._prof_setup()
|
||||
|
||||
SZ = 200_000_000
|
||||
b = Buffer(f"{Device.DEFAULT}:1", SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferOptions(nolru=True)).allocate()
|
||||
b = Buffer(f"{Device.DEFAULT}:1", SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
TestHCQ.d0._gpu_map(b._buf)
|
||||
|
||||
sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t()
|
||||
@@ -367,8 +367,8 @@ class TestHCQ(unittest.TestCase):
|
||||
def test_small_copies_from_host_buf(self):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(host=True, nolru=True)).ensure_allocated()
|
||||
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
|
||||
|
||||
for i in range(256):
|
||||
ctypes.memset(buf2._buf.va_addr, i, 1)
|
||||
@@ -384,9 +384,9 @@ class TestHCQ(unittest.TestCase):
|
||||
def test_small_copies_from_host_buf_intercopy(self):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf3 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(host=True, nolru=True)).ensure_allocated()
|
||||
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf3 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
|
||||
|
||||
for i in range(256):
|
||||
ctypes.memset(buf3._buf.va_addr, i, 1)
|
||||
@@ -406,9 +406,9 @@ class TestHCQ(unittest.TestCase):
|
||||
try: _ = Device[f"{Device.DEFAULT}:1"]
|
||||
except Exception: self.skipTest("no multidevice, test skipped")
|
||||
|
||||
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(f"{Device.DEFAULT}:1", 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf3 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(host=True, nolru=True)).ensure_allocated()
|
||||
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(f"{Device.DEFAULT}:1", 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf3 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
|
||||
TestHCQ.d0.allocator.map(buf2._buf)
|
||||
|
||||
for i in range(256):
|
||||
@@ -428,8 +428,8 @@ class TestHCQ(unittest.TestCase):
|
||||
b = a + 1
|
||||
runner = get_runner(TestHCQ.d0.device, create_schedule([b.lazydata])[-1].ast)
|
||||
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferOptions(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
|
||||
kernargs_ptr = runner.clprg.fill_kernargs([buf1._buf, buf2._buf])
|
||||
|
||||
@@ -449,9 +449,9 @@ class TestHCQ(unittest.TestCase):
|
||||
def test_memory_barrier_before_copy(self):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf3 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferOptions(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
buf1 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf3 = Buffer(Device.DEFAULT, 1, dtypes.int8, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
|
||||
for i in range(256):
|
||||
ctypes.memset(buf3._buf.va_addr, i, 1)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, struct, contextlib, tempfile, pathlib, json, time, atexit, random
|
||||
from tinygrad import Device, Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import CI, getenv, Context
|
||||
from tinygrad.device import Buffer, BufferOptions
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.runtime.support.hcq import ProfileLogger, HCQCompiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_runner
|
||||
@@ -94,7 +94,7 @@ class TestProfiler(unittest.TestCase):
|
||||
helper_validate_node(kernel_node, profile=profile, pid_name=Device.DEFAULT, tid_name="COMPUTE")
|
||||
|
||||
def test_profile_copyin(self):
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
|
||||
with helper_collect_profile(TestProfiler.d0) as profile:
|
||||
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
|
||||
@@ -104,7 +104,7 @@ class TestProfiler(unittest.TestCase):
|
||||
|
||||
def test_profile_multiops(self):
|
||||
runner_name = TestProfiler.runner.clprg.name
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
|
||||
with helper_collect_profile(TestProfiler.d0) as profile:
|
||||
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
|
||||
@@ -125,8 +125,8 @@ class TestProfiler(unittest.TestCase):
|
||||
|
||||
def test_profile_multidev_copyin(self):
|
||||
d1 = Device[f"{Device.DEFAULT}:1"]
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(f"{Device.DEFAULT}:1", 2, dtypes.float, options=BufferOptions(nolru=True)).ensure_allocated()
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(f"{Device.DEFAULT}:1", 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
|
||||
with helper_collect_profile(TestProfiler.d0, d1) as profile:
|
||||
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
|
||||
@@ -197,7 +197,7 @@ class TestProfiler(unittest.TestCase):
|
||||
expected_diff = 100000 # sleep in us
|
||||
|
||||
devs = [Device[f"{Device.DEFAULT}:{i}"] for i in range(6)]
|
||||
bufs = [Buffer(f"{Device.DEFAULT}:{i}", 2, dtypes.float, options=BufferOptions(nolru=True)).ensure_allocated() for i in range(6)]
|
||||
bufs = [Buffer(f"{Device.DEFAULT}:{i}", 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated() for i in range(6)]
|
||||
|
||||
# enqueue ops on different queues to check the timer sync
|
||||
cpu_time = []
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import dedup, flatten, prod
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
|
||||
@@ -23,7 +23,7 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
||||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
|
||||
inbufs = [cast(LazyBuffer,x.lazydata).base.buffer for x in inputs]
|
||||
src = Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
ei = CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, local_size=local_size))
|
||||
ei = CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, uops=uops, local_size=local_size))
|
||||
ei.exec(outbufs+inbufs)
|
||||
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.helpers import CI, DEBUG, getenv, Context
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.engine.schedule import create_schedule, to_si
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
@@ -20,7 +20,7 @@ def _uops_to_prg(uops_list):
|
||||
uops = linearize_uop(full_graph_rewrite(UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer))
|
||||
src = Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
has_local = Device[Device.DEFAULT].renderer.has_local
|
||||
return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops,
|
||||
return CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, uops=uops,
|
||||
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
|
||||
|
||||
def uop(uops:List[UOp], uop:Ops, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.device import Device, BufferOptions
|
||||
from tinygrad.device import Device, BufferSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "QCOM", "QCOM device required to run")
|
||||
@@ -9,7 +9,7 @@ class TestQcom(unittest.TestCase):
|
||||
dev = Device["QCOM"]
|
||||
|
||||
def __validate(imgdt, expected_pitch):
|
||||
img = dev.allocator.alloc(imgdt.shape[0] * imgdt.shape[1] * 16, options:=BufferOptions(image=imgdt))
|
||||
img = dev.allocator.alloc(imgdt.shape[0] * imgdt.shape[1] * 16, options:=BufferSpec(image=imgdt))
|
||||
pitch = (img.descriptor[2] & 0x1fffff80) >> 7
|
||||
assert pitch == expected_pitch, f"Failed pitch for image: {imgdt}. Got 0x{pitch:X}, expected 0x{expected_pitch:X}"
|
||||
dev.allocator.free(img, imgdt.shape[0] * imgdt.shape[1] * 16, options)
|
||||
|
||||
@@ -8,7 +8,7 @@ from enum import Enum, auto
|
||||
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
|
||||
graph_rewrite, track_rewrites, UPat
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap
|
||||
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
|
||||
@@ -702,7 +702,7 @@ class Kernel:
|
||||
if DEBUG >= 5: print_uops(self.uops)
|
||||
return self
|
||||
|
||||
def to_program(self, name_override:Optional[str]=None) -> Program:
|
||||
def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
|
||||
self.linearize()
|
||||
src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops)
|
||||
|
||||
@@ -715,7 +715,7 @@ class Kernel:
|
||||
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
|
||||
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
|
||||
key=lambda x: (x.op, x.src[0].arg)))
|
||||
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
|
||||
return ProgramSpec(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
|
||||
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
||||
|
||||
# the living definition of intermediate UOps
|
||||
|
||||
@@ -44,7 +44,8 @@ Device = _Device()
|
||||
# **************** Buffer + Allocators ****************
|
||||
|
||||
@dataclass(frozen=True, eq=True)
|
||||
class BufferOptions:
|
||||
class BufferSpec:
|
||||
# TODO: move device, size, dtype here?
|
||||
image: Optional[ImageDType] = None
|
||||
uncached: bool = False
|
||||
cpu_access: bool = False
|
||||
@@ -53,9 +54,9 @@ class BufferOptions:
|
||||
external_ptr: Optional[int] = None
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferSpec]=None,
|
||||
initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
|
||||
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
||||
if isinstance(dtype, ImageDType): options = BufferSpec(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
||||
else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
|
||||
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
|
||||
if base is None:
|
||||
@@ -82,7 +83,7 @@ class Buffer:
|
||||
assert not self.is_allocated(), "can't allocate already allocated buffer"
|
||||
self.allocator = Device[self.device].allocator
|
||||
if external_ptr is not None:
|
||||
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferOptions(external_ptr=external_ptr)
|
||||
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr)
|
||||
if self._base is not None:
|
||||
self._base.ensure_allocated()
|
||||
assert hasattr(self.allocator, "_offset"), "offset function required for view"
|
||||
@@ -135,12 +136,15 @@ class Buffer:
|
||||
|
||||
# TODO: size, dest, src are the same type. can we enforce this?
|
||||
class Allocator:
|
||||
def alloc(self, size:int, options:Optional[BufferOptions]=None):
|
||||
# overriden in LRUAllocator
|
||||
def alloc(self, size:int, options:Optional[BufferSpec]=None):
|
||||
assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
|
||||
return self._alloc(size, options if options is not None else BufferOptions())
|
||||
def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
|
||||
def free(self, opaque, size:int, options:Optional[BufferOptions]=None): self._free(opaque, options if options is not None else BufferOptions())
|
||||
def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
|
||||
return self._alloc(size, options if options is not None else BufferSpec())
|
||||
def free(self, opaque, size:int, options:Optional[BufferSpec]=None): self._free(opaque, options if options is not None else BufferSpec())
|
||||
|
||||
# implemented by the runtime
|
||||
def _alloc(self, size:int, options:BufferSpec): raise NotImplementedError("need alloc")
|
||||
def _free(self, opaque, options:BufferSpec): pass # if opaque is a Python object, you don't need a free
|
||||
def _copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
|
||||
def _copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
|
||||
# def _as_buffer(self, src) -> memoryview:
|
||||
@@ -152,8 +156,8 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
||||
The LRU Allocator is responsible for caching buffers.
|
||||
It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
|
||||
"""
|
||||
def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
|
||||
def alloc(self, size:int, options:Optional[BufferOptions]=None):
|
||||
def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferSpec]], Any] = defaultdict(list)
|
||||
def alloc(self, size:int, options:Optional[BufferSpec]=None):
|
||||
if len(c := self.cache[(size, options)]): return c.pop()
|
||||
try: return super().alloc(size, options)
|
||||
except (RuntimeError, MemoryError):
|
||||
@@ -163,12 +167,12 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
||||
for (sz,options),opaques in self.cache.items():
|
||||
for opaque in opaques: super().free(opaque, sz, options)
|
||||
opaques.clear()
|
||||
def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
|
||||
def free(self, opaque:Any, size:int, options:Optional[BufferSpec]=None):
|
||||
if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
|
||||
else: super().free(opaque, size, options)
|
||||
|
||||
class _MallocAllocator(LRUAllocator):
|
||||
def _alloc(self, size:int, options:BufferOptions):
|
||||
def _alloc(self, size:int, options:BufferSpec):
|
||||
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else (ctypes.c_uint8 * size)()
|
||||
def _as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
||||
def _copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BE
|
||||
from tinygrad.ops import Ops, UOp, Variable, sym_infer, sint
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, Program
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
|
||||
@@ -75,9 +75,9 @@ class Runner:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
class CompiledRunner(Runner):
|
||||
def __init__(self, p:Program, precompiled:Optional[bytes]=None):
|
||||
def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None):
|
||||
if DEBUG >= 4: print(p.src)
|
||||
self.p:Program = p
|
||||
self.p:ProgramSpec = p
|
||||
self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src)
|
||||
if DEBUG >= 6: Device[p.device].compiler.disassemble(self.lib)
|
||||
self.clprg = Device[p.device].runtime(p.function_name, self.lib)
|
||||
@@ -148,7 +148,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
if bret:=method_cache.get(bkey):
|
||||
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
|
||||
else:
|
||||
prg: Program = get_kernel(Device[device].renderer, ast).to_program()
|
||||
prg: ProgramSpec = get_kernel(Device[device].renderer, ast).to_program()
|
||||
if getenv("FUZZ_UOPS"):
|
||||
from test.external.fuzz_uops import UOpsFuzzerRunner
|
||||
return UOpsFuzzerRunner(replace(prg, device=device))
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.dtype import ImageDType, PtrDType
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
|
||||
actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
|
||||
actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)]
|
||||
@@ -33,7 +33,7 @@ def _get_test_global_size(global_size, max_global_size, var_vals):
|
||||
break
|
||||
return test_global_size, factor
|
||||
|
||||
def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:List[Buffer], early_stop:Optional[float]=None,
|
||||
def _time_program(p:ProgramSpec, lib:bytes, var_vals:Dict[Variable, int], rawbufs:List[Buffer], early_stop:Optional[float]=None,
|
||||
max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> List[float]:
|
||||
factor = 1
|
||||
if p.global_size is not None and max_global_size is not None:
|
||||
@@ -55,7 +55,7 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li
|
||||
class TimeoutException(Exception): pass
|
||||
def timeout_handler(signum, frame): raise TimeoutException()
|
||||
|
||||
def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
|
||||
def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[ProgramSpec, bytes, float]]]:
|
||||
if hasattr(signal, "alarm"):
|
||||
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
|
||||
# set timeout
|
||||
|
||||
@@ -23,7 +23,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
|
||||
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
|
||||
|
||||
@dataclass
|
||||
class Program:
|
||||
class ProgramSpec:
|
||||
name:str
|
||||
src:str
|
||||
device:str
|
||||
|
||||
@@ -2,7 +2,7 @@ import collections, time
|
||||
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
||||
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState
|
||||
from tinygrad.device import Buffer, BufferOptions, Compiled, Device
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
@@ -17,7 +17,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
for ji in self.jit_cache:
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
kernargs_size[ji.prg.dev] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
|
||||
self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
|
||||
self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferSpec(cpu_access=True)) for dev,sz in kernargs_size.items()}
|
||||
|
||||
# Fill initial arguments.
|
||||
self.ji_args: Dict[int, HCQArgsState] = {}
|
||||
@@ -197,4 +197,4 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
|
||||
|
||||
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferOptions(cpu_access=True))
|
||||
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))
|
||||
|
||||
@@ -4,7 +4,7 @@ import os, ctypes, ctypes.util, functools, pathlib, mmap, errno, time, array, co
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram
|
||||
from tinygrad.device import BufferOptions
|
||||
from tinygrad.device import BufferSpec
|
||||
from tinygrad.helpers import getenv, to_mv, round_up, data64_le, mv_address
|
||||
from tinygrad.renderer.cstyle import AMDRenderer
|
||||
from tinygrad.runtime.autogen import kfd, hsa, amd_gpu, libc
|
||||
@@ -61,7 +61,7 @@ class AMDComputeQueue(HWQueue): # pylint: disable=abstract-method
|
||||
|
||||
def __del__(self):
|
||||
if self.binded_device is not None:
|
||||
self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferOptions(cpu_access=True, nolru=True, uncached=True))
|
||||
self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True, uncached=True))
|
||||
|
||||
def _acquire_mem(self, addr=0x0, sz=(1 << 64)-1, gli=1, glm=1, glk=1, glv=1, gl1=1, gl2=1):
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_ACQUIRE_MEM, 6), 0, *data64_le(sz), *data64_le(addr), 0,
|
||||
@@ -166,7 +166,7 @@ class AMDComputeQueue(HWQueue): # pylint: disable=abstract-method
|
||||
|
||||
def bind(self, dev:AMDDevice):
|
||||
self.binded_device = dev
|
||||
self.hw_page = dev.allocator.alloc(len(self.q) * 4, BufferOptions(cpu_access=True, nolru=True, uncached=True))
|
||||
self.hw_page = dev.allocator.alloc(len(self.q) * 4, BufferSpec(cpu_access=True, nolru=True, uncached=True))
|
||||
hw_view = to_mv(self.hw_page.va_addr, self.hw_page.size).cast("I")
|
||||
for i, value in enumerate(self.q): hw_view[i] = value
|
||||
|
||||
@@ -274,7 +274,7 @@ class AMDProgram(HCQProgram):
|
||||
self.dev: AMDDevice = dev
|
||||
self.name, self.lib = name, lib
|
||||
image, sections, _ = elf_loader(self.lib)
|
||||
self.lib_gpu = self.dev.allocator.alloc(round_up(image.nbytes, 0x1000), BufferOptions(cpu_access=True, nolru=True))
|
||||
self.lib_gpu = self.dev.allocator.alloc(round_up(image.nbytes, 0x1000), BufferSpec(cpu_access=True, nolru=True))
|
||||
ctypes.memmove(self.lib_gpu.va_addr, mv_address(image), image.nbytes)
|
||||
|
||||
entry_point = min(sh.header.sh_addr for sh in sections if sh.header.sh_type == libc.SHT_PROGBITS and sh.header.sh_flags & libc.SHF_ALLOC)
|
||||
@@ -303,17 +303,17 @@ class AMDProgram(HCQProgram):
|
||||
super().__init__(AMDArgsState, self.dev, self.name, kernargs_alloc_size=self.kernargs_segment_size+additional_alloc_sz)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, BufferOptions(cpu_access=True, nolru=True))
|
||||
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, BufferSpec(cpu_access=True, nolru=True))
|
||||
|
||||
class AMDAllocator(HCQAllocator['AMDDevice']):
|
||||
def __init__(self, dev:AMDDevice): super().__init__(dev, batch_size=SDMA_MAX_COPY_SIZE)
|
||||
|
||||
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
|
||||
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
|
||||
if options.host: return self.dev._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR, public=True)
|
||||
if options.cpu_access and options.uncached: return self.dev._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
|
||||
return self.dev._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM, public=options.cpu_access)
|
||||
|
||||
def _free(self, opaque, options:BufferOptions):
|
||||
def _free(self, opaque, options:BufferSpec):
|
||||
self.dev.synchronize()
|
||||
self.dev._gpu_free(opaque)
|
||||
|
||||
|
||||
@@ -13,14 +13,14 @@ from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
|
||||
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
|
||||
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferSpec
|
||||
|
||||
# ***** API *****
|
||||
|
||||
class CloudRequest: pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferOptions # noqa: E702
|
||||
class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferFree(CloudRequest): buffer_num: int # noqa: E702
|
||||
@@ -43,7 +43,7 @@ class ProgramExec(CloudRequest):
|
||||
global_size: Optional[Tuple[int, ...]]; local_size: Optional[Tuple[int, ...]]; wait: bool # noqa: E702
|
||||
|
||||
# for safe deserialization
|
||||
whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferOptions]}
|
||||
whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferSpec]}
|
||||
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
|
||||
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
|
||||
ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]}
|
||||
@@ -76,7 +76,7 @@ class BatchRequest:
|
||||
class CloudSession:
|
||||
programs: Dict[Tuple[str, str], Any] = field(default_factory=dict)
|
||||
# TODO: the buffer should track this internally
|
||||
buffers: Dict[int, Tuple[Any, int, Optional[BufferOptions]]] = field(default_factory=dict)
|
||||
buffers: Dict[int, Tuple[Any, int, Optional[BufferSpec]]] = field(default_factory=dict)
|
||||
|
||||
class CloudHandler(BaseHTTPRequestHandler):
|
||||
protocol_version = 'HTTP/1.1'
|
||||
@@ -143,7 +143,7 @@ class CloudAllocator(Allocator):
|
||||
self.device = dev
|
||||
super().__init__()
|
||||
# TODO: ideally we shouldn't have to deal with images here
|
||||
def _alloc(self, size:int, options:BufferOptions) -> int:
|
||||
def _alloc(self, size:int, options:BufferSpec) -> int:
|
||||
self.device.buffer_num += 1
|
||||
self.device.req.q(BufferAlloc(self.device.buffer_num, size, options))
|
||||
return self.device.buffer_num
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import ctypes, ctypes.util, functools
|
||||
from typing import Tuple, Optional, List
|
||||
from tinygrad.helpers import DEBUG, getenv, from_mv, init_c_var, init_c_struct_t
|
||||
from tinygrad.device import Compiled, BufferOptions, LRUAllocator
|
||||
from tinygrad.device import Compiled, BufferSpec, LRUAllocator
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.runtime.autogen import cuda
|
||||
@@ -63,17 +63,17 @@ class CUDAAllocator(LRUAllocator):
|
||||
def __init__(self, device:CUDADevice):
|
||||
self.device = device
|
||||
super().__init__()
|
||||
def _alloc(self, size, options:BufferOptions):
|
||||
def _alloc(self, size, options:BufferSpec):
|
||||
check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
if options.host: return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0x01)))
|
||||
return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size)))
|
||||
def _free(self, opaque, options:BufferOptions):
|
||||
def _free(self, opaque, options:BufferSpec):
|
||||
if options.host: check(cuda.cuMemFreeHost(opaque))
|
||||
else: check(cuda.cuMemFree_v2(opaque))
|
||||
def _copyin(self, dest, src:memoryview):
|
||||
check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
host_mem = self.alloc(len(src), BufferOptions(host=True))
|
||||
self.device.pending_copyin.append((host_mem, len(src), BufferOptions(host=True)))
|
||||
host_mem = self.alloc(len(src), BufferSpec(host=True))
|
||||
self.device.pending_copyin.append((host_mem, len(src), BufferSpec(host=True)))
|
||||
ctypes.memmove(host_mem, from_mv(src), len(src))
|
||||
check(cuda.cuMemcpyHtoDAsync_v2(dest, host_mem, len(src), None))
|
||||
def _copyout(self, dest:memoryview, src):
|
||||
@@ -110,7 +110,7 @@ class CUDADevice(Compiled):
|
||||
CUDADevice.peer_access = True
|
||||
|
||||
self.arch = f"sm_{major.value}{minor.value}"
|
||||
self.pending_copyin: List[Tuple[int, int, Optional[BufferOptions]]] = []
|
||||
self.pending_copyin: List[Tuple[int, int, Optional[BufferSpec]]] = []
|
||||
CUDADevice.devices.append(self)
|
||||
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Tuple, Any
|
||||
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys
|
||||
assert sys.platform != 'win32'
|
||||
from tinygrad.device import BufferOptions, Compiled, Allocator
|
||||
from tinygrad.device import BufferSpec, Compiled, Allocator
|
||||
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv
|
||||
from tinygrad.runtime.ops_clang import ClangCompiler
|
||||
from tinygrad.renderer.cstyle import DSPRenderer
|
||||
@@ -43,13 +43,13 @@ class DSPAllocator(Allocator):
|
||||
self.dev = dev
|
||||
super().__init__()
|
||||
|
||||
def _alloc(self, size:int, options:BufferOptions):
|
||||
def _alloc(self, size:int, options:BufferSpec):
|
||||
b = qcom_dsp.ION_IOC_ALLOC(self.dev.ion_fd, len=size, align=0x200, heap_id_mask=1<<qcom_dsp.ION_SYSTEM_HEAP_ID, flags=qcom_dsp.ION_FLAG_CACHED)
|
||||
share_info = qcom_dsp.ION_IOC_SHARE(self.dev.ion_fd, handle=b.handle)
|
||||
va_addr = libc.mmap(0, size, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED, share_info.fd, 0)
|
||||
return DSPBuffer(va_addr, size, share_info, offset=0)
|
||||
|
||||
def _free(self, opaque:DSPBuffer, options:BufferOptions):
|
||||
def _free(self, opaque:DSPBuffer, options:BufferSpec):
|
||||
libc.munmap(opaque.va_addr, opaque.size)
|
||||
os.close(opaque.share_info.fd)
|
||||
qcom_dsp.ION_IOC_FREE(self.dev.ion_fd, handle=opaque.share_info.handle)
|
||||
@@ -75,7 +75,7 @@ class DSPDevice(Compiled):
|
||||
ClangCompiler("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump'), functools.partial(DSPProgram, self))
|
||||
|
||||
fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes()))
|
||||
self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferOptions(nolru=True))
|
||||
self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True))
|
||||
ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes)
|
||||
|
||||
self.init_dsp()
|
||||
|
||||
@@ -4,7 +4,7 @@ import ctypes, functools, hashlib
|
||||
from tinygrad.runtime.autogen import opencl as cl
|
||||
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, getenv, mv_address
|
||||
from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer
|
||||
from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler, CompileError
|
||||
from tinygrad.device import BufferSpec, LRUAllocator, Compiled, Compiler, CompileError
|
||||
|
||||
# see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
|
||||
OSX_TIMING_RATIO = (125/3) if OSX else 1.0
|
||||
@@ -44,7 +44,7 @@ class CLProgram:
|
||||
if hasattr(self, 'kernel'): check(cl.clReleaseKernel(self.kernel))
|
||||
if hasattr(self, 'program'): check(cl.clReleaseProgram(self.program))
|
||||
|
||||
def __call__(self, *bufs:Tuple[ctypes._CData, BufferOptions], global_size:Tuple[int,int,int]=(1,1,1), local_size:Optional[Tuple[int,int,int]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]: # noqa: E501
|
||||
def __call__(self, *bufs:Tuple[ctypes._CData, BufferSpec], global_size:Tuple[int,int,int]=(1,1,1), local_size:Optional[Tuple[int,int,int]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]: # noqa: E501
|
||||
for i,(b,_) in enumerate(bufs): cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))
|
||||
for i,v in enumerate(vals,start=len(bufs)): cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v)))
|
||||
if local_size is not None: global_size = cast(Tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))
|
||||
@@ -62,14 +62,14 @@ class CLAllocator(LRUAllocator):
|
||||
def __init__(self, dev:CLDevice):
|
||||
self.dev = dev
|
||||
super().__init__()
|
||||
def _alloc(self, size:int, options:BufferOptions) -> Tuple[ctypes._CData, BufferOptions]:
|
||||
def _alloc(self, size:int, options:BufferSpec) -> Tuple[ctypes._CData, BufferSpec]:
|
||||
if options.image is not None:
|
||||
return (checked(cl.clCreateImage2D(self.dev.context, cl.CL_MEM_READ_WRITE,
|
||||
cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]),
|
||||
options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status), options)
|
||||
return (checked(cl.clCreateBuffer(self.dev.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status), options)
|
||||
def _free(self, opaque:Tuple[ctypes._CData, BufferOptions], options:BufferOptions): check(cl.clReleaseMemObject(opaque[0]))
|
||||
def _copyin(self, dest:Tuple[ctypes._CData, BufferOptions], src:memoryview):
|
||||
def _free(self, opaque:Tuple[ctypes._CData, BufferSpec], options:BufferSpec): check(cl.clReleaseMemObject(opaque[0]))
|
||||
def _copyin(self, dest:Tuple[ctypes._CData, BufferSpec], src:memoryview):
|
||||
if dest[1].image is not None:
|
||||
check(cl.clEnqueueWriteImage(self.dev.queue, dest[0], False, (ctypes.c_size_t * 3)(0,0,0),
|
||||
(ctypes.c_size_t * 3)(dest[1].image.shape[1],dest[1].image.shape[0],1), 0, 0, from_mv(src), 0, None, None))
|
||||
@@ -77,7 +77,7 @@ class CLAllocator(LRUAllocator):
|
||||
if mv_address(src) % 16: src = memoryview(bytearray(src))
|
||||
check(cl.clEnqueueWriteBuffer(self.dev.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
|
||||
self.dev.pending_copyin.append(src) # NOTE: these can't be freed until the GPU actually executes this command
|
||||
def _copyout(self, dest:memoryview, src:Tuple[ctypes._CData, BufferOptions]):
|
||||
def _copyout(self, dest:memoryview, src:Tuple[ctypes._CData, BufferSpec]):
|
||||
if src[1].image is not None:
|
||||
check(cl.clEnqueueReadImage(self.dev.queue, src[0], False, (ctypes.c_size_t * 3)(0,0,0),
|
||||
(ctypes.c_size_t * 3)(src[1].image.shape[1],src[1].image.shape[0],1), 0, 0, from_mv(dest), 0, None, None))
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import ctypes, functools
|
||||
from typing import Tuple
|
||||
from tinygrad.helpers import init_c_var, from_mv, init_c_struct_t, getenv
|
||||
from tinygrad.device import Compiled, LRUAllocator, BufferOptions
|
||||
from tinygrad.device import Compiled, LRUAllocator, BufferSpec
|
||||
from tinygrad.runtime.autogen import hip
|
||||
from tinygrad.runtime.support.compiler_hip import AMDCompiler
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
@@ -46,10 +46,10 @@ class HIPAllocator(LRUAllocator):
|
||||
def __init__(self, dev:HIPDevice):
|
||||
self.dev = dev
|
||||
super().__init__()
|
||||
def _alloc(self, size:int, options:BufferOptions):
|
||||
def _alloc(self, size:int, options:BufferSpec):
|
||||
check(hip.hipSetDevice(self.dev.device_id))
|
||||
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
|
||||
def _free(self, opaque, options:BufferOptions): check(hip.hipFree(opaque))
|
||||
def _free(self, opaque, options:BufferSpec): check(hip.hipFree(opaque))
|
||||
def _copyin(self, dest, src: memoryview):
|
||||
check(hip.hipSetDevice(self.dev.device_id))
|
||||
check(hip.hipMemcpy(dest, from_mv(src), len(src), hip.hipMemcpyHostToDevice))
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Tuple, List, Any, cast, Union, Dict, Type, Optional
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, hcq_command
|
||||
from tinygrad.runtime.support.hcq import HCQArgsState, HCQProgram, HCQSignal
|
||||
from tinygrad.device import BufferOptions
|
||||
from tinygrad.device import BufferSpec
|
||||
from tinygrad.helpers import getenv, mv_address, init_c_struct_t, to_mv, round_up, data64, data64_le, DEBUG, prod
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.cstyle import NVRenderer
|
||||
@@ -85,7 +85,7 @@ class NVSignal(HCQSignal):
|
||||
|
||||
class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']): # pylint: disable=abstract-method
|
||||
def __del__(self):
|
||||
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferOptions(cpu_access=True, nolru=True))
|
||||
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True))
|
||||
|
||||
@hcq_command
|
||||
def setup(self, compute_class=None, copy_class=None, local_mem_window=None, shared_mem_window=None, local_mem=None, local_mem_tpc_bytes=None):
|
||||
@@ -108,7 +108,7 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
|
||||
|
||||
def bind(self, dev:NVDevice):
|
||||
self.binded_device = dev
|
||||
self.hw_page = dev.allocator.alloc(len(self.q) * 4, BufferOptions(cpu_access=True, nolru=True))
|
||||
self.hw_page = dev.allocator.alloc(len(self.q) * 4, BufferSpec(cpu_access=True, nolru=True))
|
||||
hw_view = to_mv(self.hw_page.va_addr, self.hw_page.size).cast("I")
|
||||
for i, value in enumerate(self.q): hw_view[i] = value
|
||||
|
||||
@@ -228,7 +228,7 @@ class NVProgram(HCQProgram):
|
||||
else: image, sections, relocs = elf_loader(self.lib, force_section_align=128)
|
||||
|
||||
# NOTE: Ensure at least 4KB of space after the program to mitigate prefetch memory faults.
|
||||
self.lib_gpu = self.dev.allocator.alloc(round_up(image.nbytes, 0x1000) + 0x1000, BufferOptions(cpu_access=True))
|
||||
self.lib_gpu = self.dev.allocator.alloc(round_up(image.nbytes, 0x1000) + 0x1000, BufferSpec(cpu_access=True))
|
||||
|
||||
self.prog_addr, self.prog_sz, self.regs_usage, self.shmem_usage, self.lcmem_usage = self.lib_gpu.va_addr, image.nbytes, 0, 0x400, 0
|
||||
self.constbufs: Dict[int, Tuple[int, int]] = {0: (0, 0x160)} # Dict[constbuf index, Tuple[va_addr, size]]
|
||||
@@ -281,7 +281,7 @@ class NVProgram(HCQProgram):
|
||||
super().__init__(NVArgsState, self.dev, self.name, kernargs_alloc_size=round_up(self.constbufs[0][1], 1 << 8) + (8 << 8))
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, BufferOptions(cpu_access=True))
|
||||
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, BufferSpec(cpu_access=True))
|
||||
|
||||
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if prod(local_size) > 1024 or self.max_threads < prod(local_size) or self.lcmem_usage > cast(NVDevice, self.dev).slm_per_thread:
|
||||
@@ -291,11 +291,11 @@ class NVProgram(HCQProgram):
|
||||
return super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
|
||||
|
||||
class NVAllocator(HCQAllocator['NVDevice']):
|
||||
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
|
||||
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
|
||||
if options.host: return self.dev._gpu_host_alloc(size, tag="user host memory")
|
||||
return self.dev._gpu_alloc(size, map_to_cpu=options.cpu_access, huge_page=(size > (16 << 20)), tag=f"user memory ({options})")
|
||||
|
||||
def _free(self, opaque, options:BufferOptions):
|
||||
def _free(self, opaque, options:BufferSpec):
|
||||
self.dev.synchronize()
|
||||
self.dev._gpu_free(opaque)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import os, ctypes, functools, mmap, struct, array, decimal, math, sys
|
||||
assert sys.platform != 'win32'
|
||||
from types import SimpleNamespace
|
||||
from typing import Tuple, List, Any, cast, Optional
|
||||
from tinygrad.device import BufferOptions
|
||||
from tinygrad.device import BufferSpec
|
||||
from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQSignal, HCQAllocator, HCQArgsState
|
||||
from tinygrad.runtime.autogen import kgsl, adreno, libc
|
||||
from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice
|
||||
@@ -50,7 +50,7 @@ class QCOMComputeQueue(HWQueue): # pylint: disable=abstract-method
|
||||
super().__init__()
|
||||
|
||||
def __del__(self):
|
||||
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferOptions(cpu_access=True, nolru=True))
|
||||
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True))
|
||||
|
||||
def cmd(self, opcode: int, *vals: int): self.q += [pkt7_hdr(opcode, len(vals)), *vals]
|
||||
|
||||
@@ -98,7 +98,7 @@ class QCOMComputeQueue(HWQueue): # pylint: disable=abstract-method
|
||||
|
||||
def bind(self, dev:QCOMDevice):
|
||||
self.binded_device = dev
|
||||
self.hw_page = dev.allocator.alloc(len(self.q) * 4, BufferOptions(cpu_access=True, nolru=True))
|
||||
self.hw_page = dev.allocator.alloc(len(self.q) * 4, BufferSpec(cpu_access=True, nolru=True))
|
||||
self.submit_req, self.obj = self._build_gpu_command(self.binded_device, self.hw_page.va_addr)
|
||||
# From now on, the queue is on the device for faster submission.
|
||||
self.q = to_mv(self.obj.gpuaddr, len(self.q) * 4).cast("I") # type: ignore
|
||||
@@ -213,7 +213,7 @@ class QCOMProgram(HCQProgram):
|
||||
self.name, self.lib = name, lib
|
||||
self._parse_lib()
|
||||
|
||||
self.lib_gpu: HCQBuffer = self.dev.allocator.alloc(self.image_size, options=BufferOptions(cpu_access=True, nolru=True))
|
||||
self.lib_gpu: HCQBuffer = self.dev.allocator.alloc(self.image_size, options=BufferSpec(cpu_access=True, nolru=True))
|
||||
to_mv(self.lib_gpu.va_addr, self.image_size)[:] = self.image
|
||||
|
||||
self.pvtmem_size_per_item: int = round_up(self.pvtmem, 512) >> 9
|
||||
@@ -283,7 +283,7 @@ class QCOMProgram(HCQProgram):
|
||||
self.fregs, self.hregs = _read_lib(reg_desc_off + 0x14), _read_lib(reg_desc_off + 0x18)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferOptions(cpu_access=True, nolru=True))
|
||||
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferSpec(cpu_access=True, nolru=True))
|
||||
|
||||
class QCOMBuffer(HCQBuffer):
|
||||
def __init__(self, va_addr:int, size:int, info=None, mapped=False, desc=None, ibo=None, pitch=None, real_stride=None, **kwargs):
|
||||
@@ -293,7 +293,7 @@ class QCOMBuffer(HCQBuffer):
|
||||
self.desc, self.ibo, self.pitch, self.real_stride = [0] * 16, [0] * 16, pitch, real_stride
|
||||
|
||||
class QCOMAllocator(HCQAllocator):
|
||||
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
|
||||
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
|
||||
if options.image is not None:
|
||||
imgw, imgh, itemsize_log = options.image.shape[1], options.image.shape[0], int(math.log2(options.image.itemsize))
|
||||
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
|
||||
@@ -337,7 +337,7 @@ class QCOMAllocator(HCQAllocator):
|
||||
self.dev.synchronize()
|
||||
return to_mv(src.va_addr, src.size)
|
||||
|
||||
def _free(self, opaque, options:BufferOptions):
|
||||
def _free(self, opaque, options:BufferSpec):
|
||||
self.dev.synchronize()
|
||||
self.dev._gpu_free(opaque)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, Typ
|
||||
import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes, functools
|
||||
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.device import BufferOptions, Compiler, Compiled, LRUAllocator
|
||||
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator
|
||||
|
||||
# **************** for HCQ Compatible Devices ****************
|
||||
|
||||
@@ -372,7 +372,7 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
||||
|
||||
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferOptions(cpu_access=True))
|
||||
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferSpec(cpu_access=True))
|
||||
self.kernargs_ptr:int = self.kernargs_page.va_addr
|
||||
self.devices.append(self)
|
||||
|
||||
@@ -483,11 +483,11 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]): # pylint: disable=abstrac
|
||||
|
||||
def __init__(self, dev:DeviceType, batch_size:int=(2 << 20), batch_cnt:int=32):
|
||||
self.dev:DeviceType = dev
|
||||
self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
|
||||
self.b = [self._alloc(batch_size, BufferSpec(host=True)) for _ in range(batch_cnt)]
|
||||
self.b_timeline, self.b_next = [0] * len(self.b), 0
|
||||
super().__init__()
|
||||
|
||||
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
|
||||
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
|
||||
|
||||
def _copyin(self, dest:HCQBuffer, src:memoryview):
|
||||
assert self.dev.hw_copy_queue_t is not None
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_u
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait
|
||||
from tinygrad.device import Device, Buffer, BufferOptions
|
||||
from tinygrad.device import Device, Buffer, BufferSpec
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
@@ -259,7 +259,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
# NOTE: this realizes on the object from as_buffer being a Python object
|
||||
cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
|
||||
buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
|
||||
if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
|
||||
if self.device != "CLANG": buf.options = BufferSpec(nolru=True)
|
||||
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
|
||||
|
||||
def data(self) -> memoryview:
|
||||
|
||||
Reference in New Issue
Block a user