From cbfd51f5a5030e1d10ebd31e802c45ec88653403 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 14 Jan 2025 13:25:54 -0500 Subject: [PATCH] make MultiLazyBuffer.bounds a property [pr] (#8614) determined by lbs shapes and axis --- tinygrad/multi.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tinygrad/multi.py b/tinygrad/multi.py index 97a17fe714..07a3430f7b 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -46,9 +46,6 @@ class MultiLazyBuffer(MathTrait): assert all(isinstance(x, UOp) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them" assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}" self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs) - if axis is not None: - splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0)) - self.bounds = tuple(zip(splits, splits[1:])) @property def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)) @@ -59,6 +56,11 @@ class MultiLazyBuffer(MathTrait): @property def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r] + @property + def bounds(self): + if self.axis is None: raise RuntimeError("bounds is not defined when axis is None") + return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.lbs], initial=0))) + def __repr__(self): return f"" @staticmethod