actual tests for the dsp backend [pr] (#9102)

* actual tests for the dsp backend [pr]

* fix name
This commit is contained in:
George Hotz
2025-02-15 15:17:56 +08:00
committed by GitHub
parent 7e09057afa
commit 4672d9af73
2 changed files with 148 additions and 31 deletions

View File

@@ -3,7 +3,7 @@ import numpy as np
from extra.datasets.imagenet import get_imagenet_categories, get_val_files, center_crop
from examples.benchmark_onnx import load_onnx_model
from PIL import Image
from tinygrad import Tensor, dtypes
from tinygrad import Tensor, dtypes, GlobalCounters
from tinygrad.helpers import fetch, getenv
# works:
@@ -19,6 +19,7 @@ from tinygrad.helpers import fetch, getenv
# QUANT=1 python3 examples/test_onnx_imagenet.py
# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx
# DONT_REALIZE_EXPAND=1 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
# VIZ=1 DONT_REALIZE_EXPAND=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx
def imagenet_dataloader(cnt=0):
@@ -65,7 +66,8 @@ if __name__ == "__main__":
assert t_spec.shape[1:] == (3,224,224), f"shape is {t_spec.shape}"
hit = 0
for i,(img,y) in enumerate(imagenet_dataloader(cnt=100)):
for i,(img,y) in enumerate(imagenet_dataloader(cnt=getenv("CNT", 100))):
GlobalCounters.reset()
p = run_onnx_jit(**{t_name:img})
assert p.shape == (1,1000)
t = p.argmax().item()