mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
unify T = TypeVar("T") (#7342)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import sys, time, logging, difflib
|
||||
from typing import Callable, Optional, Tuple, TypeVar
|
||||
from typing import Callable, Optional, Tuple
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.ops import UOp, UOps, sint
|
||||
@@ -8,7 +8,7 @@ from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.engine.realize import Runner
|
||||
from tinygrad.dtype import ConstType, DType
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.helpers import CI, OSX, getenv, colored
|
||||
from tinygrad.helpers import CI, OSX, T, getenv, colored
|
||||
|
||||
def derandomize_model(model):
|
||||
for p in get_parameters(model):
|
||||
@@ -73,7 +73,6 @@ def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional
|
||||
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
|
||||
return UOp(UOps.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
|
||||
|
||||
T = TypeVar("T")
|
||||
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
|
||||
st = time.perf_counter_ns()
|
||||
ret = fxn(*args, **kwargs)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, TypeVar, DefaultDict
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, T
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
@@ -36,7 +36,6 @@ class MetaOps(FastEnum):
|
||||
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps]
|
||||
|
||||
T = TypeVar("T")
|
||||
class MathTrait:
|
||||
# required to implement
|
||||
def alu(self:T, arg:Union[UnaryOps, BinaryOps, TernaryOps], *src) -> T: raise NotImplementedError
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, subprocess, pathlib, ctypes, tempfile, functools
|
||||
from typing import List, Any, Tuple, Optional, cast, TypeVar
|
||||
from tinygrad.helpers import prod, getenv, DEBUG
|
||||
from typing import List, Any, Tuple, Optional, cast
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, T
|
||||
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
|
||||
from tinygrad.renderer.cstyle import MetalRenderer
|
||||
|
||||
@@ -32,7 +32,6 @@ libobjc.sel_registerName.restype = objc_id
|
||||
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
|
||||
libdispatch.dispatch_data_create.restype = objc_instance
|
||||
|
||||
T = TypeVar("T")
|
||||
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
|
||||
def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id) -> T: # type: ignore [assignment]
|
||||
sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe
|
||||
|
||||
Reference in New Issue
Block a user