mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-13 16:15:19 -05:00
JitItem -> ExecItem (#4146)
* JitItem -> ExecItem * execitem in realize * cleaner * JITRunner -> Runner
This commit is contained in:
@@ -5,11 +5,11 @@ from tinygrad.helpers import init_c_var, GraphException, getenv
|
||||
from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
|
||||
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
|
||||
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
|
||||
class CUDAGraph(MultiDeviceJITGraph):
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
# Check all jit items are compatible.
|
||||
if not all(isinstance(ji.prg, CompiledASTRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ from tinygrad.buffer import Buffer, BufferOptions
|
||||
from tinygrad.device import Compiled, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
||||
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
|
||||
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
|
||||
|
||||
@@ -26,7 +26,7 @@ class VirtAQLQueue(AQLQueue):
|
||||
self.available_packet_slots -= 1
|
||||
|
||||
class HSAGraph(MultiDeviceJITGraph):
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) #type:ignore
|
||||
|
||||
@@ -3,12 +3,13 @@ import Metal
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, unwrap2, GraphException
|
||||
from tinygrad.device import Buffer, CompiledASTRunner, update_stats
|
||||
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_metal import MetalDevice, wait_check
|
||||
|
||||
class MetalGraph:
|
||||
def __init__(self, device:MetalDevice, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
def __init__(self, device:MetalDevice, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
self.jit_cache = jit_cache
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import os, mmap, _posixshmem, io, functools
|
||||
from typing import Dict, List, Any, Optional
|
||||
from tinygrad.helpers import prod, OSX
|
||||
from tinygrad.device import Compiled, Allocator, JITRunner, Buffer
|
||||
from tinygrad.device import Compiled, Allocator, Runner, Buffer
|
||||
from tinygrad.ops import UnaryOps, LazyOp, BufferOps
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
|
||||
@@ -32,7 +32,7 @@ class DiskAllocator(Allocator):
|
||||
else:
|
||||
dest[:] = src._buf()
|
||||
|
||||
class DiskRunner(JITRunner):
|
||||
class DiskRunner(Runner):
|
||||
def __init__(self, ast:LazyOp):
|
||||
# two ASTs are allowed here.
|
||||
assert ast.op is BufferOps.STORE, "output of AST must be store"
|
||||
|
||||
Reference in New Issue
Block a user