pass in shorts

This commit is contained in:
George Hotz
2022-06-21 23:33:23 -07:00
parent 18d74c01b1
commit 9ae01290ba
2 changed files with 17 additions and 7 deletions

View File

@@ -11,14 +11,16 @@ __kernel void image_conv(
short numPackedOutputChannelsForGroup,
short totalNumPackedOutputChannels,
short numOutputColumns,
short numOutputRows, short numInputRows,
short filterSizeX, short filterSizeY,
short numOutputRows, short numInputRows
/*short filterSizeX, short filterSizeY,
short paddingX, short paddingY,
short strideX, short strideY,
short dilationX, short dilationY
short dilationX, short dilationY*/
//ARGS
) {
//SHORTS
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
float4 outputValues[NUM_OUTPUTS];
@@ -91,14 +93,19 @@ __kernel void image_conv(
}
// insert unary and binary ops here
// output to memory
int2 outputLocation;
short outputColumn = startOutputColumn;
outputLocation.y = outputRow;
for (short i = 0; i < NUM_OUTPUTS; ++i) {
outputLocation.x = mad24(outputColumn, totalNumPackedOutputChannels, packedOutputChannel);
//BINOP
++outputColumn;
}
// output to memory
outputColumn = startOutputColumn;
for (short i = 0; i < NUM_OUTPUTS; ++i) {
outputLocation.x = mad24(outputColumn, totalNumPackedOutputChannels, packedOutputChannel);
if (outputColumn < numOutputColumns) {
write_imagef(output, outputLocation, outputValues[i]);
}

View File

@@ -150,14 +150,17 @@ class OpenCLBuffer(GPUBuffer):
assert C.py == 0, "batched conv doesn't work with y-padding"
assert C.cout%4 == 0
conv_src = CONV_SRC
conv_short_names = ["filterSizeX", "filterSizeY", "paddingX", "paddingY", "strideX", "strideY", "dilationX", "dilationY"]
conv_shorts = [C.W, C.H, C.px, C.py, C.xs, C.ys, C.dx, C.dy]
replacements["//SHORTS"] = ''.join([f"short {name} = {val};" for name,val in zip(conv_short_names, conv_shorts)])
for k,v in replacements.items():
conv_src = conv_src.replace(k, v)
#print(conv_src)
conv_args = [max(1, C.cin//4), C.groups*C.cin//4, max(1, C.rcout//4), C.cout//4, C.ox, C.oy, C.iy]
conv_prg = CLProgram("image_conv", conv_src,
options=tuple(options),
argdtypes=tuple([None, None, None] + [np.int16]*15 + [None]*len(ewbufs))
argdtypes=tuple([None, None, None] + [np.int16]*len(conv_args) + [None]*len(ewbufs))
)
conv_args = [max(1, C.cin//4), C.groups*C.cin//4, max(1, C.rcout//4), C.cout//4, C.ox, C.oy, C.iy, C.W, C.H, C.px, C.py, C.xs, C.ys, C.dx, C.dy]
global_work_size = [C.cout//4, (C.ox+3)//4, C.bs*C.oy]
conv_prg(global_work_size, None, x.image, w.image, ret.image, *conv_args, *[buf.image if 'image2d_t' in typ else buf.cl for typ, (_, buf) in zip(ewtypes, ewbufs)])
return ret