mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-08 13:45:50 -05:00
put local on the PtrDtype [run_process_replay] (#6656)
* put local on the PtrDtype [run_process_replay] * those are local too
This commit is contained in:
@@ -25,17 +25,20 @@ class DType:
|
||||
class ImageDType(DType):
|
||||
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
|
||||
base: DType
|
||||
local: bool = False # images are never local
|
||||
def scalar(self): return self.base
|
||||
def vec(self, sz:int): return self.base.vec(sz)
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
# @dataclass(frozen=True, init=False, repr=False, eq=False)
|
||||
class PtrDType(DType):
|
||||
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
|
||||
def __init__(self, dt:DType, local=False):
|
||||
self.base, self.local = dt, local
|
||||
super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
|
||||
def __hash__(self): return super().__hash__()
|
||||
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
|
||||
def __ne__(self, dt): return not (self == dt)
|
||||
def __repr__(self): return f"PtrDType({super().__repr__()})"
|
||||
def __repr__(self): return f"PtrDType({super().__repr__()}, local=True)" if self.local else f"PtrDType({super().__repr__()})"
|
||||
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user