mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-10 22:54:59 -05:00
init allocator for compiled backends (#1467)
* init allocator for compiled backends * Update ops_webgpu.py --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import pyopencl as cl # type: ignore
|
||||
from typing import Optional, List
|
||||
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, RawBufferTransfer
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
|
||||
from tinygrad.codegen.linearizer import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
|
||||
@@ -17,6 +17,18 @@ ROCM_LLVM_PATH = pathlib.Path("/opt/rocm/llvm/bin")
|
||||
if DEBUG >= 5:
|
||||
early_exec = fromimport("extra.helpers", "enable_early_exec")()
|
||||
|
||||
class CLAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs):
|
||||
if isinstance(dtype, ImageDType):
|
||||
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
|
||||
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
|
||||
buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
|
||||
else:
|
||||
buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize)
|
||||
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
|
||||
return buf
|
||||
|
||||
class _CL:
|
||||
def post_init(self, device=None):
|
||||
platforms: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()]) if len(y)]
|
||||
@@ -24,23 +36,14 @@ class _CL:
|
||||
self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in platforms[getenv('CL_PLATFORM', 0)] if x.name not in getenv('CL_EXCLUDE', "").split(",")] if device is None else [cl.Context(devices=[platforms[getenv('CL_PLATFORM', 0)][device]])]
|
||||
if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}")
|
||||
self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs]
|
||||
self.cl_allocator = CLAllocator(CL.cl_ctxs[0].devices[0].get_info(cl.device_info.GLOBAL_MEM_SIZE))
|
||||
def synchronize(self):
|
||||
for q in self.cl_queue: q.finish()
|
||||
CL = _CL()
|
||||
CL.post_init() if not getenv("DELAYED_RUNTIME_INIT", False) else None
|
||||
|
||||
class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
||||
def __init__(self, size, dtype, device='0'):
|
||||
if isinstance(dtype, ImageDType):
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
|
||||
buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
|
||||
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
|
||||
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
|
||||
else:
|
||||
buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize)
|
||||
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
|
||||
super().__init__(size, dtype, buf)
|
||||
|
||||
def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device})
|
||||
def _copyin(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
|
||||
self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False)
|
||||
@@ -95,5 +98,4 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
||||
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
|
||||
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True))
|
||||
|
||||
GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), renderer, CLProgram, CL.synchronize)
|
||||
|
||||
Reference in New Issue
Block a user