From 1058f9c9ff27e7c19b1d7b2543c0c3c5ea8ced85 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 30 Oct 2024 10:27:55 +0700 Subject: [PATCH] Restore vcount [pr] (#7390) * Revert "Revert "add vcount to PtrDtype (#7388)"" This reverts commit 399a5219ddaf5784ebc905bcd582bd88403c3dd2. * Revert "Revert "add tests to vcount stuff [pr] (#7389)"" This reverts commit cc8d6dbdf38275f079e03ebe7cd2d9b8127b5575. * no ptr --- test/test_dtype.py | 37 +++++++++++++++++++++++++++++++++++++ tinygrad/dtype.py | 17 +++++++++++++---- tinygrad/ops.py | 4 ++-- tinygrad/renderer/cstyle.py | 2 +- tinygrad/viz/serve.py | 4 +++- 5 files changed, 56 insertions(+), 8 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 57ecbd7f5f..27177f1269 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -295,6 +295,43 @@ class TestUint64DType(TestDType): class TestBoolDType(TestDType): DTYPE = dtypes.bool +class TestPtrDType(unittest.TestCase): + def test_vec_double(self): + dt1 = dtypes.float.vec(4).ptr(v=4) + dt2 = dtypes.float.vec(4).ptr().vec(4) + self.assertEqual(dt1, dt2) + self.assertEqual(str(dt1), str(dt2)) + + def test_scalar(self): + dt = dtypes.float.vec(4).ptr().scalar() + self.assertEqual(dt.base, dtypes.float.vec(4)) + + dt = dtypes.float.vec(4).ptr().vec(4).scalar() + self.assertEqual(dt.base, dtypes.float.vec(4)) + + dt = dtypes.float.vec(4).scalar() + self.assertEqual(dt, dtypes.float) + + def test_serialize(self): + dt = dtypes.float.vec(4).ptr(v=4) + self.assertEqual(dt, eval(str(dt))) + + def test_vcount(self): + dt = dtypes.float.ptr().vec(4) + self.assertEqual(dt.vcount, 4) + self.assertEqual(dt.v, 4) + self.assertEqual(dt.count, 1) + + dt = dtypes.float.vec(4).ptr() + self.assertEqual(dt.vcount, 1) + self.assertEqual(dt.v, 1) + self.assertEqual(dt.count, 4) + + dt = dtypes.float.vec(4).ptr(v=4) + self.assertEqual(dt.vcount, 4) + self.assertEqual(dt.v, 4) + self.assertEqual(dt.count, 4) + class TestImageDType(unittest.TestCase): def test_image_scalar(self): assert dtypes.imagef((10,10)).scalar() == dtypes.float32 diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index f2236324c7..eb1d12b5b6 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -15,12 +15,14 @@ class DType: count: int def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "") def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count) + @property + def vcount(self): return self.count def vec(self, sz:int): assert self.count == 1, f"can't vectorize {self} with size {sz}" if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz) - def ptr(self, local=False) -> Union[PtrDType, ImageDType]: - return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local) + def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]: + return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local, v) def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self @dataclass(frozen=True) @@ -30,18 +32,25 @@ class ImageDType(DType): local: bool = False # images are never local def scalar(self) -> DType: return self.base def vec(self, sz:int): return self.base.vec(sz) - def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return self + def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]: return self def __repr__(self): return f"dtypes.{self.name}({self.shape})" @dataclass(frozen=True) class PtrDType(DType): base: DType local: bool + v: int def __hash__(self): return super().__hash__() + def scalar(self) -> DType: return self.base.ptr(self.local, 1) + def vec(self, sz:int) -> DType: return self.base.ptr(self.local, sz) + @property + def vcount(self): return self.v # local isn't used in the compare def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count def __ne__(self, dt): return not (self == dt) - def __repr__(self): return f"{super().__repr__()}.ptr(local=True)" if self.local else f"{super().__repr__()}.ptr()" + def __repr__(self): + arg = (["local=true"] if self.local else []) + ([f"v={self.v}"] if self.v != 1 else []) + return f"{self.base.__repr__()}.ptr({','.join(arg)})" class dtypes: @staticmethod diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 6e8bfedb3a..20f4b91030 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -318,8 +318,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is UOps.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i]) if self.op is UOps.CONST: return UOp.const(self.dtype.scalar(), self.arg) i = (i,) - if self.dtype == dtypes.void or (i == tuple(range(len(i))) and self.dtype.count == len(i)): return self - assert len(i) >= 1 and all(x < self.dtype.count for x in i), f"bad GEP on {self.dtype}, {i}" + if self.dtype == dtypes.void or (i == tuple(range(len(i))) and self.dtype.vcount == len(i)): return self + assert len(i) >= 1 and all(x < self.dtype.vcount for x in i), f"bad GEP on {self.dtype}, {i}" return UOp(UOps.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i) @staticmethod def load(*src:UOp, dtype:DType): return UOp(UOps.LOAD, dtype, src) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 557d415057..397d03cba0 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -327,7 +327,7 @@ class CUDARenderer(CStyleLanguage): # TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"] - for dtype in dedup(uop.dtype for uop in uops if uop.dtype in {dtypes.half, dtypes.bfloat16}): + for dtype in dedup(uop.dtype for uop in uops if uop.dtype in {dtypes.half, dtypes.bfloat16} and not isinstance(uop.dtype, PtrDType)): prefix += [f"#include "] + [self.render_vector_prefix(dtype.vec(sz)) for sz in [4, 8]] dt_map = { dtypes.half: "f16", dtypes.bfloat16: "bf16" } diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 5ec2af476f..f370c9ded2 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -71,7 +71,9 @@ def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp: replaces[base] = ret = base.replace(src=tuple(_replace_uop(x, replaces) for x in base.src)) return ret @functools.lru_cache(None) -def _prg(k:Optional[Kernel]) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None +def _prg(k:Optional[Kernel]) -> Optional[str]: + try: return k.to_program().src if isinstance(k, Kernel) else None + except Exception: return None def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails: g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=pcall(_prg, k)) replaces: Dict[UOp, UOp] = {}