remove ctx from buffers (#333)

This commit is contained in:
George Hotz
2022-06-18 17:27:10 -07:00
committed by GitHub
parent 77f5cef8a6
commit aa164d901e
4 changed files with 56 additions and 40 deletions

View File

@@ -18,7 +18,7 @@ class CPUBuffer(np.ndarray):
def fromCPU(x): return x
def toCPU(x): return x
def unary_op(ctx, op, x):
def unary_op(op, x):
if op == UnaryOps.RELU: return x.relu()
elif op == UnaryOps.EXP: return x.exp()
elif op == UnaryOps.LOG: return x.log()
@@ -26,7 +26,7 @@ def unary_op(ctx, op, x):
elif op == UnaryOps.SIGN: return x.sign()
else: raise Exception(f"{op} isn't supported")
def binary_op(ctx, op, x, y):
def binary_op(op, x, y):
if op == BinaryOps.ADD: return x+y
elif op == BinaryOps.SUB: return x-y
elif op == BinaryOps.MUL: return x*y
@@ -35,7 +35,7 @@ def binary_op(ctx, op, x, y):
elif op == BinaryOps.CMPEQ: return 1.0*(x==y)
else: raise Exception(f"{op} isn't supported")
def reduce_op(ctx, op, inp, new_shape):
def reduce_op(op, inp, new_shape):
if inp.shape == new_shape: # this is just a copy, regardless of the reduce op
return inp[:]
else:
@@ -48,7 +48,7 @@ def reduce_op(ctx, op, inp, new_shape):
elif op == ReduceOps.MAX: return inp.amax(axis, keepdims=True)
else: raise Exception(f"{op} isn't supported")
def movement_op(ctx, op, x, arg=None):
def movement_op(op, x, arg=None):
if op == MovementOps.RESHAPE: return x.reshape(arg)
elif op == MovementOps.PERMUTE: return x.permute(arg)
elif op == MovementOps.FLIP: return x.flip(arg)
@@ -60,7 +60,7 @@ def movement_op(ctx, op, x, arg=None):
elif op == MovementOps.EXPAND: return x.expand(arg)
else: raise Exception(f"{op} isn't supported")
def processing_op(ctx,op,x,w,C):
def processing_op(op,x,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)])
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])

View File

