Add Onnx Huggingface to test/models/test_onnx.py (#11468)

* BOOM

* cache extra/huggingface/models/

* why max buffer size is not 0

* override MAX_BUFFER_SIZE

* less models

* remove more models and change cache dir to already cached dir

* only metal

* less is more?

* remove check ops

* why is this not setting the ENVVAR

* ughhhhh just test in models

* only cpu and gpu

* only cpu actually

* just override it idk

* final

* move extra dependencies up top

* simplification

* fix print

* make README better

* revert ops_disk fix for now

* clean up test_onnx

* remove testing fashion clip model cuz sloooowwwwww

* actually let METAL run this

* fix comment mistake

* fix download path in run_models

* does this work?

* cleanup setup and teardown

* contextvar like this?

* prove model is cached

* do I need to increment DOWNLOAD_CACHE_VERSION?

* see if cached with incremented DOWNLOAD_CACHE_VERSION

* use warnings to see if the model exists

* revert DOWNLOAD_CACHE_VERSION stuff and clean up

* add retry to download

* nit
This commit is contained in:
geohotstan
2025-08-14 23:16:41 +08:00
committed by GitHub
parent 06beeb6e13
commit 1e904155e3
10 changed files with 361 additions and 171 deletions

View File

@@ -62,8 +62,6 @@ jobs:
run: BENCHMARK_LOG=stable_diffusion_xl CAPTURE_PROCESS_REPLAY=0 JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
- name: Run model inference benchmark
run: METAL=1 python3.11 test/external/external_model_benchmark.py
- name: Run huggingface_onnx test
run: METAL=1 python3.11 extra/huggingface_onnx/run_models.py test --debug FacebookAI/xlm-roberta-large
- name: Test speed vs torch
run: BIG=2 MPS=1 python3.11 test/test_speed_v_torch.py | tee torch_speed.txt
- name: Test tensor cores

View File

@@ -0,0 +1,61 @@
# HuggingFace ONNX
Tool for discovering, downloading, and validating ONNX models from HuggingFace.
## Extra Dependencies
```bash
pip install huggingface_hub pyyaml requests onnx onnxruntime numpy
```
## Huggingface Manager (discovering and downloading)
The `huggingface_manager.py` script discovers top ONNX models from HuggingFace, collects metadata, and optionally downloads them.
```bash
# Download top 50 models sorted by downloads
python huggingface_manager.py --limit 50 --download
# Just collect metadata (no download)
python huggingface_manager.py --limit 100
# Sort by likes instead of downloads
python huggingface_manager.py --limit 20 --sort likes --download
# Custom output file
python huggingface_manager.py --limit 10 --output my_models.yaml
```
### Output Format
The tool generates a YAML file with the following structure:
```yaml
repositories:
"model-name":
url: "https://huggingface.co/model-name"
download_path: "/path/to/models/..." # when --download used
files:
- file: "model.onnx"
size: "90.91MB"
total_size: "2.45GB"
created_at: "2024-01-15T10:30:00Z"
```
## Run Models (validation)
The `run_models.py` script validates ONNX models against ONNX Runtime for correctness.
```bash
# Validate models from a YAML configuration file
python run_models.py --validate huggingface_repos.yaml
# Debug specific repository (downloads and validates all ONNX models)
python run_models.py --debug sentence-transformers/all-MiniLM-L6-v2
# Debug specific model file
python run_models.py --debug sentence-transformers/all-MiniLM-L6-v2/onnx/model.onnx
# Debug with model truncation for debugging and validating intermediate results
DEBUGONNX=1 python run_models.py --debug sentence-transformers/all-MiniLM-L6-v2/onnx/model.onnx --truncate 10
```

View File

@@ -1,85 +0,0 @@
import yaml, time, requests, argparse
from pathlib import Path
from huggingface_hub import list_models, HfApi
from tinygrad.helpers import tqdm
HUGGINGFACE_URL = "https://huggingface.co"
SKIPPED_FILES = [
"fp16", "int8", "uint8", "quantized", # numerical accuracy issues
"avx2", "arm64", "avx512", "avx512_vnni", # numerical accuracy issues
"q4", "q4f16", "bnb4", # unimplemented quantization
"model_O4", # requires non cpu ort runner and MemcpyFromHost op
"merged", # TODO implement attribute with graph type and Loop op
]
SKIPPED_REPO_PATHS = [
# Invalid model-index
"AdamCodd/vit-base-nsfw-detector",
# TODO: implement attribute with graph type and Loop op
"minishlab/potion-base-8M", "minishlab/M2V_base_output", "minishlab/potion-retrieval-32M",
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, GroupQueryAttention
"HuggingFaceTB/SmolLM2-360M-Instruct",
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, RotaryEmbedding, MultiHeadAttention
"HuggingFaceTB/SmolLM2-1.7B-Instruct",
# TODO: implmement RandomNormalLike
"stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/sdxl-turbo", 'SimianLuo/LCM_Dreamshaper_v7',
# TODO: implement NonZero
"mangoapps/fb_zeroshot_mnli_onnx",
# TODO huge Concat in here with 1024 (1, 3, 32, 32) Tensors, and maybe a MOD bug with const folding
"briaai/RMBG-2.0",
]
def get_top_repos(n: int, sort: str) -> list[str]: # list["FacebookAI/xlm-roberta-large", ...]
print(f"** Getting top {n} models sorted by {sort} **")
repos = []
i = 0
for model in list_models(filter="onnx", sort=sort):
if model.id in SKIPPED_REPO_PATHS: continue
print(f"{i+1}/{n}: {model.id} ({getattr(model, sort)})")
repos.append(model.id)
i += 1
if i == n: break
return repos
def get_metadata(repos:list[str]) -> dict:
api = HfApi()
repos_metadata = {"repositories": {}}
total_size = 0
# TODO: speed head requests up with async?
for repo in tqdm(repos, desc="Getting metadata"):
files_metadata = []
model_info = api.model_info(repo)
for file in model_info.siblings:
filename = file.rfilename
if not (filename.endswith('.onnx') or filename.endswith('.onnx_data')): continue
if any(skip_str in filename for skip_str in SKIPPED_FILES): continue
head = requests.head(f"{HUGGINGFACE_URL}/{repo}/resolve/main/{filename}", allow_redirects=True)
file_size = file.size or int(head.headers.get('Content-Length', 0))
files_metadata.append({"file": filename, "size": f"{file_size/1e6:.2f}MB"})
total_size += file_size
repos_metadata["repositories"][repo] = {
"url": f"{HUGGINGFACE_URL}/{repo}",
"download_path": None,
"files": files_metadata,
}
repos_metadata['total_size'] = f"{total_size/1e9:.2f}GB"
repos_metadata['created_at'] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
return repos_metadata
if __name__ == "__main__":
sort = "downloads" # recent 30 days downloads
huggingface_onnx_dir = Path(__file__).parent
parser = argparse.ArgumentParser(description="Produces a YAML file with metadata of top huggingface onnx models")
parser.add_argument("--limit", type=int, required=True, help="Number of top repositories to process (e.g., 100)")
parser.add_argument("--output", type=str, default="huggingface_repos.yaml", help="Output YAML file name to save the report")
args = parser.parse_args()
top_repos = get_top_repos(args.limit, sort)
metadata = get_metadata(top_repos)
yaml_path = huggingface_onnx_dir / args.output
with open(yaml_path, 'w') as f:
yaml.dump(metadata, f, sort_keys=False)
print(f"YAML saved to: {str(yaml_path)}")

View File

@@ -1,29 +0,0 @@
import yaml, argparse
from pathlib import Path
from huggingface_hub import snapshot_download
def download_models(yaml_file: str, download_dir: str) -> None:
with open(yaml_file, 'r') as f: metadata = yaml.safe_load(f)
n = len(metadata["repositories"])
for i, (model_id, model_data) in enumerate(metadata["repositories"].items()):
print(f"Downloading {i+1}/{n}: {model_id}...")
allow_patterns = [file_info["file"] for file_info in model_data["files"]]
root_path = Path(snapshot_download(repo_id=model_id, allow_patterns=allow_patterns, cache_dir=download_dir))
# download configs too (the sizes are small)
snapshot_download(repo_id=model_id, allow_patterns=["*config.json"], cache_dir=download_dir)
print(f"Downloaded model files to: {root_path}")
model_data["download_path"] = str(root_path)
# Save the updated metadata back to the YAML file
with open(yaml_file, 'w') as f: yaml.dump(metadata, f, sort_keys=False)
print("Download completed according to YAML file.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download models from Huggingface Hub based on a YAML configuration file.")
parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
args = parser.parse_args()
models_folder = Path(__file__).parent / "models"
models_folder.mkdir(parents=True, exist_ok=True)
download_models(args.input, str(models_folder))

View File

@@ -0,0 +1,230 @@
import yaml
import time
import requests
import argparse
from pathlib import Path
from huggingface_hub import list_models, HfApi, snapshot_download
from tinygrad.helpers import _ensure_downloads_dir
DOWNLOADS_DIR = _ensure_downloads_dir() / "models"
from tinygrad.helpers import tqdm
def snapshot_download_with_retry(*, repo_id: str, allow_patterns: list[str]|tuple[str, ...]|None=None, cache_dir: str|Path|None=None,
tries: int=2, **kwargs) -> Path:
for attempt in range(tries):
try:
return Path(snapshot_download(
repo_id=repo_id,
allow_patterns=allow_patterns,
cache_dir=str(cache_dir) if cache_dir is not None else None,
**kwargs
))
except Exception as e:
if attempt == tries-1: raise
time.sleep(1)
# Constants for filtering models
HUGGINGFACE_URL = "https://huggingface.co"
SKIPPED_FILES = [
"fp16", "int8", "uint8", "quantized", # numerical accuracy issues
"avx2", "arm64", "avx512", "avx512_vnni", # numerical accuracy issues
"q4", "q4f16", "bnb4", # unimplemented quantization
"model_O4", # requires non cpu ort runner and MemcpyFromHost op
"merged", # TODO implement attribute with graph type and Loop op
]
SKIPPED_REPO_PATHS = [
# Invalid model-index
"AdamCodd/vit-base-nsfw-detector",
# TODO: implement attribute with graph type and Loop op
"minishlab/potion-base-8M", "minishlab/M2V_base_output", "minishlab/potion-retrieval-32M",
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, GroupQueryAttention
"HuggingFaceTB/SmolLM2-360M-Instruct",
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, RotaryEmbedding, MultiHeadAttention
"HuggingFaceTB/SmolLM2-1.7B-Instruct",
# TODO: implement RandomNormalLike
"stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/sdxl-turbo", 'SimianLuo/LCM_Dreamshaper_v7',
# TODO: implement NonZero
"mangoapps/fb_zeroshot_mnli_onnx",
# TODO huge Concat in here with 1024 (1, 3, 32, 32) Tensors, and maybe a MOD bug with const folding
"briaai/RMBG-2.0",
]
class HuggingFaceONNXManager:
def __init__(self):
self.base_dir = Path(__file__).parent
self.models_dir = DOWNLOADS_DIR
self.api = HfApi()
def discover_models(self, limit: int, sort: str = "downloads") -> list[str]:
print(f"Discovering top {limit} ONNX models sorted by {sort}...")
repos = []
i = 0
for model in list_models(filter="onnx", sort=sort):
if model.id in SKIPPED_REPO_PATHS:
continue
print(f" {i+1}/{limit}: {model.id} ({getattr(model, sort)})")
repos.append(model.id)
i += 1
if i == limit:
break
print(f"Found {len(repos)} suitable ONNX models")
return repos
def collect_metadata(self, repos: list[str]) -> dict:
print(f"Collecting metadata for {len(repos)} repositories...")
metadata = {"repositories": {}}
total_size = 0
for repo in tqdm(repos, desc="Collecting metadata"):
try:
files_metadata = []
model_info = self.api.model_info(repo)
for file in model_info.siblings:
filename = file.rfilename
if not (filename.endswith('.onnx') or filename.endswith('.onnx_data')):
continue
if any(skip_str in filename for skip_str in SKIPPED_FILES):
continue
# Get file size from API or HEAD request
try:
head = requests.head(
f"{HUGGINGFACE_URL}/{repo}/resolve/main/{filename}",
allow_redirects=True,
timeout=10
)
file_size = file.size or int(head.headers.get('Content-Length', 0))
except requests.RequestException:
file_size = file.size or 0
files_metadata.append({
"file": filename,
"size": f"{file_size/1e6:.2f}MB"
})
total_size += file_size
if files_metadata: # Only add repos with valid ONNX files
metadata["repositories"][repo] = {
"url": f"{HUGGINGFACE_URL}/{repo}",
"download_path": None,
"files": files_metadata,
}
except Exception as e:
print(f"WARNING: Failed to collect metadata for {repo}: {e}")
continue
metadata['total_size'] = f"{total_size/1e9:.2f}GB"
metadata['created_at'] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
print(f"Collected metadata for {len(metadata['repositories'])} repositories")
print(f"Total estimated download size: {metadata['total_size']}")
return metadata
def download_models(self, metadata: dict) -> dict:
self.models_dir.mkdir(parents=True, exist_ok=True)
repos = metadata["repositories"]
n = len(repos)
print(f"Downloading {n} repositories to {self.models_dir}...")
for i, (model_id, model_data) in enumerate(repos.items()):
print(f" Downloading {i+1}/{n}: {model_id}...")
try:
# Download ONNX model files
allow_patterns = [file_info["file"] for file_info in model_data["files"]]
root_path = snapshot_download_with_retry(
repo_id=model_id,
allow_patterns=allow_patterns,
cache_dir=str(self.models_dir)
)
# Download config files (usually small)
snapshot_download_with_retry(
repo_id=model_id,
allow_patterns=["*config.json"],
cache_dir=str(self.models_dir)
)
model_data["download_path"] = str(root_path)
print(f" Downloaded to: {root_path}")
except Exception as e:
print(f" ERROR: Failed to download {model_id}: {e}")
model_data["download_path"] = None
continue
successful_downloads = sum(1 for repo in repos.values() if repo["download_path"] is not None)
print(f"Successfully downloaded {successful_downloads}/{n} repositories")
print(f"All models saved to: {self.models_dir}")
return metadata
def save_metadata(self, metadata: dict, output_file: str):
yaml_path = self.base_dir / output_file
with open(yaml_path, 'w') as f:
yaml.dump(metadata, f, sort_keys=False)
print(f"Metadata saved to: {yaml_path}")
def discover_and_download(self, limit: int, output_file: str = "huggingface_repos.yaml",
sort: str = "downloads", download: bool = True):
print(f"Starting HuggingFace ONNX workflow...")
print(f" Limit: {limit} models")
print(f" Sort by: {sort}")
print(f" Download: {'Yes' if download else 'No'}")
print(f" Output: {output_file}")
print("-" * 50)
repos = self.discover_models(limit, sort)
metadata = self.collect_metadata(repos)
if download:
metadata = self.download_models(metadata)
self.save_metadata(metadata, output_file)
print("-" * 50)
print("Workflow completed successfully!")
if download:
successful = sum(1 for repo in metadata["repositories"].values()
if repo["download_path"] is not None)
print(f"{successful}/{len(metadata['repositories'])} models downloaded")
return metadata
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="HuggingFace ONNX Model Manager - Discover, collect metadata, and download ONNX models",
)
parser.add_argument("--limit", type=int, help="Number of top repositories to process")
parser.add_argument("--output", type=str, default="huggingface_repos.yaml",
help="Output YAML file name (default: huggingface_repos.yaml)")
parser.add_argument("--sort", type=str, default="downloads",
choices=["downloads", "likes", "created", "modified"],
help="Sort criteria for model discovery (default: downloads)")
parser.add_argument("--download", action="store_true", default=False,
help="Download models after collecting metadata")
args = parser.parse_args()
if not args.limit: parser.error("--limit is required")
manager = HuggingFaceONNXManager()
manager.discover_and_download(
limit=args.limit,
output_file=args.output,
sort=args.sort,
download=args.download
)

