Files
tinygrad/accel/opencl/matmul.cl
Ollin Boer Bohan 3b1767e013 Fix OpenCL Metal texture issues (#378)
* Fix OpenCL Metal texture issues

Tile CL images when needed, to fit into the 16384 max Metal image size;
gets me to ~4.8s/iteration for SD on M1 Pro with OPENCL=1 FLOAT16=1.

* Minor cleanup

* Fix mish in CI, or no-op?

* Is mish being framed?

* It would help if any of this reproduced locally

* ???

* OPT is reverted; use original mish

* Cleanup post-review

* Fix some shape usage

* Tiler tests, shouldn't oom or overflow either

* Can't CL if there's no CL?

* Run tiler tests even if GPU=1

* relu6 segfault binary chop; revert test

* relu6 segfault binary chop; revert accel

* relu6 segfault binary chop; revert . (???)

* end relu6 segfault binary chop; repo's haunted
2022-09-29 01:21:54 -04:00

50 lines
1.6 KiB
Common Lisp

//PREFIX
__kernel void matmul(
write_only image2d_t output,
__local float *outputScratch,
read_only image2d_t input,
read_only image2d_t weights
//ARGS
) {
//SHORTS
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
short packedOutputChannel = get_global_id(2);
short scratchOffset = mad24((short)get_local_id(1), 4, (short)get_local_id(0));
short weightIndex = (short)get_global_id(0);
// fast path precompute (32x speedup)
float outputValue = 0.0f;
for (short inputSet = (short)get_global_id(1); inputSet < numPackedInputChannelsForGroup; inputSet += get_global_size(1)) {
int2 inputLocation = (int2)(inputSet, 0);
float4 inputValues = read_imagef(input, smp, INPUT_LOCATION);
int2 weightLocation = (int2)(mad24(inputSet, 4, weightIndex), packedOutputChannel);
float4 weightValues = read_imagef(weights, smp, WEIGHT_LOCATION);
outputValue += dot(inputValues, weightValues);
}
short scratchIndex = mad24((short)get_local_id(2), mul24((short)get_local_size(1), 4), scratchOffset);
outputScratch[scratchIndex] = outputValue;
barrier(CLK_LOCAL_MEM_FENCE);
if (scratchOffset == 0) {
float4 outputValues = (float4)(0, 0, 0, 0);
// fast path
for (short i = 0; i < (short)get_global_size(1); ++i) {
outputValues += vload4(0, &outputScratch[scratchIndex]);
scratchIndex += 4;
}
// insert unary and binary ops here
int2 outputLocation = (int2)(packedOutputChannel, 0);
//BINOP
// output to memory
write_imagef(output, OUTPUT_LOCATION, outputValues);
}
}