mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fixes from the dsp branch + 12500 lines (#9683)
* fixes from the dsp branch * more changes * those are gep pushing
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import sys, onnx, time
|
||||
import sys, onnx, time, pickle
|
||||
from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from extra.onnx_helpers import get_example_inputs, validate
|
||||
@@ -33,4 +33,9 @@ if __name__ == "__main__":
|
||||
|
||||
if getenv("ORT"):
|
||||
validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3)
|
||||
print("model validated")
|
||||
print("model validated")
|
||||
|
||||
if (fn:=getenv("SAVE_PKL", "")) != "":
|
||||
with open(fn, "wb") as f:
|
||||
pickle.dump(run_onnx_jit, f)
|
||||
print(f"pkl saved to {fn}")
|
||||
|
||||
@@ -70,7 +70,7 @@ if __name__ == "__main__":
|
||||
GlobalCounters.reset()
|
||||
p = run_onnx_jit(**{t_name:img})
|
||||
assert p.shape == (1,1000)
|
||||
t = p.argmax().item()
|
||||
t = p.to('cpu').argmax().item()
|
||||
hit += y==t
|
||||
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user