diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 95727f5712..58f07af045 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -60,23 +60,18 @@ def movement_op(op, x, ret, arg=None): elif op == MovementOps.EXPAND: ret[:] = x.expand(arg) else: raise Exception(f"{op} isn't supported") -def get_tx(x, C): +def processing_op(op,x,w,ret,C): + assert op == ProcessingOps.CONV, f"{op} isn't supported" + if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)]) gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3]) - return np.lib.stride_tricks.as_strided(gx, + tx = np.lib.stride_tricks.as_strided(gx, shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W), strides=(*gx.strides[0:3], gx.strides[3]*C.ys, gx.strides[4]*C.xs, gx.strides[3]*C.dy, gx.strides[4]*C.dx), writeable=False, ) - -def conv(x,w,ret,C): - if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)]) - tx = get_tx(x, C) tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W) tmp = np.zeros((C.bs,C.groups,C.oy,C.ox,C.rcout),dtype=x.dtype) for g in range(C.groups): #ijYXyx,kjyx -> iYXk ->ikYX tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3))) ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox) - -def processing_op(op,a,b,ret,C): - if op == ProcessingOps.CONV: conv(a,b,ret,C) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 8e1fe04a5b..1806337cff 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -42,9 +42,6 @@ class GPUBuffer: cl.enqueue_copy(cl_queue, data, self.cl, is_blocking=True) return data -def buffer_np(x): - return cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=x) - @functools.lru_cache def clbuild(name, prg): clprg = cl.Program(cl_ctx, prg).build().__getattr__(name) @@ -126,7 +123,8 @@ def contiguous(x, ret, st): def movement_op(op, x, ret, arg=None): contiguous(x, ret, ShapeTracker(*x.shape).movement_op(op, arg)) -def conv(x,w,ret,C): +def processing_op(op,x,w,ret,C): + assert op == ProcessingOps.CONV, f"{op} isn't supported" # input = (bs, groups, cin, iy, ix) # weight = (groups, rcout, cin, H, W) # output = (bs, groups, rcout, oy, ox) @@ -157,6 +155,3 @@ def conv(x,w,ret,C): }""") conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]]) - -def processing_op(op,a,b,ret,C): - if op == ProcessingOps.CONV: conv(a,b,ret,C) diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 3826b3285c..9471a9cbee 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -26,6 +26,7 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op from tinygrad.ops import ProcessingOps def processing_op(op,x,w,ret,C): + assert op == ProcessingOps.CONV, f"{op} isn't supported" # stride is the same as doing the full conv and slicing with stride at the end # dilation is the same as conving with a larger weight matrix with 0s added ret[:] = torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))