Merge branch 'development' into Improved-fetch-and-option-to-replay-commands-from-file

This commit is contained in:
ArDiouscuros
2022-10-08 13:26:22 +02:00
committed by GitHub
239 changed files with 8262 additions and 3944 deletions

View File

@@ -117,7 +117,7 @@ class PersonalizedBase(Dataset):
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
]
# self._length = len(self.image_paths)

View File

@@ -93,7 +93,7 @@ class PersonalizedBase(Dataset):
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
]
# self._length = len(self.image_paths)

View File

@@ -185,15 +185,21 @@ class Args(object):
switches.append(f'-W {a["width"]}')
switches.append(f'-H {a["height"]}')
switches.append(f'-C {a["cfg_scale"]}')
if a['perlin'] > 0:
switches.append(f'--perlin {a["perlin"]}')
if a['threshold'] > 0:
switches.append(f'--threshold {a["threshold"]}')
if a['grid']:
switches.append('--grid')
if a['seamless']:
switches.append('--seamless')
if a['hires_fix']:
switches.append('--hires_fix')
# img2img generations have parameters relevant only to them and have special handling
if a['init_img'] and len(a['init_img'])>0:
switches.append(f'-I {a["init_img"]}')
switches.append(f'-A ddim') # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
switches.append(f'-A {a["sampler_name"]}')
if a['fit']:
switches.append(f'--fit')
if a['init_mask'] and len(a['init_mask'])>0:
@@ -209,6 +215,9 @@ class Args(object):
if a['gfpgan_strength']:
switches.append(f'-G {a["gfpgan_strength"]}')
if a['outcrop']:
switches.append(f'-c {" ".join([str(u) for u in a["outcrop"]])}')
# esrgan-specific parameters
if a['upscale']:
switches.append(f'-U {" ".join([str(u) for u in a["upscale"]])}')
@@ -227,8 +236,8 @@ class Args(object):
# 2. However, they come out of the CLI (and probably web) with the keyword "with_variations" and
# in broken-out form. Variation (1) should be changed to comply with (2)
if a['with_variations']:
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["variations"]))
switches.append(f'-V {a["formatted_variations"]}')
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"]))
switches.append(f'-V {formatted_variations}')
if 'variations' in a:
switches.append(f'-V {a["variations"]}')
return ' '.join(switches)
@@ -274,7 +283,10 @@ class Args(object):
# the arg value. For example, the --grid and --individual options are a little
# funny because of their push/pull relationship. This is how to handle it.
if name=='grid':
return not cmd_switches.individual and value_arg # arg supersedes cmd
if cmd_switches.individual:
return False
else:
return value_cmd or value_arg
return value_cmd if value_cmd is not None else value_arg
def __setattr__(self,name,value):
@@ -458,6 +470,12 @@ class Args(object):
default='9090',
help='Web server: Port to listen on'
)
web_server_group.add_argument(
'--gui',
dest='gui',
action='store_true',
help='Start InvokeAI GUI',
)
return parser
# This creates the parser that processes commands on the dream> command line
@@ -529,6 +547,18 @@ class Args(object):
type=float,
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
)
render_group.add_argument(
'--threshold',
default=0.0,
type=float,
help='Latent threshold for classifier free guidance (CFG) - prevent generator from "trying" too hard. Use positive values, 0 disables.',
)
render_group.add_argument(
'--perlin',
default=0.0,
type=float,
help='Perlin noise scale (0.0 - 1.0) - add perlin noise to the initialization instead of the usual gaussian noise.',
)
render_group.add_argument(
'--grid',
'-g',
@@ -569,6 +599,12 @@ class Args(object):
type=str,
help='Directory to save generated images and a log of prompts and seeds',
)
render_group.add_argument(
'--hires_fix',
action='store_true',
dest='hires_fix',
help='Create hires image using img2img to prevent duplicated objects'
)
img2img_group.add_argument(
'-I',
'--init_img',
@@ -608,6 +644,14 @@ class Args(object):
metavar=('direction', 'pixels'),
help='Direction to extend the given image (left|right|top|bottom). If a distance pixel value is not specified it defaults to half the image size'
)
img2img_group.add_argument(
'-c',
'--outcrop',
nargs='+',
type=str,
metavar=('direction','pixels'),
help='Outcrop the image with one or more direction/pixel pairs: -c top 64 bottom 128 left 64 right 64',
)
postprocessing_group.add_argument(
'-ft',
'--facetool',
@@ -705,27 +749,15 @@ def metadata_dumps(opt,
'app_version' : APP_VERSION,
}
# add some RFC266 fields that are generated internally, and not as
# user args
# # add some RFC266 fields that are generated internally, and not as
# # user args
image_dict = opt.to_dict(
postprocessing=postprocessing
postprocessing=postprocessing
)
# 'postprocessing' is either null or an array of postprocessing metadatal
if postprocessing:
# TODO: This is just a hack until postprocessing pipeline work completed
image_dict['postprocessing'] = []
if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0:
image_dict['postprocessing'].append('GFPGAN (not RFC compliant)')
if image_dict['upscale'] and image_dict['upscale'][0] > 0:
image_dict['postprocessing'].append('ESRGAN (not RFC compliant)')
else:
image_dict['postprocessing'] = None
# remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','step_number','width','height','extra','strength']
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength']
rfc_dict ={}

View File

@@ -10,6 +10,7 @@ from PIL import Image
from einops import rearrange, repeat
from pytorch_lightning import seed_everything
from ldm.dream.devices import choose_autocast
from ldm.util import rand_perlin_2d
downsampling = 8
@@ -37,7 +38,7 @@ class Generator():
self.with_variations = with_variations
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
image_callback=None, step_callback=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
**kwargs):
scope = choose_autocast(self.precision)
make_image = self.get_make_image(
@@ -46,11 +47,13 @@ class Generator():
width = width,
height = height,
step_callback = step_callback,
threshold = threshold,
perlin = perlin,
**kwargs
)
results = []
seed = seed if seed else self.new_seed()
seed = seed if seed is not None else self.new_seed()
first_seed = seed
seed, initial_noise = self.generate_initial_noise(seed, width, height)
with scope(self.model.device.type), self.model.ema_scope():
@@ -65,10 +68,11 @@ class Generator():
x_T = initial_noise
else:
seed_everything(seed)
if self.model.device.type == 'mps':
try:
x_T = self.get_noise(width,height)
except:
pass
# make_image will do the equivalent of get_noise itself
image = make_image(x_T)
results.append([image, seed])
if image_callback is not None:
@@ -117,6 +121,10 @@ class Generator():
"""
raise NotImplementedError("get_noise() must be implemented in a descendent class")
def get_perlin_noise(self,width,height):
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
def new_seed(self):
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
return self.seed

View File

@@ -10,6 +10,7 @@ from PIL import Image
from ldm.dream.generator.base import Generator
from ldm.dream.generator.img2img import Img2Img
from ldm.dream.devices import choose_autocast
from ldm.models.diffusion.ddim import DDIMSampler
class Embiggen(Generator):
def __init__(self, model, precision):
@@ -349,7 +350,7 @@ class Embiggen(Generator):
prompt,
iterations = 1,
seed = seed,
sampler = sampler,
sampler = DDIMSampler(self.model, device=self.model.device),
steps = steps,
cfg_scale = cfg_scale,
conditioning = conditioning,

View File

@@ -13,20 +13,13 @@ class Img2Img(Generator):
super().__init__(model, precision)
self.init_latent = None # by get_noise()
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,**kwargs):
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
# PLMS sampler not supported yet, so ignore previous sampler
if not isinstance(sampler,DDIMSampler):
print(
f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler"
)
sampler = DDIMSampler(self.model, device=self.model.device)
self.perlin = perlin
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
@@ -41,7 +34,6 @@ class Img2Img(Generator):
t_enc = int(strength * steps)
uc, c = conditioning
@torch.no_grad()
def make_image(x_T):
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
@@ -57,7 +49,9 @@ class Img2Img(Generator):
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
init_latent = self.init_latent, # changes how noising is performed in ksampler
)
return self.sample_to_image(samples)
return make_image
@@ -67,6 +61,10 @@ class Img2Img(Generator):
init_latent = self.init_latent
assert init_latent is not None,'call to get_noise() when init_latent not set'
if device.type == 'mps':
return torch.randn_like(init_latent, device='cpu').to(device)
x = torch.randn_like(init_latent, device='cpu').to(device)
else:
return torch.randn_like(init_latent, device=device)
x = torch.randn_like(init_latent, device=device)
if self.perlin > 0.0:
shape = init_latent.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x

View File

@@ -8,6 +8,7 @@ from einops import rearrange, repeat
from ldm.dream.devices import choose_autocast
from ldm.dream.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ksampler import KSampler
class Inpaint(Img2Img):
def __init__(self, model, precision):
@@ -23,21 +24,20 @@ class Inpaint(Img2Img):
the initial image + mask. Return value depends on the seed at
the time you call it. kwargs are 'init_latent' and 'strength'
"""
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
# PLMS sampler not supported yet, so ignore previous sampler
if not isinstance(sampler,DDIMSampler):
# klms samplers not supported yet, so ignore previous sampler
if isinstance(sampler,KSampler):
print(
f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler"
f">> sampler '{sampler.__class__.__name__}' is not yet supported for inpainting, using DDIMSampler instead."
)
sampler = DDIMSampler(self.model, device=self.model.device)
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
@@ -57,7 +57,7 @@ class Inpaint(Img2Img):
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
# decode it
samples = sampler.decode(
z_enc,
@@ -69,6 +69,7 @@ class Inpaint(Img2Img):
mask = mask_image,
init_latent = self.init_latent
)
return self.sample_to_image(samples)
return make_image

