Merge branch 'release/2.3.3-rc3' into feat/lora-support-2.3

This commit is contained in:
Lincoln Stein
2023-03-31 00:33:47 -04:00

View File

@@ -614,30 +614,11 @@ class DDPM(pl.LightningModule):
class LatentDiffusion(DDPM):
"""main class"""
@staticmethod
def _fallback_personalization_config()->dict:
"""
This protects us against custom legacy config files that
don't contain the personalization_config section.
"""
return OmegaConf.create(
dict(
target='ldm.modules.embedding_manager.EmbeddingManager',
params=dict(
placeholder_strings=list('*'),
initializer_words=list('sculpture'),
per_image_tokens=False,
num_vectors_per_token=1,
progressive_words=False,
)
)
)
def __init__(
self,
first_stage_config,
cond_stage_config,
personalization_config=_fallback_personalization_config(),
personalization_config=None,
num_timesteps_cond=None,
cond_stage_key='image',
cond_stage_trainable=False,
@@ -695,7 +676,8 @@ class LatentDiffusion(DDPM):
self.model.train = disabled_train
for param in self.model.parameters():
param.requires_grad = False
personalization_config = personalization_config or self._fallback_personalization_config()
self.embedding_manager = self.instantiate_embedding_manager(
personalization_config, self.cond_stage_model
)
@@ -2170,6 +2152,25 @@ class LatentDiffusion(DDPM):
self.emb_ckpt_counter += 500
@classmethod
def _fallback_personalization_config(self)->dict:
"""
This protects us against custom legacy config files that
don't contain the personalization_config section.
"""
return OmegaConf.create(
dict(
target='ldm.modules.embedding_manager.EmbeddingManager',
params=dict(
placeholder_strings=list('*'),
initializer_words=list('sculpture'),
per_image_tokens=False,
num_vectors_per_token=1,
progressive_words=False,
)
)
)
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):