that should be right

This commit is contained in:
Comma Device
2022-07-19 19:47:37 -07:00
parent f4ed837f2f
commit 6da956b9fa

View File

@@ -235,7 +235,8 @@ def compile(input, output_fn):
needs_load = a in kernels_to_save
row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64
size = row_pitch * a.shape[1]
buf = CLBuffer(size * (1 if FLOAT16 else 2))
# this is *2 if float16 and *4 if float32
buf = CLBuffer(size * (2 if FLOAT16 else 1))
# zero out the buffer
zeros = np.zeros(size, dtype=np.uint8)