more typing work [pr] (#8345)

This commit is contained in:
George Hotz
2024-12-19 21:46:35 -08:00
committed by GitHub
parent 9c77e9f9b7
commit 62e5d96446
9 changed files with 38 additions and 46 deletions

View File

@@ -1,11 +1,10 @@
from typing import Optional
import ctypes, subprocess, pathlib, tempfile
from tinygrad.device import Compiled, Compiler, MallocAllocator
from tinygrad.helpers import cpu_time_execution, cpu_objdump
from tinygrad.renderer.cstyle import ClangRenderer
class ClangCompiler(Compiler):
def __init__(self, cachekey="compile_clang", args:Optional[list[str]]=None, objdump_tool='objdump'):
def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'):
self.args = ['-march=native'] if args is None else args
self.objdump_tool = objdump_tool
super().__init__(cachekey)

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import ctypes, ctypes.util, functools
from typing import Optional
from tinygrad.helpers import DEBUG, getenv, from_mv, init_c_var, init_c_struct_t
from tinygrad.device import Compiled, BufferSpec, LRUAllocator
from tinygrad.renderer.cstyle import CUDARenderer
@@ -19,7 +18,7 @@ def encode_args(args, vals) -> tuple[ctypes.Structure, ctypes.Array]:
ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(0))
return c_args, vargs
def cu_time_execution(cb, enable=False) -> Optional[float]:
def cu_time_execution(cb, enable=False) -> float|None:
if not enable: return cb()
evs = [init_c_var(cuda.CUevent(), lambda x: cuda.cuEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
cuda.cuEventRecord(evs[0], None)
@@ -110,7 +109,7 @@ class CUDADevice(Compiled):
CUDADevice.peer_access = True
self.arch = f"sm_{major.value}{minor.value}"
self.pending_copyin: list[tuple[int, int, Optional[BufferSpec]]] = []
self.pending_copyin: list[tuple[int, int, BufferSpec|None]] = []
CUDADevice.devices.append(self)
from tinygrad.runtime.graph.cuda import CUDAGraph

View File

@@ -1,11 +1,10 @@
from typing import Any
from dataclasses import dataclass
import tinygrad.runtime.autogen.libc as libc
@dataclass(frozen=True)
class ElfSection: name:str; header:libc.Elf64_Shdr; content:bytes # noqa: E702
def elf_loader(blob:bytes, force_section_align:int=1) -> tuple[memoryview, list[ElfSection], Any]:
def elf_loader(blob:bytes, force_section_align:int=1) -> tuple[memoryview, list[ElfSection], list[tuple]]:
def _strtab(blob: bytes, idx: int) -> str: return blob[idx:blob.find(b'\x00', idx)].decode('utf-8')
header = libc.Elf64_Ehdr.from_buffer_copy(blob)

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Optional, Dict, cast, Type, TypeVar, Generic, Any
from typing import Optional, cast, Type, TypeVar, Generic, Any
import contextlib, decimal, statistics, time, ctypes, array
from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up
from tinygrad.renderer import Renderer
@@ -308,7 +308,7 @@ class HCQCompiled(Compiled, Generic[SignalType]):
self.timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
self._shadow_timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
self.sig_prof_records:list[tuple[HCQSignal, HCQSignal, str, bool]] = []
self.raw_prof_records:list[tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
self.raw_prof_records:list[tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[dict]]] = []
self.dep_prof_records:list[tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
from tinygrad.runtime.graph.hcq import HCQGraph