mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
make MultiLazyBuffer.bounds a property [pr] (#8614)
determined by lbs shapes and axis
This commit is contained in:
@@ -46,9 +46,6 @@ class MultiLazyBuffer(MathTrait):
|
||||
assert all(isinstance(x, UOp) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
|
||||
assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
|
||||
self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
|
||||
if axis is not None:
|
||||
splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
|
||||
self.bounds = tuple(zip(splits, splits[1:]))
|
||||
|
||||
@property
|
||||
def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
|
||||
@@ -59,6 +56,11 @@ class MultiLazyBuffer(MathTrait):
|
||||
@property
|
||||
def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
|
||||
|
||||
@property
|
||||
def bounds(self):
|
||||
if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
|
||||
return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.lbs], initial=0)))
|
||||
|
||||
def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user