mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
PtrDType is dataclass [pr] (#7125)
* PtrDType is dataclass [pr] * new dataset --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Binary file not shown.
@@ -1356,7 +1356,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
stores = [u for u in k.uops if u.op is UOps.STORE]
|
||||
|
||||
# the float4 value stores directly in lds and we skip upcast
|
||||
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
|
||||
self.assertEqual(stores[0].src[-1].dtype, dtypes.float.vec(4))
|
||||
#assert stores[0].src[-1].op is not UOps.VECTORIZE
|
||||
|
||||
# the global store doesn't change
|
||||
|
||||
@@ -18,10 +18,10 @@ class DType:
|
||||
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
|
||||
def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return PtrDType(self, local)
|
||||
def ptr(self, local=False) -> Union[PtrDType, ImageDType]:
|
||||
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local)
|
||||
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
|
||||
|
||||
# dependent typing?
|
||||
@dataclass(frozen=True, repr=False)
|
||||
class ImageDType(DType):
|
||||
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
|
||||
@@ -32,11 +32,12 @@ class ImageDType(DType):
|
||||
def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return self
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
@dataclass(frozen=True, repr=False)
|
||||
class PtrDType(DType):
|
||||
def __init__(self, dt:DType, local=False):
|
||||
self.base, self.local = dt, local
|
||||
super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
|
||||
base: DType
|
||||
local: bool
|
||||
def __hash__(self): return super().__hash__()
|
||||
# local isn't used in the compare
|
||||
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
|
||||
def __ne__(self, dt): return not (self == dt)
|
||||
def __repr__(self): return f"{super().__repr__()}.ptr(local=True)" if self.local else f"{super().__repr__()}.ptr()"
|
||||
|
||||
Reference in New Issue
Block a user