From a0764f0dc031a0f2067261e2ac29567fa1d30faa Mon Sep 17 00:00:00 2001 From: Priyank Patel Date: Wed, 26 Feb 2025 19:32:25 -0800 Subject: [PATCH] (bounty) Make mnist training run with torch backend (#9233) * yml changes * torch backend remove meta decomps and add test * torch backend bump timeout for tests --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- .github/workflows/test.yml | 6 +- extra/torch_backend/backend.py | 133 ++++++++++++++++++--------------- extra/torch_backend/test.py | 5 ++ 3 files changed, 79 insertions(+), 65 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6253823e67..ea0870666e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -146,7 +146,7 @@ jobs: torchbackend: name: Torch Backend Tests runs-on: ubuntu-latest - timeout-minutes: 10 + timeout-minutes: 20 steps: - name: Checkout Code uses: actions/checkout@v4 @@ -169,10 +169,10 @@ jobs: run: PYTHONPATH=. python3 extra/torch_backend/test.py - name: Test one op in torch tests run: PYTHONPATH=. DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32 + - name: Test beautiful_mnist in torch with TINY_BACKEND + run: PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py - name: Test Ops with TINY_BACKEND (expect failure) run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py || true - - name: Test beautiful_mnist in torch with TINY_BACKEND (expect failure) - run: PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py || true - name: Test some torch tests (expect failure) run: PYTHONPATH=. python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 99990132bd..f3e14d03e2 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -73,6 +73,14 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di # TODO: this is wrong return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64))) +@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone") +def max_pool2d_with_indices_backward(grad_out:Tensor, self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, indices=None): + grad_out, self, indices = unwrap(grad_out), unwrap(self), unwrap(indices) + # TODO: utilize input indices once they are correct + self = self.detach().clone().requires_grad_(True) + Tensor.max_pool2d(self, kernel_size, stride, dilation, padding, ceil_mode).backward(grad_out) + return wrap(self.grad) + @torch.library.impl("aten::arange", "privateuseone") def arange(end, dtype=None, device=None, pin_memory=None): return wrap(Tensor.arange(0, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()))) @@ -111,6 +119,18 @@ def convolution_overrideable(input, weight, bias, stride, padding, dilation, tra groups=groups, stride=stride, dilation=dilation, padding=padding)) #raise NotImplementedError("need convolution") +@torch.library.impl("aten::convolution_backward_overrideable", "privateuseone") +def convolution_backward_overrideable(grad_out, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask): + if TORCH_DEBUG >= 1: + print(f"convolution_backward {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}") + grad_out, input, weight = unwrap(grad_out), unwrap(input), unwrap(weight) + input = input.detach().clone().requires_grad_(output_mask[0]) + weight = weight.detach().clone().requires_grad_(output_mask[1]) + bias = Tensor.zeros(weight.shape[0]).requires_grad_(output_mask[2]) + Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding).backward(grad_out) + return tuple(wrap(x.grad) if x.grad is not None else None for x in [input, weight, bias]) + #raise NotImplementedError("need convolution") + @torch.library.impl("aten::_copy_from", "privateuseone") def _copy_from(src, dest, non_blocking=False): if str(src.device) == "tiny" and str(dest.device) == "tiny": @@ -134,68 +154,57 @@ def cat_out(tensors, dim=0, out=None): # register some decompositions from torch._decomp import get_decompositions aten = torch.ops.aten -decomps = { - "post_autograd": [ - aten.native_batch_norm, aten.native_batch_norm_backward, - aten.native_layer_norm_backward, - aten.addmm, - aten.addcmul, - aten.addcdiv, - aten._log_softmax_backward_data, - aten.threshold_backward, - aten.softplus_backward, - aten.elu, # elu has a scale + input_scale param - aten.softplus, - aten.threshold, - aten.nll_loss_forward, - aten.nll_loss_backward, - # AttributeError: 'int' object has no attribute '_broadcasted' - aten.sigmoid_backward, - aten.tanh_backward, - aten.sinc, - aten._prelu_kernel, - aten.softshrink, - aten.hardshrink, - aten.log_sigmoid_forward, - aten.isneginf, - aten.isposinf, - aten.nan_to_num, - aten.logit, - aten.rsub, - aten.index_select, - aten.native_dropout, aten.native_dropout_backward, - aten._softmax_backward_data, aten.embedding_dense_backward, - aten.linalg_vector_norm, - aten.unfold, - # activations - aten.hardswish, aten.hardswish_backward, - aten.hardtanh, aten.hardtanh_backward, - aten.gelu, aten.gelu_backward, - # NOTE: this uses index - #aten.reflection_pad2d, - # NOTE: many of these don't work or cause infinite loops - #aten.var_mean, - #aten.var, - #aten.rsqrt, - #aten.max_pool2d_with_indices, - # NOTE: these are prims - #aten.digamma, - #aten.erfinv, - #aten.lgamma, - # this needs copy_strided - #aten.lerp, - ], - "meta": [ - aten.max_pool2d_with_indices_backward, - aten.convolution_backward, - ], -} - -for dctype,lst in decomps.items(): - for k,v in get_decompositions(lst, type=dctype).items(): - key = str(k._schema).split("(")[0] - if TORCH_DEBUG >= 2: print("register decomp for", k) - torch.library.impl(key, "privateuseone")(v) +decomps = [ + aten.native_batch_norm, aten.native_batch_norm_backward, + aten.native_layer_norm_backward, + aten.addmm, + aten.addcmul, + aten.addcdiv, + aten._log_softmax_backward_data, + aten.threshold_backward, + aten.softplus_backward, + aten.elu, # elu has a scale + input_scale param + aten.softplus, + aten.threshold, + aten.nll_loss_forward, + aten.nll_loss_backward, + # AttributeError: 'int' object has no attribute '_broadcasted' + aten.sigmoid_backward, + aten.tanh_backward, + aten.sinc, + aten._prelu_kernel, + aten.softshrink, + aten.hardshrink, + aten.log_sigmoid_forward, + aten.isneginf, + aten.isposinf, + aten.nan_to_num, + aten.logit, + aten.rsub, + aten.index_select, + aten.native_dropout, aten.native_dropout_backward, + aten._softmax_backward_data, aten.embedding_dense_backward, + aten.linalg_vector_norm, + # activations + aten.hardswish, aten.hardswish_backward, + aten.hardtanh, aten.hardtanh_backward, + aten.gelu, aten.gelu_backward, + # NOTE: many of these don't work or cause infinite loops + #aten.var_mean, + #aten.var, + #aten.rsqrt, + #aten.max_pool2d_with_indices, + # NOTE: these are prims + #aten.digamma, + #aten.erfinv, + #aten.lgamma, + # this needs copy_strided + #aten.lerp, +] +for k,v in get_decompositions(decomps).items(): + key = str(k._schema).split("(")[0] + if TORCH_DEBUG >= 2: print("register decomp for", k) + torch.library.impl(key, "privateuseone")(v) # NOTE: we should only implement the "out" form, it should be 0 overhead # TODO: due to issue with empty / is_realized, it is slow to use assign so we use replace diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index 88fa473741..e75e3ad0e1 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -77,6 +77,11 @@ class TestTorchBackend(unittest.TestCase): c = a == b print(c.cpu().numpy()) + def test_maxpool2d_backward(self): + x = torch.arange(3*3, device=device).reshape(1, 1, 3, 3).requires_grad_(True) + torch.nn.functional.max_pool2d(x, kernel_size=2, stride=1).sum().backward() + np.testing.assert_equal(x.grad.squeeze().cpu().numpy(), [[0, 0, 0], [0, 1, 1], [0, 1, 1]]) + @unittest.skip("meh") def test_str(self): a = torch.ones(4, device=device)