mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
metal PyTorch interop (#9229)
* add from_blob support to mps cuda * objc_id * metal pytorch interop * fix comments --------- Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
@@ -20,13 +20,19 @@ tinygrad provides interoperability with OpenCL and PyTorch, allowing efficient t
|
||||
|
||||
**Important**: When using external memory pointers with tinygrad tensors, you must ensure these pointers remain valid throughout the entire lifetime of the tinygrad tensor to prevent memory corruption.
|
||||
|
||||
### `CUDA` PyTorch Interoperability
|
||||
### `CUDA`/`METAL` PyTorch Interoperability
|
||||
|
||||
You can seamlessly work with CUDA tensors between PyTorch and tinygrad without data copying:
|
||||
You can seamlessly work with CUDA/MPS tensors between PyTorch and tinygrad without data copying:
|
||||
```python
|
||||
from tinygrad.dtype import _from_torch_dtype
|
||||
tensor1 = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cuda"))
|
||||
tiny_tensor1 = Tensor.from_blob(tensor1.data_ptr(), tensor1.shape, dtype=_from_torch_dtype(tensor1.dtype), device='CUDA')
|
||||
|
||||
# Before tinygrad calculations, mps needs to be synchronized to make sure data is valid.
|
||||
if data.device.type == "mps": torch.mps.synchronize()
|
||||
else: torch.cuda.synchronize()
|
||||
|
||||
x = (tiny_tensor1 + 1).realize()
|
||||
```
|
||||
|
||||
### `QCOM` OpenCL Interoperability
|
||||
|
||||
Reference in New Issue
Block a user