mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
fix jit realize issue (#3258)
This commit is contained in:
17
test/external/external_jit_failure.py
vendored
Normal file
17
test/external/external_jit_failure.py
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
from tinygrad import Tensor, TinyJit, Device
|
||||
import numpy as np
|
||||
|
||||
GPUS = 4
|
||||
N = 128
|
||||
ds = tuple([Device.canonicalize(f"{Device.DEFAULT}:{i}") for i in range(GPUS)])
|
||||
t = Tensor.rand(N, N, N).shard(ds, 0)
|
||||
n = t.numpy()
|
||||
|
||||
@TinyJit
|
||||
def allreduce(t:Tensor) -> Tensor:
|
||||
return t.sum(0) #.realize()
|
||||
|
||||
for i in range(10):
|
||||
print(i)
|
||||
tn = allreduce(t).numpy()
|
||||
np.testing.assert_allclose(tn, n.sum(0), atol=1e-4, rtol=1e-4)
|
||||
@@ -133,6 +133,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
elif self.cnt == 0:
|
||||
# jit ignore
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
for p in get_parameters(self.ret): p.realize()
|
||||
|
||||
# clear jit inputs
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
@@ -160,7 +160,7 @@ class HIPSyncEvent(JITRunner):
|
||||
to_mv(rawbufs[0]._buf, 4).cast("I")[0] = 0
|
||||
hip_set_device(self.device.device)
|
||||
check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0))
|
||||
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, device=self.dname)
|
||||
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.dname)
|
||||
|
||||
class HIPWaitEvent(JITRunner):
|
||||
def __init__(self, device):
|
||||
@@ -169,4 +169,4 @@ class HIPWaitEvent(JITRunner):
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
|
||||
hip_set_device(self.device.device)
|
||||
check(hip.hipStreamWaitValue32(None, rawbufs[0]._buf, 1, 1, 0xFFFFFFFF))
|
||||
update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, device=self.dname)
|
||||
update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, jit, device=self.dname)
|
||||
|
||||
Reference in New Issue
Block a user