mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
CACHECOLLECTING -> CAPTURING and don't capture clear_l2 (#5190)
fixed first time BEAM slowness
This commit is contained in:
@@ -77,6 +77,30 @@ class TestJit(unittest.TestCase):
|
||||
np.testing.assert_allclose(c.numpy(), 2*i)
|
||||
assert_jit_cache_len(add, 0)
|
||||
|
||||
def test_jit_not_capturing(self):
|
||||
@TinyJit
|
||||
def add(a, b):
|
||||
Tensor.zeros(4, 4).contiguous().realize() # no-op kernel is captured
|
||||
return (a+b).realize()
|
||||
for i in range(5):
|
||||
a = Tensor([i])
|
||||
b = Tensor([i])
|
||||
c = add(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), 2*i)
|
||||
assert_jit_cache_len(add, 2)
|
||||
|
||||
@TinyJit
|
||||
def add2(a, b):
|
||||
with Context(CAPTURING=0): # not captured
|
||||
Tensor.zeros(4, 4).contiguous().realize()
|
||||
return (a+b).realize()
|
||||
for i in range(5):
|
||||
a = Tensor([i])
|
||||
b = Tensor([i])
|
||||
c = add2(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), 2*i)
|
||||
assert_jit_cache_len(add2, 1)
|
||||
|
||||
def test_jit_shape_mismatch(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List, Dict, Optional, cast, Generator, Tuple
|
||||
import time
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING
|
||||
from tinygrad.ops import BufferOps, LoadOps, LazyOp
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
@@ -188,5 +188,5 @@ capturing: List = [] # put classes with an add method in here
|
||||
|
||||
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None, do_update_stats=True):
|
||||
for ei in lower_schedule(schedule):
|
||||
if len(capturing): capturing[0].add(ei)
|
||||
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
|
||||
@@ -45,7 +45,7 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
|
||||
input_bufs = [rawbufs[i] for i,_ in car.p.globals]
|
||||
for _ in range(cnt):
|
||||
if clear_l2:
|
||||
with Context(DEBUG=0, BEAM=0, CACHECOLLECTING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
||||
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
||||
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
|
||||
if early_stop is not None and early_stop < tms[-1]: break
|
||||
return tms
|
||||
|
||||
@@ -101,7 +101,7 @@ class ContextVar:
|
||||
def __lt__(self, x): return self.value < x
|
||||
|
||||
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
|
||||
WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1)
|
||||
WINO, THREEFRY, CAPTURING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1)
|
||||
GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
|
||||
MULTIOUTPUT, PROFILE = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user