From 230ad3a46061983df996100edeb50930f532bd6e Mon Sep 17 00:00:00 2001 From: Alexey Zaytsev Date: Mon, 23 Jun 2025 21:24:56 -0300 Subject: [PATCH] [bounty] Don't use numpy inside hlb_cifar10 training loop (#10777) * Don't use numpy inside hlb_cifar10 training loop * Lint it * jit it * Drop the last half-batch * Use gather for random_crop and reuse perms * Wrap train_cifar in FUSE_ARANGE context * No need to pass FUSE_ARANGE=1 to hlb_cifar10.py * Add cutmix to jittable augmentations * Remove .contiguous() from fetch_batches * Fix indexing boundary --------- Co-authored-by: Irwin1138 --- .github/workflows/benchmark_search.yml | 4 +- examples/hlb_cifar10.py | 64 +++++++++++++------------- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/.github/workflows/benchmark_search.yml b/.github/workflows/benchmark_search.yml index df3055ba86..91284d1808 100644 --- a/.github/workflows/benchmark_search.yml +++ b/.github/workflows/benchmark_search.yml @@ -27,7 +27,7 @@ jobs: BENCHMARK_LOG=search_sdxl_cached PYTHONPATH=. AMD=1 JITBEAM=2 python examples/sdxl.py --noshow --timing --seed 0 - name: Run winograd cifar with new search run: | - BENCHMARK_LOG=search_wino_cifar WINO=1 DEFAULT_FLOAT=HALF FUSE_ARANGE=1 JITBEAM=4 IGNORE_BEAM_CACHE=1 DISABLE_COMPILER_CACHE=1 BS=1024 STEPS=500 python examples/hlb_cifar10.py + BENCHMARK_LOG=search_wino_cifar WINO=1 DEFAULT_FLOAT=HALF JITBEAM=4 IGNORE_BEAM_CACHE=1 DISABLE_COMPILER_CACHE=1 BS=1024 STEPS=500 python examples/hlb_cifar10.py - name: Run winograd cifar with cached search run: | - BENCHMARK_LOG=search_wino_cifar_cached WINO=1 DEFAULT_FLOAT=HALF FUSE_ARANGE=1 JITBEAM=4 BS=1024 STEPS=500 python examples/hlb_cifar10.py + BENCHMARK_LOG=search_wino_cifar_cached WINO=1 DEFAULT_FLOAT=HALF JITBEAM=4 BS=1024 STEPS=500 python examples/hlb_cifar10.py diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index ff6c48b608..378700de76 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -7,7 +7,7 @@ import random, time import numpy as np from typing import Optional from extra.lr_scheduler import OneCycleLR -from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit +from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit, Variable from tinygrad.nn.state import get_state_dict, get_parameters from tinygrad.nn import optim from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod @@ -145,6 +145,7 @@ hyp = { }, } +@Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)) def train_cifar(): def set_seed(seed): @@ -201,24 +202,37 @@ def train_cifar(): idx_y = Tensor.arange(H, dtype=dtypes.int32).reshape((1,1,H,1)) return (idx_x >= low_x) * (idx_x < (low_x + mask_size)) * (idx_y >= low_y) * (idx_y < (low_y + mask_size)) - def random_crop(X:Tensor, crop_size=32): - mask = make_square_mask(X.shape, crop_size) - mask = mask.expand((-1,3,-1,-1)) - X_cropped = Tensor(X.numpy()[mask.numpy()]) - return X_cropped.reshape((-1, 3, crop_size, crop_size)) + # Similar, but different enough. + def make_random_crop_indices(shape, mask_size) -> Tensor: + BS, _, H, W = shape + low_x = Tensor.randint(BS, low=0, high=W-mask_size).reshape(BS,1,1,1) + low_y = Tensor.randint(BS, low=0, high=H-mask_size).reshape(BS,1,1,1) + idx_x = Tensor.arange(mask_size, dtype=dtypes.int32).reshape((1,1,1,mask_size)) + idx_y = Tensor.arange(mask_size, dtype=dtypes.int32).reshape((1,1,mask_size,1)) + return low_x, low_y, idx_x, idx_y - def cutmix(X:Tensor, Y:Tensor, mask_size=3): - # fill the square with randomly selected images from the same batch + def random_crop(X:Tensor, crop_size=32): + Xs, Ys, Xi, Yi = make_random_crop_indices(X.shape, crop_size) + return X.gather(-1, (Xs + Xi).expand(-1, 3, X.shape[2], -1)).gather(-2, ((Ys+Yi).expand(-1, 3, crop_size, crop_size))) + + def cutmix(X, Y, order, mask_size=3): mask = make_square_mask(X.shape, mask_size) - order = list(range(0, X.shape[0])) - random.shuffle(order) - X_patch = Tensor(X.numpy()[order], device=X.device, dtype=X.dtype) - Y_patch = Tensor(Y.numpy()[order], device=Y.device, dtype=Y.dtype) + X_patch, Y_patch = X[order], Y[order] X_cutmix = mask.where(X_patch, X) mix_portion = float(mask_size**2)/(X.shape[-2]*X.shape[-1]) Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y return X_cutmix, Y_cutmix + @TinyJit + def augmentations(X:Tensor, Y:Tensor): + perms = Tensor.randperm(X.shape[0], device=X.device) # We reuse perms for cutmix, because they are expensivne to generate + if getenv("RANDOM_CROP", 1): + X = random_crop(X, crop_size=32) + if getenv("RANDOM_FLIP", 1): + X = (Tensor.rand(X.shape[0],1,1,1) < 0.5).where(X.flip(-1), X) # flip LR + X, Y = X[perms], Y[perms] + return X, Y, *cutmix(X, Y, perms, mask_size=hyp['net']['cutmix_size']) + # the operations that remain inside batch fetcher is the ones that involves random operations def fetch_batches(X_in:Tensor, Y_in:Tensor, BS:int, is_train:bool): step, epoch = 0, 0 @@ -226,28 +240,16 @@ def train_cifar(): st = time.monotonic() X, Y = X_in, Y_in if is_train: - # TODO: these are not jitted - if getenv("RANDOM_CROP", 1): - X = random_crop(X, crop_size=32) - if getenv("RANDOM_FLIP", 1): - X = (Tensor.rand(X.shape[0],1,1,1) < 0.5).where(X.flip(-1), X) # flip LR - if getenv("CUTMIX", 1): - if step >= hyp['net']['cutmix_steps']: - X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size']) - order = list(range(0, X.shape[0])) - random.shuffle(order) - X, Y = X.numpy()[order], Y.numpy()[order] - else: - X, Y = X.numpy(), Y.numpy() + X, Y, X_cm, Y_cm = augmentations(X, Y) + if getenv("CUTMIX", 1) and step >= hyp['net']['cutmix_steps']: X, Y = X_cm, Y_cm et = time.monotonic() print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({epoch=})") - for i in range(0, X.shape[0], BS): - # pad the last batch # TODO: not correct for test - batch_end = min(i+BS, Y.shape[0]) - x = Tensor(X[batch_end-BS:batch_end], device=X_in.device, dtype=X_in.dtype) - y = Tensor(Y[batch_end-BS:batch_end], device=Y_in.device, dtype=Y_in.dtype) + + vi = Variable("i", 0, (full_batches := (X.shape[0] // BS) * BS) - BS) + for i in range(0, full_batches, BS): step += 1 - yield x, y + vib = vi.bind(i) + yield X[vib:vib+BS], Y[vib:vib+BS] epoch += 1 if not is_train: break