don't call contiguous for unpadded const into multi tensor (#4032)

* don't call contiguous for unpadded const into multi tensor

fixed multi const folding for sharded const.
still wip, need to be careful that this does not break multi device cache somewhere

* ehh need a memory test for that

* simple sharded memory test
This commit is contained in:
chenyu
2024-04-01 19:22:14 -04:00
committed by GitHub
parent d6ba44bc1e
commit 82440d3416
3 changed files with 29 additions and 7 deletions

View File

@@ -70,8 +70,9 @@ class MultiLazyBuffer:
@staticmethod
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
lbs = [lb.contiguous() if lb.base != lb else lb] * len(devices)
return MultiLazyBuffer([lb.copy_to_device(d).contiguous() for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)], axis)
lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unpadded_const() else lb] * len(devices)
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
return MultiLazyBuffer([lb if lb.is_unrealized_unpadded_const() else lb.contiguous() for lb in sharded_lbs], axis)
def copy_to_device(self, device:str) -> LazyBuffer:
if self.axis is None: return self.lbs[self.real.index(True)].copy_to_device(device)
@@ -82,9 +83,6 @@ class MultiLazyBuffer:
llbs.append(lb.pad(pad_arg))
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
# TODO: fix this
def is_unrealized_unpadded_const(self): return False
# passthroughs
def is_realized(self) -> bool: return all([lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True])
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)