mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
import tinygrad.frontend.torch (#9337)
* import tinygrad.frontend.torch * type ignore
This commit is contained in:
@@ -26,11 +26,8 @@ class Model(nn.Module):
|
||||
return self.lin(torch.flatten(x, 1))
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("TINY_BACKEND2"):
|
||||
import extra.torch_backend.backend2
|
||||
device = torch.device("cpu")
|
||||
elif getenv("TINY_BACKEND"):
|
||||
import extra.torch_backend.backend
|
||||
if getenv("TINY_BACKEND"):
|
||||
import tinygrad.frontend.torch
|
||||
device = torch.device("tiny")
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
if getenv("TINY_BACKEND"):
|
||||
import extra.torch_backend.backend # noqa: F401 # pylint: disable=unused-import
|
||||
import tinygrad.frontend.torch # noqa: F401 # pylint: disable=unused-import
|
||||
torch.set_default_device("tiny")
|
||||
|
||||
if CI:
|
||||
|
||||
0
tinygrad/frontend/__init__.py
Normal file
0
tinygrad/frontend/__init__.py
Normal file
5
tinygrad/frontend/torch.py
Normal file
5
tinygrad/frontend/torch.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# type: ignore
|
||||
import sys, pathlib
|
||||
sys.path.append(pathlib.Path(__file__).parent.parent.as_posix())
|
||||
try: import extra.torch_backend.backend # noqa: F401 # pylint: disable=unused-import
|
||||
except ImportError as e: raise ImportError("torch frontend not in release\nTo fix, install tinygrad from a git checkout with pip install -e .") from e
|
||||
Reference in New Issue
Block a user