mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
[bounty] Don't use numpy inside hlb_cifar10 training loop (#10777)
* Don't use numpy inside hlb_cifar10 training loop * Lint it * jit it * Drop the last half-batch * Use gather for random_crop and reuse perms * Wrap train_cifar in FUSE_ARANGE context * No need to pass FUSE_ARANGE=1 to hlb_cifar10.py * Add cutmix to jittable augmentations * Remove .contiguous() from fetch_batches * Fix indexing boundary --------- Co-authored-by: Irwin1138 <irwin1139@gmail.com>
This commit is contained in:
4
.github/workflows/benchmark_search.yml
vendored
4
.github/workflows/benchmark_search.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
BENCHMARK_LOG=search_sdxl_cached PYTHONPATH=. AMD=1 JITBEAM=2 python examples/sdxl.py --noshow --timing --seed 0
|
||||
- name: Run winograd cifar with new search
|
||||
run: |
|
||||
BENCHMARK_LOG=search_wino_cifar WINO=1 DEFAULT_FLOAT=HALF FUSE_ARANGE=1 JITBEAM=4 IGNORE_BEAM_CACHE=1 DISABLE_COMPILER_CACHE=1 BS=1024 STEPS=500 python examples/hlb_cifar10.py
|
||||
BENCHMARK_LOG=search_wino_cifar WINO=1 DEFAULT_FLOAT=HALF JITBEAM=4 IGNORE_BEAM_CACHE=1 DISABLE_COMPILER_CACHE=1 BS=1024 STEPS=500 python examples/hlb_cifar10.py
|
||||
- name: Run winograd cifar with cached search
|
||||
run: |
|
||||
BENCHMARK_LOG=search_wino_cifar_cached WINO=1 DEFAULT_FLOAT=HALF FUSE_ARANGE=1 JITBEAM=4 BS=1024 STEPS=500 python examples/hlb_cifar10.py
|
||||
BENCHMARK_LOG=search_wino_cifar_cached WINO=1 DEFAULT_FLOAT=HALF JITBEAM=4 BS=1024 STEPS=500 python examples/hlb_cifar10.py
|
||||
|
||||
@@ -7,7 +7,7 @@ import random, time
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
|
||||
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit, Variable
|
||||
from tinygrad.nn.state import get_state_dict, get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod
|
||||
@@ -145,6 +145,7 @@ hyp = {
|
||||
},
|
||||
}
|
||||
|
||||
@Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1))
|
||||
def train_cifar():
|
||||
|
||||
def set_seed(seed):
|
||||
@@ -201,24 +202,37 @@ def train_cifar():
|
||||
idx_y = Tensor.arange(H, dtype=dtypes.int32).reshape((1,1,H,1))
|
||||
return (idx_x >= low_x) * (idx_x < (low_x + mask_size)) * (idx_y >= low_y) * (idx_y < (low_y + mask_size))
|
||||
|
||||
def random_crop(X:Tensor, crop_size=32):
|
||||
mask = make_square_mask(X.shape, crop_size)
|
||||
mask = mask.expand((-1,3,-1,-1))
|
||||
X_cropped = Tensor(X.numpy()[mask.numpy()])
|
||||
return X_cropped.reshape((-1, 3, crop_size, crop_size))
|
||||
# Similar, but different enough.
|
||||
def make_random_crop_indices(shape, mask_size) -> Tensor:
|
||||
BS, _, H, W = shape
|
||||
low_x = Tensor.randint(BS, low=0, high=W-mask_size).reshape(BS,1,1,1)
|
||||
low_y = Tensor.randint(BS, low=0, high=H-mask_size).reshape(BS,1,1,1)
|
||||
idx_x = Tensor.arange(mask_size, dtype=dtypes.int32).reshape((1,1,1,mask_size))
|
||||
idx_y = Tensor.arange(mask_size, dtype=dtypes.int32).reshape((1,1,mask_size,1))
|
||||
return low_x, low_y, idx_x, idx_y
|
||||
|
||||
def cutmix(X:Tensor, Y:Tensor, mask_size=3):
|
||||
# fill the square with randomly selected images from the same batch
|
||||
def random_crop(X:Tensor, crop_size=32):
|
||||
Xs, Ys, Xi, Yi = make_random_crop_indices(X.shape, crop_size)
|
||||
return X.gather(-1, (Xs + Xi).expand(-1, 3, X.shape[2], -1)).gather(-2, ((Ys+Yi).expand(-1, 3, crop_size, crop_size)))
|
||||
|
||||
def cutmix(X, Y, order, mask_size=3):
|
||||
mask = make_square_mask(X.shape, mask_size)
|
||||
order = list(range(0, X.shape[0]))
|
||||
random.shuffle(order)
|
||||
X_patch = Tensor(X.numpy()[order], device=X.device, dtype=X.dtype)
|
||||
Y_patch = Tensor(Y.numpy()[order], device=Y.device, dtype=Y.dtype)
|
||||
X_patch, Y_patch = X[order], Y[order]
|
||||
X_cutmix = mask.where(X_patch, X)
|
||||
mix_portion = float(mask_size**2)/(X.shape[-2]*X.shape[-1])
|
||||
Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
|
||||
return X_cutmix, Y_cutmix
|
||||
|
||||
@TinyJit
|
||||
def augmentations(X:Tensor, Y:Tensor):
|
||||
perms = Tensor.randperm(X.shape[0], device=X.device) # We reuse perms for cutmix, because they are expensivne to generate
|
||||
if getenv("RANDOM_CROP", 1):
|
||||
X = random_crop(X, crop_size=32)
|
||||
if getenv("RANDOM_FLIP", 1):
|
||||
X = (Tensor.rand(X.shape[0],1,1,1) < 0.5).where(X.flip(-1), X) # flip LR
|
||||
X, Y = X[perms], Y[perms]
|
||||
return X, Y, *cutmix(X, Y, perms, mask_size=hyp['net']['cutmix_size'])
|
||||
|
||||
# the operations that remain inside batch fetcher is the ones that involves random operations
|
||||
def fetch_batches(X_in:Tensor, Y_in:Tensor, BS:int, is_train:bool):
|
||||
step, epoch = 0, 0
|
||||
@@ -226,28 +240,16 @@ def train_cifar():
|
||||
st = time.monotonic()
|
||||
X, Y = X_in, Y_in
|
||||
if is_train:
|
||||
# TODO: these are not jitted
|
||||
if getenv("RANDOM_CROP", 1):
|
||||
X = random_crop(X, crop_size=32)
|
||||
if getenv("RANDOM_FLIP", 1):
|
||||
X = (Tensor.rand(X.shape[0],1,1,1) < 0.5).where(X.flip(-1), X) # flip LR
|
||||
if getenv("CUTMIX", 1):
|
||||
if step >= hyp['net']['cutmix_steps']:
|
||||
X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
|
||||
order = list(range(0, X.shape[0]))
|
||||
random.shuffle(order)
|
||||
X, Y = X.numpy()[order], Y.numpy()[order]
|
||||
else:
|
||||
X, Y = X.numpy(), Y.numpy()
|
||||
X, Y, X_cm, Y_cm = augmentations(X, Y)
|
||||
if getenv("CUTMIX", 1) and step >= hyp['net']['cutmix_steps']: X, Y = X_cm, Y_cm
|
||||
et = time.monotonic()
|
||||
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({epoch=})")
|
||||
for i in range(0, X.shape[0], BS):
|
||||
# pad the last batch # TODO: not correct for test
|
||||
batch_end = min(i+BS, Y.shape[0])
|
||||
x = Tensor(X[batch_end-BS:batch_end], device=X_in.device, dtype=X_in.dtype)
|
||||
y = Tensor(Y[batch_end-BS:batch_end], device=Y_in.device, dtype=Y_in.dtype)
|
||||
|
||||
vi = Variable("i", 0, (full_batches := (X.shape[0] // BS) * BS) - BS)
|
||||
for i in range(0, full_batches, BS):
|
||||
step += 1
|
||||
yield x, y
|
||||
vib = vi.bind(i)
|
||||
yield X[vib:vib+BS], Y[vib:vib+BS]
|
||||
epoch += 1
|
||||
if not is_train: break
|
||||
|
||||
|
||||
Reference in New Issue
Block a user