mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
delete FUSE_CONV_BW (#12527)
This commit is contained in:
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
import multiprocessing
|
||||
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW, Profiling
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
|
||||
|
||||
@@ -707,7 +707,7 @@ def train_unet3d():
|
||||
```BASEDIR=<folder_path> ./examples/mlperf/scripts/setup_kits19_dataset.sh```
|
||||
|
||||
2) To start training the model, run the following:
|
||||
```time PYTHONPATH=. WANDB=1 TRAIN_BEAM=3 FUSE_CONV_BW=1 GPUS=6 BS=6 MODEL=unet3d python3 examples/mlperf/model_train.py```
|
||||
```time PYTHONPATH=. WANDB=1 TRAIN_BEAM=3 GPUS=6 BS=6 MODEL=unet3d python3 examples/mlperf/model_train.py```
|
||||
"""
|
||||
from examples.mlperf.losses import dice_ce_loss
|
||||
from examples.mlperf.metrics import dice_score
|
||||
@@ -749,7 +749,6 @@ def train_unet3d():
|
||||
"train_beam": TRAIN_BEAM,
|
||||
"eval_beam": EVAL_BEAM,
|
||||
"wino": WINO.value,
|
||||
"fuse_conv_bw": FUSE_CONV_BW.value,
|
||||
"gpus": GPUS,
|
||||
"default_float": dtypes.default_float.name
|
||||
}
|
||||
|
||||
28
test/external/external_test_opt.py
vendored
28
test/external/external_test_opt.py
vendored
@@ -4,7 +4,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from tinygrad import GlobalCounters, Tensor, Device
|
||||
from tinygrad.helpers import getenv, Context, RANGEIFY
|
||||
from tinygrad.helpers import getenv, RANGEIFY
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.engine.realize import capturing
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
@@ -217,24 +217,22 @@ class TestOpt(unittest.TestCase):
|
||||
assert cache_len == 1, "reduceop was rerun!"
|
||||
|
||||
def test_expand_reduce_is_folded_on_same_axis(self):
|
||||
with Context(FUSE_CONV_BW=1):
|
||||
for axis in [0, 1]:
|
||||
for n in [4, 8, 16]:
|
||||
b = torch.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis)
|
||||
with CLCache(allowed=3 if RANGEIFY else 2):
|
||||
a = Tensor.ones(n, n).contiguous().sum(axis).reshape(n, 1).expand(n, n).sum(axis)
|
||||
a.realize()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
|
||||
|
||||
def test_expand_reduce_is_folded_on_different_axes(self):
|
||||
with Context(FUSE_CONV_BW=1):
|
||||
axis1, axis2 = 0, 1
|
||||
for axis in [0, 1]:
|
||||
for n in [4, 8, 16]:
|
||||
b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
|
||||
b = torch.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis)
|
||||
with CLCache(allowed=3 if RANGEIFY else 2):
|
||||
a = Tensor.ones(n, n).contiguous().sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
|
||||
a = Tensor.ones(n, n).contiguous().sum(axis).reshape(n, 1).expand(n, n).sum(axis)
|
||||
a.realize()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
|
||||
|
||||
def test_expand_reduce_is_folded_on_different_axes(self):
|
||||
axis1, axis2 = 0, 1
|
||||
for n in [4, 8, 16]:
|
||||
b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
|
||||
with CLCache(allowed=3 if RANGEIFY else 2):
|
||||
a = Tensor.ones(n, n).contiguous().sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
|
||||
a.realize()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -45,7 +45,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
|
||||
def _realize_weights(m):
|
||||
for p in nn.state.get_parameters(m): p.realize()
|
||||
|
||||
def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
def _test_conv2d(allowed:int, dtype:DType=dtypes.float):
|
||||
old_default_float, dtypes.default_float = dtypes.default_float, dtype
|
||||
dtypes.default_float = dtype
|
||||
Tensor.manual_seed(0)
|
||||
@@ -54,7 +54,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True).realize()
|
||||
ret = Tensor.conv2d(img, w).relu().mean().backward()
|
||||
dtypes.default_float = old_default_float
|
||||
with Context(**kwargs): s = Tensor.schedule(ret, img.grad, w.grad)
|
||||
s = Tensor.schedule(ret, img.grad, w.grad)
|
||||
run_schedule(s.copy())
|
||||
cnt = len([si for si in s if si.ast.op is Ops.SINK])
|
||||
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
|
||||
@@ -470,15 +470,14 @@ class TestSchedule(unittest.TestCase):
|
||||
check_schedule(opt.schedule_step(), cnt)
|
||||
|
||||
def test_fold_batchnorm_backward(self):
|
||||
with Context(FUSE_CONV_BW=1):
|
||||
with Tensor.train():
|
||||
x = Tensor.empty((2, 16, 8, 8)).contiguous()
|
||||
bn = nn.BatchNorm2d(16)
|
||||
bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True
|
||||
fw = bn(x).contiguous_backward().relu().contiguous()
|
||||
fw.sum().backward()
|
||||
# TODO: this is too many
|
||||
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10)
|
||||
with Tensor.train():
|
||||
x = Tensor.empty((2, 16, 8, 8)).contiguous()
|
||||
bn = nn.BatchNorm2d(16)
|
||||
bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True
|
||||
fw = bn(x).contiguous_backward().relu().contiguous()
|
||||
fw.sum().backward()
|
||||
# TODO: this is too many
|
||||
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10)
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
@@ -1321,7 +1320,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 14)
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@expect_rangeify_fails
|
||||
@@ -1626,7 +1625,7 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(out, 2))
|
||||
|
||||
def test_conv2d(self): _test_conv2d(5 if RANGEIFY else 7)
|
||||
def test_conv2d_fused(self): _test_conv2d(5 if RANGEIFY else 5, FUSE_CONV_BW=1)
|
||||
def test_conv2d_fused(self): _test_conv2d(5 if RANGEIFY else 5)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half) and is_dtype_supported(dtypes.ulong), "need half and ulong")
|
||||
def test_conv2d_half(self): _test_conv2d(5 if RANGEIFY else 7, dtype=dtypes.half)
|
||||
|
||||
@@ -133,7 +133,6 @@ JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVa
|
||||
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
||||
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
|
||||
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
|
||||
FUSE_CONV_BW = ContextVar("FUSE_CONV_BW", 0)
|
||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
||||
PICKLE_BUFFERS, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("LRU", 1)
|
||||
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
||||
|
||||
Reference in New Issue
Block a user