typechecks

This commit is contained in:
George Hotz
2022-07-03 13:54:30 -07:00
parent cdf2be74f9
commit 99b287ed87
5 changed files with 29 additions and 25 deletions

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
import os
import functools
import numpy as np
import pyopencl as cl
from typing import List, Tuple, Optional
import pyopencl as cl # type: ignore
from typing import List, Tuple, Optional, Any
from tinygrad.helpers import prod, ConvArgs
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape
@@ -11,6 +11,8 @@ from tinygrad.ops import DEBUG
class CL:
CACHE = None
cl_ctx : Optional[cl.Context] = None
cl_queue : Optional[cl.CommandQueue] = None
def __init__(self):
if getattr(CL, "cl_queue", None) is not None: return
devices = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
@@ -29,7 +31,7 @@ class CL:
@functools.lru_cache(maxsize=None)
class CLProgram:
def __init__(self, name, prg, options=tuple(), argdtypes=None):
def __init__(self, name:str, prg:str, options=tuple(), argdtypes=None):
self.name, self.prg = name, prg
self.built = cl.Program(CL().cl_ctx, self.prg).build(options=options)
self.clprg = self.built.__getattr__(self.name)
@@ -51,7 +53,7 @@ class GPUBuffer:
def __init__(self, shape, hostbuf:Optional[GPUBuffer]=None):
self.st = ShapeTracker(shape)
self.shape = self.st.shape
self._buf = hostbuf._buf if hostbuf is not None else None
self._buf : cl.Buffer = hostbuf._buf if hostbuf is not None else None
@property
def cl(self):
@@ -102,7 +104,8 @@ class GPUBuffer:
# generate loops with combined adjacent reduce axis
acc = 1
loop_start, loop_end = [], []
loop_start : List[str] = []
loop_end : List[str] = []
for shp,stride in st.views[-1].shape_strides[::-1]:
if stride == 0:
loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {shp}; axis_{len(loop_start)}++) {{")
@@ -185,11 +188,6 @@ class GPUBuffer:
float acc = 0.0;
int gid = get_global_id(0);
"""+ints+conv_src+"""output[gid] = _ewop("""+','.join(["gid", "acc"]+[f"{name}_g" for name, _ in ewbufs])+""");
}""", options=tuple(options), argdtypes=tuple([None]*(1+len(bufs)) + [np.int32]*len(params)))
}""", options=tuple(options), argdtypes=tuple(None if i < 1+len(bufs) else np.int32 for i in range(1+len(bufs)+len(params))))
conv_prg(global_size, None, ret.cl, *[buf.cl for _, buf in bufs], *[x[1] for x in params])
return ret

View File

@@ -1,5 +1,5 @@
import torch
from tinygrad.llops.ops_cpu import CPUBuffer
from tinygrad.llops.ops_cpu import CPUBuffer # type: ignore
from tinygrad.ops import MovementOps, ProcessingOps
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from enum import Enum
from typing import Tuple, NamedTuple, Union, Any, List
from typing import Optional, Tuple, NamedTuple, Union, Any, List, Dict, Type
import functools, operator
from tinygrad.helpers import ConvArgs
from tinygrad.shapetracker import ShapeTracker
@@ -11,6 +11,7 @@ MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLI
ProcessingOps = Enum("ProcessingOps", ["CONV"])
LoadOps = Enum("LoadOps", ["FROMCPU"])
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[ProcessingOps], Type[LoadOps]]
# lazy can recurse a lot
import sys
@@ -20,7 +21,7 @@ import os
DEBUG = int(os.getenv("DEBUG", "0"))
GRAPH = int(os.getenv("GRAPH", "0"))
from collections import defaultdict
cnts = defaultdict(int)
cnts : Dict[Op, int] = defaultdict(int)
import atexit
if DEBUG:
@@ -29,7 +30,7 @@ if DEBUG:
atexit.register(debug_exit)
if GRAPH:
import networkx as nx
import networkx as nx # type: ignore
G = nx.DiGraph()
def save_graph_exit():
print("saving", G)
@@ -72,6 +73,8 @@ def log_op(optype, op, ret, inp):
# **** enumerate supported devices ****
import importlib, inspect
def find_buffer(llo, name):
return [cls for cname, cls in inspect.getmembers(llo, inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
class Device:
_ops = sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops")))
DEFAULT = None
@@ -81,11 +84,10 @@ class Device:
vars()[name] = name
DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT
try:
def find_buffer(llo, name): return [cls for cname, cls in inspect.getmembers(llo, inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
buffers[name] = find_buffer(importlib.import_module('tinygrad.llops.'+op), name)
except ImportError as e:
print(op, "not available", e)
DEFAULT = CPU if DEFAULT is None else DEFAULT
DEFAULT = "CPU" if DEFAULT is None else DEFAULT
# TODO: get device buffer types
DeviceBuffer = Any
@@ -115,7 +117,7 @@ def _realize(self:LazyBuffer) -> DeviceBuffer:
class LazyOp(NamedTuple):
op: Op
src: Tuple[Union[LazyOp, LazyBuffer]]
src: Tuple[Union[LazyOp, LazyBuffer], ...] # type: ignore
arg: Any = None
# TODO: add dest to support multiple outputs
@@ -125,11 +127,11 @@ def get_lazyops(op:LazyOp) -> List[LazyOp]: return functools.reduce(operator.add
LAZY = int(os.getenv("LAZY", "0"))
class LazyBuffer:
def __init__(self, device, shape:Union[ShapeTracker, Tuple[int]], optype:Op, op:LazyOp):
def __init__(self, device, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
self.shape = self.st.shape
self.optype, self.op = optype, op
self.realized = None
self.realized : Optional[DeviceBuffer] = None
self.device = device
if not LAZY: self.realize()

View File

@@ -9,7 +9,7 @@ def divmodidx(acc, d, mod=True):
return f"({lr}%{d})" if mod else lr # don't mod the top shape dimension
@functools.lru_cache(maxsize=None)
def to_shape_strides(shape:Tuple[int], strides:Tuple[int]) -> List[Tuple[int, int]]:
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tuple[int, int]]:
assert len(shape) == len(strides)
ret = [(shape[0], strides[0])]
for i in range(1, len(shape)):
@@ -55,6 +55,8 @@ class ZeroView:
acc *= self.shape[0]
self.expr = 'valid=' + ' && '.join(expr)
ViewTypes = Union[View, ZeroView]
@functools.lru_cache(maxsize=None)
def strides_for_shape(shape):
strides = [1]
@@ -69,9 +71,8 @@ def view_from_shape(shape:Tuple):
return View(tuple(shape), strides_for_shape(shape))
class ShapeTracker:
def __init__(self, shape:Union[ShapeTracker, Tuple[int]]):
if isinstance(shape, ShapeTracker): self.views = shape.views[:]
else: self.views = [view_from_shape(shape)]
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]]):
self.views : List[ViewTypes] = shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)]
@property
def contiguous(self):

View File

@@ -337,12 +337,15 @@ class Function:
self.requires_grad = any(self.needs_input_grad)
self.saved_tensors = []
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise NotImplementedError(f"backward not implemented for {type(self)}")
def save_for_backward(self, *x):
# NOTE: it doesn't hurt to save this since the ctx will be freed fast without grad
self.saved_tensors.extend(x)
@classmethod
def apply(cls, *x:List[Tensor], **kwargs):
def apply(cls, *x:Tensor, **kwargs):
ctx = cls(x[0].device, *x)
ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs),
device=ctx.device, requires_grad=ctx.requires_grad)