mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
noargs
This commit is contained in:
@@ -5,13 +5,15 @@
|
||||
__kernel void image_conv(
|
||||
read_only image2d_t input,
|
||||
read_only image2d_t weights,
|
||||
write_only image2d_t output,
|
||||
short numPackedInputChannelsForGroup,
|
||||
write_only image2d_t output
|
||||
#ifndef NOARGS
|
||||
,short numPackedInputChannelsForGroup,
|
||||
short totalNumPackedInputChannels,
|
||||
short numPackedOutputChannelsForGroup,
|
||||
short totalNumPackedOutputChannels,
|
||||
short numOutputColumns,
|
||||
short numOutputRows, short numInputRows
|
||||
#endif
|
||||
/*short filterSizeX, short filterSizeY,
|
||||
short paddingX, short paddingY,
|
||||
short strideX, short strideY,
|
||||
|
||||
@@ -154,11 +154,19 @@ class OpenCLBuffer(GPUBuffer):
|
||||
conv_src = CONV_SRC
|
||||
conv_short_names = ["filterSizeX", "filterSizeY", "paddingX", "paddingY", "strideX", "strideY", "dilationX", "dilationY"]
|
||||
conv_shorts = [C.W, C.H, C.px, C.py, C.xs, C.ys, C.dx, C.dy]
|
||||
conv_arg_names = ["numPackedInputChannelsForGroup", "totalNumPackedInputChannels", "numPackedOutputChannelsForGroup", "totalNumPackedOutputChannels", "numOutputColumns", "numOutputRows", "numInputRows"]
|
||||
conv_args = [max(1, C.cin//4), C.groups*C.cin//4, max(1, C.rcout//4), C.cout//4, C.ox, C.oy, C.iy]
|
||||
|
||||
# comment out for args
|
||||
conv_short_names += conv_arg_names
|
||||
conv_shorts += conv_args
|
||||
conv_args = []
|
||||
options.append("-DNOARGS")
|
||||
|
||||
replacements["//SHORTS"] = ''.join([f"short {name} = {val};" for name,val in zip(conv_short_names, conv_shorts)])
|
||||
for k,v in replacements.items():
|
||||
conv_src = conv_src.replace(k, v)
|
||||
#print(conv_src)
|
||||
conv_args = [max(1, C.cin//4), C.groups*C.cin//4, max(1, C.rcout//4), C.cout//4, C.ox, C.oy, C.iy]
|
||||
conv_prg = CLProgram("image_conv", conv_src,
|
||||
options=tuple(options),
|
||||
argdtypes=tuple([None, None, None] + [np.int16]*len(conv_args) + [None]*len(ewbufs))
|
||||
|
||||
Reference in New Issue
Block a user