mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
rename allocator methods to not conflict [pr] (#7788)
* rename allocator methods to not conflict [pr] * forgot those * transfer + offset
This commit is contained in:
@@ -15,8 +15,8 @@ a = MallocAllocator.alloc(4)
|
||||
b = MallocAllocator.alloc(4)
|
||||
|
||||
# load in some values (little endian)
|
||||
MallocAllocator.copyin(a, memoryview(bytearray([2,0,0,0])))
|
||||
MallocAllocator.copyin(b, memoryview(bytearray([3,0,0,0])))
|
||||
MallocAllocator._copyin(a, memoryview(bytearray([2,0,0,0])))
|
||||
MallocAllocator._copyin(b, memoryview(bytearray([3,0,0,0])))
|
||||
|
||||
# compile a program to a binary
|
||||
lib = ClangCompiler().compile("void add(int *out, int *a, int *b) { out[0] = a[0] + b[0]; }")
|
||||
@@ -28,7 +28,7 @@ fxn = ClangProgram("add", lib)
|
||||
fxn(out, a, b)
|
||||
|
||||
# check the data out
|
||||
print(val := MallocAllocator.as_buffer(out).cast("I").tolist()[0])
|
||||
print(val := MallocAllocator._as_buffer(out).cast("I").tolist()[0])
|
||||
assert val == 5
|
||||
|
||||
|
||||
|
||||
@@ -126,13 +126,13 @@ class HIPAllocator(LRUAllocator):
|
||||
copied_in += copy_size
|
||||
self.hb_polarity = (self.hb_polarity+1) % len(self.hb)
|
||||
minor_offset = 0 # only on the first
|
||||
def copyin(self, dest:T, src: memoryview):
|
||||
def _copyin(self, dest:T, src: memoryview):
|
||||
hip_set_device(self.device.device)
|
||||
host_mem = self._alloc_with_options(len(src), BufferOptions(host=True))
|
||||
self.device.pending_copyin.append(host_mem)
|
||||
ctypes.memmove(host_mem, from_mv(src), len(src))
|
||||
check(hip.hipMemcpyAsync(dest, host_mem, len(src), hip.hipMemcpyHostToDevice, None))
|
||||
def copyout(self, dest:memoryview, src:T):
|
||||
def _copyout(self, dest:memoryview, src:T):
|
||||
self.full_synchronize()
|
||||
hip_set_device(self.device.device)
|
||||
check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
|
||||
|
||||
@@ -116,7 +116,7 @@ class HSAAllocator(LRUAllocator):
|
||||
HSADevice.synchronize_system()
|
||||
check(hsa.hsa_amd_memory_pool_free(opaque))
|
||||
|
||||
def copyin(self, dest:T, src: memoryview):
|
||||
def _copyin(self, dest:T, src: memoryview):
|
||||
# Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets.
|
||||
self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
|
||||
mem = self._alloc(src.nbytes, BufferOptions(host=True))
|
||||
@@ -164,7 +164,7 @@ class HSAAllocator(LRUAllocator):
|
||||
if copies_called > 1: wait_signals.append(self.hb_signals[self.hb_polarity])
|
||||
self.device.hw_queue.submit_barrier(wait_signals)
|
||||
|
||||
def copyout(self, dest:memoryview, src:T):
|
||||
def _copyout(self, dest:memoryview, src:T):
|
||||
HSADevice.synchronize_system()
|
||||
copy_signal = self.device.alloc_signal(reusable=True)
|
||||
c_agents = (hsa.hsa_agent_t*2)(self.device.agent, HSADevice.cpu_agent)
|
||||
|
||||
@@ -39,8 +39,8 @@ class RawWebGLAllocator(Allocator):
|
||||
tex = ctx.texture(dtype.shape, 1, dtype=dtype_map[dtype.base])
|
||||
tex.filter = (moderngl.NEAREST, moderngl.NEAREST)
|
||||
return tex
|
||||
def copyin(self, dest:moderngl.Texture, src: memoryview): dest.write(src)
|
||||
def copyout(self, dest:memoryview, src: moderngl.Texture):
|
||||
def _copyin(self, dest:moderngl.Texture, src: memoryview): dest.write(src)
|
||||
def _copyout(self, dest:memoryview, src: moderngl.Texture):
|
||||
src.read_into(dest)
|
||||
return dest
|
||||
|
||||
|
||||
@@ -31,8 +31,8 @@ class WebGPUProgram:
|
||||
class WebGpuAllocator(Allocator):
|
||||
def _alloc(self, size: int):
|
||||
return wgpu_device.create_buffer(size=size, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)
|
||||
def copyin(self, dest, src: memoryview): wgpu_device.queue.write_buffer(dest, 0, src)
|
||||
def copyout(self, dest, src: memoryview): dest[:] = wgpu_device.queue.read_buffer(src, 0) # TODO: remove this copy
|
||||
def _copyin(self, dest, src: memoryview): wgpu_device.queue.write_buffer(dest, 0, src)
|
||||
def _copyout(self, dest, src: memoryview): dest[:] = wgpu_device.queue.read_buffer(src, 0) # TODO: remove this copy
|
||||
|
||||
class WebGpuDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
|
||||
@@ -48,8 +48,8 @@ a = MallocAllocator.alloc(na.size * np.dtype(np.float32).itemsize)
|
||||
b = MallocAllocator.alloc(nb.size * np.dtype(np.float32).itemsize)
|
||||
c = MallocAllocator.alloc(nc.size * np.dtype(np.float32).itemsize)
|
||||
|
||||
MallocAllocator.copyin(b, flat_mv(nb.data))
|
||||
MallocAllocator.copyin(c, flat_mv(nc.data))
|
||||
MallocAllocator._copyin(b, flat_mv(nb.data))
|
||||
MallocAllocator._copyin(c, flat_mv(nc.data))
|
||||
|
||||
module = ir.Module(name=__file__)
|
||||
func = ir.Function(module, ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()]*3), name='exec')
|
||||
@@ -171,7 +171,7 @@ def timeit(fxn):
|
||||
return time.perf_counter() - st
|
||||
|
||||
tm = min([timeit(lambda: prog(a, b, c, N**2)) for _ in range(20)])
|
||||
MallocAllocator.copyout(flat_mv(na.data), a)
|
||||
MallocAllocator._copyout(flat_mv(na.data), a)
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, {BW*1e-9/tm:.2f} GB/s")
|
||||
|
||||
np.testing.assert_allclose(na[:ns.shape[0]], ns, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@@ -23,8 +23,8 @@ a = cudaalloc.alloc(N*N*2 if FLOAT16 else N*N*4)
|
||||
b = cudaalloc.alloc(N*N*2 if FLOAT16 else N*N*4)
|
||||
c = cudaalloc.alloc(N*N*4)
|
||||
|
||||
cudaalloc.copyin(a, bytearray(na))
|
||||
cudaalloc.copyin(b, bytearray(nb))
|
||||
cudaalloc._copyin(a, bytearray(na))
|
||||
cudaalloc._copyin(b, bytearray(nb))
|
||||
|
||||
FLOPS = N*N*N*2
|
||||
BW = N*N*3*4
|
||||
@@ -103,5 +103,5 @@ extern "C" __global__ void wmma_example({'half' if FLOAT16 else 'float'} *a, {'h
|
||||
global_size, local_size = [(N//16)//4, (N//16)//4, 1], [32, 1, 1]
|
||||
tm = min([prog(a, b, c, global_size=global_size, local_size=local_size, wait=True) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
||||
cudaalloc.copyout(flat_mv(nc.data), c)
|
||||
cudaalloc._copyout(flat_mv(nc.data), c)
|
||||
np.testing.assert_allclose(na.T.astype(np.float32) @ nb.T.astype(np.float32), nc.reshape(N,N).T, atol=1e-2)
|
||||
@@ -40,8 +40,8 @@ c = hipallocator.alloc(N*N*2)
|
||||
na = np.empty(N*N, np.float32)
|
||||
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
|
||||
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
|
||||
hipallocator.copyin(b, memoryview(bytearray(nb)))
|
||||
hipallocator.copyin(c, memoryview(bytearray(nc)))
|
||||
hipallocator._copyin(b, memoryview(bytearray(nb)))
|
||||
hipallocator._copyin(c, memoryview(bytearray(nc)))
|
||||
|
||||
prog_str = f"""
|
||||
#define F32
|
||||
@@ -126,13 +126,13 @@ def timeit(fxn):
|
||||
if RAND:
|
||||
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
|
||||
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
|
||||
hipallocator.copyin(b, memoryview(bytearray(nb)))
|
||||
hipallocator.copyin(c, memoryview(bytearray(nc)))
|
||||
hipallocator._copyin(b, memoryview(bytearray(nb)))
|
||||
hipallocator._copyin(c, memoryview(bytearray(nc)))
|
||||
return et
|
||||
|
||||
print("global/local size", global_size, local_size, f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}")
|
||||
tm = min([timeit(lambda: prog(a, b, c, global_size=global_size, local_size=local_size, wait=True)) for _ in range(CNT)])
|
||||
hipallocator.copyout(flat_mv(na.data),a)
|
||||
hipallocator._copyout(flat_mv(na.data),a)
|
||||
na = na.reshape(N,N)
|
||||
comp = nb.astype(np.float32) @ nc.astype(np.float32)
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
||||
|
||||
@@ -20,8 +20,8 @@ na = np.zeros((N,N),dtype=np.float32)
|
||||
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)N
|
||||
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
||||
|
||||
metalalloc.copyin(b,nb.tobytes())
|
||||
metalalloc.copyin(c,nc.tobytes())
|
||||
metalalloc._copyin(b,nb.tobytes())
|
||||
metalalloc._copyin(c,nc.tobytes())
|
||||
|
||||
FLOPS = N*N*N*2
|
||||
BW = N*N*3*4
|
||||
@@ -96,7 +96,7 @@ def timeit(fxn):
|
||||
return time.perf_counter() - st
|
||||
tm = min([timeit(lambda: prog(a, b, c, global_size=[N//(8*4), N//(8*4*LID), 1], local_size=[32, LID, 1], wait=True)) for _ in range(20)])
|
||||
comp = nb@nc
|
||||
metalalloc.copyout(flat_mv(na.data), a)
|
||||
metalalloc._copyout(flat_mv(na.data), a)
|
||||
if N <= 32:
|
||||
print(na)
|
||||
print(comp)
|
||||
|
||||
@@ -80,8 +80,8 @@ kernel void test(device float* data0, const device float* data1, const device fl
|
||||
a = metalalloc.alloc(M*4)
|
||||
b = metalalloc.alloc(N*4)
|
||||
c = metalalloc.alloc(N*M*4)
|
||||
metalalloc.copyin(b,nb.tobytes())
|
||||
metalalloc.copyin(c,nc.tobytes())
|
||||
metalalloc._copyin(b,nb.tobytes())
|
||||
metalalloc._copyin(c,nc.tobytes())
|
||||
def metalrun():
|
||||
prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True)
|
||||
return a
|
||||
@@ -93,7 +93,7 @@ def timeit(fxn):
|
||||
tm = min([timeit(metalrun) for _ in range(200)])
|
||||
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal")
|
||||
metal_a = np.zeros(M, dtype=np.float32)
|
||||
metalalloc.copyout(flat_mv(metal_a.data), a)
|
||||
metalalloc._copyout(flat_mv(metal_a.data), a)
|
||||
np.testing.assert_allclose(metal_a, torch_a, atol=5e-3)
|
||||
|
||||
b = Tensor(nb)
|
||||
|
||||
2
test/external/external_test_speed_llama.py
vendored
2
test/external/external_test_speed_llama.py
vendored
@@ -14,7 +14,7 @@ class FakeProgram:
|
||||
|
||||
class FakeAllocator(Allocator):
|
||||
def _alloc(self, sz, options): return None
|
||||
def copyin(self, dest, src:memoryview): pass
|
||||
def _copyin(self, dest, src:memoryview): pass
|
||||
|
||||
class TestLLaMASpeed(unittest.TestCase):
|
||||
def test_llama_compile(self):
|
||||
|
||||
@@ -85,8 +85,8 @@ class Buffer:
|
||||
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferOptions(external_ptr=external_ptr)
|
||||
if self._base is not None:
|
||||
self._base.ensure_allocated()
|
||||
assert hasattr(self.allocator, "offset"), "offset function required for view"
|
||||
self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
|
||||
assert hasattr(self.allocator, "_offset"), "offset function required for view"
|
||||
self._buf: Any = self.allocator._offset(self.base._buf, self.nbytes, self.offset)
|
||||
else:
|
||||
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
|
||||
@@ -112,21 +112,21 @@ class Buffer:
|
||||
(f" offset:{self.offset}" if hasattr(self, "base") else "") + (f" {self.options=}" if self.options is not None else "") + ">"
|
||||
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
||||
# zero copy with as_buffer (disabled by default due to use after free)
|
||||
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer') and (self.options is None or self.options.image is None):
|
||||
return self.allocator.as_buffer(self._buf)
|
||||
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and (self.options is None or self.options.image is None):
|
||||
return self.allocator._as_buffer(self._buf)
|
||||
assert not force_zero_copy, "force zero copy was passed, but copy is required"
|
||||
return self.copyout(memoryview(bytearray(self.nbytes)))
|
||||
def copyin(self, mv:memoryview):
|
||||
mv = flat_mv(mv)
|
||||
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||
assert self.is_allocated(), "can't copyin to unallocated buffer"
|
||||
self.allocator.copyin(self._buf, mv)
|
||||
self.allocator._copyin(self._buf, mv)
|
||||
return self
|
||||
def copyout(self, mv:memoryview) -> memoryview:
|
||||
mv = flat_mv(mv)
|
||||
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||
assert self.is_allocated(), "can't copyout unallocated buffer"
|
||||
self.allocator.copyout(mv, self._buf)
|
||||
self.allocator._copyout(mv, self._buf)
|
||||
return mv
|
||||
def view(self, size:int, dtype:DType, offset:int) -> Buffer:
|
||||
assert offset < self.nbytes, "offset must be less than nbytes"
|
||||
@@ -141,8 +141,11 @@ class Allocator:
|
||||
def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
|
||||
def free(self, opaque, size:int, options:Optional[BufferOptions]=None): self._free(opaque, options if options is not None else BufferOptions())
|
||||
def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
|
||||
def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
|
||||
def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
|
||||
def _copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
|
||||
def _copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
|
||||
# def _as_buffer(self, src) -> memoryview:
|
||||
# def _offset(self, buf, size:int, offset:int):
|
||||
# def _transfer(self, dest, src, sz:int, src_dev, dest_dev):
|
||||
|
||||
class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
||||
"""
|
||||
@@ -167,10 +170,10 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
||||
class _MallocAllocator(LRUAllocator):
|
||||
def _alloc(self, size:int, options:BufferOptions):
|
||||
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else (ctypes.c_uint8 * size)()
|
||||
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
||||
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
||||
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
||||
def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
|
||||
def _as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
||||
def _copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
||||
def _copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
||||
def _offset(self, buf, size:int, offset:int): return from_mv(self._as_buffer(buf)[offset:offset+size])
|
||||
|
||||
MallocAllocator = _MallocAllocator()
|
||||
|
||||
|
||||
@@ -121,9 +121,9 @@ class BufferCopy(Runner):
|
||||
getattr(src.allocator.device, 'fd', None) is not None
|
||||
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
|
||||
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
|
||||
elif src.device.startswith("DISK") and hasattr(dest.allocator, 'as_buffer'):
|
||||
elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'):
|
||||
# fast(ish) path, uses readinto in diskbuffers
|
||||
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
|
||||
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
|
||||
else:
|
||||
dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
|
||||
@@ -136,7 +136,7 @@ class BufferCopy(Runner):
|
||||
return time.perf_counter() - st
|
||||
|
||||
class BufferXfer(BufferCopy):
|
||||
def copy(self, dest, src): dest.allocator.transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
|
||||
def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
|
||||
|
||||
# **************** method cache ****************
|
||||
|
||||
@@ -189,7 +189,7 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
||||
out, arg = si.outputs[0], si.ast.arg
|
||||
if si.ast.op is Ops.COPY:
|
||||
kernel_type = BufferCopy
|
||||
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
|
||||
if hasattr(Device[out.device].allocator, '_transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
|
||||
kernel_type = BufferXfer
|
||||
return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
|
||||
if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
|
||||
|
||||
@@ -55,7 +55,7 @@ class MetalGraph(GraphRunner):
|
||||
self.all_resources = dedup(all_resources)
|
||||
self.all_pipelines = dedup(all_pipelines)
|
||||
self.command_buffer: Any = None
|
||||
if len(self.vars): self.int_buf_view = self.device.allocator.as_buffer(self.int_buf).cast('i')
|
||||
if len(self.vars): self.int_buf_view = self.device.allocator._as_buffer(self.int_buf).cast('i')
|
||||
self.range = to_struct(0, len(self.jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
|
||||
@@ -104,10 +104,10 @@ class CloudHandler(BaseHTTPRequestHandler):
|
||||
buf,sz,buffer_options = session.buffers[c.buffer_num]
|
||||
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
|
||||
del session.buffers[c.buffer_num]
|
||||
case CopyIn(): Device[CloudHandler.dname].allocator.copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
|
||||
case CopyIn(): Device[CloudHandler.dname].allocator._copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
|
||||
case CopyOut():
|
||||
buf,sz,_ = session.buffers[c.buffer_num]
|
||||
Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
|
||||
Device[CloudHandler.dname].allocator._copyout(memoryview(ret:=bytearray(sz)), buf)
|
||||
case ProgramAlloc():
|
||||
lib = Device[CloudHandler.dname].compiler.compile_cached(req._h[c.datahash].decode())
|
||||
session.programs[(c.name, c.datahash)] = Device[CloudHandler.dname].runtime(c.name, lib)
|
||||
@@ -149,8 +149,8 @@ class CloudAllocator(Allocator):
|
||||
return self.device.buffer_num
|
||||
# TODO: options should not be here in any Allocator
|
||||
def _free(self, opaque:int, options): self.device.req.q(BufferFree(opaque))
|
||||
def copyin(self, dest:int, src:memoryview): self.device.req.q(CopyIn(dest, self.device.req.h(bytes(src))))
|
||||
def copyout(self, dest:memoryview, src:int):
|
||||
def _copyin(self, dest:int, src:memoryview): self.device.req.q(CopyIn(dest, self.device.req.h(bytes(src))))
|
||||
def _copyout(self, dest:memoryview, src:int):
|
||||
self.device.req.q(CopyOut(src))
|
||||
resp = self.device.batch_submit()
|
||||
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
|
||||
|
||||
@@ -70,24 +70,24 @@ class CUDAAllocator(LRUAllocator):
|
||||
def _free(self, opaque, options:BufferOptions):
|
||||
if options.host: check(cuda.cuMemFreeHost(opaque))
|
||||
else: check(cuda.cuMemFree_v2(opaque))
|
||||
def copyin(self, dest, src:memoryview):
|
||||
def _copyin(self, dest, src:memoryview):
|
||||
check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
host_mem = self.alloc(len(src), BufferOptions(host=True))
|
||||
self.device.pending_copyin.append((host_mem, len(src), BufferOptions(host=True)))
|
||||
ctypes.memmove(host_mem, from_mv(src), len(src))
|
||||
check(cuda.cuMemcpyHtoDAsync_v2(dest, host_mem, len(src), None))
|
||||
def copyout(self, dest:memoryview, src):
|
||||
def _copyout(self, dest:memoryview, src):
|
||||
CUDADevice.synchronize_system()
|
||||
check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
check(cuda.cuMemcpyDtoH_v2(from_mv(dest), src, len(dest)))
|
||||
def transfer(self, dest, src, sz:int, src_dev, dest_dev):
|
||||
def _transfer(self, dest, src, sz:int, src_dev, dest_dev):
|
||||
check(cuda.cuCtxSetCurrent(src_dev.context))
|
||||
check(cuda.cuEventCreate(ctypes.byref(sync_event := cuda.CUevent()), 0))
|
||||
check(cuda.cuMemcpyDtoDAsync_v2(dest, src, sz, None))
|
||||
check(cuda.cuEventRecord(sync_event, None))
|
||||
check(cuda.cuCtxSetCurrent(dest_dev.context))
|
||||
check(cuda.cuStreamWaitEvent(None, sync_event, 0)) # sync the default stream on the dest dev
|
||||
def offset(self, buf, size:int, offset:int): return cuda.CUdeviceptr_v2(buf.value + offset)
|
||||
def _offset(self, buf, size:int, offset:int): return cuda.CUdeviceptr_v2(buf.value + offset)
|
||||
|
||||
class CUDADevice(Compiled):
|
||||
devices: List[CUDADevice] = []
|
||||
|
||||
@@ -22,9 +22,9 @@ class DiskAllocator(Allocator):
|
||||
self.device._might_open(size)
|
||||
return DiskBuffer(self.device, size)
|
||||
def _free(self, opaque, options): self.device._might_close()
|
||||
def as_buffer(self, src:DiskBuffer): return src._buf()
|
||||
def copyin(self, dest:DiskBuffer, src:memoryview): dest._buf()[:] = src
|
||||
def copyout(self, dest:memoryview, src:DiskBuffer):
|
||||
def _as_buffer(self, src:DiskBuffer): return src._buf()
|
||||
def _copyin(self, dest:DiskBuffer, src:memoryview): dest._buf()[:] = src
|
||||
def _copyout(self, dest:memoryview, src:DiskBuffer):
|
||||
if OSX and self.device.fd is not None:
|
||||
# OSX doesn't seem great at mmap, this is faster
|
||||
with io.FileIO(self.device.fd, "a+b", closefd=False) as fo:
|
||||
@@ -65,7 +65,7 @@ class DiskAllocator(Allocator):
|
||||
DiskDevice.io_uring.cq.khead[0] = head + 1 # advance
|
||||
processed_reqs_cnt += 1
|
||||
|
||||
def offset(self, buf:DiskBuffer, size:int, offset:int): return DiskBuffer(buf.device, size, offset)
|
||||
def _offset(self, buf:DiskBuffer, size:int, offset:int): return DiskBuffer(buf.device, size, offset)
|
||||
|
||||
class DiskDevice(Compiled):
|
||||
_tried_io_uring_init = False
|
||||
|
||||
@@ -54,10 +54,10 @@ class DSPAllocator(Allocator):
|
||||
os.close(opaque.share_info.fd)
|
||||
qcom_dsp.ION_IOC_FREE(self.device.ion_fd, handle=opaque.share_info.handle)
|
||||
|
||||
def as_buffer(self, src:DSPBuffer) -> memoryview: return to_mv(src.va_addr, src.size)
|
||||
def copyin(self, dest:DSPBuffer, src:memoryview): ctypes.memmove(dest.va_addr, from_mv(src), src.nbytes)
|
||||
def copyout(self, dest:memoryview, src:DSPBuffer): ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
|
||||
def offset(self, buf, size:int, offset:int): return DSPBuffer(buf.va_addr+offset, size, buf.share_info, buf.offset+offset)
|
||||
def _as_buffer(self, src:DSPBuffer) -> memoryview: return to_mv(src.va_addr, src.size)
|
||||
def _copyin(self, dest:DSPBuffer, src:memoryview): ctypes.memmove(dest.va_addr, from_mv(src), src.nbytes)
|
||||
def _copyout(self, dest:memoryview, src:DSPBuffer): ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
|
||||
def _offset(self, buf, size:int, offset:int): return DSPBuffer(buf.va_addr+offset, size, buf.share_info, buf.offset+offset)
|
||||
|
||||
class DSPDevice(Compiled):
|
||||
def __init__(self, device:str=""):
|
||||
|
||||
@@ -69,7 +69,7 @@ class CLAllocator(LRUAllocator):
|
||||
options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status), options)
|
||||
return (checked(cl.clCreateBuffer(self.device.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status), options)
|
||||
def _free(self, opaque:Tuple[ctypes._CData, BufferOptions], options:BufferOptions): check(cl.clReleaseMemObject(opaque[0]))
|
||||
def copyin(self, dest:Tuple[ctypes._CData, BufferOptions], src:memoryview):
|
||||
def _copyin(self, dest:Tuple[ctypes._CData, BufferOptions], src:memoryview):
|
||||
if dest[1].image is not None:
|
||||
check(cl.clEnqueueWriteImage(self.device.queue, dest[0], False, (ctypes.c_size_t * 3)(0,0,0),
|
||||
(ctypes.c_size_t * 3)(dest[1].image.shape[1],dest[1].image.shape[0],1), 0, 0, from_mv(src), 0, None, None))
|
||||
@@ -77,7 +77,7 @@ class CLAllocator(LRUAllocator):
|
||||
if mv_address(src) % 16: src = memoryview(bytearray(src))
|
||||
check(cl.clEnqueueWriteBuffer(self.device.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
|
||||
self.device.pending_copyin.append(src) # NOTE: these can't be freed until the GPU actually executes this command
|
||||
def copyout(self, dest:memoryview, src:Tuple[ctypes._CData, BufferOptions]):
|
||||
def _copyout(self, dest:memoryview, src:Tuple[ctypes._CData, BufferOptions]):
|
||||
if src[1].image is not None:
|
||||
check(cl.clEnqueueReadImage(self.device.queue, src[0], False, (ctypes.c_size_t * 3)(0,0,0),
|
||||
(ctypes.c_size_t * 3)(src[1].image.shape[1],src[1].image.shape[0],1), 0, 0, from_mv(dest), 0, None, None))
|
||||
|
||||
@@ -50,10 +50,10 @@ class HIPAllocator(LRUAllocator):
|
||||
check(hip.hipSetDevice(self.device.device_id))
|
||||
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
|
||||
def _free(self, opaque, options:BufferOptions): check(hip.hipFree(opaque))
|
||||
def copyin(self, dest, src: memoryview):
|
||||
def _copyin(self, dest, src: memoryview):
|
||||
check(hip.hipSetDevice(self.device.device_id))
|
||||
check(hip.hipMemcpy(dest, from_mv(src), len(src), hip.hipMemcpyHostToDevice))
|
||||
def copyout(self, dest:memoryview, src):
|
||||
def _copyout(self, dest:memoryview, src):
|
||||
self.device.synchronize()
|
||||
check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
|
||||
|
||||
|
||||
@@ -139,7 +139,7 @@ class MetalAllocator(LRUAllocator):
|
||||
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
|
||||
return MetalBuffer(ret, size)
|
||||
def _free(self, opaque:MetalBuffer, options): msg(opaque.buf, "release")
|
||||
def transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
|
||||
def _transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
|
||||
dest_dev.synchronize()
|
||||
src_command_buffer = msg(src_dev.mtl_queue, "commandBuffer", restype=objc_instance)
|
||||
encoder = msg(src_command_buffer, "blitCommandEncoder", restype=objc_instance)
|
||||
@@ -155,19 +155,20 @@ class MetalAllocator(LRUAllocator):
|
||||
src_dev.timeline_value += 1
|
||||
msg(src_command_buffer, "commit")
|
||||
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
|
||||
# NOTE: this is unused
|
||||
def from_buffer(self, src:memoryview) -> Optional[Any]:
|
||||
ptr = (ctypes.c_char * src.nbytes).from_buffer(src)
|
||||
ret = msg(self.device.device, "newBufferWithBytesNoCopy:length:options:deallocator:", ptr, src.nbytes, 0, None, restype=objc_instance)
|
||||
if ret: self.device.mv_in_metal.append(src)
|
||||
return MetalBuffer(ret, src.nbytes)
|
||||
def as_buffer(self, src:MetalBuffer) -> memoryview:
|
||||
def _as_buffer(self, src:MetalBuffer) -> memoryview:
|
||||
self.device.synchronize()
|
||||
ptr = msg(src.buf, "contents", restype=objc_id) # Shared memory, do not release here
|
||||
array = (ctypes.c_char * (src.offset + src.size)).from_address(ptr.value)
|
||||
return memoryview(array).cast("B")[src.offset:]
|
||||
def copyin(self, dest:MetalBuffer, src:memoryview): self.as_buffer(dest)[:] = src
|
||||
def copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self.as_buffer(src)
|
||||
def offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
|
||||
def _copyin(self, dest:MetalBuffer, src:memoryview): self._as_buffer(dest)[:] = src
|
||||
def _copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self._as_buffer(src)
|
||||
def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
|
||||
|
||||
class MetalDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
|
||||
@@ -3,7 +3,7 @@ from tinygrad.helpers import flat_mv
|
||||
from tinygrad.device import Compiled, Allocator
|
||||
|
||||
class NpyAllocator(Allocator): # pylint: disable=abstract-method
|
||||
def copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)
|
||||
def _copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)
|
||||
|
||||
class NpyDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, NpyAllocator(), None, None, None)
|
||||
|
||||
@@ -199,8 +199,8 @@ class PythonCompiler(Compiler):
|
||||
|
||||
class PythonAllocator(Allocator):
|
||||
def _alloc(self, size, options): return memoryview(bytearray(size))
|
||||
def copyin(self, dest, src:memoryview): dest[:] = src
|
||||
def copyout(self, dest:memoryview, src): dest[:] = src
|
||||
def _copyin(self, dest, src:memoryview): dest[:] = src
|
||||
def _copyout(self, dest:memoryview, src): dest[:] = src
|
||||
|
||||
class PythonDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
|
||||
@@ -323,16 +323,16 @@ class QCOMAllocator(HCQAllocator):
|
||||
ctypes.memmove(dest_addr+dest_off, src_addr+src_off, real_size)
|
||||
src_off, dest_off = src_off+src_stride, dest_off+dest_stride
|
||||
|
||||
def copyin(self, dest:HCQBuffer, src:memoryview):
|
||||
def _copyin(self, dest:HCQBuffer, src:memoryview):
|
||||
if (qd:=cast(QCOMBuffer, dest)).pitch is not None: self._do_copy(mv_address(src), qd.va_addr, len(src), qd.real_stride, qd.real_stride, qd.pitch)
|
||||
else: ctypes.memmove(dest.va_addr, mv_address(src), src.nbytes)
|
||||
|
||||
def copyout(self, dest:memoryview, src:HCQBuffer):
|
||||
def _copyout(self, dest:memoryview, src:HCQBuffer):
|
||||
self.device.synchronize()
|
||||
if (qs:=cast(QCOMBuffer, src)).pitch is not None: self._do_copy(qs.va_addr, mv_address(dest), qs.size, qs.real_stride, qs.pitch, qs.real_stride)
|
||||
else: ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
|
||||
|
||||
def as_buffer(self, src:HCQBuffer) -> memoryview:
|
||||
def _as_buffer(self, src:HCQBuffer) -> memoryview:
|
||||
self.device.synchronize()
|
||||
return to_mv(src.va_addr, src.size)
|
||||
|
||||
|
||||
@@ -475,7 +475,7 @@ class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
|
||||
|
||||
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
|
||||
|
||||
def copyin(self, dest:HCQBuffer, src:memoryview):
|
||||
def _copyin(self, dest:HCQBuffer, src:memoryview):
|
||||
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
|
||||
for i in range(0, src.nbytes, self.b[0].size):
|
||||
self.b_next = (self.b_next + 1) % len(self.b)
|
||||
@@ -503,7 +503,7 @@ class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
|
||||
self.b_timeline[batch_info[1]] = self.device.timeline_value
|
||||
self.device.timeline_value += 1
|
||||
|
||||
def copyout(self, dest:memoryview, src:HCQBuffer):
|
||||
def _copyout(self, dest:memoryview, src:HCQBuffer):
|
||||
self.device.synchronize()
|
||||
|
||||
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
|
||||
@@ -516,7 +516,7 @@ class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
|
||||
|
||||
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
|
||||
|
||||
def transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev, dest_dev):
|
||||
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev, dest_dev):
|
||||
src_dev.allocator.map(dest)
|
||||
|
||||
with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
|
||||
@@ -534,6 +534,6 @@ class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
|
||||
|
||||
def map(self, buf:HCQBuffer): pass
|
||||
|
||||
def offset(self, buf, size:int, offset:int) -> HCQBuffer:
|
||||
def _offset(self, buf, size:int, offset:int) -> HCQBuffer:
|
||||
return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
|
||||
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)
|
||||
|
||||
Reference in New Issue
Block a user