mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
kernel count not relevant if speed is good
This commit is contained in:
@@ -38,21 +38,8 @@ def compile(onnx_file):
|
|||||||
if i == 1: test_val = np.copy(ret)
|
if i == 1: test_val = np.copy(ret)
|
||||||
print(f"captured {len(run_onnx_jit.captured.jit_cache)} kernels")
|
print(f"captured {len(run_onnx_jit.captured.jit_cache)} kernels")
|
||||||
np.testing.assert_equal(test_val, ret, "JIT run failed")
|
np.testing.assert_equal(test_val, ret, "JIT run failed")
|
||||||
|
|
||||||
print("jit run validated")
|
print("jit run validated")
|
||||||
|
|
||||||
# checks from compile2
|
|
||||||
kernel_count = 0
|
|
||||||
read_image_count = 0
|
|
||||||
gated_read_image_count = 0
|
|
||||||
for ei in run_onnx_jit.captured.jit_cache:
|
|
||||||
if isinstance(ei.prg, CompiledRunner):
|
|
||||||
kernel_count += 1
|
|
||||||
read_image_count += ei.prg.p.src.count("read_image")
|
|
||||||
gated_read_image_count += ei.prg.p.src.count("?read_image")
|
|
||||||
for v in [m.group(1) for m in re.finditer(r'(val\d+)\s*=\s*read_imagef\(', ei.prg.p.src)]:
|
|
||||||
if len(re.findall(fr'[\?\:]{v}\.[xyzw]', ei.prg.p.src)) > 0: gated_read_image_count += 1
|
|
||||||
print(f"{kernel_count=}, {read_image_count=}, {gated_read_image_count=}")
|
|
||||||
|
|
||||||
with open(OUTPUT, "wb") as f:
|
with open(OUTPUT, "wb") as f:
|
||||||
pickle.dump(run_onnx_jit, f)
|
pickle.dump(run_onnx_jit, f)
|
||||||
mdl_sz = os.path.getsize(onnx_file)
|
mdl_sz = os.path.getsize(onnx_file)
|
||||||
|
|||||||
Reference in New Issue
Block a user