mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
hotfix: interop example (#9237)
* hotfix: interop example * rm this * fix * fix ci mps * atol rtol * no uaf
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# not a stable API, but works
|
||||
|
||||
import torch, functools
|
||||
from tinygrad import Tensor, TinyJit
|
||||
from tinygrad import Tensor, TinyJit, Device
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import get_single_element, Context, OSX
|
||||
from tinygrad.dtype import _from_torch_dtype
|
||||
@@ -24,7 +24,9 @@ def custom_kernel(data: torch.Tensor, device="CUDA") -> torch.Tensor:
|
||||
|
||||
with Context(BEAM=2): f(tg_out, tg_data)
|
||||
|
||||
# Since realize() is called in f(), at this point tinygrad has finished the computation and the data is valid.
|
||||
# Wait for computation to finish and the data is valid.
|
||||
Device[device].synchronize()
|
||||
|
||||
return out
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
52
test/test_interop.py
Normal file
52
test/test_interop.py
Normal file
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.helpers import getenv, CI
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
|
||||
|
||||
MOCKGPU = getenv("MOCKGPU")
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ["METAL", "CUDA"] or MOCKGPU, f"no support on {Device.DEFAULT}")
|
||||
class TestInterop(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if Device.DEFAULT == "CUDA": self.torch_device = "cuda"
|
||||
elif Device.DEFAULT == "METAL": self.torch_device = "mps"
|
||||
|
||||
def test_torch_interop(self):
|
||||
inp = torch.rand(2, 2, 3, device=torch.device(self.torch_device))
|
||||
|
||||
if self.torch_device == "mps": torch.mps.synchronize()
|
||||
else: torch.cuda.synchronize()
|
||||
|
||||
tg_data = Tensor.from_blob(inp.data_ptr(), inp.shape, dtype=_from_torch_dtype(inp.dtype))
|
||||
|
||||
tg_out = tg_data[:, :, 0] * 0.2989 + tg_data[:, :, 1] * 0.5870 + tg_data[:, :, 2] * 0.1140
|
||||
tg_res = tg_out.numpy()
|
||||
|
||||
if self.torch_device == "mps" and CI:
|
||||
# MPS backend out of memory: https://discuss.pytorch.org/t/mps-back-end-out-of-memory-on-github-action/189773
|
||||
# Calculate expected value on cpu.
|
||||
inp = inp.cpu()
|
||||
torch_out = inp[:, :, 0] * 0.2989 + inp[:, :, 1] * 0.5870 + inp[:, :, 2] * 0.1140
|
||||
|
||||
np.testing.assert_allclose(tg_res, torch_out.cpu().numpy(), atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_torch_interop_write(self):
|
||||
tg_data = Tensor.randn((4, 4), device=Device.DEFAULT)
|
||||
|
||||
out = torch.empty(4, 4, device=torch.device(self.torch_device), dtype=_to_torch_dtype(tg_data.dtype))
|
||||
tg_out = Tensor.from_blob(out.data_ptr(), out.shape, dtype=_from_torch_dtype(out.dtype))
|
||||
|
||||
tg_out.assign(tg_data).realize()
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
|
||||
torch_out_np = out.cpu().numpy()
|
||||
|
||||
np.testing.assert_allclose(tg_data.numpy(), torch_out_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user