fix sync of offset buffers in graphs (#4850)

* correctly sync offset buffers

* test

* style

* run less

* just use base
This commit is contained in:
nimlgen
2024-06-06 16:09:45 +03:00
committed by GitHub
parent eeb5a7af39
commit 47bfd7c2b7
2 changed files with 20 additions and 4 deletions

View File

@@ -7,6 +7,7 @@ from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit
from tinygrad.device import Device
from tinygrad.helpers import CI
from tinygrad.dtype import dtypes
def _simple_test(add, extract=lambda x: x, N=10):
for _ in range(5):
@@ -304,6 +305,21 @@ class TestJit(unittest.TestCase):
np.testing.assert_allclose(a.numpy(), xc.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(b.numpy(), yc.numpy(), atol=1e-4, rtol=1e-5)
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU/CUDA/METAL in CI, fine to run on AMD/NV")
def test_jitted_view(self):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
def f(a):
x1 = a.sum(axis=(1,))
x = (x1 + 5).bitcast(dtypes.int32)
y = x.to(d1)
return y.realize()
jf = TinyJit(f)
for _ in range(5):
a = Tensor.randn(10, 1000, device=d0).realize()
xc = jf(a)
np.testing.assert_allclose((a.numpy().sum(axis=(1,)) + 5).view(np.int32), xc.numpy(), atol=1e-4, rtol=1e-5)
@unittest.skip("Pending multioutput implementation #3607")
class TestMultioutputJit(unittest.TestCase):