mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
prepare mypy==1.13.0: legacy cast (#7866)
* use helper to narrow literal type * narrow with asserts instead of cast * remove parantheses * tensor.item() calls tensor.data() * no copy * proper indexing
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
|
||||
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable, Literal
|
||||
import math, struct, ctypes, functools
|
||||
from dataclasses import dataclass, fields
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
ConstType = Union[float, int, bool]
|
||||
|
||||
FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
|
||||
|
||||
# all DTypes should only be created once
|
||||
class DTypeMetaClass(type):
|
||||
dcache: Dict[Tuple, DType] = {}
|
||||
@@ -19,11 +21,11 @@ class DType(metaclass=DTypeMetaClass):
|
||||
priority: int # this determines when things get upcasted
|
||||
itemsize: int
|
||||
name: str
|
||||
fmt: Optional[str]
|
||||
fmt: Optional[FmtStr]
|
||||
count: int
|
||||
_scalar: Optional[DType]
|
||||
@staticmethod
|
||||
def new(priority:int, itemsize:int, name:str, fmt:Optional[str]): return DType(priority, itemsize, name, fmt, 1, None)
|
||||
def new(priority:int, itemsize:int, name:str, fmt:Optional[FmtStr]): return DType(priority, itemsize, name, fmt, 1, None)
|
||||
def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
|
||||
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
|
||||
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
# a python uops emulator
|
||||
# works to test the tensor cores, and all the uops in general
|
||||
# this is the (living) definition of uops
|
||||
from typing import Tuple, List, Optional, Any, Dict
|
||||
import sys
|
||||
from typing import Tuple, List, Optional, Any, Dict, TYPE_CHECKING
|
||||
import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
@@ -66,13 +67,11 @@ class PythonProgram:
|
||||
continue
|
||||
assert dtype is not None, f"{uop} is missing a dtype"
|
||||
dl[i] = dtype
|
||||
if uop is Ops.DEFINE_GLOBAL:
|
||||
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL}:
|
||||
assert dtype.fmt is not None
|
||||
ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
|
||||
elif uop is Ops.DEFINE_LOCAL:
|
||||
assert dtype.fmt is not None
|
||||
lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
|
||||
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e"
|
||||
buf = memoryview(bytearray(arg[1]*dtype.itemsize)) if uop is Ops.DEFINE_LOCAL else pbufs.pop(0)
|
||||
ul[i] = [buf.cast(dtype.fmt)] * warp_size
|
||||
elif uop is Ops.DEFINE_VAR:
|
||||
ul[i] = [pvals.pop(0)] * warp_size
|
||||
elif uop is Ops.SPECIAL:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal, TYPE_CHECKING
|
||||
from collections import defaultdict
|
||||
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
@@ -273,6 +273,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.fmt != "e"
|
||||
return self._data().cast(self.dtype.fmt) if 0 in self.shape else self._data().cast(self.dtype.fmt, self.shape)
|
||||
|
||||
def item(self) -> ConstType:
|
||||
@@ -284,9 +285,8 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.item())
|
||||
```
|
||||
"""
|
||||
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
||||
assert self.numel() == 1, "must have one element for item"
|
||||
return self._data().cast(self.dtype.fmt)[0]
|
||||
return self.data()[(0,) * len(self.shape)]
|
||||
|
||||
# TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
|
||||
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
|
||||
|
||||
Reference in New Issue
Block a user