separate realize functions for different ops

This commit is contained in:
George Hotz
2022-07-04 09:07:22 -07:00
parent 81b73f97a3
commit 0bdb021880
2 changed files with 60 additions and 55 deletions

View File

@@ -1,5 +1,5 @@
from collections import namedtuple
import math
import os, math
def prod(x): return math.prod(x)
@@ -25,3 +25,14 @@ def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1):
if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})")
assert cout % groups == 0
return ConvArgs(H, W, groups, cout//groups, cin, oy, ox, iy, ix, sy, sx, bs, cout, py, py_, px, px_, dy, dx, (bs, cout, oy, ox))
def get_available_llops():
import importlib, inspect
_buffers, DEFAULT = {}, "CPU"
for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]:
name = op[len("ops_"):].upper()
DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT
try: _buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
except ImportError as e:
print(op, "not available", e)
return _buffers, DEFAULT

View File

@@ -3,7 +3,7 @@ from enum import Enum
from typing import Optional, Tuple, NamedTuple, Union, Any, List, Dict, Type
from copy import copy
import os, sys, functools, operator
from tinygrad.helpers import ConvArgs
from tinygrad.helpers import ConvArgs, get_available_llops
from tinygrad.shapetracker import ShapeTracker
# lazy can recurse a lot
@@ -78,65 +78,59 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device
# **** enumerate supported devices ****
def get_buffers():
import importlib, inspect
_buffers, DEFAULT = {}, "CPU"
for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]:
name = op[len("ops_"):].upper()
DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT
try: _buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
except ImportError as e:
print(op, "not available", e)
return _buffers, DEFAULT
class Device:
_buffers, DEFAULT = get_buffers()
_buffers, DEFAULT = get_available_llops()
for name in _buffers.keys():
vars()[name] = name
# TODO: make a _realize function for each type called by realize
def _realize(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]:
if self.optype == LoadOps and self.op.op == LoadOps.FROMCPU:
return Device._buffers[self.device].fromCPU(self.op.arg), []
elif self.optype == ReduceOps:
real_src = self.op.src[0].realize(self.device)
return real_src.reduce_op(self.op.op, self.op.arg), [real_src]
elif self.optype == MovementOps:
real_src = get_lazybuffers(self.op)[0].realize(self.device)
if getattr(real_src, "shapeTrackerView", None) is not None:
return real_src.shapeTrackerView(self.st), [real_src]
else:
# slow path, creates middle buffers
return functools.reduce(lambda x,o: x.movement_op(o.op, o.arg), get_lazyops(self.op)[::-1], real_src), [real_src]
elif self.optype == BinaryOps:
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {}
[real_srcs.setdefault(x,x.realize(self.device)) for x in get_lazybuffers(self.op) if x not in real_srcs]
if getattr(Device._buffers[self.device], "_processing_op", None) is not None:
buf_names : Dict[DeviceBuffer, str] = {x:f"arg_{i}" for i, x in enumerate(real_srcs.values())}
def _realize_loadops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]:
assert self.op.op == LoadOps.FROMCPU
return Device._buffers[self.device].fromCPU(self.op.arg), []
def ast_op(op, srcs_code: List[str]) -> str:
code = Device._buffers[self.device].code_for_op[op]
if len(srcs_code) >= 1: code = code.replace("A", srcs_code[0])
if len(srcs_code) >= 2: code = code.replace("B", srcs_code[1])
return code
def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]:
real_src = self.op.src[0].realize(self.device)
return real_src.reduce_op(self.op.op, self.op.arg), [real_src]
def _ast(x: Union[LazyBuffer, LazyOp]) -> str:
if isinstance(x, LazyBuffer): return buf_names[real_srcs[x]]
return ast_op(x.op, [_ast(src) for src in x.src])
def _realize_movementops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]:
real_src = get_lazybuffers(self.op)[0].realize(self.device)
if getattr(real_src, "shapeTrackerView", None) is not None:
return real_src.shapeTrackerView(self.st), [real_src]
else:
# slow path, creates middle buffers
return functools.reduce(lambda x,o: x.movement_op(o.op, o.arg), get_lazyops(self.op)[::-1], real_src), [real_src]
return Device._buffers[self.device](self.shape)._processing_op([(y,x) for (x,y) in buf_names.items()], _ast(self.op)), list(real_srcs.values())
else:
# slow path, creates middle buffers
def ast_eval(x: Union[LazyBuffer, LazyOp]) -> DeviceBuffer:
if isinstance(x, LazyBuffer): return real_srcs[x]
if isinstance(x.op, UnaryOps): return ast_eval(x.src[0]).unary_op(x.op)
if isinstance(x.op, BinaryOps): return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1]))
return ast_eval(self.op), list(real_srcs.values())
elif self.optype == ProcessingOps:
real_src_x = self.op.src[0].realize(self.device)
real_src_w = self.op.src[1].realize(self.device)
return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w]
else: raise NotImplementedError(f"can't handle optype {self.optype}")
def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]:
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {}
# TODO: if there's *one* processing op in here, we can corealize it. we can corealize binary op sibilings as well
[real_srcs.setdefault(x,x.realize(self.device)) for x in get_lazybuffers(self.op) if x not in real_srcs]
if getattr(Device._buffers[self.device], "_processing_op", None) is not None:
buf_names : Dict[DeviceBuffer, str] = {x:f"arg_{i}" for i, x in enumerate(real_srcs.values())}
def ast_op(op, srcs_code: List[str]) -> str:
code = Device._buffers[self.device].code_for_op[op]
if len(srcs_code) >= 1: code = code.replace("A", srcs_code[0])
if len(srcs_code) >= 2: code = code.replace("B", srcs_code[1])
return code
def _ast(x: Union[LazyBuffer, LazyOp]) -> str:
if isinstance(x, LazyBuffer): return buf_names[real_srcs[x]]
return ast_op(x.op, [_ast(src) for src in x.src])
return Device._buffers[self.device](self.shape)._processing_op([(y,x) for (x,y) in buf_names.items()], _ast(self.op)), list(real_srcs.values())
else:
# slow path, creates middle buffers
def ast_eval(x: Union[LazyBuffer, LazyOp]) -> DeviceBuffer:
if isinstance(x, LazyBuffer): return real_srcs[x]
if isinstance(x.op, UnaryOps): return ast_eval(x.src[0]).unary_op(x.op)
if isinstance(x.op, BinaryOps): return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1]))
return ast_eval(self.op), list(real_srcs.values())
def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]:
real_src_x = self.op.src[0].realize(self.device)
real_src_w = self.op.src[1].realize(self.device)
return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w]
_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops}
# **** lazy operations ****
@@ -165,7 +159,7 @@ class LazyBuffer:
if required_device is not None: assert required_device == self.device
if self.realized is None:
# we haven't realized the Buffer yet
self.realized, real_srcs = _realize(self)
self.realized, real_srcs = _realize[self.optype](self)
# in lazy mode, we don't log until we realize
log_op(self.optype, [x.op for x in get_lazyops(self.op)], self.realized, real_srcs)
# no need to keep the op after realization