Merge branch 'development' into development

This commit is contained in:
Peter Baylies
2022-09-17 20:35:48 -04:00
committed by GitHub
7 changed files with 209 additions and 184 deletions

View File

@@ -2,7 +2,10 @@
The Args class parses both the command line (shell) arguments, as well as the
command string passed at the dream> prompt. It serves as the definitive repository
of all the arguments used by Generate and their default values.
of all the arguments used by Generate and their default values, and implements the
preliminary metadata standards discussed here:
https://github.com/lstein/stable-diffusion/issues/266
To use:
opt = Args()
@@ -52,10 +55,32 @@ you wish to apply logic as to which one to use. For example:
To add new attributes, edit the _create_arg_parser() and
_create_dream_cmd_parser() methods.
We also export the function build_metadata
**Generating and retrieving sd-metadata**
To generate a dict representing RFC266 metadata:
metadata = metadata_dumps(opt,<seeds,model_hash,postprocesser>)
This will generate an RFC266 dictionary that can then be turned into a JSON
and written to the PNG file. The optional seeds, weights, model_hash and
postprocesser arguments are not available to the opt object and so must be
provided externally. See how dream.py does it.
Note that this function was originally called format_metadata() and a wrapper
is provided that issues a deprecation notice.
To retrieve a (series of) opt objects corresponding to the metadata, do this:
opt_list = metadata_loads(metadata)
The metadata should be pulled out of the PNG image. pngwriter has a method
retrieve_metadata that will do this.
"""
import argparse
from argparse import Namespace
import shlex
import json
import hashlib
@@ -540,17 +565,20 @@ class Args(object):
)
return parser
# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266
# it does not write all the required top-level metadata, writes too much image
# data, and doesn't support grids yet. But you gotta start somewhere, no?
def format_metadata(opt,
seeds=[],
weights=None,
model_hash=None,
postprocessing=None):
def format_metadata(**kwargs):
print(f'format_metadata() is deprecated. Please use metadata_dumps()')
return metadata_dumps(kwargs)
def metadata_dumps(opt,
seeds=[],
model_hash=None,
postprocessing=None):
'''
Given an Args object, returns a partial implementation of
the stable diffusion metadata standard
Given an Args object, returns a dict containing the keys and
structure of the proposed stable diffusion metadata standard
https://github.com/lstein/stable-diffusion/discussions/392
This is intended to be turned into JSON and stored in the
"sd
'''
# add some RFC266 fields that are generated internally, and not as
# user args
@@ -598,6 +626,9 @@ def format_metadata(opt,
rfc_dict['type'] = 'txt2img'
images = []
if len(seeds)==0 and opt.seed:
seeds=[seed]
for seed in seeds:
rfc_dict['seed'] = seed
images.append(copy.copy(rfc_dict))
@@ -611,6 +642,27 @@ def format_metadata(opt,
'images' : images,
}
def metadata_loads(metadata):
'''
Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266)
and returns a series of opt objects for each of the images described in the dictionary.
'''
results = []
try:
images = metadata['sd-metadata']['images']
for image in images:
# repack the prompt and variations
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']])
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
opt = Args()
opt._cmd_switches = Namespace(**image)
results.append(opt)
except KeyError as e:
import sys, traceback
print('>> badly-formatted metadata',file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return results
# image can either be a file path on disk or a base64-encoded
# representation of the file's contents
def calculate_init_img_hash(image_string):

View File

@@ -4,7 +4,7 @@ import copy
import base64
import mimetypes
import os
from ldm.dream.args import Args, format_metadata
from ldm.dream.args import Args, metadata_dumps
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter
from threading import Event
@@ -177,10 +177,9 @@ class DreamServer(BaseHTTPRequestHandler):
path = pngwriter.save_image_and_prompt_to_png(
image,
dream_prompt = formatted_prompt,
metadata = format_metadata(iter_opt,
seeds = [seed],
weights = self.model.weights,
model_hash = self.model.model_hash
metadata = metadata_dumps(iter_opt,
seeds = [seed],
model_hash = self.model.model_hash
),
name = name,
)

View File

@@ -90,7 +90,7 @@ class LinearAttention(nn.Module):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
@@ -167,101 +167,85 @@ class CrossAttention(nn.Module):
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
if torch.cuda.is_available():
self.einsum_op = self.einsum_op_cuda
else:
self.mem_total = psutil.virtual_memory().total / (1024**3)
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
def einsum_op_compvis(self, q, k, v, r1):
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1 = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_mps_v1(self, q, k, v, r1):
def einsum_op_compvis(self, q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
return einsum('b i j, b j d -> b i d', s, v)
def einsum_op_slice_0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
return r
def einsum_op_slice_1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
r1 = self.einsum_op_compvis(q, k, v, r1)
return self.einsum_op_compvis(q, k, v)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
return self.einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v, r1):
if self.mem_total >= 8 and q.shape[1] <= 4096:
r1 = self.einsum_op_compvis(q, k, v, r1)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_op_compvis(q, k, v)
else:
slice_size = 1
for i in range(0, q.shape[0], slice_size):
end = min(q.shape[0], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
return r1
def einsum_op_cuda(self, q, k, v, r1):
return self.einsum_op_slice_0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_op_compvis(q, k, v)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
mem_required = tensor_size * 2.5
steps = 1
def einsum_op(self, q, k, v):
if q.device.type == 'cuda':
return self.einsum_op_cuda(q, k, v)
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
if q.device.type == 'mps':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = min(q.shape[1], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
q = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
v_in = self.to_v(context)
device_type = 'mps' if x.device.type == 'mps' else 'cuda'
k = self.to_k(context) * self.scale
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
r1 = self.einsum_op(q, k, v, r1)
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = self.einsum_op(q, k, v)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
class BasicTransformerBlock(nn.Module):

View File

@@ -3,6 +3,7 @@ import gc
import math
import torch
import torch.nn as nn
from torch.nn.functional import silu
import numpy as np
from einops import rearrange
@@ -32,11 +33,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
@@ -122,14 +118,14 @@ class ResnetBlock(nn.Module):
def forward(self, x, temb):
h = self.norm1(x)
h = nonlinearity(h)
h = silu(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = h + self.temb_proj(silu(temb))[:,:,None,None]
h = self.norm2(h)
h = nonlinearity(h)
h = silu(h)
h = self.dropout(h)
h = self.conv2(h)
@@ -368,7 +364,7 @@ class Model(nn.Module):
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = silu(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
@@ -402,7 +398,7 @@ class Model(nn.Module):
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h
@@ -499,7 +495,7 @@ class Encoder(nn.Module):
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h
@@ -611,7 +607,7 @@ class Decoder(nn.Module):
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
@@ -649,7 +645,7 @@ class SimpleDecoder(nn.Module):
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
h = silu(h)
x = self.conv_out(h)
return x
@@ -697,7 +693,7 @@ class UpsampleDecoder(nn.Module):
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h
@@ -873,7 +869,7 @@ class FirstStagePostProcessor(nn.Module):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
z = nonlinearity(z)
z = silu(z)
for submodel, downmodel in zip(self.model,self.downsampler):
z = submodel(z,temb=None)

View File

@@ -252,12 +252,6 @@ def normalization(channels):
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)