init changes from the dtypes_void branch [run_process_replay] (#6475)

This commit is contained in:
qazal
2024-09-11 16:34:50 +08:00
committed by GitHub
parent d6d9234985
commit 78148e16d8

View File

@@ -1,3 +1,4 @@
from __future__ import annotations
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
from dataclasses import dataclass
import functools
@@ -16,7 +17,7 @@ class DType:
def vec(self, sz:int):
assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
# dependent typing?
@dataclass(frozen=True, repr=False)
@@ -70,6 +71,7 @@ class dtypes:
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
# TODO: priority should be higher than bool
void: Final[DType] = DType(-1, 0, "void", None, 1)
pyint: Final[DType] = DType(-1, 8, "pyint", None, 1) # arbitrary precision integer, same itemsize to int64 so min/max works
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
@@ -123,9 +125,10 @@ def least_upper_dtype(*ds:DType) -> DType:
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'pyint')) or v.__class__ is staticmethod)}
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'pyint', 'void')) or v.__class__ is staticmethod)}
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
INVERSE_DTYPES_DICT['pyint'] = 'pyint'
INVERSE_DTYPES_DICT['void'] = 'void'
def sum_acc_dtype(dt:DType):
# default acc dtype for sum