hotfix: interop example (#9237)

* hotfix: interop example

* rm this

* fix

* fix ci mps

* atol rtol

* no uaf
This commit is contained in:
nimlgen
2025-02-25 10:32:00 +03:00
committed by GitHub
parent 8c7be428e5
commit b4c3780df0
2 changed files with 56 additions and 2 deletions

View File

@@ -3,7 +3,7 @@
# not a stable API, but works
import torch, functools
from tinygrad import Tensor, TinyJit
from tinygrad import Tensor, TinyJit, Device
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import get_single_element, Context, OSX
from tinygrad.dtype import _from_torch_dtype
@@ -24,7 +24,9 @@ def custom_kernel(data: torch.Tensor, device="CUDA") -> torch.Tensor:
with Context(BEAM=2): f(tg_out, tg_data)
# Since realize() is called in f(), at this point tinygrad has finished the computation and the data is valid.
# Wait for computation to finish and the data is valid.
Device[device].synchronize()
return out
if __name__ == "__main__":

52
test/test_interop.py Normal file
View File

@@ -0,0 +1,52 @@
#!/usr/bin/env python
import unittest
import torch
import numpy as np
from tinygrad.helpers import getenv, CI
from tinygrad.tensor import Tensor
from tinygrad.device import Device
from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
MOCKGPU = getenv("MOCKGPU")
@unittest.skipIf(Device.DEFAULT not in ["METAL", "CUDA"] or MOCKGPU, f"no support on {Device.DEFAULT}")
class TestInterop(unittest.TestCase):
def setUp(self):
if Device.DEFAULT == "CUDA": self.torch_device = "cuda"
elif Device.DEFAULT == "METAL": self.torch_device = "mps"
def test_torch_interop(self):
inp = torch.rand(2, 2, 3, device=torch.device(self.torch_device))
if self.torch_device == "mps": torch.mps.synchronize()
else: torch.cuda.synchronize()
tg_data = Tensor.from_blob(inp.data_ptr(), inp.shape, dtype=_from_torch_dtype(inp.dtype))
tg_out = tg_data[:, :, 0] * 0.2989 + tg_data[:, :, 1] * 0.5870 + tg_data[:, :, 2] * 0.1140
tg_res = tg_out.numpy()
if self.torch_device == "mps" and CI:
# MPS backend out of memory: https://discuss.pytorch.org/t/mps-back-end-out-of-memory-on-github-action/189773
# Calculate expected value on cpu.
inp = inp.cpu()
torch_out = inp[:, :, 0] * 0.2989 + inp[:, :, 1] * 0.5870 + inp[:, :, 2] * 0.1140
np.testing.assert_allclose(tg_res, torch_out.cpu().numpy(), atol=1e-5, rtol=1e-5)
def test_torch_interop_write(self):
tg_data = Tensor.randn((4, 4), device=Device.DEFAULT)
out = torch.empty(4, 4, device=torch.device(self.torch_device), dtype=_to_torch_dtype(tg_data.dtype))
tg_out = Tensor.from_blob(out.data_ptr(), out.shape, dtype=_from_torch_dtype(out.dtype))
tg_out.assign(tg_data).realize()
Device[Device.DEFAULT].synchronize()
torch_out_np = out.cpu().numpy()
np.testing.assert_allclose(tg_data.numpy(), torch_out_np, atol=1e-5, rtol=1e-5)
if __name__ == '__main__':
unittest.main()