fix bug in jit prune with copy [pr] (#8073)

This commit is contained in:
George Hotz
2024-12-06 18:38:23 +08:00
committed by GitHub
parent aae8557ada
commit e37bff6c19
2 changed files with 40 additions and 2 deletions

View File

@@ -492,7 +492,43 @@ class TestCopyInsideJit(unittest.TestCase):
a = Tensor.rand(16,16,device="CLANG").realize() a = Tensor.rand(16,16,device="CLANG").realize()
b = Tensor.rand(16,16).realize() b = Tensor.rand(16,16).realize()
out = add(a,b) out = add(a,b)
self.assertListEqual(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())]) np.testing.assert_allclose(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
class TestJitPrune(unittest.TestCase):
def test_simple_prune(self):
weights = Tensor.rand(16).realize()
def w2(x) -> Tensor: return (weights*2).contiguous() + x
w2_noprune = TinyJit(w2)
w2_prune = TinyJit(w2, prune=True)
for _ in range(3):
a = Tensor.rand(16).realize()
out = w2_noprune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_noprune.captured.jit_cache) == 2
for _ in range(3):
a = Tensor.rand(16).realize()
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_prune.captured.jit_cache) == 1
def test_prune_w_copy_correct(self):
weights = Tensor.rand(16).realize()
def w2(x) -> Tensor: return (weights*2).contiguous() + x.to(Device.DEFAULT)
w2_noprune = TinyJit(w2)
w2_prune = TinyJit(w2, prune=True)
for _ in range(3):
a = Tensor.rand(16, device="CLANG").realize()
out = w2_noprune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
for _ in range(3):
a = Tensor.rand(16, device="CLANG").realize()
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -8,7 +8,7 @@ from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType from tinygrad.dtype import DType
from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner
from tinygrad.engine.memory import _internal_memory_planner from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from dataclasses import dataclass from dataclasses import dataclass
@@ -268,6 +268,8 @@ class TinyJit(Generic[ReturnType]):
if any(b in depends for b in ei.bufs): if any(b in depends for b in ei.bufs):
if isinstance(ei.prg, CompiledRunner): if isinstance(ei.prg, CompiledRunner):
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs) depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs)
if isinstance(ei.prg, (BufferCopy, BufferXfer)):
depends.add(cast(Buffer, ei.bufs[0]))
pruned, onetime = partition(jit_cache, pruned, onetime = partition(jit_cache,
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs)) lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels") if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")