View File

@@ -1,10 +1,11 @@
import onnx, yaml, tempfile, time, collections, pprint, argparse, json
import onnx, yaml, tempfile, time, argparse, json
from pathlib import Path
from typing import Any
from tinygrad.frontend.onnx import OnnxRunner
from extra.onnx import get_onnx_ops
from extra.onnx_helpers import validate, get_example_inputs
from extra.huggingface_onnx.huggingface_manager import DOWNLOADS_DIR, snapshot_download_with_retry
def get_config(root_path: Path):
def get_config(root_path: Path) -> dict[str, Any]:
ret = {}
for path in root_path.rglob("*config.json"):
config = json.load(path.open())
@@ -12,19 +13,19 @@ def get_config(root_path: Path):
ret.update(config)
return ret
def run_huggingface_validate(onnx_model_path, config, rtol, atol):
onnx_runner = OnnxRunner(onnx_model_path)
inputs = get_example_inputs(onnx_runner.graph_inputs, config)
validate(onnx_model_path, inputs, rtol=rtol, atol=atol)
def get_tolerances(file_name): # -> rtol, atol
def get_tolerances(file_name: str) -> tuple[float, float]:
# TODO very high rtol atol
if "fp16" in file_name: return 9e-2, 9e-2
if any(q in file_name for q in ["int8", "uint8", "quantized"]): return 4, 4
return 4e-3, 3e-2
def run_huggingface_validate(onnx_model_path: str | Path, config: dict[str, Any], rtol: float, atol: float):
onnx_runner = OnnxRunner(onnx_model_path)
inputs = get_example_inputs(onnx_runner.graph_inputs, config)
validate(onnx_model_path, inputs, rtol=rtol, atol=atol)
def validate_repos(models:dict[str, tuple[Path, Path]]):
print(f"** Validating {len(model_paths)} models **")
print(f"** Validating {len(models)} models **")
for model_id, (root_path, relative_path) in models.items():
print(f"validating model {model_id}")
model_path = root_path / relative_path
@@ -36,25 +37,6 @@ def validate_repos(models:dict[str, tuple[Path, Path]]):
et = time.time() - st
print(f"passed, took {et:.2f}s")
def retrieve_op_stats(models:dict[str, tuple[Path, Path]]) -> dict:
ret = {}
op_counter = collections.Counter()
unsupported_ops = collections.defaultdict(set)
supported_ops = get_onnx_ops()
print(f"** Retrieving stats from {len(model_paths)} models **")
for model_id, (root_path, relative_path) in models.items():
print(f"examining {model_id}")
model_path = root_path / relative_path
onnx_runner = OnnxRunner(model_path)
for node in onnx_runner.graph_nodes:
op_counter[node.op] += 1
if node.op not in supported_ops:
unsupported_ops[node.op].add(model_id)
del onnx_runner
ret["unsupported_ops"] = {k:list(v) for k, v in unsupported_ops.items()}
ret["op_counter"] = op_counter.most_common()
return ret
def debug_run(model_path, truncate, config, rtol, atol):
if truncate != -1:
model = onnx.load(model_path)
@@ -71,12 +53,9 @@ def debug_run(model_path, truncate, config, rtol, atol):
run_huggingface_validate(model_path, config, rtol, atol)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Huggingface ONNX Model Validator and Ops Checker")
parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
parser.add_argument("--check_ops", action="store_true", default=False,
help="Check support for ONNX operations in models from the YAML file")
parser.add_argument("--validate", action="store_true", default=False,
help="Validate correctness of models from the YAML file")
parser = argparse.ArgumentParser(description="Huggingface ONNX Model Validator")
parser.add_argument("--validate", type=str, default="",
help="Validate correctness of models from the specified YAML configuration file")
parser.add_argument("--debug", type=str, default="",
help="""Validates without explicitly needing a YAML or models pre-installed.
provide repo id (e.g. "minishlab/potion-base-8M") to validate all onnx models inside the repo
@@ -85,13 +64,13 @@ if __name__ == "__main__":
parser.add_argument("--truncate", type=int, default=-1, help="Truncate the ONNX model so intermediate results can be validated")
args = parser.parse_args()
if not (args.check_ops or args.validate or args.debug):
parser.error("Please provide either --validate, --check_ops, or --debug.")
if not (args.validate or args.debug):
parser.error("Please provide either --validate <yaml_file> or --debug <repo_id>.")
if args.truncate != -1 and not args.debug:
parser.error("--truncate and --debug should be used together for debugging")
if args.check_ops or args.validate:
with open(args.input, 'r') as f:
if args.validate:
with open(args.validate, 'r') as f:
data = yaml.safe_load(f)
assert all(repo["download_path"] is not None for repo in data["repositories"].values()), "please run `download_models.py` for this yaml"
model_paths = {
@@ -101,22 +80,16 @@ if __name__ == "__main__":
if model["file"].endswith(".onnx")
}
if args.check_ops:
pprint.pprint(retrieve_op_stats(model_paths))
if args.validate:
validate_repos(model_paths)
validate_repos(model_paths)
if args.debug:
from huggingface_hub import snapshot_download
download_dir = Path(__file__).parent / "models"
path:list[str] = args.debug.split("/")
if len(path) == 2:
# repo id
# validates all onnx models inside repo
repo_id = "/".join(path)
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=["*.onnx", "*.onnx_data"], cache_dir=download_dir))
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
root_path = snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*.onnx", "*.onnx_data"], cache_dir=DOWNLOADS_DIR)
snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=DOWNLOADS_DIR)
config = get_config(root_path)
for onnx_model in root_path.rglob("*.onnx"):
rtol, atol = get_tolerances(onnx_model.name)
@@ -128,8 +101,8 @@ if __name__ == "__main__":
onnx_model = path[-1]
assert path[-1].endswith(".onnx")
repo_id, relative_path = "/".join(path[:2]), "/".join(path[2:])
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=download_dir))
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
root_path = snapshot_download_with_retry(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=DOWNLOADS_DIR)
snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=DOWNLOADS_DIR)
config = get_config(root_path)
rtol, atol = get_tolerances(onnx_model)
print(f"validating {relative_path} with truncate={args.truncate}, {rtol=}, {atol=}")

View File

@@ -48,6 +48,9 @@ class Domain(enum.Enum):
AI_ONNX_TRAINING = "ai.onnx.training"
AI_ONNX_PREVIEW_TRAINING = "ai.onnx.preview.training"
MICROSOFT_CONTRIB_OPS = "com.microsoft"
MICROSOFT_NCHWC = "com.microsoft.nchwc"
MICROSOFT_EXPERIMENTAL = "com.microsoft.experimental"
PYTORCH_ATEN = "org.pytorch.aten"
@classmethod
def from_onnx(cls, domain: str | None) -> "Domain": return cls.ONNX if domain is None or domain == "" else cls(domain)

View File

@@ -9,7 +9,15 @@ except ModuleNotFoundError:
raise unittest.SkipTest("onnx not installed, skipping onnx test")
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI, fetch, temp
from tinygrad.device import Device
from tinygrad.helpers import CI, fetch, temp, Context
try:
from extra.onnx_helpers import validate
from extra.huggingface_onnx.huggingface_manager import DOWNLOADS_DIR, snapshot_download_with_retry
HUGGINGFACE_AVAILABLE = True
except ModuleNotFoundError:
HUGGINGFACE_AVAILABLE = False
def run_onnx_torch(onnx_model, inputs):
import torch
@@ -137,5 +145,36 @@ class TestOnnxModel(unittest.TestCase):
print(cls, _LABELS[cls])
assert "car" in _LABELS[cls] or _LABELS[cls] == "convertible"
@unittest.skipUnless(HUGGINGFACE_AVAILABLE and Device.DEFAULT == "METAL", "only run on METAL")
class TestHuggingFaceOnnxModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._ctx = Context(MAX_BUFFER_SIZE=0)
cls._ctx.__enter__()
@classmethod
def tearDownClass(cls):
cls._ctx.__exit__()
def _validate(self, repo_id, model_file, custom_inputs, rtol=1e-4, atol=1e-4):
onnx_model_path = snapshot_download_with_retry(
repo_id=repo_id,
allow_patterns=["*.onnx", "*.onnx_data"],
cache_dir=str(DOWNLOADS_DIR)
)
onnx_model_path = onnx_model_path / model_file
file_size = onnx_model_path.stat().st_size
print(f"Validating model: {repo_id}/{model_file} ({file_size/1e6:.2f}M)")
validate(onnx_model_path, custom_inputs, rtol=rtol, atol=atol)
def test_xlm_roberta_large(self):
repo_id = "FacebookAI/xlm-roberta-large"
model_file = "onnx/model.onnx"
custom_inputs = {
"input_ids": np.random.randint(0, 250002, (1, 11), dtype=np.int64),
"attention_mask": np.ones((1, 11), dtype=np.int64),
}
self._validate(repo_id, model_file, custom_inputs)
if __name__ == "__main__":
unittest.main()

View File

@@ -3,8 +3,8 @@ from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Any, Generic, TypeVar, Iterator
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal, time
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, \
colored, Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, \
Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer
@@ -124,7 +124,7 @@ class Buffer:
def allocate(self, opaque=None, external_ptr=None) -> Buffer:
assert not self.is_initialized(), "can't allocate already allocated buffer"
if DEBUG >= 7: print(f"buffer: allocate {self.nbytes} bytes on {self.device}")
if (mbs:=getenv("MAX_BUFFER_SIZE", 0)) > 0 and self.size > mbs: raise RuntimeError(f"buffer of size {self.size/1e6:.2f}M is too large")
if MAX_BUFFER_SIZE > 0 and self.size > MAX_BUFFER_SIZE: raise RuntimeError(f"buffer of size {self.size/1e6:.2f}M is too large")
self.allocator:Allocator = Device[self.device].allocator
if external_ptr is not None:
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr)

View File

@@ -139,7 +139,7 @@ DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 0)
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
QUANTIZE, VALIDATE_WITH_CPU = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0)
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
ALLOW_DEVICE_USAGE, AMD_LLVM = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("AMD_LLVM", 1)
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, AMD_LLVM = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0), ContextVar("AMD_LLVM", 1)
@dataclass(frozen=True)
class Metadata: