From 0d4ba7dd87b7ca157602dcd09acdfa16427b48c0 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 4 Mar 2025 00:15:29 +0800 Subject: [PATCH] import tinygrad.frontend.torch (#9337) * import tinygrad.frontend.torch * type ignore --- examples/other_mnist/beautiful_mnist_torch.py | 7 ++----- test/test_ops.py | 2 +- tinygrad/frontend/__init__.py | 0 tinygrad/frontend/torch.py | 5 +++++ 4 files changed, 8 insertions(+), 6 deletions(-) create mode 100644 tinygrad/frontend/__init__.py create mode 100644 tinygrad/frontend/torch.py diff --git a/examples/other_mnist/beautiful_mnist_torch.py b/examples/other_mnist/beautiful_mnist_torch.py index 1016fba9d8..e1cef16d9c 100644 --- a/examples/other_mnist/beautiful_mnist_torch.py +++ b/examples/other_mnist/beautiful_mnist_torch.py @@ -26,11 +26,8 @@ class Model(nn.Module): return self.lin(torch.flatten(x, 1)) if __name__ == "__main__": - if getenv("TINY_BACKEND2"): - import extra.torch_backend.backend2 - device = torch.device("cpu") - elif getenv("TINY_BACKEND"): - import extra.torch_backend.backend + if getenv("TINY_BACKEND"): + import tinygrad.frontend.torch device = torch.device("tiny") else: device = torch.device("mps") diff --git a/test/test_ops.py b/test/test_ops.py index 247e49cb89..394ef5604d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -9,7 +9,7 @@ from tinygrad.tensor import _to_np_dtype from tinygrad.device import is_dtype_supported if getenv("TINY_BACKEND"): - import extra.torch_backend.backend # noqa: F401 # pylint: disable=unused-import + import tinygrad.frontend.torch # noqa: F401 # pylint: disable=unused-import torch.set_default_device("tiny") if CI: diff --git a/tinygrad/frontend/__init__.py b/tinygrad/frontend/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tinygrad/frontend/torch.py b/tinygrad/frontend/torch.py new file mode 100644 index 0000000000..079256f674 --- /dev/null +++ b/tinygrad/frontend/torch.py @@ -0,0 +1,5 @@ +# type: ignore +import sys, pathlib +sys.path.append(pathlib.Path(__file__).parent.parent.as_posix()) +try: import extra.torch_backend.backend # noqa: F401 # pylint: disable=unused-import +except ImportError as e: raise ImportError("torch frontend not in release\nTo fix, install tinygrad from a git checkout with pip install -e .") from e \ No newline at end of file