mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
metal PyTorch interop (#9229)
* add from_blob support to mps cuda * objc_id * metal pytorch interop * fix comments --------- Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
@@ -20,13 +20,19 @@ tinygrad provides interoperability with OpenCL and PyTorch, allowing efficient t
|
|||||||
|
|
||||||
**Important**: When using external memory pointers with tinygrad tensors, you must ensure these pointers remain valid throughout the entire lifetime of the tinygrad tensor to prevent memory corruption.
|
**Important**: When using external memory pointers with tinygrad tensors, you must ensure these pointers remain valid throughout the entire lifetime of the tinygrad tensor to prevent memory corruption.
|
||||||
|
|
||||||
### `CUDA` PyTorch Interoperability
|
### `CUDA`/`METAL` PyTorch Interoperability
|
||||||
|
|
||||||
You can seamlessly work with CUDA tensors between PyTorch and tinygrad without data copying:
|
You can seamlessly work with CUDA/MPS tensors between PyTorch and tinygrad without data copying:
|
||||||
```python
|
```python
|
||||||
from tinygrad.dtype import _from_torch_dtype
|
from tinygrad.dtype import _from_torch_dtype
|
||||||
tensor1 = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cuda"))
|
tensor1 = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cuda"))
|
||||||
tiny_tensor1 = Tensor.from_blob(tensor1.data_ptr(), tensor1.shape, dtype=_from_torch_dtype(tensor1.dtype), device='CUDA')
|
tiny_tensor1 = Tensor.from_blob(tensor1.data_ptr(), tensor1.shape, dtype=_from_torch_dtype(tensor1.dtype), device='CUDA')
|
||||||
|
|
||||||
|
# Before tinygrad calculations, mps needs to be synchronized to make sure data is valid.
|
||||||
|
if data.device.type == "mps": torch.mps.synchronize()
|
||||||
|
else: torch.cuda.synchronize()
|
||||||
|
|
||||||
|
x = (tiny_tensor1 + 1).realize()
|
||||||
```
|
```
|
||||||
|
|
||||||
### `QCOM` OpenCL Interoperability
|
### `QCOM` OpenCL Interoperability
|
||||||
|
|||||||
@@ -3,32 +3,34 @@
|
|||||||
# not a stable API, but works
|
# not a stable API, but works
|
||||||
|
|
||||||
import torch, functools
|
import torch, functools
|
||||||
try:
|
|
||||||
import tinygrad
|
|
||||||
except ImportError:
|
|
||||||
import pip
|
|
||||||
pip.main(['install', 'tinygrad'])
|
|
||||||
|
|
||||||
from tinygrad import Tensor, TinyJit
|
from tinygrad import Tensor, TinyJit
|
||||||
from tinygrad.engine.realize import CompiledRunner
|
from tinygrad.engine.realize import CompiledRunner
|
||||||
from tinygrad.helpers import get_single_element, Context
|
from tinygrad.helpers import get_single_element, Context, OSX
|
||||||
from tinygrad.dtype import _from_torch_dtype
|
from tinygrad.dtype import _from_torch_dtype
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def f(tg_out, tg_data): return tg_out.assign(tg_data[:, :, 0] * 0.2989 + tg_data[:, :, 1] * 0.5870 + tg_data[:, :, 2] * 0.1140).realize()
|
def f(tg_out, tg_data): return tg_out.assign(tg_data[:, :, 0] * 0.2989 + tg_data[:, :, 1] * 0.5870 + tg_data[:, :, 2] * 0.1140).realize()
|
||||||
|
|
||||||
def custom_kernel(data: torch.Tensor) -> torch.Tensor:
|
def custom_kernel(data: torch.Tensor, device="CUDA") -> torch.Tensor:
|
||||||
assert data.dtype == torch.float32
|
assert data.dtype == torch.float32
|
||||||
tg_data = Tensor.from_blob(data.data_ptr(), data.shape, dtype=_from_torch_dtype(data.dtype), device='CUDA')
|
tg_data = Tensor.from_blob(data.data_ptr(), data.shape, dtype=_from_torch_dtype(data.dtype), device=device)
|
||||||
|
|
||||||
out = torch.empty((data.shape[0], data.shape[1]), dtype=data.dtype, device=data.device)
|
out = torch.empty((data.shape[0], data.shape[1]), dtype=data.dtype, device=data.device)
|
||||||
tg_out = Tensor.from_blob(out.data_ptr(), out.shape, dtype=_from_torch_dtype(out.dtype), device='CUDA')
|
tg_out = Tensor.from_blob(out.data_ptr(), out.shape, dtype=_from_torch_dtype(out.dtype), device=device)
|
||||||
|
|
||||||
|
# Need to sync torch to make sure the data is valid.
|
||||||
|
if data.device.type == "mps": torch.mps.synchronize()
|
||||||
|
else: torch.cuda.synchronize()
|
||||||
|
|
||||||
with Context(BEAM=2): f(tg_out, tg_data)
|
with Context(BEAM=2): f(tg_out, tg_data)
|
||||||
|
|
||||||
|
# Since realize() is called in f(), at this point tinygrad has finished the computation and the data is valid.
|
||||||
return out
|
return out
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
out = custom_kernel(inp:=torch.rand(16, 16, 3, device=torch.device("cuda")))
|
if OSX:
|
||||||
torch.cuda.synchronize()
|
out = custom_kernel(inp:=torch.rand(16, 16, 3, device=torch.device("mps")), device="METAL")
|
||||||
|
else:
|
||||||
|
out = custom_kernel(inp:=torch.rand(16, 16, 3, device=torch.device("cuda")), device="CUDA")
|
||||||
assert torch.allclose(out, inp[:, :, 0] * 0.2989 + inp[:, :, 1] * 0.5870 + inp[:, :, 2] * 0.1140)
|
assert torch.allclose(out, inp[:, :, 0] * 0.2989 + inp[:, :, 1] * 0.5870 + inp[:, :, 2] * 0.1140)
|
||||||
|
|||||||
@@ -192,6 +192,8 @@ class MetalAllocator(LRUAllocator):
|
|||||||
self.dev:MetalDevice = dev
|
self.dev:MetalDevice = dev
|
||||||
super().__init__()
|
super().__init__()
|
||||||
def _alloc(self, size:int, options) -> MetalBuffer:
|
def _alloc(self, size:int, options) -> MetalBuffer:
|
||||||
|
if options.external_ptr: return MetalBuffer(objc_id(options.external_ptr), size)
|
||||||
|
|
||||||
# Buffer is explicitly released in _free() rather than garbage collected via reference count
|
# Buffer is explicitly released in _free() rather than garbage collected via reference count
|
||||||
ret = msg("newBufferWithLength:options:", objc_id)(self.dev.sysdevice, ctypes.c_ulong(size), MTLResourceOptions.MTLResourceStorageModeShared)
|
ret = msg("newBufferWithLength:options:", objc_id)(self.dev.sysdevice, ctypes.c_ulong(size), MTLResourceOptions.MTLResourceStorageModeShared)
|
||||||
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
|
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
|
||||||
|
|||||||
Reference in New Issue
Block a user