mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Fix bug in skip_torch_weight_init() where the original behavior of torch.nn.Conv*d modules wasn't being restored correctly.
This commit is contained in:
@@ -42,3 +42,29 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
||||
assert reset_params_fn_before is reset_params_fn_after
|
||||
assert torch.allclose(layer_before.weight, layer_after.weight)
|
||||
assert torch.allclose(layer_before.bias, layer_after.bias)
|
||||
|
||||
|
||||
def test_skip_torch_weight_init_restores_base_class_behavior():
|
||||
"""Test that `skip_torch_weight_init()` correctly restores the original behavior of torch.nn.Conv*d modules. This
|
||||
test was created to catch a previous bug where `reset_parameters` was being copied from the base `_ConvNd` class to
|
||||
its child classes (like `Conv1d`).
|
||||
"""
|
||||
with skip_torch_weight_init():
|
||||
# There is no need to do anything while the context manager is applied, we're just testing that the original
|
||||
# behavior is restored correctly.
|
||||
pass
|
||||
|
||||
# Mock the behavior of another library that monkey patches `torch.nn.modules.conv._ConvNd.reset_parameters` and
|
||||
# expects it to affect all of the sub-classes (e.g. `torch.nn.Conv1D`, `torch.nn.Conv2D`, etc.).
|
||||
called_monkey_patched_fn = False
|
||||
|
||||
def monkey_patched_fn(*args, **kwargs):
|
||||
nonlocal called_monkey_patched_fn
|
||||
called_monkey_patched_fn = True
|
||||
|
||||
saved_fn = torch.nn.modules.conv._ConvNd.reset_parameters
|
||||
torch.nn.modules.conv._ConvNd.reset_parameters = monkey_patched_fn
|
||||
_ = torch.nn.Conv1d(10, 20, 3)
|
||||
torch.nn.modules.conv._ConvNd.reset_parameters = saved_fn
|
||||
|
||||
assert called_monkey_patched_fn == True
|
||||
|
||||
Reference in New Issue
Block a user