shrink MLB on sharded axis (#3255)

* shrink MLB on sharded axis

use onehot structure to store the real partition. goal is unsynced batchnorm2d that can be run on multigpu for training.

draft version in https://github.com/chenyuxyz/tinygrad/pull/109

* SYNCBN flag

* test unclean shrinks

* UnsyncedBatchNorm reuses BatchNorm

* more robust pad arg check

* better types

* more tests!

* 6 gpus in benchmark

* disable slow GPUS=6 benchmark
This commit is contained in:
chenyu
2024-01-31 21:48:25 -05:00
committed by GitHub
parent a3652e6ddc
commit 18e854cdbf
4 changed files with 317 additions and 29 deletions

View File

@@ -139,6 +139,9 @@ jobs:
HIP=1 JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
- name: Run 10 CIFAR training steps
run: STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
# # TODO: enable this. it took 3 minutes in CI and made the full training one more than 5 minutes
# - name: Run 10 CIFAR training steps w HALF and 6 GPUS
# run: time HALF=1 STEPS=10 BS=1536 GPUS=6 python3 examples/hlb_cifar10.py
- name: Run full CIFAR training w HALF
run: time HALF=1 STEPS=1000 python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
# # TODO: make wino faster so we can enable both