minor cleanups

This commit is contained in:
George Hotz
2022-06-15 22:27:46 -07:00
parent 3667200df5
commit bcfbb4c81b
3 changed files with 7 additions and 16 deletions

View File

@@ -60,23 +60,18 @@ def movement_op(op, x, ret, arg=None):
elif op == MovementOps.EXPAND: ret[:] = x.expand(arg)
else: raise Exception(f"{op} isn't supported")
def get_tx(x, C):
def processing_op(op,x,w,ret,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)])
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
return np.lib.stride_tricks.as_strided(gx,
tx = np.lib.stride_tricks.as_strided(gx,
shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W),
strides=(*gx.strides[0:3], gx.strides[3]*C.ys, gx.strides[4]*C.xs, gx.strides[3]*C.dy, gx.strides[4]*C.dx),
writeable=False,
)
def conv(x,w,ret,C):
if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)])
tx = get_tx(x, C)
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
tmp = np.zeros((C.bs,C.groups,C.oy,C.ox,C.rcout),dtype=x.dtype)
for g in range(C.groups):
#ijYXyx,kjyx -> iYXk ->ikYX
tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox)
def processing_op(op,a,b,ret,C):
if op == ProcessingOps.CONV: conv(a,b,ret,C)

View File

@@ -42,9 +42,6 @@ class GPUBuffer:
cl.enqueue_copy(cl_queue, data, self.cl, is_blocking=True)
return data
def buffer_np(x):
return cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=x)
@functools.lru_cache
def clbuild(name, prg):
clprg = cl.Program(cl_ctx, prg).build().__getattr__(name)
@@ -126,7 +123,8 @@ def contiguous(x, ret, st):
def movement_op(op, x, ret, arg=None):
contiguous(x, ret, ShapeTracker(*x.shape).movement_op(op, arg))
def conv(x,w,ret,C):
def processing_op(op,x,w,ret,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
# input = (bs, groups, cin, iy, ix)
# weight = (groups, rcout, cin, H, W)
# output = (bs, groups, rcout, oy, ox)
@@ -157,6 +155,3 @@ def conv(x,w,ret,C):
}""")
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]])
def processing_op(op,a,b,ret,C):
if op == ProcessingOps.CONV: conv(a,b,ret,C)

View File

@@ -26,6 +26,7 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op
from tinygrad.ops import ProcessingOps
def processing_op(op,x,w,ret,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
# stride is the same as doing the full conv and slicing with stride at the end
# dilation is the same as conving with a larger weight matrix with 0s added
ret[:] = torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))