don't compare with pointer dtype (#7394)

* don't compare with pointer dtype

* more cleanup

* images are pointers

* handle IMAGE better

* cleaner test_image

* this work

* pr match

* cleanup
This commit is contained in:
George Hotz
2024-10-30 16:48:27 +07:00
committed by GitHub
parent 95390df02a
commit 76a41a1083
2 changed files with 22 additions and 31 deletions

View File

@@ -348,14 +348,9 @@ class TestEqStrDType(unittest.TestCase):
assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match"
assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches"
assert isinstance(dtypes.imageh((1,2,4)), ImageDType)
def test_ptr_ne(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
# TODO: is this the wrong behavior?
assert dtypes.float32.ptr() == dtypes.float32
assert not (dtypes.float32.ptr() != dtypes.float32)
def test_ptr_eq(self):
assert dtypes.float32.ptr() == dtypes.float32.ptr()
assert not (dtypes.float32.ptr() != dtypes.float32.ptr())
#assert dtypes.float32.ptr() != dtypes.float32
def test_strs(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
import math, struct, ctypes, functools
from dataclasses import dataclass
from dataclasses import dataclass, replace
from tinygrad.helpers import getenv
ConstType = Union[float, int, bool]
@@ -17,40 +17,36 @@ class DType:
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
@property
def vcount(self): return self.count
def vec(self, sz:int):
def vec(self, sz:int) -> DType:
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]:
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local, v)
def ptr(self, local=False) -> Union[PtrDType, ImageDType]:
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local)
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
@dataclass(frozen=True)
class ImageDType(DType):
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
base: DType
local: bool = False # images are never local
def scalar(self) -> DType: return self.base
def vec(self, sz:int): return self.base.vec(sz)
def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]: return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
@dataclass(frozen=True)
class PtrDType(DType):
base: DType
local: bool
v: int
def __hash__(self): return super().__hash__()
def scalar(self) -> DType: return self.base.ptr(self.local, 1)
def vec(self, sz:int) -> DType: return self.base.ptr(self.local, sz)
local: bool = False
v: int = 1
def scalar(self) -> PtrDType: return replace(self, v=1)
def vec(self, sz:int) -> PtrDType: return replace(self, v=sz)
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
@property
def vcount(self): return self.v
# local isn't used in the compare
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
def __ne__(self, dt): return not (self == dt)
def __repr__(self):
arg = (["local=true"] if self.local else []) + ([f"v={self.v}"] if self.v != 1 else [])
return f"{self.base.__repr__()}.ptr({','.join(arg)})"
def __repr__(self): return f"{self.base.__repr__()}.ptr({'local=true' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
@dataclass(frozen=True)
class ImageDType(PtrDType):
shape: Tuple[int, ...] = () # shape of the Image
# NOTE: scalar/vec are wrong
def scalar(self): return self.base
def vec(self, sz:int): return self.base.vec(sz)
def ptr(self, local=False) -> Union[PtrDType, ImageDType]:
assert not local, "images can't be local"
return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
class dtypes:
@staticmethod