mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
separate realize functions for different ops
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from collections import namedtuple
|
||||
import math
|
||||
import os, math
|
||||
|
||||
def prod(x): return math.prod(x)
|
||||
|
||||
@@ -25,3 +25,14 @@ def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1):
|
||||
if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})")
|
||||
assert cout % groups == 0
|
||||
return ConvArgs(H, W, groups, cout//groups, cin, oy, ox, iy, ix, sy, sx, bs, cout, py, py_, px, px_, dy, dx, (bs, cout, oy, ox))
|
||||
|
||||
def get_available_llops():
|
||||
import importlib, inspect
|
||||
_buffers, DEFAULT = {}, "CPU"
|
||||
for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]:
|
||||
name = op[len("ops_"):].upper()
|
||||
DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT
|
||||
try: _buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
|
||||
except ImportError as e:
|
||||
print(op, "not available", e)
|
||||
return _buffers, DEFAULT
|
||||
102
tinygrad/ops.py
102
tinygrad/ops.py
@@ -3,7 +3,7 @@ from enum import Enum
|
||||
from typing import Optional, Tuple, NamedTuple, Union, Any, List, Dict, Type
|
||||
from copy import copy
|
||||
import os, sys, functools, operator
|
||||
from tinygrad.helpers import ConvArgs
|
||||
from tinygrad.helpers import ConvArgs, get_available_llops
|
||||
from tinygrad.shapetracker import ShapeTracker
|
||||
|
||||
# lazy can recurse a lot
|
||||
@@ -78,65 +78,59 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device
|
||||
|
||||
# **** enumerate supported devices ****
|
||||
|
||||
def get_buffers():
|
||||
import importlib, inspect
|
||||
_buffers, DEFAULT = {}, "CPU"
|
||||
for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]:
|
||||
name = op[len("ops_"):].upper()
|
||||
DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT
|
||||
try: _buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
|
||||
except ImportError as e:
|
||||
print(op, "not available", e)
|
||||
return _buffers, DEFAULT
|
||||
|
||||
class Device:
|
||||
_buffers, DEFAULT = get_buffers()
|
||||
_buffers, DEFAULT = get_available_llops()
|
||||
for name in _buffers.keys():
|
||||
vars()[name] = name
|
||||
|
||||
# TODO: make a _realize function for each type called by realize
|
||||
def _realize(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]:
|
||||
if self.optype == LoadOps and self.op.op == LoadOps.FROMCPU:
|
||||
return Device._buffers[self.device].fromCPU(self.op.arg), []
|
||||
elif self.optype == ReduceOps:
|
||||
real_src = self.op.src[0].realize(self.device)
|
||||
return real_src.reduce_op(self.op.op, self.op.arg), [real_src]
|
||||
elif self.optype == MovementOps:
|
||||
real_src = get_lazybuffers(self.op)[0].realize(self.device)
|
||||
if getattr(real_src, "shapeTrackerView", None) is not None:
|
||||
return real_src.shapeTrackerView(self.st), [real_src]
|
||||
else:
|
||||
# slow path, creates middle buffers
|
||||
return functools.reduce(lambda x,o: x.movement_op(o.op, o.arg), get_lazyops(self.op)[::-1], real_src), [real_src]
|
||||
elif self.optype == BinaryOps:
|
||||
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {}
|
||||
[real_srcs.setdefault(x,x.realize(self.device)) for x in get_lazybuffers(self.op) if x not in real_srcs]
|
||||
if getattr(Device._buffers[self.device], "_processing_op", None) is not None:
|
||||
buf_names : Dict[DeviceBuffer, str] = {x:f"arg_{i}" for i, x in enumerate(real_srcs.values())}
|
||||
def _realize_loadops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]:
|
||||
assert self.op.op == LoadOps.FROMCPU
|
||||
return Device._buffers[self.device].fromCPU(self.op.arg), []
|
||||
|
||||
def ast_op(op, srcs_code: List[str]) -> str:
|
||||
code = Device._buffers[self.device].code_for_op[op]
|
||||
if len(srcs_code) >= 1: code = code.replace("A", srcs_code[0])
|
||||
if len(srcs_code) >= 2: code = code.replace("B", srcs_code[1])
|
||||
return code
|
||||
def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]:
|
||||
real_src = self.op.src[0].realize(self.device)
|
||||
return real_src.reduce_op(self.op.op, self.op.arg), [real_src]
|
||||
|
||||
def _ast(x: Union[LazyBuffer, LazyOp]) -> str:
|
||||
if isinstance(x, LazyBuffer): return buf_names[real_srcs[x]]
|
||||
return ast_op(x.op, [_ast(src) for src in x.src])
|
||||
def _realize_movementops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]:
|
||||
real_src = get_lazybuffers(self.op)[0].realize(self.device)
|
||||
if getattr(real_src, "shapeTrackerView", None) is not None:
|
||||
return real_src.shapeTrackerView(self.st), [real_src]
|
||||
else:
|
||||
# slow path, creates middle buffers
|
||||
return functools.reduce(lambda x,o: x.movement_op(o.op, o.arg), get_lazyops(self.op)[::-1], real_src), [real_src]
|
||||
|
||||
return Device._buffers[self.device](self.shape)._processing_op([(y,x) for (x,y) in buf_names.items()], _ast(self.op)), list(real_srcs.values())
|
||||
else:
|
||||
# slow path, creates middle buffers
|
||||
def ast_eval(x: Union[LazyBuffer, LazyOp]) -> DeviceBuffer:
|
||||
if isinstance(x, LazyBuffer): return real_srcs[x]
|
||||
if isinstance(x.op, UnaryOps): return ast_eval(x.src[0]).unary_op(x.op)
|
||||
if isinstance(x.op, BinaryOps): return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1]))
|
||||
return ast_eval(self.op), list(real_srcs.values())
|
||||
elif self.optype == ProcessingOps:
|
||||
real_src_x = self.op.src[0].realize(self.device)
|
||||
real_src_w = self.op.src[1].realize(self.device)
|
||||
return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w]
|
||||
else: raise NotImplementedError(f"can't handle optype {self.optype}")
|
||||
def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]:
|
||||
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {}
|
||||
# TODO: if there's *one* processing op in here, we can corealize it. we can corealize binary op sibilings as well
|
||||
[real_srcs.setdefault(x,x.realize(self.device)) for x in get_lazybuffers(self.op) if x not in real_srcs]
|
||||
if getattr(Device._buffers[self.device], "_processing_op", None) is not None:
|
||||
buf_names : Dict[DeviceBuffer, str] = {x:f"arg_{i}" for i, x in enumerate(real_srcs.values())}
|
||||
|
||||
def ast_op(op, srcs_code: List[str]) -> str:
|
||||
code = Device._buffers[self.device].code_for_op[op]
|
||||
if len(srcs_code) >= 1: code = code.replace("A", srcs_code[0])
|
||||
if len(srcs_code) >= 2: code = code.replace("B", srcs_code[1])
|
||||
return code
|
||||
|
||||
def _ast(x: Union[LazyBuffer, LazyOp]) -> str:
|
||||
if isinstance(x, LazyBuffer): return buf_names[real_srcs[x]]
|
||||
return ast_op(x.op, [_ast(src) for src in x.src])
|
||||
|
||||
return Device._buffers[self.device](self.shape)._processing_op([(y,x) for (x,y) in buf_names.items()], _ast(self.op)), list(real_srcs.values())
|
||||
else:
|
||||
# slow path, creates middle buffers
|
||||
def ast_eval(x: Union[LazyBuffer, LazyOp]) -> DeviceBuffer:
|
||||
if isinstance(x, LazyBuffer): return real_srcs[x]
|
||||
if isinstance(x.op, UnaryOps): return ast_eval(x.src[0]).unary_op(x.op)
|
||||
if isinstance(x.op, BinaryOps): return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1]))
|
||||
return ast_eval(self.op), list(real_srcs.values())
|
||||
|
||||
def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer]]:
|
||||
real_src_x = self.op.src[0].realize(self.device)
|
||||
real_src_w = self.op.src[1].realize(self.device)
|
||||
return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w]
|
||||
|
||||
_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops}
|
||||
|
||||
# **** lazy operations ****
|
||||
|
||||
@@ -165,7 +159,7 @@ class LazyBuffer:
|
||||
if required_device is not None: assert required_device == self.device
|
||||
if self.realized is None:
|
||||
# we haven't realized the Buffer yet
|
||||
self.realized, real_srcs = _realize(self)
|
||||
self.realized, real_srcs = _realize[self.optype](self)
|
||||
# in lazy mode, we don't log until we realize
|
||||
log_op(self.optype, [x.op for x in get_lazyops(self.op)], self.realized, real_srcs)
|
||||
# no need to keep the op after realization
|
||||
|
||||
Reference in New Issue
Block a user