add cout to conv_args, don't change the first 12

This commit is contained in:
George Hotz
2022-06-12 00:10:15 -07:00
parent af300b121b
commit d47a421970
2 changed files with 5 additions and 5 deletions

View File

@@ -30,7 +30,7 @@ def binary_broadcast(x_shape, y_shape, extra=False):
def get_conv_args(x_shape, w_shape, stride, groups):
# TODO: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout
conv_args = namedtuple('conv_args',
['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs'])
['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs', 'cout'])
cout,cin,H,W = w_shape
ys,xs = (stride, stride) if isinstance(stride, int) else stride
bs,cin_,iy,ix = x_shape
@@ -38,4 +38,4 @@ def get_conv_args(x_shape, w_shape, stride, groups):
if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})")
assert cout % groups == 0
rcout = cout//groups
return conv_args(H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs)
return conv_args(H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs, cout)

View File

@@ -222,7 +222,7 @@ def conv(x,w,ret,C):
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
}""")
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in C])
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in C[0:12]])
# tensx = (bs, groups*cin, iy, ix)
# tensw = (groups*rcout, cin, H, W)
@@ -248,7 +248,7 @@ def convdw(x,grad_output,dw,C):
} }
dw[get_global_id(0)*H*W + y*W + x] = acc;
}""")
convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C])
convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C[0:12]])
def convdx(grad_output,w,dx,C):
convdx_prg = clbuild("convdx", """
@@ -275,7 +275,7 @@ def convdx(grad_output,w,dx,C):
} }
}
""")
convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C])
convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C[0:12]])
def processing_op(op,a,b,ret,C):
if op == ProcessingOps.CONV: conv(a,b,ret,C)