Single ReLU in ANE (#188)

* aneworks

* cleanup
This commit is contained in:
George Hotz
2020-12-12 16:11:34 -08:00
committed by GitHub
parent 07ece2105e
commit da873cd556
6 changed files with 87 additions and 35 deletions

View File

@@ -1,10 +1,13 @@
#!/usr/bin/env python3
import os
from ctypes import *
import numpy as np
import faulthandler
faulthandler.enable()
libane = cdll.LoadLibrary("libane.dylib")
libane = cdll.LoadLibrary(os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"libane.dylib"))
libane.ANE_Compile.argtypes = [c_char_p, c_int]
libane.ANE_Compile.restype = c_void_p
@@ -46,6 +49,9 @@ class ANE:
def run(self, prog, tin, tout):
libane.ANE_Run(prog, tin.tt, tout.tt)
def tensor(self, shape):
return ANETensor(shape)
if __name__ == "__main__":
ane = ANE()

BIN
ane/ops/relu.hwx Normal file

Binary file not shown.

9
examples/use_ane.py Executable file
View File

@@ -0,0 +1,9 @@
#!/usr/bin/env python3
import numpy as np
from tinygrad.tensor import Tensor
a = Tensor([-2,-1,0,1,2]).ane_()
print(a.cpu())
b = a.relu()
print(b.cpu())

View File

@@ -1,3 +1,16 @@
if __name__ == "__main__":
pass
from .tensor import Tensor, Function, register
from functools import lru_cache
@lru_cache
def compile_wrapper(ane, dat):
return ane.compile(dat)
class ReLU(Function):
@staticmethod
def forward(ctx, input):
ret = ctx.ane.tensor(input.shape)
comp = compile_wrapper(ctx.ane, open("ane/ops/relu.hwx", "rb").read())
ctx.ane.run(comp, input, ret)
return ret
register('relu', ReLU, device=Tensor.ANE)

View File

