mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix commavq benchmark (#4712)
* fix _slice and assert explicit device * with _slice
This commit is contained in:
13
test/external/external_model_benchmark.py
vendored
13
test/external/external_model_benchmark.py
vendored
@@ -1,5 +1,6 @@
|
||||
import csv, pathlib, time, numpy as np
|
||||
from os import getenv
|
||||
from tinygrad.device import CompileError
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
import onnx
|
||||
@@ -60,8 +61,8 @@ def benchmark_model(m, devices, validate_outs=False):
|
||||
|
||||
# print input names
|
||||
if DEBUG >= 2: print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded])
|
||||
try:
|
||||
for device in devices:
|
||||
for device in devices:
|
||||
try:
|
||||
Device.DEFAULT = device
|
||||
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
|
||||
tinygrad_model = get_run_onnx(onnx_model)
|
||||
@@ -72,10 +73,10 @@ def benchmark_model(m, devices, validate_outs=False):
|
||||
for _ in range(3): {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}
|
||||
benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}) # noqa: F821
|
||||
del inputs, tinygrad_model, tinygrad_jitted_model
|
||||
except Exception as e:
|
||||
# model crashed
|
||||
print(f"{m} crashed on {device} with: {e}")
|
||||
return
|
||||
except CompileError as e:
|
||||
# METAL fails with buffer count limit
|
||||
if m == "dm" and device == "METAL": return
|
||||
raise e
|
||||
|
||||
# convert model to torch
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user