From af7745073fe3ad5600cb8bb586aff60bd5fd638a Mon Sep 17 00:00:00 2001 From: Kirill Date: Sun, 12 Mar 2023 20:56:49 +0300 Subject: [PATCH] Add comments to SD (#686) * Add explanation for empty lambdas * Fix my_unpickle if pytorch_lightning is installed * oops --- examples/stable_diffusion.py | 4 ++-- extra/utils.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 4df172fe5a..82d8faca59 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -163,7 +163,7 @@ class ResBlock: self.out_layers = [ GroupNorm(32, out_channels), Tensor.silu, - lambda x: x, + lambda x: x, # needed for weights loading code to work Conv2d(out_channels, out_channels, 3, padding=1) ] self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x @@ -213,7 +213,7 @@ class FeedForward: def __init__(self, dim, mult=4): self.net = [ GEGLU(dim, dim*mult), - lambda x: x, + lambda x: x, # needed for weights loading code to work Linear(dim*mult, dim) ] diff --git a/extra/utils.py b/extra/utils.py index ffb3876e97..8f18269b8a 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -74,8 +74,9 @@ def my_unpickle(fb0): elif name == "_rebuild_parameter": return HackParameter else: + if module.startswith('pytorch_lightning'): return Dummy try: - return pickle.Unpickler.find_class(self, module, name) + return super().find_class(module, name) except Exception: return Dummy