mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-14 16:44:59 -05:00
JitItem -> ExecItem (#4146)
* JitItem -> ExecItem * execitem in realize * cleaner * JITRunner -> Runner
This commit is contained in:
@@ -5,11 +5,11 @@ from tinygrad.helpers import init_c_var, GraphException, getenv
|
||||
from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
|
||||
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
|
||||
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
|
||||
class CUDAGraph(MultiDeviceJITGraph):
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
# Check all jit items are compatible.
|
||||
if not all(isinstance(ji.prg, CompiledASTRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException
|
||||
|
||||
|
||||
Reference in New Issue
Block a user