diff --git a/openpilot/compile.py b/openpilot/compile.py index 1e4557a8c9..dfb13a856d 100644 --- a/openpilot/compile.py +++ b/openpilot/compile.py @@ -163,7 +163,8 @@ def compile(input, output_fn): tinygrad_out_np = tinygrad_out.numpy() # float32 only - if int(os.getenv("FLOAT16", 0)) == 0: + FLOAT16 = int(os.getenv("FLOAT16", 0)) + if FLOAT16 == 0: torch_out = run_onnx_torch(onnx_model, np_inputs).numpy() print(tinygrad_out_np, torch_out) np.testing.assert_allclose(torch_out, tinygrad_out_np, atol=1e-4, rtol=1e-2) @@ -227,14 +228,14 @@ def compile(input, output_fn): "id": ptr, "arg_type": "float*", "needs_load": needs_load, "size": a.size, }) if needs_load: - data = np.empty(a.size, dtype=np.uint8) + data = np.empty(a.size//4, dtype=np.float32) CL.enqueue_copy(data, a, is_blocking=True) weights.append(data.tobytes()) elif isinstance(a, cl.Image): needs_load = a in kernels_to_save - row_pitch = (a.shape[0]*4*2 + 63)//64 * 64 + row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64 size = row_pitch * a.shape[1] - buf = CLBuffer(size) + buf = CLBuffer(size * (1 if FLOAT16 else 2)) # zero out the buffer zeros = np.zeros(size, dtype=np.uint8) @@ -248,7 +249,7 @@ def compile(input, output_fn): l.x = get_global_id(0); out[l.y*row_pitch + l.x] = read_imagef(in, smp, l); } - """, argdtypes=(None, None, np.int32))(a.shape, None, a, buf.cl, row_pitch) + """, argdtypes=(None, None, np.int32))(a.shape, None, a, buf.cl, row_pitch//(4*(2 if FLOAT16 else 4))) # multiple of 32 isn't enough jdat['objects'].append({ @@ -257,8 +258,9 @@ def compile(input, output_fn): }) if needs_load: - data = np.empty(size, dtype=np.uint8) + data = np.empty(size//2, dtype=np.float32) CL.enqueue_copy(data, buf.cl, is_blocking=True) + if FLOAT16: data = data.astype(np.float16) weights.append(data.tobytes()) else: raise Exception("unknown object", a)