diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 271a2e3be3..347fba7e97 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,7 +4,8 @@ import os from argparse import Namespace from ...backend import Globals -from ..services.generate_initializer import get_generate +from ..services.model_manager_initializer import get_model_manager +from ..services.restoration_services import RestorationServices from ..services.graph import GraphExecutionState from ..services.image_storage import DiskImageStorage from ..services.invocation_queue import MemoryInvocationQueue @@ -37,18 +38,16 @@ class ApiDependencies: invoker: Invoker = None @staticmethod - def initialize(args, config, event_handler_id: int): - Globals.try_patchmatch = args.patchmatch - Globals.always_use_cpu = args.always_use_cpu - Globals.internet_available = args.internet_available and check_internet() - Globals.disable_xformers = not args.xformers - Globals.ckpt_convert = args.ckpt_convert + def initialize(config, event_handler_id: int): + Globals.try_patchmatch = config.patchmatch + Globals.always_use_cpu = config.always_use_cpu + Globals.internet_available = config.internet_available and check_internet() + Globals.disable_xformers = not config.xformers + Globals.ckpt_convert = config.ckpt_convert # TODO: Use a logger print(f">> Internet connectivity is {Globals.internet_available}") - generate = get_generate(args, config) - events = FastAPIEventService(event_handler_id) output_folder = os.path.abspath( @@ -61,7 +60,7 @@ class ApiDependencies: db_location = os.path.join(output_folder, "invokeai.db") services = InvocationServices( - generate=generate, + model_manager=get_model_manager(config), events=events, images=images, queue=MemoryInvocationQueue(), @@ -69,6 +68,7 @@ class ApiDependencies: filename=db_location, table_name="graph_executions" ), processor=DefaultInvocationProcessor(), + restoration=RestorationServices(config), ) ApiDependencies.invoker = Invoker(services) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index fb64ca3b7a..7bc38dc2dc 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -1,5 +1,4 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - import asyncio from inspect import signature @@ -53,11 +52,11 @@ config = {} # Add startup event to load dependencies @app.on_event("startup") async def startup_event(): - args = Args() - config = args.parse_args() + config = Args() + config.parse_args() ApiDependencies.initialize( - args=args, config=config, event_handler_id=event_handler_id + config=config, event_handler_id=event_handler_id ) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 9dc1429d92..732a233cb4 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -17,7 +17,8 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra from .invocations import * from .invocations.baseinvocation import BaseInvocation from .services.events import EventServiceBase -from .services.generate_initializer import get_generate +from .services.model_manager_initializer import get_model_manager +from .services.restoration_services import RestorationServices from .services.graph import EdgeConnection, GraphExecutionState from .services.image_storage import DiskImageStorage from .services.invocation_queue import MemoryInvocationQueue @@ -126,14 +127,9 @@ def invoke_all(context: CliContext): def invoke_cli(): - args = Args() - config = args.parse_args() - - generate = get_generate(args, config) - - # NOTE: load model on first use, uncomment to load at startup - # TODO: Make this a config option? - # generate.load_model() + config = Args() + config.parse_args() + model_manager = get_model_manager(config) events = EventServiceBase() @@ -145,7 +141,7 @@ def invoke_cli(): db_location = os.path.join(output_folder, "invokeai.db") services = InvocationServices( - generate=generate, + model_manager=model_manager, events=events, images=DiskImageStorage(output_folder), queue=MemoryInvocationQueue(), @@ -153,6 +149,7 @@ def invoke_cli(): filename=db_location, table_name="graph_executions" ), processor=DefaultInvocationProcessor(), + restoration=RestorationServices(config), ) invoker = Invoker(services) diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 15c5f17438..c1a0028293 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -12,12 +12,12 @@ from ..services.image_storage import ImageType from ..services.invocation_services import InvocationServices from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageField, ImageOutput +from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator SAMPLER_NAME_VALUES = Literal[ - "ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun" + tuple(InvokeAIGenerator.schedulers()) ] - # Text to image class TextToImageInvocation(BaseInvocation): """Generates an image using text2img.""" @@ -57,19 +57,18 @@ class TextToImageInvocation(BaseInvocation): # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? - if self.model is None or self.model == "": - self.model = context.services.generate.model_name - - # Set the model (if already cached, this does nothing) - context.services.generate.set_model(self.model) - - results = context.services.generate.prompt2image( + # (right now uses whatever current model is set in model manager) + model= context.services.model_manager.get_model() + outputs = Txt2Img(model).generate( prompt=self.prompt, step_callback=step_callback, **self.dict( exclude={"prompt"} ), # Shorthand for passing all of the parameters above manually ) + # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object + # each time it is called. We only need the first one. + generate_output = next(outputs) # Results are image and seed, unwrap for now and ignore the seed # TODO: pre-seed? @@ -78,7 +77,7 @@ class TextToImageInvocation(BaseInvocation): image_name = context.services.images.create_name( context.graph_execution_state_id, self.id ) - context.services.images.save(image_type, image_name, results[0][0]) + context.services.images.save(image_type, image_name, generate_output.image) return ImageOutput( image=ImageField(image_type=image_type, image_name=image_name) ) @@ -115,23 +114,20 @@ class ImageToImageInvocation(TextToImageInvocation): # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? - if self.model is None or self.model == "": - self.model = context.services.generate.model_name - - # Set the model (if already cached, this does nothing) - context.services.generate.set_model(self.model) - - results = context.services.generate.prompt2image( - prompt=self.prompt, - init_img=image, - init_mask=mask, - step_callback=step_callback, - **self.dict( - exclude={"prompt", "image", "mask"} - ), # Shorthand for passing all of the parameters above manually + model = context.services.model_manager.get_model() + generator_output = next( + Img2Img(model).generate( + prompt=self.prompt, + init_img=image, + init_mask=mask, + step_callback=step_callback, + **self.dict( + exclude={"prompt", "image", "mask"} + ), # Shorthand for passing all of the parameters above manually + ) ) - result_image = results[0][0] + result_image = generator_output.image # Results are image and seed, unwrap for now and ignore the seed # TODO: pre-seed? @@ -145,7 +141,6 @@ class ImageToImageInvocation(TextToImageInvocation): image=ImageField(image_type=image_type, image_name=image_name) ) - class InpaintInvocation(ImageToImageInvocation): """Generates an image using inpaint.""" @@ -180,23 +175,20 @@ class InpaintInvocation(ImageToImageInvocation): # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? - if self.model is None or self.model == "": - self.model = context.services.generate.model_name - - # Set the model (if already cached, this does nothing) - context.services.generate.set_model(self.model) - - results = context.services.generate.prompt2image( - prompt=self.prompt, - init_img=image, - init_mask=mask, - step_callback=step_callback, - **self.dict( - exclude={"prompt", "image", "mask"} - ), # Shorthand for passing all of the parameters above manually + manager = context.services.model_manager.get_model() + generator_output = next( + Inpaint(model).generate( + prompt=self.prompt, + init_img=image, + init_mask=mask, + step_callback=step_callback, + **self.dict( + exclude={"prompt", "image", "mask"} + ), # Shorthand for passing all of the parameters above manually + ) ) - result_image = results[0][0] + result_image = generator_output.image # Results are image and seed, unwrap for now and ignore the seed # TODO: pre-seed? diff --git a/invokeai/app/invocations/reconstruct.py b/invokeai/app/invocations/reconstruct.py index a90c33605e..c4d8f3ac7c 100644 --- a/invokeai/app/invocations/reconstruct.py +++ b/invokeai/app/invocations/reconstruct.py @@ -8,7 +8,6 @@ from ..services.invocation_services import InvocationServices from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageField, ImageOutput - class RestoreFaceInvocation(BaseInvocation): """Restores faces in an image.""" #fmt: off @@ -23,7 +22,7 @@ class RestoreFaceInvocation(BaseInvocation): image = context.services.images.get( self.image.image_type, self.image.image_name ) - results = context.services.generate.upscale_and_reconstruct( + results = context.services.restoration.upscale_and_reconstruct( image_list=[[image, 0]], upscale=None, strength=self.strength, # GFPGAN strength diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index dcc39fc9ad..4079877fdb 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -26,7 +26,7 @@ class UpscaleInvocation(BaseInvocation): image = context.services.images.get( self.image.image_type, self.image.image_name ) - results = context.services.generate.upscale_and_reconstruct( + results = context.services.restoration.upscale_and_reconstruct( image_list=[[image, 0]], upscale=(self.level, self.strength), strength=0.0, # GFPGAN strength diff --git a/invokeai/app/services/generate_initializer.py b/invokeai/app/services/generate_initializer.py deleted file mode 100644 index 9801909742..0000000000 --- a/invokeai/app/services/generate_initializer.py +++ /dev/null @@ -1,255 +0,0 @@ -import os -import sys -import traceback -from argparse import Namespace - -import invokeai.version -from invokeai.backend import Generate, ModelManager - -from ...backend import Globals - - -# TODO: most of this code should be split into individual services as the Generate.py code is deprecated -def get_generate(args, config) -> Generate: - if not args.conf: - config_file = os.path.join(Globals.root, "configs", "models.yaml") - if not os.path.exists(config_file): - report_model_error( - args, FileNotFoundError(f"The file {config_file} could not be found.") - ) - - print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}") - print(f'>> InvokeAI runtime directory is "{Globals.root}"') - - # these two lines prevent a horrible warning message from appearing - # when the frozen CLIP tokenizer is imported - import transformers # type: ignore - - transformers.logging.set_verbosity_error() - import diffusers - - diffusers.logging.set_verbosity_error() - - # Loading Face Restoration and ESRGAN Modules - gfpgan, codeformer, esrgan = load_face_restoration(args) - - # normalize the config directory relative to root - if not os.path.isabs(args.conf): - args.conf = os.path.normpath(os.path.join(Globals.root, args.conf)) - - if args.embeddings: - if not os.path.isabs(args.embedding_path): - embedding_path = os.path.normpath( - os.path.join(Globals.root, args.embedding_path) - ) - else: - embedding_path = args.embedding_path - else: - embedding_path = None - - # migrate legacy models - ModelManager.migrate_models() - - # load the infile as a list of lines - if args.infile: - try: - if os.path.isfile(args.infile): - infile = open(args.infile, "r", encoding="utf-8") - elif args.infile == "-": # stdin - infile = sys.stdin - else: - raise FileNotFoundError(f"{args.infile} not found.") - except (FileNotFoundError, IOError) as e: - print(f"{e}. Aborting.") - sys.exit(-1) - - # creating a Generate object: - try: - gen = Generate( - conf=args.conf, - model=args.model, - sampler_name=args.sampler_name, - embedding_path=embedding_path, - full_precision=args.full_precision, - precision=args.precision, - gfpgan=gfpgan, - codeformer=codeformer, - esrgan=esrgan, - free_gpu_mem=args.free_gpu_mem, - safety_checker=args.safety_checker, - max_loaded_models=args.max_loaded_models, - ) - except (FileNotFoundError, TypeError, AssertionError) as e: - report_model_error(opt, e) - except (IOError, KeyError) as e: - print(f"{e}. Aborting.") - sys.exit(-1) - - if args.seamless: - print(">> changed to seamless tiling mode") - - # preload the model - try: - gen.load_model() - except KeyError: - pass - except Exception as e: - report_model_error(args, e) - - # try to autoconvert new models - # autoimport new .ckpt files - if path := args.autoconvert: - gen.model_manager.autoconvert_weights( - conf_path=args.conf, - weights_directory=path, - ) - - return gen - - -def load_face_restoration(opt): - try: - gfpgan, codeformer, esrgan = None, None, None - if opt.restore or opt.esrgan: - from invokeai.backend.restoration import Restoration - - restoration = Restoration() - if opt.restore: - gfpgan, codeformer = restoration.load_face_restore_models( - opt.gfpgan_model_path - ) - else: - print(">> Face restoration disabled") - if opt.esrgan: - esrgan = restoration.load_esrgan(opt.esrgan_bg_tile) - else: - print(">> Upscaling disabled") - else: - print(">> Face restoration and upscaling disabled") - except (ModuleNotFoundError, ImportError): - print(traceback.format_exc(), file=sys.stderr) - print(">> You may need to install the ESRGAN and/or GFPGAN modules") - return gfpgan, codeformer, esrgan - - -def report_model_error(opt: Namespace, e: Exception): - print(f'** An error occurred while attempting to initialize the model: "{str(e)}"') - print( - "** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models." - ) - yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE") - if yes_to_all: - print( - "** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE" - ) - else: - response = input( - "Do you want to run invokeai-configure script to select and/or reinstall models? [y] " - ) - if response.startswith(("n", "N")): - return - - print("invokeai-configure is launching....\n") - - # Match arguments that were set on the CLI - # only the arguments accepted by the configuration script are parsed - root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else [] - config = ["--config", opt.conf] if opt.conf is not None else [] - previous_args = sys.argv - sys.argv = ["invokeai-configure"] - sys.argv.extend(root_dir) - sys.argv.extend(config) - if yes_to_all is not None: - for arg in yes_to_all.split(): - sys.argv.append(arg) - - from invokeai.frontend.install import invokeai_configure - - invokeai_configure() - # TODO: Figure out how to restart - # print('** InvokeAI will now restart') - # sys.argv = previous_args - # main() # would rather do a os.exec(), but doesn't exist? - # sys.exit(0) - - -# Temporary initializer for Generate until we migrate off of it -def old_get_generate(args, config) -> Generate: - # TODO: Remove the need for globals - from invokeai.backend.globals import Globals - - # alert - setting globals here - Globals.root = os.path.expanduser( - args.root_dir or os.environ.get("INVOKEAI_ROOT") or os.path.abspath(".") - ) - Globals.try_patchmatch = args.patchmatch - - print(f'>> InvokeAI runtime directory is "{Globals.root}"') - - # these two lines prevent a horrible warning message from appearing - # when the frozen CLIP tokenizer is imported - import transformers - - transformers.logging.set_verbosity_error() - - # Loading Face Restoration and ESRGAN Modules - gfpgan, codeformer, esrgan = None, None, None - try: - if config.restore or config.esrgan: - from ldm.invoke.restoration import Restoration - - restoration = Restoration() - if config.restore: - gfpgan, codeformer = restoration.load_face_restore_models( - config.gfpgan_model_path - ) - else: - print(">> Face restoration disabled") - if config.esrgan: - esrgan = restoration.load_esrgan(config.esrgan_bg_tile) - else: - print(">> Upscaling disabled") - else: - print(">> Face restoration and upscaling disabled") - except (ModuleNotFoundError, ImportError): - print(traceback.format_exc(), file=sys.stderr) - print(">> You may need to install the ESRGAN and/or GFPGAN modules") - - # normalize the config directory relative to root - if not os.path.isabs(config.conf): - config.conf = os.path.normpath(os.path.join(Globals.root, config.conf)) - - if config.embeddings: - if not os.path.isabs(config.embedding_path): - embedding_path = os.path.normpath( - os.path.join(Globals.root, config.embedding_path) - ) - else: - embedding_path = None - - # TODO: lazy-initialize this by wrapping it - try: - generate = Generate( - conf=config.conf, - model=config.model, - sampler_name=config.sampler_name, - embedding_path=embedding_path, - full_precision=config.full_precision, - precision=config.precision, - gfpgan=gfpgan, - codeformer=codeformer, - esrgan=esrgan, - free_gpu_mem=config.free_gpu_mem, - safety_checker=config.safety_checker, - max_loaded_models=config.max_loaded_models, - ) - except (FileNotFoundError, TypeError, AssertionError): - # emergency_model_reconfigure() # TODO? - sys.exit(-1) - except (IOError, KeyError) as e: - print(f"{e}. Aborting.") - sys.exit(-1) - - generate.free_gpu_mem = config.free_gpu_mem - - return generate diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 42cbd6c271..7f24c34378 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -1,36 +1,39 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from invokeai.backend import Generate +from invokeai.backend import ModelManager from .events import EventServiceBase from .image_storage import ImageStorageBase +from .restoration_services import RestorationServices from .invocation_queue import InvocationQueueABC from .item_storage import ItemStorageABC - class InvocationServices: """Services that can be used by invocations""" - generate: Generate # TODO: wrap Generate, or split it up from model? events: EventServiceBase images: ImageStorageBase queue: InvocationQueueABC + model_manager: ModelManager + restoration: RestorationServices # NOTE: we must forward-declare any types that include invocations, since invocations can use services graph_execution_manager: ItemStorageABC["GraphExecutionState"] processor: "InvocationProcessorABC" def __init__( - self, - generate: Generate, - events: EventServiceBase, - images: ImageStorageBase, - queue: InvocationQueueABC, - graph_execution_manager: ItemStorageABC["GraphExecutionState"], - processor: "InvocationProcessorABC", + self, + model_manager: ModelManager, + events: EventServiceBase, + images: ImageStorageBase, + queue: InvocationQueueABC, + graph_execution_manager: ItemStorageABC["GraphExecutionState"], + processor: "InvocationProcessorABC", + restoration: RestorationServices, ): - self.generate = generate + self.model_manager = model_manager self.events = events self.images = images self.queue = queue self.graph_execution_manager = graph_execution_manager self.processor = processor + self.restoration = restoration diff --git a/invokeai/app/services/model_manager_initializer.py b/invokeai/app/services/model_manager_initializer.py new file mode 100644 index 0000000000..3ef79f0b7e --- /dev/null +++ b/invokeai/app/services/model_manager_initializer.py @@ -0,0 +1,120 @@ +import os +import sys +import torch +from argparse import Namespace +from invokeai.backend import Args +from omegaconf import OmegaConf +from pathlib import Path + +import invokeai.version +from ...backend import ModelManager +from ...backend.util import choose_precision, choose_torch_device +from ...backend import Globals + +# TODO: Replace with an abstract class base ModelManagerBase +def get_model_manager(config: Args) -> ModelManager: + if not config.conf: + config_file = os.path.join(Globals.root, "configs", "models.yaml") + if not os.path.exists(config_file): + report_model_error( + config, FileNotFoundError(f"The file {config_file} could not be found.") + ) + + print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}") + print(f'>> InvokeAI runtime directory is "{Globals.root}"') + + # these two lines prevent a horrible warning message from appearing + # when the frozen CLIP tokenizer is imported + import transformers # type: ignore + + transformers.logging.set_verbosity_error() + import diffusers + + diffusers.logging.set_verbosity_error() + + # normalize the config directory relative to root + if not os.path.isabs(config.conf): + config.conf = os.path.normpath(os.path.join(Globals.root, config.conf)) + + if config.embeddings: + if not os.path.isabs(config.embedding_path): + embedding_path = os.path.normpath( + os.path.join(Globals.root, config.embedding_path) + ) + else: + embedding_path = config.embedding_path + else: + embedding_path = None + + # migrate legacy models + ModelManager.migrate_models() + + # creating the model manager + try: + device = torch.device(choose_torch_device()) + precision = 'float16' if config.precision=='float16' \ + else 'float32' if config.precision=='float32' \ + else choose_precision(device) + + model_manager = ModelManager( + OmegaConf.load(config.conf), + precision=precision, + device_type=device, + max_loaded_models=config.max_loaded_models, + embedding_path = Path(embedding_path), + ) + except (FileNotFoundError, TypeError, AssertionError) as e: + report_model_error(config, e) + except (IOError, KeyError) as e: + print(f"{e}. Aborting.") + sys.exit(-1) + + # try to autoconvert new models + # autoimport new .ckpt files + if path := config.autoconvert: + model_manager.autoconvert_weights( + conf_path=config.conf, + weights_directory=path, + ) + + return model_manager + +def report_model_error(opt: Namespace, e: Exception): + print(f'** An error occurred while attempting to initialize the model: "{str(e)}"') + print( + "** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models." + ) + yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE") + if yes_to_all: + print( + "** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE" + ) + else: + response = input( + "Do you want to run invokeai-configure script to select and/or reinstall models? [y] " + ) + if response.startswith(("n", "N")): + return + + print("invokeai-configure is launching....\n") + + # Match arguments that were set on the CLI + # only the arguments accepted by the configuration script are parsed + root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else [] + config = ["--config", opt.conf] if opt.conf is not None else [] + previous_config = sys.argv + sys.argv = ["invokeai-configure"] + sys.argv.extend(root_dir) + sys.argv.extend(config.to_dict()) + if yes_to_all is not None: + for arg in yes_to_all.split(): + sys.argv.append(arg) + + from invokeai.frontend.install import invokeai_configure + + invokeai_configure() + # TODO: Figure out how to restart + # print('** InvokeAI will now restart') + # sys.argv = previous_args + # main() # would rather do a os.exec(), but doesn't exist? + # sys.exit(0) diff --git a/invokeai/app/services/restoration_services.py b/invokeai/app/services/restoration_services.py new file mode 100644 index 0000000000..f5fc687c11 --- /dev/null +++ b/invokeai/app/services/restoration_services.py @@ -0,0 +1,109 @@ +import sys +import traceback +import torch +from ...backend.restoration import Restoration +from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE + +# This should be a real base class for postprocessing functions, +# but right now we just instantiate the existing gfpgan, esrgan +# and codeformer functions. +class RestorationServices: + '''Face restoration and upscaling''' + + def __init__(self,args): + try: + gfpgan, codeformer, esrgan = None, None, None + if args.restore or args.esrgan: + restoration = Restoration() + if args.restore: + gfpgan, codeformer = restoration.load_face_restore_models( + args.gfpgan_model_path + ) + else: + print(">> Face restoration disabled") + if args.esrgan: + esrgan = restoration.load_esrgan(args.esrgan_bg_tile) + else: + print(">> Upscaling disabled") + else: + print(">> Face restoration and upscaling disabled") + except (ModuleNotFoundError, ImportError): + print(traceback.format_exc(), file=sys.stderr) + print(">> You may need to install the ESRGAN and/or GFPGAN modules") + self.device = torch.device(choose_torch_device()) + self.gfpgan = gfpgan + self.codeformer = codeformer + self.esrgan = esrgan + + # note that this one method does gfpgan and codepath reconstruction, as well as + # esrgan upscaling + # TO DO: refactor into separate methods + def upscale_and_reconstruct( + self, + image_list, + facetool="gfpgan", + upscale=None, + upscale_denoise_str=0.75, + strength=0.0, + codeformer_fidelity=0.75, + save_original=False, + image_callback=None, + prefix=None, + ): + results = [] + for r in image_list: + image, seed = r + try: + if strength > 0: + if self.gfpgan is not None or self.codeformer is not None: + if facetool == "gfpgan": + if self.gfpgan is None: + print( + ">> GFPGAN not found. Face restoration is disabled." + ) + else: + image = self.gfpgan.process(image, strength, seed) + if facetool == "codeformer": + if self.codeformer is None: + print( + ">> CodeFormer not found. Face restoration is disabled." + ) + else: + cf_device = ( + CPU_DEVICE if self.device == MPS_DEVICE else self.device + ) + image = self.codeformer.process( + image=image, + strength=strength, + device=cf_device, + seed=seed, + fidelity=codeformer_fidelity, + ) + else: + print(">> Face Restoration is disabled.") + if upscale is not None: + if self.esrgan is not None: + if len(upscale) < 2: + upscale.append(0.75) + image = self.esrgan.process( + image, + upscale[1], + seed, + int(upscale[0]), + denoise_str=upscale_denoise_str, + ) + else: + print(">> ESRGAN is disabled. Image not upscaled.") + except Exception as e: + print( + f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}" + ) + + if image_callback is not None: + image_callback(image, seed, upscaled=True, use_prefix=prefix) + else: + r[0] = image + + results.append([image, seed]) + + return results diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index 06089369c2..06066dd6b1 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -2,6 +2,15 @@ Initialization file for invokeai.backend """ from .generate import Generate +from .generator import ( + InvokeAIGeneratorBasicParams, + InvokeAIGenerator, + InvokeAIGeneratorOutput, + Txt2Img, + Img2Img, + Inpaint +) from .model_management import ModelManager +from .safety_checker import SafetyChecker from .args import Args from .globals import Globals diff --git a/invokeai/backend/generate.py b/invokeai/backend/generate.py index 35dba41ffb..1b19a1aa7e 100644 --- a/invokeai/backend/generate.py +++ b/invokeai/backend/generate.py @@ -25,18 +25,19 @@ from accelerate.utils import set_seed from diffusers.pipeline_utils import DiffusionPipeline from diffusers.utils.import_utils import is_xformers_available from omegaconf import OmegaConf +from pathlib import Path from .args import metadata_from_png from .generator import infill_methods from .globals import Globals, global_cache_dir from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding from .model_management import ModelManager +from .safety_checker import SafetyChecker from .prompting import get_uc_and_c_and_ec from .prompting.conditioning import log_tokenization from .stable_diffusion import HuggingFaceConceptsLibrary from .util import choose_precision, choose_torch_device - def fix_func(orig): if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): @@ -222,6 +223,7 @@ class Generate: self.precision, max_loaded_models=max_loaded_models, sequential_offload=self.free_gpu_mem, + embedding_path=Path(self.embedding_path), ) # don't accept invalid models fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME @@ -244,31 +246,8 @@ class Generate: # load safety checker if requested if safety_checker: - try: - print(">> Initializing NSFW checker") - from diffusers.pipelines.stable_diffusion.safety_checker import ( - StableDiffusionSafetyChecker, - ) - from transformers import AutoFeatureExtractor - - safety_model_id = "CompVis/stable-diffusion-safety-checker" - safety_model_path = global_cache_dir("hub") - self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( - safety_model_id, - local_files_only=True, - cache_dir=safety_model_path, - ) - self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained( - safety_model_id, - local_files_only=True, - cache_dir=safety_model_path, - ) - self.safety_checker.to(self.device) - except Exception: - print( - "** An error was encountered while installing the safety checker:" - ) - print(traceback.format_exc()) + print(">> Initializing NSFW checker") + self.safety_checker = SafetyChecker(self.device) else: print(">> NSFW checker is disabled") @@ -523,15 +502,6 @@ class Generate: generator.set_variation(self.seed, variation_amount, with_variations) generator.use_mps_noise = use_mps_noise - checker = ( - { - "checker": self.safety_checker, - "extractor": self.safety_feature_extractor, - } - if self.safety_checker - else None - ) - results = generator.generate( prompt, iterations=iterations, @@ -558,7 +528,7 @@ class Generate: embiggen_strength=embiggen_strength, inpaint_replace=inpaint_replace, mask_blur_radius=mask_blur_radius, - safety_checker=checker, + safety_checker=self.safety_checker, seam_size=seam_size, seam_blur=seam_blur, seam_strength=seam_strength, @@ -940,18 +910,6 @@ class Generate: self.generators = {} set_seed(random.randrange(0, np.iinfo(np.uint32).max)) - if self.embedding_path is not None: - print(f">> Loading embeddings from {self.embedding_path}") - for root, _, files in os.walk(self.embedding_path): - for name in files: - ti_path = os.path.join(root, name) - self.model.textual_inversion_manager.load_textual_inversion( - ti_path, defer_injecting_tokens=True - ) - print( - f'>> Textual inversion triggers: {", ".join(sorted(self.model.textual_inversion_manager.get_all_trigger_strings()))}' - ) - self.model_name = model_name self._set_scheduler() # requires self.model_name to be set first return self.model @@ -998,7 +956,7 @@ class Generate: ): results = [] for r in image_list: - image, seed = r + image, seed, _ = r try: if strength > 0: if self.gfpgan is not None or self.codeformer is not None: diff --git a/invokeai/backend/generator/__init__.py b/invokeai/backend/generator/__init__.py index b01e93ad81..9d6263453a 100644 --- a/invokeai/backend/generator/__init__.py +++ b/invokeai/backend/generator/__init__.py @@ -1,5 +1,13 @@ """ Initialization file for the invokeai.generator package """ -from .base import Generator +from .base import ( + InvokeAIGenerator, + InvokeAIGeneratorBasicParams, + InvokeAIGeneratorOutput, + Txt2Img, + Img2Img, + Inpaint, + Generator, +) from .inpaint import infill_methods diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py index 881d3deaff..4ec0f9d54f 100644 --- a/invokeai/backend/generator/base.py +++ b/invokeai/backend/generator/base.py @@ -4,11 +4,15 @@ including img2img, txt2img, and inpaint """ from __future__ import annotations +import itertools +import dataclasses +import diffusers import os import random import traceback +from abc import ABCMeta +from argparse import Namespace from contextlib import nullcontext -from pathlib import Path import cv2 import numpy as np @@ -17,13 +21,258 @@ from PIL import Image, ImageChops, ImageFilter from accelerate.utils import set_seed from diffusers import DiffusionPipeline from tqdm import trange +from typing import List, Iterator, Type +from dataclasses import dataclass, field +from diffusers.schedulers import SchedulerMixin as Scheduler -import invokeai.assets.web as web_assets +from ..image_util import configure_model_padding from ..util.util import rand_perlin_2d +from ..safety_checker import SafetyChecker +from ..prompting.conditioning import get_uc_and_c_and_ec +from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline downsampling = 8 -CAUTION_IMG = "caution.png" +@dataclass +class InvokeAIGeneratorBasicParams: + seed: int=None + width: int=512 + height: int=512 + cfg_scale: int=7.5 + steps: int=20 + ddim_eta: float=0.0 + scheduler: int='ddim' + precision: str='float16' + perlin: float=0.0 + threshold: int=0.0 + seamless: bool=False + seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y']) + h_symmetry_time_pct: float=None + v_symmetry_time_pct: float=None + variation_amount: float = 0.0 + with_variations: list=field(default_factory=list) + safety_checker: SafetyChecker=None + +@dataclass +class InvokeAIGeneratorOutput: + ''' + InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation + operation, including the image, its seed, the model name used to generate the image + and the model hash, as well as all the generate() parameters that went into + generating the image (in .params, also available as attributes) + ''' + image: Image + seed: int + model_hash: str + attention_maps_images: List[Image] + params: Namespace + +# we are interposing a wrapper around the original Generator classes so that +# old code that calls Generate will continue to work. +class InvokeAIGenerator(metaclass=ABCMeta): + scheduler_map = dict( + ddim=diffusers.DDIMScheduler, + dpmpp_2=diffusers.DPMSolverMultistepScheduler, + k_dpm_2=diffusers.KDPM2DiscreteScheduler, + k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler, + k_dpmpp_2=diffusers.DPMSolverMultistepScheduler, + k_euler=diffusers.EulerDiscreteScheduler, + k_euler_a=diffusers.EulerAncestralDiscreteScheduler, + k_heun=diffusers.HeunDiscreteScheduler, + k_lms=diffusers.LMSDiscreteScheduler, + plms=diffusers.PNDMScheduler, + ) + + def __init__(self, + model_info: dict, + params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), + ): + self.model_info=model_info + self.params=params + + def generate(self, + prompt: str='', + callback: callable=None, + step_callback: callable=None, + iterations: int=1, + **keyword_args, + )->Iterator[InvokeAIGeneratorOutput]: + ''' + Return an iterator across the indicated number of generations. + Each time the iterator is called it will return an InvokeAIGeneratorOutput + object. Use like this: + + outputs = txt2img.generate(prompt='banana sushi', iterations=5) + for result in outputs: + print(result.image, result.seed) + + In the typical case of wanting to get just a single image, iterations + defaults to 1 and do: + + output = next(txt2img.generate(prompt='banana sushi') + + Pass None to get an infinite iterator. + + outputs = txt2img.generate(prompt='banana sushi', iterations=None) + for o in outputs: + print(o.image, o.seed) + + ''' + generator_args = dataclasses.asdict(self.params) + generator_args.update(keyword_args) + + model_info = self.model_info + model_name = model_info['model_name'] + model:StableDiffusionGeneratorPipeline = model_info['model'] + model_hash = model_info['hash'] + scheduler: Scheduler = self.get_scheduler( + model=model, + scheduler_name=generator_args.get('scheduler') + ) + uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model) + gen_class = self._generator_class() + generator = gen_class(model, self.params.precision) + if self.params.variation_amount > 0: + generator.set_variation(generator_args.get('seed'), + generator_args.get('variation_amount'), + generator_args.get('with_variations') + ) + + if isinstance(model, DiffusionPipeline): + for component in [model.unet, model.vae]: + configure_model_padding(component, + generator_args.get('seamless',False), + generator_args.get('seamless_axes') + ) + else: + configure_model_padding(model, + generator_args.get('seamless',False), + generator_args.get('seamless_axes') + ) + + iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1) + for i in iteration_count: + results = generator.generate(prompt, + conditioning=(uc, c, extra_conditioning_info), + sampler=scheduler, + **generator_args, + ) + output = InvokeAIGeneratorOutput( + image=results[0][0], + seed=results[0][1], + attention_maps_images=results[0][2], + model_hash = model_hash, + params=Namespace(model_name=model_name,**generator_args), + ) + if callback: + callback(output) + yield output + + @classmethod + def schedulers(self)->List[str]: + ''' + Return list of all the schedulers that we currently handle. + ''' + return list(self.scheduler_map.keys()) + + def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]): + return generator_class(model, self.params.precision) + + def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: + scheduler_class = self.scheduler_map.get(scheduler_name,'ddim') + scheduler = scheduler_class.from_config(model.scheduler.config) + # hack copied over from generate.py + if not hasattr(scheduler, 'uses_inpainting_model'): + scheduler.uses_inpainting_model = lambda: False + return scheduler + + @classmethod + def _generator_class(cls)->Type[Generator]: + ''' + In derived classes return the name of the generator to apply. + If you don't override will return the name of the derived + class, which nicely parallels the generator class names. + ''' + return Generator + +# ------------------------------------ +class Txt2Img(InvokeAIGenerator): + @classmethod + def _generator_class(cls): + from .txt2img import Txt2Img + return Txt2Img + +# ------------------------------------ +class Img2Img(InvokeAIGenerator): + def generate(self, + init_image: Image | torch.FloatTensor, + strength: float=0.75, + **keyword_args + )->List[InvokeAIGeneratorOutput]: + return super().generate(init_image=init_image, + strength=strength, + **keyword_args + ) + @classmethod + def _generator_class(cls): + from .img2img import Img2Img + return Img2Img + +# ------------------------------------ +# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff +class Inpaint(Img2Img): + def generate(self, + mask_image: Image | torch.FloatTensor, + # Seam settings - when 0, doesn't fill seam + seam_size: int = 0, + seam_blur: int = 0, + seam_strength: float = 0.7, + seam_steps: int = 10, + tile_size: int = 32, + inpaint_replace=False, + infill_method=None, + inpaint_width=None, + inpaint_height=None, + inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), + **keyword_args + )->List[InvokeAIGeneratorOutput]: + return super().generate( + mask_image=mask_image, + seam_size=seam_size, + seam_blur=seam_blur, + seam_strength=seam_strength, + seam_steps=seam_steps, + tile_size=tile_size, + inpaint_replace=inpaint_replace, + infill_method=infill_method, + inpaint_width=inpaint_width, + inpaint_height=inpaint_height, + inpaint_fill=inpaint_fill, + **keyword_args + ) + @classmethod + def _generator_class(cls): + from .inpaint import Inpaint + return Inpaint + +# ------------------------------------ +class Embiggen(Txt2Img): + def generate( + self, + embiggen: list=None, + embiggen_tiles: list = None, + strength: float=0.75, + **kwargs)->List[InvokeAIGeneratorOutput]: + return super().generate(embiggen=embiggen, + embiggen_tiles=embiggen_tiles, + strength=strength, + **kwargs) + + @classmethod + def _generator_class(cls): + from .embiggen import Embiggen + return Embiggen + class Generator: downsampling_factor: int @@ -44,7 +293,6 @@ class Generator: self.with_variations = [] self.use_mps_noise = False self.free_gpu_mem = None - self.caution_img = None # this is going to be overridden in img2img.py, txt2img.py and inpaint.py def get_make_image(self, prompt, **kwargs): @@ -64,10 +312,10 @@ class Generator: def generate( self, prompt, - init_image, width, height, sampler, + init_image=None, iterations=1, seed=None, image_callback=None, @@ -76,7 +324,7 @@ class Generator: perlin=0.0, h_symmetry_time_pct=None, v_symmetry_time_pct=None, - safety_checker: dict = None, + safety_checker: SafetyChecker=None, free_gpu_mem: bool = False, **kwargs, ): @@ -130,9 +378,9 @@ class Generator: image = make_image(x_T) if self.safety_checker is not None: - image = self.safety_check(image) + image = self.safety_checker.check(image) - results.append([image, seed]) + results.append([image, seed, attention_maps_images]) if image_callback is not None: attention_maps_image = ( @@ -292,16 +540,6 @@ class Generator: seed = random.randrange(0, np.iinfo(np.uint32).max) return (seed, initial_noise) - # returns a tensor filled with random numbers from a normal distribution - def get_noise(self, width, height): - """ - Returns a tensor filled with random numbers, either form a normal distribution - (txt2img) or from the latent image (img2img, inpaint) - """ - raise NotImplementedError( - "get_noise() must be implemented in a descendent class" - ) - def get_perlin_noise(self, width, height): fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device # limit noise to only the diffusion image channels, not the mask channels @@ -361,53 +599,6 @@ class Generator: return v2 - def safety_check(self, image: Image.Image): - """ - If the CompViz safety checker flags an NSFW image, we - blur it out. - """ - import diffusers - - checker = self.safety_checker["checker"] - extractor = self.safety_checker["extractor"] - features = extractor([image], return_tensors="pt") - features.to(self.model.device) - - # unfortunately checker requires the numpy version, so we have to convert back - x_image = np.array(image).astype(np.float32) / 255.0 - x_image = x_image[None].transpose(0, 3, 1, 2) - - diffusers.logging.set_verbosity_error() - checked_image, has_nsfw_concept = checker( - images=x_image, clip_input=features.pixel_values - ) - if has_nsfw_concept[0]: - print( - "** An image with potential non-safe content has been detected. A blurred image will be returned. **" - ) - return self.blur(image) - else: - return image - - def blur(self, input): - blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32)) - try: - caution = self.get_caution_img() - if caution: - blurry.paste(caution, (0, 0), caution) - except FileNotFoundError: - pass - return blurry - - def get_caution_img(self): - path = None - if self.caution_img: - return self.caution_img - path = Path(web_assets.__path__[0]) / CAUTION_IMG - caution = Image.open(path) - self.caution_img = caution.resize((caution.width // 2, caution.height // 2)) - return self.caution_img - # this is a handy routine for debugging use. Given a generated sample, # convert it into a PNG image and store it at the indicated path def save_sample(self, sample, filepath): diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 4627f283f5..06b1490c93 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -34,8 +34,7 @@ from picklescan.scanner import scan_file_path from invokeai.backend.globals import Globals, global_cache_dir from ..stable_diffusion import StableDiffusionGeneratorPipeline -from ..util import CPU_DEVICE, ask_user, download_with_resume - +from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume class SDLegacyType(Enum): V1 = 1 @@ -51,23 +50,29 @@ VAE_TO_REPO_ID = { # hack, see note in convert_and_import() } class ModelManager(object): + ''' + Model manager handles loading, caching, importing, deleting, converting, and editing models. + ''' def __init__( - self, - config: OmegaConf, - device_type: torch.device = CPU_DEVICE, - precision: str = "float16", - max_loaded_models=DEFAULT_MAX_MODELS, - sequential_offload=False, + self, + config: OmegaConf|Path, + device_type: torch.device = CUDA_DEVICE, + precision: str = "float16", + max_loaded_models=DEFAULT_MAX_MODELS, + sequential_offload=False, + embedding_path: Path=None, ): """ - Initialize with the path to the models.yaml config file, - the torch device type, and precision. The optional - min_avail_mem argument specifies how much unused system - (CPU) memory to preserve. The cache of models in RAM will - grow until this value is approached. Default is 2G. + Initialize with the path to the models.yaml config file or + an initialized OmegaConf dictionary. Optional parameters + are the torch device type, precision, max_loaded_models, + and sequential_offload boolean. Note that the default device + type and precision are set up for a CUDA system running at half precision. """ # prevent nasty-looking CLIP log message transformers.logging.set_verbosity_error() + if not isinstance(config, DictConfig): + config = OmegaConf.load(config) self.config = config self.precision = precision self.device = torch.device(device_type) @@ -76,6 +81,7 @@ class ModelManager(object): self.stack = [] # this is an LRU FIFO self.current_model = None self.sequential_offload = sequential_offload + self.embedding_path = embedding_path def valid_model(self, model_name: str) -> bool: """ @@ -84,12 +90,15 @@ class ModelManager(object): """ return model_name in self.config - def get_model(self, model_name: str): + def get_model(self, model_name: str=None)->dict: """ Given a model named identified in models.yaml, return the model object. If in RAM will load into GPU VRAM. If on disk, will load from there. """ + if not model_name: + return self.current_model if self.current_model else self.get_model(self.default_model()) + if not self.valid_model(model_name): print( f'** "{model_name}" is not a known model name. Please check your models.yaml file' @@ -112,6 +121,7 @@ class ModelManager(object): else: # we're about to load a new model, so potentially offload the least recently used one requested_model, width, height, hash = self._load_model(model_name) self.models[model_name] = { + "model_name": model_name, "model": requested_model, "width": width, "height": height, @@ -121,6 +131,7 @@ class ModelManager(object): self.current_model = model_name self._push_newest_model(model_name) return { + "model_name": model_name, "model": requested_model, "width": width, "height": height, @@ -425,6 +436,7 @@ class ModelManager(object): height = width print(f" | Default image dimensions = {width} x {height}") + self._add_embeddings_to_model(pipeline) return pipeline, width, height, model_hash @@ -1061,6 +1073,19 @@ class ModelManager(object): self.stack.remove(model_name) self.stack.append(model_name) + def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline): + if self.embedding_path is not None: + print(f">> Loading embeddings from {self.embedding_path}") + for root, _, files in os.walk(self.embedding_path): + for name in files: + ti_path = os.path.join(root, name) + model.textual_inversion_manager.load_textual_inversion( + ti_path, defer_injecting_tokens=True + ) + print( + f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}' + ) + def _has_cuda(self) -> bool: return self.device.type == "cuda" diff --git a/invokeai/backend/safety_checker.py b/invokeai/backend/safety_checker.py new file mode 100644 index 0000000000..2e6c4fd479 --- /dev/null +++ b/invokeai/backend/safety_checker.py @@ -0,0 +1,82 @@ +''' +SafetyChecker class - checks images against the StabilityAI NSFW filter +and blurs images that contain potential NSFW content. +''' +import diffusers +import numpy as np +import torch +import traceback +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) +from pathlib import Path +from PIL import Image, ImageFilter +from transformers import AutoFeatureExtractor + +import invokeai.assets.web as web_assets +from .globals import global_cache_dir +from .util import CPU_DEVICE + +class SafetyChecker(object): + CAUTION_IMG = "caution.png" + + def __init__(self, device: torch.device): + path = Path(web_assets.__path__[0]) / self.CAUTION_IMG + caution = Image.open(path) + self.caution_img = caution.resize((caution.width // 2, caution.height // 2)) + self.device = device + + try: + safety_model_id = "CompVis/stable-diffusion-safety-checker" + safety_model_path = global_cache_dir("hub") + self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( + safety_model_id, + local_files_only=True, + cache_dir=safety_model_path, + ) + self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained( + safety_model_id, + local_files_only=True, + cache_dir=safety_model_path, + ) + except Exception: + print( + "** An error was encountered while installing the safety checker:" + ) + print(traceback.format_exc()) + + def check(self, image: Image.Image): + """ + Check provided image against the StabilityAI safety checker and return + + """ + + self.safety_checker.to(self.device) + features = self.safety_feature_extractor([image], return_tensors="pt") + features.to(self.device) + + # unfortunately checker requires the numpy version, so we have to convert back + x_image = np.array(image).astype(np.float32) / 255.0 + x_image = x_image[None].transpose(0, 3, 1, 2) + + diffusers.logging.set_verbosity_error() + checked_image, has_nsfw_concept = self.safety_checker( + images=x_image, clip_input=features.pixel_values + ) + self.safety_checker.to(CPU_DEVICE) # offload + if has_nsfw_concept[0]: + print( + "** An image with potential non-safe content has been detected. A blurred image will be returned. **" + ) + return self.blur(image) + else: + return image + + def blur(self, input): + blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32)) + try: + if caution := self.caution_img: + blurry.paste(caution, (0, 0), caution) + except FileNotFoundError: + pass + return blurry diff --git a/static/dream_web/favicon.ico b/static/dream_web/favicon.ico new file mode 100644 index 0000000000..51eb844a6a Binary files /dev/null and b/static/dream_web/favicon.ico differ diff --git a/static/dream_web/index.css b/static/dream_web/index.css new file mode 100644 index 0000000000..25a0994a3d --- /dev/null +++ b/static/dream_web/index.css @@ -0,0 +1,179 @@ +:root { + --fields-dark:#DCDCDC; + --fields-light:#F5F5F5; +} + +* { + font-family: 'Arial'; + font-size: 100%; +} +body { + font-size: 1em; +} +textarea { + font-size: 0.95em; +} +header, form, #progress-section { + margin-left: auto; + margin-right: auto; + max-width: 1024px; + text-align: center; +} +fieldset { + border: none; + line-height: 2.2em; +} +fieldset > legend { + width: auto; + margin-left: 0; + margin-right: auto; + font-weight:bold; +} +select, input { + margin-right: 10px; + padding: 2px; +} +input:disabled { + cursor:auto; +} +input[type=submit] { + cursor: pointer; + background-color: #666; + color: white; +} +input[type=checkbox] { + cursor: pointer; + margin-right: 0px; + width: 20px; + height: 20px; + vertical-align: middle; +} +input#seed { + margin-right: 0px; +} +div { + padding: 10px 10px 10px 10px; +} +header { + margin-bottom: 16px; +} +header h1 { + margin-bottom: 0; + font-size: 2em; +} +#search-box { + display: flex; +} +#scaling-inprocess-message { + font-weight: bold; + font-style: italic; + display: none; +} +#prompt { + flex-grow: 1; + padding: 5px 10px 5px 10px; + border: 1px solid #999; + outline: none; +} +#submit { + padding: 5px 10px 5px 10px; + border: 1px solid #999; +} +#reset-all, #remove-image { + margin-top: 12px; + font-size: 0.8em; + background-color: pink; + border: 1px solid #999; + border-radius: 4px; +} +#results { + text-align: center; + margin: auto; + padding-top: 10px; +} +#results figure { + display: inline-block; + margin: 10px; +} +#results figcaption { + font-size: 0.8em; + padding: 3px; + color: #888; + cursor: pointer; +} +#results img { + border-radius: 5px; + object-fit: contain; + background-color: var(--fields-dark); +} +#fieldset-config { + line-height:2em; +} +input[type="number"] { + width: 60px; +} +#seed { + width: 150px; +} +button#reset-seed { + font-size: 1.7em; + background: #efefef; + border: 1px solid #999; + border-radius: 4px; + line-height: 0.8; + margin: 0 10px 0 0; + padding: 0 5px 3px; + vertical-align: middle; +} +label { + white-space: nowrap; +} +#progress-section { + display: none; +} +#progress-image { + width: 30vh; + height: 30vh; + object-fit: contain; + background-color: var(--fields-dark); +} +#cancel-button { + cursor: pointer; + color: red; +} +#txt2img { + background-color: var(--fields-dark); +} +#variations { + background-color: var(--fields-light); +} +#initimg { + background-color: var(--fields-dark); +} +#img2img { + background-color: var(--fields-light); +} +#initimg > :not(legend) { + background-color: var(--fields-light); + margin: .5em; +} + +#postprocess, #initimg { + display:flex; + flex-wrap:wrap; + padding: 0; + margin-top: 1em; + background-color: var(--fields-dark); +} +#postprocess > fieldset, #initimg > * { + flex-grow: 1; +} +#postprocess > fieldset { + background-color: var(--fields-dark); +} +#progress-section { + background-color: var(--fields-light); +} +#no-results-message:not(:only-child) { + display: none; +} diff --git a/static/dream_web/index.html b/static/dream_web/index.html new file mode 100644 index 0000000000..feb542adb2 --- /dev/null +++ b/static/dream_web/index.html @@ -0,0 +1,187 @@ + + + + Stable Diffusion Dream Server + + + + + + + + + + + +
+

