mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fast resnet eval (#3135)
* fast resnet eval * fix HIP multidevice graph * neater expression for devices * lines * add decorator test
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import ctypes
|
||||
from typing import Any, Optional, Tuple, Dict, List, cast
|
||||
import gpuctypes.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, encode_args_cuda_style
|
||||
from tinygrad.helpers import init_c_var, encode_args_cuda_style, all_same
|
||||
from tinygrad.device import CompiledASTRunner, update_stats, Buffer
|
||||
from tinygrad.runtime.ops_cuda import check, cu_time_execution
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
@@ -9,7 +9,9 @@ from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_
|
||||
|
||||
class CUDAGraph:
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException
|
||||
devices = [ji.prg.clprg.device if isinstance(ji.prg, CompiledASTRunner) else None for ji in jit_cache]
|
||||
if len(devices) == 0 or not all_same(devices) or devices[0] is None: raise GraphException
|
||||
self.device = devices[0]
|
||||
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
|
||||
Reference in New Issue
Block a user