diff --git a/tinygrad/runtime/support/compiler_cuda.py b/tinygrad/runtime/support/compiler_cuda.py index c8efff4140..e0bc06726c 100644 --- a/tinygrad/runtime/support/compiler_cuda.py +++ b/tinygrad/runtime/support/compiler_cuda.py @@ -62,8 +62,10 @@ class NVCompiler(CUDACompiler): def __init__(self, arch:str): super().__init__(arch, cache_key="nv") def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize) -class PTXCompiler(CUDACompiler): - def __init__(self, arch:str, cache_key="ptx"): super().__init__(arch, cache_key=cache_key) +class PTXCompiler(Compiler): + def __init__(self, arch:str, cache_key="ptx"): + self.arch = arch + super().__init__(f"compile_{cache_key}_{self.arch}") def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", "7.8" if self.arch >= "sm_89" else "7.5").encode() class NVPTXCompiler(PTXCompiler): diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index f6eb60eba3..01aaaa491e 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -247,11 +247,12 @@ class HCQSignal: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})") @contextlib.contextmanager -def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type=None, queue=None): +def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Optional[Type[HWQueue]]=None, queue:Optional[HWQueue]=None): st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None) if enabled and queue is not None: queue.timestamp(st) elif enabled: + assert queue_type is not None queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev) dev.timeline_value += 1 @@ -259,6 +260,7 @@ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type=None, queue=None): finally: if enabled and queue is not None: queue.timestamp(en) elif enabled: + assert queue_type is not None queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev) dev.timeline_value += 1 @@ -363,7 +365,8 @@ class HCQCompiled(Compiled, Generic[SignalType]): comp_queue_t:Type[HWQueue], copy_queue_t:Optional[Type[HWQueue]]): self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t self.timeline_value:int = 1 - self.timeline_signal, self._shadow_timeline_signal = self.signal_t(0, is_timeline=True), self.signal_t(0, is_timeline=True) + self.timeline_signal:SignalType = self.signal_t(0, is_timeline=True) + self._shadow_timeline_signal:SignalType = self.signal_t(0, is_timeline=True) self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = [] self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = [] self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []