mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
minor cleanups
This commit is contained in:
@@ -60,23 +60,18 @@ def movement_op(op, x, ret, arg=None):
|
||||
elif op == MovementOps.EXPAND: ret[:] = x.expand(arg)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def get_tx(x, C):
|
||||
def processing_op(op,x,w,ret,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)])
|
||||
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
|
||||
return np.lib.stride_tricks.as_strided(gx,
|
||||
tx = np.lib.stride_tricks.as_strided(gx,
|
||||
shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W),
|
||||
strides=(*gx.strides[0:3], gx.strides[3]*C.ys, gx.strides[4]*C.xs, gx.strides[3]*C.dy, gx.strides[4]*C.dx),
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
def conv(x,w,ret,C):
|
||||
if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)])
|
||||
tx = get_tx(x, C)
|
||||
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
|
||||
tmp = np.zeros((C.bs,C.groups,C.oy,C.ox,C.rcout),dtype=x.dtype)
|
||||
for g in range(C.groups):
|
||||
#ijYXyx,kjyx -> iYXk ->ikYX
|
||||
tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
|
||||
ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox)
|
||||
|
||||
def processing_op(op,a,b,ret,C):
|
||||
if op == ProcessingOps.CONV: conv(a,b,ret,C)
|
||||
|
||||
@@ -42,9 +42,6 @@ class GPUBuffer:
|
||||
cl.enqueue_copy(cl_queue, data, self.cl, is_blocking=True)
|
||||
return data
|
||||
|
||||
def buffer_np(x):
|
||||
return cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=x)
|
||||
|
||||
@functools.lru_cache
|
||||
def clbuild(name, prg):
|
||||
clprg = cl.Program(cl_ctx, prg).build().__getattr__(name)
|
||||
@@ -126,7 +123,8 @@ def contiguous(x, ret, st):
|
||||
def movement_op(op, x, ret, arg=None):
|
||||
contiguous(x, ret, ShapeTracker(*x.shape).movement_op(op, arg))
|
||||
|
||||
def conv(x,w,ret,C):
|
||||
def processing_op(op,x,w,ret,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
# input = (bs, groups, cin, iy, ix)
|
||||
# weight = (groups, rcout, cin, H, W)
|
||||
# output = (bs, groups, rcout, oy, ox)
|
||||
@@ -157,6 +155,3 @@ def conv(x,w,ret,C):
|
||||
}""")
|
||||
|
||||
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]])
|
||||
|
||||
def processing_op(op,a,b,ret,C):
|
||||
if op == ProcessingOps.CONV: conv(a,b,ret,C)
|
||||
|
||||
@@ -26,6 +26,7 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op
|
||||
from tinygrad.ops import ProcessingOps
|
||||
|
||||
def processing_op(op,x,w,ret,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
# stride is the same as doing the full conv and slicing with stride at the end
|
||||
# dilation is the same as conving with a larger weight matrix with 0s added
|
||||
ret[:] = torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))
|
||||
|
||||
Reference in New Issue
Block a user