fix bug in jit prune with copy [pr] (#8073)

This commit is contained in:
George Hotz
2024-12-06 18:38:23 +08:00
committed by GitHub
parent aae8557ada
commit e37bff6c19
2 changed files with 40 additions and 2 deletions

View File

@@ -492,7 +492,43 @@ class TestCopyInsideJit(unittest.TestCase):
a = Tensor.rand(16,16,device="CLANG").realize()
b = Tensor.rand(16,16).realize()
out = add(a,b)
self.assertListEqual(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
np.testing.assert_allclose(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
class TestJitPrune(unittest.TestCase):
def test_simple_prune(self):
weights = Tensor.rand(16).realize()
def w2(x) -> Tensor: return (weights*2).contiguous() + x
w2_noprune = TinyJit(w2)
w2_prune = TinyJit(w2, prune=True)
for _ in range(3):
a = Tensor.rand(16).realize()
out = w2_noprune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_noprune.captured.jit_cache) == 2
for _ in range(3):
a = Tensor.rand(16).realize()
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_prune.captured.jit_cache) == 1
def test_prune_w_copy_correct(self):
weights = Tensor.rand(16).realize()
def w2(x) -> Tensor: return (weights*2).contiguous() + x.to(Device.DEFAULT)
w2_noprune = TinyJit(w2)
w2_prune = TinyJit(w2, prune=True)
for _ in range(3):
a = Tensor.rand(16, device="CLANG").realize()
out = w2_noprune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
for _ in range(3):
a = Tensor.rand(16, device="CLANG").realize()
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
if __name__ == '__main__':
unittest.main()

View File

@@ -8,7 +8,7 @@ from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType
from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner
from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.nn.state import get_parameters
from dataclasses import dataclass
@@ -268,6 +268,8 @@ class TinyJit(Generic[ReturnType]):
if any(b in depends for b in ei.bufs):
if isinstance(ei.prg, CompiledRunner):
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs)
if isinstance(ei.prg, (BufferCopy, BufferXfer)):
depends.add(cast(Buffer, ei.bufs[0]))
pruned, onetime = partition(jit_cache,
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")