mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
unstrided
This commit is contained in:
@@ -51,6 +51,31 @@ __kernel void image_conv(
|
||||
inputLocation.y = mad24(outputRow, strideY, -paddingY);
|
||||
#endif
|
||||
|
||||
#ifdef DEPTHWISE_UNSTRIDED
|
||||
for (short rfRow = 0; rfRow < filterSizeY; ++rfRow) {
|
||||
float4 inputValues[4];
|
||||
inputLocation.x = startX;
|
||||
for (short i = 1; i < 4; ++i) {
|
||||
inputValues[i] = read_imagef(input, smp, inputLocation);
|
||||
inputLocation.x += totalNumPackedOutputChannels;
|
||||
}
|
||||
for (short rfColumn = 0; rfColumn < filterSizeX; ++rfColumn) {
|
||||
inputValues[0] = inputValues[1];
|
||||
inputValues[1] = inputValues[2];
|
||||
inputValues[2] = inputValues[3];
|
||||
inputValues[3] = read_imagef(input, smp, inputLocation);
|
||||
inputLocation.x += totalNumPackedInputChannels;
|
||||
float4 weightValues = read_imagef(weights, smp, weightLocation);
|
||||
++weightLocation.x;
|
||||
outputValues[0] += inputValues[0] * weightValues;
|
||||
outputValues[1] += inputValues[1] * weightValues;
|
||||
outputValues[2] += inputValues[2] * weightValues;
|
||||
outputValues[3] += inputValues[3] * weightValues;
|
||||
}
|
||||
++inputLocation.y;
|
||||
}
|
||||
#else
|
||||
|
||||
for (short rfRow = 0; rfRow < filterSizeY; ++rfRow) {
|
||||
// numPackedInputChannelsForGroup is 1 in depthwise
|
||||
for (short packedInputChannel = 0; packedInputChannel < numPackedInputChannelsForGroup; ++packedInputChannel) {
|
||||
@@ -91,6 +116,7 @@ __kernel void image_conv(
|
||||
}
|
||||
inputLocation.y += dilationY;
|
||||
}
|
||||
#endif
|
||||
|
||||
// insert unary and binary ops here
|
||||
int2 outputLocation;
|
||||
|
||||
@@ -148,6 +148,8 @@ class OpenCLBuffer(GPUBuffer):
|
||||
if C.bs > 1:
|
||||
options.append("-DBATCH")
|
||||
assert C.py == 0, "batched conv doesn't work with y-padding"
|
||||
if C.xs == 1 and C.ys == 1 and C.dx == 1 and C.dy == 1 and C.cin == 1: options.append("-DDEPTHWISE_UNSTRIDED")
|
||||
|
||||
assert C.cout%4 == 0
|
||||
conv_src = CONV_SRC
|
||||
conv_short_names = ["filterSizeX", "filterSizeY", "paddingX", "paddingY", "strideX", "strideY", "dilationX", "dilationY"]
|
||||
|
||||
Reference in New Issue
Block a user