mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix bug in jit prune with copy [pr] (#8073)
This commit is contained in:
@@ -492,7 +492,43 @@ class TestCopyInsideJit(unittest.TestCase):
|
|||||||
a = Tensor.rand(16,16,device="CLANG").realize()
|
a = Tensor.rand(16,16,device="CLANG").realize()
|
||||||
b = Tensor.rand(16,16).realize()
|
b = Tensor.rand(16,16).realize()
|
||||||
out = add(a,b)
|
out = add(a,b)
|
||||||
self.assertListEqual(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
|
np.testing.assert_allclose(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
|
||||||
|
|
||||||
|
class TestJitPrune(unittest.TestCase):
|
||||||
|
def test_simple_prune(self):
|
||||||
|
weights = Tensor.rand(16).realize()
|
||||||
|
def w2(x) -> Tensor: return (weights*2).contiguous() + x
|
||||||
|
w2_noprune = TinyJit(w2)
|
||||||
|
w2_prune = TinyJit(w2, prune=True)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
a = Tensor.rand(16).realize()
|
||||||
|
out = w2_noprune(a)
|
||||||
|
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||||
|
assert len(w2_noprune.captured.jit_cache) == 2
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
a = Tensor.rand(16).realize()
|
||||||
|
out = w2_prune(a)
|
||||||
|
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||||
|
assert len(w2_prune.captured.jit_cache) == 1
|
||||||
|
|
||||||
|
def test_prune_w_copy_correct(self):
|
||||||
|
weights = Tensor.rand(16).realize()
|
||||||
|
def w2(x) -> Tensor: return (weights*2).contiguous() + x.to(Device.DEFAULT)
|
||||||
|
w2_noprune = TinyJit(w2)
|
||||||
|
w2_prune = TinyJit(w2, prune=True)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
a = Tensor.rand(16, device="CLANG").realize()
|
||||||
|
out = w2_noprune(a)
|
||||||
|
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
a = Tensor.rand(16, device="CLANG").realize()
|
||||||
|
out = w2_prune(a)
|
||||||
|
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from tinygrad.device import Buffer, Compiled, Device
|
|||||||
from tinygrad.dtype import DType
|
from tinygrad.dtype import DType
|
||||||
from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
|
from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
|
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner
|
||||||
from tinygrad.engine.memory import _internal_memory_planner
|
from tinygrad.engine.memory import _internal_memory_planner
|
||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -268,6 +268,8 @@ class TinyJit(Generic[ReturnType]):
|
|||||||
if any(b in depends for b in ei.bufs):
|
if any(b in depends for b in ei.bufs):
|
||||||
if isinstance(ei.prg, CompiledRunner):
|
if isinstance(ei.prg, CompiledRunner):
|
||||||
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs)
|
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs)
|
||||||
|
if isinstance(ei.prg, (BufferCopy, BufferXfer)):
|
||||||
|
depends.add(cast(Buffer, ei.bufs[0]))
|
||||||
pruned, onetime = partition(jit_cache,
|
pruned, onetime = partition(jit_cache,
|
||||||
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
|
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
|
||||||
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
||||||
|
|||||||
Reference in New Issue
Block a user