mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
refactor device, credit martinloretzzz
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
|
||||
import sys, weakref, os, importlib, inspect
|
||||
import sys, weakref, importlib, inspect
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.helpers import ConvArgs, prod, DEBUG
|
||||
from tinygrad.shape import ShapeTracker
|
||||
@@ -16,18 +16,21 @@ NOCONV = getenv("NOCONV", 0)
|
||||
IMAGE = getenv("IMAGE", 0)
|
||||
LAZY = getenv("LAZY", 1)
|
||||
|
||||
def get_buffer(name, base='tinygrad.llops'):
|
||||
try:
|
||||
return (name.upper(), [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.ops_{name}'), inspect.isclass) if (cname.lower() == name + "buffer")][0])
|
||||
except ImportError as e: # NOTE: this can't be put on one line due to mypy issue
|
||||
print(name, "backend not available", e, file=sys.stderr)
|
||||
|
||||
class _Device:
|
||||
def __init__(self) -> None:
|
||||
self._buffers : Dict[str, Type[DeviceBuffer]] = {x[0]:x[1] for x in [
|
||||
get_buffer('cpu'), get_buffer('gpu'), get_buffer('llvm'), get_buffer('torch'),
|
||||
get_buffer('triton', 'accel.triton')] if x is not None}
|
||||
self.DEFAULT : str = "CPU"
|
||||
self._buffers : Dict[str, Type[DeviceBuffer]] = {}
|
||||
for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]:
|
||||
name = op[len("ops_"):].upper()
|
||||
if os.environ.get(name, 0) == "1": self.DEFAULT = name # note: DEFAULT can be a Device that can't be imported. better than silent use of a different device
|
||||
try:
|
||||
self._buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
|
||||
self.__setattr__(name, name)
|
||||
except ImportError as e: # NOTE: this can't be put on one line due to mypy issue
|
||||
print(op, "not available", e, file=sys.stderr)
|
||||
for name in self._buffers:
|
||||
if getenv(name) == 1: self.DEFAULT = name # note: DEFAULT can be a Device that can't be imported. better than silent use of a different device
|
||||
self.__setattr__(name, name)
|
||||
Device = _Device()
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../accel/triton/ops_triton.py
|
||||
Reference in New Issue
Block a user