conv bw schedule and correctness tests to iterate on (#6461)

first to fix AST_REWRITE=1, then to implement the same fusion for dtypes.half.
This commit is contained in:
qazal
2024-09-11 08:47:07 +08:00
committed by GitHub
parent b574caadc9
commit 803b8b9313

View File

@@ -8,17 +8,17 @@ from typing import List, Optional, Union, cast
from tinygrad import nn, dtypes
from tinygrad.device import Device
from tinygrad.dtype import PtrDType
from tinygrad.dtype import DType, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
from tinygrad.ops import graph_rewrite
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, flatten, getenv, SPLIT_REDUCEOP, unwrap
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import create_schedule, reduceop_fusor, st_fixup, ScheduleItem
from tinygrad.engine.schedule import create_schedule, reduceop_fusor, st_fixup
from tinygrad.engine.realize import CompiledRunner, run_schedule
from test.helpers import assert_equiv_uops, ast_const, is_dtype_supported, Context, timeit
from test.helpers import ast_const, is_dtype_supported, Context, timeit
from tinygrad.lazy import LazyBuffer, view_supported_devices
from extra.models.llama import precompute_freqs_cis
@@ -45,6 +45,28 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr
l.linearize()
return sched
def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
old_default_float, dtypes.default_float = dtypes.default_float, dtype
dtypes.default_float = dtype
Tensor.manual_seed(0)
BS, CIN = 2, 3
img = Tensor.randn(BS, CIN, 64, 64, requires_grad=True)
w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True)
ret = Tensor.conv2d(img, w).relu().mean().backward()
dtypes.default_float = old_default_float
with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata])
run_schedule(s.copy())
cnt = len([si for si in s if si.ast.op is UOps.SINK])
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
if getenv("CHECK", 1):
import torch
ref_img = torch.tensor(img.numpy(), requires_grad=True)
ref_w = torch.tensor(w.numpy(), requires_grad=True)
torch.nn.functional.conv2d(ref_img, ref_w).relu().mean().backward()
assert ref_img.grad is not None and ref_w.grad is not None and img.grad is not None and w.grad is not None
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
class TestSchedule(unittest.TestCase):
def test_basic_binop_fusion(self):
a = Tensor.empty(10)
@@ -232,53 +254,6 @@ class TestSchedule(unittest.TestCase):
img_bn.backward()
check_schedule(opt.schedule_step(), cnt)
def test_fold_conv_relu_backward(self):
c1 = nn.Conv2d(3,16,3, bias=False)
c1.weight.requires_grad = True
img = Tensor.rand(2,3,64,64, requires_grad=True)
# run
c1(img).relu().mean().backward()
assert img.grad is not None and c1.weight.grad is not None
run_schedule(check_schedule([img.grad, c1.weight.grad], 7, filter_sink=False))
# compare
import torch
c1_torch = torch.nn.Conv2d(3,16,3, bias=False)
c1_torch.weight.requires_grad = True
c1_torch.weight = torch.nn.Parameter(torch.tensor(c1.weight.numpy(), dtype=torch.float32))
img_torch = torch.tensor(img.numpy(), requires_grad=True)
c1_torch(img_torch).relu().mean().backward()
assert img_torch.grad is not None and c1_torch.weight.grad is not None
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_torch.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(img.grad.numpy(), img_torch.grad.numpy(), atol=5e-4, rtol=1e-5)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_fold_conv_relu_backward_half(self):
old_float = dtypes.default_float
dtypes.default_float = dtypes.float16
c1 = nn.Conv2d(3,16,3, bias=False)
c1.weight.requires_grad = True
img = Tensor.rand(2,3,64,64, requires_grad=True)
# run
c1(img).relu().mean().backward()
assert img.grad is not None and c1.weight.grad is not None
run_schedule(check_schedule([img.grad, c1.weight.grad], 7, filter_sink=False))
dtypes.default_float = old_float
# compare
import torch
c1_torch = torch.nn.Conv2d(3,16,3, bias=False, dtype=torch.half)
c1_torch.weight.requires_grad = True
c1_torch.weight = torch.nn.Parameter(torch.tensor(c1.weight.numpy(), dtype=torch.half))
img_torch = torch.tensor(img.numpy(), requires_grad=True)
c1_torch(img_torch).relu().mean().backward()
assert img_torch.grad is not None and c1_torch.weight.grad is not None
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_torch.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(img.grad.numpy(), img_torch.grad.numpy(), atol=5e-4, rtol=1e-5)
def test_fold_batchnorm_backward(self):
with Context(FUSE_CONV_BW=1):
with Tensor.train():
@@ -994,6 +969,18 @@ class TestSchedule(unittest.TestCase):
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 22)
def test_sgd_4convs_fuse_conv_bw(self):
with Tensor.train():
img = Tensor.empty(2,3,64,64)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False)
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 19)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_prefer_half_buffer(self):
x = Tensor.ones(4).contiguous().realize()
@@ -1300,107 +1287,19 @@ class TestSchedule(unittest.TestCase):
out = x.argmax(1)
run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape
class TestConvBW(unittest.TestCase):
def check_schedule(self, xt, cnt:int, flops=None) -> List[ScheduleItem]:
with Context(FUSE_CONV_BW=getenv("FUSE_CONV_BW", 1), NOOPT=flops is not None):
s = create_schedule(flatten([r.lazydata.lbs for r in xt]))
kernels = [si for si in s if si.ast.op is UOps.SINK]
for si in kernels: verify_ast(si.ast)
GlobalCounters.reset()
run_schedule(s)
if flops is not None: assert GlobalCounters.global_ops <= flops, f"too many ops {GlobalCounters.global_ops}"
if FUSE_CONV_BW: self.assertEqual(len(kernels), cnt)
return kernels
def test_fold_conv_relu_backward(self):
c1 = nn.Conv2d(3,16,3, bias=False)
c1.weight.requires_grad = True
img = Tensor.rand(2,3,64,64, requires_grad=True)
# run
c1(img).relu().mean().backward()
assert img.grad is not None and c1.weight.grad is not None
self.check_schedule([img.grad, c1.weight.grad], 4)
# compare
import torch
c1_torch = torch.nn.Conv2d(3,16,3, bias=False)
c1_torch.weight.requires_grad = True
c1_torch.weight = torch.nn.Parameter(torch.tensor(c1.weight.numpy(), dtype=torch.float32))
img_torch = torch.tensor(img.numpy(), requires_grad=True)
c1_torch(img_torch).relu().mean().backward()
assert img_torch.grad is not None and c1_torch.weight.grad is not None
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_torch.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(img.grad.numpy(), img_torch.grad.numpy(), atol=5e-4, rtol=1e-5)
def test_fold_conv_relu_backward_ast_rewrite(self):
# shared params
Tensor.manual_seed(0)
img_np = Tensor.randn(2,3,64,64).numpy()
c1_w = Tensor.randn(16,3,3,3).numpy()
# graph_rewrite
GlobalCounters.reset()
c1 = nn.Conv2d(3,16,3, bias=False)
c1.weight = Tensor(c1_w, requires_grad=True)
img = Tensor(img_np, requires_grad=True)
c1(img).relu().mean().backward()
assert img.grad is not None and c1.weight.grad is not None
with Context(AST_REWRITE=1): compare_ast = self.check_schedule([img.grad, c1.weight.grad], 3)[1].ast
rw_flops = GlobalCounters.global_ops
# ref
GlobalCounters.reset()
c1_ref = nn.Conv2d(3,16,3, bias=False)
c1_ref.weight = Tensor(c1_w, requires_grad=True)
img_ref = Tensor(img_np, requires_grad=True)
c1_ref(img_ref).relu().mean().backward()
assert img_ref.grad is not None and c1_ref.weight.grad is not None
with Context(AST_REWRITE=0): ref_ast = self.check_schedule([img_ref.grad, c1_ref.weight.grad], 3)[1].ast
ref_flops = GlobalCounters.global_ops
# correctness
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_ref.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(img.grad.numpy(), img_ref.grad.numpy(), atol=5e-4, rtol=1e-5)
# flops, TODO: This will be fixed once SWIZZLE merges view strides.
with self.assertRaises(AssertionError):
self.assertEqual(rw_flops, ref_flops)
assert_equiv_uops(compare_ast, ref_ast)
def test_conv2d(self): _test_conv2d(8)
def test_conv2d_fused(self): _test_conv2d(7, FUSE_CONV_BW=1)
@unittest.expectedFailure
def test_conv2d_fused_ast_rewrite(self): _test_conv2d(7, FUSE_CONV_BW=1, AST_REWRITE=1)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_fold_conv_relu_backward_half(self):
old_float = dtypes.default_float
dtypes.default_float = dtypes.float16
c1 = nn.Conv2d(3,16,3, bias=False)
c1.weight.requires_grad = True
# run
img = Tensor.rand(2,3,64,64, requires_grad=True)
c1(img).relu().mean().backward()
dtypes.default_float = old_float
self.check_schedule([img.grad, c1.weight.grad], 4)
# compare
import torch
c1_torch = torch.nn.Conv2d(3,16,3, bias=False, dtype=torch.half)
c1_torch.weight.requires_grad = True
c1_torch.weight = torch.nn.Parameter(torch.tensor(c1.weight.numpy(), dtype=torch.half))
img_torch = torch.tensor(img.numpy(), requires_grad=True)
c1_torch(img_torch).relu().mean().backward()
assert img_torch.grad is not None and c1_torch.weight.grad is not None
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_torch.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(img.grad.numpy(), img_torch.grad.numpy(), atol=5e-4, rtol=1e-5)
def test_sgd_4convs_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,64,64)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False)
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
self.check_schedule(opt.schedule_step(), 19)
def test_conv2d_half(self): _test_conv2d(8, dtype=dtypes.half)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.expectedFailure
def test_conv2d_fused_half(self): _test_conv2d(7, dtype=dtypes.half)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.expectedFailure
def test_conv2d_fused_ast_rewrite_half(self): _test_conv2d(7, FUSE_CONV_BW=1, AST_REWRITE=1)
class TestIndexing(unittest.TestCase):
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):