mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-08 05:35:11 -05:00
Restore vcount [pr] (#7390)
* Revert "Revert "add vcount to PtrDtype (#7388)"" This reverts commit399a5219dd. * Revert "Revert "add tests to vcount stuff [pr] (#7389)"" This reverts commitcc8d6dbdf3. * no ptr
This commit is contained in:
@@ -15,12 +15,14 @@ class DType:
|
||||
count: int
|
||||
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
|
||||
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
|
||||
@property
|
||||
def vcount(self): return self.count
|
||||
def vec(self, sz:int):
|
||||
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
|
||||
def ptr(self, local=False) -> Union[PtrDType, ImageDType]:
|
||||
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local)
|
||||
def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]:
|
||||
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local, v)
|
||||
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -30,18 +32,25 @@ class ImageDType(DType):
|
||||
local: bool = False # images are never local
|
||||
def scalar(self) -> DType: return self.base
|
||||
def vec(self, sz:int): return self.base.vec(sz)
|
||||
def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return self
|
||||
def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]: return self
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PtrDType(DType):
|
||||
base: DType
|
||||
local: bool
|
||||
v: int
|
||||
def __hash__(self): return super().__hash__()
|
||||
def scalar(self) -> DType: return self.base.ptr(self.local, 1)
|
||||
def vec(self, sz:int) -> DType: return self.base.ptr(self.local, sz)
|
||||
@property
|
||||
def vcount(self): return self.v
|
||||
# local isn't used in the compare
|
||||
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
|
||||
def __ne__(self, dt): return not (self == dt)
|
||||
def __repr__(self): return f"{super().__repr__()}.ptr(local=True)" if self.local else f"{super().__repr__()}.ptr()"
|
||||
def __repr__(self):
|
||||
arg = (["local=true"] if self.local else []) + ([f"v={self.v}"] if self.v != 1 else [])
|
||||
return f"{self.base.__repr__()}.ptr({','.join(arg)})"
|
||||
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user