mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
pass in shorts
This commit is contained in:
@@ -11,14 +11,16 @@ __kernel void image_conv(
|
||||
short numPackedOutputChannelsForGroup,
|
||||
short totalNumPackedOutputChannels,
|
||||
short numOutputColumns,
|
||||
short numOutputRows, short numInputRows,
|
||||
short filterSizeX, short filterSizeY,
|
||||
short numOutputRows, short numInputRows
|
||||
/*short filterSizeX, short filterSizeY,
|
||||
short paddingX, short paddingY,
|
||||
short strideX, short strideY,
|
||||
short dilationX, short dilationY
|
||||
short dilationX, short dilationY*/
|
||||
//ARGS
|
||||
) {
|
||||
|
||||
//SHORTS
|
||||
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
|
||||
float4 outputValues[NUM_OUTPUTS];
|
||||
@@ -91,14 +93,19 @@ __kernel void image_conv(
|
||||
}
|
||||
|
||||
// insert unary and binary ops here
|
||||
|
||||
// output to memory
|
||||
int2 outputLocation;
|
||||
short outputColumn = startOutputColumn;
|
||||
outputLocation.y = outputRow;
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
outputLocation.x = mad24(outputColumn, totalNumPackedOutputChannels, packedOutputChannel);
|
||||
//BINOP
|
||||
++outputColumn;
|
||||
}
|
||||
|
||||
// output to memory
|
||||
outputColumn = startOutputColumn;
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
outputLocation.x = mad24(outputColumn, totalNumPackedOutputChannels, packedOutputChannel);
|
||||
if (outputColumn < numOutputColumns) {
|
||||
write_imagef(output, outputLocation, outputValues[i]);
|
||||
}
|
||||
|
||||
@@ -150,14 +150,17 @@ class OpenCLBuffer(GPUBuffer):
|
||||
assert C.py == 0, "batched conv doesn't work with y-padding"
|
||||
assert C.cout%4 == 0
|
||||
conv_src = CONV_SRC
|
||||
conv_short_names = ["filterSizeX", "filterSizeY", "paddingX", "paddingY", "strideX", "strideY", "dilationX", "dilationY"]
|
||||
conv_shorts = [C.W, C.H, C.px, C.py, C.xs, C.ys, C.dx, C.dy]
|
||||
replacements["//SHORTS"] = ''.join([f"short {name} = {val};" for name,val in zip(conv_short_names, conv_shorts)])
|
||||
for k,v in replacements.items():
|
||||
conv_src = conv_src.replace(k, v)
|
||||
#print(conv_src)
|
||||
conv_args = [max(1, C.cin//4), C.groups*C.cin//4, max(1, C.rcout//4), C.cout//4, C.ox, C.oy, C.iy]
|
||||
conv_prg = CLProgram("image_conv", conv_src,
|
||||
options=tuple(options),
|
||||
argdtypes=tuple([None, None, None] + [np.int16]*15 + [None]*len(ewbufs))
|
||||
argdtypes=tuple([None, None, None] + [np.int16]*len(conv_args) + [None]*len(ewbufs))
|
||||
)
|
||||
conv_args = [max(1, C.cin//4), C.groups*C.cin//4, max(1, C.rcout//4), C.cout//4, C.ox, C.oy, C.iy, C.W, C.H, C.px, C.py, C.xs, C.ys, C.dx, C.dy]
|
||||
global_work_size = [C.cout//4, (C.ox+3)//4, C.bs*C.oy]
|
||||
conv_prg(global_work_size, None, x.image, w.image, ret.image, *conv_args, *[buf.image if 'image2d_t' in typ else buf.cl for typ, (_, buf) in zip(ewtypes, ewbufs)])
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user