mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move Device back into lazy
This commit is contained in:
@@ -128,8 +128,8 @@ class TritonBuffer(ExplicitExecAST):
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp):
|
||||
k = TritonASTKernel(ast)
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[TritonBuffer]=None):
|
||||
k = TritonASTKernel(ast, output_buffer)
|
||||
k.codegen()(*k.bufs)
|
||||
return k.ret
|
||||
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
def get_available_llops():
|
||||
import os, importlib, inspect
|
||||
_buffers, DEFAULT = {}, "CPU"
|
||||
for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]:
|
||||
name = op[len("ops_"):].upper()
|
||||
DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT
|
||||
try:
|
||||
_buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
|
||||
except ImportError as e: # NOTE: this can't be put on one line due to mypy issue
|
||||
print(op, "not available", e)
|
||||
return _buffers, DEFAULT
|
||||
|
||||
class Device:
|
||||
_buffers, DEFAULT = get_available_llops()
|
||||
for name in _buffers.keys():
|
||||
vars()[name] = name
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar
|
||||
import sys, weakref
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
|
||||
import sys, weakref, os, importlib, inspect
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.helpers import ConvArgs, prod
|
||||
from tinygrad.shape import ShapeTracker
|
||||
@@ -16,8 +16,19 @@ NOCONV = getenv("NOCONV", 0)
|
||||
IMAGE = getenv("IMAGE", 0)
|
||||
LAZY = getenv("LAZY", 1)
|
||||
|
||||
# late import of Device
|
||||
from tinygrad.device import Device
|
||||
class _Device:
|
||||
def __init__(self) -> None:
|
||||
self.DEFAULT : str = "CPU"
|
||||
self._buffers : Dict[str, Type[DeviceBuffer]] = {}
|
||||
for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]:
|
||||
name = op[len("ops_"):].upper()
|
||||
if os.environ.get(name, 0) == "1": self.DEFAULT = name # note: DEFAULT can be a Device that can't be imported. better than silent use of a different device
|
||||
try:
|
||||
self._buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
|
||||
self.__setattr__(name, name)
|
||||
except ImportError as e: # NOTE: this can't be put on one line due to mypy issue
|
||||
print(op, "not available", e)
|
||||
Device = _Device()
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
|
||||
|
||||
@@ -42,7 +42,7 @@ class DeviceBuffer:
|
||||
def fromCPU(x:np.ndarray) -> DeviceBuffer: raise NotImplementedError("must be implemented")
|
||||
def toCPU(self:DeviceBuffer) -> np.ndarray: raise NotImplementedError("must be implemented")
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp): raise NotImplementedError("must be implemented")
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer=None): raise NotImplementedError("must be implemented")
|
||||
|
||||
# this is a quick "buffer" class for flop tracking
|
||||
class GenericShape(NamedTuple):
|
||||
@@ -64,8 +64,8 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
|
||||
def contiguous(self): return type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,)))
|
||||
def movement_op(self, op:MovementOps, arg=None): return type(self)(self.fxn_for_op[op](self.buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self.buf, op.name.lower())(arg))
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GenericExecAST]=None, preprocess=lambda x: x):
|
||||
srcs = [cls.exec_ast(x, preprocess=preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src]
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GenericExecAST]=None):
|
||||
srcs = [cls.exec_ast(x) if isinstance(x, LazyOp) else x for x in ast.src]
|
||||
if ast.op in BinaryOps: assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}"
|
||||
if ast.op in ReduceOps: assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}"
|
||||
if ast.op in MovementOps: ret = srcs[0].movement_op(ast.op, ast.arg)
|
||||
@@ -76,7 +76,7 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
|
||||
return output_buffer
|
||||
else:
|
||||
return ret
|
||||
def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(ast, preprocess=lambda x: GenericExecAST(GenericShape(x.shape))).buf
|
||||
def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(map_buffers({x:GenericExecAST(GenericShape(x.shape)) for x in get_buffers(ast)}, ast)).buf
|
||||
|
||||
class GlobalCounters:
|
||||
global_ops : ClassVar[int] = 0
|
||||
|
||||
Reference in New Issue
Block a user