mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
* this is a lot of stuff TEST_TRAIN env for less data don't diskcache get_train_files debug message no lr_scaler for fp32 comment, typo type stuff don't destructure proc make batchnorm parameters float make batchnorm parameters float resnet18, checkpointing hack up checkpointing to keep the names in there oops wandb_resume lower lr eval/ckpt use e+1 lars report top_1_acc some wandb stuff split fw and bw steps to save memory oops save model when reach target formatting make sgd hparams consistent just always write the cats tag... pass X and Y into backward_step to trigger input replace shuffle eval set to fix batchnorm eval dataset is sorted by class, so the means and variances are all wrong small cleanup hack restore only one copy of each tensor do bufs from lin after cache check (lru should handle it fine) record epoch in wandb more digits for topk in eval more env vars small cleanup cleanup hack tricks cleanup hack tricks don't save ckpt for testeval cleanup diskcache train file glob clean up a little device_str SCE into tensor small small log_softmax out of resnet.py oops hack :( comments HeNormal, track gradient norm oops log SYNCBN to wandb real truncnorm less samples for truncated normal custom init for Linear log layer stats small Revert "small" This reverts commit988f4c1cf3. Revert "log layer stats" This reverts commit9d98224585. rename BNSYNC to SYNCBN to be consistent with cifar optional TRACK_NORMS fix label smoothing :/ lars skip list only weight decay if not in skip list comment default 0 TRACK_NORMS don't allocate beam scratch buffers if in cache clean up data pipeline, unsplit train/test, put back a hack remove print run test_indexing on remu (#3404) * emulated ops_hip infra * add int4 * include test_indexing in remu * Revert "Merge branch 'remu-dev-mac'" This reverts commit6870457e57, reversing changes made to3c4c8c9e16. fix bad seeding UnsyncBatchNorm2d but with synced trainable weights label downsample batchnorm in Bottleneck :/ :/ i mean... it runs... its hits the acc... its fast... new unsyncbatchnorm for resnet small fix don't do assign buffer reuse for axis change * remove changes * remove changes * move LARS out of tinygrad/ * rand_truncn rename * whitespace * stray whitespace * no more gnorms * delete some dataloading stuff * remove comment * clean up train script * small comments * move checkpointing stuff to mlperf helpers * if WANDB * small comments * remove whitespace change * new unsynced bn * clean up prints / loop vars * whitespace * undo nn changes * clean up loops * rearrange getenvs * cpu_count() * PolynomialLR whitespace * move he_normal out * cap warmup in polylr * rearrange wandb log * realize both x and y in data_get * use double quotes * combine prints in ckpts resume * take UBN from cifar * running_var * whitespace * whitespace * typo * if instead of ternary for resnet downsample * clean up dataloader cleanup a little? * separate rng for shuffle * clean up imports in model_train * clean up imports * don't realize copyin in data_get * remove TESTEVAL (train dataloader didn't get freed every loop) * adjust wandb_config entries a little * clean up wandb config dict * reduce lines * whitespace * shorter lines * put shm unlink back, but it doesn't seem to do anything * don't pass seed per task * monkeypatch batchnorm * the reseed was wrong * add epoch number to desc * don't unsyncedbatchnorm is syncbn=1 * put back downsample name * eval every epoch * Revert "the reseed was wrong" This reverts commit 3440a07dff3f40e8a8d156ca3f1938558a59249f. * cast lr in onecycle * support fp16 * cut off kernel if expand after reduce * test polynomial lr * move polynomiallr to examples/mlperf * working PolynomialDecayWithWarmup + tests....... add lars_util.py, oops * keep lars_util.py as intact as possible, simplify our interface * no more half * polylr and lars were merged * undo search change * override Linear init * remove half stuff from model_train * update scheduler init with new args * don't divide by input mean * mistake in resnet.py * restore whitespace in resnet.py * add test_data_parallel_resnet_train_step * move initializers out of resnet.py * unused imports * log_softmax to model output in test to fix precision flakiness * log_softmax to model output in test to fix precision flakiness * oops, don't realize here * is None * realize initializations in order for determinism * BENCHMARK flag for number of steps * add resnet to bechmark.yml * return instead of break * missing return * cpu_count, rearrange benchmark.yml * unused variable * disable tqdm if BENCHMARK * getenv WARMUP_EPOCHS * unlink disktensor shm file if exists * terminate instead of join * properly shut down queues * use hip in benchmark for now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
668 lines
26 KiB
Python
668 lines
26 KiB
Python
import unittest, functools
|
|
from typing import List
|
|
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit
|
|
from tinygrad.device import BufferCopy
|
|
from tinygrad.ops import LoadOps, ReduceOps
|
|
from tinygrad.helpers import CI
|
|
from tinygrad.nn.state import get_parameters, get_state_dict
|
|
from tinygrad.realize import create_schedule
|
|
import numpy as np
|
|
from hypothesis import given, strategies as strat, settings
|
|
|
|
settings.register_profile("my_profile", max_examples=200, deadline=None)
|
|
settings.load_profile("my_profile")
|
|
|
|
d_zero = f"{Device.DEFAULT}:0"
|
|
d0, d1 = f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2"
|
|
d2, d3 = f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4"
|
|
devices_2 = (d0, d1)
|
|
devices_3 = (d0, d1, d2)
|
|
N = 128
|
|
|
|
# shard_x is "data parallel"
|
|
# shard_w is "model parallel"
|
|
|
|
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
|
class TestMultiTensor(unittest.TestCase):
|
|
def test_to(self):
|
|
X = Tensor.ones(256).contiguous().realize()
|
|
X.to_((d0, d1))
|
|
for lb in X.lazydata.lbs:
|
|
assert lb.shape == (256,)
|
|
(X + X).realize()
|
|
|
|
def test_shard(self):
|
|
X = Tensor.ones(256).contiguous().realize()
|
|
X.shard_((d0, d1), 0)
|
|
for lb in X.lazydata.lbs:
|
|
assert lb.shape == (128,)
|
|
(X + X).realize()
|
|
|
|
def test_shard_same_device(self):
|
|
X = Tensor.ones(256).contiguous().realize()
|
|
X.shard_((d0, X.device), 0)
|
|
(X + X).realize()
|
|
|
|
def test_shard_plus_one_sum(self):
|
|
X = Tensor.ones(256).contiguous().realize()
|
|
X.shard_([d0, d1], 0)
|
|
(X + 1).sum().realize()
|
|
|
|
def test_shard_plus_one_sum_d_zero(self):
|
|
X = Tensor.ones(256).contiguous().realize()
|
|
X.shard_([d_zero, d1], 0)
|
|
(X + 1).sum().realize()
|
|
|
|
def test_numpy(self):
|
|
X = Tensor.ones(256)
|
|
X.shard_((d0, d1), 0)
|
|
np.testing.assert_allclose(X.numpy(), 1)
|
|
|
|
def _test_simple_add_axis(self, shard_x, shard_w):
|
|
X = Tensor.ones(256).contiguous().realize()
|
|
W = Tensor.ones(256).contiguous().realize()
|
|
X.shard_((d0, d1), shard_x)
|
|
W.shard_((d0, d1), shard_w)
|
|
O = X + W
|
|
np.testing.assert_allclose(O.numpy(), 2)
|
|
|
|
def test_simple_add(self): return self._test_simple_add_axis(None, None)
|
|
def test_simple_add_X(self): return self._test_simple_add_axis(0, None)
|
|
def test_simple_add_W(self): return self._test_simple_add_axis(None, 0)
|
|
def test_simple_add_XW(self): return self._test_simple_add_axis(0, 0)
|
|
|
|
def test_four_add(self):
|
|
X = Tensor.ones(256, 256).contiguous().realize()
|
|
W = Tensor.ones(256, 256).contiguous().realize()
|
|
X.shard_((d0, d1, d2, d3), 1)
|
|
W.shard_((d0, d1, d2, d3), None)
|
|
O = X + W
|
|
np.testing.assert_allclose(O.numpy(), 2)
|
|
|
|
@given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)), strat.sampled_from((ReduceOps.SUM, ReduceOps.MAX)),
|
|
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
|
|
def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign):
|
|
X = Tensor.rand(N*N).reshape(N, N).mul(sign)
|
|
n = X.numpy()
|
|
X.shard_(devices, shard_axis)
|
|
f = {ReduceOps.SUM: lambda x: x.sum(reduce_axis), ReduceOps.MAX: lambda x: x.max(reduce_axis)}[rop]
|
|
fX = f(X)
|
|
fn = f(n)
|
|
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
|
|
|
|
def _test_matmul_shard_axis(self, shard_x, shard_w, device):
|
|
X = Tensor.kaiming_uniform(N, N).realize()
|
|
W = Tensor.kaiming_uniform(N, N).realize()
|
|
Xs = X.shard(device, shard_x)
|
|
Ws = W.shard(device, shard_w)
|
|
O = (Xs@Ws)
|
|
np.testing.assert_allclose(X.numpy() @ W.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)
|
|
|
|
def _test_double_matmul_shard_axis(self, shard_x, shard_w, device):
|
|
X = Tensor.kaiming_uniform(N, N).realize()
|
|
W1 = Tensor.kaiming_uniform(N, N).realize()
|
|
W2 = Tensor.kaiming_uniform(N, N).realize()
|
|
Xs = X.shard(device, shard_x)
|
|
W1s = W1.shard(device, shard_w)
|
|
W2s = W2.shard(device, shard_w)
|
|
O = (Xs@W1s)@W2s
|
|
np.testing.assert_allclose((X.numpy() @ W1.numpy()) @ W2.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)
|
|
|
|
def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None, devices_2)
|
|
def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None, devices_2)
|
|
def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None, devices_2)
|
|
def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0, devices_2)
|
|
def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1, devices_2)
|
|
|
|
def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0, devices_2)
|
|
def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1, devices_2)
|
|
def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0, devices_2)
|
|
def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1, devices_2)
|
|
|
|
def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None, devices_2)
|
|
def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None, devices_2)
|
|
def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0, devices_2)
|
|
def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1, devices_2)
|
|
|
|
def test_conv_data_shard(self):
|
|
conv = nn.Conv2d(3, 16, 3, bias=False)
|
|
for p in get_parameters(conv): p.shard_((d0, d1))
|
|
fake_image = Tensor.rand((2, 3, 32, 32)).shard((d0, d1), axis=0)
|
|
out = conv(fake_image)
|
|
out.numpy()
|
|
|
|
def test_conv_bias_data_shard(self):
|
|
conv = nn.Conv2d(3, 16, 3)
|
|
for p in get_parameters(conv): p.shard_((d0, d1))
|
|
fake_image = Tensor.rand((2, 3, 32, 32)).shard((d0, d1), axis=0)
|
|
out = conv(fake_image)
|
|
out.numpy()
|
|
|
|
def test_backprop_conv(self):
|
|
conv = nn.Conv2d(3, 16, 3)
|
|
for p in get_parameters(conv): p.shard_((d0, d1))
|
|
optim = nn.optim.Adam(get_parameters(conv))
|
|
fake_image = Tensor.rand((2, 3, 32, 32)).shard((d0, d1), axis=0)
|
|
out = conv(fake_image)
|
|
optim.zero_grad()
|
|
out.mean().backward()
|
|
#for p in get_parameters(conv): p.grad.realize()
|
|
optim.step()
|
|
out.numpy()
|
|
|
|
def test_lr_scheduler_OneCycleLR(self):
|
|
from extra.lr_scheduler import OneCycleLR
|
|
conv = nn.Conv2d(3, 16, 3)
|
|
for p in get_parameters(conv): p.shard_((d0, d1))
|
|
optim = nn.optim.SGD(get_parameters(conv))
|
|
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
|
|
lr_sched.step()
|
|
|
|
def test_embedding(self):
|
|
B, T, embed_size, vocab_size = 4, 10, 20, 28
|
|
|
|
layer = nn.Embedding(vocab_size, embed_size)
|
|
x = Tensor(np.random.randint(0, vocab_size, (B, T)))
|
|
z = layer(x)
|
|
|
|
layer_sharded = nn.Embedding(vocab_size, embed_size)
|
|
layer_sharded.weight.assign(layer.weight.shard((d0, d1), axis=1)).realize()
|
|
x_sharded = x.shard((d0, d1), axis=None)
|
|
z_shard = layer_sharded(x_sharded)
|
|
|
|
np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6)
|
|
|
|
def test_rmsnorm(self):
|
|
from extra.models.llama import RMSNorm
|
|
B, T, embed_size = 4, 10, 20
|
|
|
|
layer_norm = RMSNorm(embed_size)
|
|
x = Tensor.rand((B, T, embed_size)).contiguous().realize()
|
|
y = layer_norm(x)
|
|
|
|
# for norm layers, the correct way to shard weights is duplication
|
|
layer_norm_sharded = RMSNorm(embed_size)
|
|
layer_norm_sharded.weight.shard_((d0, d1), axis=None).realize()
|
|
|
|
# if x is being sharded, then all-reduce is involved
|
|
x_sharded = x.shard((d0, d1), axis=2).realize()
|
|
y_shard = layer_norm_sharded(x_sharded).realize()
|
|
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
|
|
|
|
# if x is being duplicated, then the operations remain inside each GPU
|
|
# which is the common case
|
|
x_sharded = x.shard((d0, d1), axis=None).realize()
|
|
y_shard = layer_norm_sharded(x_sharded).realize()
|
|
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
|
|
|
|
def test_data_parallel_resnet(self):
|
|
import sys, pathlib
|
|
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
|
|
from resnet import ResNet18
|
|
|
|
fake_image = Tensor.rand((2, 3, 224, 224))
|
|
fake_image_sharded = fake_image.shard((d0, d1), axis=0)
|
|
m = ResNet18()
|
|
m.load_from_pretrained()
|
|
real_output = m(fake_image).log_softmax().numpy()
|
|
for p in get_parameters(m): p.shard_((d0, d1)).realize()
|
|
GlobalCounters.reset()
|
|
shard_output = m(fake_image_sharded).log_softmax().realize()
|
|
assert shard_output.lazydata.lbs[0].shape == (1, 1000)
|
|
assert shard_output.lazydata.lbs[1].shape == (1, 1000)
|
|
shard_output_np = shard_output.numpy()
|
|
np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6)
|
|
|
|
def test_data_parallel_resnet_train_step(self):
|
|
import sys, pathlib
|
|
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
|
|
from resnet import ResNet18
|
|
from examples.mlperf.optimizers import LARS
|
|
|
|
fake_image = Tensor.rand((2, 3, 224, 224))
|
|
fake_image_sharded = fake_image.shard((d0, d1), axis=0)
|
|
labels = Tensor.randint(2, low=0, high=1000)
|
|
labels_sharded = labels.shard((d0, d1), axis=0)
|
|
|
|
m = ResNet18()
|
|
optimizer = LARS(get_parameters(m), 0.1) # set requires_grad for all params
|
|
|
|
optimizer.zero_grad()
|
|
m.load_from_pretrained()
|
|
output = m(fake_image).sparse_categorical_crossentropy(labels, label_smoothing=0.1)
|
|
output.backward()
|
|
grad = m.conv1.weight.grad.numpy()
|
|
|
|
for p in get_parameters(m): p.shard_((d0, d1)).realize()
|
|
GlobalCounters.reset()
|
|
optimizer.zero_grad()
|
|
shard_output = m(fake_image_sharded).sparse_categorical_crossentropy(labels_sharded, label_smoothing=0.1)
|
|
assert shard_output.lazydata.axis is None
|
|
shard_output.backward()
|
|
shard_grad = m.conv1.weight.grad.numpy()
|
|
# sometimes there is zeros in these grads... why?
|
|
np.testing.assert_allclose(grad, shard_grad, atol=1e-6, rtol=1e-6)
|
|
|
|
def test_multi_tensor_jit_param(self):
|
|
@TinyJit
|
|
def jf(a, b) -> Tensor:
|
|
return (a + b).realize()
|
|
|
|
for _ in range(5):
|
|
a = Tensor.ones(256).contiguous().realize()
|
|
b = Tensor.ones(256).contiguous().realize()
|
|
a.shard_((d0, d1))
|
|
b.shard_((d0, d1))
|
|
c = jf(a, b)
|
|
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
|
assert len(jf.jit_cache) > 0
|
|
|
|
def test_multi_tensor_jit_body(self):
|
|
@TinyJit
|
|
def jf() -> Tensor:
|
|
a = Tensor.ones(256).contiguous().realize()
|
|
b = Tensor.ones(256).contiguous().realize()
|
|
a.shard_((d0, d1))
|
|
b.shard_((d0, d1))
|
|
return (a + b).realize()
|
|
|
|
for _ in range(5):
|
|
r = jf()
|
|
np.testing.assert_allclose(r.numpy(), np.ones(256)+np.ones(256), atol=1e-4, rtol=1e-5)
|
|
assert len(jf.jit_cache) > 0
|
|
|
|
#@unittest.skipIf(CI and Device.DEFAULT=="METAL", "no ICB in CI, creation of graph fails")
|
|
@unittest.skip("test broken")
|
|
def test_multi_device_jit_graph(self):
|
|
if Device[d0].graph is None or Device[d1].graph is None: raise unittest.SkipTest("only test graphs")
|
|
|
|
@TinyJit
|
|
def jf(a: Tensor, b: Tensor, c: Tensor, d:Tensor):
|
|
# Create 80 entries on device 0: 2 batches.
|
|
for _ in range(40):
|
|
a = ((a + b).realize() + (a * b).realize()).realize()
|
|
# Create 80 entries on device 1: 2 batches.
|
|
for _ in range(40):
|
|
c = ((c + d).realize() + (c * d).realize()).realize()
|
|
# Create a copy from device 0 to 1: 1 entry.
|
|
a = a.to(d1).realize()
|
|
# Creates one last entry on device 1: 1 batch.
|
|
return (a + c).realize()
|
|
|
|
a = Tensor.randn(10, 10, device=d0).realize()
|
|
b = Tensor.randn(10, 10, device=d0).realize()
|
|
c = Tensor.randn(10, 10, device=d1).realize()
|
|
d = Tensor.randn(10, 10, device=d1).realize()
|
|
|
|
ref = jf(a, b, c, d).numpy()
|
|
for _ in range(5):
|
|
o = jf(a, b, c, d).numpy()
|
|
np.testing.assert_allclose(ref, o, atol=1e-4, rtol=1e-5)
|
|
|
|
graph_d0 = Device[d0].graph.func if isinstance(Device[d0].graph, functools.partial) else Device[d0].graph
|
|
graph_d1 = Device[d1].graph.func if isinstance(Device[d1].graph, functools.partial) else Device[d1].graph
|
|
# Checking that 2 graphs per device, 1 copy and 1 last graph on device 1 are created.
|
|
assert isinstance(jf.jit_cache[0].prg, graph_d0)
|
|
assert isinstance(jf.jit_cache[1].prg, graph_d0)
|
|
assert isinstance(jf.jit_cache[2].prg, graph_d1)
|
|
assert isinstance(jf.jit_cache[3].prg, graph_d1)
|
|
assert isinstance(jf.jit_cache[4].prg, BufferCopy)
|
|
assert isinstance(jf.jit_cache[5].prg, graph_d1)
|
|
|
|
def test_uneven_shard(self):
|
|
for N in range(1, 6):
|
|
X = Tensor.rand(4, 1, 257).contiguous().realize()
|
|
n = X.numpy()
|
|
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
|
|
X.shard_(devices, 2)
|
|
np.testing.assert_equal(X.numpy(), n)
|
|
np.testing.assert_equal(X.reshape(2, 2, 257).numpy(), n.reshape((2, 2, 257)))
|
|
np.testing.assert_equal(X.shrink(((0,2), (0, 1), (0,257))).numpy(), n[0:2, 0:1, 0:257])
|
|
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
|
|
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
|
|
|
|
def test_bn_ast_on_devices(self):
|
|
devices = (d0, d1, d2, d3)
|
|
t = Tensor.empty((16, 64, 112, 112)).shard(devices, axis=0)
|
|
bn = nn.BatchNorm2d(64)
|
|
for p in get_parameters(bn): p.shard_(devices).realize()
|
|
|
|
out = bn(t)
|
|
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices and sched.ast[0].op is not LoadOps.COPY]
|
|
assert set(out.device for sched in scheds for out in sched.outputs) == set(devices), "should have ast on each shard device"
|
|
asts = [sched.ast for sched in scheds]
|
|
assert len(asts) == 8, len(asts)
|
|
# test case to show that ast can be different on devices
|
|
# TODO: make ast identical on devices
|
|
#assert len(set(asts)) == 4, len(asts)
|
|
# for i, ast in enumerate(asts):
|
|
# print(f"{i} {ast}")
|
|
|
|
def test_reshape_on_axis(self):
|
|
devices = (d0, d1, d2)
|
|
|
|
t0 = Tensor.rand((26, 15, 7)).shard(devices, axis=1)
|
|
|
|
# test split and rejoin to the right
|
|
t1 = t0.reshape((26, 3, 5, 7))
|
|
t2 = t0.reshape((26, 3, 35))
|
|
t3 = t1.reshape((26, 15, 7))
|
|
t4 = t2.reshape((26, 105,))
|
|
|
|
for t in [t0, t1, t2, t3, t4]:
|
|
assert t.lazydata.axis == 1
|
|
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
|
|
|
|
# test shape-one axis
|
|
t5 = t4.reshape((26, 1, 105))
|
|
assert t5.lazydata.axis == 2
|
|
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
|
|
|
|
# test split and rejoin to the right and reshape to the left
|
|
t5 = t0.reshape((2, 13, 3, 5, 7))
|
|
t6 = t0.reshape((13, 2, 3, 7, 5))
|
|
t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5))
|
|
np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten())
|
|
assert t5.lazydata.axis == 2
|
|
np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten())
|
|
assert t6.lazydata.axis == 2
|
|
np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten())
|
|
assert t7.lazydata.axis == 3
|
|
|
|
# test no left join
|
|
with self.assertRaises((AssertionError, ValueError)):
|
|
t0.reshape((26*15,7))
|
|
|
|
def test_reshape_on_axis_uneven(self):
|
|
devices = (d0, d1, d2)
|
|
t0 = Tensor.rand((4, 8, 15)).shard(devices, axis=1)
|
|
|
|
# no split axis if uneven
|
|
with self.assertRaises((AssertionError, ValueError)):
|
|
t0.reshape((4,4,2,15))
|
|
|
|
# ok to split reshape left and right though
|
|
t1 = t0.reshape(2, 2, 8, 3, 5)
|
|
np.testing.assert_allclose(t0.numpy().flatten(), t1.numpy().flatten())
|
|
assert t1.lazydata.axis == 2
|
|
|
|
def test_mlb_assign_change_axis(self):
|
|
devices = (d0, d1)
|
|
|
|
t_none = Tensor.zeros((16, 16)).shard(devices).contiguous().realize()
|
|
t_zero = Tensor.ones((16, 16)).shard(devices, axis=0)
|
|
with self.assertRaises(AssertionError):
|
|
# don't allow assigns that change axes
|
|
t_none.assign(t_zero)
|
|
|
|
def test_dropout_on_shard(self):
|
|
Tensor.training = True
|
|
X = Tensor.ones(256).to(devices_2)
|
|
output = X.dropout(0.5)
|
|
output.numpy()
|
|
Tensor.training = False
|
|
|
|
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
|
class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
|
# shrink a multitensor on sharded axis
|
|
def test_shrink_bad_args(self):
|
|
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
|
|
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
# sharded axis shrink on non-device boundry is not allowed
|
|
a = t.shrink(((0, 3), (0, 8)))
|
|
with self.assertRaises(AssertionError):
|
|
# cannot shrink sharded and non-sharded axis at the same time
|
|
a = t.shrink(((0, 2), (2, 4)))
|
|
|
|
a = t.shrink(((0, 2), (0, 8)))
|
|
assert a.shape == (2, 8)
|
|
assert a.lazydata.real == [True, False, False, False]
|
|
|
|
with self.assertRaises(AssertionError):
|
|
# cannot pad sharded and non-sharded axis at the same time
|
|
p = a.pad(((0, 6), (0, 1)))
|
|
|
|
with self.assertRaises(AssertionError):
|
|
# can only pad to whole axis
|
|
p = a.pad(((1, 5), (0, 0)))
|
|
|
|
p = a.pad(((0, 6), (0, 0)))
|
|
assert p.shape == (8, 8)
|
|
assert p.lazydata.real == [True, True, True, True]
|
|
|
|
def test_ops(self):
|
|
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
|
|
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
|
|
for i in range(4):
|
|
print(f"{i=}")
|
|
a = t.shrink(((0+2*i,2+2*i),None))
|
|
b = Tensor(t.numpy()[0+2*i:2+2*i])
|
|
assert a.shape == b.shape == (2, 8)
|
|
assert a.lazydata.real == [i==j for j in range(4)]
|
|
np.testing.assert_allclose(a.numpy(), b.numpy())
|
|
# cast
|
|
np.testing.assert_allclose(a.float().numpy(), b.float().numpy())
|
|
|
|
# elementwise
|
|
np.testing.assert_allclose(a.exp().numpy(), b.exp().numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.reciprocal().numpy(), b.reciprocal().numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.pow(-0.5).numpy(), b.pow(-0.5).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose((a+a).numpy(), (b+b).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_equal((a+1).numpy(), (b+1).numpy())
|
|
np.testing.assert_equal((1+a).numpy(), (1+b).numpy())
|
|
np.testing.assert_allclose((a.where(a+a, a)).numpy(), (b.where(b+b, b)).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose((a.where(1, 0)).numpy(), (b.where(1, 0)).numpy(), rtol=1e-7, atol=1e-3)
|
|
|
|
# reduce
|
|
np.testing.assert_allclose(a.max().numpy(), b.max().numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.sum().numpy(), b.sum().numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.mean().numpy(), b.mean().numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.max(0).numpy(), b.max(0).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.sum(0).numpy(), b.sum(0).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.mean(0).numpy(), b.mean(0).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.max(1).numpy(), b.max(1).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.sum(1).numpy(), b.sum(1).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.mean(1).numpy(), b.mean(1).numpy(), rtol=1e-7, atol=1e-3)
|
|
|
|
# pad it back
|
|
np.testing.assert_allclose(a.pad(((2*i, 2*(4-i-1)), None)).numpy(), b.pad(((2*i, 2*(4-i-1)), None)).numpy(), rtol=1e-7, atol=1e-3)
|
|
|
|
# other movement
|
|
np.testing.assert_allclose(a.pad((None, (1, 1))).numpy(), b.pad((None, (1, 1))).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.shrink((None, (1, 3))).numpy(), b.shrink((None, (1, 3))).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.permute((1, 0)).numpy(), b.permute((1, 0)).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.reshape((2, 2, 4)).numpy(), b.reshape((2, 2, 4)).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3)
|
|
np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3)
|
|
|
|
def test_uneven(self):
|
|
t = Tensor.arange(24).reshape(3, 8).contiguous().realize()
|
|
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0)
|
|
|
|
a = t.shrink(((0, 2), None))
|
|
b = t.shrink(((2, 3), None))
|
|
na = t.numpy()[0:2]
|
|
nb = t.numpy()[2:3]
|
|
np.testing.assert_equal(a.numpy(), na)
|
|
np.testing.assert_equal(b.numpy(), nb)
|
|
np.testing.assert_equal((a+1).numpy(), na+1)
|
|
np.testing.assert_equal((b+1).numpy(), nb+1)
|
|
np.testing.assert_equal((1+a).numpy(), 1+na)
|
|
np.testing.assert_equal((1+b).numpy(), 1+nb)
|
|
np.testing.assert_equal((a+a).numpy(), na+na)
|
|
np.testing.assert_equal((b+b).numpy(), nb+nb)
|
|
|
|
def test_add_two_partitions(self):
|
|
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
|
|
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
|
|
|
|
a = t.shrink(((2, 4), None))
|
|
b = t.shrink(((6, 8), None))
|
|
na = t.numpy()[2:4]
|
|
nb = t.numpy()[6:8]
|
|
np.testing.assert_equal(a.numpy(), na)
|
|
np.testing.assert_equal(b.numpy(), nb)
|
|
with self.assertRaises(AssertionError):
|
|
# cannot add directly
|
|
c = a + b
|
|
|
|
c = a.pad(((2, 4), None)) + b.pad(((6, 0), None))
|
|
expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb])
|
|
np.testing.assert_equal(c.numpy(), expected)
|
|
|
|
def test_add_different_tensors(self):
|
|
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
|
x = Tensor.arange(64).reshape(8, 8).contiguous().realize().shard(devices, axis=0)
|
|
|
|
to_add = []
|
|
for i in range(len(devices)):
|
|
to_add.append((Tensor.ones(2, 8) * i).shard(devices))
|
|
|
|
added:List[Tensor] = []
|
|
for bound, a in zip(x.lazydata.bounds, to_add):
|
|
added.append(x[bound[0]:bound[1]] + a)
|
|
|
|
output = added[0].cat(*added[1:])
|
|
expected = np.arange(64).reshape((8,8)) + np.array([[0,0,1,1,2,2,3,3] for _ in range(8)]).T
|
|
np.testing.assert_allclose(output.numpy(), expected)
|
|
|
|
def test_unsynced_backprop_conv_bn(self):
|
|
from extra.lr_scheduler import OneCycleLR
|
|
|
|
convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)]
|
|
bns = [nn.BatchNorm2d(16), nn.BatchNorm2d(16)]
|
|
|
|
for p in get_parameters(convs + bns):
|
|
p.shard_((d1, d2))
|
|
optim = nn.optim.Adam(get_parameters(convs + bns))
|
|
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
|
|
lr_sched.step()
|
|
|
|
fake_image = Tensor.rand((8, 3, 32, 32)).shard((d1, d2), axis=0)
|
|
|
|
f1 = fake_image.shrink(((0, 4), None, None, None))
|
|
f2 = fake_image.shrink(((4, 8), None, None, None))
|
|
|
|
out1 = bns[0](convs[0](f1))
|
|
out2 = bns[1](convs[1](f2))
|
|
out = out1.cat(out2)
|
|
optim.zero_grad()
|
|
out.mean().backward()
|
|
optim.step()
|
|
out.numpy()
|
|
|
|
def test_unsynced_backprop_standalone_bn(self):
|
|
from extra.lr_scheduler import OneCycleLR
|
|
GPUS = (d1, d2)
|
|
|
|
class BatchNorm:
|
|
def __init__(self, num_features):
|
|
self.bns:List[nn.BatchNorm2d] = []
|
|
for _ in GPUS:
|
|
bn = nn.BatchNorm2d(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
|
|
self.bns.append(bn)
|
|
|
|
def __call__(self, x:Tensor):
|
|
bn_ts = []
|
|
for bound, bn in zip(x.lazydata.bounds, self.bns):
|
|
xi = x.shrink((bound, None, None, None))
|
|
bni = bn(xi)
|
|
bn_ts.append(bni)
|
|
return bn_ts[0].cat(*bn_ts[1:])
|
|
|
|
with Tensor.train():
|
|
conv = nn.Conv2d(3, 16, 3)
|
|
bn = BatchNorm(16)
|
|
|
|
for p in get_parameters([conv, bn]):
|
|
p.shard_(GPUS)
|
|
optim = nn.optim.Adam(get_parameters([conv, bn]))
|
|
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
|
|
lr_sched.step()
|
|
|
|
fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0)
|
|
|
|
out = bn(conv(fake_image))
|
|
optim.zero_grad()
|
|
out.mean().backward()
|
|
optim.step()
|
|
|
|
def test_unsynced_backprop_sync_weights(self):
|
|
from extra.lr_scheduler import OneCycleLR
|
|
from examples.hlb_cifar10 import UnsyncedBatchNorm
|
|
GPUS = (d1, d2)
|
|
|
|
with Tensor.train():
|
|
conv = nn.Conv2d(3, 16, 3)
|
|
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
|
|
|
|
for k, p in get_state_dict([conv, bn]).items():
|
|
if 'running_mean' in k or 'running_var' in k:
|
|
p.shard_(GPUS, axis=0)
|
|
else:
|
|
p.to_(GPUS)
|
|
optim = nn.optim.Adam(get_parameters([conv, bn]))
|
|
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
|
|
lr_sched.step()
|
|
|
|
fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0)
|
|
|
|
out = bn(conv(fake_image))
|
|
optim.zero_grad()
|
|
out.mean().backward()
|
|
optim.step()
|
|
|
|
@given(strat.sampled_from((False, True)))
|
|
def test_batchnorm(self, is_training):
|
|
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
|
x = Tensor.arange(4096).reshape(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
|
|
|
|
with Tensor.train(is_training):
|
|
bns = []
|
|
for _ in range(len(devices)):
|
|
bn = nn.BatchNorm2d(8)
|
|
for p in get_parameters(bn):
|
|
p.shard_(devices)
|
|
bn.weight.requires_grad = True
|
|
bn.bias.requires_grad = True
|
|
bns.append(bn)
|
|
|
|
bn_ts = []
|
|
for bound, bn in zip(x.lazydata.bounds, bns):
|
|
bni = bn(x[bound[0]:bound[1]])
|
|
bn_ts.append(bni)
|
|
|
|
bn_ts[0].cat(*bn_ts[1:]).numpy()
|
|
|
|
def test_synced_vs_unsynced_bn(self):
|
|
from examples.hlb_cifar10 import UnsyncedBatchNorm
|
|
from tinygrad.nn import BatchNorm2d
|
|
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
|
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
|
|
|
|
with Tensor.train():
|
|
synced_bn = BatchNorm2d(8)
|
|
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
|
|
|
|
for p in get_parameters(synced_bn):
|
|
p.shard_(devices)
|
|
for k, p in get_state_dict(unsynced_bn).items():
|
|
if 'running_mean' in k or 'running_var' in k:
|
|
p.shard_(devices, axis=0)
|
|
else:
|
|
p.to_(devices)
|
|
|
|
synced_out = synced_bn(x)
|
|
synced_si = [si for si in create_schedule(synced_out.lazydata.lbs)]
|
|
unsynced_out = unsynced_bn(x)
|
|
unsynced_si = [si for si in create_schedule(unsynced_out.lazydata.lbs)]
|
|
|
|
# TODO: test synced / unsynced batchnorm cross device kernel and copies
|
|
assert synced_si
|
|
assert unsynced_si
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|