move Device back into lazy

This commit is contained in:
George Hotz
2023-02-11 11:26:53 -08:00
parent 9152bb5b4a
commit b9eae94ae9
4 changed files with 21 additions and 26 deletions

View File

@@ -128,8 +128,8 @@ class TritonBuffer(ExplicitExecAST):
return data
@classmethod
def exec_ast(cls, ast:LazyOp):
k = TritonASTKernel(ast)
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[TritonBuffer]=None):
k = TritonASTKernel(ast, output_buffer)
k.codegen()(*k.bufs)
return k.ret

View File

@@ -1,16 +0,0 @@
def get_available_llops():
import os, importlib, inspect
_buffers, DEFAULT = {}, "CPU"
for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]:
name = op[len("ops_"):].upper()
DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT
try:
_buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
except ImportError as e: # NOTE: this can't be put on one line due to mypy issue
print(op, "not available", e)
return _buffers, DEFAULT
class Device:
_buffers, DEFAULT = get_available_llops()
for name in _buffers.keys():
vars()[name] = name

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar
import sys, weakref
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
import sys, weakref, os, importlib, inspect
from weakref import WeakValueDictionary
from tinygrad.helpers import ConvArgs, prod
from tinygrad.shape import ShapeTracker
@@ -16,8 +16,19 @@ NOCONV = getenv("NOCONV", 0)
IMAGE = getenv("IMAGE", 0)
LAZY = getenv("LAZY", 1)
# late import of Device
from tinygrad.device import Device
class _Device:
def __init__(self) -> None:
self.DEFAULT : str = "CPU"
self._buffers : Dict[str, Type[DeviceBuffer]] = {}
for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]:
name = op[len("ops_"):].upper()
if os.environ.get(name, 0) == "1": self.DEFAULT = name # note: DEFAULT can be a Device that can't be imported. better than silent use of a different device
try:
self._buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
self.__setattr__(name, name)
except ImportError as e: # NOTE: this can't be put on one line due to mypy issue
print(op, "not available", e)
Device = _Device()
# TODO: movement ops that only change shape are really nops. treat them as such
REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1

View File

@@ -42,7 +42,7 @@ class DeviceBuffer:
def fromCPU(x:np.ndarray) -> DeviceBuffer: raise NotImplementedError("must be implemented")
def toCPU(self:DeviceBuffer) -> np.ndarray: raise NotImplementedError("must be implemented")
@classmethod
def exec_ast(cls, ast:LazyOp): raise NotImplementedError("must be implemented")
def exec_ast(cls, ast:LazyOp, output_buffer=None): raise NotImplementedError("must be implemented")
# this is a quick "buffer" class for flop tracking
class GenericShape(NamedTuple):
@@ -64,8 +64,8 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
def contiguous(self): return type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,)))
def movement_op(self, op:MovementOps, arg=None): return type(self)(self.fxn_for_op[op](self.buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self.buf, op.name.lower())(arg))
@classmethod
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GenericExecAST]=None, preprocess=lambda x: x):
srcs = [cls.exec_ast(x, preprocess=preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src]
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GenericExecAST]=None):
srcs = [cls.exec_ast(x) if isinstance(x, LazyOp) else x for x in ast.src]
if ast.op in BinaryOps: assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}"
if ast.op in ReduceOps: assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}"
if ast.op in MovementOps: ret = srcs[0].movement_op(ast.op, ast.arg)
@@ -76,7 +76,7 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
return output_buffer
else:
return ret
def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(ast, preprocess=lambda x: GenericExecAST(GenericShape(x.shape))).buf
def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(map_buffers({x:GenericExecAST(GenericShape(x.shape)) for x in get_buffers(ast)}, ast)).buf
class GlobalCounters:
global_ops : ClassVar[int] = 0