From b4c3780df0cd729e2c8f0d95e53a7d8bb302bf79 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Tue, 25 Feb 2025 10:32:00 +0300 Subject: [PATCH] hotfix: interop example (#9237) * hotfix: interop example * rm this * fix * fix ci mps * atol rtol * no uaf --- examples/torch_cuda_kernel.py | 6 ++-- test/test_interop.py | 52 +++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 test/test_interop.py diff --git a/examples/torch_cuda_kernel.py b/examples/torch_cuda_kernel.py index 8c89b7d5a6..5d1706efe8 100644 --- a/examples/torch_cuda_kernel.py +++ b/examples/torch_cuda_kernel.py @@ -3,7 +3,7 @@ # not a stable API, but works import torch, functools -from tinygrad import Tensor, TinyJit +from tinygrad import Tensor, TinyJit, Device from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import get_single_element, Context, OSX from tinygrad.dtype import _from_torch_dtype @@ -24,7 +24,9 @@ def custom_kernel(data: torch.Tensor, device="CUDA") -> torch.Tensor: with Context(BEAM=2): f(tg_out, tg_data) - # Since realize() is called in f(), at this point tinygrad has finished the computation and the data is valid. + # Wait for computation to finish and the data is valid. + Device[device].synchronize() + return out if __name__ == "__main__": diff --git a/test/test_interop.py b/test/test_interop.py new file mode 100644 index 0000000000..dafe54319d --- /dev/null +++ b/test/test_interop.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +import unittest +import torch +import numpy as np + +from tinygrad.helpers import getenv, CI +from tinygrad.tensor import Tensor +from tinygrad.device import Device +from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype + +MOCKGPU = getenv("MOCKGPU") + +@unittest.skipIf(Device.DEFAULT not in ["METAL", "CUDA"] or MOCKGPU, f"no support on {Device.DEFAULT}") +class TestInterop(unittest.TestCase): + def setUp(self): + if Device.DEFAULT == "CUDA": self.torch_device = "cuda" + elif Device.DEFAULT == "METAL": self.torch_device = "mps" + + def test_torch_interop(self): + inp = torch.rand(2, 2, 3, device=torch.device(self.torch_device)) + + if self.torch_device == "mps": torch.mps.synchronize() + else: torch.cuda.synchronize() + + tg_data = Tensor.from_blob(inp.data_ptr(), inp.shape, dtype=_from_torch_dtype(inp.dtype)) + + tg_out = tg_data[:, :, 0] * 0.2989 + tg_data[:, :, 1] * 0.5870 + tg_data[:, :, 2] * 0.1140 + tg_res = tg_out.numpy() + + if self.torch_device == "mps" and CI: + # MPS backend out of memory: https://discuss.pytorch.org/t/mps-back-end-out-of-memory-on-github-action/189773 + # Calculate expected value on cpu. + inp = inp.cpu() + torch_out = inp[:, :, 0] * 0.2989 + inp[:, :, 1] * 0.5870 + inp[:, :, 2] * 0.1140 + + np.testing.assert_allclose(tg_res, torch_out.cpu().numpy(), atol=1e-5, rtol=1e-5) + + def test_torch_interop_write(self): + tg_data = Tensor.randn((4, 4), device=Device.DEFAULT) + + out = torch.empty(4, 4, device=torch.device(self.torch_device), dtype=_to_torch_dtype(tg_data.dtype)) + tg_out = Tensor.from_blob(out.data_ptr(), out.shape, dtype=_from_torch_dtype(out.dtype)) + + tg_out.assign(tg_data).realize() + Device[Device.DEFAULT].synchronize() + + torch_out_np = out.cpu().numpy() + + np.testing.assert_allclose(tg_data.numpy(), torch_out_np, atol=1e-5, rtol=1e-5) + +if __name__ == '__main__': + unittest.main()