mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-14 07:45:12 -05:00
Merge branch 'development' into development
This commit is contained in:
@@ -2,7 +2,10 @@
|
||||
|
||||
The Args class parses both the command line (shell) arguments, as well as the
|
||||
command string passed at the dream> prompt. It serves as the definitive repository
|
||||
of all the arguments used by Generate and their default values.
|
||||
of all the arguments used by Generate and their default values, and implements the
|
||||
preliminary metadata standards discussed here:
|
||||
|
||||
https://github.com/lstein/stable-diffusion/issues/266
|
||||
|
||||
To use:
|
||||
opt = Args()
|
||||
@@ -52,10 +55,32 @@ you wish to apply logic as to which one to use. For example:
|
||||
To add new attributes, edit the _create_arg_parser() and
|
||||
_create_dream_cmd_parser() methods.
|
||||
|
||||
We also export the function build_metadata
|
||||
**Generating and retrieving sd-metadata**
|
||||
|
||||
To generate a dict representing RFC266 metadata:
|
||||
|
||||
metadata = metadata_dumps(opt,<seeds,model_hash,postprocesser>)
|
||||
|
||||
This will generate an RFC266 dictionary that can then be turned into a JSON
|
||||
and written to the PNG file. The optional seeds, weights, model_hash and
|
||||
postprocesser arguments are not available to the opt object and so must be
|
||||
provided externally. See how dream.py does it.
|
||||
|
||||
Note that this function was originally called format_metadata() and a wrapper
|
||||
is provided that issues a deprecation notice.
|
||||
|
||||
To retrieve a (series of) opt objects corresponding to the metadata, do this:
|
||||
|
||||
opt_list = metadata_loads(metadata)
|
||||
|
||||
The metadata should be pulled out of the PNG image. pngwriter has a method
|
||||
retrieve_metadata that will do this.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
import shlex
|
||||
import json
|
||||
import hashlib
|
||||
@@ -540,17 +565,20 @@ class Args(object):
|
||||
)
|
||||
return parser
|
||||
|
||||
# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266
|
||||
# it does not write all the required top-level metadata, writes too much image
|
||||
# data, and doesn't support grids yet. But you gotta start somewhere, no?
|
||||
def format_metadata(opt,
|
||||
seeds=[],
|
||||
weights=None,
|
||||
model_hash=None,
|
||||
postprocessing=None):
|
||||
def format_metadata(**kwargs):
|
||||
print(f'format_metadata() is deprecated. Please use metadata_dumps()')
|
||||
return metadata_dumps(kwargs)
|
||||
|
||||
def metadata_dumps(opt,
|
||||
seeds=[],
|
||||
model_hash=None,
|
||||
postprocessing=None):
|
||||
'''
|
||||
Given an Args object, returns a partial implementation of
|
||||
the stable diffusion metadata standard
|
||||
Given an Args object, returns a dict containing the keys and
|
||||
structure of the proposed stable diffusion metadata standard
|
||||
https://github.com/lstein/stable-diffusion/discussions/392
|
||||
This is intended to be turned into JSON and stored in the
|
||||
"sd
|
||||
'''
|
||||
# add some RFC266 fields that are generated internally, and not as
|
||||
# user args
|
||||
@@ -598,6 +626,9 @@ def format_metadata(opt,
|
||||
rfc_dict['type'] = 'txt2img'
|
||||
|
||||
images = []
|
||||
if len(seeds)==0 and opt.seed:
|
||||
seeds=[seed]
|
||||
|
||||
for seed in seeds:
|
||||
rfc_dict['seed'] = seed
|
||||
images.append(copy.copy(rfc_dict))
|
||||
@@ -611,6 +642,27 @@ def format_metadata(opt,
|
||||
'images' : images,
|
||||
}
|
||||
|
||||
def metadata_loads(metadata):
|
||||
'''
|
||||
Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266)
|
||||
and returns a series of opt objects for each of the images described in the dictionary.
|
||||
'''
|
||||
results = []
|
||||
try:
|
||||
images = metadata['sd-metadata']['images']
|
||||
for image in images:
|
||||
# repack the prompt and variations
|
||||
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']])
|
||||
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
|
||||
opt = Args()
|
||||
opt._cmd_switches = Namespace(**image)
|
||||
results.append(opt)
|
||||
except KeyError as e:
|
||||
import sys, traceback
|
||||
print('>> badly-formatted metadata',file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return results
|
||||
|
||||
# image can either be a file path on disk or a base64-encoded
|
||||
# representation of the file's contents
|
||||
def calculate_init_img_hash(image_string):
|
||||
|
||||
@@ -4,7 +4,7 @@ import copy
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from ldm.dream.args import Args, format_metadata
|
||||
from ldm.dream.args import Args, metadata_dumps
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
from threading import Event
|
||||
@@ -177,10 +177,9 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
path = pngwriter.save_image_and_prompt_to_png(
|
||||
image,
|
||||
dream_prompt = formatted_prompt,
|
||||
metadata = format_metadata(iter_opt,
|
||||
seeds = [seed],
|
||||
weights = self.model.weights,
|
||||
model_hash = self.model.model_hash
|
||||
metadata = metadata_dumps(iter_opt,
|
||||
seeds = [seed],
|
||||
model_hash = self.model.model_hash
|
||||
),
|
||||
name = name,
|
||||
)
|
||||
|
||||
@@ -90,7 +90,7 @@ class LinearAttention(nn.Module):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
@@ -167,101 +167,85 @@ class CrossAttention(nn.Module):
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.einsum_op = self.einsum_op_cuda
|
||||
else:
|
||||
self.mem_total = psutil.virtual_memory().total / (1024**3)
|
||||
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
|
||||
|
||||
def einsum_op_compvis(self, q, k, v, r1):
|
||||
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
r1 = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v, r1):
|
||||
def einsum_op_compvis(self, q, k, v):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
|
||||
def einsum_op_slice_0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||
return r
|
||||
|
||||
def einsum_op_slice_1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
return self.einsum_op_slice_1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v, r1):
|
||||
if self.mem_total >= 8 and q.shape[1] <= 4096:
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
slice_size = 1
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = min(q.shape[0], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
return r1
|
||||
|
||||
def einsum_op_cuda(self, q, k, v, r1):
|
||||
return self.einsum_op_slice_0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
def einsum_op(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
if q.device.type == 'mps':
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = min(q.shape[1], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
device_type = 'mps' if x.device.type == 'mps' else 'cuda'
|
||||
k = self.to_k(context) * self.scale
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
r1 = self.einsum_op(q, k, v, r1)
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = self.einsum_op(q, k, v)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
||||
@@ -3,6 +3,7 @@ import gc
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.functional import silu
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
@@ -32,11 +33,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
@@ -122,14 +118,14 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = self.norm1(x)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
h = h + self.temb_proj(silu(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
@@ -368,7 +364,7 @@ class Model(nn.Module):
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = silu(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
@@ -402,7 +398,7 @@ class Model(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@@ -499,7 +495,7 @@ class Encoder(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@@ -611,7 +607,7 @@ class Decoder(nn.Module):
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
@@ -649,7 +645,7 @@ class SimpleDecoder(nn.Module):
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
@@ -697,7 +693,7 @@ class UpsampleDecoder(nn.Module):
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@@ -873,7 +869,7 @@ class FirstStagePostProcessor(nn.Module):
|
||||
z_fs = self.encode_with_pretrained(x)
|
||||
z = self.proj_norm(z_fs)
|
||||
z = self.proj(z)
|
||||
z = nonlinearity(z)
|
||||
z = silu(z)
|
||||
|
||||
for submodel, downmodel in zip(self.model,self.downsampler):
|
||||
z = submodel(z,temb=None)
|
||||
|
||||
@@ -252,12 +252,6 @@ def normalization(channels):
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
Reference in New Issue
Block a user