From e8ca3ad05361cb4fa0317c26335f3020f845fb3a Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 17 Nov 2021 12:46:07 -0800 Subject: [PATCH] add Buffer abstraction for each device --- tinygrad/ops_cpu.py | 21 ++++++++++ tinygrad/ops_gpu.py | 75 ++++++++++++++++++++++++---------- tinygrad/ops_torch.py | 7 ++++ tinygrad/tensor.py | 95 ++++++++----------------------------------- 4 files changed, 98 insertions(+), 100 deletions(-) diff --git a/tinygrad/ops_cpu.py b/tinygrad/ops_cpu.py index 4c3d251e48..75825933ca 100644 --- a/tinygrad/ops_cpu.py +++ b/tinygrad/ops_cpu.py @@ -1,6 +1,27 @@ import numpy as np from .tensor import Function +class CPUBuffer(np.ndarray): + def log(x): + return np.log(x) + def exp(x): + return np.exp(x) + def relu(x): + return np.maximum(x, 0) + def expand(x, shp): + return np.broadcast_to(x, shp) + def amax(x, *args, **kwargs): + return np.amax(x, *args, **kwargs) + def permute(x, order): + return x.transpose(order) + def type(x, tt): + return x.astype(tt) + def toCPU(x): + return x + @staticmethod + def fromCPU(x): + return x + # ************* unary ops ************* class ReLU(Function): diff --git a/tinygrad/ops_gpu.py b/tinygrad/ops_gpu.py index 3f573a6981..e681687961 100644 --- a/tinygrad/ops_gpu.py +++ b/tinygrad/ops_gpu.py @@ -1,13 +1,44 @@ import functools import pyopencl as cl import numpy as np -from .tensor import Function, GPUBuffer +from .tensor import Function + +cl_ctx, cl_queue = None, None +def require_init_gpu(): + global cl_ctx, cl_queue + if cl_queue is None: + devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.GPU) + if len(devices) == 0: + devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.CPU) + cl_ctx = cl.Context(devices=devices) + # this is an in-order command queue + cl_queue = cl.CommandQueue(cl_ctx) +require_init_gpu() + +class GPUBuffer: + def __init__(self, shape, hostbuf=None): + self.shape, self.dtype = tuple(shape), np.float32 + self.cl = hostbuf.cl if isinstance(hostbuf, GPUBuffer) else \ + cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | (cl.mem_flags.COPY_HOST_PTR if hostbuf is not None else 0), 4*np.prod(shape), + hostbuf=hostbuf.astype(np.float32).ravel() if hostbuf is not None else None) + + def __repr__(self): + return f"" + + @staticmethod + def fromCPU(x): + return GPUBuffer(x.shape, x.view(np.ndarray)) + + def toCPU(self): + data = np.empty(self.shape, dtype=np.float32) + cl.enqueue_copy(cl_queue, data, self.cl, is_blocking=True) + return data def buffer_new(ctx, shape, zero=False): return GPUBuffer(shape, hostbuf=None if not zero else np.zeros(shape, dtype=np.float32)) def buffer_np(ctx, x): - return cl.Buffer(ctx.cl_ctx, cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=x) + return cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=x) @functools.lru_cache def clbuild(cl_ctx, name, prg): @@ -21,13 +52,13 @@ i32 = np.int32 def unary_op(ctx, code, x): ret = buffer_new(ctx, x.shape) - unop = clbuild(ctx.cl_ctx, "unop", """ + unop = clbuild(cl_ctx, "unop", """ __kernel void unop(__global const float *a_g, __global float *res_g) { int gid = get_global_id(0); float a = a_g[gid]; res_g[gid] = """+code+"""; }""") - unop(ctx.cl_queue, [np.prod(ret.shape)], None, x.cl, ret.cl) + unop(cl_queue, [np.prod(ret.shape)], None, x.cl, ret.cl) return ret class ReLU(Function): @@ -72,7 +103,7 @@ def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"): ret.shape = (1,) # TODO: this is insanely slow - reduce = clbuild(ctx.cl_ctx, "reduce", """ + reduce = clbuild(cl_ctx, "reduce", """ __kernel void reduce(__global const float *a_g, int sz, __global float *res_g, int prod, int n_dims, __global const int *shape_x, __global const int *shape_ret) { int gid = get_global_id(0); @@ -97,7 +128,7 @@ def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"): } res_g[gid] = """+code2+"""; }""") - reduce(ctx.cl_queue, [np.prod(osize)], None, inp.cl, + reduce(cl_queue, [np.prod(osize)], None, inp.cl, i32(np.prod(inp.shape)//np.prod(osize)), ret.cl, i32(np.prod(osize)), i32(len(osize)), buffer_np(ctx, np.array(inp.shape, dtype=np.int32)), @@ -174,10 +205,10 @@ def binary_op(ctx, code, x, y): for i in range(n_dims): # group together any adjacent dimensions that we can to simplify broadcasting push(i32(max(shape_x[i], shape_y[i])), (shape_x[i] > 1, shape_y[i] > 1)) - prg = get_binop_prg(ctx.cl_ctx, code, tuple(complist)) + prg = get_binop_prg(cl_ctx, code, tuple(complist)) ret = buffer_new(ctx, shape_ret, zero=True) prod_list = np.array(dimlist, dtype=i32)[-1::-1].cumprod(dtype=i32)[-1::-1] # take cumprod from back to front - prg.binop(ctx.cl_queue, [prod_list[0]] if len(dimlist) > 0 else [1], None, x.cl, y.cl, ret.cl, *dimlist, *(prod_list[1:])) + prg.binop(cl_queue, [prod_list[0]] if len(dimlist) > 0 else [1], None, x.cl, y.cl, ret.cl, *dimlist, *(prod_list[1:])) return ret def unbroadcast(ctx, out, in_sh): @@ -245,7 +276,7 @@ class Reshape(Function): def perm_axis(ctx, inp, order): osize = np.array(inp.shape)[list(order)] ret = buffer_new(ctx, osize) - perm = clbuild(ctx.cl_ctx, "perm", """ + perm = clbuild(cl_ctx, "perm", """ __kernel void perm(__global const float *a_g, __global float *res_g, int n_axis, __global const int *shape, __global const int *order) { int gid = get_global_id(0); @@ -259,7 +290,7 @@ def perm_axis(ctx, inp, order): } res_g[gid] = a_g[idx]; }""") - perm(ctx.cl_queue, [np.prod(osize)], None, inp.cl, ret.cl, i32(len(osize)), + perm(cl_queue, [np.prod(osize)], None, inp.cl, ret.cl, i32(len(osize)), buffer_np(ctx, np.array(inp.shape, dtype=np.int32)), buffer_np(ctx, np.array(order, dtype=np.int32))) return ret @@ -277,7 +308,7 @@ def inner_slice(ctx, x, arg): shift = [y[0] for y in arg] oshape = [y[1]-y[0] for y in arg] ret = buffer_new(ctx, oshape) - gslice = clbuild(ctx.cl_ctx, "gslice", """ + gslice = clbuild(cl_ctx, "gslice", """ __kernel void gslice(__global const float *input, __global float *output, int prod, int n_dims, __global const int *shape_x, __global const int *shape_ret, __global const int *shift) { @@ -292,7 +323,7 @@ def inner_slice(ctx, x, arg): } output[gid] = zero ? input[iptr] : 0.0; }""") - gslice(ctx.cl_queue, [np.prod(ret.shape)], None, + gslice(cl_queue, [np.prod(ret.shape)], None, x.cl, ret.cl, i32(np.prod(ret.shape)), i32(len(ret.shape)), buffer_np(ctx, np.array(x.shape, dtype=np.int32)), buffer_np(ctx, np.array(ret.shape, dtype=np.int32)), @@ -318,7 +349,7 @@ class Matmul(Function): isize, msize, osize = i32(input.shape[-2]), i32(input.shape[-1]), i32(weight.shape[-1]) ret = buffer_new(ctx, list(input.shape[0:-2])+[isize, osize]) - matmul = clbuild(ctx.cl_ctx, "matmul", """ + matmul = clbuild(cl_ctx, "matmul", """ __kernel void matmul( __global const float *input, __global const float *weight, __global float *res, int isize, int is0, int is1, int msize, int ws0, int ws1, int osize @@ -339,7 +370,7 @@ class Matmul(Function): ctx.save_for_backward(input, weight, matmul, cnt) # (isize,msize) x (msize,osize) = (isize,osize) - matmul(ctx.cl_queue, [isize, osize, cnt], None, + matmul(cl_queue, [isize, osize, cnt], None, input.cl, weight.cl, ret.cl, isize, msize, i32(1), msize, i32(1), osize, osize) return ret @@ -352,12 +383,12 @@ class Matmul(Function): grad_weight = buffer_new(ctx, weight.shape) # (isize,osize) x (msize,osize) = (isize,msize) - matmul(ctx.cl_queue, [isize, msize, cnt], None, + matmul(cl_queue, [isize, msize, cnt], None, grad_output.cl, weight.cl, grad_input.cl, isize, osize, i32(1), osize, osize, i32(1), msize) # (isize,msize) x (isize,osize) = (msize,osize) - matmul(ctx.cl_queue, [msize, osize, cnt], None, + matmul(cl_queue, [msize, osize, cnt], None, input.cl, grad_output.cl, grad_weight.cl, msize, i32(1), msize, isize, i32(1), osize, osize) @@ -383,7 +414,7 @@ class Conv2D(Function): # weight = (groups, rcout, cin, H, W) # output = (bs, groups, rcout, oy, ox) - conv = clbuild(ctx.cl_ctx, "conv", """ + conv = clbuild(cl_ctx, "conv", """ __kernel void conv(__global const float *input, __global const float *weight, __global float *output, int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs) { @@ -408,7 +439,7 @@ class Conv2D(Function): output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc; }""") - conv(ctx.cl_queue, [bs*groups*rcout, oy, ox], None, + conv(cl_queue, [bs*groups*rcout, oy, ox], None, x.cl, w.cl, ret.cl, i32(H), i32(W), i32(groups), i32(rcout), i32(cin), i32(oy), i32(ox), i32(iy), i32(ix), i32(ys), i32(xs) @@ -433,7 +464,7 @@ class Conv2D(Function): # tensw = (groups*rcout, cin, H, W) # ggg = (bs, groups*rout, oy, ox) - convw = clbuild(ctx.cl_ctx, "convw", """ + convw = clbuild(cl_ctx, "convw", """ __kernel void convw(__global const float *tensx, __global const float *ggg, __global float *dw, int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) { @@ -454,7 +485,7 @@ class Conv2D(Function): } dw[get_global_id(0)*H*W + y*W + x] = acc; }""") - convx = clbuild(ctx.cl_ctx, "convx", """ + convx = clbuild(cl_ctx, "convx", """ __kernel void convx(__global const float *tensw, __global const float *ggg, __global float *dx, int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) { @@ -480,6 +511,6 @@ class Conv2D(Function): """) conv_args = i32(H), i32(W), i32(ctx.groups), i32(rcout), i32(cin), i32(oy), i32(ox), i32(iy), i32(ix), i32(ys), i32(xs), i32(bs) - convw(ctx.cl_queue, [ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *conv_args) - convx(ctx.cl_queue, [bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *conv_args) + convw(cl_queue, [ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *conv_args) + convx(cl_queue, [bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *conv_args) return dx, dw diff --git a/tinygrad/ops_torch.py b/tinygrad/ops_torch.py index dc17c2a594..a6677f090b 100644 --- a/tinygrad/ops_torch.py +++ b/tinygrad/ops_torch.py @@ -2,6 +2,13 @@ import torch import numpy as np from .tensor import Function +class TorchBuffer(torch.Tensor): + @staticmethod + def fromCPU(data): + return TorchBuffer(torch.from_numpy(data).requires_grad_(False)) + def toCPU(x): + return x.numpy() + # ************* unary+binary+reduce ops ************* from tinygrad.ops_cpu import ReLU, Log, Exp, Add, Sub, Mul, Pow, Sum, Max diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e0634fad74..674196ad91 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -24,40 +24,17 @@ class ProfileOp: return self def __exit__(self, *junk): if DEBUG: - if cl_queue is not None: - cl_queue.finish() + # TODO: fix this + #if cl_queue is not None: + # cl_queue.finish() et = (time.time()-self.st)*1000. debug_counts[self.name] += 1 debug_times[self.name] += et print(f"{self.name:>20} : {et:>7.2f} ms {str([y.shape for y in self.x]):>40} {'-> '+str(self.output.shape) if self.output is not None else ''}") -# **** GPU functions **** - -cl_ctx, cl_queue = None, None -def require_init_gpu(): - if not GPU: raise Exception("No GPU Support, install pyopencl") - global cl_ctx, cl_queue - if cl_queue is None: - devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.GPU) - if len(devices) == 0: - devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.CPU) - cl_ctx = cl.Context(devices=devices) - # this is an in-order command queue - cl_queue = cl.CommandQueue(cl_ctx) - -class GPUBuffer: - def __init__(self, shape, hostbuf=None): - self.shape, self.dtype = tuple(shape), np.float32 - self.cl = hostbuf.cl if isinstance(hostbuf, GPUBuffer) else \ - cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | (cl.mem_flags.COPY_HOST_PTR if hostbuf is not None else 0), 4*np.prod(shape), - hostbuf=hostbuf.astype(np.float32).ravel() if hostbuf is not None else None) - - def __repr__(self): - return f"" - # **** start with two base classes, Tensor and Function **** -class Device: CPU, GPU, TORCH = 0, 1, 2 +class Device: CPU, GPU, TORCH, buffers = 0, 1, 2, {} DEFAULT_DEVICE = (Device.CPU if os.environ.get("GPU", 0) != "1" else Device.GPU) if os.environ.get("TORCH", 0) != "1" else Device.TORCH @@ -148,58 +125,22 @@ class Tensor: gt = Tensor(g, device=self.device, requires_grad=False) t.grad = gt if t.grad is None else (t.grad + gt) - # ***** tinygrad supports CPU and GPU ***** + # ***** tinygrad supports many devices ***** @staticmethod def _move_data(data, device): - if isinstance(data, GPUBuffer): - if device == Device.GPU: return data - old = data - data = np.empty(old.shape, dtype=np.float32) - with ProfileOp("toCPU", [data]): - cl.enqueue_copy(cl_queue, data, old.cl, is_blocking=True) - - if str(type(data)).startswith("torch"): - data = data.numpy() - - if not isinstance(data, np.ndarray): - data = np.array(data, dtype=np.float32) + if isinstance(data, np.ndarray): + data = data.view(Device.buffers[Device.CPU]) + if isinstance(data, Device.buffers[device]): + return data if data.dtype != np.float32 and not Tensor.did_float_warning: # warning? float64 is actually needed for numerical jacobian print(f"warning, {data.shape!r} isn't float32, it's {data.dtype}") Tensor.did_float_warning = True - if device == Device.CPU: - # add these functions to ndarray - class CPUBuffer(np.ndarray): - def log(x): - return np.log(x) - def exp(x): - return np.exp(x) - def relu(x): - return np.maximum(x, 0) - def expand(x, shp): - return np.broadcast_to(x, shp) - def amax(x, *args, **kwargs): - return np.amax(x, *args, **kwargs) - def permute(x, order): - return x.transpose(order) - def type(x, tt): - return x.astype(tt) - data = data.view(CPUBuffer) - - if device == Device.GPU: - require_init_gpu() - with ProfileOp("toGPU", [data]): - return GPUBuffer(data.shape, data) - - if device == Device.TORCH: - import torch - with ProfileOp("toTORCH", [data]): - return torch.from_numpy(data).requires_grad_(False) - - return data + data = data.toCPU().view(Device.buffers[Device.CPU]) + return Device.buffers[device].fromCPU(data) def to_(self, device): self.data, self.device = self._move_data(self.data, device), device @@ -345,7 +286,8 @@ def register(name, fxn, device=Device.CPU): tt = [arg for arg in x if isinstance(arg, Tensor)][0] x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x] f = Tensor.ops[tt.device][name] - f.cl_ctx, f.cl_queue, f.device = cl_ctx, cl_queue, tt.device + #f.cl_ctx, f.cl_queue, f.device = cl_ctx, cl_queue, tt.device + f.device = tt.device return f.apply(f, *x, **kwargs) setattr(Tensor, name, dispatch) if name in ['add', 'sub', 'mul', 'pow', 'matmul']: @@ -360,21 +302,18 @@ for device in [device for device in Device.__dict__.keys() if device[0] != "_"]: # this registers all the operations def _register_ops(namespace, device=Device.CPU): for name, cls in inspect.getmembers(namespace, inspect.isclass): - if name[0] != "_": register(name.lower(), cls, device=device) + if name.endswith("Buffer"): Device.buffers[device] = cls + elif name[0] != "_": register(name.lower(), cls, device=device) +# TODO: refactor this from tinygrad import ops_cpu _register_ops(ops_cpu) try: - import pyopencl as cl - # TODO: move this import to require_init_gpu? from tinygrad import ops_gpu _register_ops(ops_gpu, device=Device.GPU) - GPU = True except ImportError: - # no GPU support - GPU = False + pass try: - import torch from tinygrad import ops_torch _register_ops(ops_torch, device=Device.TORCH) except ImportError: