mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
put the allocations back in the ops
This commit is contained in:
@@ -12,55 +12,55 @@ class CPUBuffer(np.ndarray):
|
||||
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
|
||||
def permute(x, order): return x.transpose(order)
|
||||
def custompad(x, padding): return np.pad(x, padding)
|
||||
def expand(x, new_shape): return np.broadcast_to(x, new_shape)
|
||||
def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x): return x
|
||||
def toCPU(x): return x
|
||||
|
||||
def unary_op(op, x, ret):
|
||||
if op == UnaryOps.RELU: ret[:] = x.relu()
|
||||
elif op == UnaryOps.EXP: ret[:] = x.exp()
|
||||
elif op == UnaryOps.LOG: ret[:] = x.log()
|
||||
elif op == UnaryOps.NEG: ret[:] = -x
|
||||
elif op == UnaryOps.SIGN: ret[:] = x.sign()
|
||||
def unary_op(ctx, op, x):
|
||||
if op == UnaryOps.RELU: return x.relu()
|
||||
elif op == UnaryOps.EXP: return x.exp()
|
||||
elif op == UnaryOps.LOG: return x.log()
|
||||
elif op == UnaryOps.NEG: return -x
|
||||
elif op == UnaryOps.SIGN: return x.sign()
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def binary_op(op, x, y, ret):
|
||||
if op == BinaryOps.ADD: ret[:] = x+y
|
||||
elif op == BinaryOps.SUB: ret[:] = x-y
|
||||
elif op == BinaryOps.MUL: ret[:] = x*y
|
||||
elif op == BinaryOps.DIV: ret[:] = y/x
|
||||
elif op == BinaryOps.POW: ret[:] = x**y
|
||||
elif op == BinaryOps.CMPEQ: ret[:] = 1.0*(x==y)
|
||||
def binary_op(ctx, op, x, y):
|
||||
if op == BinaryOps.ADD: return x+y
|
||||
elif op == BinaryOps.SUB: return x-y
|
||||
elif op == BinaryOps.MUL: return x*y
|
||||
elif op == BinaryOps.DIV: return y/x
|
||||
elif op == BinaryOps.POW: return x**y
|
||||
elif op == BinaryOps.CMPEQ: return 1.0*(x==y)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def reduce_op(op, inp, ret):
|
||||
if inp.shape == ret.shape: # this is just a copy, regardless of the reduce op
|
||||
ret[:] = inp
|
||||
def reduce_op(ctx, op, inp, new_shape):
|
||||
if inp.shape == new_shape: # this is just a copy, regardless of the reduce op
|
||||
return inp[:]
|
||||
else:
|
||||
if ret.shape == (1,): # full reduce
|
||||
if new_shape == (1,): # full reduce
|
||||
axis = tuple(range(len(inp.shape)))
|
||||
else:
|
||||
assert len(inp.shape) == len(ret.shape)
|
||||
axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, ret.shape)) if a != b])
|
||||
if op == ReduceOps.SUM: ret[:] = inp.sum(axis, keepdims=True)
|
||||
elif op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True)
|
||||
assert len(inp.shape) == len(new_shape)
|
||||
axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, new_shape)) if a != b])
|
||||
if op == ReduceOps.SUM: return inp.sum(axis, keepdims=True)
|
||||
elif op == ReduceOps.MAX: return inp.amax(axis, keepdims=True)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def movement_op(op, x, ret, arg=None):
|
||||
if op == MovementOps.RESHAPE: ret[:] = x.reshape(arg)
|
||||
elif op == MovementOps.PERMUTE: ret[:] = x.permute(arg)
|
||||
elif op == MovementOps.FLIP: ret[:] = x.flip(arg)
|
||||
def movement_op(ctx, op, x, arg=None):
|
||||
if op == MovementOps.RESHAPE: return x.reshape(arg)
|
||||
elif op == MovementOps.PERMUTE: return x.permute(arg)
|
||||
elif op == MovementOps.FLIP: return x.flip(arg)
|
||||
elif op == MovementOps.SLICE:
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
x = x.custompad(padding)
|
||||
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
|
||||
ret[:] = x[tuple([slice(x[0], x[1], None) for x in slicee])]
|
||||
elif op == MovementOps.EXPAND: ret[:] = x.expand(arg)
|
||||
return x[tuple([slice(x[0], x[1], None) for x in slicee])].view(CPUBuffer)
|
||||
elif op == MovementOps.EXPAND: return x.expand(arg)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def processing_op(op,x,w,ret,C):
|
||||
def processing_op(ctx, op,x,w,out_shape,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)])
|
||||
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
|
||||
@@ -74,4 +74,4 @@ def processing_op(op,x,w,ret,C):
|
||||
for g in range(C.groups):
|
||||
#ijYXyx,kjyx -> iYXk ->ikYX
|
||||
tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
|
||||
ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox)
|
||||
return np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer)
|
||||
|
||||
@@ -48,7 +48,8 @@ def clbuild(name, prg):
|
||||
def run(*args): clprg(cl_queue, *args)
|
||||
return run
|
||||
|
||||
def unary_op(op, x, ret):
|
||||
def unary_op(ctx, op, x):
|
||||
ret = ctx.buffer(x.shape)
|
||||
if op == UnaryOps.RELU: code = 'max(a, (float)0.)'
|
||||
elif op == UnaryOps.EXP: code = 'exp(a)'
|
||||
elif op == UnaryOps.LOG: code = 'log(a)'
|
||||
@@ -64,7 +65,8 @@ def unary_op(op, x, ret):
|
||||
unop([roundup(prod(ret.shape))//4], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def binary_op(op, x, y, ret):
|
||||
def binary_op(ctx, op, x, y):
|
||||
ret = ctx.buffer(x.shape)
|
||||
if op == BinaryOps.ADD: code = "a+b"
|
||||
elif op == BinaryOps.SUB: code = "a-b"
|
||||
elif op == BinaryOps.MUL: code = "a*b"
|
||||
@@ -83,7 +85,8 @@ def binary_op(op, x, y, ret):
|
||||
binop([roundup(prod(ret.shape))//4], None, x.cl, y.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def reduce_op(op, inp, ret):
|
||||
def reduce_op(ctx, op, inp, new_shape):
|
||||
ret = ctx.buffer(new_shape)
|
||||
if op == ReduceOps.SUM: code, start = "out += a", "0.0"
|
||||
elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY"
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
@@ -114,17 +117,21 @@ def reduce_op(op, inp, ret):
|
||||
res_g[gid] = out;
|
||||
}"""
|
||||
clbuild("reduce", prg)([prod(ret.shape)], None, inp.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def contiguous(x, ret, st):
|
||||
def contiguous(ctx, x, st):
|
||||
ret = ctx.buffer(st.shape)
|
||||
clbuild("contiguous", """__kernel void contiguous(__global const float *x, __global float *ret) {
|
||||
int gid = get_global_id(0); int valid = 1; int idx = gid; """+st.expr().replace('//', '/')+""";
|
||||
ret[gid] = valid ? x[idx] : 0.0; // should never be out-of-bounds accesses
|
||||
}""")([prod(ret.shape)], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def movement_op(op, x, ret, arg=None):
|
||||
contiguous(x, ret, ShapeTracker(*x.shape).movement_op(op, arg))
|
||||
def movement_op(ctx, op, x, arg=None):
|
||||
return contiguous(ctx, x, ShapeTracker(*x.shape).movement_op(op, arg))
|
||||
|
||||
def processing_op(op,x,w,ret,C):
|
||||
def processing_op(ctx,op,x,w,out_shape,C):
|
||||
ret = ctx.buffer(out_shape)
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
# input = (bs, groups, cin, iy, ix)
|
||||
# weight = (groups, rcout, cin, H, W)
|
||||
@@ -156,3 +163,4 @@ def processing_op(op,x,w,ret,C):
|
||||
}""")
|
||||
|
||||
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]])
|
||||
return ret
|
||||
|
||||
@@ -43,33 +43,38 @@ def log_op(op, ret, inp):
|
||||
|
||||
class Ops:
|
||||
def unary_op(ctx, op:UnaryOps, x):
|
||||
ret = ctx.buffer(x.shape)
|
||||
ctx.op.unary_op(op, x, ret)
|
||||
ret = ctx.op.unary_op(ctx, op, x)
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == x.shape
|
||||
log_op(op, ret, [x])
|
||||
return ret
|
||||
|
||||
def reduce_op(ctx, op:ReduceOps, x, new_shape):
|
||||
ret = ctx.buffer(new_shape)
|
||||
ctx.op.reduce_op(op, x, ret)
|
||||
ret = ctx.op.reduce_op(ctx, op, x, new_shape)
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == tuple(new_shape)
|
||||
log_op(op, ret, [x])
|
||||
return ret
|
||||
|
||||
def binary_op(ctx, op:BinaryOps, x, y):
|
||||
assert x.shape == y.shape
|
||||
ret = ctx.buffer(x.shape)
|
||||
ctx.op.binary_op(op, x, y, ret)
|
||||
ret = ctx.op.binary_op(ctx, op, x, y)
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == x.shape
|
||||
log_op(op, ret, [x, y])
|
||||
return ret
|
||||
|
||||
def movement_op(ctx, op:MovementOps, x, arg=None):
|
||||
ret = ctx.buffer(ShapeTracker(*x.shape).movement_op(op, arg).shape)
|
||||
ctx.op.movement_op(op, x, ret, arg)
|
||||
ret = ctx.op.movement_op(ctx, op, x, arg)
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == ShapeTracker(*x.shape).movement_op(op, arg).shape
|
||||
log_op(op, ret, [x])
|
||||
return ret
|
||||
|
||||
def processing_op(ctx, op:ProcessingOps, x, y, out_shape, C):
|
||||
# TODO: can we do better than out_shape?
|
||||
ret = ctx.buffer(out_shape)
|
||||
ctx.op.processing_op(op, x, y, ret, C)
|
||||
ret = ctx.op.processing_op(ctx, op, x, y, out_shape, C)
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == out_shape
|
||||
log_op(op, ret, [x, y])
|
||||
return ret
|
||||
Reference in New Issue
Block a user