PtrDType is dataclass [pr] (#7125)

* PtrDType is dataclass [pr]

* new dataset

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
George Hotz
2024-10-18 21:40:33 +08:00
committed by GitHub
parent ea016b55d1
commit b0a13896d7
3 changed files with 7 additions and 6 deletions

Binary file not shown.

View File

@@ -1356,7 +1356,7 @@ class TestLinearizer(unittest.TestCase):
stores = [u for u in k.uops if u.op is UOps.STORE]
# the float4 value stores directly in lds and we skip upcast
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
self.assertEqual(stores[0].src[-1].dtype, dtypes.float.vec(4))
#assert stores[0].src[-1].op is not UOps.VECTORIZE
# the global store doesn't change

View File

@@ -18,10 +18,10 @@ class DType:
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return PtrDType(self, local)
def ptr(self, local=False) -> Union[PtrDType, ImageDType]:
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local)
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
# dependent typing?
@dataclass(frozen=True, repr=False)
class ImageDType(DType):
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
@@ -32,11 +32,12 @@ class ImageDType(DType):
def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
@dataclass(frozen=True, repr=False)
class PtrDType(DType):
def __init__(self, dt:DType, local=False):
self.base, self.local = dt, local
super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
base: DType
local: bool
def __hash__(self): return super().__hash__()
# local isn't used in the compare
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
def __ne__(self, dt): return not (self == dt)
def __repr__(self): return f"{super().__repr__()}.ptr(local=True)" if self.local else f"{super().__repr__()}.ptr()"