conv bw in one kernel with graph_rewrite (#6330)

* double reduce merger

* add test_fold_conv_relu_backward_ast_rewrite

* a correctness test to iterate on

* merge axes the other way around

* better
This commit is contained in:
qazal
2024-09-03 03:53:53 +08:00
committed by GitHub
parent bf645d62b3
commit 2f00bf0c78
2 changed files with 37 additions and 0 deletions

View File

@@ -1330,6 +1330,36 @@ class TestConvBW(unittest.TestCase):
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_torch.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(img.grad.numpy(), img_torch.grad.numpy(), atol=5e-4, rtol=1e-5)
def test_fold_conv_relu_backward_ast_rewrite(self):
# shared params
Tensor.manual_seed(0)
img_np = Tensor.randn(2,3,64,64).numpy()
c1_w = Tensor.randn(16,3,3,3).numpy()
# graph_rewrite
GlobalCounters.reset()
c1 = nn.Conv2d(3,16,3, bias=False)
c1.weight = Tensor(c1_w, requires_grad=True)
img = Tensor(img_np, requires_grad=True)
c1(img).relu().mean().backward()
assert img.grad is not None and c1.weight.grad is not None
with Context(AST_REWRITE=1): self.check_schedule([img.grad, c1.weight.grad], 3)
rw_flops = GlobalCounters.global_ops
# ref
GlobalCounters.reset()
c1_ref = nn.Conv2d(3,16,3, bias=False)
c1_ref.weight = Tensor(c1_w, requires_grad=True)
img_ref = Tensor(img_np, requires_grad=True)
c1_ref(img_ref).relu().mean().backward()
assert img_ref.grad is not None and c1_ref.weight.grad is not None
with Context(AST_REWRITE=0): self.check_schedule([img_ref.grad, c1_ref.weight.grad], 3)
ref_flops = GlobalCounters.global_ops
# correctness
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_ref.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(img.grad.numpy(), img_ref.grad.numpy(), atol=5e-4, rtol=1e-5)
# flops, TODO: This will be fixed once SWIZZLE merges view strides.
with self.assertRaises(AssertionError):
self.assertEqual(rw_flops, ref_flops)
@unittest.expectedFailure
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_fold_conv_relu_backward_half(self):

View File

@@ -180,6 +180,12 @@ def push_swizzle_through_reduce(swizzle:UOp, reduceop:UOp) -> UOp:
new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, reduceop.arg[1])
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), (reduceop.arg[0], new_axis))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
new_axis: Tuple[int, ...] = root.arg[1]+first_reduce.arg[1]
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], new_axis))
def push_reduceop_shape(root:UOp) -> Optional[UOp]:
reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS]
if len(reduceops) == 0: return None
@@ -190,6 +196,7 @@ def push_reduceop_shape(root:UOp) -> Optional[UOp]:
reduceop_fusor = PatternMatcher([
(UPat(UOps.SWIZZLE, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_through_reduce),
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE}, name="root"), push_reduceop_shape),
])