JIT_BATCH_SIZE is a ContextVar (#11228)

This commit is contained in:
uuuvn
2025-07-14 11:03:45 +00:00
committed by GitHub
parent c4a920d95c
commit b2cc6cfa1b
2 changed files with 3 additions and 2 deletions

View File

@@ -1,7 +1,7 @@
from typing import TypeVar, Generic, Callable, Union, cast, Optional, Any
import functools, collections
from tinygrad.tensor import Tensor
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, partition, unwrap
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
from tinygrad.dtype import DType
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
@@ -198,7 +198,7 @@ class CapturedJit(Generic[ReturnType]):
if b is not None: b.ensure_allocated()
# create graph if needed
if JIT < 2:
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=getenv("JIT_BATCH_SIZE", 32))
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=JIT_BATCH_SIZE.value)
self._input_replace = get_input_replace(self._jit_cache, input_buffers)
self._first_run = False

View File

@@ -118,6 +118,7 @@ class ContextVar:
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1)
JIT_BATCH_SIZE = ContextVar("JIT_BATCH_SIZE", 32)
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1), ContextVar("NOLOCALS", 0)