prepare mypy==1.13.0: legacy cast (#7866)

* use helper to narrow literal type

* narrow with asserts instead of cast

* remove parantheses

* tensor.item() calls tensor.data()

* no copy

* proper indexing
This commit is contained in:
JaSpa99
2024-11-27 16:33:35 +01:00
committed by GitHub
parent 753f07e193
commit 38f34ca0cb
3 changed files with 14 additions and 13 deletions

View File

@@ -1,11 +1,13 @@
from __future__ import annotations
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable, Literal
import math, struct, ctypes, functools
from dataclasses import dataclass, fields
from tinygrad.helpers import getenv
ConstType = Union[float, int, bool]
FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
# all DTypes should only be created once
class DTypeMetaClass(type):
dcache: Dict[Tuple, DType] = {}
@@ -19,11 +21,11 @@ class DType(metaclass=DTypeMetaClass):
priority: int # this determines when things get upcasted
itemsize: int
name: str
fmt: Optional[str]
fmt: Optional[FmtStr]
count: int
_scalar: Optional[DType]
@staticmethod
def new(priority:int, itemsize:int, name:str, fmt:Optional[str]): return DType(priority, itemsize, name, fmt, 1, None)
def new(priority:int, itemsize:int, name:str, fmt:Optional[FmtStr]): return DType(priority, itemsize, name, fmt, 1, None)
def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)

View File

@@ -2,7 +2,8 @@
# a python uops emulator
# works to test the tensor cores, and all the uops in general
# this is the (living) definition of uops
from typing import Tuple, List, Optional, Any, Dict
import sys
from typing import Tuple, List, Optional, Any, Dict, TYPE_CHECKING
import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
from tinygrad.helpers import all_same, getenv, flatten
@@ -66,13 +67,11 @@ class PythonProgram:
continue
assert dtype is not None, f"{uop} is missing a dtype"
dl[i] = dtype
if uop is Ops.DEFINE_GLOBAL:
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL}:
assert dtype.fmt is not None
ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
elif uop is Ops.DEFINE_LOCAL:
assert dtype.fmt is not None
lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e"
buf = memoryview(bytearray(arg[1]*dtype.itemsize)) if uop is Ops.DEFINE_LOCAL else pbufs.pop(0)
ul[i] = [buf.cast(dtype.fmt)] * warp_size
elif uop is Ops.DEFINE_VAR:
ul[i] = [pvals.pop(0)] * warp_size
elif uop is Ops.SPECIAL:

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal, TYPE_CHECKING
from collections import defaultdict
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
@@ -273,6 +273,7 @@ class Tensor(SimpleMathTrait):
"""
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.fmt != "e"
return self._data().cast(self.dtype.fmt) if 0 in self.shape else self._data().cast(self.dtype.fmt, self.shape)
def item(self) -> ConstType:
@@ -284,9 +285,8 @@ class Tensor(SimpleMathTrait):
print(t.item())
```
"""
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
assert self.numel() == 1, "must have one element for item"
return self._data().cast(self.dtype.fmt)[0]
return self.data()[(0,) * len(self.shape)]
# TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803