diff --git a/test/test_dtype.py b/test/test_dtype.py index b10ba41669..faaff6b3a6 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -348,14 +348,9 @@ class TestEqStrDType(unittest.TestCase): assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match" assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches" assert isinstance(dtypes.imageh((1,2,4)), ImageDType) - def test_ptr_ne(self): - if PtrDType is None: raise unittest.SkipTest("no PtrDType support") - # TODO: is this the wrong behavior? - assert dtypes.float32.ptr() == dtypes.float32 - assert not (dtypes.float32.ptr() != dtypes.float32) + def test_ptr_eq(self): assert dtypes.float32.ptr() == dtypes.float32.ptr() assert not (dtypes.float32.ptr() != dtypes.float32.ptr()) - #assert dtypes.float32.ptr() != dtypes.float32 def test_strs(self): if PtrDType is None: raise unittest.SkipTest("no PtrDType support") self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))") diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index eb1d12b5b6..2dcbd2ea81 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable import math, struct, ctypes, functools -from dataclasses import dataclass +from dataclasses import dataclass, replace from tinygrad.helpers import getenv ConstType = Union[float, int, bool] @@ -17,40 +17,36 @@ class DType: def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count) @property def vcount(self): return self.count - def vec(self, sz:int): + def vec(self, sz:int) -> DType: assert self.count == 1, f"can't vectorize {self} with size {sz}" if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz) - def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]: - return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local, v) + def ptr(self, local=False) -> Union[PtrDType, ImageDType]: + return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local) def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self -@dataclass(frozen=True) -class ImageDType(DType): - shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape - base: DType - local: bool = False # images are never local - def scalar(self) -> DType: return self.base - def vec(self, sz:int): return self.base.vec(sz) - def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]: return self - def __repr__(self): return f"dtypes.{self.name}({self.shape})" - @dataclass(frozen=True) class PtrDType(DType): base: DType - local: bool - v: int - def __hash__(self): return super().__hash__() - def scalar(self) -> DType: return self.base.ptr(self.local, 1) - def vec(self, sz:int) -> DType: return self.base.ptr(self.local, sz) + local: bool = False + v: int = 1 + def scalar(self) -> PtrDType: return replace(self, v=1) + def vec(self, sz:int) -> PtrDType: return replace(self, v=sz) + def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer") @property def vcount(self): return self.v - # local isn't used in the compare - def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count - def __ne__(self, dt): return not (self == dt) - def __repr__(self): - arg = (["local=true"] if self.local else []) + ([f"v={self.v}"] if self.v != 1 else []) - return f"{self.base.__repr__()}.ptr({','.join(arg)})" + def __repr__(self): return f"{self.base.__repr__()}.ptr({'local=true' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '') + +@dataclass(frozen=True) +class ImageDType(PtrDType): + shape: Tuple[int, ...] = () # shape of the Image + # NOTE: scalar/vec are wrong + def scalar(self): return self.base + def vec(self, sz:int): return self.base.vec(sz) + def ptr(self, local=False) -> Union[PtrDType, ImageDType]: + assert not local, "images can't be local" + return self + def __repr__(self): return f"dtypes.{self.name}({self.shape})" class dtypes: @staticmethod