diff --git a/accel/opencl/conv.cl b/accel/opencl/conv.cl index b634d65d16..39f8698a0e 100644 --- a/accel/opencl/conv.cl +++ b/accel/opencl/conv.cl @@ -51,6 +51,31 @@ __kernel void image_conv( inputLocation.y = mad24(outputRow, strideY, -paddingY); #endif +#ifdef DEPTHWISE_UNSTRIDED + for (short rfRow = 0; rfRow < filterSizeY; ++rfRow) { + float4 inputValues[4]; + inputLocation.x = startX; + for (short i = 1; i < 4; ++i) { + inputValues[i] = read_imagef(input, smp, inputLocation); + inputLocation.x += totalNumPackedOutputChannels; + } + for (short rfColumn = 0; rfColumn < filterSizeX; ++rfColumn) { + inputValues[0] = inputValues[1]; + inputValues[1] = inputValues[2]; + inputValues[2] = inputValues[3]; + inputValues[3] = read_imagef(input, smp, inputLocation); + inputLocation.x += totalNumPackedInputChannels; + float4 weightValues = read_imagef(weights, smp, weightLocation); + ++weightLocation.x; + outputValues[0] += inputValues[0] * weightValues; + outputValues[1] += inputValues[1] * weightValues; + outputValues[2] += inputValues[2] * weightValues; + outputValues[3] += inputValues[3] * weightValues; + } + ++inputLocation.y; + } +#else + for (short rfRow = 0; rfRow < filterSizeY; ++rfRow) { // numPackedInputChannelsForGroup is 1 in depthwise for (short packedInputChannel = 0; packedInputChannel < numPackedInputChannelsForGroup; ++packedInputChannel) { @@ -91,6 +116,7 @@ __kernel void image_conv( } inputLocation.y += dilationY; } +#endif // insert unary and binary ops here int2 outputLocation; diff --git a/accel/opencl/ops_opencl.py b/accel/opencl/ops_opencl.py index 39ddba25d0..b17d8c0071 100644 --- a/accel/opencl/ops_opencl.py +++ b/accel/opencl/ops_opencl.py @@ -148,6 +148,8 @@ class OpenCLBuffer(GPUBuffer): if C.bs > 1: options.append("-DBATCH") assert C.py == 0, "batched conv doesn't work with y-padding" + if C.xs == 1 and C.ys == 1 and C.dx == 1 and C.dy == 1 and C.cin == 1: options.append("-DDEPTHWISE_UNSTRIDED") + assert C.cout%4 == 0 conv_src = CONV_SRC conv_short_names = ["filterSizeX", "filterSizeY", "paddingX", "paddingY", "strideX", "strideY", "dilationX", "dilationY"]