mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add cout to conv_args, don't change the first 12
This commit is contained in:
@@ -30,7 +30,7 @@ def binary_broadcast(x_shape, y_shape, extra=False):
|
||||
def get_conv_args(x_shape, w_shape, stride, groups):
|
||||
# TODO: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout
|
||||
conv_args = namedtuple('conv_args',
|
||||
['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs'])
|
||||
['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs', 'cout'])
|
||||
cout,cin,H,W = w_shape
|
||||
ys,xs = (stride, stride) if isinstance(stride, int) else stride
|
||||
bs,cin_,iy,ix = x_shape
|
||||
@@ -38,4 +38,4 @@ def get_conv_args(x_shape, w_shape, stride, groups):
|
||||
if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})")
|
||||
assert cout % groups == 0
|
||||
rcout = cout//groups
|
||||
return conv_args(H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs)
|
||||
return conv_args(H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs, cout)
|
||||
|
||||
@@ -222,7 +222,7 @@ def conv(x,w,ret,C):
|
||||
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
|
||||
}""")
|
||||
|
||||
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in C])
|
||||
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in C[0:12]])
|
||||
|
||||
# tensx = (bs, groups*cin, iy, ix)
|
||||
# tensw = (groups*rcout, cin, H, W)
|
||||
@@ -248,7 +248,7 @@ def convdw(x,grad_output,dw,C):
|
||||
} }
|
||||
dw[get_global_id(0)*H*W + y*W + x] = acc;
|
||||
}""")
|
||||
convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C])
|
||||
convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C[0:12]])
|
||||
|
||||
def convdx(grad_output,w,dx,C):
|
||||
convdx_prg = clbuild("convdx", """
|
||||
@@ -275,7 +275,7 @@ def convdx(grad_output,w,dx,C):
|
||||
} }
|
||||
}
|
||||
""")
|
||||
convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C])
|
||||
convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C[0:12]])
|
||||
|
||||
def processing_op(op,a,b,ret,C):
|
||||
if op == ProcessingOps.CONV: conv(a,b,ret,C)
|
||||
|
||||
Reference in New Issue
Block a user