refactor device, credit martinloretzzz

This commit is contained in:
George Hotz
2023-02-22 17:28:04 -08:00
parent a3ddc1d484
commit ac8daaeea5
2 changed files with 13 additions and 11 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
import sys, weakref, os, importlib, inspect
import sys, weakref, importlib, inspect
from weakref import WeakValueDictionary
from tinygrad.helpers import ConvArgs, prod, DEBUG
from tinygrad.shape import ShapeTracker
@@ -16,18 +16,21 @@ NOCONV = getenv("NOCONV", 0)
IMAGE = getenv("IMAGE", 0)
LAZY = getenv("LAZY", 1)
def get_buffer(name, base='tinygrad.llops'):
try:
return (name.upper(), [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.ops_{name}'), inspect.isclass) if (cname.lower() == name + "buffer")][0])
except ImportError as e: # NOTE: this can't be put on one line due to mypy issue
print(name, "backend not available", e, file=sys.stderr)
class _Device:
def __init__(self) -> None:
self._buffers : Dict[str, Type[DeviceBuffer]] = {x[0]:x[1] for x in [
get_buffer('cpu'), get_buffer('gpu'), get_buffer('llvm'), get_buffer('torch'),
get_buffer('triton', 'accel.triton')] if x is not None}
self.DEFAULT : str = "CPU"
self._buffers : Dict[str, Type[DeviceBuffer]] = {}
for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]:
name = op[len("ops_"):].upper()
if os.environ.get(name, 0) == "1": self.DEFAULT = name # note: DEFAULT can be a Device that can't be imported. better than silent use of a different device
try:
self._buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
self.__setattr__(name, name)
except ImportError as e: # NOTE: this can't be put on one line due to mypy issue
print(op, "not available", e, file=sys.stderr)
for name in self._buffers:
if getenv(name) == 1: self.DEFAULT = name # note: DEFAULT can be a Device that can't be imported. better than silent use of a different device
self.__setattr__(name, name)
Device = _Device()
# TODO: movement ops that only change shape are really nops. treat them as such

View File

@@ -1 +0,0 @@
../../accel/triton/ops_triton.py