mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 15:15:13 -05:00
don't call contiguous for unpadded const into multi tensor (#4032)
* don't call contiguous for unpadded const into multi tensor fixed multi const folding for sharded const. still wip, need to be careful that this does not break multi device cache somewhere * ehh need a memory test for that * simple sharded memory test
This commit is contained in:
@@ -70,8 +70,9 @@ class MultiLazyBuffer:
|
||||
|
||||
@staticmethod
|
||||
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
|
||||
lbs = [lb.contiguous() if lb.base != lb else lb] * len(devices)
|
||||
return MultiLazyBuffer([lb.copy_to_device(d).contiguous() for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)], axis)
|
||||
lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unpadded_const() else lb] * len(devices)
|
||||
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
|
||||
return MultiLazyBuffer([lb if lb.is_unrealized_unpadded_const() else lb.contiguous() for lb in sharded_lbs], axis)
|
||||
|
||||
def copy_to_device(self, device:str) -> LazyBuffer:
|
||||
if self.axis is None: return self.lbs[self.real.index(True)].copy_to_device(device)
|
||||
@@ -82,9 +83,6 @@ class MultiLazyBuffer:
|
||||
llbs.append(lb.pad(pad_arg))
|
||||
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
|
||||
|
||||
# TODO: fix this
|
||||
def is_unrealized_unpadded_const(self): return False
|
||||
|
||||
# passthroughs
|
||||
def is_realized(self) -> bool: return all([lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True])
|
||||
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
|
||||
|
||||
Reference in New Issue
Block a user