@@ -1,5 +1,5 @@
import numpy as np
from .tensor import Function, register, GPUBuffer
from .tensor import Function, register, GPUBuffer, Tensor
import pyopencl as cl
import functools
@@ -194,7 +194,7 @@ class Add(Function):
grad_x, grad_y = grad_output, grad_output
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
register('add', Add, gpu=True)
register('add', Add, device=Tensor.GPU)
class Sub(Function):
@staticmethod
@@ -207,7 +207,7 @@ class Sub(Function):
grad_x, grad_y = grad_output, unary_op(ctx, '-a', grad_output)
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
register('sub', Sub, gpu=True)
register('sub', Sub, device=Tensor.GPU)
class Mul(Function):
@staticmethod
@@ -221,7 +221,7 @@ class Mul(Function):
grad_x = binary_op(ctx, 'a*b', y, grad_output)
grad_y = binary_op(ctx, 'a*b', x, grad_output)
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
register('mul', Mul, gpu=True)
register('mul', Mul, device=Tensor.GPU)
class Pow(Function):
@staticmethod
@@ -237,7 +237,7 @@ class Pow(Function):
grad_y = binary_op(ctx, 'a*b', grad_output,
binary_op(ctx, 'pow(a, (float)b) * log(a);', x, y))
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
register('pow', Pow, gpu=True)
register('pow', Pow, device=Tensor.GPU)
class Sum(Function):
@staticmethod
@@ -254,7 +254,7 @@ class Sum(Function):
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
output = GPUBuffer(shape, hostbuf=grad_output)
return binary_op(ctx, 'a+b', output, buffer_zeros(ctx, input.shape))
register('sum', Sum, gpu=True)
register('sum', Sum, device=Tensor.GPU)
class Dot(Function):
@staticmethod
@@ -308,8 +308,8 @@ class Dot(Function):
i32(1), msize, isize, i32(1), osize, osize)
return grad_input, grad_weight
register('dot', Dot, gpu=True)
register('matmul', Dot, gpu=True)
register('dot', Dot, device=Tensor.GPU)
register('matmul', Dot, device=Tensor.GPU)
# ************* simple ops *************
@@ -352,7 +352,7 @@ class Pad2D(Function):
i32(oy), i32(ox), i32(iy), i32(ix)
)
return ret
register('pad2d', Pad2D, gpu=True)
register('pad2d', Pad2D, device=Tensor.GPU)
class Reshape(Function):
@staticmethod
@@ -368,7 +368,7 @@ class Reshape(Function):
in_shape, = ctx.saved_tensors
grad_output = GPUBuffer(in_shape, hostbuf=grad_output)
return grad_output
register('reshape', Reshape, gpu=True)
register('reshape', Reshape, device=Tensor.GPU)
# ************* activation ops *************
@@ -382,7 +382,7 @@ class ReLU(Function):
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return binary_op(ctx, 'a * (b >= 0)', grad_output, input)
register('relu', ReLU, gpu=True)
register('relu', ReLU, device=Tensor.GPU)
class Sigmoid(Function):
@staticmethod
@@ -395,7 +395,7 @@ class Sigmoid(Function):
def backward(ctx, grad_output):
ret, = ctx.saved_tensors
return binary_op(ctx, 'a * (b * (1 - b));', grad_output, ret)
register('sigmoid', Sigmoid, gpu=True)
register('sigmoid', Sigmoid, device=Tensor.GPU)
class AvgPool2D(Function):
@staticmethod
@@ -410,7 +410,7 @@ class AvgPool2D(Function):
orig_shape, = ctx.saved_tensors
return supersample_op(ctx, grad_output, orig_shape, ctx.kernel_size,
result_op="input[iid] / (ksz.x * ksz.y)")
register('avg_pool2d', AvgPool2D, gpu=True)
register('avg_pool2d', AvgPool2D, device=Tensor.GPU)
class MaxPool2D(Function):
@staticmethod
@@ -430,7 +430,7 @@ class MaxPool2D(Function):
result_op="(maxidx == kernidx) * input[iid]",
decls="int maxidx=((__global float*)input2)[iid]; int kernidx=(gid.x%ksz.x) + ksz.x*(gid.y%ksz.y)",
input2=idxs)
register('max_pool2d', MaxPool2D, gpu=True)
register('max_pool2d', MaxPool2D, device=Tensor.GPU)
class LogSoftmax(Function):
@staticmethod
@@ -447,7 +447,7 @@ class LogSoftmax(Function):
lsum = reduce_op(ctx, "out += a", "out", grad_output, axis=[1])
texp = binary_op(ctx, "exp(a) * b", output, lsum)
return binary_op(ctx, "a - b", grad_output, texp)
register('logsoftmax', LogSoftmax, gpu=True)
register('logsoftmax', LogSoftmax, device=Tensor.GPU)
# ************* conv ops *************
@@ -572,4 +572,4 @@ class Conv2D(Function):
convw(ctx.cl_queue, [ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *conv_args)
convx(ctx.cl_queue, [bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *conv_args)
return dx, dw
register('conv2d', Conv2D, gpu=True)
register('conv2d', Conv2D, device=Tensor.GPU)

View File

@@ -2,6 +2,7 @@
from inspect import signature
import numpy as np
import os
from collections import defaultdict
# **** profiler ****
@@ -53,20 +54,30 @@ class GPUBuffer:
def __repr__(self):
return f"<GPUBuffer with shape {self.shape!r}>"
# **** ANE functions ****
ane = None
def require_init_ane():
global ane
if ane is None:
import ane.lib.ane, tinygrad.ops_ane
ane = ane.lib.ane.ANE()
# **** start with two base classes, Tensor and Function ****
class Tensor:
did_float_warning = False
default_gpu = False
ops_cpu, ops_gpu = {}, {}
ops = defaultdict(dict)
CPU, GPU, ANE = 0, 1, 2
def __init__(self, data, gpu=None, requires_grad=True):
if gpu is None:
gpu = Tensor.default_gpu
if isinstance(data, list):
if "ANETensor" in str(type(data)):
self.device = Tensor.ANE
elif isinstance(data, list):
data = np.array(data, dtype=np.float32)
elif GPU and isinstance(data, GPUBuffer):
self.gpu = True
self.device = Tensor.GPU
elif not isinstance(data, np.ndarray):
raise TypeError(f"Error constructing tensor with {data!r}")
@@ -75,7 +86,7 @@ class Tensor:
# warning? float64 is actually needed for numerical jacobian
print(f"warning, {data.shape!r} isn't float32")
Tensor.did_float_warning = True
self.gpu = False
self.device = Tensor.CPU
self.data = data
self.grad = None
@@ -156,19 +167,25 @@ class Tensor:
# ***** tinygrad supports CPU and GPU *****
def cpu(self):
if self.gpu:
if self.device == Tensor.GPU:
with ProfileOp("toCPU", [self]):
ret = Tensor(np.empty(self.shape, dtype=np.float32), gpu=False)
cl.enqueue_copy(cl_queue, ret.data, self.data.cl, is_blocking=True)
if self.grad:
ret.grad = self.grad.cpu()
return ret
elif self.device == Tensor.ANE:
return Tensor(self.data.data().astype(np.float32), gpu=False)
else:
return self
@property
def gpu(self):
return self.device == Tensor.GPU
def cuda_(self):
self.data = self.cuda().data
self.gpu = True
self.device = Tensor.GPU
def cuda(self):
if not GPU:
@@ -183,6 +200,15 @@ class Tensor:
else:
return self
def ane_(self):
assert(not self.gpu)
require_init_ane()
self.device = Tensor.ANE
ndata = ane.tensor(self.shape)
ndata.data()[:] = self.data
self.data = ndata
return self
def detach(self):
return Tensor(self.data, self.gpu)
@@ -237,16 +263,13 @@ class Function:
ret._ctx = ctx
return ret
def register(name, fxn, gpu=False):
if gpu:
Tensor.ops_gpu[name] = fxn
else:
Tensor.ops_cpu[name] = fxn
def register(name, fxn, device=Tensor.CPU):
Tensor.ops[device][name] = fxn
def dispatch(*x, **kwargs):
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
x = [Tensor(np.array([arg], dtype=tt.dtype), gpu=tt.gpu, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
f = (Tensor.ops_gpu if tt.gpu else Tensor.ops_cpu)[name]
f.cl_ctx, f.cl_queue = cl_ctx, cl_queue
f = (Tensor.ops[tt.device])[name]
f.cl_ctx, f.cl_queue, f.ane = cl_ctx, cl_queue, ane
return f.apply(f, *x, **kwargs)
setattr(Tensor, name, dispatch)
# TODO: div is a second class op, so it doesn't work here
@@ -259,6 +282,7 @@ def register(name, fxn, gpu=False):
import tinygrad.ops_cpu
try:
import pyopencl as cl
# TODO: move this import to require_init_gpu?
import tinygrad.ops_gpu
GPU = True
except ImportError: