mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
conv bw in one kernel with graph_rewrite (#6330)
* double reduce merger * add test_fold_conv_relu_backward_ast_rewrite * a correctness test to iterate on * merge axes the other way around * better
This commit is contained in:
@@ -1330,6 +1330,36 @@ class TestConvBW(unittest.TestCase):
|
||||
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_torch.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(img.grad.numpy(), img_torch.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_fold_conv_relu_backward_ast_rewrite(self):
|
||||
# shared params
|
||||
Tensor.manual_seed(0)
|
||||
img_np = Tensor.randn(2,3,64,64).numpy()
|
||||
c1_w = Tensor.randn(16,3,3,3).numpy()
|
||||
# graph_rewrite
|
||||
GlobalCounters.reset()
|
||||
c1 = nn.Conv2d(3,16,3, bias=False)
|
||||
c1.weight = Tensor(c1_w, requires_grad=True)
|
||||
img = Tensor(img_np, requires_grad=True)
|
||||
c1(img).relu().mean().backward()
|
||||
assert img.grad is not None and c1.weight.grad is not None
|
||||
with Context(AST_REWRITE=1): self.check_schedule([img.grad, c1.weight.grad], 3)
|
||||
rw_flops = GlobalCounters.global_ops
|
||||
# ref
|
||||
GlobalCounters.reset()
|
||||
c1_ref = nn.Conv2d(3,16,3, bias=False)
|
||||
c1_ref.weight = Tensor(c1_w, requires_grad=True)
|
||||
img_ref = Tensor(img_np, requires_grad=True)
|
||||
c1_ref(img_ref).relu().mean().backward()
|
||||
assert img_ref.grad is not None and c1_ref.weight.grad is not None
|
||||
with Context(AST_REWRITE=0): self.check_schedule([img_ref.grad, c1_ref.weight.grad], 3)
|
||||
ref_flops = GlobalCounters.global_ops
|
||||
# correctness
|
||||
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_ref.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(img.grad.numpy(), img_ref.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
# flops, TODO: This will be fixed once SWIZZLE merges view strides.
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertEqual(rw_flops, ref_flops)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_fold_conv_relu_backward_half(self):
|
||||
|
||||
@@ -180,6 +180,12 @@ def push_swizzle_through_reduce(swizzle:UOp, reduceop:UOp) -> UOp:
|
||||
new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, reduceop.arg[1])
|
||||
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), (reduceop.arg[0], new_axis))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
||||
new_axis: Tuple[int, ...] = root.arg[1]+first_reduce.arg[1]
|
||||
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], new_axis))
|
||||
|
||||
def push_reduceop_shape(root:UOp) -> Optional[UOp]:
|
||||
reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS]
|
||||
if len(reduceops) == 0: return None
|
||||
@@ -190,6 +196,7 @@ def push_reduceop_shape(root:UOp) -> Optional[UOp]:
|
||||
|
||||
reduceop_fusor = PatternMatcher([
|
||||
(UPat(UOps.SWIZZLE, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_through_reduce),
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE}, name="root"), push_reduceop_shape),
|
||||
])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user