diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 5a91e79d0f..392347ce23 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -139,6 +139,9 @@ jobs: HIP=1 JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt - name: Run 10 CIFAR training steps run: STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt + # # TODO: enable this. it took 3 minutes in CI and made the full training one more than 5 minutes + # - name: Run 10 CIFAR training steps w HALF and 6 GPUS + # run: time HALF=1 STEPS=10 BS=1536 GPUS=6 python3 examples/hlb_cifar10.py - name: Run full CIFAR training w HALF run: time HALF=1 STEPS=1000 python3 examples/hlb_cifar10.py | tee train_cifar_half.txt # # TODO: make wino faster so we can enable both diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 1b3b9f5f29..7d54d7375f 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -5,13 +5,14 @@ # https://siboehm.com/articles/22/CUDA-MMM import random, time import numpy as np -from typing import Optional +from typing import Optional, List from extra.datasets import fetch_cifar, cifar_mean, cifar_std from extra.lr_scheduler import OneCycleLR from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit from tinygrad.nn.state import get_state_dict, get_parameters from tinygrad.nn import optim from tinygrad.helpers import Context, BEAM, WINO, getenv +from tinygrad.features.multi import MultiLazyBuffer BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000) EVAL_BS = getenv("EVAL_BS", BS) @@ -33,13 +34,39 @@ class BatchNorm(nn.BatchNorm2d): self.weight.requires_grad = False self.bias.requires_grad = True +class UnsyncedBatchNorm: + def __init__(self, num_features, num_devices=len(GPUS)): + self.bns:List[BatchNorm] = [] + for _ in range(num_devices): + bn = BatchNorm(num_features) + self.bns.append(bn) + + def __call__(self, x:Tensor): + if len(self.bns) == 1: return self.bns[0](x) + + bn_ts = [] + assert isinstance(x.lazydata, MultiLazyBuffer) + for bound, bn in zip(x.lazydata.bounds, self.bns): + # TODO: __getitem__ does not work + # xi = x[bound] + xi = x.shrink((bound, None, None, None)) + bni = bn(xi) + bn_ts.append(bni) + # TODO: what do we want to do for inference? average weight? pick any one? + # a good start would be to check each mean/std are similar + return bn_ts[0].cat(*bn_ts[1:]) + class ConvGroup: def __init__(self, channels_in, channels_out): self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False) self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False) - self.norm1 = BatchNorm(channels_out) - self.norm2 = BatchNorm(channels_out) + if getenv("SYNCBN"): + self.norm1 = BatchNorm(channels_out) + self.norm2 = BatchNorm(channels_out) + else: + self.norm1 = UnsyncedBatchNorm(channels_out) + self.norm2 = UnsyncedBatchNorm(channels_out) def __call__(self, x): x = self.conv1(x) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 3bf2910434..77f5ce202a 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1,4 +1,5 @@ import unittest, functools +from typing import List from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit from tinygrad.device import BufferCopy from tinygrad.ops import LoadOps, ReduceOps @@ -305,5 +306,234 @@ class TestMultiTensor(unittest.TestCase): # for i, ast in enumerate(asts): # print(f"{i} {ast}") +@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI") +class TestShrinkMultiTensorShardedAxis(unittest.TestCase): + # shrink a multitensor on sharded axis + def test_shrink_bad_args(self): + t = Tensor.arange(64).reshape(8, 8).contiguous().realize() + t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0) + + with self.assertRaises(AssertionError): + # sharded axis shrink on non-device boundry is not allowed + a = t.shrink(((0, 3), (0, 8))) + with self.assertRaises(AssertionError): + # cannot shrink sharded and non-sharded axis at the same time + a = t.shrink(((0, 2), (2, 4))) + + a = t.shrink(((0, 2), (0, 8))) + assert a.shape == (2, 8) + assert a.lazydata.real == [True, False, False, False] + + with self.assertRaises(AssertionError): + # cannot pad sharded and non-sharded axis at the same time + p = a.pad(((0, 6), (0, 1))) + + with self.assertRaises(AssertionError): + # can only pad to whole axis + p = a.pad(((1, 5), (0, 0))) + + p = a.pad(((0, 6), (0, 0))) + assert p.shape == (8, 8) + assert p.lazydata.real == [True, True, True, True] + + def test_ops(self): + t = Tensor.arange(64).reshape(8, 8).contiguous().realize() + t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0) + for i in range(4): + print(f"{i=}") + a = t.shrink(((0+2*i,2+2*i),None)) + b = Tensor(t.numpy()[0+2*i:2+2*i]) + assert a.shape == b.shape == (2, 8) + assert a.lazydata.real == [i==j for j in range(4)] + np.testing.assert_allclose(a.numpy(), b.numpy()) + # cast + np.testing.assert_allclose(a.float().numpy(), b.float().numpy()) + + # elementwise + np.testing.assert_allclose(a.exp().numpy(), b.exp().numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.reciprocal().numpy(), b.reciprocal().numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.pow(-0.5).numpy(), b.pow(-0.5).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose((a+a).numpy(), (b+b).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_equal((a+1).numpy(), (b+1).numpy()) + np.testing.assert_equal((1+a).numpy(), (1+b).numpy()) + np.testing.assert_allclose((a.where(a+a, a)).numpy(), (b.where(b+b, b)).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose((a.where(1, 0)).numpy(), (b.where(1, 0)).numpy(), rtol=1e-7, atol=1e-3) + + # reduce + np.testing.assert_allclose(a.max().numpy(), b.max().numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.sum().numpy(), b.sum().numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.mean().numpy(), b.mean().numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.max(0).numpy(), b.max(0).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.sum(0).numpy(), b.sum(0).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.mean(0).numpy(), b.mean(0).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.max(1).numpy(), b.max(1).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.sum(1).numpy(), b.sum(1).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.mean(1).numpy(), b.mean(1).numpy(), rtol=1e-7, atol=1e-3) + + # pad it back + np.testing.assert_allclose(a.pad(((2*i, 2*(4-i-1)), None)).numpy(), b.pad(((2*i, 2*(4-i-1)), None)).numpy(), rtol=1e-7, atol=1e-3) + + # other movement + np.testing.assert_allclose(a.pad((None, (1, 1))).numpy(), b.pad((None, (1, 1))).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.shrink((None, (1, 3))).numpy(), b.shrink((None, (1, 3))).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.permute((1, 0)).numpy(), b.permute((1, 0)).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.reshape((2, 2, 4)).numpy(), b.reshape((2, 2, 4)).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3) + np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3) + + def test_uneven(self): + t = Tensor.arange(24).reshape(3, 8).contiguous().realize() + t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0) + + a = t.shrink(((0, 2), None)) + b = t.shrink(((2, 3), None)) + na = t.numpy()[0:2] + nb = t.numpy()[2:3] + np.testing.assert_equal(a.numpy(), na) + np.testing.assert_equal(b.numpy(), nb) + np.testing.assert_equal((a+1).numpy(), na+1) + np.testing.assert_equal((b+1).numpy(), nb+1) + np.testing.assert_equal((1+a).numpy(), 1+na) + np.testing.assert_equal((1+b).numpy(), 1+nb) + np.testing.assert_equal((a+a).numpy(), na+na) + np.testing.assert_equal((b+b).numpy(), nb+nb) + + def test_add_two_partitions(self): + t = Tensor.arange(64).reshape(8, 8).contiguous().realize() + t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0) + + a = t.shrink(((2, 4), None)) + b = t.shrink(((6, 8), None)) + na = t.numpy()[2:4] + nb = t.numpy()[6:8] + np.testing.assert_equal(a.numpy(), na) + np.testing.assert_equal(b.numpy(), nb) + with self.assertRaises(AssertionError): + # cannot add directly + c = a + b + + c = a.pad(((2, 4), None)) + b.pad(((6, 0), None)) + expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb]) + np.testing.assert_equal(c, expected) + + def test_add_different_tensors(self): + devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] + x = Tensor.arange(64).reshape(8, 8).contiguous().realize().shard(devices, axis=0) + + to_add = [] + for i in range(len(devices)): + to_add.append((Tensor.ones(2, 8) * i).shard(devices)) + + added:List[Tensor] = [] + for bound, a in zip(x.lazydata.bounds, to_add): + added.append(x[bound[0]:bound[1]] + a) + + output = added[0].cat(*added[1:]) + expected = np.arange(64).reshape((8,8)) + np.array([[0,0,1,1,2,2,3,3] for _ in range(8)]).T + np.testing.assert_allclose(output.numpy(), expected) + + def test_unsynced_backprop_conv_bn(self): + from extra.lr_scheduler import OneCycleLR + + convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)] + bns = [nn.BatchNorm2d(16), nn.BatchNorm2d(16)] + + for p in get_parameters(convs + bns): + p.shard_((d1, d2)) + optim = nn.optim.Adam(get_parameters(convs + bns)) + lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10) + lr_sched.step() + + fake_image = Tensor.rand((8, 3, 32, 32)).shard((d1, d2), axis=0) + + f1 = fake_image.shrink(((0, 4), None, None, None)) + f2 = fake_image.shrink(((4, 8), None, None, None)) + + out1 = bns[0](convs[0](f1)) + out2 = bns[1](convs[1](f2)) + out = out1.cat(out2) + optim.zero_grad() + out.mean().backward() + optim.step() + + def test_unsynced_backprop_standalone_bn(self): + from extra.lr_scheduler import OneCycleLR + GPUS = (d1, d2) + + class BatchNorm: + def __init__(self, num_features): + self.bns:List[nn.BatchNorm2d] = [] + for _ in GPUS: + bn = nn.BatchNorm2d(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True) + self.bns.append(bn) + + def __call__(self, x:Tensor): + bn_ts = [] + for bound, bn in zip(x.lazydata.bounds, self.bns): + xi = x.shrink((bound, None, None, None)) + bni = bn(xi) + bn_ts.append(bni) + return bn_ts[0].cat(*bn_ts[1:]) + + with Tensor.train(): + conv = nn.Conv2d(3, 16, 3) + bn = BatchNorm(16) + + for p in get_parameters([conv, bn]): + p.shard_(GPUS) + optim = nn.optim.Adam(get_parameters([conv, bn])) + lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10) + lr_sched.step() + + fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0) + + out = bn(conv(fake_image)) + optim.zero_grad() + out.mean().backward() + optim.step() + + @given(strat.sampled_from((False, True))) + def test_batchnorm(self, is_training): + devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] + x = Tensor.arange(4096).reshape(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0) + + with Tensor.train(is_training): + bns = [] + for _ in range(len(devices)): + bn = nn.BatchNorm2d(8) + for p in get_parameters(bn): + p.shard_(devices) + bn.weight.requires_grad = True + bn.bias.requires_grad = True + bns.append(bn) + + bn_ts = [] + for bound, bn in zip(x.lazydata.bounds, bns): + bni = bn(x[bound[0]:bound[1]]) + bn_ts.append(bni) + + bn_ts[0].cat(*bn_ts[1:]).numpy() + + def test_synced_vs_unsynced_bn(self): + from examples.hlb_cifar10 import BatchNorm, UnsyncedBatchNorm + devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] + x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0) + + with Tensor.train(): + synced_bn = BatchNorm(8) + unsynced_bn = UnsyncedBatchNorm(8) + + for p in get_parameters([synced_bn, unsynced_bn]): + p.shard_(devices) + + synced_out = synced_bn(x) + synced_si = [si for si in synced_out.lazydata.schedule()] + unsynced_out = unsynced_bn(x) + unsynced_si = [si for si in unsynced_out.lazydata.schedule()] + + # TODO: test synced / unsynced batchnorm cross device kernel and copies + assert synced_si + assert unsynced_si + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index 407def782e..0c159e1caf 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -18,17 +18,26 @@ def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]: return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)] class MultiLazyBuffer: - def __init__(self, lbs:List[LazyBuffer], axis:Optional[int]): + def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None): assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them" #assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st" - self.lbs, self.axis, self.dtype, self.device = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs) - self.shape = tuple(sum(y.shape[a] for y in self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape)) + self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs) + if axis is not None: + splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0)) + self.bounds = [(st,ed) for st,ed in zip(splits, splits[1:])] @property - def size(self): return sum(x.size for x in self.lbs) + def shape(self): + return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)) + + @property + def size(self): return sum(x.size for x in self.real_lbs) + + @property + def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r] def __repr__(self): - return f"" + return f"" @staticmethod def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None): @@ -36,10 +45,10 @@ class MultiLazyBuffer: return MultiLazyBuffer([lb.copy_to_device(d).contiguous() for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)], axis) def copy_to_device(self, device:str) -> LazyBuffer: - if self.axis is None: return self.lbs[0].copy_to_device(device) + if self.axis is None: return self.lbs[self.real.index(True)].copy_to_device(device) sz = self.lbs[0].shape[self.axis] llbs = [] - for i,lb in enumerate([lb.copy_to_device(device) for lb in self.lbs]): + for i,lb in enumerate([lb.copy_to_device(device) for lb in self.real_lbs]): pad_arg = tuple((0,0) if a != self.axis else (sz*i, max(0, self.shape[self.axis]-sz*(i+1))) for a in range(len(lb.shape))) llbs.append(lb.pad(pad_arg)) return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs) @@ -48,10 +57,10 @@ class MultiLazyBuffer: def is_unrealized_contiguous_const(self): return False # passthroughs - def schedule(self, seen=None): return create_schedule(self.lbs, seen) - def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis) - def const(self, val:Scalar) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis) - def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis) + def schedule(self, seen=None): return create_schedule(self.real_lbs, seen) + def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real) + def const(self, val:Scalar) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real) + def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real) # elementwise is simple def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer: @@ -62,11 +71,15 @@ class MultiLazyBuffer: # NOTE: they all have to share an axis, we always choose [-1] axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None srcs = [] + not_all_real = any(not all(mlb.real) for mlb in msrcs) + new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real + assert any(new_real), "output contains no real lb" for mlb in msrcs: - if mlb.axis == axis: srcs.append(mlb.lbs) + if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs) elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis)) else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis)) - return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) for lsrcs in zip(*srcs)], axis) + # NOTE: lsrcs[-1].const(0) is correct for where + return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) if r else lsrcs[-1].const(0) for lsrcs,r in zip(zip(*srcs),new_real)], axis, new_real) def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]: return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape)) @@ -74,33 +87,48 @@ class MultiLazyBuffer: def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer: if self.axis is not None and new_shape[self.axis] == 1: # all-reduce on sharded axes - return MultiLazyBuffer(all_reduce(op, [x.r(op, new_shape) for x in self.lbs]), None) + reduced_parts = [x.r(op, new_shape) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)] + if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None) + return MultiLazyBuffer(reduced_parts, None, self.real) # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape, x)) for x in self.lbs], self.axis) + return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape, x)) for x in self.lbs], self.axis, self.real) # *** movement ops *** def reshape(self, arg:Tuple[sint, ...]): - if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None) + if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real) arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1)) # new_axis is the one that preserves prod(prior to new_axis) and prod(post to new_axis) new_axis = [tuple(p) for p in zip(arg_acc, arg_acc[1:])].index((prod(self.shape[:self.axis]), prod(self.shape[:self.axis+1]))) - return MultiLazyBuffer([x.reshape(tuple(x.shape[self.axis] if a == new_axis else s for a,s in enumerate(arg))) for x in self.lbs], new_axis) + return MultiLazyBuffer([x.reshape(tuple(x.shape[self.axis] if a == new_axis else s for a,s in enumerate(arg))) for x in self.lbs], + new_axis, self.real) def pad(self, arg:Tuple[Tuple[sint, sint], ...]): - assert self.axis is None or arg[self.axis] == (0,0), "padding not supported on sharded axis" - return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis) + assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}" + # pad on shard axis -> fill others with zeros and set real to all True + if self.axis is not None and arg[self.axis] != (0,0): + # pad back to whole axis, remove real mask + assert all(arg[i] == (0, 0) or i == self.axis for i in range(len(self.shape))), "cannot pad sharded and non-sharded axis at the same time" + assert arg[self.axis] == (sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i < self.real.index(True)), \ + sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis" + return MultiLazyBuffer([x if r else x.const(0) for x,r in zip(self.lbs, self.real)], self.axis) + return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real) def expand(self, arg:Tuple[sint, ...]): # NOTE: this assert isn't needed, sharded axis can have dim 1 - assert self.axis is None or arg[self.axis] == self.shape[self.axis], "expand not supported on sharded axis" - return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis) + assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}" + return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real) def permute(self, arg:Tuple[int, ...]): # all permutes supported! - return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None) + return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real) def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): - assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]), "shrinking not supported on sharded axis" - return MultiLazyBuffer( - [x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg))) for x in self.lbs], self.axis) + assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}" + if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]): + assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time" + idx = self.bounds.index(arg[self.axis]) + # zero out other lbs to not create lb reference + return MultiLazyBuffer([lb if i==idx else lb.const(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))]) + return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs], + self.axis, self.real) def stride(self, arg:Tuple[int, ...]): assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis" - return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis) + return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)