prepare mypy==1.13.0: legacy cast (#7866)

* use helper to narrow literal type

* narrow with asserts instead of cast

* remove parantheses

* tensor.item() calls tensor.data()

* no copy

* proper indexing
This commit is contained in:
JaSpa99
2024-11-27 16:33:35 +01:00
committed by GitHub
parent 753f07e193
commit 38f34ca0cb
3 changed files with 14 additions and 13 deletions

View File

@@ -1,11 +1,13 @@
from __future__ import annotations
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable, Literal
import math, struct, ctypes, functools
from dataclasses import dataclass, fields
from tinygrad.helpers import getenv
ConstType = Union[float, int, bool]
FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
# all DTypes should only be created once
class DTypeMetaClass(type):
dcache: Dict[Tuple, DType] = {}
@@ -19,11 +21,11 @@ class DType(metaclass=DTypeMetaClass):
priority: int # this determines when things get upcasted
itemsize: int
name: str
fmt: Optional[str]
fmt: Optional[FmtStr]
count: int
_scalar: Optional[DType]
@staticmethod
def new(priority:int, itemsize:int, name:str, fmt:Optional[str]): return DType(priority, itemsize, name, fmt, 1, None)
def new(priority:int, itemsize:int, name:str, fmt:Optional[FmtStr]): return DType(priority, itemsize, name, fmt, 1, None)
def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)