View File

@@ -12,12 +12,13 @@ class Txt2Img(Generator):
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,**kwargs):
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
self.perlin = perlin
uc, c = conditioning
@torch.no_grad()
@@ -30,6 +31,8 @@ class Txt2Img(Generator):
if self.free_gpu_mem and self.model.model.device != self.model.device:
self.model.model.to(self.model.device)
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
samples, _ = sampler.sample(
batch_size = 1,
@@ -41,7 +44,8 @@ class Txt2Img(Generator):
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback
img_callback = step_callback,
threshold = threshold,
)
if self.free_gpu_mem:
@@ -56,14 +60,17 @@ class Txt2Img(Generator):
def get_noise(self,width,height):
device = self.model.device
if device.type == 'mps':
return torch.randn([1,
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
else:
return torch.randn([1,
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x

View File

@@ -0,0 +1,130 @@
'''
ldm.dream.generator.txt2img inherits from ldm.dream.generator
'''
import torch
import numpy as np
import math
from ldm.dream.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
class Txt2Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # for get_noise()
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,strength,step_callback=None,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
uc, c = conditioning
@torch.no_grad()
def make_image(x_T):
trained_square = 512 * 512
actual_square = width * height
scale = math.sqrt(trained_square / actual_square)
init_width = math.ceil(scale * width / 64) * 64
init_height = math.ceil(scale * height / 64) * 64
shape = [
self.latent_channels,
init_height // self.downsampling_factor,
init_width // self.downsampling_factor,
]
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
#x = self.get_noise(init_width, init_height)
x = x_T
if self.free_gpu_mem and self.model.model.device != self.model.device:
self.model.model.to(self.model.device)
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
x_T = x,
conditioning = c,
shape = shape,
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback
)
print(
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height}"
)
# resizing
samples = torch.nn.functional.interpolate(
samples,
size=(height // self.downsampling_factor, width // self.downsampling_factor),
mode="bilinear"
)
t_enc = int(strength * steps)
x = self.get_noise(width,height,False)
z_enc = sampler.stochastic_encode(
samples,
torch.tensor([t_enc]).to(self.model.device),
noise=x
)
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
)
if self.free_gpu_mem:
self.model.model.to("cpu")
return self.sample_to_image(samples)
return make_image
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height,scale = True):
# print(f"Get noise: {width}x{height}")
if scale:
trained_square = 512 * 512
actual_square = width * height
scale = math.sqrt(trained_square / actual_square)
scaled_width = math.ceil(scale * width / 64) * 64
scaled_height = math.ceil(scale * height / 64) * 64
else:
scaled_width = width
scaled_height = height
device = self.model.device
if device.type == 'mps':
return torch.randn([1,
self.latent_channels,
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
device='cpu').to(device)
else:
return torch.randn([1,
self.latent_channels,
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
device=device)

View File

@@ -37,7 +37,7 @@ class PngWriter:
path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo()
info.add_text('Dream', dream_prompt)
if metadata: # TODO: merge command line app's method of writing metadata and always just write metadata
if metadata:
info.add_text('sd-metadata', json.dumps(metadata))
image.save(path, 'PNG', pnginfo=info)
return path
@@ -61,3 +61,8 @@ def retrieve_metadata(img_path):
dream_prompt = im.text.get('Dream', '')
return {'sd-metadata': json.loads(md), 'Dream': dream_prompt}
def write_metadata(img_path:str, meta:dict):
im = Image.open(img_path)
info = PngImagePlugin.PngInfo()
info.add_text('sd-metadata', json.dumps(meta))
im.save(img_path,'PNG',pnginfo=info)

View File

@@ -17,7 +17,7 @@ from ldm.dream.args import Args
try:
import readline
readline_available = True
except:
except (ImportError,ModuleNotFoundError):
readline_available = False
IMG_EXTENSIONS = ('.png','.jpg','.jpeg')
@@ -27,6 +27,8 @@ COMMANDS = (
'--iterations','-n',
'--width','-W','--height','-H',
'--cfg_scale','-C',
'--threshold',
'--perlin',
'--grid','-g',
'--individual','-i',
'--init_img','-I',
@@ -44,7 +46,8 @@ COMMANDS = (
'-save_orig','--save_original',
'--skip_normalize','-x',
'--log_tokenization','-t',
'!fix','!fetch','!history','!replay'
'--hires_fix',
'!fix','!fetch',!replay','!history','!search','!clear',
)
IMG_PATH_COMMANDS = (
'--outdir[=\s]',
@@ -59,7 +62,7 @@ IMG_FILE_COMMANDS=(
)
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
class Completer:
class Completer(object):
def __init__(self, options):
self.options = sorted(options)
self.seeds = set()
@@ -108,6 +111,19 @@ class Completer:
if not self.auto_history_active:
readline.add_history(line)
def clear_history(self):
'''
Pass clear_history() thru to readline
'''
readline.clear_history()
def search_history(self,match:str):
'''
Like show_history() but only shows items that
contain the match string.
'''
self.show_history(match)
def remove_history_item(self,pos):
readline.remove_history_item(pos)
@@ -134,7 +150,7 @@ class Completer:
def get_history_item(self,index):
return readline.get_history_item(index)
def show_history(self):
def show_history(self,match=None):
'''
Print the session history using the pydoc pager
'''
@@ -146,7 +162,10 @@ class Completer:
return
for i in range(0,h_len):
lines.append(f'[{i+1}] {self.get_history_item(i+1)}')
line = self.get_history_item(i+1)
if match and match not in line:
continue
lines.append(f'[{i+1}] {line}')
pydoc.pager('\n'.join(lines))
def set_line(self,line)->None:
@@ -232,6 +251,9 @@ class DummyCompleter(Completer):
def add_history(self,line):
self.history.append(line)
def clear_history(self):
self.history = list()
def get_current_history_length(self):
return len(self.history)

View File

@@ -0,0 +1,111 @@
import warnings
import math
from PIL import Image, ImageFilter
class Outcrop(object):
def __init__(
self,
image,
generate, # current generate object
):
self.image = image
self.generate = generate
def process (
self,
extents:dict,
opt, # current options
orig_opt, # ones originally used to generate the image
image_callback = None,
prefix = None
):
# grow and mask the image
extended_image = self._extend_all(extents)
# switch samplers temporarily
curr_sampler = self.generate.sampler
self.generate.sampler_name = opt.sampler_name
self.generate._set_sampler()
def wrapped_callback(img,seed,**kwargs):
image_callback(img,orig_opt.seed,use_prefix=prefix,**kwargs)
result= self.generate.prompt2image(
orig_opt.prompt,
# seed = orig_opt.seed, # uncomment to make it deterministic
sampler = self.generate.sampler,
steps = opt.steps,
cfg_scale = opt.cfg_scale,
ddim_eta = self.generate.ddim_eta,
width = extended_image.width,
height = extended_image.height,
init_img = extended_image,
strength = opt.strength,
image_callback = wrapped_callback,
)
# swap sampler back
self.generate.sampler = curr_sampler
return result
def _extend_all(
self,
extents:dict,
) -> Image:
'''
Extend the image in direction ('top','bottom','left','right') by
the indicated value. The image canvas is extended, and the empty
rectangular section will be filled with a blurred copy of the
adjacent image.
'''
image = self.image
for direction in extents:
assert direction in ['top', 'left', 'bottom', 'right'],'Direction must be one of "top", "left", "bottom", "right"'
pixels = extents[direction]
# round pixels up to the nearest 64
pixels = math.ceil(pixels/64) * 64
print(f'>> extending image {direction}ward by {pixels} pixels')
image = self._rotate(image,direction)
image = self._extend(image,pixels)
image = self._rotate(image,direction,reverse=True)
return image
def _rotate(self,image:Image,direction:str,reverse=False) -> Image:
'''
Rotates image so that the area to extend is always at the top top.
Simplifies logic later. The reverse argument, if true, will undo the
previous transpose.
'''
transposes = {
'right': ['ROTATE_90','ROTATE_270'],
'bottom': ['ROTATE_180','ROTATE_180'],
'left': ['ROTATE_270','ROTATE_90']
}
if direction not in transposes:
return image
transpose = transposes[direction][1 if reverse else 0]
return image.transpose(Image.Transpose.__dict__[transpose])
def _extend(self,image:Image,pixels:int)-> Image:
extended_img = Image.new('RGBA',(image.width,image.height+pixels))
# first paste places old image at top of extended image, stretch
# it, and applies a gaussian blur to it
# take the top half region, stretch and paste it
top_slice = image.crop(box=(0,0,image.width,pixels//2))
top_slice = top_slice.resize((image.width,pixels))
extended_img.paste(top_slice,box=(0,0))
# second paste creates a copy of the image displaced pixels downward;
# The overall effect is to create a blurred duplicate of the top portion of
# the image.
extended_img.paste(image,box=(0,pixels))
extended_img = extended_img.filter(filter=ImageFilter.GaussianBlur(radius=pixels//2))
extended_img.paste(image,box=(0,pixels))
# now make the top part transparent to use as a mask
alpha = extended_img.getchannel('A')
alpha.paste(0,(0,0,extended_img.width,pixels*2))
extended_img.putalpha(alpha)
return extended_img

View File

@@ -0,0 +1,94 @@
import warnings
import math
from PIL import Image, ImageFilter
class Outpaint(object):
def __init__(self, image, generate):
self.image = image
self.generate = generate
def process(self, opt, old_opt, image_callback = None, prefix = None):
image = self._create_outpaint_image(self.image, opt.out_direction)
seed = old_opt.seed
prompt = old_opt.prompt
print(f'DEBUG: old seed={seed}, old prompt = {prompt}')
def wrapped_callback(img,seed,**kwargs):
image_callback(img,seed,use_prefix=prefix,**kwargs)
return self.generate.prompt2image(
prompt,
seed = seed,
sampler = self.generate.sampler,
steps = opt.steps,
cfg_scale = opt.cfg_scale,
ddim_eta = self.generate.ddim_eta,
width = opt.width,
height = opt.height,
init_img = image,
strength = 0.83,
image_callback = wrapped_callback,
prefix = prefix,
)
def _create_outpaint_image(self, image, direction_args):
assert len(direction_args) in [1, 2], 'Direction (-D) must have exactly one or two arguments.'
if len(direction_args) == 1:
direction = direction_args[0]
pixels = None
elif len(direction_args) == 2:
direction = direction_args[0]
pixels = int(direction_args[1])
assert direction in ['top', 'left', 'bottom', 'right'], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
image = image.convert("RGBA")
# we always extend top, but rotate to extend along the requested side
if direction == 'left':
image = image.transpose(Image.Transpose.ROTATE_270)
elif direction == 'bottom':
image = image.transpose(Image.Transpose.ROTATE_180)
elif direction == 'right':
image = image.transpose(Image.Transpose.ROTATE_90)
pixels = image.height//2 if pixels is None else int(pixels)
assert 0 < pixels < image.height, 'Direction (-D) pixels length must be in the range 0 - image.size'
# the top part of the image is taken from the source image mirrored
# coordinates (0,0) are the upper left corner of an image
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
top = top.crop((0, top.height - pixels, top.width, top.height))
# setting all alpha of the top part to 0
alpha = top.getchannel("A")
alpha.paste(0, (0, 0, top.width, top.height))
top.putalpha(alpha)
# taking the bottom from the original image
bottom = image.crop((0, 0, image.width, image.height - pixels))
new_img = image.copy()
new_img.paste(top, (0, 0))
new_img.paste(bottom, (0, pixels))
# create a 10% dither in the middle
dither = min(image.height//10, pixels)
for x in range(0, image.width, 2):
for y in range(pixels - dither, pixels + dither):
(r, g, b, a) = new_img.getpixel((x, y))
new_img.putpixel((x, y), (r, g, b, 0))
# let's rotate back again
if direction == 'left':
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
elif direction == 'bottom':
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
elif direction == 'right':
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
return new_img

View File

@@ -35,6 +35,8 @@ def build_opt(post_data, seed, gfpgan_model_exists):
setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None)
setattr(opt, 'progress_images', 'progress_images' in post_data)
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
setattr(opt, 'threshold', float(post_data['threshold']))
setattr(opt, 'perlin', float(post_data['perlin']))
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
setattr(opt, 'with_variations', [])
setattr(opt, 'embiggen', None)

View File

@@ -19,6 +19,7 @@ import cv2
import skimage
from omegaconf import OmegaConf
from ldm.dream.generator.base import downsampling
from PIL import Image, ImageOps
from torch import nn
from pytorch_lightning import seed_everything, logging
@@ -33,23 +34,7 @@ from ldm.dream.image_util import InitImageResizer
from ldm.dream.devices import choose_torch_device, choose_precision
from ldm.dream.conditioning import get_uc_and_c
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
@@ -69,23 +54,7 @@ torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
"""Simplified text to image API for stable diffusion/latent diffusion
@@ -269,6 +238,8 @@ class Generate:
log_tokenization = False,
with_variations = None,
variation_amount = 0.0,
threshold = 0.0,
perlin = 0.0,
# these are specific to img2img and inpaint
init_img = None,
init_mask = None,
@@ -278,7 +249,6 @@ class Generate:
# these are specific to embiggen (which also relies on img2img args)
embiggen = None,
embiggen_tiles = None,
out_direction = None,
# these are specific to GFPGAN/ESRGAN
facetool = None,
gfpgan_strength = 0,
@@ -287,6 +257,7 @@ class Generate:
upscale = None,
# Set this True to handle KeyboardInterrupt internally
catch_interrupts = False,
hires_fix = False,
**args,
): # eat up additional cruft
"""
@@ -308,6 +279,8 @@ class Generate:
image_callback // a function or method that will be called each time an image is generated
with_variations // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation
variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image)
threshold // optional value >=0 to add thresholding to latent values for k-diffusion samplers (0 disables)
perlin // optional 0-1 value to add a percentage of perlin noise to the initial noise
embiggen // scale factor relative to the size of the --init_img (-I), followed by ESRGAN upscaling strength (0-1.0), followed by minimum amount of overlap between tiles as a decimal ratio (0 - 1.0) or number of pixels
embiggen_tiles // list of tiles by number in order to process and replace onto the image e.g. `0 2 4`
@@ -348,12 +321,16 @@ class Generate:
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
assert threshold >= 0.0, '--threshold must be >=0.0'
assert (
0.0 < strength < 1.0
), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'
assert (
0.0 <= variation_amount <= 1.0
), '-v --variation_amount must be in [0.0, 1.0]'
assert (
0.0 <= perlin <= 1.0
), '--perlin must be in [0.0, 1.0]'
assert (
(embiggen == None and embiggen_tiles == None) or (
(embiggen != None or embiggen_tiles != None) and init_img != None)
@@ -395,7 +372,6 @@ class Generate:
width,
height,
fit=fit,
out_direction=out_direction,
)
if (init_image is not None) and (mask_image is not None):
generator = self._make_inpaint()
@@ -403,6 +379,8 @@ class Generate:
generator = self._make_embiggen()
elif init_image is not None:
generator = self._make_img2img()
elif hires_fix:
generator = self._make_txt2img2img()
else:
generator = self._make_txt2img()
@@ -425,6 +403,8 @@ class Generate:
init_image=init_image, # notice that init_image is different from init_img
mask_image=mask_image,
strength=strength,
threshold=threshold,
perlin=perlin,
embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
)
@@ -485,6 +465,7 @@ class Generate:
codeformer_fidelity = 0.75,
upscale = None,
out_direction = None,
outcrop = [],
save_original = True, # to get new name
callback = None,
opt = None,
@@ -504,17 +485,22 @@ class Generate:
seed = 42
# try to reuse the same filename prefix as the original file.
# note that this is hacky
# we take everything up to the first period
prefix = None
m = re.search('(\d+)\.',os.path.basename(image_path))
m = re.match('^([^.]+)\.',os.path.basename(image_path))
if m:
prefix = m.groups()[0]
# face fixers and esrgan take an Image, but embiggen takes a path
image = Image.open(image_path)
# Note that we need to adopt a uniform API for the postprocessors.
# This is completely ad hoc ATCM
# used by multiple postfixers
uc, c = get_uc_and_c(
prompt, model =self.model,
skip_normalize=opt.skip_normalize,
log_tokens =opt.log_tokenization
)
if tool in ('gfpgan','codeformer','upscale'):
if tool == 'gfpgan':
facetool = 'gfpgan'
@@ -534,14 +520,24 @@ class Generate:
prefix = prefix,
)
elif tool == 'outcrop':
from ldm.dream.restoration.outcrop import Outcrop
extend_instructions = {}
for direction,pixels in _pairwise(opt.outcrop):
extend_instructions[direction]=int(pixels)
restorer = Outcrop(image,self,)
return restorer.process (
extend_instructions,
opt = opt,
orig_opt = args,
image_callback = callback,
prefix = prefix,
)
elif tool == 'embiggen':
# fetch the metadata from the image
generator = self._make_embiggen()
uc, c = get_uc_and_c(
prompt, model =self.model,
skip_normalize=opt.skip_normalize,
log_tokens =opt.log_tokenization
)
opt.strength = 0.40
print(f'>> Setting img2img strength to {opt.strength} for happy embiggening')
# embiggen takes a image path (sigh)
@@ -562,27 +558,15 @@ class Generate:
image_callback = callback,
)
elif tool == 'outpaint':
oldargs = metadata_from_png(image_path)
opt.strength = 0.83
opt.init_img = image_path
return self.prompt2image(
oldargs.prompt,
out_direction = opt.out_direction,
sampler = self.sampler,
steps = opt.steps,
cfg_scale = opt.cfg_scale,
ddim_eta = self.ddim_eta,
conditioning= get_uc_and_c(
oldargs.prompt, model =self.model,
skip_normalize=opt.skip_normalize,
log_tokens =opt.log_tokenization
),
width = opt.width,
height = opt.height,
init_img = image_path, # not the Image! (sigh)
strength = opt.strength,
from ldm.dream.restoration.outpaint import Outpaint
restorer = Outpaint(image,self)
return restorer.process(
opt,
args,
image_callback = callback,
)
prefix = prefix
)
elif tool is None:
print(f'* please provide at least one postprocessing option, such as -G or -U')
return None
@@ -598,7 +582,6 @@ class Generate:
width,
height,
fit=False,
out_direction=None,
):
init_image = None
init_mask = None
@@ -609,11 +592,7 @@ class Generate:
img,
width,
height,
fit=fit
) # this returns an Image
if out_direction:
image = self._create_outpaint_image(image, out_direction)
init_image = self._create_init_image(image) # this returns a torch tensor
)
# if image has a transparent area and no mask was provided, then try to generate mask
if self._has_transparency(image) and not mask:
@@ -626,12 +605,17 @@ class Generate:
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
)
# this returns a torch tensor
init_mask = self._create_init_mask(image)
init_mask = self._create_init_mask(image,width,height,fit=fit)
if (image.width * image.height) > (self.width * self.height):
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
if mask:
mask_image = self._load_img(
mask, width, height, fit=fit) # this returns an Image
init_mask = self._create_init_mask(mask_image)
mask, width, height) # this returns an Image
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
return init_image, init_mask
@@ -660,6 +644,13 @@ class Generate:
self.generators['txt2img'].free_gpu_mem = self.free_gpu_mem
return self.generators['txt2img']
def _make_txt2img2img(self):
if not self.generators.get('txt2img2'):
from ldm.dream.generator.txt2img2img import Txt2Img2Img
self.generators['txt2img2'] = Txt2Img2Img(self.model, self.precision)
self.generators['txt2img2'].free_gpu_mem = self.free_gpu_mem
return self.generators['txt2img2']
def _make_inpaint(self):
if not self.generators.get('inpaint'):
from ldm.dream.generator.inpaint import Inpaint
@@ -840,7 +831,7 @@ class Generate:
return model
def _load_img(self, img, width, height, fit=False):
def _load_img(self, img, width, height)->Image:
if isinstance(img, Image.Image):
image = img
print(
@@ -858,92 +849,33 @@ class Generate:
print(
f'>> loaded input image of size {image.width}x{image.height}'
)
image = ImageOps.exif_transpose(image)
return image
def _create_init_image(self, image, width, height, fit=True):
image = image.convert('RGB')
if fit:
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
return image
def _create_init_image(self, image):
image = image.convert('RGB')
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.0 * image - 1.0
return image.to(self.device)
# TODO: outpainting is a post-processing application and should be made to behave
# like the other ones.
def _create_outpaint_image(self, image, direction_args):
assert len(direction_args) in [1, 2], 'Direction (-D) must have exactly one or two arguments.'
if len(direction_args) == 1:
direction = direction_args[0]
pixels = None
elif len(direction_args) == 2:
direction = direction_args[0]
pixels = int(direction_args[1])
assert direction in ['top', 'left', 'bottom', 'right'], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
image = image.convert("RGBA")
# we always extend top, but rotate to extend along the requested side
if direction == 'left':
image = image.transpose(Image.Transpose.ROTATE_270)
elif direction == 'bottom':
image = image.transpose(Image.Transpose.ROTATE_180)
elif direction == 'right':
image = image.transpose(Image.Transpose.ROTATE_90)
pixels = image.height//2 if pixels is None else int(pixels)
assert 0 < pixels < image.height, 'Direction (-D) pixels length must be in the range 0 - image.size'
# the top part of the image is taken from the source image mirrored
# coordinates (0,0) are the upper left corner of an image
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
top = top.crop((0, top.height - pixels, top.width, top.height))
# setting all alpha of the top part to 0
alpha = top.getchannel("A")
alpha.paste(0, (0, 0, top.width, top.height))
top.putalpha(alpha)
# taking the bottom from the original image
bottom = image.crop((0, 0, image.width, image.height - pixels))
new_img = image.copy()
new_img.paste(top, (0, 0))
new_img.paste(bottom, (0, pixels))
# create a 10% dither in the middle
dither = min(image.height//10, pixels)
for x in range(0, image.width, 2):
for y in range(pixels - dither, pixels + dither):
(r, g, b, a) = new_img.getpixel((x, y))
new_img.putpixel((x, y), (r, g, b, 0))
# let's rotate back again
if direction == 'left':
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
elif direction == 'bottom':
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
elif direction == 'right':
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
return new_img
def _create_init_mask(self, image):
def _create_init_mask(self, image, width, height, fit=True):
# convert into a black/white mask
image = self._image_to_mask(image)
image = image.convert('RGB')
# BUG: We need to use the model's downsample factor rather than hardcoding "8"
from ldm.dream.generator.base import downsampling
# now we adjust the size
if fit:
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
image = image.resize((image.width//downsampling, image.height //
downsampling), resample=Image.Resampling.NEAREST)
# print(
# f'>> DEBUG: writing the mask to mask.png'
# )
# image.save('mask.png')
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
@@ -1025,10 +957,6 @@ class Generate:
height = h
width = w
resize_needed = True
if (width * height) > (self.width * self.height):
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
return width, height, resize_needed
@@ -1055,3 +983,20 @@ class Generate:
f.write(hash)
return hash
def write_intermediate_images(self,modulus,path):
counter = -1
if not os.path.exists(path):
os.makedirs(path)
def callback(img):
nonlocal counter
counter += 1
if counter % modulus != 0:
return;
image = self.sample_to_image(img)
image.save(os.path.join(path,f'{counter:03}.png'),'PNG')
return callback
def _pairwise(iterable):
"s -> (s0, s1), (s2, s3), (s4, s5), ..."
a = iter(iterable)
return zip(a, a)

View File

@@ -5,289 +5,31 @@ import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.dream.devices import choose_torch_device
from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
class DDIMSampler(object):
class DDIMSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device or choose_torch_device()
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(dtype=torch.float32, device=self.device)
setattr(self, name, attr)
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
super().__init__(model,schedule,model.num_timesteps,device)
# This is the central routine
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
# This routine gets called from img2img
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
print(f'\nRunning DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(
time_range,
desc='DDIM Sampler',
total=total_steps,
dynamic_ncols=True,
)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
img, pred_x0 = outs
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
# This routine gets called from ddim_sampling() and decode()
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
def p_sample(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
**kwargs,
):
b, *_, device = *x.shape, x.device
@@ -351,83 +93,5 @@ class DDIMSampler(object):
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
return x_prev, pred_x0, None
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
* noise
)
@torch.no_grad()
def decode(
self,
x_latent,
cond,
t_start,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
init_latent = None,
mask = None,
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
x0 = init_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(x_latent.shape[0],),
step,
device=x_latent.device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
xdec_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
if img_callback:
img_callback(x_dec, i)
return x_dec

View File

@@ -701,7 +701,7 @@ class LatentDiffusion(DDPM):
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
# only for very first batch
if (
self.scale_by_std
@@ -1890,7 +1890,7 @@ class LatentDiffusion(DDPM):
N=8,
n_row=4,
sample=True,
ddim_steps=200,
ddim_steps=50,
ddim_eta=1.0,
return_keys=None,
quantize_denoised=True,

View File

@@ -3,26 +3,62 @@ import k_diffusion as K
import torch
import torch.nn as nn
from ldm.dream.devices import choose_torch_device
from ldm.models.diffusion.sampler import Sampler
from ldm.util import rand_perlin_2d
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0:
return result
maxval = 0.0 + torch.max(result).cpu().numpy()
minval = 0.0 + torch.min(result).cpu().numpy()
if maxval < threshold and minval > -threshold:
return result
if maxval > threshold:
maxval = min(max(1, scale*maxval), threshold)
if minval < -threshold:
minval = max(min(-1, scale*minval), -threshold)
return torch.clamp(result, min=minval, max=maxval)
class CFGDenoiser(nn.Module):
def __init__(self, model):
def __init__(self, model, threshold = 0, warmup = 0):
super().__init__()
self.inner_model = model
self.threshold = threshold
self.warmup_max = warmup
self.warmup = max(warmup / 10, 1)
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1
else:
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
class KSampler(object):
class KSampler(Sampler):
def __init__(self, model, schedule='lms', device=None, **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule
self.device = device or choose_torch_device()
denoiser = K.external.CompVisDenoiser(model)
super().__init__(
denoiser,
schedule,
steps=model.num_timesteps,
)
self.ds = None
self.s_in = None
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
@@ -33,7 +69,74 @@ class KSampler(object):
).chunk(2)
return uncond + (cond - uncond) * cond_scale
# most of these arguments are ignored and are only present for compatibility with
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=False,
):
outer_model = self.model
self.model = outer_model.inner_model
super().make_schedule(
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=False,
)
self.model = outer_model
self.ddim_num_steps = ddim_num_steps
# we don't need both of these sigmas, but storing them here to make
# comparison easier later on
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
self.karras_sigmas = K.sampling.get_sigmas_karras(
n=ddim_num_steps,
sigma_min=self.model.sigmas[0].item(),
sigma_max=self.model.sigmas[-1].item(),
rho=7.,
device=self.device,
)
self.sigmas = self.karras_sigmas
# ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will not work. To get this to work we need to be able to
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
# in the lstein/k-diffusion branch.
@torch.no_grad()
def decode(
self,
z_enc,
cond,
t_enc,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
init_latent = None,
mask = None,
):
samples,_ = self.sample(
batch_size = 1,
S = t_enc,
x_T = z_enc,
shape = z_enc.shape[1:],
conditioning = cond,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning = unconditional_conditioning,
img_callback = img_callback,
x0 = init_latent,
mask = mask
)
return samples
# this is a no-op, provided here for compatibility with ddim and plms samplers
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
return x0
# Most of these arguments are ignored and are only present for compatibility with
# other samples
@torch.no_grad()
def sample(
@@ -58,27 +161,36 @@ class KSampler(object):
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
def route_callback(k_callback_values):
if img_callback is not None:
img_callback(k_callback_values['x'], k_callback_values['i'])
img_callback(k_callback_values['x'],k_callback_values['i'])
sigmas = self.model.get_sigmas(S)
# sigmas are set up in make_schedule - we take the last steps items
total_steps = len(self.sigmas)
sigmas = self.sigmas[-S-1:]
# x_T is variation noise. When an init image is provided (in x0) we need to add
# more randomness to the starting image.
if x_T is not None:
x = x_T * sigmas[0]
if x0 is not None:
x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0]
else:
x = x_T * sigmas[0]
else:
x = (
torch.randn([batch_size, *shape], device=self.device)
* sigmas[0]
) # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale,
}
print(f'>> Sampling with k_{self.schedule}')
return (
K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args,
@@ -86,3 +198,74 @@ class KSampler(object):
),
None,
)
@torch.no_grad()
def p_sample(
self,
img,
cond,
ts,
index,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
**kwargs,
):
if self.model_wrap is None:
self.model_wrap = CFGDenoiser(self.model)
extra_args = {
'cond': cond,
'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale,
}
if self.s_in is None:
self.s_in = img.new_ones([img.shape[0]])
if self.ds is None:
self.ds = []
# terrible, confusing names here
steps = self.ddim_num_steps
t_enc = self.t_enc
# sigmas is a full steps in length, but t_enc might
# be less. We start in the middle of the sigma array
# and work our way to the end after t_enc steps.
# index starts at t_enc and works its way to zero,
# so the actual formula for indexing into sigmas:
# sigma_index = (steps-index)
s_index = t_enc - index - 1
img = K.sampling.__dict__[f'_{self.schedule}'](
self.model_wrap,
img,
self.sigmas,
s_index,
s_in = self.s_in,
ds = self.ds,
extra_args=extra_args,
)
return img, None, None
# REVIEW THIS METHOD: it has never been tested. In particular,
# we should not be multiplying by self.sigmas[0] if we
# are at an intermediate step in img2img. See similar in
# sample() which does work.
def get_initial_image(self,x_T,shape,steps):
print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing')
x = (torch.randn(shape, device=self.device) * self.sigmas[0])
if x_T is not None:
return x_T + x
else:
return x
def prepare_to_sample(self,t_enc):
self.t_enc = t_enc
self.model_wrap = None
self.ds = None
self.s_in = None
def q_sample(self,x0,ts):
'''
Overrides parent method to return the q_sample of the inner model.
'''
return self.model.inner_model.q_sample(x0,ts)

View File

@@ -5,302 +5,34 @@ import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.dream.devices import choose_torch_device
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like
class PLMSSampler(object):
class PLMSSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device if device else choose_torch_device()
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.float32).to(torch.device(self.device))
setattr(self, name, attr)
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
super().__init__(model,schedule,model.num_timesteps, device)
# this is the essential routine
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = (
list(reversed(range(0, timesteps)))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
# print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(
time_range,
desc='PLMS Sampler',
total=total_steps,
dynamic_ncols=True,
)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_plms(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
def p_sample(
self,
x, # image, called 'img' elsewhere
c, # conditioning, called 'cond' elsewhere
t, # timesteps, called 'ts' elsewhere
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=[],
t_next=None,
**kwargs,
):
b, *_, device = *x.shape, x.device

View File

@@ -0,0 +1,404 @@
'''
ldm.models.diffusion.sampler
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
'''
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.dream.devices import choose_torch_device
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
class Sampler(object):
def __init__(self, model, schedule='linear', steps=None, device=None, **kwargs):
self.model = model
self.ddpm_num_timesteps = steps
self.schedule = schedule
self.device = device or choose_torch_device()
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.float32).to(torch.device(self.device))
setattr(self, name, attr)
# This method was copied over from ddim.py and probably does stuff that is
# ddim-specific. Disentangle at some point.
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=False,
):
self.total_steps = ddim_num_steps
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
* noise
)
@torch.no_grad()
def sample(
self,
S, # S is steps
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=False,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
ts = self.get_timesteps(S)
# sampling
C, H, W = shape
shape = (batch_size, C, H, W)
samples, intermediates = self.do_sampling(
conditioning,
shape,
timesteps=ts,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
steps=S,
)
return samples, intermediates
#torch.no_grad()
def do_sampling(
self,
cond,
shape,
timesteps=None,
x_T=None,
ddim_use_original_steps=False,
callback=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
steps=None,
):
b = shape[0]
time_range = (
list(reversed(range(0, timesteps)))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps=steps
iterator = tqdm(
time_range,
desc=f'{self.__class__.__name__}',
total=total_steps,
dynamic_ncols=True,
)
old_eps = []
self.prepare_to_sample(t_enc=total_steps)
img = self.get_initial_image(x_T,shape,total_steps)
# probably don't need this at all
intermediates = {'x_inter': [img], 'pred_x0': [img]}
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(b,),
step,
device=self.device,
dtype=torch.long
)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=self.device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback:
callback(i)
if img_callback:
img_callback(img,i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
# NOTE that decode() and sample() are almost the same code, and do the same thing.
# The variable names are changed in order to be confusing.
@torch.no_grad()
def decode(
self,
x_latent,
cond,
t_start,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
init_latent = None,
mask = None,
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
x0 = init_latent
self.prepare_to_sample(t_enc=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(x_latent.shape[0],),
step,
device=x_latent.device,
dtype=torch.long,
)
ts_next = torch.full(
(x_latent.shape[0],),
time_range[min(i + 1, len(time_range) - 1)],
device=self.device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
xdec_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
outs = self.p_sample(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
t_next = ts_next,
)
x_dec, pred_x0, e_t = outs
if img_callback:
img_callback(x_dec,i)
return x_dec
def get_initial_image(self,x_T,shape,timesteps=None):
if x_T is None:
return torch.randn(shape, device=self.device)
else:
return x_T
def p_sample(
self,
img,
cond,
ts,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
steps=None,
):
raise NotImplementedError("p_sample() must be implemented in a descendent class")
def prepare_to_sample(self,t_enc,**kwargs):
'''
Hook that will be called right before the very first invocation of p_sample()
to allow subclass to do additional initialization. t_enc corresponds to the actual
number of steps that will be run, and may be less than total steps if img2img is
active.
'''
pass
def get_timesteps(self,ddim_steps):
'''
The ddim and plms samplers work on timesteps. This method is called after
ddim_timesteps are created in make_schedule(), and selects the portion of
timesteps that will be used for sampling, depending on the t_enc in img2img.
'''
return self.ddim_timesteps[:ddim_steps]
def q_sample(self,x0,ts):
'''
Returns self.model.q_sample(x0,ts). Is overridden in the k* samplers to
return self.model.inner_model.q_sample(x0,ts)
'''
return self.model.q_sample(x0,ts)

View File

@@ -169,9 +169,14 @@ class EmbeddingManager(nn.Module):
placeholder_embedding.shape[0], max_step_tokens
)
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
if torch.cuda.is_available():
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
else:
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token
)
if placeholder_rows.nelement() == 0:
continue

View File

@@ -2,6 +2,7 @@ import importlib
import torch
import numpy as np
import math
from collections import abc
from einops import rearrange
from functools import partial
@@ -212,3 +213,25 @@ def parallel_data_prefetch(
return out
else:
return gather_res
def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1
rand_val = torch.rand(res[0]+1, res[1]+1)
angles = 2*math.pi*rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1)
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1])
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1])
t = fade(grid[:shape[0], :shape[1]])
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])