delete FUSE_CONV_BW (#12527)

This commit is contained in:
chenyu
2025-10-08 22:41:38 +08:00
committed by GitHub
parent 2653147cb7
commit 28edea5d67
4 changed files with 27 additions and 32 deletions

View File

@@ -3,7 +3,7 @@ from pathlib import Path
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW, Profiling
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
@@ -707,7 +707,7 @@ def train_unet3d():
```BASEDIR=<folder_path> ./examples/mlperf/scripts/setup_kits19_dataset.sh```
2) To start training the model, run the following:
```time PYTHONPATH=. WANDB=1 TRAIN_BEAM=3 FUSE_CONV_BW=1 GPUS=6 BS=6 MODEL=unet3d python3 examples/mlperf/model_train.py```
```time PYTHONPATH=. WANDB=1 TRAIN_BEAM=3 GPUS=6 BS=6 MODEL=unet3d python3 examples/mlperf/model_train.py```
"""
from examples.mlperf.losses import dice_ce_loss
from examples.mlperf.metrics import dice_score
@@ -749,7 +749,6 @@ def train_unet3d():
"train_beam": TRAIN_BEAM,
"eval_beam": EVAL_BEAM,
"wino": WINO.value,
"fuse_conv_bw": FUSE_CONV_BW.value,
"gpus": GPUS,
"default_float": dtypes.default_float.name
}

View File

@@ -4,7 +4,7 @@ import numpy as np
import torch
from tinygrad import GlobalCounters, Tensor, Device
from tinygrad.helpers import getenv, Context, RANGEIFY
from tinygrad.helpers import getenv, RANGEIFY
from tinygrad.nn.state import get_parameters
from tinygrad.engine.realize import capturing
from tinygrad.tensor import _to_np_dtype
@@ -217,24 +217,22 @@ class TestOpt(unittest.TestCase):
assert cache_len == 1, "reduceop was rerun!"
def test_expand_reduce_is_folded_on_same_axis(self):
with Context(FUSE_CONV_BW=1):
for axis in [0, 1]:
for n in [4, 8, 16]:
b = torch.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis)
with CLCache(allowed=3 if RANGEIFY else 2):
a = Tensor.ones(n, n).contiguous().sum(axis).reshape(n, 1).expand(n, n).sum(axis)
a.realize()
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
def test_expand_reduce_is_folded_on_different_axes(self):
with Context(FUSE_CONV_BW=1):
axis1, axis2 = 0, 1
for axis in [0, 1]:
for n in [4, 8, 16]:
b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
b = torch.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis)
with CLCache(allowed=3 if RANGEIFY else 2):
a = Tensor.ones(n, n).contiguous().sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
a = Tensor.ones(n, n).contiguous().sum(axis).reshape(n, 1).expand(n, n).sum(axis)
a.realize()
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
def test_expand_reduce_is_folded_on_different_axes(self):
axis1, axis2 = 0, 1
for n in [4, 8, 16]:
b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
with CLCache(allowed=3 if RANGEIFY else 2):
a = Tensor.ones(n, n).contiguous().sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
a.realize()
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
if __name__ == '__main__':
unittest.main()

View File

@@ -45,7 +45,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
def _realize_weights(m):
for p in nn.state.get_parameters(m): p.realize()
def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
def _test_conv2d(allowed:int, dtype:DType=dtypes.float):
old_default_float, dtypes.default_float = dtypes.default_float, dtype
dtypes.default_float = dtype
Tensor.manual_seed(0)
@@ -54,7 +54,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True).realize()
ret = Tensor.conv2d(img, w).relu().mean().backward()
dtypes.default_float = old_default_float
with Context(**kwargs): s = Tensor.schedule(ret, img.grad, w.grad)
s = Tensor.schedule(ret, img.grad, w.grad)
run_schedule(s.copy())
cnt = len([si for si in s if si.ast.op is Ops.SINK])
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
@@ -470,15 +470,14 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), cnt)
def test_fold_batchnorm_backward(self):
with Context(FUSE_CONV_BW=1):
with Tensor.train():
x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16)
bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True
fw = bn(x).contiguous_backward().relu().contiguous()
fw.sum().backward()
# TODO: this is too many
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10)
with Tensor.train():
x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16)
bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True
fw = bn(x).contiguous_backward().relu().contiguous()
fw.sum().backward()
# TODO: this is too many
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10)
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)
@@ -1321,7 +1320,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 14)
check_schedule(opt.schedule_step(), 14)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@expect_rangeify_fails
@@ -1626,7 +1625,7 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 2))
def test_conv2d(self): _test_conv2d(5 if RANGEIFY else 7)
def test_conv2d_fused(self): _test_conv2d(5 if RANGEIFY else 5, FUSE_CONV_BW=1)
def test_conv2d_fused(self): _test_conv2d(5 if RANGEIFY else 5)
@unittest.skipUnless(is_dtype_supported(dtypes.half) and is_dtype_supported(dtypes.ulong), "need half and ulong")
def test_conv2d_half(self): _test_conv2d(5 if RANGEIFY else 7, dtype=dtypes.half)

View File

@@ -133,7 +133,6 @@ JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVa
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
FUSE_CONV_BW = ContextVar("FUSE_CONV_BW", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
PICKLE_BUFFERS, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("LRU", 1)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)