unify T = TypeVar("T") (#7342)

This commit is contained in:
chenyu
2024-10-28 18:43:44 -04:00
committed by GitHub
parent 293adc141a
commit 6021bf87f4
3 changed files with 6 additions and 9 deletions

View File

@@ -1,5 +1,5 @@
import sys, time, logging, difflib
from typing import Callable, Optional, Tuple, TypeVar
from typing import Callable, Optional, Tuple
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.ops import UOp, UOps, sint
@@ -8,7 +8,7 @@ from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner
from tinygrad.dtype import ConstType, DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import CI, OSX, getenv, colored
from tinygrad.helpers import CI, OSX, T, getenv, colored
def derandomize_model(model):
for p in get_parameters(model):
@@ -73,7 +73,6 @@ def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
return UOp(UOps.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
T = TypeVar("T")
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
st = time.perf_counter_ns()
ret = fxn(*args, **kwargs)

View File

@@ -1,12 +1,12 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, TypeVar, DefaultDict
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
from collections import defaultdict
from weakref import WeakValueDictionary
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, T
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
@@ -36,7 +36,6 @@ class MetaOps(FastEnum):
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps]
T = TypeVar("T")
class MathTrait:
# required to implement
def alu(self:T, arg:Union[UnaryOps, BinaryOps, TernaryOps], *src) -> T: raise NotImplementedError

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import os, subprocess, pathlib, ctypes, tempfile, functools
from typing import List, Any, Tuple, Optional, cast, TypeVar
from tinygrad.helpers import prod, getenv, DEBUG
from typing import List, Any, Tuple, Optional, cast
from tinygrad.helpers import prod, getenv, DEBUG, T
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
from tinygrad.renderer.cstyle import MetalRenderer
@@ -32,7 +32,6 @@ libobjc.sel_registerName.restype = objc_id
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
libdispatch.dispatch_data_create.restype = objc_instance
T = TypeVar("T")
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id) -> T: # type: ignore [assignment]
sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe