put the allocations back in the ops

This commit is contained in:
George Hotz
2022-06-16 12:12:55 -07:00
parent ce15bf2bdb
commit 9306759cbc
3 changed files with 60 additions and 47 deletions

View File

@@ -12,55 +12,55 @@ class CPUBuffer(np.ndarray):
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
def permute(x, order): return x.transpose(order)
def custompad(x, padding): return np.pad(x, padding)
def expand(x, new_shape): return np.broadcast_to(x, new_shape)
def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer)
@staticmethod
def fromCPU(x): return x
def toCPU(x): return x
def unary_op(op, x, ret):
if op == UnaryOps.RELU: ret[:] = x.relu()
elif op == UnaryOps.EXP: ret[:] = x.exp()
elif op == UnaryOps.LOG: ret[:] = x.log()
elif op == UnaryOps.NEG: ret[:] = -x
elif op == UnaryOps.SIGN: ret[:] = x.sign()
def unary_op(ctx, op, x):
if op == UnaryOps.RELU: return x.relu()
elif op == UnaryOps.EXP: return x.exp()
elif op == UnaryOps.LOG: return x.log()
elif op == UnaryOps.NEG: return -x
elif op == UnaryOps.SIGN: return x.sign()
else: raise Exception(f"{op} isn't supported")
def binary_op(op, x, y, ret):
if op == BinaryOps.ADD: ret[:] = x+y
elif op == BinaryOps.SUB: ret[:] = x-y
elif op == BinaryOps.MUL: ret[:] = x*y
elif op == BinaryOps.DIV: ret[:] = y/x
elif op == BinaryOps.POW: ret[:] = x**y
elif op == BinaryOps.CMPEQ: ret[:] = 1.0*(x==y)
def binary_op(ctx, op, x, y):
if op == BinaryOps.ADD: return x+y
elif op == BinaryOps.SUB: return x-y
elif op == BinaryOps.MUL: return x*y
elif op == BinaryOps.DIV: return y/x
elif op == BinaryOps.POW: return x**y
elif op == BinaryOps.CMPEQ: return 1.0*(x==y)
else: raise Exception(f"{op} isn't supported")
def reduce_op(op, inp, ret):
if inp.shape == ret.shape: # this is just a copy, regardless of the reduce op
ret[:] = inp
def reduce_op(ctx, op, inp, new_shape):
if inp.shape == new_shape: # this is just a copy, regardless of the reduce op
return inp[:]
else:
if ret.shape == (1,): # full reduce
if new_shape == (1,): # full reduce
axis = tuple(range(len(inp.shape)))
else:
assert len(inp.shape) == len(ret.shape)
axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, ret.shape)) if a != b])
if op == ReduceOps.SUM: ret[:] = inp.sum(axis, keepdims=True)
elif op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True)
assert len(inp.shape) == len(new_shape)
axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, new_shape)) if a != b])
if op == ReduceOps.SUM: return inp.sum(axis, keepdims=True)
elif op == ReduceOps.MAX: return inp.amax(axis, keepdims=True)
else: raise Exception(f"{op} isn't supported")
def movement_op(op, x, ret, arg=None):
if op == MovementOps.RESHAPE: ret[:] = x.reshape(arg)
elif op == MovementOps.PERMUTE: ret[:] = x.permute(arg)
elif op == MovementOps.FLIP: ret[:] = x.flip(arg)
def movement_op(ctx, op, x, arg=None):
if op == MovementOps.RESHAPE: return x.reshape(arg)
elif op == MovementOps.PERMUTE: return x.permute(arg)
elif op == MovementOps.FLIP: return x.flip(arg)
elif op == MovementOps.SLICE:
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
x = x.custompad(padding)
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
ret[:] = x[tuple([slice(x[0], x[1], None) for x in slicee])]
elif op == MovementOps.EXPAND: ret[:] = x.expand(arg)
return x[tuple([slice(x[0], x[1], None) for x in slicee])].view(CPUBuffer)
elif op == MovementOps.EXPAND: return x.expand(arg)
else: raise Exception(f"{op} isn't supported")
def processing_op(op,x,w,ret,C):
def processing_op(ctx, op,x,w,out_shape,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)])
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
@@ -74,4 +74,4 @@ def processing_op(op,x,w,ret,C):
for g in range(C.groups):
#ijYXyx,kjyx -> iYXk ->ikYX
tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox)
return np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer)

View File

