mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
float16 fixups
This commit is contained in:
@@ -163,7 +163,8 @@ def compile(input, output_fn):
|
||||
tinygrad_out_np = tinygrad_out.numpy()
|
||||
|
||||
# float32 only
|
||||
if int(os.getenv("FLOAT16", 0)) == 0:
|
||||
FLOAT16 = int(os.getenv("FLOAT16", 0))
|
||||
if FLOAT16 == 0:
|
||||
torch_out = run_onnx_torch(onnx_model, np_inputs).numpy()
|
||||
print(tinygrad_out_np, torch_out)
|
||||
np.testing.assert_allclose(torch_out, tinygrad_out_np, atol=1e-4, rtol=1e-2)
|
||||
@@ -227,14 +228,14 @@ def compile(input, output_fn):
|
||||
"id": ptr, "arg_type": "float*", "needs_load": needs_load, "size": a.size,
|
||||
})
|
||||
if needs_load:
|
||||
data = np.empty(a.size, dtype=np.uint8)
|
||||
data = np.empty(a.size//4, dtype=np.float32)
|
||||
CL.enqueue_copy(data, a, is_blocking=True)
|
||||
weights.append(data.tobytes())
|
||||
elif isinstance(a, cl.Image):
|
||||
needs_load = a in kernels_to_save
|
||||
row_pitch = (a.shape[0]*4*2 + 63)//64 * 64
|
||||
row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64
|
||||
size = row_pitch * a.shape[1]
|
||||
buf = CLBuffer(size)
|
||||
buf = CLBuffer(size * (1 if FLOAT16 else 2))
|
||||
|
||||
# zero out the buffer
|
||||
zeros = np.zeros(size, dtype=np.uint8)
|
||||
@@ -248,7 +249,7 @@ def compile(input, output_fn):
|
||||
l.x = get_global_id(0);
|
||||
out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
|
||||
}
|
||||
""", argdtypes=(None, None, np.int32))(a.shape, None, a, buf.cl, row_pitch)
|
||||
""", argdtypes=(None, None, np.int32))(a.shape, None, a, buf.cl, row_pitch//(4*(2 if FLOAT16 else 4)))
|
||||
|
||||
# multiple of 32 isn't enough
|
||||
jdat['objects'].append({
|
||||
@@ -257,8 +258,9 @@ def compile(input, output_fn):
|
||||
})
|
||||
|
||||
if needs_load:
|
||||
data = np.empty(size, dtype=np.uint8)
|
||||
data = np.empty(size//2, dtype=np.float32)
|
||||
CL.enqueue_copy(data, buf.cl, is_blocking=True)
|
||||
if FLOAT16: data = data.astype(np.float16)
|
||||
weights.append(data.tobytes())
|
||||
else:
|
||||
raise Exception("unknown object", a)
|
||||
|
||||
Reference in New Issue
Block a user