mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Remove no longer used code paths, general cleanup of new dequantization code, update probe
This commit is contained in:
committed by
Kent Keirsey
parent
7d9f125232
commit
446e2884bc
@@ -30,7 +30,7 @@ from invokeai.backend.model_manager.config import (
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFTensor
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
@@ -482,7 +482,7 @@ class CheckpointProbeBase(ProbeBase):
|
||||
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
|
||||
):
|
||||
return ModelFormat.BnbQuantizednf4b
|
||||
elif any(isinstance(v, GGUFTensor) for v in state_dict.values()):
|
||||
elif any(isinstance(v, GGMLTensor) for v in state_dict.values()):
|
||||
return ModelFormat.GGUFQuantized
|
||||
return ModelFormat("checkpoint")
|
||||
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.utils import dequantize_tensor, is_quantized
|
||||
|
||||
PATCH_TYPES = Union[list[torch.Tensor], tuple[torch.Tensor]]
|
||||
|
||||
|
||||
class GGUFTensor(torch.Tensor):
|
||||
"""
|
||||
Main tensor-like class for storing quantized weights.
|
||||
Inherits from torch.Tensor and adds additional attributes.
|
||||
"""
|
||||
|
||||
tensor_type: Union[torch.dtype, gguf.GGMLQuantizationType, None]
|
||||
tensor_shape: torch.Size
|
||||
patches: List[Callable[[torch.Tensor], torch.Tensor]]
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
data,
|
||||
tensor_type: Union[torch.dtype, gguf.GGMLQuantizationType],
|
||||
tensor_shape: torch.Size,
|
||||
patches: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Create a new tensor instance using the superclass method
|
||||
if isinstance(data, torch.Tensor):
|
||||
tensor = data.as_subclass(cls)
|
||||
else:
|
||||
tensor = torch.tensor(data, **kwargs).as_subclass(cls)
|
||||
# Set the additional attributes
|
||||
tensor.tensor_type = tensor_type
|
||||
tensor.tensor_shape = tensor_shape
|
||||
tensor.patches = patches or []
|
||||
return tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data,
|
||||
tensor_type: Union[torch.dtype, gguf.GGMLQuantizationType],
|
||||
tensor_shape: torch.Size,
|
||||
patches: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# __init__ is not called for torch.Tensor subclasses
|
||||
pass
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# Create a new tensor with the desired type/device and copy attributes
|
||||
new = super().to(*args, **kwargs)
|
||||
new = new.as_subclass(GGUFTensor)
|
||||
new.tensor_type = getattr(self, "tensor_type", self.dtype)
|
||||
new.tensor_shape = getattr(self, "tensor_shape", self.size())
|
||||
new.patches = getattr(self, "patches", []).copy()
|
||||
return new
|
||||
|
||||
def clone(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def detach(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def copy_(self, *args, **kwargs):
|
||||
# Attempt to copy data into the tensor; handle exceptions gracefully
|
||||
try:
|
||||
new = super().copy_(*args, **kwargs)
|
||||
new = new.as_subclass(GGUFTensor)
|
||||
new.tensor_type = getattr(self, "tensor_type", self.dtype)
|
||||
new.tensor_shape = getattr(self, "tensor_shape", self.size())
|
||||
new.patches = getattr(self, "patches", []).copy()
|
||||
return new
|
||||
except Exception as e:
|
||||
print(f"Ignoring 'copy_' on tensor: {e}")
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
# Create a deep copy of the tensor and copy attributes
|
||||
new = super().__deepcopy__(memo)
|
||||
if isinstance(new, torch.Tensor):
|
||||
new = new.as_subclass(GGUFTensor)
|
||||
new.tensor_type = getattr(self, "tensor_type", self.dtype)
|
||||
new.tensor_shape = getattr(self, "tensor_shape", self.size())
|
||||
new.patches = getattr(self, "patches", []).copy()
|
||||
return new
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if not hasattr(self, "tensor_shape"):
|
||||
self.tensor_shape = self.size()
|
||||
return self.tensor_shape
|
||||
|
||||
|
||||
class GGUFLayer(torch.nn.Module):
|
||||
"""
|
||||
This (should) be responsible for de-quantizing on the fly
|
||||
"""
|
||||
|
||||
dequant_dtype = None
|
||||
patch_dtype = None
|
||||
torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
|
||||
|
||||
def is_ggml_quantized(self, *, weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None):
|
||||
weight = weight if weight is not None else self.weight
|
||||
bias = bias if bias is not None else self.bias
|
||||
weight_quantized = is_quantized(weight)
|
||||
bias_quantized = is_quantized(bias)
|
||||
return weight_quantized or bias_quantized
|
||||
|
||||
def _load_from_state_dict(self, state_dict: dict[str, torch.Tensor], prefix: str, *args, **kwargs):
|
||||
weight, bias = state_dict.get(f"{prefix}weight", None), state_dict.get(f"{prefix}bias", None)
|
||||
if self.is_ggml_quantized(weight=weight, bias=bias):
|
||||
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
||||
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
||||
|
||||
def ggml_load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys: list[str],
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
for k, v in state_dict.items():
|
||||
if k.endswith("weight"):
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif k.endswith("bias") and v is not None:
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
missing_keys.append(k)
|
||||
|
||||
def get_weight(self, tensor: Optional[torch.Tensor], dtype: torch.dtype):
|
||||
if tensor is None:
|
||||
return
|
||||
|
||||
# dequantize tensor while patches load
|
||||
weight = dequantize_tensor(tensor, dtype, self.dequant_dtype)
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Get the size of this model in bytes."""
|
||||
return self.bias.nelement() * self.bias.element_size()
|
||||
|
||||
def cast_bias_weight(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
bias_dtype: Optional[torch.dtype] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if dtype is None:
|
||||
dtype = getattr(input, "dtype", torch.float32)
|
||||
if dtype is None:
|
||||
raise ValueError("dtype is required")
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
bias = self.get_weight(self.bias.to(device), dtype)
|
||||
if bias is not None:
|
||||
bias = bias.to(dtype=bias_dtype, device=device, copy=False)
|
||||
|
||||
weight = self.get_weight(self.weight.to(device), dtype)
|
||||
if weight is not None:
|
||||
weight = weight.to(dtype=dtype, device=device)
|
||||
if weight is None or bias is None:
|
||||
raise ValueError("Weight or bias is None")
|
||||
return weight, bias
|
||||
@@ -1,19 +1,16 @@
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFTensor
|
||||
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
|
||||
|
||||
|
||||
def gguf_sd_loader(path: Path) -> dict[str, GGUFTensor]:
|
||||
def gguf_sd_loader(path: Path) -> dict[str, GGMLTensor]:
|
||||
reader = gguf.GGUFReader(path)
|
||||
|
||||
sd: dict[str, GGUFTensor] = {}
|
||||
sd: dict[str, GGMLTensor] = {}
|
||||
for tensor in reader.tensors:
|
||||
torch_tensor = torch.from_numpy(tensor.data)
|
||||
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
||||
@@ -21,62 +18,3 @@ def gguf_sd_loader(path: Path) -> dict[str, GGUFTensor]:
|
||||
torch_tensor = torch_tensor.view(*shape)
|
||||
sd[tensor.name] = GGMLTensor(torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape)
|
||||
return sd
|
||||
|
||||
|
||||
# def gguf_sd_loader(
|
||||
# path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
|
||||
# ) -> dict[str, GGUFTensor]:
|
||||
# """
|
||||
# Read state dict as fake tensors
|
||||
# """
|
||||
# reader = gguf.GGUFReader(path)
|
||||
|
||||
# prefix_len = len(handle_prefix)
|
||||
# tensor_names = {tensor.name for tensor in reader.tensors}
|
||||
# has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
||||
|
||||
# tensors: list[tuple[str, gguf.ReaderTensor]] = []
|
||||
# for tensor in reader.tensors:
|
||||
# sd_key = tensor_name = tensor.name
|
||||
# if has_prefix:
|
||||
# if not tensor_name.startswith(handle_prefix):
|
||||
# continue
|
||||
# sd_key = tensor_name[prefix_len:]
|
||||
# tensors.append((sd_key, tensor))
|
||||
|
||||
# # detect and verify architecture
|
||||
# compat = None
|
||||
# arch_str = None
|
||||
# arch_field = reader.get_field("general.architecture")
|
||||
# if arch_field is not None:
|
||||
# if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
|
||||
# raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
|
||||
# arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
|
||||
# if arch_str not in {"flux"}:
|
||||
# raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
|
||||
# else:
|
||||
# arch_str = detect_arch({val[0] for val in tensors})
|
||||
# compat = "sd.cpp"
|
||||
|
||||
# # main loading loop
|
||||
# state_dict: dict[str, GGUFTensor] = {}
|
||||
# qtype_dict: dict[str, int] = {}
|
||||
# for sd_key, tensor in tensors:
|
||||
# tensor_name = tensor.name
|
||||
# tensor_type_str = str(tensor.tensor_type)
|
||||
# torch_tensor = torch.from_numpy(tensor.data) # mmap
|
||||
|
||||
# shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
||||
# # Workaround for stable-diffusion.cpp SDXL detection.
|
||||
# if compat == "sd.cpp" and arch_str == "sdxl":
|
||||
# if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
|
||||
# while len(shape) > 2 and shape[-1] == 1:
|
||||
# shape = shape[:-1]
|
||||
|
||||
# # add to state dict
|
||||
# if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
|
||||
# torch_tensor = torch_tensor.view(*shape)
|
||||
# state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
|
||||
# qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
|
||||
|
||||
# return state_dict
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Generator, Optional, Type
|
||||
|
||||
import wrapt
|
||||
from torch import Tensor, bfloat16, dtype, float16, nn
|
||||
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFLayer
|
||||
|
||||
|
||||
class TorchPatcher:
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def wrap(cls) -> Generator[None, None, None]:
|
||||
# Dictionary to store original torch.nn classes for later restoration
|
||||
original_classes: Dict[str, Type[Any]] = {}
|
||||
try:
|
||||
# Iterate over cls's attributes and replace matching torch.nn classes
|
||||
for attr_name in dir(cls):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
# Get the class from cls
|
||||
patcher_class: Type[Any] = getattr(cls, attr_name)
|
||||
|
||||
# Check if torch.nn has a class with the same name
|
||||
if hasattr(nn, attr_name):
|
||||
# Get the original torch.nn class
|
||||
original_class: Type[Any] = getattr(nn, attr_name)
|
||||
|
||||
# Save the original class for restoration later
|
||||
original_classes[attr_name] = original_class
|
||||
|
||||
# Apply the patch
|
||||
patched_class = cls.create_patch_function(patcher_class)(original_class)
|
||||
setattr(nn, attr_name, patched_class)
|
||||
yield
|
||||
finally:
|
||||
# Restore the original torch.nn classes
|
||||
for attr_name, original_class in original_classes.items():
|
||||
setattr(nn, attr_name, original_class)
|
||||
|
||||
@staticmethod
|
||||
def create_patch_function(patcher_attr: Type[Any]) -> Callable[[Type[Any]], Type[Any]]:
|
||||
# Return a new patch_class function specific to this patcher_attr
|
||||
@wrapt.decorator
|
||||
def patch_class(
|
||||
wrapped: Callable[..., Any],
|
||||
instance: Any,
|
||||
args: Any,
|
||||
kwargs: Any,
|
||||
) -> Any:
|
||||
# Call the patcher_attr version of the class
|
||||
return patcher_attr(*args, **kwargs)
|
||||
|
||||
return patch_class
|
||||
|
||||
|
||||
class GGUFPatcher(TorchPatcher):
|
||||
"""
|
||||
Dequantize weights on the fly before doing the compute
|
||||
"""
|
||||
|
||||
class Linear(GGUFLayer, nn.Linear):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.linear(input, weight, bias)
|
||||
|
||||
class Conv2d(GGUFLayer, nn.Conv2d):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
class Embedding(GGUFLayer, nn.Embedding):
|
||||
def forward(self, input: Tensor, out_dtype: Optional[dtype] = None) -> Tensor:
|
||||
output_dtype = out_dtype
|
||||
if not self.weight:
|
||||
raise ValueError("Embedding layer must have a weight")
|
||||
if self.weight.dtype == float16 or self.weight.dtype == bfloat16:
|
||||
out_dtype = None
|
||||
weight, _ = self.cast_bias_weight(input, device=input.device, dtype=out_dtype)
|
||||
return nn.functional.embedding(
|
||||
input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
|
||||
).to(dtype=output_dtype)
|
||||
|
||||
class LayerNorm(GGUFLayer, nn.LayerNorm):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if self.weight is None:
|
||||
return nn.functional.layer_norm(input, self.normalized_shape, None, None, self.eps)
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
class GroupNorm(GGUFLayer, nn.GroupNorm):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
@@ -1,17 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFTensor
|
||||
|
||||
|
||||
def test_ggml_tensor():
|
||||
"""Smoke test that multiplication works on a GGMLTensor."""
|
||||
weight: GGUFTensor = torch.load("tests/assets/gguf_qweight.pt")
|
||||
tensor_shape = weight.tensor_shape
|
||||
tensor_type = weight.tensor_type
|
||||
data = torch.Tensor(weight.data)
|
||||
|
||||
ggml_tensor = GGMLTensor(data, tensor_type, tensor_shape)
|
||||
ones = torch.ones([1], dtype=torch.float32)
|
||||
|
||||
_ = ggml_tensor * ones
|
||||
@@ -1,124 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFLayer
|
||||
from invokeai.backend.quantization.gguf.torch_patcher import TorchPatcher
|
||||
|
||||
quantized_sd = {
|
||||
"linear.weight": torch.load("tests/assets/gguf_qweight.pt"),
|
||||
"linear.bias": torch.load("tests/assets/gguf_qbias.pt"),
|
||||
}
|
||||
|
||||
|
||||
class TestGGUFPatcher(TorchPatcher):
|
||||
class Linear(GGUFLayer, nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
class Test2GGUFPatcher(TorchPatcher):
|
||||
class Linear(GGUFLayer, nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
# Define a dummy module for testing
|
||||
class DummyModule(nn.Module):
|
||||
def __init__(self, device: str = "cpu", dtype: torch.dtype = torch.float32):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3072, 18432, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
# Test that TorchPatcher patches and unpatches nn.Linear correctly
|
||||
def test_torch_patcher_patches_nn_linear():
|
||||
original_linear = nn.Linear
|
||||
|
||||
with TorchPatcher.wrap():
|
||||
# nn.Linear should not be replaced
|
||||
assert nn.Linear is original_linear
|
||||
assert nn.Linear is original_linear
|
||||
|
||||
|
||||
# Test that GGUFPatcher patches and unpatches nn.Linear correctly
|
||||
def test_gguf_patcher_patches_nn_linear():
|
||||
original_linear = nn.Linear
|
||||
|
||||
with TestGGUFPatcher.wrap():
|
||||
# nn.Linear should be replaced
|
||||
assert nn.Linear is not original_linear
|
||||
# Create a linear layer and check its type
|
||||
linear_layer = nn.Linear(3072, 18432)
|
||||
assert isinstance(linear_layer, TestGGUFPatcher.Linear)
|
||||
# nn.Linear should be restored
|
||||
assert nn.Linear is original_linear
|
||||
|
||||
|
||||
# Test that unpatching restores the original behavior
|
||||
def test_gguf_patcher_unpatch_restores_behavior():
|
||||
device = "cpu"
|
||||
dtype = torch.float32
|
||||
|
||||
input_tensor = torch.randn(1, 3072, device=device, dtype=dtype)
|
||||
model = DummyModule(device=device, dtype=dtype)
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
model.load_state_dict(quantized_sd)
|
||||
|
||||
original_linear = nn.Linear
|
||||
with TestGGUFPatcher.wrap():
|
||||
patched_model = DummyModule(device=device, dtype=dtype)
|
||||
patched_model.load_state_dict(quantized_sd)
|
||||
# Will raise if patch is not applied
|
||||
patched_model(input_tensor)
|
||||
assert isinstance(nn.Linear(4, 8), TestGGUFPatcher.Linear)
|
||||
|
||||
# Ensure nn.Linear is restored
|
||||
assert nn.Linear is not TestGGUFPatcher.Linear
|
||||
assert isinstance(nn.Linear(4, 8), original_linear)
|
||||
|
||||
|
||||
# Test that the patched Linear layer behaves as expected
|
||||
def test_gguf_patcher_linear_layer_behavior():
|
||||
device = "cpu"
|
||||
dtype = torch.float32
|
||||
|
||||
input_tensor = torch.randn(1, 3072, device=device, dtype=dtype)
|
||||
model = DummyModule(device=device, dtype=dtype)
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
model.load_state_dict(quantized_sd)
|
||||
|
||||
with TestGGUFPatcher.wrap():
|
||||
patched_model = DummyModule(device=device, dtype=dtype)
|
||||
patched_model.load_state_dict(quantized_sd)
|
||||
|
||||
patched_tensor = patched_model(input_tensor)
|
||||
|
||||
# After unpatching, run forward and ensure patched classes are still applied
|
||||
assert torch.equal(patched_tensor, patched_model(input_tensor))
|
||||
|
||||
|
||||
# Test that the TorchPatcher works correctly when nesting contexts
|
||||
def test_torch_patcher_nested_contexts():
|
||||
original_linear = nn.Linear
|
||||
|
||||
with TestGGUFPatcher.wrap():
|
||||
# First level patching
|
||||
first_level_linear = nn.Linear
|
||||
assert first_level_linear is not original_linear
|
||||
|
||||
with Test2GGUFPatcher.wrap():
|
||||
# Second level patching
|
||||
second_level_linear = nn.Linear
|
||||
assert second_level_linear is not first_level_linear
|
||||
|
||||
# After exiting inner context, nn.Linear should be restored to first level patch
|
||||
assert nn.Linear is first_level_linear
|
||||
|
||||
# After exiting outer context, nn.Linear should be restored to original
|
||||
assert nn.Linear is original_linear
|
||||
Reference in New Issue
Block a user