mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
Reactivate commavq/gpt2m benchmark (#1731)
* get commavq/gpt2m from huggingface * increase tols
This commit is contained in:
5
test/external/external_model_benchmark.py
vendored
5
test/external/external_model_benchmark.py
vendored
@@ -18,8 +18,7 @@ MODELS = {
|
||||
# "openpilot": "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx",
|
||||
"efficientnet": "https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx",
|
||||
"shufflenet": "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx",
|
||||
# cannot download the model from github
|
||||
# "commavq": "https://github.com/commaai/commavq/raw/master/models/gpt2m.onnx",
|
||||
"commavq": "https://huggingface.co/commaai/commavq-gpt2m/resolve/main/gpt2m.onnx",
|
||||
|
||||
# broken in torch MPS
|
||||
#"zfnet": "https://github.com/onnx/models/raw/main/vision/classification/zfnet-512/model/zfnet512-9.onnx",
|
||||
@@ -101,7 +100,7 @@ def benchmark_model(m, validate_outs=False):
|
||||
del ort_sess
|
||||
|
||||
if validate_outs:
|
||||
rtol, atol = 8e-4, 8e-4 # tolerance for fp16 models
|
||||
rtol, atol = 2e-3, 2e-3 # tolerance for fp16 models
|
||||
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
|
||||
tinygrad_model = get_run_onnx(onnx_model)
|
||||
tinygrad_out = tinygrad_model(inputs)
|
||||
|
||||
Reference in New Issue
Block a user