From a8734df03030c2b842f475492cbff2859be72595 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 21 Aug 2022 12:03:37 -0700 Subject: [PATCH] add openpilot tests to tinygrad --- .github/workflows/test.yml | 7 +++++-- accel/opencl/conv.cl | 4 ++-- accel/opencl/matmul.cl | 4 ++-- accel/opencl/ops_opencl.py | 23 +++++++++++------------ openpilot/compile.py | 17 ++++++----------- openpilot/run_thneed.py | 9 ++++++--- 6 files changed, 32 insertions(+), 32 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c4cb04279e..272c7750b3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -103,7 +103,7 @@ jobs: run: OPT=1 GPU=1 python -m pytest -s -v testopencl: - name: OpenCL Tests + name: OpenCL (openpilot) Test runs-on: ubuntu-20.04 if: ${{ false }} @@ -125,7 +125,10 @@ jobs: - name: Install Dependencies run: pip install -e '.[gpu,testing]' - name: Run Pytest (default) - run: OPENCL=1 python -m pytest -s -v + run: | + UNSAFE_FLOAT4=1 DEBUGCL=1 python3 openpilot/compile.py + FLOAT16=1 UNSAFE_FLOAT4=1 DEBUGCL=1 python3 openpilot/compile.py + python3 openpilot/run_thneed.py /tmp/output.thneed testmypy: name: Mypy Tests diff --git a/accel/opencl/conv.cl b/accel/opencl/conv.cl index 7ddafb4417..c9658b7dc8 100644 --- a/accel/opencl/conv.cl +++ b/accel/opencl/conv.cl @@ -1,9 +1,9 @@ //PREFIX __kernel void image_conv( + write_only image2d_t output, read_only image2d_t input, - read_only image2d_t weights, - write_only image2d_t output + read_only image2d_t weights #ifndef NOARGS ,short numPackedInputChannelsForGroup, short totalNumPackedInputChannels, diff --git a/accel/opencl/matmul.cl b/accel/opencl/matmul.cl index 555e899eb7..b6121e6ee0 100644 --- a/accel/opencl/matmul.cl +++ b/accel/opencl/matmul.cl @@ -1,10 +1,10 @@ //PREFIX __kernel void matmul( + write_only image2d_t output, __local float *outputScratch, read_only image1d_t input, - read_only image2d_t weights, - write_only image2d_t output + read_only image2d_t weights //ARGS ) { diff --git a/accel/opencl/ops_opencl.py b/accel/opencl/ops_opencl.py index 6e43761fea..3cdc4e1aae 100644 --- a/accel/opencl/ops_opencl.py +++ b/accel/opencl/ops_opencl.py @@ -3,7 +3,7 @@ from __future__ import annotations import os from tinygrad.llops.ops_gpu import GPUBuffer, CL, CLProgram, CLBuffer -from tinygrad.ops import ProcessingOps +from tinygrad.ops import ProcessingOps, ReduceOps from tinygrad.helpers import prod, ConvArgs from typing import List, Tuple, Optional, Dict, Set import numpy as np @@ -80,8 +80,8 @@ class OpenCLBuffer(GPUBuffer): #print(f"converting {self.shape} back to buffer, image shape is {self._image.shape}") CLProgram("from_image", """ __kernel void from_image( - read_only image2d_t in, - __global float4 *out) { + __global float4 *out, + read_only image2d_t in) { const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; int2 l; l.y = get_global_id(1); @@ -89,7 +89,7 @@ class OpenCLBuffer(GPUBuffer): int W = get_image_width(in); out[l.y*W + l.x] = read_imagef(in, smp, l); } - """)(self._image.shape, None, self._image, self._buf.cl) + """)(self._image.shape, None, self._buf.cl, self._image) self._image = None return self._buf.cl @@ -105,15 +105,15 @@ class OpenCLBuffer(GPUBuffer): #print(f"converting {self.shape} to image with shape {self._image.shape}") CLProgram("to_image", """ __kernel void to_image( - __global const float4 *in, - write_only image2d_t out) { + write_only image2d_t out, + __global const float4 *in) { int2 l; l.y = get_global_id(1); l.x = get_global_id(0); int W = get_image_width(out); write_imagef(out, l, in[l.y*W + l.x]); } - """)(self._image.shape, None, self._buf.cl, self._image) + """)(self._image.shape, None, self._image, self._buf.cl) self._buf = None return self._image @@ -123,12 +123,11 @@ class OpenCLBuffer(GPUBuffer): return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C) seen = set() - def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None, start="0.0", reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc"): + def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc"): if C is None or earlycode != "acc": # TODO: handle an opencl conv without the conv part - return super()._processing_op(bufs, code, C, start, reduce_shape, earlybufs, earlycode) + return super()._processing_op(bufs, code, C, op, reduce_shape, earlybufs, earlycode) assert earlycode == "acc" - assert start == "0.0" x = [x for x in bufs if x[0] == "input"][0][1] w = [x for x in bufs if x[0] == "weight"][0][1] @@ -228,7 +227,7 @@ class OpenCLBuffer(GPUBuffer): local_work_size = [4, global_work_size[1], lw] #print(global_work_size, local_work_size) - conv_prg(global_work_size, local_work_size, cl.LocalMemory(4 * local_work_size[0] * local_work_size[1] * lw), x.image, w.image, ret.image, *conv_args, *[buf.image if 'image2d_t' in typ else buf.cl for typ, (_, buf) in zip(ewtypes, ewbufs)]) + conv_prg(global_work_size, local_work_size, ret.image, cl.LocalMemory(4 * local_work_size[0] * local_work_size[1] * lw), x.image, w.image, *conv_args, *[buf.image if 'image2d_t' in typ else buf.cl for typ, (_, buf) in zip(ewtypes, ewbufs)]) return ret # this option is unused @@ -259,7 +258,7 @@ class OpenCLBuffer(GPUBuffer): argdtypes=tuple([None, None, None] + [np.int16]*len(conv_args) + [None]*len(ewbufs)) ) global_work_size = [C.cout//4, (C.ox+NUM_OUTPUTS-1)//NUM_OUTPUTS, C.bs*C.oy] - conv_prg(global_work_size, None, x.image, w.image, ret.image, *conv_args, *[buf.image if 'image2d_t' in typ else buf.cl for typ, (_, buf) in zip(ewtypes, ewbufs)]) + conv_prg(global_work_size, None, ret.image, x.image, w.image, *conv_args, *[buf.image if 'image2d_t' in typ else buf.cl for typ, (_, buf) in zip(ewtypes, ewbufs)]) return ret GPUBuffer = OpenCLBuffer diff --git a/openpilot/compile.py b/openpilot/compile.py index 77378ba627..ace68a730b 100644 --- a/openpilot/compile.py +++ b/openpilot/compile.py @@ -178,18 +178,13 @@ def compile(input, output_fn): saved_binaries = set() kernels_to_save = set() - kernels_to_not_save = set() + kernels_to_not_save = set(inputs) import pyopencl as cl for self, args in local_cl_cache: - for i,a in enumerate(args[2:]): - access_qualifer = self.clprg.get_arg_info(i, cl.kernel_arg_info.ACCESS_QUALIFIER) - type_qualifer = self.clprg.get_arg_info(i, cl.kernel_arg_info.TYPE_QUALIFIER) - type_name = self.clprg.get_arg_info(i, cl.kernel_arg_info.TYPE_NAME) - if cl.kernel_arg_access_qualifier.READ_ONLY == access_qualifer or cl.kernel_arg_type_qualifier.CONST == type_qualifer: - kernels_to_save.add(a) - else: - # this is written to at some point, we don't have to save it - kernels_to_not_save.add(a) + # output is always the first parameter + kernels_to_not_save.add(args[2]) + for a in args[3:]: + kernels_to_save.add(a) kernels_to_save -= kernels_to_not_save gobj = 0 @@ -259,7 +254,7 @@ def compile(input, output_fn): }) if needs_load: - data = np.empty(size//2, dtype=np.float32) + data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32) CL.enqueue_copy(data, buf.cl, is_blocking=True) if FLOAT16: data = data.astype(np.float16) weights.append(data.tobytes()) diff --git a/openpilot/run_thneed.py b/openpilot/run_thneed.py index 8cdc39e6ae..12e195202f 100644 --- a/openpilot/run_thneed.py +++ b/openpilot/run_thneed.py @@ -11,8 +11,8 @@ THNEED_KERNELS = "../../selfdrive/modeld/thneed/kernels/" def load_thneed_model(fn="model.thneed", float32=False, replace=None): import pyopencl as cl platform = [x for x in cl.get_platforms()] - assert len(platform) == 1 - ctx = cl.Context(devices=platform[0].get_devices(device_type=cl.device_type.GPU)) + assert len(platform) >= 1 + ctx = cl.Context(devices=platform[0].get_devices(device_type=cl.device_type.GPU)[0:1]) q = cl.CommandQueue(ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) mf = cl.mem_flags image_fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT if float32 else cl.channel_type.HALF_FLOAT) @@ -110,7 +110,10 @@ def load_thneed_model(fn="model.thneed", float32=False, replace=None): k['args_name'] = [] prg = prgs[k['name']] for i,arg in enumerate(k['args']): - k['args_name'].append(prg.get_arg_info(i, cl.kernel_arg_info.NAME)) + try: + k['args_name'].append(prg.get_arg_info(i, cl.kernel_arg_info.NAME)) + except cl.RuntimeError: + k['args_name'].append("") vision = vision[0:1] vnum = vnum[0] if len(vnum) >= 1 else None