From bd111411bf50b7bdad07b44407ee990574a86c90 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 17 Aug 2023 20:33:32 +0300 Subject: [PATCH] init allocator for compiled backends (#1467) * init allocator for compiled backends * Update ops_webgpu.py --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- .../external_test_allocator_on_models.py | 135 ++++++++++++++++++ test/test_allocators.py | 106 ++++++++++++++ tinygrad/helpers.py | 1 + tinygrad/runtime/lib.py | 51 ++++++- tinygrad/runtime/ops_cuda.py | 8 +- tinygrad/runtime/ops_gpu.py | 28 ++-- tinygrad/runtime/ops_hip.py | 11 +- tinygrad/runtime/ops_metal.py | 13 +- tinygrad/runtime/ops_webgpu.py | 29 ++-- 9 files changed, 343 insertions(+), 39 deletions(-) create mode 100644 test/external/external_test_allocator_on_models.py create mode 100644 test/test_allocators.py diff --git a/test/external/external_test_allocator_on_models.py b/test/external/external_test_allocator_on_models.py new file mode 100644 index 0000000000..36717945d3 --- /dev/null +++ b/test/external/external_test_allocator_on_models.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +import unittest, gc +import numpy as np +from tinygrad.tensor import Tensor +from tinygrad.state import get_parameters, get_state_dict +from tinygrad.ops import GlobalCounters, LazyOp, LoadOps +from tinygrad.runtime.lib import RawBuffer, LRUAllocator +from tinygrad.helpers import dtypes, prod +from tinygrad.lazy import Device + +from examples.llama import Transformer + +ALLOCATED_DEV_BUFS = 0 +class FakeDeviceBuffer(): + def __init__(self, sz, dt, device): + self.id = 1 + self.size = sz + self.dtype = dt + self.device = device + + global ALLOCATED_DEV_BUFS + ALLOCATED_DEV_BUFS += 1 +class FakeAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device) + def _do_free(self, buf): + buf.id -= 1 + assert buf.id == 0, f"Free should be called once, but {buf.id}" + +FAKE_GLOBAL_ALLOCATOR = None +class FakeBuffer(RawBuffer): + def __init__(self, size, dtype, device='0'): + global FAKE_GLOBAL_ALLOCATOR + super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device}) + assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size." + @classmethod + def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs) + def toCPU(self): return np.empty(self.size, dtype=self.dtype.np) +class FakeProgram: + def __init__(self, name:str, prg:str): pass + def __call__(self, global_size, local_size, *bufs, wait=False): pass + +def helper_test_correctness(gen, train): + from tinygrad.runtime.ops_gpu import CL, CLAllocator + old_alloc = CL.cl_allocator + CL.cl_allocator = CLAllocator(0) + no_alloc_result = train(*gen()).numpy() + Device[Device.DEFAULT].synchronize() + CL.cl_allocator = CLAllocator(512<<30) # Test cache correctness, so cache as much as possible, 512gb + for _ in range(4): + GlobalCounters.reset() + np.testing.assert_allclose(train(*gen()).numpy(), no_alloc_result, rtol=1e-3, atol=1e-5) + Device[Device.DEFAULT].synchronize() + assert len(CL.cl_allocator.cached_buffers) != 0, "Cache must be used" + CL.cl_allocator = old_alloc + +def __helper_test_alloc_count(gen, train): + was_alloc = ALLOCATED_DEV_BUFS + for _ in range(2): + train(*gen()) + return ALLOCATED_DEV_BUFS - was_alloc + +def helper_test_alloc_count(mm, gen, train): + global FAKE_GLOBAL_ALLOCATOR + backup_program = Device[Device.DEFAULT].runtime + backup_buffer = Device[Device.DEFAULT].buffer + Device[Device.DEFAULT].runtime = FakeProgram + Device[Device.DEFAULT].buffer = FakeBuffer + Device[Device.DEFAULT].method_cache.clear() + FAKE_GLOBAL_ALLOCATOR = FakeAllocator(16<<30) + new_allocs = __helper_test_alloc_count(gen, train) + Device[Device.DEFAULT].method_cache.clear() + FAKE_GLOBAL_ALLOCATOR = FakeAllocator(0) + old_allocs = __helper_test_alloc_count(gen, train) + print(f"{mm}: llama: old allocs count {old_allocs}, new allocs count {new_allocs}") + assert new_allocs < old_allocs, f"Hmm, doesn't cache work any more?" + Device[Device.DEFAULT].runtime = backup_program + Device[Device.DEFAULT].buffer = backup_buffer + FAKE_GLOBAL_ALLOCATOR = None + +def check_gc(): + if Device.DEFAULT == "GPU": + gc.collect() # Need to collect Tensors. + from extra.introspection import print_objects + assert print_objects() == 0 + +# for speed +def derandomize(x): + if isinstance(x, LazyOp): + if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY + x.src = [derandomize(s) for s in x.src] + else: + x.op = derandomize(x.op) + return x + +def derandomize_model(model): + for p in get_parameters(model): + p.lazydata = derandomize(p.lazydata) + p.realize() + +class TestAllocators(unittest.TestCase): + @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") + def test_lru_allocator_tiny_llama(self): + old_type = Tensor.default_type + Tensor.default_type = dtypes.float16 + + args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000} + def __test(): + model = Transformer(**args_tiny) + derandomize_model(model) + def test(t): return model(t, 0).realize() + helper_test_correctness(lambda: (Tensor([[1,]]),), test) + __test() + Tensor.default_type = old_type + check_gc() + + @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") + def test_lru_allocator_tiny_llama_alloc_counts(self): + args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000} + def test_alloc_count(t): + model = Transformer(**args_tiny) + for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype)) + return model(t, 0).realize() + helper_test_alloc_count("llama", lambda: (Tensor([[2,]]),), test_alloc_count) + check_gc() + + @unittest.skip("huge for CI") + def test_stable_diffusion(self): + from examples.stable_diffusion import UNetModel + model = UNetModel() + derandomize_model(model) + def test(t, t2): return model(t, 801, t2).realize() + helper_test_correctness(lambda: (Tensor.randn(1, 4, 16, 16),Tensor.randn(1, 77, 768)), test) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_allocators.py b/test/test_allocators.py new file mode 100644 index 0000000000..5480debbc4 --- /dev/null +++ b/test/test_allocators.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +import unittest +import numpy as np +from weakref import ref +from tinygrad.ops import GlobalCounters +from tinygrad.runtime.lib import RawBuffer, LRUAllocator +from tinygrad.helpers import dtypes, prod +from tinygrad.lazy import Device + +def check_gc(): + if Device.DEFAULT == "GPU": + from extra.introspection import print_objects + assert print_objects() == 0 + +class FakeDeviceBuffer(): + def __init__(self, sz, dt, device): + self.id = 1 + self.size = sz + self.dtype = dt + self.device = device + def __del__(self): + assert self.id == 0, "Should called _do_free() before" + +class FakeAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device) + def _do_free(self, buf): + buf.id -= 1 + assert buf.id == 0, f"Free should be called once, but {buf.id}" + +FAKE_GLOBAL_ALLOCATOR = None +class FakeBuffer(RawBuffer): + def __init__(self, size, dtype, device='0'): + global FAKE_GLOBAL_ALLOCATOR + super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device}) + assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size." + @classmethod + def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs) + def toCPU(self): return np.empty(self.size, dtype=self.dtype.np) + +def alloc(allocator, size, dtype, **kwargs): + global FAKE_GLOBAL_ALLOCATOR + FAKE_GLOBAL_ALLOCATOR = allocator + buf = FakeBuffer(size, dtype, **kwargs) + assert buf.dtype == dtype and buf.size == size + FAKE_GLOBAL_ALLOCATOR = None + return buf + +def alloc_free_trace(allocator, size, dtype, **kwargs): + buf = alloc(allocator, size, dtype, **kwargs) + return ref(buf._buf) + +def cmp_trace_and_buf(buf, trace_ref): return trace_ref and trace_ref() == buf._buf + +class TestAllocators(unittest.TestCase): + def test_lru_allocator_reusage(self): + def test(): + lru_allocator = FakeAllocator(2048) + traced_buf = alloc_free_trace(lru_allocator, 16, dtypes.float32) + assert GlobalCounters.mem_cached == 16*dtypes.float32.itemsize, "Buffer should be cached" + for _ in range(32): + def __test(): + buf = alloc(lru_allocator, 16, dtypes.float32) + assert cmp_trace_and_buf(buf, traced_buf), "Buffer should be reused" + __test() + + usedbuf = alloc(lru_allocator, 16, dtypes.float32) + for _ in range(32): + def __test(): + buf = alloc(lru_allocator, 16, dtypes.float32) + assert usedbuf != buf, "Nobody should get used buffer" + __test() + assert GlobalCounters.mem_used == 16*dtypes.float32.itemsize, "Only usedbuf is still allocated." + test() + check_gc() + + def test_lru_allocator_cache_free(self): + def test(): + lru_allocator = FakeAllocator(128) + refs = [] + for _ in range(32): + refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32)) + for sz in range(32): + alloc_free_trace(lru_allocator, sz, dtypes.float32) + assert GlobalCounters.mem_used + GlobalCounters.mem_cached <= 128, "Should not allocate on device more than allowed (128)" + for r in refs: assert r() is None, "All refs should be dead, since buffers were cleared from cache" + test() + check_gc() + + def test_lru_allocator_multidevice(self): + def test(): + lru_allocator = FakeAllocator(256) + refs=[] + for i in range(8): + refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32, device=str(i))) + for i in range(64): + def __test(): + dev = str(i % 8) + buf = alloc(lru_allocator, 16, dtypes.float32, device=dev) + assert cmp_trace_and_buf(buf, refs[i%8]), "Buffer should be reused" + __test() + for r in refs: assert r() is not None, "All refs should be cached" + test() + check_gc() + +if __name__ == "__main__": + unittest.main() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index c5cad3a940..1329f5af09 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -130,6 +130,7 @@ class GlobalCounters: time_sum_s: ClassVar[float] = 0.0 kernel_count: ClassVar[int] = 0 mem_used: ClassVar[int] = 0 # NOTE: this is not reset + mem_cached: ClassVar[int] = 0 # NOTE: this is not reset cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None @staticmethod def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index 6bbb553b5d..00dff3cb25 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -1,18 +1,21 @@ import ctypes import numpy as np -from typing import TypeVar, Type, Any -from tinygrad.helpers import DType, dtypes, prod, GlobalCounters +from collections import defaultdict, deque +from typing import TypeVar, Type, Any, Dict, Deque, Tuple +from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, ImageDType _T = TypeVar("_T") class RawBuffer: # pylint: disable=abstract-method - def __init__(self, size:int, dtype:DType, buf:Any=None): + def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs): self.size: int = size self.dtype: DType = dtype - self._buf = buf + self._buf = buf if buf is not None else (allocator.alloc(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator. self._memsz: int = size*dtype.itemsize + self._allocator = allocator GlobalCounters.mem_used += self._memsz def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz + if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf) def __repr__(self): return f"buffer<{self.size}, {self.dtype}>" @property def key(self): return (self.size, self.dtype.key) @@ -66,3 +69,43 @@ class RawConst(RawBuffer): # pylint: disable=abstract-method def buf_is_kernel_arg(x) -> bool: return x.realized is not None and x.realized.__class__ is not RawConst + +class LRUAllocator: + def __init__(self, dev_memsz=(4<<30)): + self.epoch = 0 + self.free_space: Dict[Any, int] = defaultdict(lambda: dev_memsz) + self.buffer_info: Dict[Any, Tuple[int, DType, str]] = dict() + self.cached_buffers: Dict[Tuple[int, ...], Deque[Tuple[Any, int]]] = defaultdict(deque) # Cached buffer storage, splitted by type and size, newest first. + self.aging_order: Dict[Any, Deque[Tuple[Tuple[int, ...], int]]] = defaultdict(deque) # Keys of cached_buffers, ordered from oldest to newest updates. + def __del__(self): + for v in self.cached_buffers.values(): + for buf, _ in v: self._free_buffer(buf) + def _cache_reuse_buffer(self, rawbufs: Deque[Tuple[Any, int]]): # The newest cached buffer is reused. + GlobalCounters.mem_cached -= self._underlying_buf_memsz(rawbufs[0][0]) + return rawbufs.popleft()[0] + def _alloc_buffer(self, size, dtype, device, **kwargs): + self.free_space[device] -= size*dtype.itemsize + while len(self.aging_order[device]) and self.free_space[device] < 0: # When OOM removing lru buffers. + bucket, epoch = self.aging_order[device].popleft() + if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache. + newbuf = self._do_alloc(size, dtype, device, **kwargs) + self.buffer_info[newbuf] = (size, dtype, device) + return newbuf + def _free_buffer(self, buf_to_free): + self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free) + GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free) + self.buffer_info.pop(buf_to_free) + self._do_free(buf_to_free) + def alloc(self, size, dtype, device='0', **kwargs): + rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None) + return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs) + def free(self, buf): # free() just caches buffer. It might be freed later when OOM during allocation. + self.epoch += 1 + size, dtype, device = self.buffer_info[buf] + self.cached_buffers[self._cached_bufkey(size, dtype, device)].appendleft((buf, self.epoch)) + self.aging_order[device].append((self._cached_bufkey(size, dtype, device), self.epoch)) + GlobalCounters.mem_cached += self._underlying_buf_memsz(buf) + def _underlying_buf_memsz(self, buf): return self.buffer_info[buf][0] * self.buffer_info[buf][1].itemsize + def _cached_bufkey(self, size, dtype, device) -> Tuple[int, ...]: return (device, size, dtype, dtype.shape) if isinstance(dtype, ImageDType) else (device, size, dtype) # Provides a key for reusing device buffers with identical keys. + def _do_alloc(self, size, dtype, device, **kwargs): raise NotImplementedError("must be implemented") + def _do_free(self, buf): pass diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 2ea3479cd2..53f5ed13b6 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -4,7 +4,7 @@ import numpy as np from pycuda.compiler import compile as cuda_compile # type: ignore from tinygrad.helpers import DEBUG, getenv, colored, fromimport from tinygrad.ops import Compiled -from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer +from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage @@ -47,8 +47,12 @@ if getenv("CUDACPU", 0) == 1: else: import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401 import pycuda.driver as cuda # type: ignore + class CUDAAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): return cuda.mem_alloc(size * dtype.itemsize) # type: ignore + def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype. + CUDAAlloc = CUDAAllocator(pycuda.driver.Context.get_device().total_memory()) class RawCUDABuffer(RawBufferCopyInOut): # type: ignore - def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize)) # type: ignore + def __init__(self, size, dtype): super().__init__(size, dtype, allocator=CUDAAlloc) def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x.ravel(), stream) # type: ignore def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf) # type: ignore diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 628afaea45..c800010b00 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -5,7 +5,7 @@ import pyopencl as cl # type: ignore from typing import Optional, List from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport from tinygrad.ops import Compiled -from tinygrad.runtime.lib import RawBufferCopyInOut, RawBufferTransfer +from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage @@ -17,6 +17,18 @@ ROCM_LLVM_PATH = pathlib.Path("/opt/rocm/llvm/bin") if DEBUG >= 5: early_exec = fromimport("extra.helpers", "enable_early_exec")() +class CLAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): + if isinstance(dtype, ImageDType): + # NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize + assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}" + fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize]) + buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0])) + else: + buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize) + setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer + return buf + class _CL: def post_init(self, device=None): platforms: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()]) if len(y)] @@ -24,23 +36,14 @@ class _CL: self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in platforms[getenv('CL_PLATFORM', 0)] if x.name not in getenv('CL_EXCLUDE', "").split(",")] if device is None else [cl.Context(devices=[platforms[getenv('CL_PLATFORM', 0)][device]])] if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}") self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs] + self.cl_allocator = CLAllocator(CL.cl_ctxs[0].devices[0].get_info(cl.device_info.GLOBAL_MEM_SIZE)) def synchronize(self): for q in self.cl_queue: q.finish() CL = _CL() CL.post_init() if not getenv("DELAYED_RUNTIME_INIT", False) else None class CLBuffer(RawBufferCopyInOut, RawBufferTransfer): - def __init__(self, size, dtype, device='0'): - if isinstance(dtype, ImageDType): - fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize]) - buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0])) - assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}" - # NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize - else: - buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize) - setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer - super().__init__(size, dtype, buf) - + def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device}) def _copyin(self, x:np.ndarray): assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}" self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False) @@ -95,5 +98,4 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage( half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable", barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)", gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)) - GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), renderer, CLProgram, CL.synchronize) diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index 86f9d33fbc..07369e08a1 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -3,7 +3,7 @@ import ctypes, functools import extra.hip_wrapper as hip from tinygrad.helpers import DEBUG from tinygrad.ops import Compiled -from tinygrad.runtime.lib import RawBufferCopyInOut +from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage @@ -14,9 +14,14 @@ if DEBUG >= 5: # The default HIP stream is used for everything. +class HIPAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): return hip.hipMalloc(size * dtype.itemsize) + def _do_free(self, buf): hip.hipFree(buf) + def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype. +HIPAlloc = HIPAllocator(hip.hipGetDeviceProperties(hip.hipGetDevice()).totalGlobalMem) + class RawHIPBuffer(RawBufferCopyInOut): - def __init__(self, size, dtype): super().__init__(size, dtype, hip.hipMalloc(size * dtype.itemsize)) - def __del__(self): hip.hipFree(self._buf) + def __init__(self, size, dtype): super().__init__(size, dtype, allocator=HIPAlloc) def _copyin(self, x:np.ndarray): hip.hipMemcpyAsync_htod(self._buf, x.ctypes.data, self.size * self.dtype.itemsize, 0) def _copyout(self, x:np.ndarray): hip.hipMemcpy_dtoh(x.ctypes.data, self._buf, self.size * self.dtype.itemsize) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index b3600faed7..2165a87d8f 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -6,15 +6,21 @@ from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes from tinygrad.ops import Compiled -from tinygrad.runtime.lib import RawBufferMapped +from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator METAL_XCODE = getenv("METAL_XCODE") +class MetalAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared) + def _do_free(self, buf): buf.release() + def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype. + class _METAL: def __init__(self): self.mtl_buffers_in_flight: List[Any] = [] self.device = Metal.MTLCreateSystemDefaultDevice() self.mtl_queue = self.device.newCommandQueue() + self.allocator = MetalAllocator(self.device.dedicatedMemorySize() or self.device.sharedMemorySize()) # TODO: is there a better way to do this? def synchronize(self): for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted() @@ -24,10 +30,7 @@ METAL = _METAL() class RawMetalBuffer(RawBufferMapped): def __init__(self, size:int, dtype:DType): assert dtype != dtypes.double, f"METAL does not support {dtype.name}" - super().__init__(size, dtype, METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)) - def __del__(self): - self._buf.release() - super().__del__() + super().__init__(size, dtype, allocator=METAL.allocator) def _buffer(self): METAL.synchronize() return self._buf.contents().as_buffer(self._buf.length()) diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 50181ca186..2124f83621 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -1,7 +1,7 @@ import numpy as np import functools from wgpu.utils._device import get_default_device # type: ignore -from tinygrad.runtime.lib import RawBufferCopyIn +from tinygrad.runtime.lib import RawBufferCopyIn, LRUAllocator from tinygrad.helpers import dtypes, DType from tinygrad.ops import Compiled from tinygrad.codegen.linearizer import LinearizerOptions @@ -9,32 +9,37 @@ from tinygrad.renderer.cstyle import uops_to_cstyle from tinygrad.renderer.wgsl import WGSLLanguage import wgpu # type: ignore -device = get_default_device() +wgpu_device = get_default_device() class WebGPUProgram: - def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,device.create_shader_module(code=prg) + def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg) def __call__(self, global_size, local_size, *bufs, wait=False): assert len(bufs) <= 8, "WEBGPU only supports 8 buffers" binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))] bindings = [{"binding": i, "resource": {"buffer": x._buf, "offset": 0, "size": x._buf.size}} for i, x in enumerate(bufs)] - bind_group_layout = device.create_bind_group_layout(entries=binding_layouts) - pipeline_layout = device.create_pipeline_layout(bind_group_layouts=[bind_group_layout]) - bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings) - compute_pipeline = device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},) - command_encoder = device.create_command_encoder() + bind_group_layout = wgpu_device.create_bind_group_layout(entries=binding_layouts) + pipeline_layout = wgpu_device.create_pipeline_layout(bind_group_layouts=[bind_group_layout]) + bind_group = wgpu_device.create_bind_group(layout=bind_group_layout, entries=bindings) + compute_pipeline = wgpu_device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},) + command_encoder = wgpu_device.create_command_encoder() compute_pass = command_encoder.begin_compute_pass() compute_pass.set_pipeline(compute_pipeline) compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used compute_pass.dispatch_workgroups(*global_size) # x y z compute_pass.end() - device.queue.submit([command_encoder.finish()]) + wgpu_device.queue.submit([command_encoder.finish()]) + +class RawWebGPUAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): return wgpu_device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC) + def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype. +WebGPUAlloc = RawWebGPUAllocator(wgpu_device.limits['max_buffer_size']) class RawWebGPUBuffer(RawBufferCopyIn): def __init__(self, size:int, dtype:DType): assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64,dtypes.double], f"dtype {dtype} not supported on WEBGPU" - super().__init__(size, dtype, device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)) - def _copyin(self, x:np.ndarray): device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x)) - def toCPU(self) -> np.ndarray: return np.frombuffer(device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore + super().__init__(size, dtype, allocator=WebGPUAlloc) + def _copyin(self, x:np.ndarray): wgpu_device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x)) + def toCPU(self) -> np.ndarray: return np.frombuffer(wgpu_device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore renderer = functools.partial(uops_to_cstyle, WGSLLanguage()) WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, WebGPUProgram)