From 28edea5d67f362cdefb70448e3956831296b3afb Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 8 Oct 2025 22:41:38 +0800 Subject: [PATCH] delete FUSE_CONV_BW (#12527) --- examples/mlperf/model_train.py | 5 ++--- test/external/external_test_opt.py | 28 +++++++++++++--------------- test/test_schedule.py | 25 ++++++++++++------------- tinygrad/helpers.py | 1 - 4 files changed, 27 insertions(+), 32 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index db3767edd3..6354155b14 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -3,7 +3,7 @@ from pathlib import Path import multiprocessing from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes -from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW, Profiling +from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW @@ -707,7 +707,7 @@ def train_unet3d(): ```BASEDIR= ./examples/mlperf/scripts/setup_kits19_dataset.sh``` 2) To start training the model, run the following: - ```time PYTHONPATH=. WANDB=1 TRAIN_BEAM=3 FUSE_CONV_BW=1 GPUS=6 BS=6 MODEL=unet3d python3 examples/mlperf/model_train.py``` + ```time PYTHONPATH=. WANDB=1 TRAIN_BEAM=3 GPUS=6 BS=6 MODEL=unet3d python3 examples/mlperf/model_train.py``` """ from examples.mlperf.losses import dice_ce_loss from examples.mlperf.metrics import dice_score @@ -749,7 +749,6 @@ def train_unet3d(): "train_beam": TRAIN_BEAM, "eval_beam": EVAL_BEAM, "wino": WINO.value, - "fuse_conv_bw": FUSE_CONV_BW.value, "gpus": GPUS, "default_float": dtypes.default_float.name } diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index c59e60de80..45bb87fd50 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -4,7 +4,7 @@ import numpy as np import torch from tinygrad import GlobalCounters, Tensor, Device -from tinygrad.helpers import getenv, Context, RANGEIFY +from tinygrad.helpers import getenv, RANGEIFY from tinygrad.nn.state import get_parameters from tinygrad.engine.realize import capturing from tinygrad.tensor import _to_np_dtype @@ -217,24 +217,22 @@ class TestOpt(unittest.TestCase): assert cache_len == 1, "reduceop was rerun!" def test_expand_reduce_is_folded_on_same_axis(self): - with Context(FUSE_CONV_BW=1): - for axis in [0, 1]: - for n in [4, 8, 16]: - b = torch.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis) - with CLCache(allowed=3 if RANGEIFY else 2): - a = Tensor.ones(n, n).contiguous().sum(axis).reshape(n, 1).expand(n, n).sum(axis) - a.realize() - np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) - - def test_expand_reduce_is_folded_on_different_axes(self): - with Context(FUSE_CONV_BW=1): - axis1, axis2 = 0, 1 + for axis in [0, 1]: for n in [4, 8, 16]: - b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) + b = torch.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis) with CLCache(allowed=3 if RANGEIFY else 2): - a = Tensor.ones(n, n).contiguous().sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) + a = Tensor.ones(n, n).contiguous().sum(axis).reshape(n, 1).expand(n, n).sum(axis) a.realize() np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) + def test_expand_reduce_is_folded_on_different_axes(self): + axis1, axis2 = 0, 1 + for n in [4, 8, 16]: + b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) + with CLCache(allowed=3 if RANGEIFY else 2): + a = Tensor.ones(n, n).contiguous().sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) + a.realize() + np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) + if __name__ == '__main__': unittest.main() diff --git a/test/test_schedule.py b/test/test_schedule.py index b582b6e089..ae3f0dd1f1 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -45,7 +45,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te def _realize_weights(m): for p in nn.state.get_parameters(m): p.realize() -def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): +def _test_conv2d(allowed:int, dtype:DType=dtypes.float): old_default_float, dtypes.default_float = dtypes.default_float, dtype dtypes.default_float = dtype Tensor.manual_seed(0) @@ -54,7 +54,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True).realize() ret = Tensor.conv2d(img, w).relu().mean().backward() dtypes.default_float = old_default_float - with Context(**kwargs): s = Tensor.schedule(ret, img.grad, w.grad) + s = Tensor.schedule(ret, img.grad, w.grad) run_schedule(s.copy()) cnt = len([si for si in s if si.ast.op is Ops.SINK]) assert cnt == allowed, f"expected {allowed} kernels, got {cnt}" @@ -470,15 +470,14 @@ class TestSchedule(unittest.TestCase): check_schedule(opt.schedule_step(), cnt) def test_fold_batchnorm_backward(self): - with Context(FUSE_CONV_BW=1): - with Tensor.train(): - x = Tensor.empty((2, 16, 8, 8)).contiguous() - bn = nn.BatchNorm2d(16) - bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True - fw = bn(x).contiguous_backward().relu().contiguous() - fw.sum().backward() - # TODO: this is too many - check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10) + with Tensor.train(): + x = Tensor.empty((2, 16, 8, 8)).contiguous() + bn = nn.BatchNorm2d(16) + bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True + fw = bn(x).contiguous_backward().relu().contiguous() + fw.sum().backward() + # TODO: this is too many + check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10) def test_fold_conv_relu(self): c1 = nn.Conv2d(3,16,3) @@ -1321,7 +1320,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 14) + check_schedule(opt.schedule_step(), 14) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") @expect_rangeify_fails @@ -1626,7 +1625,7 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) def test_conv2d(self): _test_conv2d(5 if RANGEIFY else 7) - def test_conv2d_fused(self): _test_conv2d(5 if RANGEIFY else 5, FUSE_CONV_BW=1) + def test_conv2d_fused(self): _test_conv2d(5 if RANGEIFY else 5) @unittest.skipUnless(is_dtype_supported(dtypes.half) and is_dtype_supported(dtypes.ulong), "need half and ulong") def test_conv2d_half(self): _test_conv2d(5 if RANGEIFY else 7, dtype=dtypes.half) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index a7a39c5cd9..fd13c8b166 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -133,7 +133,6 @@ JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVa WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0) TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0) -FUSE_CONV_BW = ContextVar("FUSE_CONV_BW", 0) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) PICKLE_BUFFERS, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("LRU", 1) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)