diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index e3f2ef237c..0830d5049c 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -1,8 +1,8 @@ import json, pathlib, zipfile, pickle, tarfile, struct, functools, io -from typing import Union, Optional, Any, Callable, BinaryIO, Iterable, TypeVar +from typing import Union, Optional, Any, Callable, BinaryIO, Iterable from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes -from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up +from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T from tinygrad.shape.view import strides_for_shape from tinygrad.multi import MultiLazyBuffer @@ -35,10 +35,9 @@ safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dt "I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64} inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()} -R = TypeVar('R') -def accept_filename(func: Callable[[Tensor], R]) -> Callable[[Union[Tensor, str, pathlib.Path]], R]: +def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Union[Tensor, str, pathlib.Path]], T]: @functools.wraps(func) - def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> R: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn) + def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> T: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn) return wrapper @accept_filename diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index f97c13661e..9bb71dc2d0 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import Optional, Callable, Any +from typing import Optional, Callable import functools from dataclasses import dataclass, field, replace from tinygrad.helpers import to_function_name, dedup, prod -from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp +from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher from tinygrad.dtype import DType @dataclass(frozen=True) @@ -116,7 +116,7 @@ class Renderer: local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now shared_max: int = 32768 tensor_cores: list[TensorCore] = [] - extra_matcher: Any = None + extra_matcher: Optional[PatternMatcher] = None code_for_op: dict[Ops, Callable] = {} def __reduce__(self): return self.__class__, () diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 677a2c8078..8619aaa2ac 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -146,7 +146,7 @@ class AMDComputeQueue(HWQueue): self.indirect_cmd = [amd_gpu.PACKET3(amd_gpu.PACKET3_INDIRECT_BUFFER, 2), *data64_le(self.hw_page.va_addr), len(self._q) | amd_gpu.INDIRECT_BUFFER_VALID] - self._q = hw_view # type: ignore + self._q = hw_view return self def _submit(self, dev:AMDDevice): diff --git a/tinygrad/runtime/ops_cloud.py b/tinygrad/runtime/ops_cloud.py index 5ed31b020a..e1550c5eb4 100644 --- a/tinygrad/runtime/ops_cloud.py +++ b/tinygrad/runtime/ops_cloud.py @@ -5,7 +5,7 @@ # it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC from __future__ import annotations -from typing import Optional, Any, DefaultDict +from typing import Optional, Any from collections import defaultdict from dataclasses import dataclass, field import multiprocessing, functools, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib @@ -81,7 +81,7 @@ class CloudSession: class CloudHandler(BaseHTTPRequestHandler): protocol_version = 'HTTP/1.1' device: str - sessions: DefaultDict[str, CloudSession] = defaultdict(CloudSession) + sessions: defaultdict[str, CloudSession] = defaultdict(CloudSession) def setup(self): super().setup() diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index c503b56258..d2395e97db 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -111,7 +111,7 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']): for i, value in enumerate(self._q): hw_view[i] = value # From now on, the queue is on the device for faster submission. - self._q = hw_view # type: ignore + self._q = hw_view def _submit_to_gpfifo(self, dev:NVDevice, gpfifo:GPFifo): if dev == self.binded_device: cmdq_addr = self.hw_page.va_addr diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index bcdc9484c1..0b5041f914 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -97,7 +97,7 @@ class QCOMComputeQueue(HWQueue): self.hw_page = dev.allocator.alloc(len(self._q) * 4, BufferSpec(cpu_access=True, nolru=True)) self.submit_req, self.obj = self._build_gpu_command(self.binded_device, self.hw_page.va_addr) # From now on, the queue is on the device for faster submission. - self._q = to_mv(self.obj.gpuaddr, len(self._q) * 4).cast("I") # type: ignore + self._q = to_mv(self.obj.gpuaddr, len(self._q) * 4).cast("I") def _submit(self, dev:QCOMDevice): if self.binded_device == dev: submit_req = self.submit_req