MultiLazyBuffer is UOp [pr] (#8662)

* MultiLazyBuffer is UOp [pr]

* this is new mlb

* this is the idea

* progress

* multitensor works

* more movement ops

* this

* MultiLazyBuffer is UOp

* cleanups

* multi axis

* fix more tests

* work

* not that

* add multi grad and move shard to ops

* mops not views

* no double contig

* sweet, all mt tests passing

* port old logic

* remove lbs

* fix realized

* whitespace

* assign tweak

* test_assign_kv_cache_multi passes

* fix is_realized

* fix JIT for multi

* just a few more lines i'll pay them back soon i swear please bro just a few more

* no split reduceop for multi
This commit is contained in:
George Hotz
2025-01-24 13:28:55 +09:00
committed by GitHub
parent eb77488f85
commit e82ba1454b
11 changed files with 277 additions and 210 deletions

View File

@@ -11,7 +11,6 @@ from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
from tinygrad.nn.state import get_state_dict, get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod
from tinygrad.multi import MultiLazyBuffer
cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
@@ -35,8 +34,6 @@ class UnsyncedBatchNorm:
self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.int, requires_grad=False)
def __call__(self, x:Tensor):
if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices
xr = x.reshape(self.num_devices, -1, *x.shape[1:]).cast(dtypes.float32)
batch_mean, batch_invstd = self.calc_stats(xr)
ret = xr.batchnorm(