Compare commits

...

6 Commits

31 changed files with 289 additions and 63 deletions

View File

@@ -5,7 +5,8 @@ from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase

View File

@@ -6,7 +6,8 @@ from typing import Optional, Type
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.load import (
LoadedModel,
ModelLoaderRegistry,

View File

@@ -1,6 +1,6 @@
"""Initialization file for model manager service."""
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from .model_manager_default import ModelManagerService, ModelManagerServiceBase
@@ -8,7 +8,6 @@ from .model_manager_default import ModelManagerService, ModelManagerServiceBase
__all__ = [
"ModelManagerServiceBase",
"ModelManagerService",
"AnyModel",
"AnyModelConfig",
"BaseModelType",
"ModelType",

View File

@@ -12,7 +12,6 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from ..raw_model import RawModel
from .resampler import Resampler
@@ -102,7 +101,7 @@ class MLPProjModel(torch.nn.Module):
return clip_extra_context_tokens
class IPAdapter(RawModel):
class IPAdapter(torch.nn.Module):
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
def __init__(
@@ -112,6 +111,7 @@ class IPAdapter(RawModel):
dtype: torch.dtype = torch.float16,
num_tokens: int = 4,
):
super().__init__()
self.device = device
self.dtype = dtype

View File

@@ -11,8 +11,6 @@ from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from .raw_model import RawModel
class LoRALayerBase:
# rank: Optional[int]
@@ -368,15 +366,13 @@ class IA3Layer(LoRALayerBase):
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
class LoRAModelRaw(RawModel): # (torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]
class LoRAModelRaw(torch.nn.Module):
def __init__(
self,
name: str,
layers: Dict[str, AnyLoRALayer],
):
super().__init__()
self._name = name
self.layers = layers

View File

@@ -1,7 +1,6 @@
"""Re-export frequently-used symbols from the Model Manager backend."""
from .config import (
AnyModel,
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
@@ -18,7 +17,6 @@ from .probe import ModelProbe
from .search import ModelSearch
__all__ = [
"AnyModel",
"AnyModelConfig",
"BaseModelType",
"ModelRepoVariant",

View File

@@ -0,0 +1,12 @@
from typing import Union
import torch
from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
# ModelMixin is the base class for all diffusers and transformers models
AnyModel = Union[ModelMixin, torch.nn.Module, IPAdapter, LoRAModelRaw, TextualInversionModelRaw, IAIOnnxRuntimeModel]

View File

@@ -24,20 +24,12 @@ import time
from enum import Enum
from typing import Literal, Optional, Type, TypeAlias, Union
import torch
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""

View File

@@ -15,7 +15,7 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
)
from omegaconf import DictConfig
from . import AnyModel
from invokeai.backend.model_manager.any_model_type import AnyModel
def convert_ldm_vae_to_diffusers(

View File

@@ -10,8 +10,8 @@ from pathlib import Path
from typing import Any, Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
SubModelType,
)

View File

@@ -7,11 +7,11 @@ from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
InvalidModelConfigException,
SubModelType,
)
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase

View File

@@ -14,7 +14,8 @@ from typing import Dict, Generic, Optional, TypeVar
import torch
from invokeai.backend.model_manager.config import AnyModel, SubModelType
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.config import SubModelType
class ModelLockerBase(ABC):

View File

@@ -28,7 +28,8 @@ from typing import Dict, List, Optional
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger

View File

@@ -4,7 +4,7 @@ Base class and implementation of a class that moves models in and out of VRAM.
import torch
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.any_model_type import AnyModel
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase

View File

@@ -5,12 +5,12 @@ from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers

View File

@@ -9,7 +9,6 @@ from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
@@ -17,6 +16,7 @@ from invokeai.backend.model_manager import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.config import DiffusersConfigBase
from .. import ModelLoader, ModelLoaderRegistry

View File

@@ -7,9 +7,9 @@ from typing import Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
from invokeai.backend.raw_model import RawModel
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
@@ -25,7 +25,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
if submodel_type is not None:
raise ValueError("There are no submodels in an IP-Adapter model.")
model_path = Path(config.path)
model: RawModel = build_ip_adapter(
model = build_ip_adapter(
ip_adapter_ckpt_path=model_path,
device=torch.device("cpu"),
dtype=self._torch_dtype,

View File

@@ -8,13 +8,13 @@ from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase

View File

@@ -6,13 +6,13 @@ from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.any_model_type import AnyModel
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader

View File

@@ -5,7 +5,6 @@ from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
@@ -13,6 +12,7 @@ from invokeai.backend.model_manager import (
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
DiffusersConfigBase,

View File

@@ -5,13 +5,13 @@ from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from .. import ModelLoader, ModelLoaderRegistry

View File

@@ -14,7 +14,8 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from .. import ModelLoaderRegistry

View File

@@ -8,7 +8,7 @@ from typing import Optional
import torch
from diffusers import DiffusionPipeline
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel

View File

@@ -17,7 +17,7 @@ def skip_torch_weight_init() -> Generator[None, None, None]:
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding, torch.nn.LayerNorm]
saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules]
try:

View File

@@ -13,7 +13,7 @@ from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.any_model_type import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel

View File

@@ -6,17 +6,16 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np
import onnx
import torch
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from ..raw_model import RawModel
ONNX_WEIGHTS_NAME = "model.onnx"
# NOTE FROM LS: This was copied from Stalker's original implementation.
# I have not yet gone through and fixed all the type hints
class IAIOnnxRuntimeModel(RawModel):
class IAIOnnxRuntimeModel(torch.nn.Module):
class _tensor_access:
def __init__(self, model): # type: ignore
self.model = model
@@ -103,7 +102,7 @@ class IAIOnnxRuntimeModel(RawModel):
self.proto = onnx.load(model_path, load_external_data=False)
"""
super().__init__()
self.proto = onnx.load(model_path, load_external_data=True)
# self.data = dict()
# for tensor in self.proto.graph.initializer:

View File

@@ -0,0 +1,50 @@
from pathlib import Path
from typing import Optional, Union
import torch
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.peft.sdxl_format_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.util.serialization import load_state_dict
class PeftModel:
"""A class for loading and managing parameter-efficient fine-tuning models."""
def __init__(
self,
name: str,
state_dict: dict[str, torch.Tensor],
network_alphas: dict[str, torch.Tensor],
):
self.name = name
self.state_dict = state_dict
self.network_alphas = network_alphas
def calc_size(self) -> int:
model_size = 0
for tensor in self.state_dict.values():
model_size += tensor.nelement() * tensor.element_size()
return model_size
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
file_path = Path(file_path)
state_dict = load_state_dict(file_path, device=str(device))
if base_model == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
# TODO(ryand): We shouldn't be using an unexported function from diffusers here. Consider opening an upstream PR
# to move this function to state_dict_utils.py.
# state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
return cls(name=file_path.stem, state_dict=state_dict, network_alphas=network_alphas)

View File

@@ -0,0 +1,155 @@
import bisect
import torch
def make_sdxl_unet_conversion_map() -> list[tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.
Ported from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
"""
unet_conversion_map_layer: list[tuple[str, str]] = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map: list[tuple[str, str]] = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
}
def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict: dict[str, torch.Tensor] = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict

View File

@@ -1,15 +0,0 @@
"""Base class for 'Raw' models.
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
and is used for type checking of calls to the model patcher. Its main purpose
is to avoid a circular import issues when lora.py tries to import BaseModelType
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
from lora.py.
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes.
"""
class RawModel:
"""Base class for 'Raw' model wrappers."""

View File

@@ -9,10 +9,8 @@ from safetensors.torch import load_file
from transformers import CLIPTokenizer
from typing_extensions import Self
from .raw_model import RawModel
class TextualInversionModelRaw(RawModel):
class TextualInversionModelRaw(torch.nn.Module):
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models

View File

@@ -0,0 +1,37 @@
from pathlib import Path
from typing import Any, Optional, Union
import torch
from safetensors.torch import load_file
def state_dict_to(
state_dict: dict[str, torch.Tensor], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> dict[str, torch.Tensor]:
new_state_dict: dict[str, torch.Tensor] = {}
for k, v in state_dict.items():
new_state_dict[k] = v.to(device=device, dtype=dtype, non_blocking=True)
return new_state_dict
def load_state_dict(file_path: Union[str, Path], device: str = "cpu") -> Any:
"""Load a state_dict from a file that may be in either PyTorch or safetensors format. The file format is inferred
from the file extension.
"""
file_path = Path(file_path)
if file_path.suffix == ".safetensors":
state_dict = load_file(
file_path,
device=device,
)
else:
# weights_only=True is used to address a security vulnerability that allows arbitrary code execution.
# This option was first introduced in https://github.com/pytorch/pytorch/pull/86812.
#
# mmap=True is used to both reduce memory usage and speed up loading. This setting causes torch.load() to more
# closely mirror the behaviour of safetensors.torch.load_file(). This option was first introduced in
# https://github.com/pytorch/pytorch/pull/102549. The discussion on that PR provides helpful context.
state_dict = torch.load(file_path, map_location=device, weights_only=True, mmap=True)
return state_dict