@@ -22,11 +22,15 @@ def roundup(x, n=4): return (x+(n-1))//n * n
class GPUBuffer:
def __init__(self, shape, hostbuf=None):
require_init_gpu()
self.shape, self.dtype = tuple(shape), np.float32
self.cl = hostbuf.cl if isinstance(hostbuf, GPUBuffer) else cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE, 4*roundup(prod(shape))) # padding
if hostbuf is not None and not isinstance(hostbuf, GPUBuffer):
self.shape = tuple(shape)
self.cl = cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE, 4*roundup(prod(self.shape))) # padding
if hostbuf is not None:
# TODO: this doesn't have to block
cl.enqueue_copy(cl_queue, self.cl, hostbuf.astype(np.float32).ravel())
@property
def dtype(self): return np.float32
def __repr__(self):
return f"<GPUBuffer with shape {self.shape!r}>"
@@ -39,14 +43,21 @@ class GPUBuffer:
cl.enqueue_copy(cl_queue, data, self.cl, is_blocking=True)
return data
@functools.lru_cache
def clbuild(name, prg, options=tuple()):
clprg = cl.Program(cl_ctx, prg).build(options=options).__getattr__(name)
def run(*args): clprg(cl_queue, *args)
return run
class CLProgram:
def __init__(self, name, prg, options, argdtypes):
self.built = cl.Program(cl_ctx, prg).build(options=options)
self.clprg = self.built.__getattr__(name)
if argdtypes is not None: self.clprg.set_scalar_arg_dtypes(argdtypes)
def __call__(self, *args): self.clprg(cl_queue, *args)
def unary_op(ctx, op, x):
ret = ctx.buffer(x.shape)
@functools.lru_cache(maxsize=None)
def clbuild(name, prg, options=tuple(), argdtypes=None):
#print("cache miss")
#print(prg)
return CLProgram(name, prg, options, argdtypes)
def unary_op(op, x):
ret = GPUBuffer(x.shape)
if op == UnaryOps.RELU: code = 'max(a, (float)0.)'
elif op == UnaryOps.EXP: code = 'exp(a)'
elif op == UnaryOps.LOG: code = 'log(a)'
@@ -62,8 +73,8 @@ def unary_op(ctx, op, x):
unop([roundup(prod(ret.shape))//4], None, x.cl, ret.cl)
return ret
def binary_op(ctx, op, x, y):
ret = ctx.buffer(x.shape)
def binary_op(op, x, y):
ret = GPUBuffer(x.shape)
if op == BinaryOps.ADD: code = "a+b"
elif op == BinaryOps.SUB: code = "a-b"
elif op == BinaryOps.MUL: code = "a*b"
@@ -82,8 +93,8 @@ def binary_op(ctx, op, x, y):
binop([roundup(prod(ret.shape))//4], None, x.cl, y.cl, ret.cl)
return ret
def reduce_op(ctx, op, inp, new_shape):
ret = ctx.buffer(new_shape)
def reduce_op(op, inp, new_shape):
ret = GPUBuffer(new_shape)
if op == ReduceOps.SUM: code, start = "out += a", "0.0"
elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY"
else: raise Exception(f"{op} isn't supported")
@@ -116,23 +127,20 @@ def reduce_op(ctx, op, inp, new_shape):
clbuild("reduce", prg)([prod(ret.shape)], None, inp.cl, ret.cl)
return ret
def contiguous(ctx, x, st, ret=None):
if ret is None: ret = ctx.buffer(st.shape)
def contiguous(x, st, ret=None):
if ret is None: ret = GPUBuffer(st.shape)
clbuild("contiguous", """__kernel void contiguous(__global const float *x, __global float *ret) {
int gid = get_global_id(0); int valid = 1; int idx = gid; """+st.expr().replace('//', '/')+""";
ret[gid] = valid ? x[idx] : 0.0; // should never be out-of-bounds accesses
}""")([prod(ret.shape)], None, x.cl, ret.cl)
return ret
def movement_op(ctx, op, x, arg=None):
return contiguous(ctx, x, ShapeTracker(*x.shape).movement_op(op, arg))
def movement_op(op, x, arg=None):
return contiguous(x, ShapeTracker(*x.shape).movement_op(op, arg))
def processing_op(ctx,op,x,w,C):
ret = ctx.buffer((C.bs, C.cout, C.oy, C.ox))
def processing_op(op,x,w,C):
ret = GPUBuffer((C.bs, C.cout, C.oy, C.ox))
assert op == ProcessingOps.CONV, f"{op} isn't supported"
# input = (bs, groups, cin, iy, ix)
# weight = (groups, rcout, cin, H, W)
# output = (bs, groups, rcout, oy, ox)
conv_prg = clbuild("conv", """
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs, int dx, int dy, int px, int py) {
@@ -146,8 +154,15 @@ def processing_op(ctx,op,x,w,C):
int IY = Y*ys;
int IX = X*xs;
int gid = get_global_id(0)*oy*ox + Y*ox + X;
float acc = 0.0;
for (int ci = 0; ci < cin; ci++) {
#ifdef ONEBYONE
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + IY*ix + IX] * \
weight[g*rcout*cin + c*cin + ci];
#else
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
int idx_y = y*dy + IY - py;
int idx_x = x*dx + IX - px;
@@ -155,9 +170,12 @@ def processing_op(ctx,op,x,w,C):
acc += valid ? input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + idx_y*ix + idx_x] * \
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x] : 0.0;
} }
#endif
}
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
}""")
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[np.int32(x) for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]])
output[gid] = acc;
}""",
options=tuple(["-DONEBYONE"]) if C.H == 1 and C.W == 1 and C.px == 0 and C.py == 0 else tuple(),
argdtypes=tuple([None, None, None] + [np.int32]*16))
conv_prg([C.bs*C.cout, C.oy, C.ox], None, x.cl, w.cl, ret.cl,
*[x for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]])
return ret

View File

@@ -25,6 +25,6 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op
from tinygrad.ops import ProcessingOps
def processing_op(ctx,op,x,w,C):
def processing_op(op,x,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))

View File

@@ -44,14 +44,14 @@ def log_op(op, ret, inp):
class Ops:
def unary_op(ctx, op:UnaryOps, x):
ret = ctx.op.unary_op(ctx, op, x)
ret = ctx.op.unary_op(op, x)
log_op(op, ret, [x])
assert isinstance(ret, ctx.buffer)
assert ret.shape == x.shape
return ret
def reduce_op(ctx, op:ReduceOps, x, new_shape):
ret = ctx.op.reduce_op(ctx, op, x, new_shape)
ret = ctx.op.reduce_op(op, x, new_shape)
log_op(op, ret, [x])
assert isinstance(ret, ctx.buffer)
assert ret.shape == tuple(new_shape)
@@ -59,24 +59,22 @@ class Ops:
def binary_op(ctx, op:BinaryOps, x, y):
assert x.shape == y.shape
ret = ctx.op.binary_op(ctx, op, x, y)
ret = ctx.op.binary_op(op, x, y)
log_op(op, ret, [x, y])
assert isinstance(ret, ctx.buffer)
assert ret.shape == x.shape
return ret
def movement_op(ctx, op:MovementOps, x, arg=None):
ret = ctx.op.movement_op(ctx, op, x, arg)
ret = ctx.op.movement_op(op, x, arg)
log_op(op, ret, [x])
assert isinstance(ret, ctx.buffer)
assert ret.shape == ShapeTracker(*x.shape).movement_op(op, arg).shape
return ret
def processing_op(ctx, op:ProcessingOps, x, y, C):
if getattr(ctx.op, "preprocessing_op", None) is not None: x,y,C = ctx.op.preprocessing_op(ctx, op, x, y, C)
ret = ctx.op.processing_op(ctx, op, x, y, C)
ret = ctx.op.processing_op(op, x, y, C)
log_op(op, ret, [x, y])
if getattr(ctx.op, "postprocessing_op", None) is not None: ret = ctx.op.postprocessing_op(ctx, op, ret, C)
assert isinstance(ret, ctx.buffer)
assert ret.shape == (C.bs, C.cout, C.oy, C.ox)
return ret