fix jit realize issue (#3258)

This commit is contained in:
George Hotz
2024-01-26 18:27:35 -08:00
committed by GitHub
parent 4197ef17c4
commit c4d870db0d
3 changed files with 20 additions and 2 deletions

17
test/external/external_jit_failure.py vendored Normal file
View File

@@ -0,0 +1,17 @@
from tinygrad import Tensor, TinyJit, Device
import numpy as np
GPUS = 4
N = 128
ds = tuple([Device.canonicalize(f"{Device.DEFAULT}:{i}") for i in range(GPUS)])
t = Tensor.rand(N, N, N).shard(ds, 0)
n = t.numpy()
@TinyJit
def allreduce(t:Tensor) -> Tensor:
return t.sum(0) #.realize()
for i in range(10):
print(i)
tn = allreduce(t).numpy()
np.testing.assert_allclose(tn, n.sum(0), atol=1e-4, rtol=1e-4)

View File

@@ -133,6 +133,7 @@ class TinyJit(Generic[ReturnType]):
elif self.cnt == 0:
# jit ignore
self.ret = self.fxn(*args, **kwargs)
for p in get_parameters(self.ret): p.realize()
# clear jit inputs
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None

View File

@@ -160,7 +160,7 @@ class HIPSyncEvent(JITRunner):
to_mv(rawbufs[0]._buf, 4).cast("I")[0] = 0
hip_set_device(self.device.device)
check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0))
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, device=self.dname)
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.dname)
class HIPWaitEvent(JITRunner):
def __init__(self, device):
@@ -169,4 +169,4 @@ class HIPWaitEvent(JITRunner):
def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
hip_set_device(self.device.device)
check(hip.hipStreamWaitValue32(None, rawbufs[0]._buf, 1, 1, 0xFFFFFFFF))
update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, device=self.dname)
update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, jit, device=self.dname)