mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
shrink MLB on sharded axis (#3255)
* shrink MLB on sharded axis use onehot structure to store the real partition. goal is unsynced batchnorm2d that can be run on multigpu for training. draft version in https://github.com/chenyuxyz/tinygrad/pull/109 * SYNCBN flag * test unclean shrinks * UnsyncedBatchNorm reuses BatchNorm * more robust pad arg check * better types * more tests! * 6 gpus in benchmark * disable slow GPUS=6 benchmark
This commit is contained in:
@@ -5,13 +5,14 @@
|
||||
# https://siboehm.com/articles/22/CUDA-MMM
|
||||
import random, time
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
|
||||
from tinygrad.nn.state import get_state_dict, get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import Context, BEAM, WINO, getenv
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
|
||||
BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
|
||||
EVAL_BS = getenv("EVAL_BS", BS)
|
||||
@@ -33,13 +34,39 @@ class BatchNorm(nn.BatchNorm2d):
|
||||
self.weight.requires_grad = False
|
||||
self.bias.requires_grad = True
|
||||
|
||||
class UnsyncedBatchNorm:
|
||||
def __init__(self, num_features, num_devices=len(GPUS)):
|
||||
self.bns:List[BatchNorm] = []
|
||||
for _ in range(num_devices):
|
||||
bn = BatchNorm(num_features)
|
||||
self.bns.append(bn)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
if len(self.bns) == 1: return self.bns[0](x)
|
||||
|
||||
bn_ts = []
|
||||
assert isinstance(x.lazydata, MultiLazyBuffer)
|
||||
for bound, bn in zip(x.lazydata.bounds, self.bns):
|
||||
# TODO: __getitem__ does not work
|
||||
# xi = x[bound]
|
||||
xi = x.shrink((bound, None, None, None))
|
||||
bni = bn(xi)
|
||||
bn_ts.append(bni)
|
||||
# TODO: what do we want to do for inference? average weight? pick any one?
|
||||
# a good start would be to check each mean/std are similar
|
||||
return bn_ts[0].cat(*bn_ts[1:])
|
||||
|
||||
class ConvGroup:
|
||||
def __init__(self, channels_in, channels_out):
|
||||
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
self.norm1 = BatchNorm(channels_out)
|
||||
self.norm2 = BatchNorm(channels_out)
|
||||
if getenv("SYNCBN"):
|
||||
self.norm1 = BatchNorm(channels_out)
|
||||
self.norm2 = BatchNorm(channels_out)
|
||||
else:
|
||||
self.norm1 = UnsyncedBatchNorm(channels_out)
|
||||
self.norm2 = UnsyncedBatchNorm(channels_out)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
Reference in New Issue
Block a user