mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
MultiLazyBuffer is UOp [pr] (#8662)
* MultiLazyBuffer is UOp [pr] * this is new mlb * this is the idea * progress * multitensor works * more movement ops * this * MultiLazyBuffer is UOp * cleanups * multi axis * fix more tests * work * not that * add multi grad and move shard to ops * mops not views * no double contig * sweet, all mt tests passing * port old logic * remove lbs * fix realized * whitespace * assign tweak * test_assign_kv_cache_multi passes * fix is_realized * fix JIT for multi * just a few more lines i'll pay them back soon i swear please bro just a few more * no split reduceop for multi
This commit is contained in:
@@ -11,7 +11,6 @@ from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
|
||||
from tinygrad.nn.state import get_state_dict, get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
|
||||
cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
|
||||
cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
|
||||
@@ -35,8 +34,6 @@ class UnsyncedBatchNorm:
|
||||
self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.int, requires_grad=False)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices
|
||||
|
||||
xr = x.reshape(self.num_devices, -1, *x.shape[1:]).cast(dtypes.float32)
|
||||
batch_mean, batch_invstd = self.calc_stats(xr)
|
||||
ret = xr.batchnorm(
|
||||
|
||||
Reference in New Issue
Block a user