Restore vcount [pr] (#7390)

* Revert "Revert "add vcount to PtrDtype (#7388)""

This reverts commit 399a5219dd.

* Revert "Revert "add tests to vcount stuff [pr] (#7389)""

This reverts commit cc8d6dbdf3.

* no ptr
This commit is contained in:
George Hotz
2024-10-30 10:27:55 +07:00
committed by GitHub
parent 399a5219dd
commit 1058f9c9ff
5 changed files with 56 additions and 8 deletions

View File

@@ -15,12 +15,14 @@ class DType:
count: int
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
@property
def vcount(self): return self.count
def vec(self, sz:int):
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def ptr(self, local=False) -> Union[PtrDType, ImageDType]:
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local)
def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]:
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local, v)
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
@dataclass(frozen=True)
@@ -30,18 +32,25 @@ class ImageDType(DType):
local: bool = False # images are never local
def scalar(self) -> DType: return self.base
def vec(self, sz:int): return self.base.vec(sz)
def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return self
def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]: return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
@dataclass(frozen=True)
class PtrDType(DType):
base: DType
local: bool
v: int
def __hash__(self): return super().__hash__()
def scalar(self) -> DType: return self.base.ptr(self.local, 1)
def vec(self, sz:int) -> DType: return self.base.ptr(self.local, sz)
@property
def vcount(self): return self.v
# local isn't used in the compare
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
def __ne__(self, dt): return not (self == dt)
def __repr__(self): return f"{super().__repr__()}.ptr(local=True)" if self.local else f"{super().__repr__()}.ptr()"
def __repr__(self):
arg = (["local=true"] if self.local else []) + ([f"v={self.v}"] if self.v != 1 else [])
return f"{self.base.__repr__()}.ptr({','.join(arg)})"
class dtypes:
@staticmethod