From 8fbe2e4aeda62da705ee3ecc9f6cce8f4316317d Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 21 Jun 2022 10:07:49 -0700 Subject: [PATCH] No ctx in llops (#345) * remove ctx from gpu ops * ctx for the others * this is okay * mlops are not static. fix lazy * cl is property, _processing_op is class method * kernel_name * contiguous_op --- accel/lazy/ops_lazy.py | 70 +++++----- tinygrad/helpers.py | 4 +- tinygrad/llops/ops_cpu.py | 108 +++++++-------- tinygrad/llops/ops_gpu.py | 262 ++++++++++++++++++------------------ tinygrad/llops/ops_torch.py | 16 +-- tinygrad/mlops.py | 28 ---- tinygrad/ops.py | 10 +- tinygrad/tensor.py | 5 +- 8 files changed, 239 insertions(+), 264 deletions(-) diff --git a/accel/lazy/ops_lazy.py b/accel/lazy/ops_lazy.py index cc23bf5d8a..78f73d542b 100644 --- a/accel/lazy/ops_lazy.py +++ b/accel/lazy/ops_lazy.py @@ -2,6 +2,9 @@ from __future__ import annotations from typing import Union, NamedTuple, List, Any, Tuple, Dict from tinygrad.shapetracker import ShapeTracker import functools, operator +from tinygrad.helpers import prod +import sys +sys.setrecursionlimit(10000) from tinygrad.ops import ReduceOps, BinaryOps, MovementOps, ProcessingOps, log_op, DEBUG, GRAPH from enum import Enum @@ -70,6 +73,36 @@ class LazyBuffer: def toCPU(self): return self.realize().toCPU() + def unary_op(x, op): return elementwise_op(op, (x,)) + def binary_op(x, op, y:LazyBuffer): return elementwise_op(op, (x,y)) + def contiguous_op(x): return x if x.st.contiguous else LazyBuffer(x.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (x,))) + + @functools.lru_cache(maxsize=None) + def movement_op(x, op:MovementOps, arg) -> LazyBuffer: + if SHUFFLE_MOVEMENT_OPS and x.optype == BinaryOps: + # if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead + def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer: + if isinstance(y, LazyBuffer): return y.movement_op(op, arg) + return elementwise_op(y.op, tuple(replace_with_movement_op(z) for z in y.src)) + return replace_with_movement_op(x.op) + + # if a MovementOp is applied to a MovementOp, merge them and use one buffer + ret = LazyBuffer(x.st, MovementOps, LazyOp(op, (x.op if MERGE_MOVEMENT_OPS and x.optype == MovementOps and x.realized is None else x,), arg)) + ret.shape = ret.st.movement_op(op, arg).shape # update the shape after we modify the ShapeTracker + + if REMOVE_MOVEMENT_NOPS and x.optype == MovementOps and x.realized is None and ret.st.contiguous: + root = get_root(x.op) + if ret.st.shape == root.shape: + return root + + return ret + + def reduce_op(x, op, new_shape:Tuple[int]): + return LazyBuffer(new_shape, ReduceOps, LazyOp(op, (x,), new_shape)) + + def processing_op(x, op, w:LazyBuffer, C): + return LazyBuffer(C.out_shape, ProcessingOps, LazyOp(op, (x.contiguous_op(), w.contiguous_op()), C)) + def ast_op(op: Op, srcs_code: List[str]) -> str: code = gops.code_for_op[op] if len(srcs_code) >= 1: code = code.replace("A", srcs_code[0]) @@ -117,37 +150,6 @@ def elementwise_op(op, srcs:Tuple[LazyBuffer]) -> LazyBuffer: return LazyBuffer(out_shape, BinaryOps, LazyOp(op, srcs)) -def unary_op(op, x): return elementwise_op(op, (x,)) -def binary_op(op, x, y): return elementwise_op(op, (x,y)) - -@functools.lru_cache(maxsize=None) -def movement_op(op:MovementOps, x:LazyBuffer, arg): - if SHUFFLE_MOVEMENT_OPS and x.optype == BinaryOps: - # if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead - def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer: - if isinstance(y, LazyBuffer): return movement_op(op, y, arg) - return elementwise_op(y.op, tuple(replace_with_movement_op(z) for z in y.src)) - return replace_with_movement_op(x.op) - - # if a MovementOp is applied to a MovementOp, merge them and use one buffer - ret = LazyBuffer(x.st, MovementOps, LazyOp(op, (x.op if MERGE_MOVEMENT_OPS and x.optype == MovementOps and x.realized is None else x,), arg)) - ret.shape = ret.st.movement_op(op, arg).shape # update the shape after we modify the ShapeTracker - - if REMOVE_MOVEMENT_NOPS and x.optype == MovementOps and x.realized is None and ret.st.contiguous: - root = get_root(x.op) - if ret.st.shape == root.shape: - return root - - return ret - -def reduce_op(op, x, new_shape): - return LazyBuffer(new_shape, ReduceOps, LazyOp(op, (x,), new_shape)) - -def processing_op(op, x, w, C): - if not x.st.contiguous: x = LazyBuffer(x.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (x,))) - if not w.st.contiguous: w = LazyBuffer(w.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (w,))) - return LazyBuffer(C.out_shape, ProcessingOps, LazyOp(op, (x, w), C)) - # these functions determines the backing buffer import tinygrad.llops.ops_gpu as gops @@ -181,7 +183,7 @@ def _realize_binary_op(self:LazyBuffer) -> Tuple[gops.GPUBuffer, List[gops.GPUBu real_dict[s] = f"arg_{len(real_srcs)}" real_srcs.append((f"arg_{len(real_srcs)}", s.realize())) code = ast(self.op, real_dict) - return gops._processing_op(self.shape, real_srcs, code, arg), [x[1] for x in real_srcs] + return gops.GPUBuffer(self.shape)._processing_op(real_srcs, code, arg), [x[1] for x in real_srcs] def _realize(self:LazyBuffer) -> Tuple[gops.GPUBuffer, List[gops.GPUBuffer]]: if self.optype == LoadOps and self.op.op == LoadOps.FROMCPU: @@ -189,10 +191,10 @@ def _realize(self:LazyBuffer) -> Tuple[gops.GPUBuffer, List[gops.GPUBuffer]]: return gops.GPUBuffer.fromCPU(self.op.arg), [] elif self.optype == LoadOps and self.op.op == LoadOps.CONTIGUOUS: real_src = self.op.src[0].realize() - return gops.contiguous(real_src), [real_src] + return real_src.contiguous(), [real_src] elif self.optype == ReduceOps: real_src = self.op.src[0].realize() - return gops.reduce_op(self.op.op, real_src, self.op.arg), [real_src] + return real_src.reduce_op(self.op.op, self.op.arg), [real_src] elif self.optype == MovementOps: real_src = get_root(self.op).realize() return gops.GPUBuffer(self.st, real_src), [real_src] diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 96acb7c29e..c025f208a9 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -5,7 +5,7 @@ def prod(x): return math.prod(x) def reduce_shape(shape, axis): return [1 if i in axis else shape[i] for i in range(len(shape))] -conv_args = namedtuple('conv_args', ['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs', 'cout', 'py', 'px', 'dy', 'dx', 'out_shape']) +ConvArgs = namedtuple('ConvArgs', ['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs', 'cout', 'py', 'px', 'dy', 'dx', 'out_shape']) def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1): # TODO: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout cout,cin,H,W = w_shape @@ -19,4 +19,4 @@ def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1): ox = (ix + 2*px - dx * (W-1) - 1)//xs + 1 if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})") assert cout % groups == 0 - return conv_args(H, W, groups, cout//groups, cin, oy, ox, iy, ix, ys, xs, bs, cout, py, px, dy, dx, (bs, cout, oy, ox)) + return ConvArgs(H, W, groups, cout//groups, cin, oy, ox, iy, ix, ys, xs, bs, cout, py, px, dy, dx, (bs, cout, oy, ox)) diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index a9d618ebe4..e807a4c750 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -17,60 +17,60 @@ class CPUBuffer(np.ndarray): def fromCPU(x): return x def toCPU(x): return x -def unary_op(op, x): - if op == UnaryOps.NOOP: return x[:] - elif op == UnaryOps.RELU: return x.relu() - elif op == UnaryOps.EXP: return x.exp() - elif op == UnaryOps.LOG: return x.log() - elif op == UnaryOps.NEG: return -x - elif op == UnaryOps.SIGN: return x.sign() - else: raise Exception(f"{op} isn't supported") - -def binary_op(op, x, y): - if op == BinaryOps.ADD: return x+y - elif op == BinaryOps.SUB: return x-y - elif op == BinaryOps.MUL: return x*y - elif op == BinaryOps.DIV: return y/x - elif op == BinaryOps.POW: return x**y - elif op == BinaryOps.CMPEQ: return 1.0*(x==y) - else: raise Exception(f"{op} isn't supported") - -def reduce_op(op, inp, new_shape): - if inp.shape == new_shape: # this is just a copy, regardless of the reduce op - return inp[:] - else: - if new_shape == (1,): # full reduce - axis = tuple(range(len(inp.shape))) - else: - assert len(inp.shape) == len(new_shape) - axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, new_shape)) if a != b]) - if op == ReduceOps.SUM: return inp.sum(axis, keepdims=True) - elif op == ReduceOps.MAX: return inp.amax(axis, keepdims=True) + def unary_op(x, op): + if op == UnaryOps.NOOP: return x[:] + elif op == UnaryOps.RELU: return x.relu() + elif op == UnaryOps.EXP: return x.exp() + elif op == UnaryOps.LOG: return x.log() + elif op == UnaryOps.NEG: return -x + elif op == UnaryOps.SIGN: return x.sign() else: raise Exception(f"{op} isn't supported") -def movement_op(op, x, arg=None): - if op == MovementOps.RESHAPE: return x.reshape(arg) - elif op == MovementOps.PERMUTE: return x.permute(arg) - elif op == MovementOps.FLIP: return x.flip(arg) - elif op == MovementOps.SLICE: - padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)] - slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)] - return x.custompad(padding)[tuple([slice(x[0], x[1], None) for x in slicee])] - elif op == MovementOps.EXPAND: return x.expand(arg) - else: raise Exception(f"{op} isn't supported") + def binary_op(x, op, y): + if op == BinaryOps.ADD: return x+y + elif op == BinaryOps.SUB: return x-y + elif op == BinaryOps.MUL: return x*y + elif op == BinaryOps.DIV: return y/x + elif op == BinaryOps.POW: return x**y + elif op == BinaryOps.CMPEQ: return 1.0*(x==y) + else: raise Exception(f"{op} isn't supported") -def processing_op(op,x,w,C): - assert op == ProcessingOps.CONV, f"{op} isn't supported" - if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)]) - gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3]) - tx = np.lib.stride_tricks.as_strided(gx, - shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W), - strides=(*gx.strides[0:3], gx.strides[3]*C.ys, gx.strides[4]*C.xs, gx.strides[3]*C.dy, gx.strides[4]*C.dx), - writeable=False, - ) - tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W) - tmp = np.empty((C.bs,C.groups,C.oy,C.ox,C.rcout),dtype=x.dtype) - for g in range(C.groups): - #ijYXyx,kjyx -> iYXk ->ikYX - tmp[:,g] = np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3))) - return np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer) + def reduce_op(x, op, new_shape): + if x.shape == new_shape: # this is just a copy, regardless of the reduce op + return x[:] + else: + if new_shape == (1,): # full reduce + axis = tuple(range(len(x.shape))) + else: + assert len(x.shape) == len(new_shape) + axis = tuple([i for i,(a,b) in enumerate(zip(x.shape, new_shape)) if a != b]) + if op == ReduceOps.SUM: return x.sum(axis, keepdims=True) + elif op == ReduceOps.MAX: return x.amax(axis, keepdims=True) + else: raise Exception(f"{op} isn't supported") + + def movement_op(x, op, arg=None): + if op == MovementOps.RESHAPE: return x.reshape(arg) + elif op == MovementOps.PERMUTE: return x.permute(arg) + elif op == MovementOps.FLIP: return x.flip(arg) + elif op == MovementOps.SLICE: + padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)] + slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)] + return x.custompad(padding)[tuple([slice(x[0], x[1], None) for x in slicee])] + elif op == MovementOps.EXPAND: return x.expand(arg) + else: raise Exception(f"{op} isn't supported") + + def processing_op(x,op,w,C): + assert op == ProcessingOps.CONV, f"{op} isn't supported" + if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)]) + gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3]) + tx = np.lib.stride_tricks.as_strided(gx, + shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W), + strides=(*gx.strides[0:3], gx.strides[3]*C.ys, gx.strides[4]*C.xs, gx.strides[3]*C.dy, gx.strides[4]*C.dx), + writeable=False, + ) + tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W) + tmp = np.empty((C.bs,C.groups,C.oy,C.ox,C.rcout),dtype=x.dtype) + for g in range(C.groups): + #ijYXyx,kjyx -> iYXk ->ikYX + tmp[:,g] = np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3))) + return np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index dccf34e5d5..48b857cbb8 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -1,9 +1,9 @@ +from __future__ import annotations import functools import numpy as np import pyopencl as cl -from typing import List, Tuple -from tinygrad.helpers import prod -from tinygrad.llops.ops_cpu import unary_op +from typing import List, Tuple, Optional +from tinygrad.helpers import prod, ConvArgs from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape @@ -19,12 +19,33 @@ def require_init_gpu(): cl_ctx = cl.Context(devices=devices) cl_queue = cl.CommandQueue(cl_ctx) # this is an in-order command queue +@functools.lru_cache(maxsize=None) +class CLProgram: + def __init__(self, name, prg, options=tuple(), argdtypes=None): + self.name = name + self.built = cl.Program(cl_ctx, prg).build(options=options) + self.clprg = self.built.__getattr__(name) + if argdtypes is not None: self.clprg.set_scalar_arg_dtypes(argdtypes) + def __call__(self, *args): + #print(f"running {self.name} with {args[0]} count {len(args)-2}") + self.clprg(cl_queue, *args) + +code_for_op = { + UnaryOps.NOOP: "(A)", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.NEG: "(-(A))", UnaryOps.SIGN: "sign(A)", + BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(B/A)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)", +} + class GPUBuffer: - def __init__(self, shape, hostbuf=None): + def __init__(self, shape, hostbuf:Optional[GPUBuffer]=None): require_init_gpu() self.st = ShapeTracker(shape) self.shape = self.st.shape - self.cl = hostbuf.cl if hostbuf is not None else cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE, 4*prod(self.shape)) + self._buf = hostbuf._buf if hostbuf is not None else None + + @property + def cl(self): + if self._buf is None: self._buf = cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE, 4*prod(self.shape)) + return self._buf def __repr__(self): return f"" @@ -38,143 +59,128 @@ class GPUBuffer: def toCPU(self): data = np.empty(self.shape, dtype=np.float32) - cl.enqueue_copy(cl_queue, data, contiguous(self).cl, is_blocking=True) + cl.enqueue_copy(cl_queue, data, self.contiguous_op().cl, is_blocking=True) return data -@functools.lru_cache(maxsize=None) -class CLProgram: - def __init__(self, name, prg, options=tuple(), argdtypes=None): - self.name = name - self.built = cl.Program(cl_ctx, prg).build(options=options) - self.clprg = self.built.__getattr__(name) - if argdtypes is not None: self.clprg.set_scalar_arg_dtypes(argdtypes) - def __call__(self, *args): - #print(f"running {self.name} with {args[0]} count {len(args)-2}") - self.clprg(cl_queue, *args) + def contiguous_view(x, name:str) -> str: + return f"inline float get_{name}(__global const float *x, int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? x[idx] : 0.0;}}" -def contiguous_view(name:str, x:GPUBuffer) -> str: - return f"inline float get_{name}(__global const float *x, int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? x[idx] : 0.0;}}" + def unary_op(x, op:UnaryOps): return type(x)(x.shape)._processing_op([("A", x)], code_for_op[op]) + def binary_op(x, op:BinaryOps, y:GPUBuffer): return type(x)(x.shape)._processing_op([("A", x), ("B", y)], code_for_op[op]) + def contiguous_op(x): return x if x.st.contiguous else x.unary_op(UnaryOps.NOOP) -def _processing_op(out_shape: Tuple[int], bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C=None): - ret = GPUBuffer(out_shape) - options = [] + def movement_op(x, op:MovementOps, arg) -> GPUBuffer: + ret = GPUBuffer(x.st, x) + ret.shape = ret.st.movement_op(op, arg).shape + return ret - if C is not None: - ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "ys", "xs", "dx", "dy", "px", "py", "groups", "rcout", "cin"]) - params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]] - if C.px == 0 and C.py == 0: options.append("-DALLVALID") - if C.oy == 1 and C.ox == 1: options.append("-DONEBYONE") - global_size = [C.bs*C.cout, C.oy, C.ox] - assert bufs[0][0] == "input" and bufs[1][0] == "weight" - ewbufs = bufs[2:] # input and weight are consumed by the convs - else: - ints, params = '', [] - options.append("-DNOCONV") - global_size = [prod(ret.shape), 1, 1] - ewbufs = bufs + def processing_op(x, op:ProcessingOps, w:GPUBuffer, C:ConvArgs): + assert op == ProcessingOps.CONV, f"{op} isn't supported" + return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C) - elementwise_prefix = '\n'.join([contiguous_view(name, buf) for name, buf in ewbufs])+ \ - "inline float _ewop("+','.join(["int gid", "float acc"]+[f"__global const float *{name}_g" for name, _ in ewbufs])+") {"+ \ - '\n'.join([f"float {name} = get_{name}({name}_g, gid);" for name, _ in ewbufs])+ \ - f"return {code}; }}" + def reduce_op(x, op:ReduceOps, new_shape:Tuple[int]): + ret = GPUBuffer(new_shape) + if op == ReduceOps.SUM: code, start = "out += a", "0.0" + elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY" + else: raise Exception(f"{op} isn't supported") - conv_params = ["__global float* restrict output"] + \ - [f"__global const float *{name}_g" for name, _ in bufs] + \ - [x[0] for x in params] - conv_prg = CLProgram("conv", elementwise_prefix+""" - __kernel void conv("""+','.join(conv_params)+""") { - float acc = 0.0; - int gid = get_global_id(0); - """+ints+""" + # reverse operation of expand, this validates inputs + st = ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, x.shape) + # this takes a ret index to an inp index, indexing 0 on the reduced strides + view = View(ret.shape, strides_for_shape(x.shape)) -#ifndef NOCONV - int B = gid/(groups*rcout); // range 0-bs - int g = (gid/rcout)%groups; - int c = gid % rcout; + # generate loops with combined adjacent reduce axis + acc = 1 + loop_start, loop_end = [], [] + for shp,stride in st.views[-1].shape_strides[::-1]: + if stride == 0: + loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {shp}; axis_{len(loop_start)}++) {{") + loop_end.append(f"idx += {acc}; }} idx -= {shp*acc};") + acc *= shp -#ifdef ONEBYONE - int Y = 0; - int X = 0; -#else - int Y = get_global_id(1); // range 0-oy - int X = get_global_id(2); // range 0-ox - gid = gid*oy*ox + Y*ox + X; -#endif + # TODO: support multistage reduces + CLProgram("reduce", x.contiguous_view('A')+""" + __kernel void reduce(__global const float *a_g, __global float *res_g) { + int gid = get_global_id(0); int idx = gid;"""+view.expr.replace('//', '/')+"""; + float out = """+start+""";\n"""+ \ + '\n'.join(loop_start[::-1])+""" + float a = get_A(a_g, idx); + """+code+""";\n"""+ \ + '\n'.join(loop_end)+""" + res_g[gid] = out; + }""")([prod(ret.shape)], None, x.cl, ret.cl) + return ret - int IY = Y*ys; - int IX = X*xs; + def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None) -> GPUBuffer: + options = [] + if C is not None: + ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "ys", "xs", "dx", "dy", "px", "py", "groups", "rcout", "cin"]) + params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]] + if C.px == 0 and C.py == 0: options.append("-DALLVALID") + if C.oy == 1 and C.ox == 1: options.append("-DONEBYONE") + global_size = [C.bs*C.cout, C.oy, C.ox] + assert bufs[0][0] == "input" and bufs[1][0] == "weight" + ewbufs = bufs[2:] # input and weight are consumed by the convs + kernel_name = "conv" + else: + ints, params = '', [] + options.append("-DNOCONV") + global_size = [prod(ret.shape), 1, 1] + ewbufs = bufs + kernel_name = "elementwise" - for (int ci = 0; ci < cin; ci++) { - for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) { - int idx_y = y*dy + IY - py; - int idx_x = x*dx + IX - px; -#ifdef ALLVALID - acc += input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + idx_y*ix + idx_x] * \ - weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x]; -#else - int valid = (idx_y >= 0 && idx_y < iy && idx_x >= 0 && idx_x < ix); - acc += valid * input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + clamp(idx_y, 0, iy-1)*ix + clamp(idx_x, 0, ix-1)] * \ - weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x]; -#endif - } } - } -#endif + elementwise_prefix = '\n'.join([buf.contiguous_view(name) for name, buf in ewbufs])+ \ + "inline float _ewop("+','.join(["int gid", "float acc"]+[f"__global const float *{name}_g" for name, _ in ewbufs])+") {"+ \ + '\n'.join([f"float {name} = get_{name}({name}_g, gid);" for name, _ in ewbufs])+ \ + f"return {code}; }}" - output[gid] = _ewop("""+','.join(["gid", "acc"]+[f"{name}_g" for name, _ in ewbufs])+"""); - }""", options=tuple(options), argdtypes=tuple([None]*(1+len(bufs)) + [np.int32]*len(params))) - conv_prg(global_size, None, ret.cl, *[buf.cl for _, buf in bufs], *[x[1] for x in params]) - return ret + conv_params = ["__global float* restrict output"] + \ + [f"__global const float *{name}_g" for name, _ in bufs] + \ + [x[0] for x in params] + conv_prg = CLProgram(kernel_name, elementwise_prefix+f"__kernel void {kernel_name}("+','.join(conv_params)+""") { + float acc = 0.0; + int gid = get_global_id(0); + """+ints+""" + + #ifndef NOCONV + int B = gid/(groups*rcout); // range 0-bs + int g = (gid/rcout)%groups; + int c = gid % rcout; + + #ifdef ONEBYONE + int Y = 0; + int X = 0; + #else + int Y = get_global_id(1); // range 0-oy + int X = get_global_id(2); // range 0-ox + gid = gid*oy*ox + Y*ox + X; + #endif + + int IY = Y*ys; + int IX = X*xs; + + for (int ci = 0; ci < cin; ci++) { + for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) { + int idx_y = y*dy + IY - py; + int idx_x = x*dx + IX - px; + #ifdef ALLVALID + acc += input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + idx_y*ix + idx_x] * \ + weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x]; + #else + int valid = (idx_y >= 0 && idx_y < iy && idx_x >= 0 && idx_x < ix); + acc += valid * input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + clamp(idx_y, 0, iy-1)*ix + clamp(idx_x, 0, ix-1)] * \ + weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x]; + #endif + } } + } + #endif + + output[gid] = _ewop("""+','.join(["gid", "acc"]+[f"{name}_g" for name, _ in ewbufs])+"""); + }""", options=tuple(options), argdtypes=tuple([None]*(1+len(bufs)) + [np.int32]*len(params))) + conv_prg(global_size, None, ret.cl, *[buf.cl for _, buf in bufs], *[x[1] for x in params]) + return ret -# gpu ops -code_for_op = { - UnaryOps.NOOP: "(A)", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.NEG: "(-(A))", UnaryOps.SIGN: "sign(A)", - BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(B/A)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)", -} -def unary_op(op, x): return _processing_op(x.shape, [("A", x)], code_for_op[op]) -def binary_op(op, x, y): return _processing_op(x.shape, [("A", x), ("B", y)], code_for_op[op]) -def contiguous(x:GPUBuffer): return x if x.st.contiguous else unary_op(UnaryOps.NOOP, x) -def movement_op(op, x, arg): - ret = GPUBuffer(x.st, x) - ret.shape = ret.st.movement_op(op, arg).shape - return ret - -def processing_op(op, x, w, C): - assert op == ProcessingOps.CONV, f"{op} isn't supported" - return _processing_op(C.out_shape, [("input", contiguous(x)), ("weight", contiguous(w))], "acc", C) - -def reduce_op(op, x, new_shape): - ret = GPUBuffer(new_shape) - if op == ReduceOps.SUM: code, start = "out += a", "0.0" - elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY" - else: raise Exception(f"{op} isn't supported") - - # reverse operation of expand, this validates inputs - st = ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, x.shape) - # this takes a ret index to an inp index, indexing 0 on the reduced strides - view = View(ret.shape, strides_for_shape(x.shape)) - - # generate loops with combined adjacent reduce axis - acc = 1 - loop_start, loop_end = [], [] - for shp,stride in st.views[-1].shape_strides[::-1]: - if stride == 0: - loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {shp}; axis_{len(loop_start)}++) {{") - loop_end.append(f"idx += {acc}; }} idx -= {shp*acc};") - acc *= shp - - # TODO: support multistage reduces - CLProgram("reduce", contiguous_view('A', x)+""" - __kernel void reduce(__global const float *a_g, __global float *res_g) { - int gid = get_global_id(0); int idx = gid;"""+view.expr.replace('//', '/')+"""; - float out = """+start+""";\n"""+ \ - '\n'.join(loop_start[::-1])+""" - float a = get_A(a_g, idx); - """+code+""";\n"""+ \ - '\n'.join(loop_end)+""" - res_g[gid] = out; - }""")([prod(ret.shape)], None, x.cl, ret.cl) - return ret diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 28ec1390ab..422e31f55b 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -1,5 +1,7 @@ import torch import numpy as np +from tinygrad.llops.ops_cpu import CPUBuffer +from tinygrad.ops import ProcessingOps device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class TorchBuffer(torch.Tensor): @@ -17,14 +19,8 @@ class TorchBuffer(torch.Tensor): def getdtype(self): return np.float32 -# ************* unary+binary+reduce+movement ops ************* + unary_op, binary_op, reduce_op, movement_op = CPUBuffer.unary_op, CPUBuffer.binary_op, CPUBuffer.reduce_op, CPUBuffer.movement_op -from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op - -# ************* processing ops ************* - -from tinygrad.ops import ProcessingOps - -def processing_op(op,x,w,C): - assert op == ProcessingOps.CONV, f"{op} isn't supported" - return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px)) + def processing_op(x,op,w,C): + assert op == ProcessingOps.CONV, f"{op} isn't supported" + return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px)) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 9237dac387..f583401dab 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -6,12 +6,10 @@ from tinygrad.tensor import Function # ************* unary ops ************* class _UnaryOp(Function): - @staticmethod def forward(ctx, input): ctx.save_for_backward(input) return ctx.unary_op(ctx.fop, input) - @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors return ctx.binary_op(ctx.bop, input, grad_output) @@ -19,7 +17,6 @@ class _UnaryOp(Function): class ReLU(_UnaryOp): fop = UnaryOps.RELU - @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors ret = ctx.unary_op(UnaryOps.SIGN, input) @@ -31,7 +28,6 @@ class Log(_UnaryOp): bop = BinaryOps.DIV class Exp(_UnaryOp): - @staticmethod def forward(ctx, input): ret = ctx.unary_op(UnaryOps.EXP, input) ctx.save_for_backward(ret) # we save the output here, not the input @@ -42,24 +38,20 @@ class Exp(_UnaryOp): # ************* reduce ops ************* class Sum(Function): - @staticmethod def forward(ctx, input, axis=None): ctx.save_for_backward(input.shape) return ctx.reduce_op(ReduceOps.SUM, input, reduce_shape(input.shape, axis)) - @staticmethod def backward(ctx, grad_output): shape_input, = ctx.saved_tensors return ctx.movement_op(MovementOps.EXPAND, grad_output, shape_input) class Max(Function): - @staticmethod def forward(ctx, input, axis=None): ret = ctx.reduce_op(ReduceOps.MAX, input, reduce_shape(input.shape, axis)) ctx.save_for_backward(input, ret) return ret - @staticmethod def backward(ctx, grad_output): input, ret = ctx.saved_tensors @@ -77,32 +69,26 @@ class Max(Function): # ************* binary ops ************* class Add(Function): - @staticmethod def forward(ctx, x, y): return ctx.binary_op(BinaryOps.ADD, x, y) - @staticmethod def backward(ctx, grad_output): return grad_output if ctx.needs_input_grad[0] else None, \ grad_output if ctx.needs_input_grad[1] else None class Sub(Function): - @staticmethod def forward(ctx, x, y): return ctx.binary_op(BinaryOps.SUB, x, y) - @staticmethod def backward(ctx, grad_output): return grad_output if ctx.needs_input_grad[0] else None, \ ctx.unary_op(UnaryOps.NEG, grad_output) if ctx.needs_input_grad[1] else None class Mul(Function): - @staticmethod def forward(ctx, x, y): ctx.save_for_backward(x, y) return ctx.binary_op(BinaryOps.MUL, x, y) - @staticmethod def backward(ctx, grad_output): x,y = ctx.saved_tensors grad_x = ctx.binary_op(BinaryOps.MUL, y, grad_output) if ctx.needs_input_grad[0] else None @@ -110,13 +96,11 @@ class Mul(Function): return grad_x, grad_y class Pow(Function): - @staticmethod def forward(ctx, x, y): ret = ctx.binary_op(BinaryOps.POW, x, y) ctx.save_for_backward(x, y, ret) return ret - @staticmethod def backward(ctx, grad_output): x,y,powxy = ctx.saved_tensors grad_x, grad_y = None, None @@ -133,58 +117,48 @@ class Pow(Function): # NOTE: this is sum in reverse class Expand(Function): - @staticmethod def forward(ctx, x, shape): ctx.save_for_backward(x.shape) return ctx.movement_op(MovementOps.EXPAND, x, shape) - @staticmethod def backward(ctx, grad_output): in_shape, = ctx.saved_tensors return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape) class Flip(Function): - @staticmethod def forward(ctx, x, axis): ctx.save_for_backward(axis) return ctx.movement_op(MovementOps.FLIP, x, axis) - @staticmethod def backward(ctx, grad_output): axis, = ctx.saved_tensors return ctx.movement_op(MovementOps.FLIP, grad_output, axis) class Reshape(Function): - @staticmethod def forward(ctx, x, shape): ctx.save_for_backward(x.shape) shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape) return ctx.movement_op(MovementOps.RESHAPE, x, shape) - @staticmethod def backward(ctx, grad_output): in_shape, = ctx.saved_tensors return ctx.movement_op(MovementOps.RESHAPE, grad_output, in_shape) class Permute(Function): - @staticmethod def forward(ctx, x, order=(1,0)): ctx.save_for_backward(order) return ctx.movement_op(MovementOps.PERMUTE, x, order) - @staticmethod def backward(ctx, grad_output): order, = ctx.saved_tensors norder = np.argsort(order).tolist() return ctx.movement_op(MovementOps.PERMUTE, grad_output, norder) class Slice(Function): - @staticmethod def forward(ctx, x, arg=None): ctx.save_for_backward(x.shape, arg) return ctx.movement_op(MovementOps.SLICE, x, arg) - @staticmethod def backward(ctx, grad_output): shape, arg = ctx.saved_tensors narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)] @@ -193,13 +167,11 @@ class Slice(Function): # ************* processing ops ************* class Conv2D(Function): - @staticmethod def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0): C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding) ctx.save_for_backward(x,w,C) return ctx.processing_op(ProcessingOps.CONV, x, w, C) - @staticmethod def backward(ctx, grad_output): x, w, C = ctx.saved_tensors dx, dw = None, None diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 79f25233fb..3e547ed07c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -51,14 +51,14 @@ def log_op(optype, op, ret, inp): class Ops: def unary_op(ctx, op:UnaryOps, x): - ret = ctx.op.unary_op(op, x) + ret = x.unary_op(op) if 'LAZY' not in ctx.device: log_op(UnaryOps, op, ret, [x]) assert isinstance(ret, ctx.buffer) assert ret.shape == x.shape return ret def reduce_op(ctx, op:ReduceOps, x, new_shape): - ret = ctx.op.reduce_op(op, x, tuple(new_shape)) + ret = x.reduce_op(op, tuple(new_shape)) if 'LAZY' not in ctx.device: log_op(ReduceOps, op, ret, [x]) assert isinstance(ret, ctx.buffer) assert ret.shape == tuple(new_shape) @@ -66,14 +66,14 @@ class Ops: def binary_op(ctx, op:BinaryOps, x, y): assert x.shape == y.shape - ret = ctx.op.binary_op(op, x, y) + ret = x.binary_op(op, y) if 'LAZY' not in ctx.device: log_op(BinaryOps, op, ret, [x, y]) assert isinstance(ret, ctx.buffer) assert ret.shape == x.shape return ret def movement_op(ctx, op:MovementOps, x, arg): - ret = ctx.op.movement_op(op, x, tuple(arg)) + ret = x.movement_op(op, tuple(arg)) if 'LAZY' not in ctx.device: log_op(MovementOps, op, ret, [x]) assert isinstance(ret, ctx.buffer) # this check is slow @@ -81,7 +81,7 @@ class Ops: return ret def processing_op(ctx, op:ProcessingOps, x, y, C): - ret = ctx.op.processing_op(op, x, y, C) + ret = x.processing_op(op, y, C) if 'LAZY' not in ctx.device: log_op(ProcessingOps, op, ret, [x, y]) assert isinstance(ret, ctx.buffer) assert ret.shape == C.out_shape diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 16c8904a7b..d9f0acb369 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -111,7 +111,7 @@ class Tensor: if not any(x.requires_grad for x in t0._ctx.parents): continue assert (t0.grad is not None) - grads = t0._ctx.backward(t0._ctx, t0.grad.data) + grads = t0._ctx.backward(t0.grad.data) grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None for g in ([grads] if len(t0._ctx.parents) == 1 else grads)] for t, g in zip(t0._ctx.parents, grads): @@ -380,7 +380,6 @@ class Function(Ops): self.saved_tensors = [] buffer = property(lambda self: Device.buffers[self.device]) - op = property(lambda self: Device.llops[self.device]) def save_for_backward(self, *x): if self.requires_grad: @@ -389,7 +388,7 @@ class Function(Ops): @classmethod def apply(cls, *x:List[Tensor], **kwargs): ctx = cls(x[0].device, *x) - ret = Tensor(cls.forward(ctx, *[t.data for t in x], **kwargs), + ret = Tensor(ctx.forward(*[t.data for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad) if ctx.requires_grad: ret._ctx = ctx # used by autograd engine