hotfix: interop example (#9237)

* hotfix: interop example

* rm this

* fix

* fix ci mps

* atol rtol

* no uaf
This commit is contained in:
nimlgen
2025-02-25 10:32:00 +03:00
committed by GitHub
parent 8c7be428e5
commit b4c3780df0
2 changed files with 56 additions and 2 deletions

View File

@@ -3,7 +3,7 @@
# not a stable API, but works # not a stable API, but works
import torch, functools import torch, functools
from tinygrad import Tensor, TinyJit from tinygrad import Tensor, TinyJit, Device
from tinygrad.engine.realize import CompiledRunner from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import get_single_element, Context, OSX from tinygrad.helpers import get_single_element, Context, OSX
from tinygrad.dtype import _from_torch_dtype from tinygrad.dtype import _from_torch_dtype
@@ -24,7 +24,9 @@ def custom_kernel(data: torch.Tensor, device="CUDA") -> torch.Tensor:
with Context(BEAM=2): f(tg_out, tg_data) with Context(BEAM=2): f(tg_out, tg_data)
# Since realize() is called in f(), at this point tinygrad has finished the computation and the data is valid. # Wait for computation to finish and the data is valid.
Device[device].synchronize()
return out return out
if __name__ == "__main__": if __name__ == "__main__":

52
test/test_interop.py Normal file
View File

@@ -0,0 +1,52 @@
#!/usr/bin/env python
import unittest
import torch
import numpy as np
from tinygrad.helpers import getenv, CI
from tinygrad.tensor import Tensor
from tinygrad.device import Device
from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
MOCKGPU = getenv("MOCKGPU")
@unittest.skipIf(Device.DEFAULT not in ["METAL", "CUDA"] or MOCKGPU, f"no support on {Device.DEFAULT}")
class TestInterop(unittest.TestCase):
def setUp(self):
if Device.DEFAULT == "CUDA": self.torch_device = "cuda"
elif Device.DEFAULT == "METAL": self.torch_device = "mps"
def test_torch_interop(self):
inp = torch.rand(2, 2, 3, device=torch.device(self.torch_device))
if self.torch_device == "mps": torch.mps.synchronize()
else: torch.cuda.synchronize()
tg_data = Tensor.from_blob(inp.data_ptr(), inp.shape, dtype=_from_torch_dtype(inp.dtype))
tg_out = tg_data[:, :, 0] * 0.2989 + tg_data[:, :, 1] * 0.5870 + tg_data[:, :, 2] * 0.1140
tg_res = tg_out.numpy()
if self.torch_device == "mps" and CI:
# MPS backend out of memory: https://discuss.pytorch.org/t/mps-back-end-out-of-memory-on-github-action/189773
# Calculate expected value on cpu.
inp = inp.cpu()
torch_out = inp[:, :, 0] * 0.2989 + inp[:, :, 1] * 0.5870 + inp[:, :, 2] * 0.1140
np.testing.assert_allclose(tg_res, torch_out.cpu().numpy(), atol=1e-5, rtol=1e-5)
def test_torch_interop_write(self):
tg_data = Tensor.randn((4, 4), device=Device.DEFAULT)
out = torch.empty(4, 4, device=torch.device(self.torch_device), dtype=_to_torch_dtype(tg_data.dtype))
tg_out = Tensor.from_blob(out.data_ptr(), out.shape, dtype=_from_torch_dtype(out.dtype))
tg_out.assign(tg_data).realize()
Device[Device.DEFAULT].synchronize()
torch_out_np = out.cpu().numpy()
np.testing.assert_allclose(tg_data.numpy(), torch_out_np, atol=1e-5, rtol=1e-5)
if __name__ == '__main__':
unittest.main()