diff --git a/openpilot/compile.py b/openpilot/compile.py index dfb13a856d..e8fc88a53d 100644 --- a/openpilot/compile.py +++ b/openpilot/compile.py @@ -235,7 +235,8 @@ def compile(input, output_fn): needs_load = a in kernels_to_save row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64 size = row_pitch * a.shape[1] - buf = CLBuffer(size * (1 if FLOAT16 else 2)) + # this is *2 if float16 and *4 if float32 + buf = CLBuffer(size * (2 if FLOAT16 else 1)) # zero out the buffer zeros = np.zeros(size, dtype=np.uint8)