mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-01 03:01:13 -04:00
- Cache stat collection enabled. - Implemented ONNX loading. - Add ability to specify the repo version variant in installer CLI. - If caller asks for a repo version that doesn't exist, will fall back to empty version rather than raising an error.
42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
|
"""Class for Onnx model loading in InvokeAI."""
|
|
|
|
# This should work the same as Stable Diffusion pipelines
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from invokeai.backend.model_manager import (
|
|
AnyModel,
|
|
BaseModelType,
|
|
ModelFormat,
|
|
ModelRepoVariant,
|
|
ModelType,
|
|
SubModelType,
|
|
)
|
|
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
|
|
|
|
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
|
class OnnyxDiffusersModel(ModelLoader):
|
|
"""Class to load onnx models."""
|
|
|
|
def _load_model(
|
|
self,
|
|
model_path: Path,
|
|
model_variant: Optional[ModelRepoVariant] = None,
|
|
submodel_type: Optional[SubModelType] = None,
|
|
) -> AnyModel:
|
|
if not submodel_type is not None:
|
|
raise Exception("A submodel type must be provided when loading onnx pipelines.")
|
|
load_class = self._get_hf_load_class(model_path, submodel_type)
|
|
variant = model_variant.value if model_variant else None
|
|
model_path = model_path / submodel_type.value
|
|
result: AnyModel = load_class.from_pretrained(
|
|
model_path,
|
|
torch_dtype=self._torch_dtype,
|
|
variant=variant,
|
|
) # type: ignore
|
|
return result
|