fast resnet eval (#3135)

* fast resnet eval

* fix HIP multidevice graph

* neater expression for devices

* lines

* add decorator test
This commit is contained in:
George Hotz
2024-01-15 14:15:18 -08:00
committed by GitHub
parent b7b494e9b8
commit a464909d79
10 changed files with 222 additions and 56 deletions

View File

@@ -1,7 +1,7 @@
import ctypes
from typing import Any, Optional, Tuple, Dict, List, cast
import gpuctypes.cuda as cuda
from tinygrad.helpers import init_c_var, encode_args_cuda_style
from tinygrad.helpers import init_c_var, encode_args_cuda_style, all_same
from tinygrad.device import CompiledASTRunner, update_stats, Buffer
from tinygrad.runtime.ops_cuda import check, cu_time_execution
from tinygrad.shape.symbolic import Variable
@@ -9,7 +9,9 @@ from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_
class CUDAGraph:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException
devices = [ji.prg.clprg.device if isinstance(ji.prg, CompiledASTRunner) else None for ji in jit_cache]
if len(devices) == 0 or not all_same(devices) or devices[0] is None: raise GraphException
self.device = devices[0]
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)