JitItem -> ExecItem (#4146)

* JitItem -> ExecItem

* execitem in realize

* cleaner

* JITRunner -> Runner
This commit is contained in:
George Hotz
2024-04-11 08:24:57 -07:00
committed by GitHub
parent e79a11b99c
commit b7e281cf10
10 changed files with 55 additions and 58 deletions

View File

@@ -5,11 +5,11 @@ from tinygrad.helpers import init_c_var, GraphException, getenv
from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
class CUDAGraph(MultiDeviceJITGraph):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
# Check all jit items are compatible.
if not all(isinstance(ji.prg, CompiledASTRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException

View File

@@ -5,8 +5,8 @@ from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Compiled, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
@@ -26,7 +26,7 @@ class VirtAQLQueue(AQLQueue):
self.available_packet_slots -= 1
class HSAGraph(MultiDeviceJITGraph):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) #type:ignore

View File

@@ -3,12 +3,13 @@ import Metal
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, unwrap2, GraphException
from tinygrad.device import Buffer, CompiledASTRunner, update_stats
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_metal import MetalDevice, wait_check
class MetalGraph:
def __init__(self, device:MetalDevice, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
def __init__(self, device:MetalDevice, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException
self.jit_cache = jit_cache

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import os, mmap, _posixshmem, io, functools
from typing import Dict, List, Any, Optional
from tinygrad.helpers import prod, OSX
from tinygrad.device import Compiled, Allocator, JITRunner, Buffer
from tinygrad.device import Compiled, Allocator, Runner, Buffer
from tinygrad.ops import UnaryOps, LazyOp, BufferOps
from tinygrad.shape.view import strides_for_shape
@@ -32,7 +32,7 @@ class DiskAllocator(Allocator):
else:
dest[:] = src._buf()
class DiskRunner(JITRunner):
class DiskRunner(Runner):
def __init__(self, ast:LazyOp):
# two ASTs are allowed here.
assert ast.op is BufferOps.STORE, "output of AST must be store"