put local on the PtrDtype [run_process_replay] (#6656)

* put local on the PtrDtype [run_process_replay]

* those are local too
This commit is contained in:
George Hotz
2024-09-23 10:29:17 +08:00
committed by GitHub
parent 90c1ccc402
commit e945fa9c5c
7 changed files with 23 additions and 17 deletions

View File

@@ -25,17 +25,20 @@ class DType:
class ImageDType(DType):
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
base: DType
local: bool = False # images are never local
def scalar(self): return self.base
def vec(self, sz:int): return self.base.vec(sz)
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
# @dataclass(frozen=True, init=False, repr=False, eq=False)
class PtrDType(DType):
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
def __init__(self, dt:DType, local=False):
self.base, self.local = dt, local
super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
def __hash__(self): return super().__hash__()
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
def __ne__(self, dt): return not (self == dt)
def __repr__(self): return f"PtrDType({super().__repr__()})"
def __repr__(self): return f"PtrDType({super().__repr__()}, local=True)" if self.local else f"PtrDType({super().__repr__()})"
class dtypes:
@staticmethod