shrink MLB on sharded axis (#3255)

* shrink MLB on sharded axis

use onehot structure to store the real partition. goal is unsynced batchnorm2d that can be run on multigpu for training.

draft version in https://github.com/chenyuxyz/tinygrad/pull/109

* SYNCBN flag

* test unclean shrinks

* UnsyncedBatchNorm reuses BatchNorm

* more robust pad arg check

* better types

* more tests!

* 6 gpus in benchmark

* disable slow GPUS=6 benchmark
This commit is contained in:
chenyu
2024-01-31 21:48:25 -05:00
committed by GitHub
parent a3652e6ddc
commit 18e854cdbf
4 changed files with 317 additions and 29 deletions

View File

@@ -5,13 +5,14 @@
# https://siboehm.com/articles/22/CUDA-MMM
import random, time
import numpy as np
from typing import Optional
from typing import Optional, List
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
from extra.lr_scheduler import OneCycleLR
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
from tinygrad.nn.state import get_state_dict, get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import Context, BEAM, WINO, getenv
from tinygrad.features.multi import MultiLazyBuffer
BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
EVAL_BS = getenv("EVAL_BS", BS)
@@ -33,13 +34,39 @@ class BatchNorm(nn.BatchNorm2d):
self.weight.requires_grad = False
self.bias.requires_grad = True
class UnsyncedBatchNorm:
def __init__(self, num_features, num_devices=len(GPUS)):
self.bns:List[BatchNorm] = []
for _ in range(num_devices):
bn = BatchNorm(num_features)
self.bns.append(bn)
def __call__(self, x:Tensor):
if len(self.bns) == 1: return self.bns[0](x)
bn_ts = []
assert isinstance(x.lazydata, MultiLazyBuffer)
for bound, bn in zip(x.lazydata.bounds, self.bns):
# TODO: __getitem__ does not work
# xi = x[bound]
xi = x.shrink((bound, None, None, None))
bni = bn(xi)
bn_ts.append(bni)
# TODO: what do we want to do for inference? average weight? pick any one?
# a good start would be to check each mean/std are similar
return bn_ts[0].cat(*bn_ts[1:])
class ConvGroup:
def __init__(self, channels_in, channels_out):
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
self.norm1 = BatchNorm(channels_out)
self.norm2 = BatchNorm(channels_out)
if getenv("SYNCBN"):
self.norm1 = BatchNorm(channels_out)
self.norm2 = BatchNorm(channels_out)
else:
self.norm1 = UnsyncedBatchNorm(channels_out)
self.norm2 = UnsyncedBatchNorm(channels_out)
def __call__(self, x):
x = self.conv1(x)