mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
don't compare with pointer dtype (#7394)
* don't compare with pointer dtype * more cleanup * images are pointers * handle IMAGE better * cleaner test_image * this work * pr match * cleanup
This commit is contained in:
@@ -348,14 +348,9 @@ class TestEqStrDType(unittest.TestCase):
|
||||
assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match"
|
||||
assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches"
|
||||
assert isinstance(dtypes.imageh((1,2,4)), ImageDType)
|
||||
def test_ptr_ne(self):
|
||||
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
||||
# TODO: is this the wrong behavior?
|
||||
assert dtypes.float32.ptr() == dtypes.float32
|
||||
assert not (dtypes.float32.ptr() != dtypes.float32)
|
||||
def test_ptr_eq(self):
|
||||
assert dtypes.float32.ptr() == dtypes.float32.ptr()
|
||||
assert not (dtypes.float32.ptr() != dtypes.float32.ptr())
|
||||
#assert dtypes.float32.ptr() != dtypes.float32
|
||||
def test_strs(self):
|
||||
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
||||
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
|
||||
import math, struct, ctypes, functools
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
ConstType = Union[float, int, bool]
|
||||
@@ -17,40 +17,36 @@ class DType:
|
||||
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
|
||||
@property
|
||||
def vcount(self): return self.count
|
||||
def vec(self, sz:int):
|
||||
def vec(self, sz:int) -> DType:
|
||||
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
|
||||
def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]:
|
||||
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local, v)
|
||||
def ptr(self, local=False) -> Union[PtrDType, ImageDType]:
|
||||
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local)
|
||||
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImageDType(DType):
|
||||
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
|
||||
base: DType
|
||||
local: bool = False # images are never local
|
||||
def scalar(self) -> DType: return self.base
|
||||
def vec(self, sz:int): return self.base.vec(sz)
|
||||
def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]: return self
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PtrDType(DType):
|
||||
base: DType
|
||||
local: bool
|
||||
v: int
|
||||
def __hash__(self): return super().__hash__()
|
||||
def scalar(self) -> DType: return self.base.ptr(self.local, 1)
|
||||
def vec(self, sz:int) -> DType: return self.base.ptr(self.local, sz)
|
||||
local: bool = False
|
||||
v: int = 1
|
||||
def scalar(self) -> PtrDType: return replace(self, v=1)
|
||||
def vec(self, sz:int) -> PtrDType: return replace(self, v=sz)
|
||||
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
|
||||
@property
|
||||
def vcount(self): return self.v
|
||||
# local isn't used in the compare
|
||||
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
|
||||
def __ne__(self, dt): return not (self == dt)
|
||||
def __repr__(self):
|
||||
arg = (["local=true"] if self.local else []) + ([f"v={self.v}"] if self.v != 1 else [])
|
||||
return f"{self.base.__repr__()}.ptr({','.join(arg)})"
|
||||
def __repr__(self): return f"{self.base.__repr__()}.ptr({'local=true' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImageDType(PtrDType):
|
||||
shape: Tuple[int, ...] = () # shape of the Image
|
||||
# NOTE: scalar/vec are wrong
|
||||
def scalar(self): return self.base
|
||||
def vec(self, sz:int): return self.base.vec(sz)
|
||||
def ptr(self, local=False) -> Union[PtrDType, ImageDType]:
|
||||
assert not local, "images can't be local"
|
||||
return self
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user