mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Add Onnx Huggingface to test/models/test_onnx.py (#11468)
* BOOM * cache extra/huggingface/models/ * why max buffer size is not 0 * override MAX_BUFFER_SIZE * less models * remove more models and change cache dir to already cached dir * only metal * less is more? * remove check ops * why is this not setting the ENVVAR * ughhhhh just test in models * only cpu and gpu * only cpu actually * just override it idk * final * move extra dependencies up top * simplification * fix print * make README better * revert ops_disk fix for now * clean up test_onnx * remove testing fashion clip model cuz sloooowwwwww * actually let METAL run this * fix comment mistake * fix download path in run_models * does this work? * cleanup setup and teardown * contextvar like this? * prove model is cached * do I need to increment DOWNLOAD_CACHE_VERSION? * see if cached with incremented DOWNLOAD_CACHE_VERSION * use warnings to see if the model exists * revert DOWNLOAD_CACHE_VERSION stuff and clean up * add retry to download * nit
This commit is contained in:
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
@@ -62,8 +62,6 @@ jobs:
|
||||
run: BENCHMARK_LOG=stable_diffusion_xl CAPTURE_PROCESS_REPLAY=0 JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||
- name: Run model inference benchmark
|
||||
run: METAL=1 python3.11 test/external/external_model_benchmark.py
|
||||
- name: Run huggingface_onnx test
|
||||
run: METAL=1 python3.11 extra/huggingface_onnx/run_models.py test --debug FacebookAI/xlm-roberta-large
|
||||
- name: Test speed vs torch
|
||||
run: BIG=2 MPS=1 python3.11 test/test_speed_v_torch.py | tee torch_speed.txt
|
||||
- name: Test tensor cores
|
||||
|
||||
61
extra/huggingface_onnx/README.md
Normal file
61
extra/huggingface_onnx/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# HuggingFace ONNX
|
||||
|
||||
Tool for discovering, downloading, and validating ONNX models from HuggingFace.
|
||||
|
||||
## Extra Dependencies
|
||||
|
||||
```bash
|
||||
pip install huggingface_hub pyyaml requests onnx onnxruntime numpy
|
||||
```
|
||||
|
||||
## Huggingface Manager (discovering and downloading)
|
||||
|
||||
The `huggingface_manager.py` script discovers top ONNX models from HuggingFace, collects metadata, and optionally downloads them.
|
||||
|
||||
```bash
|
||||
# Download top 50 models sorted by downloads
|
||||
python huggingface_manager.py --limit 50 --download
|
||||
|
||||
# Just collect metadata (no download)
|
||||
python huggingface_manager.py --limit 100
|
||||
|
||||
# Sort by likes instead of downloads
|
||||
python huggingface_manager.py --limit 20 --sort likes --download
|
||||
|
||||
# Custom output file
|
||||
python huggingface_manager.py --limit 10 --output my_models.yaml
|
||||
```
|
||||
|
||||
### Output Format
|
||||
|
||||
The tool generates a YAML file with the following structure:
|
||||
|
||||
```yaml
|
||||
repositories:
|
||||
"model-name":
|
||||
url: "https://huggingface.co/model-name"
|
||||
download_path: "/path/to/models/..." # when --download used
|
||||
files:
|
||||
- file: "model.onnx"
|
||||
size: "90.91MB"
|
||||
total_size: "2.45GB"
|
||||
created_at: "2024-01-15T10:30:00Z"
|
||||
```
|
||||
|
||||
## Run Models (validation)
|
||||
|
||||
The `run_models.py` script validates ONNX models against ONNX Runtime for correctness.
|
||||
|
||||
```bash
|
||||
# Validate models from a YAML configuration file
|
||||
python run_models.py --validate huggingface_repos.yaml
|
||||
|
||||
# Debug specific repository (downloads and validates all ONNX models)
|
||||
python run_models.py --debug sentence-transformers/all-MiniLM-L6-v2
|
||||
|
||||
# Debug specific model file
|
||||
python run_models.py --debug sentence-transformers/all-MiniLM-L6-v2/onnx/model.onnx
|
||||
|
||||
# Debug with model truncation for debugging and validating intermediate results
|
||||
DEBUGONNX=1 python run_models.py --debug sentence-transformers/all-MiniLM-L6-v2/onnx/model.onnx --truncate 10
|
||||
```
|
||||
@@ -1,85 +0,0 @@
|
||||
import yaml, time, requests, argparse
|
||||
from pathlib import Path
|
||||
from huggingface_hub import list_models, HfApi
|
||||
from tinygrad.helpers import tqdm
|
||||
|
||||
HUGGINGFACE_URL = "https://huggingface.co"
|
||||
SKIPPED_FILES = [
|
||||
"fp16", "int8", "uint8", "quantized", # numerical accuracy issues
|
||||
"avx2", "arm64", "avx512", "avx512_vnni", # numerical accuracy issues
|
||||
"q4", "q4f16", "bnb4", # unimplemented quantization
|
||||
"model_O4", # requires non cpu ort runner and MemcpyFromHost op
|
||||
"merged", # TODO implement attribute with graph type and Loop op
|
||||
]
|
||||
SKIPPED_REPO_PATHS = [
|
||||
# Invalid model-index
|
||||
"AdamCodd/vit-base-nsfw-detector",
|
||||
# TODO: implement attribute with graph type and Loop op
|
||||
"minishlab/potion-base-8M", "minishlab/M2V_base_output", "minishlab/potion-retrieval-32M",
|
||||
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, GroupQueryAttention
|
||||
"HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, RotaryEmbedding, MultiHeadAttention
|
||||
"HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||
# TODO: implmement RandomNormalLike
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/sdxl-turbo", 'SimianLuo/LCM_Dreamshaper_v7',
|
||||
# TODO: implement NonZero
|
||||
"mangoapps/fb_zeroshot_mnli_onnx",
|
||||
# TODO huge Concat in here with 1024 (1, 3, 32, 32) Tensors, and maybe a MOD bug with const folding
|
||||
"briaai/RMBG-2.0",
|
||||
]
|
||||
|
||||
def get_top_repos(n: int, sort: str) -> list[str]: # list["FacebookAI/xlm-roberta-large", ...]
|
||||
print(f"** Getting top {n} models sorted by {sort} **")
|
||||
repos = []
|
||||
i = 0
|
||||
for model in list_models(filter="onnx", sort=sort):
|
||||
if model.id in SKIPPED_REPO_PATHS: continue
|
||||
print(f"{i+1}/{n}: {model.id} ({getattr(model, sort)})")
|
||||
repos.append(model.id)
|
||||
i += 1
|
||||
if i == n: break
|
||||
return repos
|
||||
|
||||
def get_metadata(repos:list[str]) -> dict:
|
||||
api = HfApi()
|
||||
repos_metadata = {"repositories": {}}
|
||||
total_size = 0
|
||||
|
||||
# TODO: speed head requests up with async?
|
||||
for repo in tqdm(repos, desc="Getting metadata"):
|
||||
files_metadata = []
|
||||
model_info = api.model_info(repo)
|
||||
|
||||
for file in model_info.siblings:
|
||||
filename = file.rfilename
|
||||
if not (filename.endswith('.onnx') or filename.endswith('.onnx_data')): continue
|
||||
if any(skip_str in filename for skip_str in SKIPPED_FILES): continue
|
||||
head = requests.head(f"{HUGGINGFACE_URL}/{repo}/resolve/main/{filename}", allow_redirects=True)
|
||||
file_size = file.size or int(head.headers.get('Content-Length', 0))
|
||||
files_metadata.append({"file": filename, "size": f"{file_size/1e6:.2f}MB"})
|
||||
total_size += file_size
|
||||
|
||||
repos_metadata["repositories"][repo] = {
|
||||
"url": f"{HUGGINGFACE_URL}/{repo}",
|
||||
"download_path": None,
|
||||
"files": files_metadata,
|
||||
}
|
||||
repos_metadata['total_size'] = f"{total_size/1e9:.2f}GB"
|
||||
repos_metadata['created_at'] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
return repos_metadata
|
||||
|
||||
if __name__ == "__main__":
|
||||
sort = "downloads" # recent 30 days downloads
|
||||
huggingface_onnx_dir = Path(__file__).parent
|
||||
|
||||
parser = argparse.ArgumentParser(description="Produces a YAML file with metadata of top huggingface onnx models")
|
||||
parser.add_argument("--limit", type=int, required=True, help="Number of top repositories to process (e.g., 100)")
|
||||
parser.add_argument("--output", type=str, default="huggingface_repos.yaml", help="Output YAML file name to save the report")
|
||||
args = parser.parse_args()
|
||||
|
||||
top_repos = get_top_repos(args.limit, sort)
|
||||
metadata = get_metadata(top_repos)
|
||||
yaml_path = huggingface_onnx_dir / args.output
|
||||
with open(yaml_path, 'w') as f:
|
||||
yaml.dump(metadata, f, sort_keys=False)
|
||||
print(f"YAML saved to: {str(yaml_path)}")
|
||||
@@ -1,29 +0,0 @@
|
||||
import yaml, argparse
|
||||
from pathlib import Path
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
def download_models(yaml_file: str, download_dir: str) -> None:
|
||||
with open(yaml_file, 'r') as f: metadata = yaml.safe_load(f)
|
||||
n = len(metadata["repositories"])
|
||||
|
||||
for i, (model_id, model_data) in enumerate(metadata["repositories"].items()):
|
||||
print(f"Downloading {i+1}/{n}: {model_id}...")
|
||||
allow_patterns = [file_info["file"] for file_info in model_data["files"]]
|
||||
root_path = Path(snapshot_download(repo_id=model_id, allow_patterns=allow_patterns, cache_dir=download_dir))
|
||||
# download configs too (the sizes are small)
|
||||
snapshot_download(repo_id=model_id, allow_patterns=["*config.json"], cache_dir=download_dir)
|
||||
print(f"Downloaded model files to: {root_path}")
|
||||
model_data["download_path"] = str(root_path)
|
||||
|
||||
# Save the updated metadata back to the YAML file
|
||||
with open(yaml_file, 'w') as f: yaml.dump(metadata, f, sort_keys=False)
|
||||
print("Download completed according to YAML file.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download models from Huggingface Hub based on a YAML configuration file.")
|
||||
parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
|
||||
args = parser.parse_args()
|
||||
|
||||
models_folder = Path(__file__).parent / "models"
|
||||
models_folder.mkdir(parents=True, exist_ok=True)
|
||||
download_models(args.input, str(models_folder))
|
||||
230
extra/huggingface_onnx/huggingface_manager.py
Normal file
230
extra/huggingface_onnx/huggingface_manager.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import yaml
|
||||
import time
|
||||
import requests
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from huggingface_hub import list_models, HfApi, snapshot_download
|
||||
from tinygrad.helpers import _ensure_downloads_dir
|
||||
DOWNLOADS_DIR = _ensure_downloads_dir() / "models"
|
||||
from tinygrad.helpers import tqdm
|
||||
|
||||
def snapshot_download_with_retry(*, repo_id: str, allow_patterns: list[str]|tuple[str, ...]|None=None, cache_dir: str|Path|None=None,
|
||||
tries: int=2, **kwargs) -> Path:
|
||||
for attempt in range(tries):
|
||||
try:
|
||||
return Path(snapshot_download(
|
||||
repo_id=repo_id,
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=str(cache_dir) if cache_dir is not None else None,
|
||||
**kwargs
|
||||
))
|
||||
except Exception as e:
|
||||
if attempt == tries-1: raise
|
||||
time.sleep(1)
|
||||
|
||||
# Constants for filtering models
|
||||
HUGGINGFACE_URL = "https://huggingface.co"
|
||||
SKIPPED_FILES = [
|
||||
"fp16", "int8", "uint8", "quantized", # numerical accuracy issues
|
||||
"avx2", "arm64", "avx512", "avx512_vnni", # numerical accuracy issues
|
||||
"q4", "q4f16", "bnb4", # unimplemented quantization
|
||||
"model_O4", # requires non cpu ort runner and MemcpyFromHost op
|
||||
"merged", # TODO implement attribute with graph type and Loop op
|
||||
]
|
||||
|
||||
SKIPPED_REPO_PATHS = [
|
||||
# Invalid model-index
|
||||
"AdamCodd/vit-base-nsfw-detector",
|
||||
# TODO: implement attribute with graph type and Loop op
|
||||
"minishlab/potion-base-8M", "minishlab/M2V_base_output", "minishlab/potion-retrieval-32M",
|
||||
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, GroupQueryAttention
|
||||
"HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, RotaryEmbedding, MultiHeadAttention
|
||||
"HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||
# TODO: implement RandomNormalLike
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/sdxl-turbo", 'SimianLuo/LCM_Dreamshaper_v7',
|
||||
# TODO: implement NonZero
|
||||
"mangoapps/fb_zeroshot_mnli_onnx",
|
||||
# TODO huge Concat in here with 1024 (1, 3, 32, 32) Tensors, and maybe a MOD bug with const folding
|
||||
"briaai/RMBG-2.0",
|
||||
]
|
||||
|
||||
|
||||
class HuggingFaceONNXManager:
|
||||
def __init__(self):
|
||||
self.base_dir = Path(__file__).parent
|
||||
self.models_dir = DOWNLOADS_DIR
|
||||
self.api = HfApi()
|
||||
|
||||
def discover_models(self, limit: int, sort: str = "downloads") -> list[str]:
|
||||
print(f"Discovering top {limit} ONNX models sorted by {sort}...")
|
||||
repos = []
|
||||
i = 0
|
||||
|
||||
for model in list_models(filter="onnx", sort=sort):
|
||||
if model.id in SKIPPED_REPO_PATHS:
|
||||
continue
|
||||
|
||||
print(f" {i+1}/{limit}: {model.id} ({getattr(model, sort)})")
|
||||
repos.append(model.id)
|
||||
i += 1
|
||||
if i == limit:
|
||||
break
|
||||
|
||||
print(f"Found {len(repos)} suitable ONNX models")
|
||||
return repos
|
||||
|
||||
def collect_metadata(self, repos: list[str]) -> dict:
|
||||
print(f"Collecting metadata for {len(repos)} repositories...")
|
||||
metadata = {"repositories": {}}
|
||||
total_size = 0
|
||||
|
||||
for repo in tqdm(repos, desc="Collecting metadata"):
|
||||
try:
|
||||
files_metadata = []
|
||||
model_info = self.api.model_info(repo)
|
||||
|
||||
for file in model_info.siblings:
|
||||
filename = file.rfilename
|
||||
if not (filename.endswith('.onnx') or filename.endswith('.onnx_data')):
|
||||
continue
|
||||
if any(skip_str in filename for skip_str in SKIPPED_FILES):
|
||||
continue
|
||||
|
||||
# Get file size from API or HEAD request
|
||||
try:
|
||||
head = requests.head(
|
||||
f"{HUGGINGFACE_URL}/{repo}/resolve/main/{filename}",
|
||||
allow_redirects=True,
|
||||
timeout=10
|
||||
)
|
||||
file_size = file.size or int(head.headers.get('Content-Length', 0))
|
||||
except requests.RequestException:
|
||||
file_size = file.size or 0
|
||||
|
||||
files_metadata.append({
|
||||
"file": filename,
|
||||
"size": f"{file_size/1e6:.2f}MB"
|
||||
})
|
||||
total_size += file_size
|
||||
|
||||
if files_metadata: # Only add repos with valid ONNX files
|
||||
metadata["repositories"][repo] = {
|
||||
"url": f"{HUGGINGFACE_URL}/{repo}",
|
||||
"download_path": None,
|
||||
"files": files_metadata,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"WARNING: Failed to collect metadata for {repo}: {e}")
|
||||
continue
|
||||
|
||||
metadata['total_size'] = f"{total_size/1e9:.2f}GB"
|
||||
metadata['created_at'] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
|
||||
print(f"Collected metadata for {len(metadata['repositories'])} repositories")
|
||||
print(f"Total estimated download size: {metadata['total_size']}")
|
||||
|
||||
return metadata
|
||||
|
||||
def download_models(self, metadata: dict) -> dict:
|
||||
self.models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
repos = metadata["repositories"]
|
||||
n = len(repos)
|
||||
|
||||
print(f"Downloading {n} repositories to {self.models_dir}...")
|
||||
|
||||
for i, (model_id, model_data) in enumerate(repos.items()):
|
||||
print(f" Downloading {i+1}/{n}: {model_id}...")
|
||||
|
||||
try:
|
||||
# Download ONNX model files
|
||||
allow_patterns = [file_info["file"] for file_info in model_data["files"]]
|
||||
root_path = snapshot_download_with_retry(
|
||||
repo_id=model_id,
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=str(self.models_dir)
|
||||
)
|
||||
|
||||
# Download config files (usually small)
|
||||
snapshot_download_with_retry(
|
||||
repo_id=model_id,
|
||||
allow_patterns=["*config.json"],
|
||||
cache_dir=str(self.models_dir)
|
||||
)
|
||||
|
||||
model_data["download_path"] = str(root_path)
|
||||
print(f" Downloaded to: {root_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ERROR: Failed to download {model_id}: {e}")
|
||||
model_data["download_path"] = None
|
||||
continue
|
||||
|
||||
successful_downloads = sum(1 for repo in repos.values() if repo["download_path"] is not None)
|
||||
print(f"Successfully downloaded {successful_downloads}/{n} repositories")
|
||||
print(f"All models saved to: {self.models_dir}")
|
||||
|
||||
return metadata
|
||||
|
||||
def save_metadata(self, metadata: dict, output_file: str):
|
||||
yaml_path = self.base_dir / output_file
|
||||
with open(yaml_path, 'w') as f:
|
||||
yaml.dump(metadata, f, sort_keys=False)
|
||||
print(f"Metadata saved to: {yaml_path}")
|
||||
|
||||
def discover_and_download(self, limit: int, output_file: str = "huggingface_repos.yaml",
|
||||
sort: str = "downloads", download: bool = True):
|
||||
print(f"Starting HuggingFace ONNX workflow...")
|
||||
print(f" Limit: {limit} models")
|
||||
print(f" Sort by: {sort}")
|
||||
print(f" Download: {'Yes' if download else 'No'}")
|
||||
print(f" Output: {output_file}")
|
||||
print("-" * 50)
|
||||
|
||||
repos = self.discover_models(limit, sort)
|
||||
|
||||
metadata = self.collect_metadata(repos)
|
||||
|
||||
if download:
|
||||
metadata = self.download_models(metadata)
|
||||
|
||||
self.save_metadata(metadata, output_file)
|
||||
|
||||
print("-" * 50)
|
||||
print("Workflow completed successfully!")
|
||||
if download:
|
||||
successful = sum(1 for repo in metadata["repositories"].values()
|
||||
if repo["download_path"] is not None)
|
||||
print(f"{successful}/{len(metadata['repositories'])} models downloaded")
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="HuggingFace ONNX Model Manager - Discover, collect metadata, and download ONNX models",
|
||||
)
|
||||
|
||||
parser.add_argument("--limit", type=int, help="Number of top repositories to process")
|
||||
parser.add_argument("--output", type=str, default="huggingface_repos.yaml",
|
||||
help="Output YAML file name (default: huggingface_repos.yaml)")
|
||||
parser.add_argument("--sort", type=str, default="downloads",
|
||||
choices=["downloads", "likes", "created", "modified"],
|
||||
help="Sort criteria for model discovery (default: downloads)")
|
||||
|
||||
parser.add_argument("--download", action="store_true", default=False,
|
||||
help="Download models after collecting metadata")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.limit: parser.error("--limit is required")
|
||||
|
||||
manager = HuggingFaceONNXManager()
|
||||
manager.discover_and_download(
|
||||
limit=args.limit,
|
||||
output_file=args.output,
|
||||
sort=args.sort,
|
||||
download=args.download
|
||||
)
|
||||
@@ -1,10 +1,11 @@
|
||||
import onnx, yaml, tempfile, time, collections, pprint, argparse, json
|
||||
import onnx, yaml, tempfile, time, argparse, json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from extra.onnx import get_onnx_ops
|
||||
from extra.onnx_helpers import validate, get_example_inputs
|
||||
from extra.huggingface_onnx.huggingface_manager import DOWNLOADS_DIR, snapshot_download_with_retry
|
||||
|
||||
def get_config(root_path: Path):
|
||||
def get_config(root_path: Path) -> dict[str, Any]:
|
||||
ret = {}
|
||||
for path in root_path.rglob("*config.json"):
|
||||
config = json.load(path.open())
|
||||
@@ -12,19 +13,19 @@ def get_config(root_path: Path):
|
||||
ret.update(config)
|
||||
return ret
|
||||
|
||||
def run_huggingface_validate(onnx_model_path, config, rtol, atol):
|
||||
onnx_runner = OnnxRunner(onnx_model_path)
|
||||
inputs = get_example_inputs(onnx_runner.graph_inputs, config)
|
||||
validate(onnx_model_path, inputs, rtol=rtol, atol=atol)
|
||||
|
||||
def get_tolerances(file_name): # -> rtol, atol
|
||||
def get_tolerances(file_name: str) -> tuple[float, float]:
|
||||
# TODO very high rtol atol
|
||||
if "fp16" in file_name: return 9e-2, 9e-2
|
||||
if any(q in file_name for q in ["int8", "uint8", "quantized"]): return 4, 4
|
||||
return 4e-3, 3e-2
|
||||
|
||||
def run_huggingface_validate(onnx_model_path: str | Path, config: dict[str, Any], rtol: float, atol: float):
|
||||
onnx_runner = OnnxRunner(onnx_model_path)
|
||||
inputs = get_example_inputs(onnx_runner.graph_inputs, config)
|
||||
validate(onnx_model_path, inputs, rtol=rtol, atol=atol)
|
||||
|
||||
def validate_repos(models:dict[str, tuple[Path, Path]]):
|
||||
print(f"** Validating {len(model_paths)} models **")
|
||||
print(f"** Validating {len(models)} models **")
|
||||
for model_id, (root_path, relative_path) in models.items():
|
||||
print(f"validating model {model_id}")
|
||||
model_path = root_path / relative_path
|
||||
@@ -36,25 +37,6 @@ def validate_repos(models:dict[str, tuple[Path, Path]]):
|
||||
et = time.time() - st
|
||||
print(f"passed, took {et:.2f}s")
|
||||
|
||||
def retrieve_op_stats(models:dict[str, tuple[Path, Path]]) -> dict:
|
||||
ret = {}
|
||||
op_counter = collections.Counter()
|
||||
unsupported_ops = collections.defaultdict(set)
|
||||
supported_ops = get_onnx_ops()
|
||||
print(f"** Retrieving stats from {len(model_paths)} models **")
|
||||
for model_id, (root_path, relative_path) in models.items():
|
||||
print(f"examining {model_id}")
|
||||
model_path = root_path / relative_path
|
||||
onnx_runner = OnnxRunner(model_path)
|
||||
for node in onnx_runner.graph_nodes:
|
||||
op_counter[node.op] += 1
|
||||
if node.op not in supported_ops:
|
||||
unsupported_ops[node.op].add(model_id)
|
||||
del onnx_runner
|
||||
ret["unsupported_ops"] = {k:list(v) for k, v in unsupported_ops.items()}
|
||||
ret["op_counter"] = op_counter.most_common()
|
||||
return ret
|
||||
|
||||
def debug_run(model_path, truncate, config, rtol, atol):
|
||||
if truncate != -1:
|
||||
model = onnx.load(model_path)
|
||||
@@ -71,12 +53,9 @@ def debug_run(model_path, truncate, config, rtol, atol):
|
||||
run_huggingface_validate(model_path, config, rtol, atol)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Huggingface ONNX Model Validator and Ops Checker")
|
||||
parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
|
||||
parser.add_argument("--check_ops", action="store_true", default=False,
|
||||
help="Check support for ONNX operations in models from the YAML file")
|
||||
parser.add_argument("--validate", action="store_true", default=False,
|
||||
help="Validate correctness of models from the YAML file")
|
||||
parser = argparse.ArgumentParser(description="Huggingface ONNX Model Validator")
|
||||
parser.add_argument("--validate", type=str, default="",
|
||||
help="Validate correctness of models from the specified YAML configuration file")
|
||||
parser.add_argument("--debug", type=str, default="",
|
||||
help="""Validates without explicitly needing a YAML or models pre-installed.
|
||||
provide repo id (e.g. "minishlab/potion-base-8M") to validate all onnx models inside the repo
|
||||
@@ -85,13 +64,13 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--truncate", type=int, default=-1, help="Truncate the ONNX model so intermediate results can be validated")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not (args.check_ops or args.validate or args.debug):
|
||||
parser.error("Please provide either --validate, --check_ops, or --debug.")
|
||||
if not (args.validate or args.debug):
|
||||
parser.error("Please provide either --validate <yaml_file> or --debug <repo_id>.")
|
||||
if args.truncate != -1 and not args.debug:
|
||||
parser.error("--truncate and --debug should be used together for debugging")
|
||||
|
||||
if args.check_ops or args.validate:
|
||||
with open(args.input, 'r') as f:
|
||||
if args.validate:
|
||||
with open(args.validate, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert all(repo["download_path"] is not None for repo in data["repositories"].values()), "please run `download_models.py` for this yaml"
|
||||
model_paths = {
|
||||
@@ -101,22 +80,16 @@ if __name__ == "__main__":
|
||||
if model["file"].endswith(".onnx")
|
||||
}
|
||||
|
||||
if args.check_ops:
|
||||
pprint.pprint(retrieve_op_stats(model_paths))
|
||||
|
||||
if args.validate:
|
||||
validate_repos(model_paths)
|
||||
validate_repos(model_paths)
|
||||
|
||||
if args.debug:
|
||||
from huggingface_hub import snapshot_download
|
||||
download_dir = Path(__file__).parent / "models"
|
||||
path:list[str] = args.debug.split("/")
|
||||
if len(path) == 2:
|
||||
# repo id
|
||||
# validates all onnx models inside repo
|
||||
repo_id = "/".join(path)
|
||||
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=["*.onnx", "*.onnx_data"], cache_dir=download_dir))
|
||||
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
|
||||
root_path = snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*.onnx", "*.onnx_data"], cache_dir=DOWNLOADS_DIR)
|
||||
snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=DOWNLOADS_DIR)
|
||||
config = get_config(root_path)
|
||||
for onnx_model in root_path.rglob("*.onnx"):
|
||||
rtol, atol = get_tolerances(onnx_model.name)
|
||||
@@ -128,8 +101,8 @@ if __name__ == "__main__":
|
||||
onnx_model = path[-1]
|
||||
assert path[-1].endswith(".onnx")
|
||||
repo_id, relative_path = "/".join(path[:2]), "/".join(path[2:])
|
||||
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=download_dir))
|
||||
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
|
||||
root_path = snapshot_download_with_retry(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=DOWNLOADS_DIR)
|
||||
snapshot_download_with_retry(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=DOWNLOADS_DIR)
|
||||
config = get_config(root_path)
|
||||
rtol, atol = get_tolerances(onnx_model)
|
||||
print(f"validating {relative_path} with truncate={args.truncate}, {rtol=}, {atol=}")
|
||||
|
||||
@@ -48,6 +48,9 @@ class Domain(enum.Enum):
|
||||
AI_ONNX_TRAINING = "ai.onnx.training"
|
||||
AI_ONNX_PREVIEW_TRAINING = "ai.onnx.preview.training"
|
||||
MICROSOFT_CONTRIB_OPS = "com.microsoft"
|
||||
MICROSOFT_NCHWC = "com.microsoft.nchwc"
|
||||
MICROSOFT_EXPERIMENTAL = "com.microsoft.experimental"
|
||||
PYTORCH_ATEN = "org.pytorch.aten"
|
||||
@classmethod
|
||||
def from_onnx(cls, domain: str | None) -> "Domain": return cls.ONNX if domain is None or domain == "" else cls(domain)
|
||||
|
||||
|
||||
@@ -9,7 +9,15 @@ except ModuleNotFoundError:
|
||||
raise unittest.SkipTest("onnx not installed, skipping onnx test")
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import CI, fetch, temp
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import CI, fetch, temp, Context
|
||||
|
||||
try:
|
||||
from extra.onnx_helpers import validate
|
||||
from extra.huggingface_onnx.huggingface_manager import DOWNLOADS_DIR, snapshot_download_with_retry
|
||||
HUGGINGFACE_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
HUGGINGFACE_AVAILABLE = False
|
||||
|
||||
def run_onnx_torch(onnx_model, inputs):
|
||||
import torch
|
||||
@@ -137,5 +145,36 @@ class TestOnnxModel(unittest.TestCase):
|
||||
print(cls, _LABELS[cls])
|
||||
assert "car" in _LABELS[cls] or _LABELS[cls] == "convertible"
|
||||
|
||||
@unittest.skipUnless(HUGGINGFACE_AVAILABLE and Device.DEFAULT == "METAL", "only run on METAL")
|
||||
class TestHuggingFaceOnnxModels(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._ctx = Context(MAX_BUFFER_SIZE=0)
|
||||
cls._ctx.__enter__()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._ctx.__exit__()
|
||||
|
||||
def _validate(self, repo_id, model_file, custom_inputs, rtol=1e-4, atol=1e-4):
|
||||
onnx_model_path = snapshot_download_with_retry(
|
||||
repo_id=repo_id,
|
||||
allow_patterns=["*.onnx", "*.onnx_data"],
|
||||
cache_dir=str(DOWNLOADS_DIR)
|
||||
)
|
||||
onnx_model_path = onnx_model_path / model_file
|
||||
file_size = onnx_model_path.stat().st_size
|
||||
print(f"Validating model: {repo_id}/{model_file} ({file_size/1e6:.2f}M)")
|
||||
validate(onnx_model_path, custom_inputs, rtol=rtol, atol=atol)
|
||||
|
||||
def test_xlm_roberta_large(self):
|
||||
repo_id = "FacebookAI/xlm-roberta-large"
|
||||
model_file = "onnx/model.onnx"
|
||||
custom_inputs = {
|
||||
"input_ids": np.random.randint(0, 250002, (1, 11), dtype=np.int64),
|
||||
"attention_mask": np.ones((1, 11), dtype=np.int64),
|
||||
}
|
||||
self._validate(repo_id, model_file, custom_inputs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -3,8 +3,8 @@ from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import Any, Generic, TypeVar, Iterator
|
||||
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal, time
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, \
|
||||
colored, Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, \
|
||||
Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -124,7 +124,7 @@ class Buffer:
|
||||
def allocate(self, opaque=None, external_ptr=None) -> Buffer:
|
||||
assert not self.is_initialized(), "can't allocate already allocated buffer"
|
||||
if DEBUG >= 7: print(f"buffer: allocate {self.nbytes} bytes on {self.device}")
|
||||
if (mbs:=getenv("MAX_BUFFER_SIZE", 0)) > 0 and self.size > mbs: raise RuntimeError(f"buffer of size {self.size/1e6:.2f}M is too large")
|
||||
if MAX_BUFFER_SIZE > 0 and self.size > MAX_BUFFER_SIZE: raise RuntimeError(f"buffer of size {self.size/1e6:.2f}M is too large")
|
||||
self.allocator:Allocator = Device[self.device].allocator
|
||||
if external_ptr is not None:
|
||||
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr)
|
||||
|
||||
@@ -139,7 +139,7 @@ DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 0)
|
||||
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
|
||||
QUANTIZE, VALIDATE_WITH_CPU = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0)
|
||||
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
|
||||
ALLOW_DEVICE_USAGE, AMD_LLVM = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("AMD_LLVM", 1)
|
||||
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, AMD_LLVM = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0), ContextVar("AMD_LLVM", 1)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
Reference in New Issue
Block a user