mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
that should be right
This commit is contained in:
@@ -235,7 +235,8 @@ def compile(input, output_fn):
|
||||
needs_load = a in kernels_to_save
|
||||
row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64
|
||||
size = row_pitch * a.shape[1]
|
||||
buf = CLBuffer(size * (1 if FLOAT16 else 2))
|
||||
# this is *2 if float16 and *4 if float32
|
||||
buf = CLBuffer(size * (2 if FLOAT16 else 1))
|
||||
|
||||
# zero out the buffer
|
||||
zeros = np.zeros(size, dtype=np.uint8)
|
||||
|
||||
Reference in New Issue
Block a user