diff --git a/tinygrad/device.py b/tinygrad/device.py index 4a001066c3..c687d4e993 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -524,7 +524,7 @@ class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize) def transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev, dest_dev): - src_dev._gpu_map(dest) + src_dev.allocator.map(dest) with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE): src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \ @@ -539,6 +539,8 @@ class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method .signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev) dest_dev.timeline_value += 1 + def map(self, buf:HCQBuffer): pass + def offset(self, buf, size:int, offset:int) -> HCQBuffer: return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']}, **{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf) diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index c3db52478c..1f156f6f97 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -1,7 +1,7 @@ import collections, time from typing import List, Any, Dict, cast, Optional, Tuple, Set from tinygrad.helpers import round_up, to_mv, PROFILE -from tinygrad.device import Buffer, BufferOptions, Compiled, Device +from tinygrad.device import HCQAllocator, Buffer, BufferOptions, Compiled, Device from tinygrad.shape.symbolic import Variable from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner from tinygrad.engine.jit import MultiGraphRunner @@ -100,7 +100,7 @@ class HCQGraph(MultiGraphRunner): if isinstance(ji.prg, CompiledRunner): enqueue_queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals)) elif isinstance(ji.prg, BufferXfer): dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] - Device[src.device]._gpu_map(dest._buf) #type: ignore + cast(HCQAllocator, Device[src.device].allocator).map(dest._buf) enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) self.copy_to_devs[Device[dest.device]].add(Device[src.device]) self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index d01397c7b7..5eb4559781 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -362,6 +362,8 @@ class AMDAllocator(HCQAllocator): def _free(self, opaque, options:BufferOptions): self.device._gpu_free(opaque) + def map(self, buf:HCQBuffer): self.device._gpu_map(buf._base if hasattr(buf, '_base') else buf) + MAP_FIXED, MAP_NORESERVE = 0x10, 0x400 @dataclass @@ -380,7 +382,6 @@ class AMDDevice(HCQCompiled): gpus:List[pathlib.Path] = [] def _gpu_map(self, mem): - mem = mem._base if hasattr(mem, '_base') else mem if self.gpu_id in getattr(mem, "mapped_gpu_ids", []): return mem.__setattr__("mapped_gpu_ids", getattr(mem, "mapped_gpu_ids", []) + [self.gpu_id]) c_gpus = (ctypes.c_int32 * len(mem.mapped_gpu_ids))(*mem.mapped_gpu_ids) diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index e6cb45a654..9cadc0a982 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -339,6 +339,8 @@ class NVAllocator(HCQAllocator): if options.host: self.device._gpu_host_free(opaque) else: self.device._gpu_free(opaque) + def map(self, buf:HCQBuffer): self.device._gpu_map(buf._base if hasattr(buf, '_base') else buf) + @dataclass class GPFifo: ring: memoryview @@ -437,7 +439,6 @@ class NVDevice(HCQCompiled): gpuAttributesCount=1, perGpuAttributes=gpu_attrs, va_addr=va_base, size=size, mapped_gpu_ids=[self.gpu_uuid]) def _gpu_map(self, mem): - mem = mem._base if hasattr(mem, '_base') else mem if self.gpu_uuid in mem.mapped_gpu_ids: return mem.mapped_gpu_ids.append(self.gpu_uuid) self._gpu_uvm_map(mem.va_addr, mem.size, mem.hMemory, create_range=False)