update openpilot slow conv uop ast (#13197)

the two remaining slow ones
This commit is contained in:
chenyu
2025-11-10 14:03:20 -08:00
committed by GitHub
parent 0c978d45e6
commit 829cdafccc
2 changed files with 56 additions and 236 deletions

View File

@@ -147,10 +147,10 @@ src = renderer.render(uops)
lib = compiler.compile(src)
ps = ProgramSpec("cat", src, Device.DEFAULT, ast, uops)
print(ps.src)
print(ps.applied_opts)
# TODO: this is faster with no GROUP and with NOLOCALS
# (Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=19, arg=4), Opt(op=OptOps.UNROLL, axis=17, arg=4), Opt(op=OptOps.UNROLL, axis=15, arg=4), Opt(op=OptOps.UNROLL, axis=13, arg=4), Opt(op=OptOps.UNROLL, axis=11, arg=4), Opt(op=OptOps.UNROLL, axis=9, arg=4), Opt(op=OptOps.UNROLL, axis=7, arg=4), Opt(op=OptOps.UNROLL, axis=5, arg=4), Opt(op=OptOps.UNROLL, axis=3, arg=4), Opt(op=OptOps.UNROLL, axis=1, arg=4), Opt(op=OptOps.GROUPTOP, axis=0, arg=16))
# print(ps.src)
# print(ps.applied_opts)
# NOTE: this is faster with no GROUP and with NOLOCALS
# (Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=19, arg=4), Opt(op=OptOps.UNROLL, axis=17, arg=4), Opt(op=OptOps.UNROLL, axis=15, arg=4), Opt(op=OptOps.UNROLL, axis=13, arg=4), Opt(op=OptOps.UNROLL, axis=11, arg=4), Opt(op=OptOps.UNROLL, axis=9, arg=4), Opt(op=OptOps.UNROLL, axis=7, arg=4), Opt(op=OptOps.UNROLL, axis=5, arg=4), Opt(op=OptOps.UNROLL, axis=3, arg=4), Opt(op=OptOps.UNROLL, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None))
cr = CompiledRunner(ps, precompiled=lib)
gs = sorted(dedup([u for u in ast.toposort() if u.op is Ops.DEFINE_GLOBAL]), key=lambda u: u.arg)

View File

@@ -1,242 +1,65 @@
# ruff: noqa: E501
# ruff: noqa: E501 E712
from tinygrad import dtypes, Device
from tinygrad.uop.ops import UOp, AxisType, Ops, KernelInfo
from tinygrad.codegen import full_rewrite
from tinygrad.codegen.opt import Opt, OptOps
# from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup
from tinygrad.helpers import dedup, getenv
from tinygrad.device import Buffer
from tinygrad.dtype import ImageDType
from tinygrad.dtype import ImageDType, Invalid
# PYTHONPATH="." DEBUG=5 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
# kernel 672
# faster on d59d4cd, 50% slower with the new linearizer
# PYTHONPATH="." DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
""" d59d4cd
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((32, 1024, 4)), arg=0, src=())
c1 = UOp.range(UOp.const(dtypes.index, 64), 3, AxisType.LOOP)
c2 = UOp.range(UOp.const(dtypes.index, 64), 4, AxisType.LOOP)
c3 = UOp.range(UOp.const(dtypes.index, 32), 2, AxisType.LOOP)
c4 = (((c1*UOp.const(dtypes.index, 64))+c2)+(c3*UOp.const(dtypes.index, 4096)))
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((32, 1024, 4)), arg=1, src=())
c6 = c5.index(c4).load()
c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((32, 3072, 4)), arg=2, src=())
c8 = UOp.range(UOp.const(dtypes.index, 48), 0, AxisType.REDUCE)
c9 = UOp.range(UOp.const(dtypes.index, 4), 1, AxisType.REDUCE)
c10 = c7.index(((((c8*UOp.const(dtypes.index, 4))+c9)+(c1*UOp.const(dtypes.index, 192)))+(c3*UOp.const(dtypes.index, 12288)))).load()
c11 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((16, 192, 4)), arg=3, src=())
c12 = c11.index(((((c9*UOp.const(dtypes.index, 4))+(c2%UOp.const(dtypes.index, 4)))+(c8*UOp.const(dtypes.index, 16)))+((c2//UOp.const(dtypes.index, 4))*UOp.const(dtypes.index, 768)))).load()
c13 = UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(64), arg=4, src=())
c14 = c13.index(c2).load()
c15 = UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(64), arg=5, src=())
c16 = c15.index(c2).load()
c17 = (c6+(((c10*c12.cast(dtypes.float)).cast(dtypes.float).reduce(c8, c9, arg=Ops.ADD)+c14.cast(dtypes.float))*c16.cast(dtypes.float)))
c18 = c0.index(c4).store(c17, c3, c1, c2)
ast = c18.sink()
more upcast axis : [(3, 320, 0, 4)]
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void r_512_16_4_4_48_4(write_only image2d_t data0_131072, read_only image2d_t data1_131072, read_only image2d_t data2_393216, read_only image2d_t data3_12288, __global half* data4_64, __global half* data5_64) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
float acc0[16];
int idx0 = get_global_id(0); /* 16 */
int idx1 = get_global_id(1); /* 512 */
int alu0 = (idx1>>4);
*(acc0+0) = 0.0f;
*(acc0+1) = 0.0f;
*(acc0+2) = 0.0f;
*(acc0+3) = 0.0f;
*(acc0+4) = 0.0f;
*(acc0+5) = 0.0f;
*(acc0+6) = 0.0f;
*(acc0+7) = 0.0f;
*(acc0+8) = 0.0f;
*(acc0+9) = 0.0f;
*(acc0+10) = 0.0f;
*(acc0+11) = 0.0f;
*(acc0+12) = 0.0f;
*(acc0+13) = 0.0f;
*(acc0+14) = 0.0f;
*(acc0+15) = 0.0f;
for (int Ridx0 = 0; Ridx0 < 48; Ridx0++) {
int alu17 = ((idx1*192)+Ridx0);
int alu18 = (alu17+48);
int alu19 = (alu17+96);
int alu20 = (alu17+144);
int alu21 = (Ridx0<<2);
float4 val0 = read_imagef(data3_12288, smp, (int2)(alu21,idx0));
float4 val1 = read_imagef(data3_12288, smp, (int2)((alu21+1),idx0));
float4 val2 = read_imagef(data3_12288, smp, (int2)((alu21+2),idx0));
float4 val3 = read_imagef(data3_12288, smp, (int2)((alu21+3),idx0));
float4 val4 = read_imagef(data2_393216, smp, (int2)((alu18-(3072*(((alu18>>10)*43)>>7))),alu0));
float4 val5 = read_imagef(data2_393216, smp, (int2)((alu19-(3072*(((alu19>>10)*43)>>7))),alu0));
float4 val6 = read_imagef(data2_393216, smp, (int2)((alu20-(3072*(((alu20>>10)*43)>>7))),alu0));
float4 val7 = read_imagef(data2_393216, smp, (int2)((alu17-(3072*(((alu17>>10)*43)>>7))),alu0));
*(acc0+1) = ((*(acc0+1))+(val4.x*val0.x)+(val4.y*val1.x)+(val4.z*val2.x)+(val4.w*val3.x));
*(acc0+5) = ((*(acc0+5))+(val4.x*val0.y)+(val4.y*val1.y)+(val4.z*val2.y)+(val4.w*val3.y));
*(acc0+9) = ((*(acc0+9))+(val4.x*val0.z)+(val4.y*val1.z)+(val4.z*val2.z)+(val4.w*val3.z));
*(acc0+13) = ((*(acc0+13))+(val4.x*val0.w)+(val4.y*val1.w)+(val4.z*val2.w)+(val4.w*val3.w));
*(acc0+2) = ((*(acc0+2))+(val5.x*val0.x)+(val5.y*val1.x)+(val5.z*val2.x)+(val5.w*val3.x));
*(acc0+6) = ((*(acc0+6))+(val5.x*val0.y)+(val5.y*val1.y)+(val5.z*val2.y)+(val5.w*val3.y));
*(acc0+10) = ((*(acc0+10))+(val5.x*val0.z)+(val5.y*val1.z)+(val5.z*val2.z)+(val5.w*val3.z));
*(acc0+14) = ((*(acc0+14))+(val5.x*val0.w)+(val5.y*val1.w)+(val5.z*val2.w)+(val5.w*val3.w));
*(acc0+3) = ((*(acc0+3))+(val6.x*val0.x)+(val6.y*val1.x)+(val6.z*val2.x)+(val6.w*val3.x));
*(acc0+7) = ((*(acc0+7))+(val6.x*val0.y)+(val6.y*val1.y)+(val6.z*val2.y)+(val6.w*val3.y));
*(acc0+11) = ((*(acc0+11))+(val6.x*val0.z)+(val6.y*val1.z)+(val6.z*val2.z)+(val6.w*val3.z));
*(acc0+15) = ((*(acc0+15))+(val6.x*val0.w)+(val6.y*val1.w)+(val6.z*val2.w)+(val6.w*val3.w));
*(acc0+0) = ((*(acc0+0))+(val7.x*val0.x)+(val7.y*val1.x)+(val7.z*val2.x)+(val7.w*val3.x));
*(acc0+4) = ((*(acc0+4))+(val7.x*val0.y)+(val7.y*val1.y)+(val7.z*val2.y)+(val7.w*val3.y));
*(acc0+8) = ((*(acc0+8))+(val7.x*val0.z)+(val7.y*val1.z)+(val7.z*val2.z)+(val7.w*val3.z));
*(acc0+12) = ((*(acc0+12))+(val7.x*val0.w)+(val7.y*val1.w)+(val7.z*val2.w)+(val7.w*val3.w));
}
int alu39 = (idx0<<2);
half4 val8 = (*((__global half4*)((data4_64+alu39))));
half4 val9 = (*((__global half4*)((data5_64+alu39))));
int alu40 = (idx0+(idx1<<6));
int2 cast0 = (int2)((alu40&1023),alu0);
float4 val10 = read_imagef(data1_131072, smp, cast0);
int2 cast1 = (int2)(((alu40+16)&1023),alu0);
float4 val11 = read_imagef(data1_131072, smp, cast1);
int2 cast2 = (int2)(((alu40+32)&1023),alu0);
float4 val12 = read_imagef(data1_131072, smp, cast2);
int2 cast3 = (int2)(((alu40+48)&1023),alu0);
float4 val13 = read_imagef(data1_131072, smp, cast3);
float cast4 = ((float)(val8.x));
float cast5 = ((float)(val9.x));
float cast6 = ((float)(val8.y));
float cast7 = ((float)(val9.y));
float cast8 = ((float)(val8.z));
float cast9 = ((float)(val9.z));
float cast10 = ((float)(val8.w));
float cast11 = ((float)(val9.w));
write_imagef(data0_131072, cast0, (float4)((val10.x+(((*(acc0+0))+cast4)*cast5)),(val10.y+(((*(acc0+4))+cast6)*cast7)),(val10.z+(((*(acc0+8))+cast8)*cast9)),(val10.w+(((*(acc0+12))+cast10)*cast11))));
write_imagef(data0_131072, cast1, (float4)((val11.x+(((*(acc0+1))+cast4)*cast5)),(val11.y+(((*(acc0+5))+cast6)*cast7)),(val11.z+(((*(acc0+9))+cast8)*cast9)),(val11.w+(((*(acc0+13))+cast10)*cast11))));
write_imagef(data0_131072, cast2, (float4)((val12.x+(((*(acc0+2))+cast4)*cast5)),(val12.y+(((*(acc0+6))+cast6)*cast7)),(val12.z+(((*(acc0+10))+cast8)*cast9)),(val12.w+(((*(acc0+14))+cast10)*cast11))));
write_imagef(data0_131072, cast3, (float4)((val13.x+(((*(acc0+3))+cast4)*cast5)),(val13.y+(((*(acc0+7))+cast6)*cast7)),(val13.z+(((*(acc0+11))+cast8)*cast9)),(val13.w+(((*(acc0+15))+cast10)*cast11))));
}
*** QCOM 672 r_512_16_4_4_48_4 arg 6 mem 0.10 GB tm 322.55us/ 77.83ms ( 157 GFLOPS 4|160 GB/s) ['mul', '__add__', 'conv2d']
"""
def vision_conv_143():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((16, 1024, 4)), (), 0)
c2 = UOp.range(32, 3, AxisType.LOOP)
c5 = UOp.range(128, 4, AxisType.LOOP)
c8 = UOp.range(16, 2, AxisType.LOOP)
c16 = UOp.range(7, 0, AxisType.REDUCE)
c17 = c8*2+c16
c24 = ((c17<3)!=True)&(c17<35)
c26 = UOp.range(7, 1, AxisType.REDUCE)
c27 = c2*2+c26
c32 = ((c27<3)!=True)&(c27<67)
c34 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((32, 1024, 4)), (), 1)
c38 = c5//2
c45 = (c32&c24).where((c27*64+c38+c17*4096+-12480), UOp.const(dtypes.index, Invalid))
c48 = (c24&c32).where(c34.index(c45), UOp.const(dtypes.float, 0.0))
c49 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((64, 49, 4)), (), 2)
c61 = c48*c49.index((c26*4+c5%2+c16*28+c38*196))
c63 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(128), (), 3)
c65 = c61.reduce(c16, c26, arg=Ops.ADD)+c63.index(c5)
c67 = c0.index((c2*128+c5+c8*4096), ptr=True).store(c65).end(c8, c2, c5)
""" master 99e76f33a0f4ec84c79c1271dbc955fe6b5a7778
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((32, 1024, 4)), (), 0)
c2 = UOp.range(64, 3, AxisType.LOOP)
c4 = UOp.range(64, 4, AxisType.LOOP)
c7 = UOp.range(32, 2, AxisType.LOOP)
c10 = (((c2*64)+c4)+(c7*4096))
c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((32, 1024, 4)), (), 1)
c14 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((32, 3072, 4)), (), 2)
c16 = UOp.range(48, 0, AxisType.REDUCE)
c19 = UOp.range(4, 1, AxisType.REDUCE)
c28 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((16, 192, 4)), (), 3)
c40 = (c14.index(((((c16*4)+c19)+(c2*192))+(c7*12288)))*c28.index(((((c19*4)+(c4%4))+(c16*16))+((c4//4)*768))))
c42 = UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(64), (), 4)
c46 = UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(64), (), 5)
c50 = (c12.index(c10)+((c40.reduce(c16, c19, arg=Ops.ADD)+c42.index(c4).cast(dtypes.float))*c46.index(c4).cast(dtypes.float)))
c52 = c0.index(c10, ptr=True).store(c50).end(c7, c2, c4)
ast = c52.sink()
more upcast axis : [(3, 320, 0, 4)]
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void r_512_16_4_4_48_4(write_only image2d_t data0_131072, read_only image2d_t data1_131072, read_only image2d_t data2_393216, read_only image2d_t data3_12288, __global half* data4_64, __global half* data5_64) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
float acc0[16];
int idx0 = get_global_id(0); /* 16 */
int idx1 = get_global_id(1); /* 512 */
*(acc0+0) = 0.0f;
*(acc0+1) = 0.0f;
*(acc0+2) = 0.0f;
*(acc0+3) = 0.0f;
*(acc0+4) = 0.0f;
*(acc0+5) = 0.0f;
*(acc0+6) = 0.0f;
*(acc0+7) = 0.0f;
*(acc0+8) = 0.0f;
*(acc0+9) = 0.0f;
*(acc0+10) = 0.0f;
*(acc0+11) = 0.0f;
*(acc0+12) = 0.0f;
*(acc0+13) = 0.0f;
*(acc0+14) = 0.0f;
*(acc0+15) = 0.0f;
int alu16 = (idx0<<2);
half4 val0 = (*((__global half4*)((data4_64+alu16))));
half4 val1 = (*((__global half4*)((data5_64+alu16))));
int alu17 = (idx0+(idx1<<6));
int alu18 = (idx1>>4);
int2 cast0 = (int2)((alu17&1023),alu18);
float4 val2 = read_imagef(data1_131072, smp, cast0);
int2 cast1 = (int2)(((alu17+16)&1023),alu18);
float4 val3 = read_imagef(data1_131072, smp, cast1);
int2 cast2 = (int2)(((alu17+32)&1023),alu18);
float4 val4 = read_imagef(data1_131072, smp, cast2);
int2 cast3 = (int2)(((alu17+48)&1023),alu18);
float4 val5 = read_imagef(data1_131072, smp, cast3);
for (int Ridx0 = 0; Ridx0 < 48; Ridx0++) {
int alu19 = ((idx1*192)+Ridx0);
int alu20 = (alu19+48);
int alu21 = (alu19+96);
int alu22 = (alu19+144);
int alu23 = (Ridx0<<2);
float4 val6 = read_imagef(data3_12288, smp, (int2)(alu23,idx0));
float4 val7 = read_imagef(data3_12288, smp, (int2)((alu23+1),idx0));
float4 val8 = read_imagef(data3_12288, smp, (int2)((alu23+2),idx0));
float4 val9 = read_imagef(data3_12288, smp, (int2)((alu23+3),idx0));
float4 val10 = read_imagef(data2_393216, smp, (int2)((alu20-(3072*(((alu20>>10)*43)>>7))),alu18));
*(acc0+1) = ((*(acc0+1))+(val10.x*val6.x)+(val10.y*val7.x)+(val10.z*val8.x)+(val10.w*val9.x));
*(acc0+5) = ((*(acc0+5))+(val10.x*val6.y)+(val10.y*val7.y)+(val10.z*val8.y)+(val10.w*val9.y));
*(acc0+9) = ((*(acc0+9))+(val10.x*val6.z)+(val10.y*val7.z)+(val10.z*val8.z)+(val10.w*val9.z));
*(acc0+13) = ((*(acc0+13))+(val10.x*val6.w)+(val10.y*val7.w)+(val10.z*val8.w)+(val10.w*val9.w));
float4 val11 = read_imagef(data2_393216, smp, (int2)((alu21-(3072*(((alu21>>10)*43)>>7))),alu18));
*(acc0+2) = ((*(acc0+2))+(val11.x*val6.x)+(val11.y*val7.x)+(val11.z*val8.x)+(val11.w*val9.x));
*(acc0+6) = ((*(acc0+6))+(val11.x*val6.y)+(val11.y*val7.y)+(val11.z*val8.y)+(val11.w*val9.y));
*(acc0+10) = ((*(acc0+10))+(val11.x*val6.z)+(val11.y*val7.z)+(val11.z*val8.z)+(val11.w*val9.z));
*(acc0+14) = ((*(acc0+14))+(val11.x*val6.w)+(val11.y*val7.w)+(val11.z*val8.w)+(val11.w*val9.w));
float4 val12 = read_imagef(data2_393216, smp, (int2)((alu22-(3072*(((alu22>>10)*43)>>7))),alu18));
*(acc0+3) = ((*(acc0+3))+(val12.x*val6.x)+(val12.y*val7.x)+(val12.z*val8.x)+(val12.w*val9.x));
*(acc0+7) = ((*(acc0+7))+(val12.x*val6.y)+(val12.y*val7.y)+(val12.z*val8.y)+(val12.w*val9.y));
*(acc0+11) = ((*(acc0+11))+(val12.x*val6.z)+(val12.y*val7.z)+(val12.z*val8.z)+(val12.w*val9.z));
*(acc0+15) = ((*(acc0+15))+(val12.x*val6.w)+(val12.y*val7.w)+(val12.z*val8.w)+(val12.w*val9.w));
float4 val13 = read_imagef(data2_393216, smp, (int2)((alu19-(3072*(((alu19>>10)*43)>>7))),alu18));
*(acc0+0) = ((*(acc0+0))+(val13.x*val6.x)+(val13.y*val7.x)+(val13.z*val8.x)+(val13.w*val9.x));
*(acc0+4) = ((*(acc0+4))+(val13.x*val6.y)+(val13.y*val7.y)+(val13.z*val8.y)+(val13.w*val9.y));
*(acc0+8) = ((*(acc0+8))+(val13.x*val6.z)+(val13.y*val7.z)+(val13.z*val8.z)+(val13.w*val9.z));
*(acc0+12) = ((*(acc0+12))+(val13.x*val6.w)+(val13.y*val7.w)+(val13.z*val8.w)+(val13.w*val9.w));
}
float cast4 = ((float)(val0.x));
float cast5 = ((float)(val1.x));
float cast6 = ((float)(val0.y));
float cast7 = ((float)(val1.y));
float cast8 = ((float)(val0.z));
float cast9 = ((float)(val1.z));
float cast10 = ((float)(val0.w));
float cast11 = ((float)(val1.w));
write_imagef(data0_131072, cast0, (float4)((val2.x+(((*(acc0+0))+cast4)*cast5)),(val2.y+(((*(acc0+4))+cast6)*cast7)),(val2.z+(((*(acc0+8))+cast8)*cast9)),(val2.w+(((*(acc0+12))+cast10)*cast11))));
write_imagef(data0_131072, cast1, (float4)((val3.x+(((*(acc0+1))+cast4)*cast5)),(val3.y+(((*(acc0+5))+cast6)*cast7)),(val3.z+(((*(acc0+9))+cast8)*cast9)),(val3.w+(((*(acc0+13))+cast10)*cast11))));
write_imagef(data0_131072, cast2, (float4)((val4.x+(((*(acc0+2))+cast4)*cast5)),(val4.y+(((*(acc0+6))+cast6)*cast7)),(val4.z+(((*(acc0+10))+cast8)*cast9)),(val4.w+(((*(acc0+14))+cast10)*cast11))));
write_imagef(data0_131072, cast3, (float4)((val5.x+(((*(acc0+3))+cast4)*cast5)),(val5.y+(((*(acc0+7))+cast6)*cast7)),(val5.z+(((*(acc0+11))+cast8)*cast9)),(val5.w+(((*(acc0+15))+cast10)*cast11))));
}
*** QCOM 672 r_512_16_4_4_48_4 arg 6 mem 0.10 GB tm 527.97us/ 78.94ms ( 96 GFLOPS 3|98 GB/s) ['conv2d', 'mul', '__add__']
"""
opts = None
return c67.sink(arg=KernelInfo(name="conv", opts_to_apply=opts))
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((32, 1024, 4)), (), 0)
c2 = UOp.range(64, 3, AxisType.LOOP)
c4 = UOp.range(64, 4, AxisType.LOOP)
c7 = UOp.range(32, 2, AxisType.LOOP)
c10 = (((c2*64)+c4)+(c7*4096))
c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((32, 1024, 4)), (), 1)
c14 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((32, 3072, 4)), (), 2)
c16 = UOp.range(48, 0, AxisType.REDUCE)
c19 = UOp.range(4, 1, AxisType.REDUCE)
c28 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((16, 192, 4)), (), 3)
c40 = (c14.index(((((c16*4)+c19)+(c2*192))+(c7*12288)))*c28.index(((((c19*4)+(c4%4))+(c16*16))+((c4//4)*768))))
c42 = UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(64), (), 4)
c46 = UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(64), (), 5)
c50 = (c12.index(c10)+((c40.reduce(c16, c19, arg=Ops.ADD)+c42.index(c4).cast(dtypes.float))*c46.index(c4).cast(dtypes.float)))
c52 = c0.index(c10, ptr=True).store(c50).end(c7, c2, c4)
def vision_conv_153():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((8, 1024, 4)), (), 0)
c2 = UOp.range(16, 3, AxisType.LOOP)
c5 = UOp.range(256, 4, AxisType.LOOP)
c8 = UOp.range(8, 2, AxisType.LOOP)
c16 = UOp.range(7, 0, AxisType.REDUCE)
c17 = c8*2+c16
c24 = ((c17<3)!=True)&(c17<19)
c26 = UOp.range(7, 1, AxisType.REDUCE)
c27 = c2*2+c26
c32 = ((c27<3)!=True)&(c27<35)
c34 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((16, 1024, 4)), (), 1)
c38 = c5//2
c45 = (c32&c24).where((c27*128+c38+c17*4096+-12672), UOp.const(dtypes.index, Invalid))
c48 = (c24&c32).where(c34.index(c45), UOp.const(dtypes.float, 0.0))
c49 = UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((128, 49, 4)), (), 2)
c61 = c48*c49.index((c26*4+c5%2+c16*28+c38*196))
c63 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), (), 3)
c65 = c61.reduce(c16, c26, arg=Ops.ADD)+c63.index(c5)
c67 = c0.index((c2*256+c5+c8*4096), ptr=True).store(c65).end(c8, c2, c5)
# NOLOCALS=1 IMAGE=2 DEV=CL
opts = (Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None))
opts = None
return c67.sink(arg=KernelInfo(name="conv", opts_to_apply=opts))
ast = c52.sink(arg=KernelInfo(name="conv", opts_to_apply=opts))
ast = vision_conv_143() if getenv("NUM", 143) == 143 else vision_conv_153()
compiler = Device.default.compiler
renderer = Device.default.renderer
@@ -247,14 +70,11 @@ src = renderer.render(uops)
lib = compiler.compile(src)
ps = ProgramSpec("conv", src, Device.DEFAULT, ast, uops)
print(ps.src)
print(ps.applied_opts)
cr = CompiledRunner(ps, precompiled=lib)
gs = sorted(dedup([u for u in ast.toposort() if u.op is Ops.DEFINE_GLOBAL]), key=lambda u: u.arg)
print(len(gs))
print([g.dtype for g in gs])
# print(len(gs))
# print([g.dtype for g in gs])
bufs = [Buffer(ps.device, g.size, g.dtype if isinstance(g.dtype, ImageDType) else g.dtype._base).ensure_allocated() for g in gs]
t = cr(bufs, wait=True)