mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
conv bw schedule and correctness tests to iterate on (#6461)
first to fix AST_REWRITE=1, then to implement the same fusion for dtypes.half.
This commit is contained in:
@@ -8,17 +8,17 @@ from typing import List, Optional, Union, cast
|
||||
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.dtype import DType, PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
|
||||
from tinygrad.ops import graph_rewrite
|
||||
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap
|
||||
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, flatten, getenv, SPLIT_REDUCEOP, unwrap
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import create_schedule, reduceop_fusor, st_fixup, ScheduleItem
|
||||
from tinygrad.engine.schedule import create_schedule, reduceop_fusor, st_fixup
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
from test.helpers import assert_equiv_uops, ast_const, is_dtype_supported, Context, timeit
|
||||
from test.helpers import ast_const, is_dtype_supported, Context, timeit
|
||||
from tinygrad.lazy import LazyBuffer, view_supported_devices
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
@@ -45,6 +45,28 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr
|
||||
l.linearize()
|
||||
return sched
|
||||
|
||||
def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
old_default_float, dtypes.default_float = dtypes.default_float, dtype
|
||||
dtypes.default_float = dtype
|
||||
Tensor.manual_seed(0)
|
||||
BS, CIN = 2, 3
|
||||
img = Tensor.randn(BS, CIN, 64, 64, requires_grad=True)
|
||||
w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True)
|
||||
ret = Tensor.conv2d(img, w).relu().mean().backward()
|
||||
dtypes.default_float = old_default_float
|
||||
with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata])
|
||||
run_schedule(s.copy())
|
||||
cnt = len([si for si in s if si.ast.op is UOps.SINK])
|
||||
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
|
||||
if getenv("CHECK", 1):
|
||||
import torch
|
||||
ref_img = torch.tensor(img.numpy(), requires_grad=True)
|
||||
ref_w = torch.tensor(w.numpy(), requires_grad=True)
|
||||
torch.nn.functional.conv2d(ref_img, ref_w).relu().mean().backward()
|
||||
assert ref_img.grad is not None and ref_w.grad is not None and img.grad is not None and w.grad is not None
|
||||
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
def test_basic_binop_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
@@ -232,53 +254,6 @@ class TestSchedule(unittest.TestCase):
|
||||
img_bn.backward()
|
||||
check_schedule(opt.schedule_step(), cnt)
|
||||
|
||||
def test_fold_conv_relu_backward(self):
|
||||
c1 = nn.Conv2d(3,16,3, bias=False)
|
||||
c1.weight.requires_grad = True
|
||||
img = Tensor.rand(2,3,64,64, requires_grad=True)
|
||||
|
||||
# run
|
||||
c1(img).relu().mean().backward()
|
||||
assert img.grad is not None and c1.weight.grad is not None
|
||||
run_schedule(check_schedule([img.grad, c1.weight.grad], 7, filter_sink=False))
|
||||
|
||||
# compare
|
||||
import torch
|
||||
c1_torch = torch.nn.Conv2d(3,16,3, bias=False)
|
||||
c1_torch.weight.requires_grad = True
|
||||
c1_torch.weight = torch.nn.Parameter(torch.tensor(c1.weight.numpy(), dtype=torch.float32))
|
||||
img_torch = torch.tensor(img.numpy(), requires_grad=True)
|
||||
c1_torch(img_torch).relu().mean().backward()
|
||||
assert img_torch.grad is not None and c1_torch.weight.grad is not None
|
||||
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_torch.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(img.grad.numpy(), img_torch.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_fold_conv_relu_backward_half(self):
|
||||
old_float = dtypes.default_float
|
||||
dtypes.default_float = dtypes.float16
|
||||
|
||||
c1 = nn.Conv2d(3,16,3, bias=False)
|
||||
c1.weight.requires_grad = True
|
||||
img = Tensor.rand(2,3,64,64, requires_grad=True)
|
||||
|
||||
# run
|
||||
c1(img).relu().mean().backward()
|
||||
assert img.grad is not None and c1.weight.grad is not None
|
||||
run_schedule(check_schedule([img.grad, c1.weight.grad], 7, filter_sink=False))
|
||||
dtypes.default_float = old_float
|
||||
|
||||
# compare
|
||||
import torch
|
||||
c1_torch = torch.nn.Conv2d(3,16,3, bias=False, dtype=torch.half)
|
||||
c1_torch.weight.requires_grad = True
|
||||
c1_torch.weight = torch.nn.Parameter(torch.tensor(c1.weight.numpy(), dtype=torch.half))
|
||||
img_torch = torch.tensor(img.numpy(), requires_grad=True)
|
||||
c1_torch(img_torch).relu().mean().backward()
|
||||
assert img_torch.grad is not None and c1_torch.weight.grad is not None
|
||||
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_torch.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(img.grad.numpy(), img_torch.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_fold_batchnorm_backward(self):
|
||||
with Context(FUSE_CONV_BW=1):
|
||||
with Tensor.train():
|
||||
@@ -994,6 +969,18 @@ class TestSchedule(unittest.TestCase):
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 22)
|
||||
|
||||
def test_sgd_4convs_fuse_conv_bw(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.empty(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,4,3,bias=False)
|
||||
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||
c3 = nn.Conv2d(8,16,3,bias=False)
|
||||
c4 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 19)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_prefer_half_buffer(self):
|
||||
x = Tensor.ones(4).contiguous().realize()
|
||||
@@ -1300,107 +1287,19 @@ class TestSchedule(unittest.TestCase):
|
||||
out = x.argmax(1)
|
||||
run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape
|
||||
|
||||
class TestConvBW(unittest.TestCase):
|
||||
def check_schedule(self, xt, cnt:int, flops=None) -> List[ScheduleItem]:
|
||||
with Context(FUSE_CONV_BW=getenv("FUSE_CONV_BW", 1), NOOPT=flops is not None):
|
||||
s = create_schedule(flatten([r.lazydata.lbs for r in xt]))
|
||||
kernels = [si for si in s if si.ast.op is UOps.SINK]
|
||||
for si in kernels: verify_ast(si.ast)
|
||||
GlobalCounters.reset()
|
||||
run_schedule(s)
|
||||
if flops is not None: assert GlobalCounters.global_ops <= flops, f"too many ops {GlobalCounters.global_ops}"
|
||||
if FUSE_CONV_BW: self.assertEqual(len(kernels), cnt)
|
||||
return kernels
|
||||
|
||||
def test_fold_conv_relu_backward(self):
|
||||
c1 = nn.Conv2d(3,16,3, bias=False)
|
||||
c1.weight.requires_grad = True
|
||||
img = Tensor.rand(2,3,64,64, requires_grad=True)
|
||||
|
||||
# run
|
||||
c1(img).relu().mean().backward()
|
||||
assert img.grad is not None and c1.weight.grad is not None
|
||||
self.check_schedule([img.grad, c1.weight.grad], 4)
|
||||
|
||||
# compare
|
||||
import torch
|
||||
c1_torch = torch.nn.Conv2d(3,16,3, bias=False)
|
||||
c1_torch.weight.requires_grad = True
|
||||
c1_torch.weight = torch.nn.Parameter(torch.tensor(c1.weight.numpy(), dtype=torch.float32))
|
||||
img_torch = torch.tensor(img.numpy(), requires_grad=True)
|
||||
c1_torch(img_torch).relu().mean().backward()
|
||||
assert img_torch.grad is not None and c1_torch.weight.grad is not None
|
||||
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_torch.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(img.grad.numpy(), img_torch.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_fold_conv_relu_backward_ast_rewrite(self):
|
||||
# shared params
|
||||
Tensor.manual_seed(0)
|
||||
img_np = Tensor.randn(2,3,64,64).numpy()
|
||||
c1_w = Tensor.randn(16,3,3,3).numpy()
|
||||
# graph_rewrite
|
||||
GlobalCounters.reset()
|
||||
c1 = nn.Conv2d(3,16,3, bias=False)
|
||||
c1.weight = Tensor(c1_w, requires_grad=True)
|
||||
img = Tensor(img_np, requires_grad=True)
|
||||
c1(img).relu().mean().backward()
|
||||
assert img.grad is not None and c1.weight.grad is not None
|
||||
with Context(AST_REWRITE=1): compare_ast = self.check_schedule([img.grad, c1.weight.grad], 3)[1].ast
|
||||
rw_flops = GlobalCounters.global_ops
|
||||
# ref
|
||||
GlobalCounters.reset()
|
||||
c1_ref = nn.Conv2d(3,16,3, bias=False)
|
||||
c1_ref.weight = Tensor(c1_w, requires_grad=True)
|
||||
img_ref = Tensor(img_np, requires_grad=True)
|
||||
c1_ref(img_ref).relu().mean().backward()
|
||||
assert img_ref.grad is not None and c1_ref.weight.grad is not None
|
||||
with Context(AST_REWRITE=0): ref_ast = self.check_schedule([img_ref.grad, c1_ref.weight.grad], 3)[1].ast
|
||||
ref_flops = GlobalCounters.global_ops
|
||||
# correctness
|
||||
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_ref.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(img.grad.numpy(), img_ref.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
# flops, TODO: This will be fixed once SWIZZLE merges view strides.
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertEqual(rw_flops, ref_flops)
|
||||
assert_equiv_uops(compare_ast, ref_ast)
|
||||
|
||||
def test_conv2d(self): _test_conv2d(8)
|
||||
def test_conv2d_fused(self): _test_conv2d(7, FUSE_CONV_BW=1)
|
||||
@unittest.expectedFailure
|
||||
def test_conv2d_fused_ast_rewrite(self): _test_conv2d(7, FUSE_CONV_BW=1, AST_REWRITE=1)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_fold_conv_relu_backward_half(self):
|
||||
old_float = dtypes.default_float
|
||||
dtypes.default_float = dtypes.float16
|
||||
|
||||
c1 = nn.Conv2d(3,16,3, bias=False)
|
||||
c1.weight.requires_grad = True
|
||||
|
||||
# run
|
||||
img = Tensor.rand(2,3,64,64, requires_grad=True)
|
||||
c1(img).relu().mean().backward()
|
||||
dtypes.default_float = old_float
|
||||
self.check_schedule([img.grad, c1.weight.grad], 4)
|
||||
|
||||
# compare
|
||||
import torch
|
||||
c1_torch = torch.nn.Conv2d(3,16,3, bias=False, dtype=torch.half)
|
||||
c1_torch.weight.requires_grad = True
|
||||
c1_torch.weight = torch.nn.Parameter(torch.tensor(c1.weight.numpy(), dtype=torch.half))
|
||||
img_torch = torch.tensor(img.numpy(), requires_grad=True)
|
||||
c1_torch(img_torch).relu().mean().backward()
|
||||
assert img_torch.grad is not None and c1_torch.weight.grad is not None
|
||||
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_torch.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(img.grad.numpy(), img_torch.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_sgd_4convs_fuse(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.empty(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,4,3,bias=False)
|
||||
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||
c3 = nn.Conv2d(8,16,3,bias=False)
|
||||
c4 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
self.check_schedule(opt.schedule_step(), 19)
|
||||
def test_conv2d_half(self): _test_conv2d(8, dtype=dtypes.half)
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.expectedFailure
|
||||
def test_conv2d_fused_half(self): _test_conv2d(7, dtype=dtypes.half)
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.expectedFailure
|
||||
def test_conv2d_fused_ast_rewrite_half(self): _test_conv2d(7, FUSE_CONV_BW=1, AST_REWRITE=1)
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):
|
||||
|
||||
Reference in New Issue
Block a user