diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index bf545219b2..affb5ded4b 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union from dataclasses import dataclass import functools @@ -16,7 +17,7 @@ class DType: def vec(self, sz:int): assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}" return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz) - def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self + def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self # dependent typing? @dataclass(frozen=True, repr=False) @@ -70,6 +71,7 @@ class dtypes: @staticmethod def fields() -> Dict[str, DType]: return DTYPES_DICT # TODO: priority should be higher than bool + void: Final[DType] = DType(-1, 0, "void", None, 1) pyint: Final[DType] = DType(-1, 8, "pyint", None, 1) # arbitrary precision integer, same itemsize to int64 so min/max works bool: Final[DType] = DType(0, 1, "bool", '?', 1) int8: Final[DType] = DType(1, 1, "char", 'b', 1) @@ -123,9 +125,10 @@ def least_upper_dtype(*ds:DType) -> DType: def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32) # HACK: staticmethods are not callable in 3.8 so we have to compare the class -DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'pyint')) or v.__class__ is staticmethod)} +DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'pyint', 'void')) or v.__class__ is staticmethod)} INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()} INVERSE_DTYPES_DICT['pyint'] = 'pyint' +INVERSE_DTYPES_DICT['void'] = 'void' def sum_acc_dtype(dt:DType): # default acc dtype for sum