mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
fix bug in jit prune with copy [pr] (#8073)
This commit is contained in:
@@ -492,7 +492,43 @@ class TestCopyInsideJit(unittest.TestCase):
|
||||
a = Tensor.rand(16,16,device="CLANG").realize()
|
||||
b = Tensor.rand(16,16).realize()
|
||||
out = add(a,b)
|
||||
self.assertListEqual(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
|
||||
np.testing.assert_allclose(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
|
||||
|
||||
class TestJitPrune(unittest.TestCase):
|
||||
def test_simple_prune(self):
|
||||
weights = Tensor.rand(16).realize()
|
||||
def w2(x) -> Tensor: return (weights*2).contiguous() + x
|
||||
w2_noprune = TinyJit(w2)
|
||||
w2_prune = TinyJit(w2, prune=True)
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16).realize()
|
||||
out = w2_noprune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
assert len(w2_noprune.captured.jit_cache) == 2
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16).realize()
|
||||
out = w2_prune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
assert len(w2_prune.captured.jit_cache) == 1
|
||||
|
||||
def test_prune_w_copy_correct(self):
|
||||
weights = Tensor.rand(16).realize()
|
||||
def w2(x) -> Tensor: return (weights*2).contiguous() + x.to(Device.DEFAULT)
|
||||
w2_noprune = TinyJit(w2)
|
||||
w2_prune = TinyJit(w2, prune=True)
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16, device="CLANG").realize()
|
||||
out = w2_noprune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16, device="CLANG").realize()
|
||||
out = w2_prune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.device import Buffer, Compiled, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
|
||||
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from dataclasses import dataclass
|
||||
@@ -268,6 +268,8 @@ class TinyJit(Generic[ReturnType]):
|
||||
if any(b in depends for b in ei.bufs):
|
||||
if isinstance(ei.prg, CompiledRunner):
|
||||
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs)
|
||||
if isinstance(ei.prg, (BufferCopy, BufferXfer)):
|
||||
depends.add(cast(Buffer, ei.bufs[0]))
|
||||
pruned, onetime = partition(jit_cache,
|
||||
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
|
||||
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
||||
|
||||
Reference in New Issue
Block a user