mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Merge branch 'development' into fix-high-step-count
This commit is contained in:
@@ -1,4 +0,0 @@
|
||||
'''
|
||||
Initialization file for the ldm.dream.generator package
|
||||
'''
|
||||
from .base import Generator
|
||||
@@ -1,4 +0,0 @@
|
||||
'''
|
||||
Initialization file for the ldm.dream.restoration package
|
||||
'''
|
||||
from .base import Restoration
|
||||
332
ldm/generate.py
332
ldm/generate.py
@@ -19,7 +19,7 @@ import cv2
|
||||
import skimage
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from ldm.dream.generator.base import downsampling
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
from PIL import Image, ImageOps
|
||||
from torch import nn
|
||||
from pytorch_lightning import seed_everything, logging
|
||||
@@ -28,30 +28,15 @@ from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
from ldm.dream.args import metadata_from_png
|
||||
from ldm.dream.image_util import InitImageResizer
|
||||
from ldm.dream.devices import choose_torch_device, choose_precision
|
||||
from ldm.dream.conditioning import get_uc_and_c
|
||||
|
||||
def fix_func(orig):
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
def new_func(*args, **kw):
|
||||
device = kw.get("device", "mps")
|
||||
kw["device"]="cpu"
|
||||
return orig(*args, **kw).to(device)
|
||||
return new_func
|
||||
return orig
|
||||
|
||||
torch.rand = fix_func(torch.rand)
|
||||
torch.rand_like = fix_func(torch.rand_like)
|
||||
torch.randn = fix_func(torch.randn)
|
||||
torch.randn_like = fix_func(torch.randn_like)
|
||||
torch.randint = fix_func(torch.randint)
|
||||
torch.randint_like = fix_func(torch.randint_like)
|
||||
torch.bernoulli = fix_func(torch.bernoulli)
|
||||
torch.multinomial = fix_func(torch.multinomial)
|
||||
|
||||
from ldm.invoke.pngwriter import PngWriter
|
||||
from ldm.invoke.args import metadata_from_png
|
||||
from ldm.invoke.image_util import InitImageResizer
|
||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||
from ldm.invoke.conditioning import get_uc_and_c
|
||||
from ldm.invoke.model_cache import ModelCache
|
||||
from ldm.invoke.seamless import configure_model_padding
|
||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
|
||||
def fix_func(orig):
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
def new_func(*args, **kw):
|
||||
@@ -174,14 +159,14 @@ class Generate:
|
||||
config = None,
|
||||
gfpgan=None,
|
||||
codeformer=None,
|
||||
esrgan=None
|
||||
esrgan=None,
|
||||
free_gpu_mem=False,
|
||||
):
|
||||
models = OmegaConf.load(conf)
|
||||
mconfig = models[model]
|
||||
self.weights = mconfig.weights if weights is None else weights
|
||||
self.config = mconfig.config if config is None else config
|
||||
self.height = mconfig.height
|
||||
self.width = mconfig.width
|
||||
mconfig = OmegaConf.load(conf)
|
||||
self.model_name = model
|
||||
self.height = None
|
||||
self.width = None
|
||||
self.model_cache = None
|
||||
self.iterations = 1
|
||||
self.steps = 50
|
||||
self.cfg_scale = 7.5
|
||||
@@ -190,8 +175,11 @@ class Generate:
|
||||
self.precision = precision
|
||||
self.strength = 0.75
|
||||
self.seamless = False
|
||||
self.seamless_axes = {'x','y'}
|
||||
self.hires_fix = False
|
||||
self.embedding_path = embedding_path
|
||||
self.model = None # empty for now
|
||||
self.model_hash = None
|
||||
self.sampler = None
|
||||
self.device = None
|
||||
self.session_peakmem = None
|
||||
@@ -201,11 +189,15 @@ class Generate:
|
||||
self.gfpgan = gfpgan
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
self.txt2mask = None
|
||||
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
# it wasn't actually doing anything. This logic could be reinstated.
|
||||
device_type = choose_torch_device()
|
||||
print(f'>> Using device_type {device_type}')
|
||||
self.device = torch.device(device_type)
|
||||
if full_precision:
|
||||
if self.precision != 'auto':
|
||||
@@ -216,6 +208,9 @@ class Generate:
|
||||
if self.precision == 'auto':
|
||||
self.precision = choose_precision(self.device)
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||
|
||||
# for VRAM usage statistics
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||
transformers.logging.set_verbosity_error()
|
||||
@@ -267,6 +262,7 @@ class Generate:
|
||||
height = None,
|
||||
sampler_name = None,
|
||||
seamless = False,
|
||||
seamless_axes = {'x','y'},
|
||||
log_tokenization = False,
|
||||
with_variations = None,
|
||||
variation_amount = 0.0,
|
||||
@@ -275,6 +271,7 @@ class Generate:
|
||||
# these are specific to img2img and inpaint
|
||||
init_img = None,
|
||||
init_mask = None,
|
||||
text_mask = None,
|
||||
fit = False,
|
||||
strength = None,
|
||||
init_color = None,
|
||||
@@ -283,10 +280,12 @@ class Generate:
|
||||
embiggen_tiles = None,
|
||||
# these are specific to GFPGAN/ESRGAN
|
||||
facetool = None,
|
||||
gfpgan_strength = 0,
|
||||
facetool_strength = 0,
|
||||
codeformer_fidelity = None,
|
||||
save_original = False,
|
||||
upscale = None,
|
||||
# this is specific to inpainting and causes more extreme inpainting
|
||||
inpaint_replace = 0.0,
|
||||
# Set this True to handle KeyboardInterrupt internally
|
||||
catch_interrupts = False,
|
||||
hires_fix = False,
|
||||
@@ -303,9 +302,12 @@ class Generate:
|
||||
height // height of image, in multiples of 64 (512)
|
||||
cfg_scale // how strongly the prompt influences the image (7.5) (must be >1)
|
||||
seamless // whether the generated image should tile
|
||||
hires_fix // whether the Hires Fix should be applied during generation
|
||||
init_img // path to an initial image
|
||||
init_mask // path to a mask for the initial image
|
||||
text_mask // a text string that will be used to guide clipseg generation of the init_mask
|
||||
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||
step_callback // a function or method that will be called each step
|
||||
image_callback // a function or method that will be called each time an image is generated
|
||||
@@ -327,15 +329,17 @@ class Generate:
|
||||
def process_image(image,seed):
|
||||
image.save(f{'images/seed.png'})
|
||||
|
||||
The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code
|
||||
to create the requested output directory, select a unique informative name for each image, and
|
||||
write the prompt into the PNG metadata.
|
||||
The code used to save images to a directory can be found in ldm/invoke/pngwriter.py.
|
||||
It contains code to create the requested output directory, select a unique informative
|
||||
name for each image, and write the prompt into the PNG metadata.
|
||||
"""
|
||||
# TODO: convert this into a getattr() loop
|
||||
steps = steps or self.steps
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
seamless = seamless or self.seamless
|
||||
seamless_axes = seamless_axes or self.seamless_axes
|
||||
hires_fix = hires_fix or self.hires_fix
|
||||
cfg_scale = cfg_scale or self.cfg_scale
|
||||
ddim_eta = ddim_eta or self.ddim_eta
|
||||
iterations = iterations or self.iterations
|
||||
@@ -346,11 +350,14 @@ class Generate:
|
||||
with_variations = [] if with_variations is None else with_variations
|
||||
|
||||
# will instantiate the model or return it from cache
|
||||
model = self.load_model()
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
|
||||
model = self.set_model(self.model_name)
|
||||
|
||||
# self.width and self.height are set by set_model()
|
||||
# to the width and height of the image training set
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
|
||||
configure_model_padding(model, seamless, seamless_axes)
|
||||
|
||||
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
||||
assert threshold >= 0.0, '--threshold must be >=0.0'
|
||||
@@ -378,6 +385,7 @@ class Generate:
|
||||
f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}'
|
||||
|
||||
width, height, _ = self._resolution_check(width, height, log=True)
|
||||
assert inpaint_replace >=0.0 and inpaint_replace <= 1.0,'inpaint_replace must be between 0.0 and 1.0'
|
||||
|
||||
if sampler_name and (sampler_name != self.sampler_name):
|
||||
self.sampler_name = sampler_name
|
||||
@@ -404,7 +412,10 @@ class Generate:
|
||||
width,
|
||||
height,
|
||||
fit=fit,
|
||||
text_mask=text_mask,
|
||||
)
|
||||
|
||||
# TODO: Hacky selection of operation to perform. Needs to be refactored.
|
||||
if (init_image is not None) and (mask_image is not None):
|
||||
generator = self._make_inpaint()
|
||||
elif (embiggen != None or embiggen_tiles != None):
|
||||
@@ -417,7 +428,9 @@ class Generate:
|
||||
generator = self._make_txt2img()
|
||||
|
||||
generator.set_variation(
|
||||
self.seed, variation_amount, with_variations)
|
||||
self.seed, variation_amount, with_variations
|
||||
)
|
||||
|
||||
results = generator.generate(
|
||||
prompt,
|
||||
iterations=iterations,
|
||||
@@ -439,6 +452,7 @@ class Generate:
|
||||
perlin=perlin,
|
||||
embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
inpaint_replace=inpaint_replace,
|
||||
)
|
||||
|
||||
if init_color:
|
||||
@@ -446,11 +460,11 @@ class Generate:
|
||||
reference_image_path = init_color,
|
||||
image_callback = image_callback)
|
||||
|
||||
if upscale is not None or gfpgan_strength > 0:
|
||||
if upscale is not None or facetool_strength > 0:
|
||||
self.upscale_and_reconstruct(results,
|
||||
upscale = upscale,
|
||||
facetool = facetool,
|
||||
strength = gfpgan_strength,
|
||||
strength = facetool_strength,
|
||||
codeformer_fidelity = codeformer_fidelity,
|
||||
save_original = save_original,
|
||||
image_callback = image_callback)
|
||||
@@ -493,7 +507,7 @@ class Generate:
|
||||
self,
|
||||
image_path,
|
||||
tool = 'gfpgan', # one of 'upscale', 'gfpgan', 'codeformer', 'outpaint', or 'embiggen'
|
||||
gfpgan_strength = 0.0,
|
||||
facetool_strength = 0.0,
|
||||
codeformer_fidelity = 0.75,
|
||||
upscale = None,
|
||||
out_direction = None,
|
||||
@@ -540,11 +554,11 @@ class Generate:
|
||||
facetool = 'codeformer'
|
||||
elif tool == 'upscale':
|
||||
facetool = 'gfpgan' # but won't be run
|
||||
gfpgan_strength = 0
|
||||
facetool_strength = 0
|
||||
return self.upscale_and_reconstruct(
|
||||
[[image,seed]],
|
||||
facetool = facetool,
|
||||
strength = gfpgan_strength,
|
||||
strength = facetool_strength,
|
||||
codeformer_fidelity = codeformer_fidelity,
|
||||
save_original = save_original,
|
||||
upscale = upscale,
|
||||
@@ -553,7 +567,7 @@ class Generate:
|
||||
)
|
||||
|
||||
elif tool == 'outcrop':
|
||||
from ldm.dream.restoration.outcrop import Outcrop
|
||||
from ldm.invoke.restoration.outcrop import Outcrop
|
||||
extend_instructions = {}
|
||||
for direction,pixels in _pairwise(opt.outcrop):
|
||||
extend_instructions[direction]=int(pixels)
|
||||
@@ -590,7 +604,7 @@ class Generate:
|
||||
image_callback = callback,
|
||||
)
|
||||
elif tool == 'outpaint':
|
||||
from ldm.dream.restoration.outpaint import Outpaint
|
||||
from ldm.invoke.restoration.outpaint import Outpaint
|
||||
restorer = Outpaint(image,self)
|
||||
return restorer.process(
|
||||
opt,
|
||||
@@ -614,105 +628,112 @@ class Generate:
|
||||
width,
|
||||
height,
|
||||
fit=False,
|
||||
text_mask=None,
|
||||
):
|
||||
init_image = None
|
||||
init_mask = None
|
||||
if not img:
|
||||
return None, None
|
||||
|
||||
image = self._load_img(
|
||||
img,
|
||||
width,
|
||||
height,
|
||||
)
|
||||
image = self._load_img(img)
|
||||
|
||||
if image.width < self.width and image.height < self.height:
|
||||
print(f'>> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions')
|
||||
|
||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||
if self._has_transparency(image) and not mask:
|
||||
print(
|
||||
'>> Initial image has transparent areas. Will inpaint in these regions.')
|
||||
if self._check_for_erasure(image):
|
||||
print(
|
||||
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
|
||||
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
|
||||
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
|
||||
)
|
||||
if self._has_transparency(image):
|
||||
self._transparency_check_and_warning(image, mask)
|
||||
# this returns a torch tensor
|
||||
init_mask = self._create_init_mask(image,width,height,fit=fit)
|
||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||
|
||||
if (image.width * image.height) > (self.width * self.height):
|
||||
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||
self.size_matters = False
|
||||
|
||||
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
||||
|
||||
if mask:
|
||||
mask_image = self._load_img(
|
||||
mask, width, height) # this returns an Image
|
||||
mask_image = self._load_img(mask) # this returns an Image
|
||||
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
|
||||
|
||||
elif text_mask:
|
||||
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
||||
|
||||
return init_image, init_mask
|
||||
|
||||
def _make_base(self):
|
||||
if not self.generators.get('base'):
|
||||
from ldm.dream.generator import Generator
|
||||
from ldm.invoke.generator import Generator
|
||||
self.generators['base'] = Generator(self.model, self.precision)
|
||||
return self.generators['base']
|
||||
|
||||
def _make_img2img(self):
|
||||
if not self.generators.get('img2img'):
|
||||
from ldm.dream.generator.img2img import Img2Img
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
self.generators['img2img'] = Img2Img(self.model, self.precision)
|
||||
return self.generators['img2img']
|
||||
|
||||
def _make_embiggen(self):
|
||||
if not self.generators.get('embiggen'):
|
||||
from ldm.dream.generator.embiggen import Embiggen
|
||||
from ldm.invoke.generator.embiggen import Embiggen
|
||||
self.generators['embiggen'] = Embiggen(self.model, self.precision)
|
||||
return self.generators['embiggen']
|
||||
|
||||
def _make_txt2img(self):
|
||||
if not self.generators.get('txt2img'):
|
||||
from ldm.dream.generator.txt2img import Txt2Img
|
||||
from ldm.invoke.generator.txt2img import Txt2Img
|
||||
self.generators['txt2img'] = Txt2Img(self.model, self.precision)
|
||||
self.generators['txt2img'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['txt2img']
|
||||
|
||||
def _make_txt2img2img(self):
|
||||
if not self.generators.get('txt2img2'):
|
||||
from ldm.dream.generator.txt2img2img import Txt2Img2Img
|
||||
from ldm.invoke.generator.txt2img2img import Txt2Img2Img
|
||||
self.generators['txt2img2'] = Txt2Img2Img(self.model, self.precision)
|
||||
self.generators['txt2img2'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['txt2img2']
|
||||
|
||||
def _make_inpaint(self):
|
||||
if not self.generators.get('inpaint'):
|
||||
from ldm.dream.generator.inpaint import Inpaint
|
||||
from ldm.invoke.generator.inpaint import Inpaint
|
||||
self.generators['inpaint'] = Inpaint(self.model, self.precision)
|
||||
return self.generators['inpaint']
|
||||
|
||||
def load_model(self):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if self.model is None:
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
try:
|
||||
model = self._load_model_from_config(self.config, self.weights)
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(
|
||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
||||
)
|
||||
self.model = model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
self.model.cond_stage_model.device = self.device
|
||||
except AttributeError as e:
|
||||
print(f'>> Error loading model. {str(e)}', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
raise SystemExit from e
|
||||
'''
|
||||
preload model identified in self.model_name
|
||||
'''
|
||||
self.set_model(self.model_name)
|
||||
|
||||
self._set_sampler()
|
||||
def set_model(self,model_name):
|
||||
"""
|
||||
Given the name of a model defined in models.yaml, will load and initialize it
|
||||
and return the model object. Previously-used models will be cached.
|
||||
"""
|
||||
if self.model_name == model_name and self.model is not None:
|
||||
return self.model
|
||||
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m._orig_padding_mode = m.padding_mode
|
||||
model_data = self.model_cache.get_model(model_name)
|
||||
if model_data is None or len(model_data) == 0:
|
||||
print(f'** Model switch failed **')
|
||||
return self.model
|
||||
|
||||
self.model = model_data['model']
|
||||
self.width = model_data['width']
|
||||
self.height= model_data['height']
|
||||
self.model_hash = model_data['hash']
|
||||
|
||||
# uncache generators so they pick up new models
|
||||
self.generators = {}
|
||||
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
if self.embedding_path is not None:
|
||||
self.model.embedding_manager.load(
|
||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
||||
)
|
||||
|
||||
self._set_sampler()
|
||||
self.model_name = model_name
|
||||
return self.model
|
||||
|
||||
def correct_colors(self,
|
||||
@@ -784,6 +805,23 @@ class Generate:
|
||||
else:
|
||||
r[0] = image
|
||||
|
||||
def apply_textmask(self, image_path:str, prompt:str, callback, threshold:float=0.5):
|
||||
assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **'
|
||||
basename,_ = os.path.splitext(os.path.basename(image_path))
|
||||
if self.txt2mask is None:
|
||||
self.txt2mask = Txt2Mask(device = self.device)
|
||||
segmented = self.txt2mask.segment(image_path,prompt)
|
||||
trans = segmented.to_transparent()
|
||||
inverse = segmented.to_transparent(invert=True)
|
||||
mask = segmented.to_mask(threshold)
|
||||
|
||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||
safe_prompt = path_filter.sub('_', prompt)[:50].rstrip(' .')
|
||||
|
||||
callback(trans,f'{safe_prompt}.deselected',use_prefix=basename)
|
||||
callback(inverse,f'{safe_prompt}.selected',use_prefix=basename)
|
||||
callback(mask,f'{safe_prompt}.masked',use_prefix=basename)
|
||||
|
||||
# to help WebGUI - front end to generator util function
|
||||
def sample_to_image(self, samples):
|
||||
return self._make_base().sample_to_image(samples)
|
||||
@@ -816,54 +854,7 @@ class Generate:
|
||||
|
||||
print(msg)
|
||||
|
||||
# Be warned: config is the path to the model config file, not the dream conf file!
|
||||
# Also note that we can get config and weights from self, so why do we need to
|
||||
# pass them as args?
|
||||
def _load_model_from_config(self, config, weights):
|
||||
print(f'>> Loading model from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
device_type = choose_torch_device()
|
||||
if device_type == 'cuda':
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
c = OmegaConf.load(config)
|
||||
with open(weights,'rb') as f:
|
||||
weight_bytes = f.read()
|
||||
self.model_hash = self._cached_sha256(weights,weight_bytes)
|
||||
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||
del weight_bytes
|
||||
sd = pl_sd['state_dict']
|
||||
model = instantiate_from_config(c.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if self.precision == 'float16':
|
||||
print('>> Using faster float16 precision')
|
||||
model.to(torch.float16)
|
||||
else:
|
||||
print('>> Using more accurate float32 precision')
|
||||
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(
|
||||
f'>> Model loaded in', '%4.2fs' % (toc - tic)
|
||||
)
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def _load_img(self, img, width, height)->Image:
|
||||
def _load_img(self, img)->Image:
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
print(
|
||||
@@ -881,6 +872,7 @@ class Generate:
|
||||
print(
|
||||
f'>> loaded input image of size {image.width}x{image.height}'
|
||||
)
|
||||
image = ImageOps.exif_transpose(image)
|
||||
return image
|
||||
|
||||
def _create_init_image(self, image, width, height, fit=True):
|
||||
@@ -889,7 +881,6 @@ class Generate:
|
||||
image = self._fit_image(image, (width, height))
|
||||
else:
|
||||
image = self._squeeze_image(image)
|
||||
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
@@ -906,7 +897,6 @@ class Generate:
|
||||
image = self._fit_image(image, (width, height))
|
||||
else:
|
||||
image = self._squeeze_image(image)
|
||||
|
||||
image = image.resize((image.width//downsampling, image.height //
|
||||
downsampling), resample=Image.Resampling.NEAREST)
|
||||
image = np.array(image)
|
||||
@@ -926,6 +916,29 @@ class Generate:
|
||||
mask = ImageOps.invert(mask)
|
||||
return mask
|
||||
|
||||
# TODO: The latter part of this method repeats code from _create_init_mask()
|
||||
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
|
||||
prompt = text_mask[0]
|
||||
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
|
||||
if self.txt2mask is None:
|
||||
self.txt2mask = Txt2Mask(device = self.device)
|
||||
|
||||
segmented = self.txt2mask.segment(image, prompt)
|
||||
mask = segmented.to_mask(float(confidence_level))
|
||||
mask = mask.convert('RGB')
|
||||
# now we adjust the size
|
||||
if fit:
|
||||
mask = self._fit_image(mask, (width, height))
|
||||
else:
|
||||
mask = self._squeeze_image(mask)
|
||||
mask = mask.resize((mask.width//downsampling, mask.height //
|
||||
downsampling), resample=Image.Resampling.NEAREST)
|
||||
mask = np.array(mask)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None].transpose(0, 3, 1, 2)
|
||||
mask = torch.from_numpy(mask)
|
||||
return mask.to(self.device)
|
||||
|
||||
def _has_transparency(self, image):
|
||||
if image.info.get("transparency", None) is not None:
|
||||
return True
|
||||
@@ -953,6 +966,17 @@ class Generate:
|
||||
colored += 1
|
||||
return colored == 0
|
||||
|
||||
def _transparency_check_and_warning(self,image, mask):
|
||||
if not mask:
|
||||
print(
|
||||
'>> Initial image has transparent areas. Will inpaint in these regions.')
|
||||
if self._check_for_erasure(image):
|
||||
print(
|
||||
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
|
||||
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
|
||||
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
|
||||
)
|
||||
|
||||
def _squeeze_image(self, image):
|
||||
x, y, resize_needed = self._resolution_check(image.width, image.height)
|
||||
if resize_needed:
|
||||
@@ -996,26 +1020,6 @@ class Generate:
|
||||
def _has_cuda(self):
|
||||
return self.device.type == 'cuda'
|
||||
|
||||
def _cached_sha256(self,path,data):
|
||||
dirname = os.path.dirname(path)
|
||||
basename = os.path.basename(path)
|
||||
base, _ = os.path.splitext(basename)
|
||||
hashpath = os.path.join(dirname,base+'.sha256')
|
||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
print(f'>> Calculating sha256 hash of weights file')
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
sha.update(data)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
||||
with open(hashpath,'w') as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
def write_intermediate_images(self,modulus,path):
|
||||
counter = -1
|
||||
if not os.path.exists(path):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Helper class for dealing with image generation arguments.
|
||||
|
||||
The Args class parses both the command line (shell) arguments, as well as the
|
||||
command string passed at the dream> prompt. It serves as the definitive repository
|
||||
command string passed at the invoke> prompt. It serves as the definitive repository
|
||||
of all the arguments used by Generate and their default values, and implements the
|
||||
preliminary metadata standards discussed here:
|
||||
|
||||
@@ -19,7 +19,7 @@ To use:
|
||||
print('oops')
|
||||
sys.exit(-1)
|
||||
|
||||
# read in a command passed to the dream> prompt:
|
||||
# read in a command passed to the invoke> prompt:
|
||||
opts = opt.parse_cmd('do androids dream of electric sheep? -H256 -W1024 -n4')
|
||||
|
||||
# The Args object acts like a namespace object
|
||||
@@ -64,7 +64,7 @@ To generate a dict representing RFC266 metadata:
|
||||
This will generate an RFC266 dictionary that can then be turned into a JSON
|
||||
and written to the PNG file. The optional seeds, weights, model_hash and
|
||||
postprocesser arguments are not available to the opt object and so must be
|
||||
provided externally. See how dream.py does it.
|
||||
provided externally. See how invoke.py does it.
|
||||
|
||||
Note that this function was originally called format_metadata() and a wrapper
|
||||
is provided that issues a deprecation notice.
|
||||
@@ -82,6 +82,7 @@ with metadata_from_png():
|
||||
|
||||
import argparse
|
||||
from argparse import Namespace, RawTextHelpFormatter
|
||||
import pydoc
|
||||
import shlex
|
||||
import json
|
||||
import hashlib
|
||||
@@ -90,8 +91,8 @@ import re
|
||||
import copy
|
||||
import base64
|
||||
import functools
|
||||
import ldm.dream.pngwriter
|
||||
from ldm.dream.conditioning import split_weighted_subprompts
|
||||
import ldm.invoke.pngwriter
|
||||
from ldm.invoke.conditioning import split_weighted_subprompts
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
'ddim',
|
||||
@@ -115,12 +116,42 @@ PRECISION_CHOICES = [
|
||||
APP_ID = 'lstein/stable-diffusion'
|
||||
APP_VERSION = 'v1.15'
|
||||
|
||||
class ArgFormatter(argparse.RawTextHelpFormatter):
|
||||
# use defined argument order to display usage
|
||||
def _format_usage(self, usage, actions, groups, prefix):
|
||||
if prefix is None:
|
||||
prefix = 'usage: '
|
||||
|
||||
# if usage is specified, use that
|
||||
if usage is not None:
|
||||
usage = usage % dict(prog=self._prog)
|
||||
|
||||
# if no optionals or positionals are available, usage is just prog
|
||||
elif usage is None and not actions:
|
||||
usage = 'invoke>'
|
||||
elif usage is None:
|
||||
prog='invoke>'
|
||||
# build full usage string
|
||||
action_usage = self._format_actions_usage(actions, groups) # NEW
|
||||
usage = ' '.join([s for s in [prog, action_usage] if s])
|
||||
# omit the long line wrapping code
|
||||
# prefix with 'usage:'
|
||||
return '%s%s\n\n' % (prefix, usage)
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
'''
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
'''
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
class Args(object):
|
||||
def __init__(self,arg_parser=None,cmd_parser=None):
|
||||
'''
|
||||
Initialize new Args class. It takes two optional arguments, an argparse
|
||||
parser for switches given on the shell command line, and an argparse
|
||||
parser for switches given on the dream> CLI line. If one or both are
|
||||
parser for switches given on the invoke> CLI line. If one or both are
|
||||
missing, it creates appropriate parsers internally.
|
||||
'''
|
||||
self._arg_parser = arg_parser or self._create_arg_parser()
|
||||
@@ -137,7 +168,7 @@ class Args(object):
|
||||
return None
|
||||
|
||||
def parse_cmd(self,cmd_string):
|
||||
'''Parse a dream>-style command string '''
|
||||
'''Parse a invoke>-style command string '''
|
||||
command = cmd_string.replace("'", "\\'")
|
||||
try:
|
||||
elements = shlex.split(command)
|
||||
@@ -208,12 +239,17 @@ class Args(object):
|
||||
switches.append(f'--init_color {a["init_color"]}')
|
||||
if a['strength'] and a['strength']>0:
|
||||
switches.append(f'-f {a["strength"]}')
|
||||
if a['inpaint_replace']:
|
||||
switches.append(f'--inpaint_replace')
|
||||
else:
|
||||
switches.append(f'-A {a["sampler_name"]}')
|
||||
|
||||
# gfpgan-specific parameters
|
||||
if a['gfpgan_strength']:
|
||||
switches.append(f'-G {a["gfpgan_strength"]}')
|
||||
# facetool-specific parameters, only print if running facetool
|
||||
if a['facetool_strength']:
|
||||
switches.append(f'-G {a["facetool_strength"]}')
|
||||
switches.append(f'-ft {a["facetool"]}')
|
||||
if a["facetool"] == "codeformer":
|
||||
switches.append(f'-cf {a["codeformer_fidelity"]}')
|
||||
|
||||
if a['outcrop']:
|
||||
switches.append(f'-c {" ".join([str(u) for u in a["outcrop"]])}')
|
||||
@@ -231,14 +267,15 @@ class Args(object):
|
||||
# outpainting parameters
|
||||
if a['out_direction']:
|
||||
switches.append(f'-D {" ".join([str(u) for u in a["out_direction"]])}')
|
||||
|
||||
# LS: slight semantic drift which needs addressing in the future:
|
||||
# 1. Variations come out of the stored metadata as a packed string with the keyword "variations"
|
||||
# 2. However, they come out of the CLI (and probably web) with the keyword "with_variations" and
|
||||
# in broken-out form. Variation (1) should be changed to comply with (2)
|
||||
if a['with_variations']:
|
||||
if a['with_variations'] and len(a['with_variations'])>0:
|
||||
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"]))
|
||||
switches.append(f'-V {formatted_variations}')
|
||||
if 'variations' in a:
|
||||
if 'variations' in a and len(a['variations'])>0:
|
||||
switches.append(f'-V {a["variations"]}')
|
||||
return ' '.join(switches)
|
||||
|
||||
@@ -341,6 +378,14 @@ class Args(object):
|
||||
default='stable-diffusion-1.4',
|
||||
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--png_compression','-z',
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0,9),
|
||||
dest='png_compression',
|
||||
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--sampler',
|
||||
'-A',
|
||||
@@ -478,23 +523,23 @@ class Args(object):
|
||||
)
|
||||
return parser
|
||||
|
||||
# This creates the parser that processes commands on the dream> command line
|
||||
# This creates the parser that processes commands on the invoke> command line
|
||||
def _create_dream_cmd_parser(self):
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
parser = PagingArgumentParser(
|
||||
formatter_class=ArgFormatter,
|
||||
description=
|
||||
"""
|
||||
*Image generation:*
|
||||
dream> a fantastic alien landscape -W576 -H512 -s60 -n4
|
||||
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
|
||||
|
||||
*postprocessing*
|
||||
!fix applies upscaling/facefixing to a previously-generated image.
|
||||
dream> !fix 0000045.4829112.png -G1 -U4 -ft codeformer
|
||||
invoke> !fix 0000045.4829112.png -G1 -U4 -ft codeformer
|
||||
|
||||
*History manipulation*
|
||||
!fetch retrieves the command used to generate an earlier image.
|
||||
dream> !fetch 0000015.8929913.png
|
||||
dream> a fantastic alien landscape -W 576 -H 512 -s 60 -A plms -C 7.5
|
||||
invoke> !fetch 0000015.8929913.png
|
||||
invoke> a fantastic alien landscape -W 576 -H 512 -s 60 -A plms -C 7.5
|
||||
|
||||
!history lists all the commands issued during the current session.
|
||||
|
||||
@@ -605,6 +650,21 @@ class Args(object):
|
||||
dest='hires_fix',
|
||||
help='Create hires image using img2img to prevent duplicated objects'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--save_intermediates',
|
||||
type=int,
|
||||
default=0,
|
||||
dest='save_intermediates',
|
||||
help='Save every nth intermediate image into an "intermediates" directory within the output directory'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--png_compression','-z',
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0,10),
|
||||
dest='png_compression',
|
||||
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
'-I',
|
||||
'--init_img',
|
||||
@@ -617,6 +677,14 @@ class Args(object):
|
||||
type=str,
|
||||
help='Path to input mask for inpainting mode (supersedes width and height)',
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
'-tm',
|
||||
'--text_mask',
|
||||
nargs='+',
|
||||
type=str,
|
||||
help='Use the clipseg classifier to generate the mask area for inpainting. Provide a description of the area to mask ("a mug"), optionally followed by the confidence level threshold (0-1.0; defaults to 0.5).',
|
||||
default=None,
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
'--init_color',
|
||||
type=str,
|
||||
@@ -652,6 +720,13 @@ class Args(object):
|
||||
metavar=('direction','pixels'),
|
||||
help='Outcrop the image with one or more direction/pixel pairs: -c top 64 bottom 128 left 64 right 64',
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
'-r',
|
||||
'--inpaint_replace',
|
||||
type=float,
|
||||
default=0.0,
|
||||
help='when inpainting, adjust how aggressively to replace the part of the picture under the mask, from 0.0 (a gentle merge) to 1.0 (replace entirely)',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'-ft',
|
||||
'--facetool',
|
||||
@@ -661,6 +736,7 @@ class Args(object):
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'-G',
|
||||
'--facetool_strength',
|
||||
'--gfpgan_strength',
|
||||
type=float,
|
||||
help='The strength at which to apply the face restoration to the result.',
|
||||
@@ -708,6 +784,12 @@ class Args(object):
|
||||
action='store_true',
|
||||
help='Change the model to seamless tiling (circular) mode',
|
||||
)
|
||||
special_effects_group.add_argument(
|
||||
'--seamless_axes',
|
||||
default=['x', 'y'],
|
||||
type=list[str],
|
||||
help='Specify which axes to use circular convolution on.',
|
||||
)
|
||||
variation_group.add_argument(
|
||||
'-v',
|
||||
'--variation_amount',
|
||||
@@ -757,7 +839,8 @@ def metadata_dumps(opt,
|
||||
|
||||
# remove any image keys not mentioned in RFC #266
|
||||
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
|
||||
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength']
|
||||
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength',
|
||||
'init_img','init_mask']
|
||||
|
||||
rfc_dict ={}
|
||||
|
||||
@@ -778,11 +861,15 @@ def metadata_dumps(opt,
|
||||
# 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
|
||||
rfc_dict['variations'] = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations] if opt.with_variations else []
|
||||
|
||||
# if variations are present then we need to replace 'seed' with 'orig_seed'
|
||||
if hasattr(opt,'first_seed'):
|
||||
rfc_dict['seed'] = opt.first_seed
|
||||
|
||||
if opt.init_img:
|
||||
rfc_dict['type'] = 'img2img'
|
||||
rfc_dict['strength_steps'] = rfc_dict.pop('strength')
|
||||
rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img)
|
||||
rfc_dict['sampler'] = 'ddim' # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
|
||||
rfc_dict['type'] = 'img2img'
|
||||
rfc_dict['strength_steps'] = rfc_dict.pop('strength')
|
||||
rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img)
|
||||
rfc_dict['inpaint_replace'] = opt.inpaint_replace
|
||||
else:
|
||||
rfc_dict['type'] = 'txt2img'
|
||||
rfc_dict.pop('strength')
|
||||
@@ -811,7 +898,7 @@ def metadata_from_png(png_file_path) -> Args:
|
||||
an Args object containing the image metadata. Note that this
|
||||
returns a single Args object, not multiple.
|
||||
'''
|
||||
meta = ldm.dream.pngwriter.retrieve_metadata(png_file_path)
|
||||
meta = ldm.invoke.pngwriter.retrieve_metadata(png_file_path)
|
||||
if 'sd-metadata' in meta and len(meta['sd-metadata'])>0 :
|
||||
return metadata_loads(meta)[0]
|
||||
else:
|
||||
4
ldm/invoke/generator/__init__.py
Normal file
4
ldm/invoke/generator/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for the ldm.invoke.generator package
|
||||
'''
|
||||
from .base import Generator
|
||||
@@ -1,15 +1,16 @@
|
||||
'''
|
||||
Base class for ldm.dream.generator.*
|
||||
Base class for ldm.invoke.generator.*
|
||||
including img2img, txt2img, and inpaint
|
||||
'''
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
from tqdm import tqdm, trange
|
||||
from PIL import Image
|
||||
from einops import rearrange, repeat
|
||||
from pytorch_lightning import seed_everything
|
||||
from ldm.dream.devices import choose_autocast
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
downsampling = 8
|
||||
@@ -21,6 +22,8 @@ class Generator():
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
|
||||
@@ -122,8 +125,8 @@ class Generator():
|
||||
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
||||
|
||||
def get_perlin_noise(self,width,height):
|
||||
return torch.stack([rand_perlin_2d((height, width), (8, 8)).to(self.model.device) for _ in range(self.latent_channels)], dim=0)
|
||||
|
||||
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
||||
|
||||
def new_seed(self):
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
@@ -166,3 +169,14 @@ class Generator():
|
||||
|
||||
return v2
|
||||
|
||||
# this is a handy routine for debugging use. Given a generated sample,
|
||||
# convert it into a PNG image and store it at the indicated path
|
||||
def save_sample(self, sample, filepath):
|
||||
image = self.sample_to_image(sample)
|
||||
dirname = os.path.dirname(filepath) or '.'
|
||||
if not os.path.exists(dirname):
|
||||
print(f'** creating directory {dirname}')
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath,'PNG')
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
'''
|
||||
ldm.dream.generator.embiggen descends from ldm.dream.generator
|
||||
and generates with ldm.dream.generator.img2img
|
||||
ldm.invoke.generator.embiggen descends from ldm.invoke.generator
|
||||
and generates with ldm.invoke.generator.img2img
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
from PIL import Image
|
||||
from ldm.dream.generator.base import Generator
|
||||
from ldm.dream.generator.img2img import Img2Img
|
||||
from ldm.dream.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
class Embiggen(Generator):
|
||||
@@ -107,7 +107,7 @@ class Embiggen(Generator):
|
||||
initsuperwidth = round(initsuperwidth*embiggen[0])
|
||||
initsuperheight = round(initsuperheight*embiggen[0])
|
||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||
from ldm.dream.restoration.realesrgan import ESRGAN
|
||||
from ldm.invoke.restoration.realesrgan import ESRGAN
|
||||
esrgan = ESRGAN()
|
||||
print(
|
||||
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
|
||||
@@ -1,11 +1,11 @@
|
||||
'''
|
||||
ldm.dream.generator.img2img descends from ldm.dream.generator
|
||||
ldm.invoke.generator.img2img descends from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.dream.devices import choose_autocast
|
||||
from ldm.dream.generator.base import Generator
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
class Img2Img(Generator):
|
||||
@@ -49,6 +49,7 @@ class Img2Img(Generator):
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
||||
)
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
@@ -1,12 +1,12 @@
|
||||
'''
|
||||
ldm.dream.generator.inpaint descends from ldm.dream.generator
|
||||
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import rearrange, repeat
|
||||
from ldm.dream.devices import choose_autocast
|
||||
from ldm.dream.generator.img2img import Img2Img
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
|
||||
@@ -18,7 +18,7 @@ class Inpaint(Img2Img):
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,mask_image,strength,
|
||||
step_callback=None,**kwargs):
|
||||
step_callback=None,inpaint_replace=False,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and
|
||||
the initial image + mask. Return value depends on the seed at
|
||||
@@ -27,7 +27,7 @@ class Inpaint(Img2Img):
|
||||
# klms samplers not supported yet, so ignore previous sampler
|
||||
if isinstance(sampler,KSampler):
|
||||
print(
|
||||
f">> sampler '{sampler.__class__.__name__}' is not yet supported for inpainting, using DDIMSampler instead."
|
||||
f">> Using recommended DDIM sampler for inpainting."
|
||||
)
|
||||
sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
|
||||
@@ -58,6 +58,14 @@ class Inpaint(Img2Img):
|
||||
noise=x_T
|
||||
)
|
||||
|
||||
# to replace masked area with latent noise, weighted by inpaint_replace strength
|
||||
if inpaint_replace > 0.0:
|
||||
print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
|
||||
l_noise = self.get_noise(kwargs['width'],kwargs['height'])
|
||||
inverted_mask = 1.0-mask_image # there will be 1s where the mask is
|
||||
masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
|
||||
z_enc = z_enc * mask_image + masked_region
|
||||
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
@@ -1,10 +1,10 @@
|
||||
'''
|
||||
ldm.dream.generator.txt2img inherits from ldm.dream.generator
|
||||
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.dream.generator.base import Generator
|
||||
from ldm.invoke.generator.base import Generator
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@@ -74,3 +74,4 @@ class Txt2Img(Generator):
|
||||
if self.perlin > 0.0:
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||
return x
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
'''
|
||||
ldm.dream.generator.txt2img inherits from ldm.dream.generator
|
||||
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
from ldm.dream.generator.base import Generator
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ class Txt2Img2Img(Generator):
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height}"
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
@@ -75,17 +75,19 @@ class Txt2Img2Img(Generator):
|
||||
)
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
ddim_sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
ddim_sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
x = self.get_noise(width,height,False)
|
||||
|
||||
z_enc = sampler.stochastic_encode(
|
||||
z_enc = ddim_sampler.stochastic_encode(
|
||||
samples,
|
||||
torch.tensor([t_enc]).to(self.model.device),
|
||||
noise=x
|
||||
noise=self.get_noise(width,height,False)
|
||||
)
|
||||
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
samples = ddim_sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
281
ldm/invoke/model_cache.py
Normal file
281
ldm/invoke/model_cache.py
Normal file
@@ -0,0 +1,281 @@
|
||||
'''
|
||||
Manage a cache of Stable Diffusion model files for fast switching.
|
||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||
below a preset minimum, the least recently used model will be
|
||||
cleared and loaded from disk when next needed.
|
||||
'''
|
||||
|
||||
import torch
|
||||
import os
|
||||
import io
|
||||
import time
|
||||
import gc
|
||||
import hashlib
|
||||
import psutil
|
||||
import transformers
|
||||
from sys import getrefcount
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.errors import ConfigAttributeError
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
GIGS=2**30
|
||||
AVG_MODEL_SIZE=2.1*GIGS
|
||||
DEFAULT_MIN_AVAIL=2*GIGS
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_avail_mem=DEFAULT_MIN_AVAIL):
|
||||
'''
|
||||
Initialize with the path to the models.yaml config file,
|
||||
the torch device type, and precision. The optional
|
||||
min_avail_mem argument specifies how much unused system
|
||||
(CPU) memory to preserve. The cache of models in RAM will
|
||||
grow until this value is approached. Default is 2G.
|
||||
'''
|
||||
# prevent nasty-looking CLIP log message
|
||||
transformers.logging.set_verbosity_error()
|
||||
self.config = config
|
||||
self.precision = precision
|
||||
self.device = torch.device(device_type)
|
||||
self.min_avail_mem = min_avail_mem
|
||||
self.models = {}
|
||||
self.stack = [] # this is an LRU FIFO
|
||||
self.current_model = None
|
||||
|
||||
def get_model(self, model_name:str):
|
||||
'''
|
||||
Given a model named identified in models.yaml, return
|
||||
the model object. If in RAM will load into GPU VRAM.
|
||||
If on disk, will load from there.
|
||||
'''
|
||||
if model_name not in self.config:
|
||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return None
|
||||
|
||||
if self.current_model != model_name:
|
||||
self.unload_model(self.current_model)
|
||||
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]['model']
|
||||
print(f'>> Retrieving model {model_name} from system RAM cache')
|
||||
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
|
||||
width = self.models[model_name]['width']
|
||||
height = self.models[model_name]['height']
|
||||
hash = self.models[model_name]['hash']
|
||||
else:
|
||||
self._check_memory()
|
||||
try:
|
||||
requested_model, width, height, hash = self._load_model(model_name)
|
||||
self.models[model_name] = {}
|
||||
self.models[model_name]['model'] = requested_model
|
||||
self.models[model_name]['width'] = width
|
||||
self.models[model_name]['height'] = height
|
||||
self.models[model_name]['hash'] = hash
|
||||
except Exception as e:
|
||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||
print(f'** restoring {self.current_model}')
|
||||
return self.get_model(self.current_model)
|
||||
|
||||
self.current_model = model_name
|
||||
self._push_newest_model(model_name)
|
||||
return {
|
||||
'model':requested_model,
|
||||
'width':width,
|
||||
'height':height,
|
||||
'hash': hash
|
||||
}
|
||||
|
||||
def list_models(self) -> dict:
|
||||
'''
|
||||
Return a dict of models in the format:
|
||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
||||
'description': description,
|
||||
},
|
||||
model_name2: { etc }
|
||||
'''
|
||||
result = {}
|
||||
for name in self.config:
|
||||
try:
|
||||
description = self.config[name].description
|
||||
except ConfigAttributeError:
|
||||
description = '<no description>'
|
||||
if self.current_model == name:
|
||||
status = 'active'
|
||||
elif name in self.models:
|
||||
status = 'cached'
|
||||
else:
|
||||
status = 'not loaded'
|
||||
result[name]={}
|
||||
result[name]['status']=status
|
||||
result[name]['description']=description
|
||||
return result
|
||||
|
||||
def print_models(self):
|
||||
'''
|
||||
Print a table of models, their descriptions, and load status
|
||||
'''
|
||||
models = self.list_models()
|
||||
for name in models:
|
||||
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}'
|
||||
if models[name]['status'] == 'active':
|
||||
print(f'\033[1m{line}\033[0m')
|
||||
else:
|
||||
print(line)
|
||||
|
||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
|
||||
'''
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory and a YAML
|
||||
string will be returned.
|
||||
'''
|
||||
omega = self.config
|
||||
# check that all the required fields are present
|
||||
for field in ('description','weights','height','width','config'):
|
||||
assert field in model_attributes, f'required field {field} is missing'
|
||||
|
||||
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
|
||||
config = omega[model_name] if model_name in omega else {}
|
||||
for field in model_attributes:
|
||||
config[field] = model_attributes[field]
|
||||
|
||||
omega[model_name] = config
|
||||
return OmegaConf.to_yaml(omega)
|
||||
|
||||
def _check_memory(self):
|
||||
avail_memory = psutil.virtual_memory()[1]
|
||||
if AVG_MODEL_SIZE + self.min_avail_mem > avail_memory:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
gc.collect()
|
||||
|
||||
|
||||
def _load_model(self, model_name:str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if model_name not in self.config:
|
||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return None
|
||||
|
||||
mconfig = self.config[model_name]
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
print(f'>> Loading {model_name} from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
c = OmegaConf.load(config)
|
||||
with open(weights,'rb') as f:
|
||||
weight_bytes = f.read()
|
||||
model_hash = self._cached_sha256(weights,weight_bytes)
|
||||
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||
del weight_bytes
|
||||
sd = pl_sd['state_dict']
|
||||
model = instantiate_from_config(c.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if self.precision == 'float16':
|
||||
print(' | Using faster float16 precision')
|
||||
model.to(torch.float16)
|
||||
else:
|
||||
print(' | Using more accurate float32 precision')
|
||||
|
||||
model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
model.cond_stage_model.device = self.device
|
||||
model.eval()
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||
m._orig_padding_mode = m.padding_mode
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
return model, width, height, model_hash
|
||||
|
||||
def unload_model(self, model_name:str):
|
||||
if model_name not in self.models:
|
||||
return
|
||||
print(f'>> Caching model {model_name} in system RAM')
|
||||
model = self.models[model_name]['model']
|
||||
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||
gc.collect()
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _model_to_cpu(self,model):
|
||||
if self.device != 'cpu':
|
||||
model.cond_stage_model.device = 'cpu'
|
||||
model.first_stage_model.to('cpu')
|
||||
model.cond_stage_model.to('cpu')
|
||||
model.model.to('cpu')
|
||||
return model.to('cpu')
|
||||
else:
|
||||
return model
|
||||
|
||||
def _model_from_cpu(self,model):
|
||||
if self.device != 'cpu':
|
||||
model.to(self.device)
|
||||
model.first_stage_model.to(self.device)
|
||||
model.cond_stage_model.to(self.device)
|
||||
model.cond_stage_model.device = self.device
|
||||
return model
|
||||
|
||||
def _pop_oldest_model(self):
|
||||
'''
|
||||
Remove the first element of the FIFO, which ought
|
||||
to be the least recently accessed model. Do not
|
||||
pop the last one, because it is in active use!
|
||||
'''
|
||||
if len(self.stack) > 1:
|
||||
return self.stack.pop(0)
|
||||
|
||||
def _push_newest_model(self,model_name:str):
|
||||
'''
|
||||
Maintain a simple FIFO. First element is always the
|
||||
least recent, and last element is always the most recent.
|
||||
'''
|
||||
try:
|
||||
self.stack.remove(model_name)
|
||||
except ValueError:
|
||||
pass
|
||||
self.stack.append(model_name)
|
||||
|
||||
def _has_cuda(self):
|
||||
return self.device.type == 'cuda'
|
||||
|
||||
def _cached_sha256(self,path,data):
|
||||
dirname = os.path.dirname(path)
|
||||
basename = os.path.basename(path)
|
||||
base, _ = os.path.splitext(basename)
|
||||
hashpath = os.path.join(dirname,base+'.sha256')
|
||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
print(f'>> Calculating sha256 hash of weights file')
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
sha.update(data)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
||||
with open(hashpath,'w') as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
@@ -33,13 +33,13 @@ class PngWriter:
|
||||
|
||||
# saves image named _image_ to outdir/name, writing metadata from prompt
|
||||
# returns full path of output
|
||||
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None):
|
||||
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6):
|
||||
path = os.path.join(self.outdir, name)
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text('Dream', dream_prompt)
|
||||
if metadata:
|
||||
info.add_text('sd-metadata', json.dumps(metadata))
|
||||
image.save(path, 'PNG', pnginfo=info)
|
||||
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level)
|
||||
return path
|
||||
|
||||
def retrieve_metadata(self,img_basename):
|
||||
@@ -66,3 +66,43 @@ def write_metadata(img_path:str, meta:dict):
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text('sd-metadata', json.dumps(meta))
|
||||
im.save(img_path,'PNG',pnginfo=info)
|
||||
|
||||
class PromptFormatter:
|
||||
def __init__(self, t2i, opt):
|
||||
self.t2i = t2i
|
||||
self.opt = opt
|
||||
|
||||
# note: the t2i object should provide all these values.
|
||||
# there should be no need to or against opt values
|
||||
def normalize_prompt(self):
|
||||
"""Normalize the prompt and switches"""
|
||||
t2i = self.t2i
|
||||
opt = self.opt
|
||||
|
||||
switches = list()
|
||||
switches.append(f'"{opt.prompt}"')
|
||||
switches.append(f'-s{opt.steps or t2i.steps}')
|
||||
switches.append(f'-W{opt.width or t2i.width}')
|
||||
switches.append(f'-H{opt.height or t2i.height}')
|
||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||
switches.append(f'-A{opt.sampler_name or t2i.sampler_name}')
|
||||
# to do: put model name into the t2i object
|
||||
# switches.append(f'--model{t2i.model_name}')
|
||||
if opt.seamless or t2i.seamless:
|
||||
switches.append(f'--seamless')
|
||||
if opt.init_img:
|
||||
switches.append(f'-I{opt.init_img}')
|
||||
if opt.fit:
|
||||
switches.append(f'--fit')
|
||||
if opt.strength and opt.init_img is not None:
|
||||
switches.append(f'-f{opt.strength or t2i.strength}')
|
||||
if opt.gfpgan_strength:
|
||||
switches.append(f'-G{opt.gfpgan_strength}')
|
||||
if opt.upscale:
|
||||
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
|
||||
if opt.variation_amount > 0:
|
||||
switches.append(f'-v{opt.variation_amount}')
|
||||
if opt.with_variations:
|
||||
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in opt.with_variations)
|
||||
switches.append(f'-V{formatted_variations}')
|
||||
return ' '.join(switches)
|
||||
@@ -1,17 +1,17 @@
|
||||
"""
|
||||
Readline helper functions for dream.py (linux and mac only).
|
||||
Readline helper functions for invoke.py.
|
||||
You may import the global singleton `completer` to get access to the
|
||||
completer object itself. This is useful when you want to autocomplete
|
||||
seeds:
|
||||
|
||||
from ldm.dream.readline import completer
|
||||
from ldm.invoke.readline import completer
|
||||
completer.add_seed(18247566)
|
||||
completer.add_seed(9281839)
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import atexit
|
||||
from ldm.dream.args import Args
|
||||
from ldm.invoke.args import Args
|
||||
|
||||
# ---------------readline utilities---------------------
|
||||
try:
|
||||
@@ -20,7 +20,9 @@ try:
|
||||
except (ImportError,ModuleNotFoundError):
|
||||
readline_available = False
|
||||
|
||||
IMG_EXTENSIONS = ('.png','.jpg','.jpeg')
|
||||
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
|
||||
WEIGHT_EXTENSIONS = ('.ckpt','.bae')
|
||||
CONFIG_EXTENSIONS = ('.yaml','.yml')
|
||||
COMMANDS = (
|
||||
'--steps','-s',
|
||||
'--seed','-S',
|
||||
@@ -31,6 +33,7 @@ COMMANDS = (
|
||||
'--perlin',
|
||||
'--grid','-g',
|
||||
'--individual','-i',
|
||||
'--save_intermediates',
|
||||
'--init_img','-I',
|
||||
'--init_mask','-M',
|
||||
'--init_color',
|
||||
@@ -41,13 +44,27 @@ COMMANDS = (
|
||||
'--embedding_path',
|
||||
'--device',
|
||||
'--grid','-g',
|
||||
'--gfpgan_strength','-G',
|
||||
'--facetool','-ft',
|
||||
'--facetool_strength','-G',
|
||||
'--codeformer_fidelity','-cf',
|
||||
'--upscale','-U',
|
||||
'-save_orig','--save_original',
|
||||
'--skip_normalize','-x',
|
||||
'--log_tokenization','-t',
|
||||
'--hires_fix',
|
||||
'--inpaint_replace','-r',
|
||||
'--png_compression','-z',
|
||||
'--text_mask','-tm',
|
||||
'!fix','!fetch','!history','!search','!clear',
|
||||
'!mask',
|
||||
'!models','!switch','!import_model','!edit_model'
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
'!switch',
|
||||
'!edit_model',
|
||||
)
|
||||
WEIGHT_COMMANDS = (
|
||||
'!import_model',
|
||||
)
|
||||
IMG_PATH_COMMANDS = (
|
||||
'--outdir[=\s]',
|
||||
@@ -55,32 +72,42 @@ IMG_PATH_COMMANDS = (
|
||||
IMG_FILE_COMMANDS=(
|
||||
'!fix',
|
||||
'!fetch',
|
||||
'!mask',
|
||||
'--init_img[=\s]','-I',
|
||||
'--init_mask[=\s]','-M',
|
||||
'--init_color[=\s]',
|
||||
'--embedding_path[=\s]',
|
||||
)
|
||||
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
|
||||
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
|
||||
weight_regexp = '('+'|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
|
||||
|
||||
class Completer(object):
|
||||
def __init__(self, options):
|
||||
def __init__(self, options, models=[]):
|
||||
self.options = sorted(options)
|
||||
self.models = sorted(models)
|
||||
self.seeds = set()
|
||||
self.matches = list()
|
||||
self.default_dir = None
|
||||
self.linebuffer = None
|
||||
self.auto_history_active = True
|
||||
self.extensions = None
|
||||
return
|
||||
|
||||
def complete(self, text, state):
|
||||
'''
|
||||
Completes dream command line.
|
||||
Completes invoke command line.
|
||||
BUG: it doesn't correctly complete files that have spaces in the name.
|
||||
'''
|
||||
buffer = readline.get_line_buffer()
|
||||
|
||||
if state == 0:
|
||||
if re.search(path_regexp,buffer):
|
||||
|
||||
# extensions defined, so go directly into path completion mode
|
||||
if self.extensions is not None:
|
||||
self.matches = self._path_completions(text, state, self.extensions)
|
||||
|
||||
# looking for an image file
|
||||
elif re.search(path_regexp,buffer):
|
||||
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
|
||||
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
|
||||
|
||||
@@ -88,6 +115,13 @@ class Completer(object):
|
||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||
self.matches= self._seed_completions(text,state)
|
||||
|
||||
# looking for a model
|
||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state)
|
||||
|
||||
elif re.search(weight_regexp,buffer):
|
||||
self.matches = self._path_completions(text, state, WEIGHT_EXTENSIONS)
|
||||
|
||||
# This is the first time for this text, so build a match list.
|
||||
elif text:
|
||||
self.matches = [
|
||||
@@ -104,6 +138,13 @@ class Completer(object):
|
||||
response = None
|
||||
return response
|
||||
|
||||
def complete_extensions(self, extensions:list):
|
||||
'''
|
||||
If called with a list of extensions, will force completer
|
||||
to do file path completions.
|
||||
'''
|
||||
self.extensions=extensions
|
||||
|
||||
def add_history(self,line):
|
||||
'''
|
||||
Pass thru to readline
|
||||
@@ -188,6 +229,21 @@ class Completer(object):
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _model_completions(self, text, state):
|
||||
m = re.search('(!switch\s+)(\w*)',text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
partial = m.groups()[1]
|
||||
else:
|
||||
switch = ''
|
||||
partial = text
|
||||
matches = list()
|
||||
for s in self.models:
|
||||
if s.startswith(partial):
|
||||
matches.append(switch+s)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
readline.insert_text(self.linebuffer)
|
||||
@@ -266,9 +322,9 @@ class DummyCompleter(Completer):
|
||||
def set_line(self,line):
|
||||
print(f'# {line}')
|
||||
|
||||
def get_completer(opt:Args)->Completer:
|
||||
def get_completer(opt:Args, models=[])->Completer:
|
||||
if readline_available:
|
||||
completer = Completer(COMMANDS)
|
||||
completer = Completer(COMMANDS,models)
|
||||
|
||||
readline.set_completer(
|
||||
completer.complete
|
||||
@@ -287,7 +343,7 @@ def get_completer(opt:Args)->Completer:
|
||||
readline.parse_and_bind('set skip-completed-text on')
|
||||
readline.parse_and_bind('set show-all-if-ambiguous on')
|
||||
|
||||
histfile = os.path.join(os.path.expanduser(opt.outdir), '.dream_history')
|
||||
histfile = os.path.join(os.path.expanduser(opt.outdir), '.invoke_history')
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
4
ldm/invoke/restoration/__init__.py
Normal file
4
ldm/invoke/restoration/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for the ldm.invoke.restoration package
|
||||
'''
|
||||
from .base import Restoration
|
||||
@@ -23,16 +23,16 @@ class Restoration():
|
||||
|
||||
# Face Restore Models
|
||||
def load_gfpgan(self, gfpgan_dir, gfpgan_model_path):
|
||||
from ldm.dream.restoration.gfpgan import GFPGAN
|
||||
from ldm.invoke.restoration.gfpgan import GFPGAN
|
||||
return GFPGAN(gfpgan_dir, gfpgan_model_path)
|
||||
|
||||
def load_codeformer(self):
|
||||
from ldm.dream.restoration.codeformer import CodeFormerRestoration
|
||||
from ldm.invoke.restoration.codeformer import CodeFormerRestoration
|
||||
return CodeFormerRestoration()
|
||||
|
||||
# Upscale Models
|
||||
def load_esrgan(self, esrgan_bg_tile=400):
|
||||
from ldm.dream.restoration.realesrgan import ESRGAN
|
||||
from ldm.invoke.restoration.realesrgan import ESRGAN
|
||||
esrgan = ESRGAN(esrgan_bg_tile)
|
||||
print('>> ESRGAN Initialized')
|
||||
return esrgan;
|
||||
@@ -8,7 +8,7 @@ pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v
|
||||
|
||||
class CodeFormerRestoration():
|
||||
def __init__(self,
|
||||
codeformer_dir='ldm/dream/restoration/codeformer',
|
||||
codeformer_dir='ldm/invoke/restoration/codeformer',
|
||||
codeformer_model_path='weights/codeformer.pth') -> None:
|
||||
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||
@@ -27,7 +27,7 @@ class CodeFormerRestoration():
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from ldm.dream.restoration.codeformer_arch import CodeFormer
|
||||
from ldm.invoke.restoration.codeformer_arch import CodeFormer
|
||||
from torchvision.transforms.functional import normalize
|
||||
from PIL import Image
|
||||
|
||||
@@ -35,16 +35,18 @@ class CodeFormerRestoration():
|
||||
|
||||
cf = cf_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
|
||||
|
||||
checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/dream/restoration/codeformer/weights'), progress=True)
|
||||
checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/invoke/restoration/codeformer/weights'), progress=True)
|
||||
checkpoint = torch.load(checkpoint_path)['params_ema']
|
||||
cf.load_state_dict(checkpoint)
|
||||
cf.eval()
|
||||
|
||||
image = image.convert('RGB')
|
||||
# Codeformer expects a BGR np array; make array and flip channels
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
|
||||
|
||||
face_helper = FaceRestoreHelper(upscale_factor=1, use_parse=True, device=device)
|
||||
face_helper.clean_all()
|
||||
face_helper.read_image(np.array(image, dtype=np.uint8))
|
||||
face_helper.read_image(bgr_image_array)
|
||||
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
|
||||
face_helper.align_warp_face()
|
||||
|
||||
@@ -71,7 +73,8 @@ class CodeFormerRestoration():
|
||||
|
||||
restored_img = face_helper.paste_faces_to_input_image()
|
||||
|
||||
res = Image.fromarray(restored_img)
|
||||
# Flip the channels back to RGB
|
||||
res = Image.fromarray(restored_img[...,::-1])
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
@@ -5,7 +5,7 @@ from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List
|
||||
|
||||
from ldm.dream.restoration.vqgan_arch import *
|
||||
from ldm.invoke.restoration.vqgan_arch import *
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
@@ -55,13 +55,18 @@ class GFPGAN():
|
||||
|
||||
image = image.convert('RGB')
|
||||
|
||||
# GFPGAN expects a BGR np array; make array and flip channels
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
|
||||
|
||||
_, _, restored_img = self.gfpgan.enhance(
|
||||
np.array(image, dtype=np.uint8),
|
||||
bgr_image_array,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
# Flip the channels back to RGB
|
||||
res = Image.fromarray(restored_img[...,::-1])
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
@@ -13,8 +13,6 @@ class Outpaint(object):
|
||||
seed = old_opt.seed
|
||||
prompt = old_opt.prompt
|
||||
|
||||
print(f'DEBUG: old seed={seed}, old prompt = {prompt}')
|
||||
|
||||
def wrapped_callback(img,seed,**kwargs):
|
||||
image_callback(img,seed,use_prefix=prefix,**kwargs)
|
||||
|
||||
@@ -60,14 +60,18 @@ class ESRGAN():
|
||||
print(
|
||||
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
|
||||
)
|
||||
|
||||
# REALSRGAN expects a BGR np array; make array and flip channels
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
|
||||
|
||||
output, _ = upsampler.enhance(
|
||||
np.array(image, dtype=np.uint8),
|
||||
bgr_image_array,
|
||||
outscale=upsampler_scale,
|
||||
alpha_upsampler='realesrgan',
|
||||
)
|
||||
|
||||
res = Image.fromarray(output)
|
||||
# Flip the channels back to RGB
|
||||
res = Image.fromarray(output[...,::-1])
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
30
ldm/invoke/seamless.py
Normal file
30
ldm/invoke/seamless.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch.nn as nn
|
||||
|
||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
"""
|
||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||
"""
|
||||
working = nn.functional.pad(input, self.asymmetric_padding['x'], mode=self.asymmetric_padding_mode['x'])
|
||||
working = nn.functional.pad(working, self.asymmetric_padding['y'], mode=self.asymmetric_padding_mode['y'])
|
||||
return nn.functional.conv2d(working, weight, bias, self.stride, nn.modules.utils._pair(0), self.dilation, self.groups)
|
||||
|
||||
def configure_model_padding(model, seamless, seamless_axes):
|
||||
"""
|
||||
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
|
||||
"""
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
if seamless:
|
||||
m.asymmetric_padding_mode = {}
|
||||
m.asymmetric_padding = {}
|
||||
m.asymmetric_padding_mode['x'] = 'circular' if ('x' in seamless_axes) else 'constant'
|
||||
m.asymmetric_padding['x'] = (m._reversed_padding_repeated_twice[0], m._reversed_padding_repeated_twice[1], 0, 0)
|
||||
m.asymmetric_padding_mode['y'] = 'circular' if ('y' in seamless_axes) else 'constant'
|
||||
m.asymmetric_padding['y'] = (0, 0, m._reversed_padding_repeated_twice[2], m._reversed_padding_repeated_twice[3])
|
||||
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
|
||||
else:
|
||||
m._conv_forward = nn.Conv2d._conv_forward.__get__(m, nn.Conv2d)
|
||||
if hasattr(m, 'asymmetric_padding_mode'):
|
||||
del m.asymmetric_padding_mode
|
||||
if hasattr(m, 'asymmetric_padding'):
|
||||
del m.asymmetric_padding
|
||||
@@ -4,9 +4,9 @@ import copy
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from ldm.dream.args import Args, metadata_dumps
|
||||
from ldm.invoke.args import Args, metadata_dumps
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
from ldm.invoke.pngwriter import PngWriter
|
||||
from threading import Event
|
||||
|
||||
def build_opt(post_data, seed, gfpgan_model_exists):
|
||||
@@ -31,12 +31,13 @@ def build_opt(post_data, seed, gfpgan_model_exists):
|
||||
setattr(opt, 'embiggen', None)
|
||||
setattr(opt, 'embiggen_tiles', None)
|
||||
|
||||
setattr(opt, 'gfpgan_strength', float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0)
|
||||
setattr(opt, 'facetool_strength', float(post_data['facetool_strength']) if gfpgan_model_exists else 0)
|
||||
setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None)
|
||||
setattr(opt, 'progress_images', 'progress_images' in post_data)
|
||||
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
|
||||
setattr(opt, 'threshold', float(post_data['threshold']))
|
||||
setattr(opt, 'perlin', float(post_data['perlin']))
|
||||
setattr(opt, 'hires_fix', 'hires_fix' in post_data)
|
||||
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
|
||||
setattr(opt, 'with_variations', [])
|
||||
setattr(opt, 'embiggen', None)
|
||||
@@ -196,7 +197,7 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
# control state of the "postprocessing..." message
|
||||
upscaling_requested = opt.upscale or opt.gfpgan_strength > 0
|
||||
upscaling_requested = opt.upscale or opt.facetool_strength > 0
|
||||
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
|
||||
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
|
||||
if upscaled:
|
||||
246
ldm/invoke/server_legacy.py
Normal file
246
ldm/invoke/server_legacy.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import argparse
|
||||
import json
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from ldm.invoke.pngwriter import PngWriter, PromptFormatter
|
||||
from threading import Event
|
||||
|
||||
def build_opt(post_data, seed, gfpgan_model_exists):
|
||||
opt = argparse.Namespace()
|
||||
setattr(opt, 'prompt', post_data['prompt'])
|
||||
setattr(opt, 'init_img', post_data['initimg'])
|
||||
setattr(opt, 'strength', float(post_data['strength']))
|
||||
setattr(opt, 'iterations', int(post_data['iterations']))
|
||||
setattr(opt, 'steps', int(post_data['steps']))
|
||||
setattr(opt, 'width', int(post_data['width']))
|
||||
setattr(opt, 'height', int(post_data['height']))
|
||||
setattr(opt, 'seamless', 'seamless' in post_data)
|
||||
setattr(opt, 'fit', 'fit' in post_data)
|
||||
setattr(opt, 'mask', 'mask' in post_data)
|
||||
setattr(opt, 'invert_mask', 'invert_mask' in post_data)
|
||||
setattr(opt, 'cfg_scale', float(post_data['cfg_scale']))
|
||||
setattr(opt, 'sampler_name', post_data['sampler_name'])
|
||||
setattr(opt, 'gfpgan_strength', float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0)
|
||||
setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None)
|
||||
setattr(opt, 'progress_images', 'progress_images' in post_data)
|
||||
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
|
||||
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
|
||||
setattr(opt, 'with_variations', [])
|
||||
|
||||
broken = False
|
||||
if int(post_data['seed']) != -1 and post_data['with_variations'] != '':
|
||||
for part in post_data['with_variations'].split(','):
|
||||
seed_and_weight = part.split(':')
|
||||
if len(seed_and_weight) != 2:
|
||||
print(f'could not parse with_variation part "{part}"')
|
||||
broken = True
|
||||
break
|
||||
try:
|
||||
seed = int(seed_and_weight[0])
|
||||
weight = float(seed_and_weight[1])
|
||||
except ValueError:
|
||||
print(f'could not parse with_variation part "{part}"')
|
||||
broken = True
|
||||
break
|
||||
opt.with_variations.append([seed, weight])
|
||||
|
||||
if broken:
|
||||
raise CanceledException
|
||||
|
||||
if len(opt.with_variations) == 0:
|
||||
opt.with_variations = None
|
||||
|
||||
return opt
|
||||
|
||||
class CanceledException(Exception):
|
||||
pass
|
||||
|
||||
class DreamServer(BaseHTTPRequestHandler):
|
||||
model = None
|
||||
outdir = None
|
||||
canceled = Event()
|
||||
|
||||
def do_GET(self):
|
||||
if self.path == "/":
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
with open("./static/dream_web/index.html", "rb") as content:
|
||||
self.wfile.write(content.read())
|
||||
elif self.path == "/config.js":
|
||||
# unfortunately this import can't be at the top level, since that would cause a circular import
|
||||
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/javascript")
|
||||
self.end_headers()
|
||||
config = {
|
||||
'gfpgan_model_exists': gfpgan_model_exists
|
||||
}
|
||||
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
|
||||
elif self.path == "/run_log.json":
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
output = []
|
||||
|
||||
log_file = os.path.join(self.outdir, "dream_web_log.txt")
|
||||
if os.path.exists(log_file):
|
||||
with open(log_file, "r") as log:
|
||||
for line in log:
|
||||
url, config = line.split(": {", maxsplit=1)
|
||||
config = json.loads("{" + config)
|
||||
config["url"] = url.lstrip(".")
|
||||
if os.path.exists(url):
|
||||
output.append(config)
|
||||
|
||||
self.wfile.write(bytes(json.dumps({"run_log": output}), "utf-8"))
|
||||
elif self.path == "/cancel":
|
||||
self.canceled.set()
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(bytes('{}', 'utf8'))
|
||||
else:
|
||||
path = "." + self.path
|
||||
cwd = os.path.realpath(os.getcwd())
|
||||
is_in_cwd = os.path.commonprefix((os.path.realpath(path), cwd)) == cwd
|
||||
if not (is_in_cwd and os.path.exists(path)):
|
||||
self.send_response(404)
|
||||
return
|
||||
mime_type = mimetypes.guess_type(path)[0]
|
||||
if mime_type is not None:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", mime_type)
|
||||
self.end_headers()
|
||||
with open("." + self.path, "rb") as content:
|
||||
self.wfile.write(content.read())
|
||||
else:
|
||||
self.send_response(404)
|
||||
|
||||
def do_POST(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
|
||||
# unfortunately this import can't be at the top level, since that would cause a circular import
|
||||
# TODO temporarily commented out, import fails for some reason
|
||||
# from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
||||
gfpgan_model_exists = False
|
||||
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
post_data = json.loads(self.rfile.read(content_length))
|
||||
opt = build_opt(post_data, self.model.seed, gfpgan_model_exists)
|
||||
|
||||
self.canceled.clear()
|
||||
print(f">> Request to generate with prompt: {opt.prompt}")
|
||||
# In order to handle upscaled images, the PngWriter needs to maintain state
|
||||
# across images generated by each call to prompt2img(), so we define it in
|
||||
# the outer scope of image_done()
|
||||
config = post_data.copy() # Shallow copy
|
||||
config['initimg'] = config.pop('initimg_name', '')
|
||||
|
||||
images_generated = 0 # helps keep track of when upscaling is started
|
||||
images_upscaled = 0 # helps keep track of when upscaling is completed
|
||||
pngwriter = PngWriter(self.outdir)
|
||||
|
||||
prefix = pngwriter.unique_prefix()
|
||||
# if upscaling is requested, then this will be called twice, once when
|
||||
# the images are first generated, and then again when after upscaling
|
||||
# is complete. The upscaling replaces the original file, so the second
|
||||
# entry should not be inserted into the image list.
|
||||
def image_done(image, seed, upscaled=False, first_seed=-1, use_prefix=None):
|
||||
print(f'First seed: {first_seed}')
|
||||
name = f'{prefix}.{seed}.png'
|
||||
iter_opt = argparse.Namespace(**vars(opt)) # copy
|
||||
if opt.variation_amount > 0:
|
||||
this_variation = [[seed, opt.variation_amount]]
|
||||
if opt.with_variations is None:
|
||||
iter_opt.with_variations = this_variation
|
||||
else:
|
||||
iter_opt.with_variations = opt.with_variations + this_variation
|
||||
iter_opt.variation_amount = 0
|
||||
elif opt.with_variations is None:
|
||||
iter_opt.seed = seed
|
||||
normalized_prompt = PromptFormatter(self.model, iter_opt).normalize_prompt()
|
||||
path = pngwriter.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{iter_opt.seed}', name)
|
||||
|
||||
if int(config['seed']) == -1:
|
||||
config['seed'] = seed
|
||||
# Append post_data to log, but only once!
|
||||
if not upscaled:
|
||||
with open(os.path.join(self.outdir, "dream_web_log.txt"), "a") as log:
|
||||
log.write(f"{path}: {json.dumps(config)}\n")
|
||||
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event': 'result', 'url': path, 'seed': seed, 'config': config}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
# control state of the "postprocessing..." message
|
||||
upscaling_requested = opt.upscale or opt.gfpgan_strength > 0
|
||||
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
|
||||
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
|
||||
if upscaled:
|
||||
images_upscaled += 1
|
||||
else:
|
||||
images_generated += 1
|
||||
if upscaling_requested:
|
||||
action = None
|
||||
if images_generated >= opt.iterations:
|
||||
if images_upscaled < opt.iterations:
|
||||
action = 'upscaling-started'
|
||||
else:
|
||||
action = 'upscaling-done'
|
||||
if action:
|
||||
x = images_upscaled + 1
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event': action, 'processed_file_cnt': f'{x}/{opt.iterations}'}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
step_writer = PngWriter(os.path.join(self.outdir, "intermediates"))
|
||||
step_index = 1
|
||||
def image_progress(sample, step):
|
||||
if self.canceled.is_set():
|
||||
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
|
||||
raise CanceledException
|
||||
path = None
|
||||
# since rendering images is moderately expensive, only render every 5th image
|
||||
# and don't bother with the last one, since it'll render anyway
|
||||
nonlocal step_index
|
||||
if opt.progress_images and step % 5 == 0 and step < opt.steps - 1:
|
||||
image = self.model.sample_to_image(sample)
|
||||
name = f'{prefix}.{opt.seed}.{step_index}.png'
|
||||
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
|
||||
path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
|
||||
step_index += 1
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event': 'step', 'step': step + 1, 'url': path}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
try:
|
||||
if opt.init_img is None:
|
||||
# Run txt2img
|
||||
self.model.prompt2image(**vars(opt), step_callback=image_progress, image_callback=image_done)
|
||||
else:
|
||||
# Decode initimg as base64 to temp file
|
||||
with open("./img2img-tmp.png", "wb") as f:
|
||||
initimg = opt.init_img.split(",")[1] # Ignore mime type
|
||||
f.write(base64.b64decode(initimg))
|
||||
opt1 = argparse.Namespace(**vars(opt))
|
||||
opt1.init_img = "./img2img-tmp.png"
|
||||
|
||||
try:
|
||||
# Run img2img
|
||||
self.model.prompt2image(**vars(opt1), step_callback=image_progress, image_callback=image_done)
|
||||
finally:
|
||||
# Remove the temp file
|
||||
os.remove("./img2img-tmp.png")
|
||||
except CanceledException:
|
||||
print(f"Canceled.")
|
||||
return
|
||||
|
||||
|
||||
class ThreadingDreamServer(ThreadingHTTPServer):
|
||||
def __init__(self, server_address):
|
||||
super(ThreadingDreamServer, self).__init__(server_address, DreamServer)
|
||||
131
ldm/invoke/txt2mask.py
Normal file
131
ldm/invoke/txt2mask.py
Normal file
@@ -0,0 +1,131 @@
|
||||
'''Makes available the Txt2Mask class, which assists in the automatic
|
||||
assignment of masks via text prompt using clipseg.
|
||||
|
||||
Here is typical usage:
|
||||
|
||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
from PIL import Image
|
||||
|
||||
txt2mask = Txt2Mask(self.device)
|
||||
segmented = txt2mask.segment(Image.open('/path/to/img.png'),'a bagel')
|
||||
|
||||
# this will return a grayscale Image of the segmented data
|
||||
grayscale = segmented.to_grayscale()
|
||||
|
||||
# this will return a semi-transparent image in which the
|
||||
# selected object(s) are opaque and the rest is at various
|
||||
# levels of transparency
|
||||
transparent = segmented.to_transparent()
|
||||
|
||||
# this will return a masked image suitable for use in inpainting:
|
||||
mask = segmented.to_mask(threshold=0.5)
|
||||
|
||||
The threshold used in the call to to_mask() selects pixels for use in
|
||||
the mask that exceed the indicated confidence threshold. Values range
|
||||
from 0.0 to 1.0. The higher the threshold, the more confident the
|
||||
algorithm is. In limited testing, I have found that values around 0.5
|
||||
work fine.
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from clipseg_models.clipseg import CLIPDensePredT
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image, ImageOps
|
||||
from torchvision import transforms
|
||||
|
||||
CLIP_VERSION = 'ViT-B/16'
|
||||
CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth'
|
||||
CLIPSEG_SIZE = 352
|
||||
|
||||
class SegmentedGrayscale(object):
|
||||
def __init__(self, image:Image, heatmap:torch.Tensor):
|
||||
self.heatmap = heatmap
|
||||
self.image = image
|
||||
|
||||
def to_grayscale(self)->Image:
|
||||
return self._rescale(Image.fromarray(np.uint8(self.heatmap*255)))
|
||||
|
||||
def to_mask(self,threshold:float=0.5)->Image:
|
||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
|
||||
|
||||
def to_transparent(self,invert:bool=False)->Image:
|
||||
transparent_image = self.image.copy()
|
||||
gs = self.to_grayscale()
|
||||
# The following line looks like a bug, but isn't.
|
||||
# For img2img, we want the selected regions to be transparent,
|
||||
# but to_grayscale() returns the opposite.
|
||||
gs = ImageOps.invert(gs) if not invert else gs
|
||||
transparent_image.putalpha(gs)
|
||||
return transparent_image
|
||||
|
||||
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||
def _rescale(self, heatmap:Image)->Image:
|
||||
size = self.image.width if (self.image.width > self.image.height) else self.image.height
|
||||
resized_image = heatmap.resize(
|
||||
(size,size),
|
||||
resample=Image.Resampling.LANCZOS
|
||||
)
|
||||
return resized_image.crop((0,0,self.image.width,self.image.height))
|
||||
|
||||
class Txt2Mask(object):
|
||||
'''
|
||||
Create new Txt2Mask object. The optional device argument can be one of
|
||||
'cuda', 'mps' or 'cpu'.
|
||||
'''
|
||||
def __init__(self,device='cpu'):
|
||||
print('>> Initializing clipseg model for text to mask inference')
|
||||
self.device = device
|
||||
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, )
|
||||
self.model.eval()
|
||||
# initially we keep everything in cpu to conserve space
|
||||
self.model.to('cpu')
|
||||
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def segment(self, image, prompt:str) -> SegmentedGrayscale:
|
||||
'''
|
||||
Given a prompt string such as "a bagel", tries to identify the object in the
|
||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||
pixels indicate where the object is inferred to be.
|
||||
'''
|
||||
self._to_device(self.device)
|
||||
prompts = [prompt] # right now we operate on just a single prompt at a time
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
|
||||
])
|
||||
|
||||
if type(image) is str:
|
||||
image = Image.open(image).convert('RGB')
|
||||
|
||||
image = ImageOps.exif_transpose(image)
|
||||
img = self._scale_and_crop(image)
|
||||
img = transform(img).unsqueeze(0)
|
||||
|
||||
preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0]
|
||||
heatmap = torch.sigmoid(preds[0][0]).cpu()
|
||||
self._to_device('cpu')
|
||||
return SegmentedGrayscale(image, heatmap)
|
||||
|
||||
def _to_device(self, device):
|
||||
self.model.to(device)
|
||||
|
||||
def _scale_and_crop(self, image:Image)->Image:
|
||||
scaled_image = Image.new('RGB',(CLIPSEG_SIZE,CLIPSEG_SIZE))
|
||||
if image.width > image.height: # width is constraint
|
||||
scale = CLIPSEG_SIZE / image.width
|
||||
else:
|
||||
scale = CLIPSEG_SIZE / image.height
|
||||
scaled_image.paste(
|
||||
image.resize(
|
||||
(int(scale * image.width),
|
||||
int(scale * image.height)
|
||||
),
|
||||
resample=Image.Resampling.LANCZOS
|
||||
),box=(0,0)
|
||||
)
|
||||
return scaled_image
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ class DDPM(pl.LightningModule):
|
||||
], 'currently only supporting "eps" and "x0"'
|
||||
self.parameterization = parameterization
|
||||
print(
|
||||
f'{self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
|
||||
f' | {self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
|
||||
)
|
||||
self.cond_stage_model = None
|
||||
self.clip_denoised = clip_denoised
|
||||
@@ -1353,7 +1353,7 @@ class LatentDiffusion(DDPM):
|
||||
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
||||
rescale_latent = 2 ** (num_downs)
|
||||
|
||||
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
|
||||
# get top left positions of patches as conforming for the bbbox tokenizer, therefore we
|
||||
# need to rescale the tl patch coordinates to be in between (0,1)
|
||||
tl_patch_coordinates = [
|
||||
(
|
||||
|
||||
@@ -2,9 +2,15 @@
|
||||
import k_diffusion as K
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.util import rand_perlin_2d
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
|
||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
if threshold <= 0.0:
|
||||
@@ -51,8 +57,9 @@ class KSampler(Sampler):
|
||||
schedule,
|
||||
steps=model.num_timesteps,
|
||||
)
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
self.sigmas = None
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
@@ -81,13 +88,55 @@ class KSampler(Sampler):
|
||||
)
|
||||
self.model = outer_model
|
||||
self.ddim_num_steps = ddim_num_steps
|
||||
sigmas = self.model.get_sigmas(ddim_num_steps)
|
||||
self.sigmas = sigmas
|
||||
# we don't need both of these sigmas, but storing them here to make
|
||||
# comparison easier later on
|
||||
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
|
||||
self.karras_sigmas = K.sampling.get_sigmas_karras(
|
||||
n=ddim_num_steps,
|
||||
sigma_min=self.model.sigmas[0].item(),
|
||||
sigma_max=self.model.sigmas[-1].item(),
|
||||
rho=7.,
|
||||
device=self.device,
|
||||
)
|
||||
self.sigmas = self.model_sigmas
|
||||
#self.sigmas = self.karras_sigmas
|
||||
|
||||
# ALERT: We are completely overriding the sample() method in the base class, which
|
||||
# means that inpainting will (probably?) not work correctly. To get this to work
|
||||
# we need to be able to modify the inner loop of k_heun, k_lms, etc, as is done
|
||||
# in an ugly way in the lstein/k-diffusion branch.
|
||||
# means that inpainting will not work. To get this to work we need to be able to
|
||||
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
||||
# in the lstein/k-diffusion branch.
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
z_enc,
|
||||
cond,
|
||||
t_enc,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
):
|
||||
samples,_ = self.sample(
|
||||
batch_size = 1,
|
||||
S = t_enc,
|
||||
x_T = z_enc,
|
||||
shape = z_enc.shape[1:],
|
||||
conditioning = cond,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning = unconditional_conditioning,
|
||||
img_callback = img_callback,
|
||||
x0 = init_latent,
|
||||
mask = mask
|
||||
)
|
||||
return samples
|
||||
|
||||
# this is a no-op, provided here for compatibility with ddim and plms samplers
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
return x0
|
||||
|
||||
# Most of these arguments are ignored and are only present for compatibility with
|
||||
# other samples
|
||||
@@ -123,24 +172,35 @@ class KSampler(Sampler):
|
||||
if img_callback is not None:
|
||||
img_callback(k_callback_values['x'],k_callback_values['i'])
|
||||
|
||||
# sigmas = self.model.get_sigmas(S)
|
||||
# sigmas are now set up in make_schedule - we take the last steps items
|
||||
# if make_schedule() hasn't been called, we do it now
|
||||
if self.sigmas is None:
|
||||
self.make_schedule(
|
||||
ddim_num_steps=S,
|
||||
ddim_eta = eta,
|
||||
verbose = False,
|
||||
)
|
||||
|
||||
# sigmas are set up in make_schedule - we take the last steps items
|
||||
total_steps = len(self.sigmas)
|
||||
sigmas = self.sigmas[-S-1:]
|
||||
|
||||
# x_T is variation noise. When an init image is provided (in x0) we need to add
|
||||
# more randomness to the starting image.
|
||||
if x_T is not None:
|
||||
x = x_T * sigmas[0]
|
||||
if x0 is not None:
|
||||
x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0]
|
||||
else:
|
||||
x = x_T * sigmas[0]
|
||||
else:
|
||||
x = (
|
||||
torch.randn([batch_size, *shape], device=self.device)
|
||||
* sigmas[0]
|
||||
) # for GPU draw
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': unconditional_guidance_scale,
|
||||
}
|
||||
print(f'>> Sampling with k_{self.schedule}')
|
||||
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
||||
return (
|
||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||
@@ -149,6 +209,8 @@ class KSampler(Sampler):
|
||||
None,
|
||||
)
|
||||
|
||||
# this code will support inpainting if and when ksampler API modified or
|
||||
# a workaround is found.
|
||||
@torch.no_grad()
|
||||
def p_sample(
|
||||
self,
|
||||
@@ -195,11 +257,17 @@ class KSampler(Sampler):
|
||||
|
||||
return img, None, None
|
||||
|
||||
# REVIEW THIS METHOD: it has never been tested. In particular,
|
||||
# we should not be multiplying by self.sigmas[0] if we
|
||||
# are at an intermediate step in img2img. See similar in
|
||||
# sample() which does work.
|
||||
def get_initial_image(self,x_T,shape,steps):
|
||||
print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing')
|
||||
x = (torch.randn(shape, device=self.device) * self.sigmas[0])
|
||||
if x_T is not None:
|
||||
return x_T + x_T * self.sigmas[0]
|
||||
return x_T + x
|
||||
else:
|
||||
return (torch.randn(shape, device=self.device) * self.sigmas[0])
|
||||
return x
|
||||
|
||||
def prepare_to_sample(self,t_enc):
|
||||
self.t_enc = t_enc
|
||||
@@ -213,29 +281,3 @@ class KSampler(Sampler):
|
||||
'''
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
z_enc,
|
||||
cond,
|
||||
t_enc,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
):
|
||||
samples,_ = self.sample(
|
||||
batch_size = 1,
|
||||
S = t_enc,
|
||||
x_T = z_enc,
|
||||
shape = z_enc.shape[1:],
|
||||
conditioning = cond,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning = unconditional_conditioning,
|
||||
img_callback = img_callback,
|
||||
x0 = init_latent,
|
||||
mask = mask
|
||||
)
|
||||
return samples
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
@@ -20,6 +20,7 @@ from ldm.modules.diffusionmodules.util import (
|
||||
class Sampler(object):
|
||||
def __init__(self, model, schedule='linear', steps=None, device=None, **kwargs):
|
||||
self.model = model
|
||||
self.ddim_timesteps = None
|
||||
self.ddpm_num_timesteps = steps
|
||||
self.schedule = schedule
|
||||
self.device = device or choose_torch_device()
|
||||
@@ -39,6 +40,7 @@ class Sampler(object):
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
):
|
||||
self.total_steps = ddim_num_steps
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
@@ -138,7 +140,7 @@ class Sampler(object):
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
img_callback=None, # TODO: this is very confusing because it is called "step_callback" elsewhere. Change.
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
@@ -156,6 +158,14 @@ class Sampler(object):
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# check to see if make_schedule() has run, and if not, run it
|
||||
if self.ddim_timesteps is None:
|
||||
self.make_schedule(
|
||||
ddim_num_steps=S,
|
||||
ddim_eta = eta,
|
||||
verbose = False,
|
||||
)
|
||||
|
||||
ts = self.get_timesteps(S)
|
||||
|
||||
# sampling
|
||||
@@ -211,6 +221,7 @@ class Sampler(object):
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
|
||||
total_steps=steps
|
||||
|
||||
iterator = tqdm(
|
||||
@@ -305,7 +316,7 @@ class Sampler(object):
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f'>> Running {self.__class__.__name__} Sampling with {total_steps} timesteps')
|
||||
print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)')
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
|
||||
@@ -49,9 +49,15 @@ class Upsample(nn.Module):
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
cpu_m1_cond = True if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and \
|
||||
x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3] % 2**27 == 0 else False
|
||||
if cpu_m1_cond:
|
||||
x = x.to('cpu') # send to cpu
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
if cpu_m1_cond:
|
||||
x = x.to('mps') # return to mps
|
||||
return x
|
||||
|
||||
|
||||
@@ -117,6 +123,14 @@ class ResnetBlock(nn.Module):
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
x_size = x.size()
|
||||
if (x_size[0] * x_size[1] * x_size[2] * x_size[3]) % 2**29 == 0:
|
||||
self.to('cpu')
|
||||
x = x.to('cpu')
|
||||
else:
|
||||
self.to('mps')
|
||||
x = x.to('mps')
|
||||
h = self.norm1(x)
|
||||
h = silu(h)
|
||||
h = self.conv1(h)
|
||||
@@ -245,7 +259,7 @@ class AttnBlock(nn.Module):
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla"):
|
||||
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
print(f" | Making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "none":
|
||||
@@ -521,7 +535,7 @@ class Decoder(nn.Module):
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(
|
||||
print(" | Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
|
||||
@@ -66,7 +66,7 @@ def make_ddim_timesteps(
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
if c < 1:
|
||||
c = 1
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
ddim_timesteps = (np.arange(0, num_ddim_timesteps) * c).astype(int)
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = (
|
||||
(
|
||||
@@ -83,8 +83,8 @@ def make_ddim_timesteps(
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
# steps_out = ddim_timesteps + 1
|
||||
steps_out = ddim_timesteps
|
||||
steps_out = ddim_timesteps + 1
|
||||
# steps_out = ddim_timesteps
|
||||
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
|
||||
@@ -5,7 +5,7 @@ import clip
|
||||
from einops import rearrange, repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import kornia
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
|
||||
from ldm.modules.x_transformer import (
|
||||
Encoder,
|
||||
|
||||
24
ldm/util.py
24
ldm/util.py
@@ -75,7 +75,7 @@ def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(
|
||||
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
||||
f' | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
||||
)
|
||||
return total_params
|
||||
|
||||
@@ -214,20 +214,24 @@ def parallel_data_prefetch(
|
||||
else:
|
||||
return gather_res
|
||||
|
||||
def rand_perlin_2d(shape, res, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
|
||||
def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
|
||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||
d = (shape[0] // res[0], shape[1] // res[1])
|
||||
|
||||
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1) % 1
|
||||
angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1)
|
||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1)
|
||||
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1
|
||||
|
||||
rand_val = torch.rand(res[0]+1, res[1]+1)
|
||||
|
||||
angles = 2*math.pi*rand_val
|
||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device)
|
||||
|
||||
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
|
||||
|
||||
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
|
||||
|
||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
||||
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1])
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1])
|
||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
|
||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
|
||||
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
|
||||
t = fade(grid[:shape[0], :shape[1]])
|
||||
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
||||
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
|
||||
|
||||
Reference in New Issue
Block a user