mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
remove ctx from buffers (#333)
This commit is contained in:
@@ -18,7 +18,7 @@ class CPUBuffer(np.ndarray):
|
||||
def fromCPU(x): return x
|
||||
def toCPU(x): return x
|
||||
|
||||
def unary_op(ctx, op, x):
|
||||
def unary_op(op, x):
|
||||
if op == UnaryOps.RELU: return x.relu()
|
||||
elif op == UnaryOps.EXP: return x.exp()
|
||||
elif op == UnaryOps.LOG: return x.log()
|
||||
@@ -26,7 +26,7 @@ def unary_op(ctx, op, x):
|
||||
elif op == UnaryOps.SIGN: return x.sign()
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def binary_op(ctx, op, x, y):
|
||||
def binary_op(op, x, y):
|
||||
if op == BinaryOps.ADD: return x+y
|
||||
elif op == BinaryOps.SUB: return x-y
|
||||
elif op == BinaryOps.MUL: return x*y
|
||||
@@ -35,7 +35,7 @@ def binary_op(ctx, op, x, y):
|
||||
elif op == BinaryOps.CMPEQ: return 1.0*(x==y)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def reduce_op(ctx, op, inp, new_shape):
|
||||
def reduce_op(op, inp, new_shape):
|
||||
if inp.shape == new_shape: # this is just a copy, regardless of the reduce op
|
||||
return inp[:]
|
||||
else:
|
||||
@@ -48,7 +48,7 @@ def reduce_op(ctx, op, inp, new_shape):
|
||||
elif op == ReduceOps.MAX: return inp.amax(axis, keepdims=True)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def movement_op(ctx, op, x, arg=None):
|
||||
def movement_op(op, x, arg=None):
|
||||
if op == MovementOps.RESHAPE: return x.reshape(arg)
|
||||
elif op == MovementOps.PERMUTE: return x.permute(arg)
|
||||
elif op == MovementOps.FLIP: return x.flip(arg)
|
||||
@@ -60,7 +60,7 @@ def movement_op(ctx, op, x, arg=None):
|
||||
elif op == MovementOps.EXPAND: return x.expand(arg)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def processing_op(ctx,op,x,w,C):
|
||||
def processing_op(op,x,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)])
|
||||
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
|
||||
|
||||
@@ -22,11 +22,15 @@ def roundup(x, n=4): return (x+(n-1))//n * n
|
||||
class GPUBuffer:
|
||||
def __init__(self, shape, hostbuf=None):
|
||||
require_init_gpu()
|
||||
self.shape, self.dtype = tuple(shape), np.float32
|
||||
self.cl = hostbuf.cl if isinstance(hostbuf, GPUBuffer) else cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE, 4*roundup(prod(shape))) # padding
|
||||
if hostbuf is not None and not isinstance(hostbuf, GPUBuffer):
|
||||
self.shape = tuple(shape)
|
||||
self.cl = cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE, 4*roundup(prod(self.shape))) # padding
|
||||
if hostbuf is not None:
|
||||
# TODO: this doesn't have to block
|
||||
cl.enqueue_copy(cl_queue, self.cl, hostbuf.astype(np.float32).ravel())
|
||||
|
||||
@property
|
||||
def dtype(self): return np.float32
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GPUBuffer with shape {self.shape!r}>"
|
||||
|
||||
@@ -39,14 +43,21 @@ class GPUBuffer:
|
||||
cl.enqueue_copy(cl_queue, data, self.cl, is_blocking=True)
|
||||
return data
|
||||
|
||||
@functools.lru_cache
|
||||
def clbuild(name, prg, options=tuple()):
|
||||
clprg = cl.Program(cl_ctx, prg).build(options=options).__getattr__(name)
|
||||
def run(*args): clprg(cl_queue, *args)
|
||||
return run
|
||||
class CLProgram:
|
||||
def __init__(self, name, prg, options, argdtypes):
|
||||
self.built = cl.Program(cl_ctx, prg).build(options=options)
|
||||
self.clprg = self.built.__getattr__(name)
|
||||
if argdtypes is not None: self.clprg.set_scalar_arg_dtypes(argdtypes)
|
||||
def __call__(self, *args): self.clprg(cl_queue, *args)
|
||||
|
||||
def unary_op(ctx, op, x):
|
||||
ret = ctx.buffer(x.shape)
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def clbuild(name, prg, options=tuple(), argdtypes=None):
|
||||
#print("cache miss")
|
||||
#print(prg)
|
||||
return CLProgram(name, prg, options, argdtypes)
|
||||
|
||||
def unary_op(op, x):
|
||||
ret = GPUBuffer(x.shape)
|
||||
if op == UnaryOps.RELU: code = 'max(a, (float)0.)'
|
||||
elif op == UnaryOps.EXP: code = 'exp(a)'
|
||||
elif op == UnaryOps.LOG: code = 'log(a)'
|
||||
@@ -62,8 +73,8 @@ def unary_op(ctx, op, x):
|
||||
unop([roundup(prod(ret.shape))//4], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def binary_op(ctx, op, x, y):
|
||||
ret = ctx.buffer(x.shape)
|
||||
def binary_op(op, x, y):
|
||||
ret = GPUBuffer(x.shape)
|
||||
if op == BinaryOps.ADD: code = "a+b"
|
||||
elif op == BinaryOps.SUB: code = "a-b"
|
||||
elif op == BinaryOps.MUL: code = "a*b"
|
||||
@@ -82,8 +93,8 @@ def binary_op(ctx, op, x, y):
|
||||
binop([roundup(prod(ret.shape))//4], None, x.cl, y.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def reduce_op(ctx, op, inp, new_shape):
|
||||
ret = ctx.buffer(new_shape)
|
||||
def reduce_op(op, inp, new_shape):
|
||||
ret = GPUBuffer(new_shape)
|
||||
if op == ReduceOps.SUM: code, start = "out += a", "0.0"
|
||||
elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY"
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
@@ -116,23 +127,20 @@ def reduce_op(ctx, op, inp, new_shape):
|
||||
clbuild("reduce", prg)([prod(ret.shape)], None, inp.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def contiguous(ctx, x, st, ret=None):
|
||||
if ret is None: ret = ctx.buffer(st.shape)
|
||||
def contiguous(x, st, ret=None):
|
||||
if ret is None: ret = GPUBuffer(st.shape)
|
||||
clbuild("contiguous", """__kernel void contiguous(__global const float *x, __global float *ret) {
|
||||
int gid = get_global_id(0); int valid = 1; int idx = gid; """+st.expr().replace('//', '/')+""";
|
||||
ret[gid] = valid ? x[idx] : 0.0; // should never be out-of-bounds accesses
|
||||
}""")([prod(ret.shape)], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def movement_op(ctx, op, x, arg=None):
|
||||
return contiguous(ctx, x, ShapeTracker(*x.shape).movement_op(op, arg))
|
||||
def movement_op(op, x, arg=None):
|
||||
return contiguous(x, ShapeTracker(*x.shape).movement_op(op, arg))
|
||||
|
||||
def processing_op(ctx,op,x,w,C):
|
||||
ret = ctx.buffer((C.bs, C.cout, C.oy, C.ox))
|
||||
def processing_op(op,x,w,C):
|
||||
ret = GPUBuffer((C.bs, C.cout, C.oy, C.ox))
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
# input = (bs, groups, cin, iy, ix)
|
||||
# weight = (groups, rcout, cin, H, W)
|
||||
# output = (bs, groups, rcout, oy, ox)
|
||||
conv_prg = clbuild("conv", """
|
||||
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs, int dx, int dy, int px, int py) {
|
||||
@@ -146,8 +154,15 @@ def processing_op(ctx,op,x,w,C):
|
||||
int IY = Y*ys;
|
||||
int IX = X*xs;
|
||||
|
||||
int gid = get_global_id(0)*oy*ox + Y*ox + X;
|
||||
|
||||
float acc = 0.0;
|
||||
for (int ci = 0; ci < cin; ci++) {
|
||||
|
||||
#ifdef ONEBYONE
|
||||
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + IY*ix + IX] * \
|
||||
weight[g*rcout*cin + c*cin + ci];
|
||||
#else
|
||||
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
|
||||
int idx_y = y*dy + IY - py;
|
||||
int idx_x = x*dx + IX - px;
|
||||
@@ -155,9 +170,12 @@ def processing_op(ctx,op,x,w,C):
|
||||
acc += valid ? input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + idx_y*ix + idx_x] * \
|
||||
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x] : 0.0;
|
||||
} }
|
||||
#endif
|
||||
}
|
||||
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
|
||||
}""")
|
||||
|
||||
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[np.int32(x) for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]])
|
||||
output[gid] = acc;
|
||||
}""",
|
||||
options=tuple(["-DONEBYONE"]) if C.H == 1 and C.W == 1 and C.px == 0 and C.py == 0 else tuple(),
|
||||
argdtypes=tuple([None, None, None] + [np.int32]*16))
|
||||
conv_prg([C.bs*C.cout, C.oy, C.ox], None, x.cl, w.cl, ret.cl,
|
||||
*[x for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]])
|
||||
return ret
|
||||
|
||||
@@ -25,6 +25,6 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op
|
||||
|
||||
from tinygrad.ops import ProcessingOps
|
||||
|
||||
def processing_op(ctx,op,x,w,C):
|
||||
def processing_op(op,x,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))
|
||||
|
||||
@@ -44,14 +44,14 @@ def log_op(op, ret, inp):
|
||||
|
||||
class Ops:
|
||||
def unary_op(ctx, op:UnaryOps, x):
|
||||
ret = ctx.op.unary_op(ctx, op, x)
|
||||
ret = ctx.op.unary_op(op, x)
|
||||
log_op(op, ret, [x])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == x.shape
|
||||
return ret
|
||||
|
||||
def reduce_op(ctx, op:ReduceOps, x, new_shape):
|
||||
ret = ctx.op.reduce_op(ctx, op, x, new_shape)
|
||||
ret = ctx.op.reduce_op(op, x, new_shape)
|
||||
log_op(op, ret, [x])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == tuple(new_shape)
|
||||
@@ -59,24 +59,22 @@ class Ops:
|
||||
|
||||
def binary_op(ctx, op:BinaryOps, x, y):
|
||||
assert x.shape == y.shape
|
||||
ret = ctx.op.binary_op(ctx, op, x, y)
|
||||
ret = ctx.op.binary_op(op, x, y)
|
||||
log_op(op, ret, [x, y])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == x.shape
|
||||
return ret
|
||||
|
||||
def movement_op(ctx, op:MovementOps, x, arg=None):
|
||||
ret = ctx.op.movement_op(ctx, op, x, arg)
|
||||
ret = ctx.op.movement_op(op, x, arg)
|
||||
log_op(op, ret, [x])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == ShapeTracker(*x.shape).movement_op(op, arg).shape
|
||||
return ret
|
||||
|
||||
def processing_op(ctx, op:ProcessingOps, x, y, C):
|
||||
if getattr(ctx.op, "preprocessing_op", None) is not None: x,y,C = ctx.op.preprocessing_op(ctx, op, x, y, C)
|
||||
ret = ctx.op.processing_op(ctx, op, x, y, C)
|
||||
ret = ctx.op.processing_op(op, x, y, C)
|
||||
log_op(op, ret, [x, y])
|
||||
if getattr(ctx.op, "postprocessing_op", None) is not None: ret = ctx.op.postprocessing_op(ctx, op, ret, C)
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == (C.bs, C.cout, C.oy, C.ox)
|
||||
return ret
|
||||
Reference in New Issue
Block a user