stable diffusion

This commit is contained in:
Robin Rombach
2022-08-10 16:30:49 +02:00
parent f13bf9bf46
commit 2ff270f4e0
43 changed files with 3668 additions and 214 deletions

View File

@@ -455,7 +455,7 @@ class UNetModel(nn.Module):
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
@@ -464,21 +464,28 @@ class UNetModel(nn.Module):
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None # custom support for prediction of discrete ids into codebook of first stage vq model
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
@@ -532,13 +539,20 @@ class UNetModel(nn.Module):
]
ch = mult * model_channels
if ds in attention_resolutions:
dim_head = ch // num_heads
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
@@ -572,7 +586,14 @@ class UNetModel(nn.Module):
ds *= 2
self._feature_size += ch
dim_head = ch // num_heads
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
@@ -586,7 +607,7 @@ class UNetModel(nn.Module):
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
@@ -619,13 +640,20 @@ class UNetModel(nn.Module):
]
ch = model_channels * mult
if ds in attention_resolutions:
dim_head = ch // num_heads
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
@@ -691,7 +719,6 @@ class UNetModel(nn.Module):
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
assert timesteps is not None, 'need to implement no-timestep usage'
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
@@ -710,14 +737,12 @@ class UNetModel(nn.Module):
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
#return self.out(h), self.id_predictor(h)
return self.id_predictor(h)
else:
return self.out(h)
class EncoderUNetModel(nn.Module):
# TODO: do we use it ?
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.