diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 5eeef3be16..f8e29d2fda 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,5 +1,5 @@ from collections import namedtuple -import math +import os, math def prod(x): return math.prod(x) @@ -25,3 +25,14 @@ def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1): if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})") assert cout % groups == 0 return ConvArgs(H, W, groups, cout//groups, cin, oy, ox, iy, ix, sy, sx, bs, cout, py, py_, px, px_, dy, dx, (bs, cout, oy, ox)) + +def get_available_llops(): + import importlib, inspect + _buffers, DEFAULT = {}, "CPU" + for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]: + name = op[len("ops_"):].upper() + DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT + try: _buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0] + except ImportError as e: + print(op, "not available", e) + return _buffers, DEFAULT \ No newline at end of file diff --git a/tinygrad/ops.py b/tinygrad/ops.py index aaf5bcfc6d..aa235a1bc1 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Optional, Tuple, NamedTuple, Union, Any, List, Dict, Type from copy import copy import os, sys, functools, operator -from tinygrad.helpers import ConvArgs +from tinygrad.helpers import ConvArgs, get_available_llops from tinygrad.shapetracker import ShapeTracker # lazy can recurse a lot @@ -78,65 +78,59 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device # **** enumerate supported devices **** -def get_buffers(): - import importlib, inspect - _buffers, DEFAULT = {}, "CPU" - for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]: - name = op[len("ops_"):].upper() - DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT - try: _buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0] - except ImportError as e: - print(op, "not available", e) - return _buffers, DEFAULT - class Device: - _buffers, DEFAULT = get_buffers() + _buffers, DEFAULT = get_available_llops() for name in _buffers.keys(): vars()[name] = name -# TODO: make a _realize function for each type called by realize -def _realize(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]: - if self.optype == LoadOps and self.op.op == LoadOps.FROMCPU: - return Device._buffers[self.device].fromCPU(self.op.arg), [] - elif self.optype == ReduceOps: - real_src = self.op.src[0].realize(self.device) - return real_src.reduce_op(self.op.op, self.op.arg), [real_src] - elif self.optype == MovementOps: - real_src = get_lazybuffers(self.op)[0].realize(self.device) - if getattr(real_src, "shapeTrackerView", None) is not None: - return real_src.shapeTrackerView(self.st), [real_src] - else: - # slow path, creates middle buffers - return functools.reduce(lambda x,o: x.movement_op(o.op, o.arg), get_lazyops(self.op)[::-1], real_src), [real_src] - elif self.optype == BinaryOps: - real_srcs : Dict[LazyBuffer, DeviceBuffer] = {} - [real_srcs.setdefault(x,x.realize(self.device)) for x in get_lazybuffers(self.op) if x not in real_srcs] - if getattr(Device._buffers[self.device], "_processing_op", None) is not None: - buf_names : Dict[DeviceBuffer, str] = {x:f"arg_{i}" for i, x in enumerate(real_srcs.values())} +def _realize_loadops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]: + assert self.op.op == LoadOps.FROMCPU + return Device._buffers[self.device].fromCPU(self.op.arg), [] - def ast_op(op, srcs_code: List[str]) -> str: - code = Device._buffers[self.device].code_for_op[op] - if len(srcs_code) >= 1: code = code.replace("A", srcs_code[0]) - if len(srcs_code) >= 2: code = code.replace("B", srcs_code[1]) - return code +def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]: + real_src = self.op.src[0].realize(self.device) + return real_src.reduce_op(self.op.op, self.op.arg), [real_src] - def _ast(x: Union[LazyBuffer, LazyOp]) -> str: - if isinstance(x, LazyBuffer): return buf_names[real_srcs[x]] - return ast_op(x.op, [_ast(src) for src in x.src]) +def _realize_movementops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]: + real_src = get_lazybuffers(self.op)[0].realize(self.device) + if getattr(real_src, "shapeTrackerView", None) is not None: + return real_src.shapeTrackerView(self.st), [real_src] + else: + # slow path, creates middle buffers + return functools.reduce(lambda x,o: x.movement_op(o.op, o.arg), get_lazyops(self.op)[::-1], real_src), [real_src] - return Device._buffers[self.device](self.shape)._processing_op([(y,x) for (x,y) in buf_names.items()], _ast(self.op)), list(real_srcs.values()) - else: - # slow path, creates middle buffers - def ast_eval(x: Union[LazyBuffer, LazyOp]) -> DeviceBuffer: - if isinstance(x, LazyBuffer): return real_srcs[x] - if isinstance(x.op, UnaryOps): return ast_eval(x.src[0]).unary_op(x.op) - if isinstance(x.op, BinaryOps): return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1])) - return ast_eval(self.op), list(real_srcs.values()) - elif self.optype == ProcessingOps: - real_src_x = self.op.src[0].realize(self.device) - real_src_w = self.op.src[1].realize(self.device) - return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w] - else: raise NotImplementedError(f"can't handle optype {self.optype}") +def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]: + real_srcs : Dict[LazyBuffer, DeviceBuffer] = {} + # TODO: if there's *one* processing op in here, we can corealize it. we can corealize binary op sibilings as well + [real_srcs.setdefault(x,x.realize(self.device)) for x in get_lazybuffers(self.op) if x not in real_srcs] + if getattr(Device._buffers[self.device], "_processing_op", None) is not None: + buf_names : Dict[DeviceBuffer, str] = {x:f"arg_{i}" for i, x in enumerate(real_srcs.values())} + + def ast_op(op, srcs_code: List[str]) -> str: + code = Device._buffers[self.device].code_for_op[op] + if len(srcs_code) >= 1: code = code.replace("A", srcs_code[0]) + if len(srcs_code) >= 2: code = code.replace("B", srcs_code[1]) + return code + + def _ast(x: Union[LazyBuffer, LazyOp]) -> str: + if isinstance(x, LazyBuffer): return buf_names[real_srcs[x]] + return ast_op(x.op, [_ast(src) for src in x.src]) + + return Device._buffers[self.device](self.shape)._processing_op([(y,x) for (x,y) in buf_names.items()], _ast(self.op)), list(real_srcs.values()) + else: + # slow path, creates middle buffers + def ast_eval(x: Union[LazyBuffer, LazyOp]) -> DeviceBuffer: + if isinstance(x, LazyBuffer): return real_srcs[x] + if isinstance(x.op, UnaryOps): return ast_eval(x.src[0]).unary_op(x.op) + if isinstance(x.op, BinaryOps): return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1])) + return ast_eval(self.op), list(real_srcs.values()) + +def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]: + real_src_x = self.op.src[0].realize(self.device) + real_src_w = self.op.src[1].realize(self.device) + return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w] + +_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops} # **** lazy operations **** @@ -165,7 +159,7 @@ class LazyBuffer: if required_device is not None: assert required_device == self.device if self.realized is None: # we haven't realized the Buffer yet - self.realized, real_srcs = _realize(self) + self.realized, real_srcs = _realize[self.optype](self) # in lazy mode, we don't log until we realize log_op(self.optype, [x.op for x in get_lazyops(self.op)], self.realized, real_srcs) # no need to keep the op after realization