CACHECOLLECTING -> CAPTURING and don't capture clear_l2 (#5190)

fixed first time BEAM slowness
This commit is contained in:
chenyu
2024-06-27 12:32:28 -04:00
committed by GitHub
parent 01e8838b65
commit ad91962dcf
4 changed files with 28 additions and 4 deletions

View File

@@ -77,6 +77,30 @@ class TestJit(unittest.TestCase):
np.testing.assert_allclose(c.numpy(), 2*i)
assert_jit_cache_len(add, 0)
def test_jit_not_capturing(self):
@TinyJit
def add(a, b):
Tensor.zeros(4, 4).contiguous().realize() # no-op kernel is captured
return (a+b).realize()
for i in range(5):
a = Tensor([i])
b = Tensor([i])
c = add(a, b)
np.testing.assert_allclose(c.numpy(), 2*i)
assert_jit_cache_len(add, 2)
@TinyJit
def add2(a, b):
with Context(CAPTURING=0): # not captured
Tensor.zeros(4, 4).contiguous().realize()
return (a+b).realize()
for i in range(5):
a = Tensor([i])
b = Tensor([i])
c = add2(a, b)
np.testing.assert_allclose(c.numpy(), 2*i)
assert_jit_cache_len(add2, 1)
def test_jit_shape_mismatch(self):
@TinyJit
def add(a, b): return (a+b).realize()

View File

@@ -1,7 +1,7 @@
from typing import List, Dict, Optional, cast, Generator, Tuple
import time
from dataclasses import dataclass, replace
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING
from tinygrad.ops import BufferOps, LoadOps, LazyOp
from tinygrad.device import Device, Buffer
from tinygrad.shape.symbolic import Variable, sym_infer, sint
@@ -188,5 +188,5 @@ capturing: List = [] # put classes with an add method in here
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None, do_update_stats=True):
for ei in lower_schedule(schedule):
if len(capturing): capturing[0].add(ei)
if len(capturing) and CAPTURING: capturing[0].add(ei)
ei.run(var_vals, do_update_stats=do_update_stats)

View File

@@ -45,7 +45,7 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
input_bufs = [rawbufs[i] for i,_ in car.p.globals]
for _ in range(cnt):
if clear_l2:
with Context(DEBUG=0, BEAM=0, CACHECOLLECTING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
if early_stop is not None and early_stop < tms[-1]: break
return tms

View File

@@ -101,7 +101,7 @@ class ContextVar:
def __lt__(self, x): return self.value < x
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1)
WINO, THREEFRY, CAPTURING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1)
GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
MULTIOUTPUT, PROFILE = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0)