This commit is contained in:
George Hotz
2022-06-21 23:48:58 -07:00
parent 1074dfbb71
commit 9cb0522574
2 changed files with 13 additions and 3 deletions

View File

@@ -5,13 +5,15 @@
__kernel void image_conv(
read_only image2d_t input,
read_only image2d_t weights,
write_only image2d_t output,
short numPackedInputChannelsForGroup,
write_only image2d_t output
#ifndef NOARGS
,short numPackedInputChannelsForGroup,
short totalNumPackedInputChannels,
short numPackedOutputChannelsForGroup,
short totalNumPackedOutputChannels,
short numOutputColumns,
short numOutputRows, short numInputRows
#endif
/*short filterSizeX, short filterSizeY,
short paddingX, short paddingY,
short strideX, short strideY,

View File

@@ -154,11 +154,19 @@ class OpenCLBuffer(GPUBuffer):
conv_src = CONV_SRC
conv_short_names = ["filterSizeX", "filterSizeY", "paddingX", "paddingY", "strideX", "strideY", "dilationX", "dilationY"]
conv_shorts = [C.W, C.H, C.px, C.py, C.xs, C.ys, C.dx, C.dy]
conv_arg_names = ["numPackedInputChannelsForGroup", "totalNumPackedInputChannels", "numPackedOutputChannelsForGroup", "totalNumPackedOutputChannels", "numOutputColumns", "numOutputRows", "numInputRows"]
conv_args = [max(1, C.cin//4), C.groups*C.cin//4, max(1, C.rcout//4), C.cout//4, C.ox, C.oy, C.iy]
# comment out for args
conv_short_names += conv_arg_names
conv_shorts += conv_args
conv_args = []
options.append("-DNOARGS")
replacements["//SHORTS"] = ''.join([f"short {name} = {val};" for name,val in zip(conv_short_names, conv_shorts)])
for k,v in replacements.items():
conv_src = conv_src.replace(k, v)
#print(conv_src)
conv_args = [max(1, C.cin//4), C.groups*C.cin//4, max(1, C.rcout//4), C.cout//4, C.ox, C.oy, C.iy]
conv_prg = CLProgram("image_conv", conv_src,
options=tuple(options),
argdtypes=tuple([None, None, None] + [np.int16]*len(conv_args) + [None]*len(ewbufs))