namedtuple for conv_args

This commit is contained in:
George Hotz
2022-06-09 09:19:52 -07:00
parent d0c3204996
commit 40f8eb7383
4 changed files with 24 additions and 33 deletions

View File

@@ -1,4 +1,5 @@
import numpy as np
from collections import namedtuple
def binary_broadcast(x_shape, y_shape, extra=False):
n_dims = max(len(x_shape), len(y_shape))
@@ -22,14 +23,16 @@ def binary_broadcast(x_shape, y_shape, extra=False):
return (shape_ret, dimlist, complist) if extra else shape_ret
def get_conv_args(x_shape, w_shape, stride, groups):
conv_args = namedtuple('conv_args',
['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs'])
cout,cin,H,W = w_shape
ys,xs = stride
ys,xs = (stride, stride) if isinstance(stride, int) else stride
bs,cin_,iy,ix = x_shape
oy,ox = (iy-(H-ys))//ys, (ix-(W-xs))//xs
if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w.shape}. ({cin*ctx.groups} vs. {cin_})")
if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})")
assert cout % groups == 0
rcout = cout//groups
return H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
return conv_args(H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs)
from enum import Enum
UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])

View File

@@ -228,11 +228,7 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False):
osize)
return c
# TODO: combine any of these three?
def conv(x,w,ret,conv_args):
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
def conv(x,w,ret,C):
# input = (bs, groups, cin, iy, ix)
# weight = (groups, rcout, cin, H, W)
# output = (bs, groups, rcout, oy, ox)
@@ -259,15 +255,14 @@ def conv(x,w,ret,conv_args):
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
}""")
conv_prg([bs*groups*rcout, oy, ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in conv_args])
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in C])
return ret
# tensx = (bs, groups*cin, iy, ix)
# tensw = (groups*rcout, cin, H, W)
# ggg = (bs, groups*rout, oy, ox)
def convdw(x,grad_output,dw,conv_args):
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
def convdw(x,grad_output,dw,C):
convdw_prg = clbuild("convdw", """
__kernel void convdw(__global const float *tensx, __global const float *ggg, __global float *dw,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
@@ -287,11 +282,10 @@ def convdw(x,grad_output,dw,conv_args):
} }
dw[get_global_id(0)*H*W + y*W + x] = acc;
}""")
convdw_prg([groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in conv_args])
convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C])
return dw
def convdx(w,grad_output,dx,conv_args):
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
def convdx(w,grad_output,dx,C):
convdx_prg = clbuild("convdx", """
__kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
@@ -316,5 +310,5 @@ def convdx(w,grad_output,dx,conv_args):
} }
}
""")
convdx_prg([bs, groups, cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in conv_args])
convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C])
return dx

View File

@@ -23,18 +23,14 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, reshape, perm
# ************* processing ops *************
def conv(x,w,ret,conv_args):
# TODO: replace conv_args with stride and groups everywhere
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
ret[:] = torch.nn.functional.conv2d(x, w, stride=(ys,xs), groups=groups)
def conv(x,w,ret,C):
ret[:] = torch.nn.functional.conv2d(x, w, stride=(C.ys,C.xs), groups=C.groups)
return ret
def convdw(x,grad_output,dw,conv_args):
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
dw[:] = torch.nn.grad.conv2d_weight(x, dw.shape, grad_output, stride=(ys,xs), groups=groups)
def convdw(x,grad_output,dw,C):
dw[:] = torch.nn.grad.conv2d_weight(x, dw.shape, grad_output, stride=(C.ys,C.xs), groups=C.groups)
return dw
def convdx(w,grad_output,dx,conv_args):
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
dx[:] = torch.nn.grad.conv2d_input(dx.shape, w, grad_output, stride=(ys,xs), groups=groups)
def convdx(w,grad_output,dx,C):
dx[:] = torch.nn.grad.conv2d_input(dx.shape, w, grad_output, stride=(C.ys,C.xs), groups=C.groups)
return dx

View File

@@ -179,14 +179,12 @@ class Matmul(Function):
class Conv2D(Function):
def forward(ctx, x, w, stride=1, groups=1):
if isinstance(ctx.stride, int): ctx.stride = (ctx.stride, ctx.stride)
ctx.save_for_backward(x,w)
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args = get_conv_args(x.shape, w.shape, ctx.stride, ctx.groups)
return ctx.op.conv(x, w, ctx.buffer((bs, groups*rcout, oy, ox)), conv_args)
C = get_conv_args(x.shape, w.shape, stride, groups)
ctx.save_for_backward(x,w,C)
return ctx.op.conv(x, w, ctx.buffer((C.bs, C.groups*C.rcout, C.oy, C.ox)), C)
def backward(ctx, grad_output):
x, w = ctx.saved_tensors
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args = get_conv_args(x.shape, w.shape, ctx.stride, ctx.groups)
dx = ctx.op.convdx(w, grad_output, ctx.buffer((bs, groups*cin, iy, ix)), conv_args) if ctx.needs_input_grad[0] else None
dw = ctx.op.convdw(x, grad_output, ctx.buffer((groups*rcout, cin, H, W)), conv_args) if ctx.needs_input_grad[1] else None
x, w, C = ctx.saved_tensors
dx = ctx.op.convdx(w, grad_output, ctx.buffer((C.bs, C.groups*C.cin, C.iy, C.ix)), C) if ctx.needs_input_grad[0] else None
dw = ctx.op.convdw(x, grad_output, ctx.buffer((C.groups*C.rcout, C.cin, C.H, C.W)), C) if ctx.needs_input_grad[1] else None
return dx, dw