diff --git a/docs/runtime.md b/docs/runtime.md index 710dc59dcd..d109b156cf 100644 --- a/docs/runtime.md +++ b/docs/runtime.md @@ -20,13 +20,19 @@ tinygrad provides interoperability with OpenCL and PyTorch, allowing efficient t **Important**: When using external memory pointers with tinygrad tensors, you must ensure these pointers remain valid throughout the entire lifetime of the tinygrad tensor to prevent memory corruption. -### `CUDA` PyTorch Interoperability +### `CUDA`/`METAL` PyTorch Interoperability -You can seamlessly work with CUDA tensors between PyTorch and tinygrad without data copying: +You can seamlessly work with CUDA/MPS tensors between PyTorch and tinygrad without data copying: ```python from tinygrad.dtype import _from_torch_dtype tensor1 = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cuda")) tiny_tensor1 = Tensor.from_blob(tensor1.data_ptr(), tensor1.shape, dtype=_from_torch_dtype(tensor1.dtype), device='CUDA') + +# Before tinygrad calculations, mps needs to be synchronized to make sure data is valid. +if data.device.type == "mps": torch.mps.synchronize() +else: torch.cuda.synchronize() + +x = (tiny_tensor1 + 1).realize() ``` ### `QCOM` OpenCL Interoperability diff --git a/examples/torch_cuda_kernel.py b/examples/torch_cuda_kernel.py index f9bbeb5ccb..8c89b7d5a6 100644 --- a/examples/torch_cuda_kernel.py +++ b/examples/torch_cuda_kernel.py @@ -3,32 +3,34 @@ # not a stable API, but works import torch, functools -try: - import tinygrad -except ImportError: - import pip - pip.main(['install', 'tinygrad']) - from tinygrad import Tensor, TinyJit from tinygrad.engine.realize import CompiledRunner -from tinygrad.helpers import get_single_element, Context +from tinygrad.helpers import get_single_element, Context, OSX from tinygrad.dtype import _from_torch_dtype @TinyJit def f(tg_out, tg_data): return tg_out.assign(tg_data[:, :, 0] * 0.2989 + tg_data[:, :, 1] * 0.5870 + tg_data[:, :, 2] * 0.1140).realize() -def custom_kernel(data: torch.Tensor) -> torch.Tensor: +def custom_kernel(data: torch.Tensor, device="CUDA") -> torch.Tensor: assert data.dtype == torch.float32 - tg_data = Tensor.from_blob(data.data_ptr(), data.shape, dtype=_from_torch_dtype(data.dtype), device='CUDA') + tg_data = Tensor.from_blob(data.data_ptr(), data.shape, dtype=_from_torch_dtype(data.dtype), device=device) out = torch.empty((data.shape[0], data.shape[1]), dtype=data.dtype, device=data.device) - tg_out = Tensor.from_blob(out.data_ptr(), out.shape, dtype=_from_torch_dtype(out.dtype), device='CUDA') + tg_out = Tensor.from_blob(out.data_ptr(), out.shape, dtype=_from_torch_dtype(out.dtype), device=device) + + # Need to sync torch to make sure the data is valid. + if data.device.type == "mps": torch.mps.synchronize() + else: torch.cuda.synchronize() with Context(BEAM=2): f(tg_out, tg_data) + + # Since realize() is called in f(), at this point tinygrad has finished the computation and the data is valid. return out if __name__ == "__main__": for i in range(3): - out = custom_kernel(inp:=torch.rand(16, 16, 3, device=torch.device("cuda"))) - torch.cuda.synchronize() + if OSX: + out = custom_kernel(inp:=torch.rand(16, 16, 3, device=torch.device("mps")), device="METAL") + else: + out = custom_kernel(inp:=torch.rand(16, 16, 3, device=torch.device("cuda")), device="CUDA") assert torch.allclose(out, inp[:, :, 0] * 0.2989 + inp[:, :, 1] * 0.5870 + inp[:, :, 2] * 0.1140) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index b81a801925..b2b78fc007 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -192,6 +192,8 @@ class MetalAllocator(LRUAllocator): self.dev:MetalDevice = dev super().__init__() def _alloc(self, size:int, options) -> MetalBuffer: + if options.external_ptr: return MetalBuffer(objc_id(options.external_ptr), size) + # Buffer is explicitly released in _free() rather than garbage collected via reference count ret = msg("newBufferWithLength:options:", objc_id)(self.dev.sysdevice, ctypes.c_ulong(size), MTLResourceOptions.MTLResourceStorageModeShared) if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")