From ac8daaeea540f0de8d0af33b44245d7478fee7ab Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 22 Feb 2023 17:28:04 -0800 Subject: [PATCH] refactor device, credit martinloretzzz --- tinygrad/lazy.py | 23 +++++++++++++---------- tinygrad/llops/ops_triton.py | 1 - 2 files changed, 13 insertions(+), 11 deletions(-) delete mode 120000 tinygrad/llops/ops_triton.py diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index f162333327..0a9fcc7ae2 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type -import sys, weakref, os, importlib, inspect +import sys, weakref, importlib, inspect from weakref import WeakValueDictionary from tinygrad.helpers import ConvArgs, prod, DEBUG from tinygrad.shape import ShapeTracker @@ -16,18 +16,21 @@ NOCONV = getenv("NOCONV", 0) IMAGE = getenv("IMAGE", 0) LAZY = getenv("LAZY", 1) +def get_buffer(name, base='tinygrad.llops'): + try: + return (name.upper(), [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.ops_{name}'), inspect.isclass) if (cname.lower() == name + "buffer")][0]) + except ImportError as e: # NOTE: this can't be put on one line due to mypy issue + print(name, "backend not available", e, file=sys.stderr) + class _Device: def __init__(self) -> None: + self._buffers : Dict[str, Type[DeviceBuffer]] = {x[0]:x[1] for x in [ + get_buffer('cpu'), get_buffer('gpu'), get_buffer('llvm'), get_buffer('torch'), + get_buffer('triton', 'accel.triton')] if x is not None} self.DEFAULT : str = "CPU" - self._buffers : Dict[str, Type[DeviceBuffer]] = {} - for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]: - name = op[len("ops_"):].upper() - if os.environ.get(name, 0) == "1": self.DEFAULT = name # note: DEFAULT can be a Device that can't be imported. better than silent use of a different device - try: - self._buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0] - self.__setattr__(name, name) - except ImportError as e: # NOTE: this can't be put on one line due to mypy issue - print(op, "not available", e, file=sys.stderr) + for name in self._buffers: + if getenv(name) == 1: self.DEFAULT = name # note: DEFAULT can be a Device that can't be imported. better than silent use of a different device + self.__setattr__(name, name) Device = _Device() # TODO: movement ops that only change shape are really nops. treat them as such diff --git a/tinygrad/llops/ops_triton.py b/tinygrad/llops/ops_triton.py deleted file mode 120000 index e0d70a618a..0000000000 --- a/tinygrad/llops/ops_triton.py +++ /dev/null @@ -1 +0,0 @@ -../../accel/triton/ops_triton.py \ No newline at end of file