mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
actual tests for the dsp backend [pr] (#9102)
* actual tests for the dsp backend [pr] * fix name
This commit is contained in:
@@ -3,7 +3,7 @@ import numpy as np
|
||||
from extra.datasets.imagenet import get_imagenet_categories, get_val_files, center_crop
|
||||
from examples.benchmark_onnx import load_onnx_model
|
||||
from PIL import Image
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad import Tensor, dtypes, GlobalCounters
|
||||
from tinygrad.helpers import fetch, getenv
|
||||
|
||||
# works:
|
||||
@@ -19,6 +19,7 @@ from tinygrad.helpers import fetch, getenv
|
||||
|
||||
# QUANT=1 python3 examples/test_onnx_imagenet.py
|
||||
# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx
|
||||
# DONT_REALIZE_EXPAND=1 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
|
||||
# VIZ=1 DONT_REALIZE_EXPAND=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx
|
||||
|
||||
def imagenet_dataloader(cnt=0):
|
||||
@@ -65,7 +66,8 @@ if __name__ == "__main__":
|
||||
assert t_spec.shape[1:] == (3,224,224), f"shape is {t_spec.shape}"
|
||||
|
||||
hit = 0
|
||||
for i,(img,y) in enumerate(imagenet_dataloader(cnt=100)):
|
||||
for i,(img,y) in enumerate(imagenet_dataloader(cnt=getenv("CNT", 100))):
|
||||
GlobalCounters.reset()
|
||||
p = run_onnx_jit(**{t_name:img})
|
||||
assert p.shape == (1,1000)
|
||||
t = p.argmax().item()
|
||||
|
||||
Reference in New Issue
Block a user