diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 9639e0595b..7b96d707fd 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -1,11 +1,13 @@ from __future__ import annotations -from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable +from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable, Literal import math, struct, ctypes, functools from dataclasses import dataclass, fields from tinygrad.helpers import getenv ConstType = Union[float, int, bool] +FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd'] + # all DTypes should only be created once class DTypeMetaClass(type): dcache: Dict[Tuple, DType] = {} @@ -19,11 +21,11 @@ class DType(metaclass=DTypeMetaClass): priority: int # this determines when things get upcasted itemsize: int name: str - fmt: Optional[str] + fmt: Optional[FmtStr] count: int _scalar: Optional[DType] @staticmethod - def new(priority:int, itemsize:int, name:str, fmt:Optional[str]): return DType(priority, itemsize, name, fmt, 1, None) + def new(priority:int, itemsize:int, name:str, fmt:Optional[FmtStr]): return DType(priority, itemsize, name, fmt, 1, None) def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self)) def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "") def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 1f4ef3ff54..3076cd747f 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -2,7 +2,8 @@ # a python uops emulator # works to test the tensor cores, and all the uops in general # this is the (living) definition of uops -from typing import Tuple, List, Optional, Any, Dict +import sys +from typing import Tuple, List, Optional, Any, Dict, TYPE_CHECKING import pickle, base64, itertools, time, struct from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate from tinygrad.helpers import all_same, getenv, flatten @@ -66,13 +67,11 @@ class PythonProgram: continue assert dtype is not None, f"{uop} is missing a dtype" dl[i] = dtype - if uop is Ops.DEFINE_GLOBAL: + if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL}: assert dtype.fmt is not None - ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size - elif uop is Ops.DEFINE_LOCAL: - assert dtype.fmt is not None - lbuf = memoryview(bytearray(arg[1]*dtype.itemsize)) - ul[i] = [lbuf.cast(dtype.fmt)] * warp_size + if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e" + buf = memoryview(bytearray(arg[1]*dtype.itemsize)) if uop is Ops.DEFINE_LOCAL else pbufs.pop(0) + ul[i] = [buf.cast(dtype.fmt)] * warp_size elif uop is Ops.DEFINE_VAR: ul[i] = [pvals.pop(0)] * warp_size elif uop is Ops.SPECIAL: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d09539d498..85ca7f90a3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2,7 +2,7 @@ from __future__ import annotations import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib from contextlib import ContextDecorator -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal +from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal, TYPE_CHECKING from collections import defaultdict from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate @@ -273,6 +273,7 @@ class Tensor(SimpleMathTrait): """ assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}" assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" + if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.fmt != "e" return self._data().cast(self.dtype.fmt) if 0 in self.shape else self._data().cast(self.dtype.fmt, self.shape) def item(self) -> ConstType: @@ -284,9 +285,8 @@ class Tensor(SimpleMathTrait): print(t.item()) ``` """ - assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}" assert self.numel() == 1, "must have one element for item" - return self._data().cast(self.dtype.fmt)[0] + return self.data()[(0,) * len(self.shape)] # TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int] # src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803