diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1885499ab0..4f980faf94 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -161,6 +161,10 @@ jobs: run: | sudo apt update || true sudo apt install -y --no-install-recommends ninja-build + - name: Lint with ruff + run: | + pip3 install --upgrade --force-reinstall ruff + python3 -m ruff check extra/torch_backend/backend.py - name: Test one op run: PYTHONPATH=. FORWARD_ONLY=1 TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_add - name: Test ResNet-18 diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index bdf3972500..f210f56b51 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -1,3 +1,7 @@ +# ruff: noqa: E501, A001, A002, A006 +# A001 Variable `input` is shadowing a Python builtin +# A002 Function argument `input` is shadowing a Python builtin +# A006 Lambda argument `input` is shadowing a Python builtin from tinygrad import Tensor, dtypes from tinygrad.helpers import getenv, prod import torch.lib @@ -280,7 +284,6 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_ "aten.scatter.value_out": Tensor.scatter, "aten.where.self_out": Tensor.where, "aten.prod.int_out": Tensor.prod, - "aten.div.out_mode": Tensor.div, "aten.scatter_add.out": functools.partial(Tensor.scatter_reduce, reduce='sum'), }} @@ -340,12 +343,11 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ # these don't work in out form, they have size 0 "aten.abs": Tensor.abs, "aten.logical_not": Tensor.logical_not, - "aten.masked_fill_.Scalar": lambda self,mask,value: self.assign(mask.where(self, value)), "aten.multinomial": Tensor.multinomial, "aten.pad": Tensor.pad, "aten.reflection_pad2d": functools.partial(Tensor.pad, mode="reflect"), + "aten.masked_fill_.Scalar": lambda self,mask,value: self.assign(mask.where(self, value)), "aten.masked_fill.Scalar": Tensor.masked_fill, - "aten.masked_fill_.Scalar": Tensor.masked_fill, "aten.masked_fill.Tensor": Tensor.masked_fill, "aten.all": Tensor.all, "aten.sgn": Tensor.sign,