@@ -48,7 +48,8 @@ def clbuild(name, prg):
def run(*args): clprg(cl_queue, *args)
return run
def unary_op(op, x, ret):
def unary_op(ctx, op, x):
ret = ctx.buffer(x.shape)
if op == UnaryOps.RELU: code = 'max(a, (float)0.)'
elif op == UnaryOps.EXP: code = 'exp(a)'
elif op == UnaryOps.LOG: code = 'log(a)'
@@ -64,7 +65,8 @@ def unary_op(op, x, ret):
unop([roundup(prod(ret.shape))//4], None, x.cl, ret.cl)
return ret
def binary_op(op, x, y, ret):
def binary_op(ctx, op, x, y):
ret = ctx.buffer(x.shape)
if op == BinaryOps.ADD: code = "a+b"
elif op == BinaryOps.SUB: code = "a-b"
elif op == BinaryOps.MUL: code = "a*b"
@@ -83,7 +85,8 @@ def binary_op(op, x, y, ret):
binop([roundup(prod(ret.shape))//4], None, x.cl, y.cl, ret.cl)
return ret
def reduce_op(op, inp, ret):
def reduce_op(ctx, op, inp, new_shape):
ret = ctx.buffer(new_shape)
if op == ReduceOps.SUM: code, start = "out += a", "0.0"
elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY"
else: raise Exception(f"{op} isn't supported")
@@ -114,17 +117,21 @@ def reduce_op(op, inp, ret):
res_g[gid] = out;
}"""
clbuild("reduce", prg)([prod(ret.shape)], None, inp.cl, ret.cl)
return ret
def contiguous(x, ret, st):
def contiguous(ctx, x, st):
ret = ctx.buffer(st.shape)
clbuild("contiguous", """__kernel void contiguous(__global const float *x, __global float *ret) {
int gid = get_global_id(0); int valid = 1; int idx = gid; """+st.expr().replace('//', '/')+""";
ret[gid] = valid ? x[idx] : 0.0; // should never be out-of-bounds accesses
}""")([prod(ret.shape)], None, x.cl, ret.cl)
return ret
def movement_op(op, x, ret, arg=None):
contiguous(x, ret, ShapeTracker(*x.shape).movement_op(op, arg))
def movement_op(ctx, op, x, arg=None):
return contiguous(ctx, x, ShapeTracker(*x.shape).movement_op(op, arg))
def processing_op(op,x,w,ret,C):
def processing_op(ctx,op,x,w,out_shape,C):
ret = ctx.buffer(out_shape)
assert op == ProcessingOps.CONV, f"{op} isn't supported"
# input = (bs, groups, cin, iy, ix)
# weight = (groups, rcout, cin, H, W)
@@ -156,3 +163,4 @@ def processing_op(op,x,w,ret,C):
}""")
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]])
return ret

View File

@@ -43,33 +43,38 @@ def log_op(op, ret, inp):
class Ops:
def unary_op(ctx, op:UnaryOps, x):
ret = ctx.buffer(x.shape)
ctx.op.unary_op(op, x, ret)
ret = ctx.op.unary_op(ctx, op, x)
assert isinstance(ret, ctx.buffer)
assert ret.shape == x.shape
log_op(op, ret, [x])
return ret
def reduce_op(ctx, op:ReduceOps, x, new_shape):
ret = ctx.buffer(new_shape)
ctx.op.reduce_op(op, x, ret)
ret = ctx.op.reduce_op(ctx, op, x, new_shape)
assert isinstance(ret, ctx.buffer)
assert ret.shape == tuple(new_shape)
log_op(op, ret, [x])
return ret
def binary_op(ctx, op:BinaryOps, x, y):
assert x.shape == y.shape
ret = ctx.buffer(x.shape)
ctx.op.binary_op(op, x, y, ret)
ret = ctx.op.binary_op(ctx, op, x, y)
assert isinstance(ret, ctx.buffer)
assert ret.shape == x.shape
log_op(op, ret, [x, y])
return ret
def movement_op(ctx, op:MovementOps, x, arg=None):
ret = ctx.buffer(ShapeTracker(*x.shape).movement_op(op, arg).shape)
ctx.op.movement_op(op, x, ret, arg)
ret = ctx.op.movement_op(ctx, op, x, arg)
assert isinstance(ret, ctx.buffer)
assert ret.shape == ShapeTracker(*x.shape).movement_op(op, arg).shape
log_op(op, ret, [x])
return ret
def processing_op(ctx, op:ProcessingOps, x, y, out_shape, C):
# TODO: can we do better than out_shape?
ret = ctx.buffer(out_shape)
ctx.op.processing_op(op, x, y, ret, C)
ret = ctx.op.processing_op(ctx, op, x, y, out_shape, C)
assert isinstance(ret, ctx.buffer)
assert ret.shape == out_shape
log_op(op, ret, [x, y])
return ret