Stable Diffusion Dream Server

+
+ For news and support for this web service, visit our GitHub + site +
+
+ +
+ +
+
+ + + + + + + + + + + + + + + +
+ + + + + + + + + +
+ + + + + +
+
+ + + + +
+
+
+ + + + +
+ + + +
+
+ + + + + + + + +
+
+
+
+ + + + + + +
+
+ + + + + + + + +
+
+ +
+
+
+
+ + +
+ +
+ Postprocessing...1/3 +
+
+
+ +
+
+
+ + + diff --git a/static/dream_web/index.js b/static/dream_web/index.js new file mode 100644 index 0000000000..55bc4bdb8f --- /dev/null +++ b/static/dream_web/index.js @@ -0,0 +1,396 @@ +const socket = io(); + +var priorResultsLoadState = { + page: 0, + pages: 1, + per_page: 10, + total: 20, + offset: 0, // number of items generated since last load + loading: false, + initialized: false +}; + +function loadPriorResults() { + // Fix next page by offset + let offsetPages = priorResultsLoadState.offset / priorResultsLoadState.per_page; + priorResultsLoadState.page += offsetPages; + priorResultsLoadState.pages += offsetPages; + priorResultsLoadState.total += priorResultsLoadState.offset; + priorResultsLoadState.offset = 0; + + if (priorResultsLoadState.loading) { + return; + } + + if (priorResultsLoadState.page >= priorResultsLoadState.pages) { + return; // Nothing more to load + } + + // Load + priorResultsLoadState.loading = true + let url = new URL('/api/images', document.baseURI); + url.searchParams.append('page', priorResultsLoadState.initialized ? priorResultsLoadState.page + 1 : priorResultsLoadState.page); + url.searchParams.append('per_page', priorResultsLoadState.per_page); + fetch(url.href, { + method: 'GET', + headers: new Headers({'content-type': 'application/json'}) + }) + .then(response => response.json()) + .then(data => { + priorResultsLoadState.page = data.page; + priorResultsLoadState.pages = data.pages; + priorResultsLoadState.per_page = data.per_page; + priorResultsLoadState.total = data.total; + + data.items.forEach(function(dreamId, index) { + let src = 'api/images/' + dreamId; + fetch('/api/images/' + dreamId + '/metadata', { + method: 'GET', + headers: new Headers({'content-type': 'application/json'}) + }) + .then(response => response.json()) + .then(metadata => { + let seed = metadata.seed || 0; // TODO: Parse old metadata + appendOutput(src, seed, metadata, true); + }); + }); + + // Load until page is full + if (!priorResultsLoadState.initialized) { + if (document.body.scrollHeight <= window.innerHeight) { + loadPriorResults(); + } + } + }) + .finally(() => { + priorResultsLoadState.loading = false; + priorResultsLoadState.initialized = true; + }); +} + +function resetForm() { + var form = document.getElementById('generate-form'); + form.querySelector('fieldset').removeAttribute('disabled'); +} + +function initProgress(totalSteps, showProgressImages) { + // TODO: Progress could theoretically come from multiple jobs at the same time (in the future) + let progressSectionEle = document.querySelector('#progress-section'); + progressSectionEle.style.display = 'initial'; + let progressEle = document.querySelector('#progress-bar'); + progressEle.setAttribute('max', totalSteps); + + let progressImageEle = document.querySelector('#progress-image'); + progressImageEle.src = BLANK_IMAGE_URL; + progressImageEle.style.display = showProgressImages ? 'initial': 'none'; +} + +function setProgress(step, totalSteps, src) { + let progressEle = document.querySelector('#progress-bar'); + progressEle.setAttribute('value', step); + + if (src) { + let progressImageEle = document.querySelector('#progress-image'); + progressImageEle.src = src; + } +} + +function resetProgress(hide = true) { + if (hide) { + let progressSectionEle = document.querySelector('#progress-section'); + progressSectionEle.style.display = 'none'; + } + let progressEle = document.querySelector('#progress-bar'); + progressEle.setAttribute('value', 0); +} + +function toBase64(file) { + return new Promise((resolve, reject) => { + const r = new FileReader(); + r.readAsDataURL(file); + r.onload = () => resolve(r.result); + r.onerror = (error) => reject(error); + }); +} + +function ondragdream(event) { + let dream = event.target.dataset.dream; + event.dataTransfer.setData("dream", dream); +} + +function seedClick(event) { + // Get element + var image = event.target.closest('figure').querySelector('img'); + var dream = JSON.parse(decodeURIComponent(image.dataset.dream)); + + let form = document.querySelector("#generate-form"); + for (const [k, v] of new FormData(form)) { + if (k == 'initimg') { continue; } + let formElem = form.querySelector(`*[name=${k}]`); + formElem.value = dream[k] !== undefined ? dream[k] : formElem.defaultValue; + } + + document.querySelector("#seed").value = dream.seed; + document.querySelector('#iterations').value = 1; // Reset to 1 iteration since we clicked a single image (not a full job) + + // NOTE: leaving this manual for the user for now - it was very confusing with this behavior + // document.querySelector("#with_variations").value = variations || ''; + // if (document.querySelector("#variation_amount").value <= 0) { + // document.querySelector("#variation_amount").value = 0.2; + // } + + saveFields(document.querySelector("#generate-form")); +} + +function appendOutput(src, seed, config, toEnd=false) { + let outputNode = document.createElement("figure"); + let altText = seed.toString() + " | " + config.prompt; + + // img needs width and height for lazy loading to work + // TODO: store the full config in a data attribute on the image? + const figureContents = ` + + ${altText} + +
${seed}
+ `; + + outputNode.innerHTML = figureContents; + + if (toEnd) { + document.querySelector("#results").append(outputNode); + } else { + document.querySelector("#results").prepend(outputNode); + } + document.querySelector("#no-results-message")?.remove(); +} + +function saveFields(form) { + for (const [k, v] of new FormData(form)) { + if (typeof v !== 'object') { // Don't save 'file' type + localStorage.setItem(k, v); + } + } +} + +function loadFields(form) { + for (const [k, v] of new FormData(form)) { + const item = localStorage.getItem(k); + if (item != null) { + form.querySelector(`*[name=${k}]`).value = item; + } + } +} + +function clearFields(form) { + localStorage.clear(); + let prompt = form.prompt.value; + form.reset(); + form.prompt.value = prompt; +} + +const BLANK_IMAGE_URL = 'data:image/svg+xml,'; +async function generateSubmit(form) { + // Convert file data to base64 + // TODO: Should probably uplaod files with formdata or something, and store them in the backend? + let formData = Object.fromEntries(new FormData(form)); + if (!formData.enable_generate && !formData.enable_init_image) { + gen_label = document.querySelector("label[for=enable_generate]").innerHTML; + initimg_label = document.querySelector("label[for=enable_init_image]").innerHTML; + alert(`Error: one of "${gen_label}" or "${initimg_label}" must be set`); + } + + + formData.initimg_name = formData.initimg.name + formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null; + + // Evaluate all checkboxes + let checkboxes = form.querySelectorAll('input[type=checkbox]'); + checkboxes.forEach(function (checkbox) { + if (checkbox.checked) { + formData[checkbox.name] = 'true'; + } + }); + + let strength = formData.strength; + let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps; + let showProgressImages = formData.progress_images; + + // Set enabling flags + + + // Initialize the progress bar + initProgress(totalSteps, showProgressImages); + + // POST, use response to listen for events + fetch(form.action, { + method: form.method, + headers: new Headers({'content-type': 'application/json'}), + body: JSON.stringify(formData), + }) + .then(response => response.json()) + .then(data => { + var jobId = data.jobId; + socket.emit('join_room', { 'room': jobId }); + }); + + form.querySelector('fieldset').setAttribute('disabled',''); +} + +function fieldSetEnableChecked(event) { + cb = event.target; + fields = cb.closest('fieldset'); + fields.disabled = !cb.checked; +} + +// Socket listeners +socket.on('job_started', (data) => {}) + +socket.on('dream_result', (data) => { + var jobId = data.jobId; + var dreamId = data.dreamId; + var dreamRequest = data.dreamRequest; + var src = 'api/images/' + dreamId; + + priorResultsLoadState.offset += 1; + appendOutput(src, dreamRequest.seed, dreamRequest); + + resetProgress(false); +}) + +socket.on('dream_progress', (data) => { + // TODO: it'd be nice if we could get a seed reported here, but the generator would need to be updated + var step = data.step; + var totalSteps = data.totalSteps; + var jobId = data.jobId; + var dreamId = data.dreamId; + + var progressType = data.progressType + if (progressType === 'GENERATION') { + var src = data.hasProgressImage ? + 'api/intermediates/' + dreamId + '/' + step + : null; + setProgress(step, totalSteps, src); + } else if (progressType === 'UPSCALING_STARTED') { + // step and totalSteps are used for upscale count on this message + document.getElementById("processing_cnt").textContent = step; + document.getElementById("processing_total").textContent = totalSteps; + document.getElementById("scaling-inprocess-message").style.display = "block"; + } else if (progressType == 'UPSCALING_DONE') { + document.getElementById("scaling-inprocess-message").style.display = "none"; + } +}) + +socket.on('job_canceled', (data) => { + resetForm(); + resetProgress(); +}) + +socket.on('job_done', (data) => { + jobId = data.jobId + socket.emit('leave_room', { 'room': jobId }); + + resetForm(); + resetProgress(); +}) + +window.onload = async () => { + document.querySelector("#prompt").addEventListener("keydown", (e) => { + if (e.key === "Enter" && !e.shiftKey) { + const form = e.target.form; + generateSubmit(form); + } + }); + document.querySelector("#generate-form").addEventListener('submit', (e) => { + e.preventDefault(); + const form = e.target; + + generateSubmit(form); + }); + document.querySelector("#generate-form").addEventListener('change', (e) => { + saveFields(e.target.form); + }); + document.querySelector("#reset-seed").addEventListener('click', (e) => { + document.querySelector("#seed").value = 0; + saveFields(e.target.form); + }); + document.querySelector("#reset-all").addEventListener('click', (e) => { + clearFields(e.target.form); + }); + document.querySelector("#remove-image").addEventListener('click', (e) => { + initimg.value=null; + }); + loadFields(document.querySelector("#generate-form")); + + document.querySelector('#cancel-button').addEventListener('click', () => { + fetch('/api/cancel').catch(e => { + console.error(e); + }); + }); + document.documentElement.addEventListener('keydown', (e) => { + if (e.key === "Escape") + fetch('/api/cancel').catch(err => { + console.error(err); + }); + }); + + if (!config.gfpgan_model_exists) { + document.querySelector("#gfpgan").style.display = 'none'; + } + + window.addEventListener("scroll", () => { + if ((window.innerHeight + window.pageYOffset) >= document.body.offsetHeight) { + loadPriorResults(); + } + }); + + + + // Enable/disable forms by checkboxes + document.querySelectorAll("legend > input[type=checkbox]").forEach(function(cb) { + cb.addEventListener('change', fieldSetEnableChecked); + fieldSetEnableChecked({ target: cb}) + }); + + + // Load some of the previous results + loadPriorResults(); + + // Image drop/upload WIP + /* + let drop = document.getElementById('dropper'); + function ondrop(event) { + let dreamData = event.dataTransfer.getData('dream'); + if (dreamData) { + var dream = JSON.parse(decodeURIComponent(dreamData)); + alert(dream.dreamId); + } + }; + + function ondragenter(event) { + event.preventDefault(); + }; + + function ondragover(event) { + event.preventDefault(); + }; + + function ondragleave(event) { + + } + + drop.addEventListener('drop', ondrop); + drop.addEventListener('dragenter', ondragenter); + drop.addEventListener('dragover', ondragover); + drop.addEventListener('dragleave', ondragleave); + */ +}; diff --git a/static/legacy_web/favicon.ico b/static/legacy_web/favicon.ico new file mode 100644 index 0000000000..51eb844a6a Binary files /dev/null and b/static/legacy_web/favicon.ico differ diff --git a/static/legacy_web/index.css b/static/legacy_web/index.css new file mode 100644 index 0000000000..51f0f267c3 --- /dev/null +++ b/static/legacy_web/index.css @@ -0,0 +1,152 @@ +* { + font-family: 'Arial'; + font-size: 100%; +} +body { + font-size: 1em; +} +textarea { + font-size: 0.95em; +} +header, form, #progress-section { + margin-left: auto; + margin-right: auto; + max-width: 1024px; + text-align: center; +} +fieldset { + border: none; + line-height: 2.2em; +} +select, input { + margin-right: 10px; + padding: 2px; +} +input[type=submit] { + background-color: #666; + color: white; +} +input[type=checkbox] { + margin-right: 0px; + width: 20px; + height: 20px; + vertical-align: middle; +} +input#seed { + margin-right: 0px; +} +div { + padding: 10px 10px 10px 10px; +} +header { + margin-bottom: 16px; +} +header h1 { + margin-bottom: 0; + font-size: 2em; +} +#search-box { + display: flex; +} +#scaling-inprocess-message { + font-weight: bold; + font-style: italic; + display: none; +} +#prompt { + flex-grow: 1; + padding: 5px 10px 5px 10px; + border: 1px solid #999; + outline: none; +} +#submit { + padding: 5px 10px 5px 10px; + border: 1px solid #999; +} +#reset-all, #remove-image { + margin-top: 12px; + font-size: 0.8em; + background-color: pink; + border: 1px solid #999; + border-radius: 4px; +} +#results { + text-align: center; + margin: auto; + padding-top: 10px; +} +#results figure { + display: inline-block; + margin: 10px; +} +#results figcaption { + font-size: 0.8em; + padding: 3px; + color: #888; + cursor: pointer; +} +#results img { + border-radius: 5px; + object-fit: cover; +} +#fieldset-config { + line-height:2em; + background-color: #F0F0F0; +} +input[type="number"] { + width: 60px; +} +#seed { + width: 150px; +} +button#reset-seed { + font-size: 1.7em; + background: #efefef; + border: 1px solid #999; + border-radius: 4px; + line-height: 0.8; + margin: 0 10px 0 0; + padding: 0 5px 3px; + vertical-align: middle; +} +label { + white-space: nowrap; +} +#progress-section { + display: none; +} +#progress-image { + width: 30vh; + height: 30vh; +} +#cancel-button { + cursor: pointer; + color: red; +} +#basic-parameters { + background-color: #EEEEEE; +} +#txt2img { + background-color: #DCDCDC; +} +#variations { + background-color: #EEEEEE; +} +#img2img { + background-color: #DCDCDC; +} +#gfpgan { + background-color: #EEEEEE; +} +#progress-section { + background-color: #F5F5F5; +} +.section-header { + text-align: left; + font-weight: bold; + padding: 0 0 0 0; +} +#no-results-message:not(:only-child) { + display: none; +} + diff --git a/static/legacy_web/index.html b/static/legacy_web/index.html new file mode 100644 index 0000000000..c96eed54c3 --- /dev/null +++ b/static/legacy_web/index.html @@ -0,0 +1,137 @@ + + + Stable Diffusion Dream Server + + + + + + + + +
+

Stable Diffusion Dream Server

+
+ For news and support for this web service, visit our GitHub site +
+
+ +
+
+
+ +
+
+
Basic options
+ + + + + + + + + + +
+ + + + + + + + + +
+ + + + + +
+ + + + + + +
+
+
Image-to-image options
+ + + +
+ + + + +
+
+
Post-processing options
+ + + + + + +
+
+
+
+
+ + +
+ +
+ Postprocessing...1/3 +
+ +
+ +
+
+

No results...

+
+
+
+ + diff --git a/static/legacy_web/index.js b/static/legacy_web/index.js new file mode 100644 index 0000000000..57ad076062 --- /dev/null +++ b/static/legacy_web/index.js @@ -0,0 +1,213 @@ +function toBase64(file) { + return new Promise((resolve, reject) => { + const r = new FileReader(); + r.readAsDataURL(file); + r.onload = () => resolve(r.result); + r.onerror = (error) => reject(error); + }); +} + +function appendOutput(src, seed, config) { + let outputNode = document.createElement("figure"); + + let variations = config.with_variations; + if (config.variation_amount > 0) { + variations = (variations ? variations + ',' : '') + seed + ':' + config.variation_amount; + } + let baseseed = (config.with_variations || config.variation_amount > 0) ? config.seed : seed; + let altText = baseseed + ' | ' + (variations ? variations + ' | ' : '') + config.prompt; + + // img needs width and height for lazy loading to work + const figureContents = ` + + ${altText} + +
${seed}
+ `; + + outputNode.innerHTML = figureContents; + let figcaption = outputNode.querySelector('figcaption'); + + // Reload image config + figcaption.addEventListener('click', () => { + let form = document.querySelector("#generate-form"); + for (const [k, v] of new FormData(form)) { + if (k == 'initimg') { continue; } + form.querySelector(`*[name=${k}]`).value = config[k]; + } + + document.querySelector("#seed").value = baseseed; + document.querySelector("#with_variations").value = variations || ''; + if (document.querySelector("#variation_amount").value <= 0) { + document.querySelector("#variation_amount").value = 0.2; + } + + saveFields(document.querySelector("#generate-form")); + }); + + document.querySelector("#results").prepend(outputNode); +} + +function saveFields(form) { + for (const [k, v] of new FormData(form)) { + if (typeof v !== 'object') { // Don't save 'file' type + localStorage.setItem(k, v); + } + } +} + +function loadFields(form) { + for (const [k, v] of new FormData(form)) { + const item = localStorage.getItem(k); + if (item != null) { + form.querySelector(`*[name=${k}]`).value = item; + } + } +} + +function clearFields(form) { + localStorage.clear(); + let prompt = form.prompt.value; + form.reset(); + form.prompt.value = prompt; +} + +const BLANK_IMAGE_URL = 'data:image/svg+xml,'; +async function generateSubmit(form) { + const prompt = document.querySelector("#prompt").value; + + // Convert file data to base64 + let formData = Object.fromEntries(new FormData(form)); + formData.initimg_name = formData.initimg.name + formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null; + + let strength = formData.strength; + let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps; + + let progressSectionEle = document.querySelector('#progress-section'); + progressSectionEle.style.display = 'initial'; + let progressEle = document.querySelector('#progress-bar'); + progressEle.setAttribute('max', totalSteps); + let progressImageEle = document.querySelector('#progress-image'); + progressImageEle.src = BLANK_IMAGE_URL; + + progressImageEle.style.display = {}.hasOwnProperty.call(formData, 'progress_images') ? 'initial': 'none'; + + // Post as JSON, using Fetch streaming to get results + fetch(form.action, { + method: form.method, + body: JSON.stringify(formData), + }).then(async (response) => { + const reader = response.body.getReader(); + + let noOutputs = true; + while (true) { + let {value, done} = await reader.read(); + value = new TextDecoder().decode(value); + if (done) { + progressSectionEle.style.display = 'none'; + break; + } + + for (let event of value.split('\n').filter(e => e !== '')) { + const data = JSON.parse(event); + + if (data.event === 'result') { + noOutputs = false; + appendOutput(data.url, data.seed, data.config); + progressEle.setAttribute('value', 0); + progressEle.setAttribute('max', totalSteps); + } else if (data.event === 'upscaling-started') { + document.getElementById("processing_cnt").textContent=data.processed_file_cnt; + document.getElementById("scaling-inprocess-message").style.display = "block"; + } else if (data.event === 'upscaling-done') { + document.getElementById("scaling-inprocess-message").style.display = "none"; + } else if (data.event === 'step') { + progressEle.setAttribute('value', data.step); + if (data.url) { + progressImageEle.src = data.url; + } + } else if (data.event === 'canceled') { + // avoid alerting as if this were an error case + noOutputs = false; + } + } + } + + // Re-enable form, remove no-results-message + form.querySelector('fieldset').removeAttribute('disabled'); + document.querySelector("#prompt").value = prompt; + document.querySelector('progress').setAttribute('value', '0'); + + if (noOutputs) { + alert("Error occurred while generating."); + } + }); + + // Disable form while generating + form.querySelector('fieldset').setAttribute('disabled',''); + document.querySelector("#prompt").value = `Generating: "${prompt}"`; +} + +async function fetchRunLog() { + try { + let response = await fetch('/run_log.json') + const data = await response.json(); + for(let item of data.run_log) { + appendOutput(item.url, item.seed, item); + } + } catch (e) { + console.error(e); + } +} + +window.onload = async () => { + document.querySelector("#prompt").addEventListener("keydown", (e) => { + if (e.key === "Enter" && !e.shiftKey) { + const form = e.target.form; + generateSubmit(form); + } + }); + document.querySelector("#generate-form").addEventListener('submit', (e) => { + e.preventDefault(); + const form = e.target; + + generateSubmit(form); + }); + document.querySelector("#generate-form").addEventListener('change', (e) => { + saveFields(e.target.form); + }); + document.querySelector("#reset-seed").addEventListener('click', (e) => { + document.querySelector("#seed").value = -1; + saveFields(e.target.form); + }); + document.querySelector("#reset-all").addEventListener('click', (e) => { + clearFields(e.target.form); + }); + document.querySelector("#remove-image").addEventListener('click', (e) => { + initimg.value=null; + }); + loadFields(document.querySelector("#generate-form")); + + document.querySelector('#cancel-button').addEventListener('click', () => { + fetch('/cancel').catch(e => { + console.error(e); + }); + }); + document.documentElement.addEventListener('keydown', (e) => { + if (e.key === "Escape") + fetch('/cancel').catch(err => { + console.error(err); + }); + }); + + if (!config.gfpgan_model_exists) { + document.querySelector("#gfpgan").style.display = 'none'; + } + await fetchRunLog() +}; diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 4c22507098..b722539935 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -21,12 +21,13 @@ def simple_graph(): def mock_services(): # NOTE: none of these are actually called by the test invocations return InvocationServices( - generate = None, + model_manager = None, events = None, images = None, queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() + processor = DefaultInvocationProcessor(), + restoration = None, ) def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 6a7867bffe..718baa7a1f 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -21,12 +21,13 @@ def simple_graph(): def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations return InvocationServices( - generate = None, # type: ignore + model_manager = None, # type: ignore events = TestEventService(), images = None, # type: ignore queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor() + processor = DefaultInvocationProcessor(), + restoration = None, ) @pytest.fixture()