mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
namedtuple for conv_args
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
from collections import namedtuple
|
||||
|
||||
def binary_broadcast(x_shape, y_shape, extra=False):
|
||||
n_dims = max(len(x_shape), len(y_shape))
|
||||
@@ -22,14 +23,16 @@ def binary_broadcast(x_shape, y_shape, extra=False):
|
||||
return (shape_ret, dimlist, complist) if extra else shape_ret
|
||||
|
||||
def get_conv_args(x_shape, w_shape, stride, groups):
|
||||
conv_args = namedtuple('conv_args',
|
||||
['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs'])
|
||||
cout,cin,H,W = w_shape
|
||||
ys,xs = stride
|
||||
ys,xs = (stride, stride) if isinstance(stride, int) else stride
|
||||
bs,cin_,iy,ix = x_shape
|
||||
oy,ox = (iy-(H-ys))//ys, (ix-(W-xs))//xs
|
||||
if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w.shape}. ({cin*ctx.groups} vs. {cin_})")
|
||||
if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})")
|
||||
assert cout % groups == 0
|
||||
rcout = cout//groups
|
||||
return H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
|
||||
return conv_args(H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs)
|
||||
|
||||
from enum import Enum
|
||||
UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
|
||||
|
||||
@@ -228,11 +228,7 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False):
|
||||
osize)
|
||||
return c
|
||||
|
||||
|
||||
# TODO: combine any of these three?
|
||||
def conv(x,w,ret,conv_args):
|
||||
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
|
||||
|
||||
def conv(x,w,ret,C):
|
||||
# input = (bs, groups, cin, iy, ix)
|
||||
# weight = (groups, rcout, cin, H, W)
|
||||
# output = (bs, groups, rcout, oy, ox)
|
||||
@@ -259,15 +255,14 @@ def conv(x,w,ret,conv_args):
|
||||
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
|
||||
}""")
|
||||
|
||||
conv_prg([bs*groups*rcout, oy, ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in conv_args])
|
||||
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in C])
|
||||
return ret
|
||||
|
||||
# tensx = (bs, groups*cin, iy, ix)
|
||||
# tensw = (groups*rcout, cin, H, W)
|
||||
# ggg = (bs, groups*rout, oy, ox)
|
||||
|
||||
def convdw(x,grad_output,dw,conv_args):
|
||||
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
|
||||
def convdw(x,grad_output,dw,C):
|
||||
convdw_prg = clbuild("convdw", """
|
||||
__kernel void convdw(__global const float *tensx, __global const float *ggg, __global float *dw,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
|
||||
@@ -287,11 +282,10 @@ def convdw(x,grad_output,dw,conv_args):
|
||||
} }
|
||||
dw[get_global_id(0)*H*W + y*W + x] = acc;
|
||||
}""")
|
||||
convdw_prg([groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in conv_args])
|
||||
convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C])
|
||||
return dw
|
||||
|
||||
def convdx(w,grad_output,dx,conv_args):
|
||||
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
|
||||
def convdx(w,grad_output,dx,C):
|
||||
convdx_prg = clbuild("convdx", """
|
||||
__kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
|
||||
@@ -316,5 +310,5 @@ def convdx(w,grad_output,dx,conv_args):
|
||||
} }
|
||||
}
|
||||
""")
|
||||
convdx_prg([bs, groups, cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in conv_args])
|
||||
convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C])
|
||||
return dx
|
||||
|
||||
@@ -23,18 +23,14 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, reshape, perm
|
||||
|
||||
# ************* processing ops *************
|
||||
|
||||
def conv(x,w,ret,conv_args):
|
||||
# TODO: replace conv_args with stride and groups everywhere
|
||||
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
|
||||
ret[:] = torch.nn.functional.conv2d(x, w, stride=(ys,xs), groups=groups)
|
||||
def conv(x,w,ret,C):
|
||||
ret[:] = torch.nn.functional.conv2d(x, w, stride=(C.ys,C.xs), groups=C.groups)
|
||||
return ret
|
||||
|
||||
def convdw(x,grad_output,dw,conv_args):
|
||||
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
|
||||
dw[:] = torch.nn.grad.conv2d_weight(x, dw.shape, grad_output, stride=(ys,xs), groups=groups)
|
||||
def convdw(x,grad_output,dw,C):
|
||||
dw[:] = torch.nn.grad.conv2d_weight(x, dw.shape, grad_output, stride=(C.ys,C.xs), groups=C.groups)
|
||||
return dw
|
||||
|
||||
def convdx(w,grad_output,dx,conv_args):
|
||||
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
|
||||
dx[:] = torch.nn.grad.conv2d_input(dx.shape, w, grad_output, stride=(ys,xs), groups=groups)
|
||||
def convdx(w,grad_output,dx,C):
|
||||
dx[:] = torch.nn.grad.conv2d_input(dx.shape, w, grad_output, stride=(C.ys,C.xs), groups=C.groups)
|
||||
return dx
|
||||
|
||||
@@ -179,14 +179,12 @@ class Matmul(Function):
|
||||
|
||||
class Conv2D(Function):
|
||||
def forward(ctx, x, w, stride=1, groups=1):
|
||||
if isinstance(ctx.stride, int): ctx.stride = (ctx.stride, ctx.stride)
|
||||
ctx.save_for_backward(x,w)
|
||||
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args = get_conv_args(x.shape, w.shape, ctx.stride, ctx.groups)
|
||||
return ctx.op.conv(x, w, ctx.buffer((bs, groups*rcout, oy, ox)), conv_args)
|
||||
C = get_conv_args(x.shape, w.shape, stride, groups)
|
||||
ctx.save_for_backward(x,w,C)
|
||||
return ctx.op.conv(x, w, ctx.buffer((C.bs, C.groups*C.rcout, C.oy, C.ox)), C)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x, w = ctx.saved_tensors
|
||||
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args = get_conv_args(x.shape, w.shape, ctx.stride, ctx.groups)
|
||||
dx = ctx.op.convdx(w, grad_output, ctx.buffer((bs, groups*cin, iy, ix)), conv_args) if ctx.needs_input_grad[0] else None
|
||||
dw = ctx.op.convdw(x, grad_output, ctx.buffer((groups*rcout, cin, H, W)), conv_args) if ctx.needs_input_grad[1] else None
|
||||
x, w, C = ctx.saved_tensors
|
||||
dx = ctx.op.convdx(w, grad_output, ctx.buffer((C.bs, C.groups*C.cin, C.iy, C.ix)), C) if ctx.needs_input_grad[0] else None
|
||||
dw = ctx.op.convdw(x, grad_output, ctx.buffer((C.groups*C.rcout, C.cin, C.H, C.W)), C) if ctx.needs_input_grad[1] else None
|
||||
return dx, dw
|
||||
Reference in New Issue
Block a user