[bounty] Don't use numpy inside hlb_cifar10 training loop (#10777)

* Don't use numpy inside hlb_cifar10 training loop

* Lint it

* jit it

* Drop the last half-batch

* Use gather for random_crop and reuse perms

* Wrap train_cifar in FUSE_ARANGE context

* No need to pass FUSE_ARANGE=1 to hlb_cifar10.py

* Add cutmix to jittable augmentations

* Remove .contiguous() from fetch_batches

* Fix indexing boundary

---------

Co-authored-by: Irwin1138 <irwin1139@gmail.com>
This commit is contained in:
Alexey Zaytsev
2025-06-23 21:24:56 -03:00
committed by GitHub
parent 383010555f
commit 230ad3a460
2 changed files with 35 additions and 33 deletions

View File

@@ -27,7 +27,7 @@ jobs:
BENCHMARK_LOG=search_sdxl_cached PYTHONPATH=. AMD=1 JITBEAM=2 python examples/sdxl.py --noshow --timing --seed 0
- name: Run winograd cifar with new search
run: |
BENCHMARK_LOG=search_wino_cifar WINO=1 DEFAULT_FLOAT=HALF FUSE_ARANGE=1 JITBEAM=4 IGNORE_BEAM_CACHE=1 DISABLE_COMPILER_CACHE=1 BS=1024 STEPS=500 python examples/hlb_cifar10.py
BENCHMARK_LOG=search_wino_cifar WINO=1 DEFAULT_FLOAT=HALF JITBEAM=4 IGNORE_BEAM_CACHE=1 DISABLE_COMPILER_CACHE=1 BS=1024 STEPS=500 python examples/hlb_cifar10.py
- name: Run winograd cifar with cached search
run: |
BENCHMARK_LOG=search_wino_cifar_cached WINO=1 DEFAULT_FLOAT=HALF FUSE_ARANGE=1 JITBEAM=4 BS=1024 STEPS=500 python examples/hlb_cifar10.py
BENCHMARK_LOG=search_wino_cifar_cached WINO=1 DEFAULT_FLOAT=HALF JITBEAM=4 BS=1024 STEPS=500 python examples/hlb_cifar10.py

View File

@@ -7,7 +7,7 @@ import random, time
import numpy as np
from typing import Optional
from extra.lr_scheduler import OneCycleLR
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit, Variable
from tinygrad.nn.state import get_state_dict, get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod
@@ -145,6 +145,7 @@ hyp = {
},
}
@Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1))
def train_cifar():
def set_seed(seed):
@@ -201,24 +202,37 @@ def train_cifar():
idx_y = Tensor.arange(H, dtype=dtypes.int32).reshape((1,1,H,1))
return (idx_x >= low_x) * (idx_x < (low_x + mask_size)) * (idx_y >= low_y) * (idx_y < (low_y + mask_size))
def random_crop(X:Tensor, crop_size=32):
mask = make_square_mask(X.shape, crop_size)
mask = mask.expand((-1,3,-1,-1))
X_cropped = Tensor(X.numpy()[mask.numpy()])
return X_cropped.reshape((-1, 3, crop_size, crop_size))
# Similar, but different enough.
def make_random_crop_indices(shape, mask_size) -> Tensor:
BS, _, H, W = shape
low_x = Tensor.randint(BS, low=0, high=W-mask_size).reshape(BS,1,1,1)
low_y = Tensor.randint(BS, low=0, high=H-mask_size).reshape(BS,1,1,1)
idx_x = Tensor.arange(mask_size, dtype=dtypes.int32).reshape((1,1,1,mask_size))
idx_y = Tensor.arange(mask_size, dtype=dtypes.int32).reshape((1,1,mask_size,1))
return low_x, low_y, idx_x, idx_y
def cutmix(X:Tensor, Y:Tensor, mask_size=3):
# fill the square with randomly selected images from the same batch
def random_crop(X:Tensor, crop_size=32):
Xs, Ys, Xi, Yi = make_random_crop_indices(X.shape, crop_size)
return X.gather(-1, (Xs + Xi).expand(-1, 3, X.shape[2], -1)).gather(-2, ((Ys+Yi).expand(-1, 3, crop_size, crop_size)))
def cutmix(X, Y, order, mask_size=3):
mask = make_square_mask(X.shape, mask_size)
order = list(range(0, X.shape[0]))
random.shuffle(order)
X_patch = Tensor(X.numpy()[order], device=X.device, dtype=X.dtype)
Y_patch = Tensor(Y.numpy()[order], device=Y.device, dtype=Y.dtype)
X_patch, Y_patch = X[order], Y[order]
X_cutmix = mask.where(X_patch, X)
mix_portion = float(mask_size**2)/(X.shape[-2]*X.shape[-1])
Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
return X_cutmix, Y_cutmix
@TinyJit
def augmentations(X:Tensor, Y:Tensor):
perms = Tensor.randperm(X.shape[0], device=X.device) # We reuse perms for cutmix, because they are expensivne to generate
if getenv("RANDOM_CROP", 1):
X = random_crop(X, crop_size=32)
if getenv("RANDOM_FLIP", 1):
X = (Tensor.rand(X.shape[0],1,1,1) < 0.5).where(X.flip(-1), X) # flip LR
X, Y = X[perms], Y[perms]
return X, Y, *cutmix(X, Y, perms, mask_size=hyp['net']['cutmix_size'])
# the operations that remain inside batch fetcher is the ones that involves random operations
def fetch_batches(X_in:Tensor, Y_in:Tensor, BS:int, is_train:bool):
step, epoch = 0, 0
@@ -226,28 +240,16 @@ def train_cifar():
st = time.monotonic()
X, Y = X_in, Y_in
if is_train:
# TODO: these are not jitted
if getenv("RANDOM_CROP", 1):
X = random_crop(X, crop_size=32)
if getenv("RANDOM_FLIP", 1):
X = (Tensor.rand(X.shape[0],1,1,1) < 0.5).where(X.flip(-1), X) # flip LR
if getenv("CUTMIX", 1):
if step >= hyp['net']['cutmix_steps']:
X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
order = list(range(0, X.shape[0]))
random.shuffle(order)
X, Y = X.numpy()[order], Y.numpy()[order]
else:
X, Y = X.numpy(), Y.numpy()
X, Y, X_cm, Y_cm = augmentations(X, Y)
if getenv("CUTMIX", 1) and step >= hyp['net']['cutmix_steps']: X, Y = X_cm, Y_cm
et = time.monotonic()
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({epoch=})")
for i in range(0, X.shape[0], BS):
# pad the last batch # TODO: not correct for test
batch_end = min(i+BS, Y.shape[0])
x = Tensor(X[batch_end-BS:batch_end], device=X_in.device, dtype=X_in.dtype)
y = Tensor(Y[batch_end-BS:batch_end], device=Y_in.device, dtype=Y_in.dtype)
vi = Variable("i", 0, (full_batches := (X.shape[0] // BS) * BS) - BS)
for i in range(0, full_batches, BS):
step += 1
yield x, y
vib = vi.bind(i)
yield X[vib:vib+BS], Y[vib:vib+BS]
epoch += 1
if not is_train: break