init allocator for compiled backends (#1467)

* init allocator for compiled backends

* Update ops_webgpu.py

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
nimlgen
2023-08-17 20:33:32 +03:00
committed by GitHub
parent a293c18d34
commit bd111411bf
9 changed files with 343 additions and 39 deletions

View File

@@ -5,7 +5,7 @@ import pyopencl as cl # type: ignore
from typing import Optional, List
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut, RawBufferTransfer
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
@@ -17,6 +17,18 @@ ROCM_LLVM_PATH = pathlib.Path("/opt/rocm/llvm/bin")
if DEBUG >= 5:
early_exec = fromimport("extra.helpers", "enable_early_exec")()
class CLAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs):
if isinstance(dtype, ImageDType):
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
else:
buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize)
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
return buf
class _CL:
def post_init(self, device=None):
platforms: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()]) if len(y)]
@@ -24,23 +36,14 @@ class _CL:
self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in platforms[getenv('CL_PLATFORM', 0)] if x.name not in getenv('CL_EXCLUDE', "").split(",")] if device is None else [cl.Context(devices=[platforms[getenv('CL_PLATFORM', 0)][device]])]
if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}")
self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs]
self.cl_allocator = CLAllocator(CL.cl_ctxs[0].devices[0].get_info(cl.device_info.GLOBAL_MEM_SIZE))
def synchronize(self):
for q in self.cl_queue: q.finish()
CL = _CL()
CL.post_init() if not getenv("DELAYED_RUNTIME_INIT", False) else None
class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
def __init__(self, size, dtype, device='0'):
if isinstance(dtype, ImageDType):
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
else:
buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize)
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
super().__init__(size, dtype, buf)
def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device})
def _copyin(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False)
@@ -95,5 +98,4 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True))
GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), renderer, CLProgram, CL.synchronize)