diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index cb867354a5..649c8f7f18 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -20,7 +20,6 @@ from invokeai.version.invokeai_version import __version__ from ..services.default_graphs import create_system_graphs from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage -from ..services.restoration_services import RestorationServices from ..services.graph import GraphExecutionState, LibraryGraph from ..services.image_file_storage import DiskImageFileStorage from ..services.invocation_queue import MemoryInvocationQueue @@ -58,7 +57,7 @@ class ApiDependencies: @staticmethod def initialize(config, event_handler_id: int, logger: Logger = logger): - logger.debug(f'InvokeAI version {__version__}') + logger.debug(f"InvokeAI version {__version__}") logger.debug(f"Internet connectivity is {config.internet_available}") events = FastAPIEventService(event_handler_id) @@ -117,7 +116,7 @@ class ApiDependencies: ) services = InvocationServices( - model_manager=ModelManagerService(config,logger), + model_manager=ModelManagerService(config, logger), events=events, latents=latents, images=images, @@ -129,7 +128,6 @@ class ApiDependencies: ), graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), - restoration=RestorationServices(config, logger), configuration=config, logger=logger, ) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 888d36c4bf..c88cefe2eb 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -54,7 +54,6 @@ from .services.invocation_services import InvocationServices from .services.invoker import Invoker from .services.model_manager_service import ModelManagerService from .services.processor import DefaultInvocationProcessor -from .services.restoration_services import RestorationServices from .services.sqlite import SqliteItemStorage import torch @@ -295,7 +294,6 @@ def invoke_cli(): ), graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), - restoration=RestorationServices(config,logger=logger), logger=logger, configuration=config, ) diff --git a/invokeai/app/invocations/reconstruct.py b/invokeai/app/invocations/reconstruct.py deleted file mode 100644 index 84ace2eefe..0000000000 --- a/invokeai/app/invocations/reconstruct.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Literal, Optional - -from pydantic import Field - -from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin - -from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig -from .image import ImageOutput - - -class RestoreFaceInvocation(BaseInvocation): - """Restores faces in an image.""" - - # fmt: off - type: Literal["restore_face"] = "restore_face" - - # Inputs - image: Optional[ImageField] = Field(description="The input image") - strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" ) - # fmt: on - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["restoration", "image"], - }, - } - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - results = context.services.restoration.upscale_and_reconstruct( - image_list=[[image, 0]], - upscale=None, - strength=self.strength, # GFPGAN strength - save_original=False, - image_callback=None, - ) - - # Results are image and seed, unwrap for now - # TODO: can this return multiple results? - image_dto = context.services.images.create( - image=results[0][0], - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 4e1da3b040..2f45863bc6 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -10,7 +10,6 @@ if TYPE_CHECKING: from invokeai.app.services.model_manager_service import ModelManagerServiceBase from invokeai.app.services.events import EventServiceBase from invokeai.app.services.latent_storage import LatentsStorageBase - from invokeai.app.services.restoration_services import RestorationServices from invokeai.app.services.invocation_queue import InvocationQueueABC from invokeai.app.services.item_storage import ItemStorageABC from invokeai.app.services.config import InvokeAISettings @@ -34,7 +33,6 @@ class InvocationServices: model_manager: "ModelManagerServiceBase" processor: "InvocationProcessorABC" queue: "InvocationQueueABC" - restoration: "RestorationServices" def __init__( self, @@ -50,7 +48,6 @@ class InvocationServices: model_manager: "ModelManagerServiceBase", processor: "InvocationProcessorABC", queue: "InvocationQueueABC", - restoration: "RestorationServices", ): self.board_images = board_images self.boards = boards @@ -65,4 +62,3 @@ class InvocationServices: self.model_manager = model_manager self.processor = processor self.queue = queue - self.restoration = restoration diff --git a/invokeai/app/services/restoration_services.py b/invokeai/app/services/restoration_services.py deleted file mode 100644 index 5ff0195ca5..0000000000 --- a/invokeai/app/services/restoration_services.py +++ /dev/null @@ -1,113 +0,0 @@ -import sys -import traceback -import torch -from typing import types -from ...backend.restoration import Restoration -from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE - -# This should be a real base class for postprocessing functions, -# but right now we just instantiate the existing gfpgan, esrgan -# and codeformer functions. -class RestorationServices: - '''Face restoration and upscaling''' - - def __init__(self,args,logger:types.ModuleType): - try: - gfpgan, codeformer, esrgan = None, None, None - if args.restore or args.esrgan: - restoration = Restoration() - # TODO: redo for new model structure - if False and args.restore: - gfpgan, codeformer = restoration.load_face_restore_models( - args.gfpgan_model_path - ) - else: - logger.info("Face restoration disabled") - if False and args.esrgan: - esrgan = restoration.load_esrgan(args.esrgan_bg_tile) - else: - logger.info("Upscaling disabled") - else: - logger.info("Face restoration and upscaling disabled") - except (ModuleNotFoundError, ImportError): - print(traceback.format_exc(), file=sys.stderr) - logger.info("You may need to install the ESRGAN and/or GFPGAN modules") - self.device = torch.device(choose_torch_device()) - self.gfpgan = gfpgan - self.codeformer = codeformer - self.esrgan = esrgan - self.logger = logger - self.logger.info('Face restoration initialized') - - # note that this one method does gfpgan and codepath reconstruction, as well as - # esrgan upscaling - # TO DO: refactor into separate methods - def upscale_and_reconstruct( - self, - image_list, - facetool="gfpgan", - upscale=None, - upscale_denoise_str=0.75, - strength=0.0, - codeformer_fidelity=0.75, - save_original=False, - image_callback=None, - prefix=None, - ): - results = [] - for r in image_list: - image, seed = r - try: - if strength > 0: - if self.gfpgan is not None or self.codeformer is not None: - if facetool == "gfpgan": - if self.gfpgan is None: - self.logger.info( - "GFPGAN not found. Face restoration is disabled." - ) - else: - image = self.gfpgan.process(image, strength, seed) - if facetool == "codeformer": - if self.codeformer is None: - self.logger.info( - "CodeFormer not found. Face restoration is disabled." - ) - else: - cf_device = ( - CPU_DEVICE if self.device == MPS_DEVICE else self.device - ) - image = self.codeformer.process( - image=image, - strength=strength, - device=cf_device, - seed=seed, - fidelity=codeformer_fidelity, - ) - else: - self.logger.info("Face Restoration is disabled.") - if upscale is not None: - if self.esrgan is not None: - if len(upscale) < 2: - upscale.append(0.75) - image = self.esrgan.process( - image, - upscale[1], - seed, - int(upscale[0]), - denoise_str=upscale_denoise_str, - ) - else: - self.logger.info("ESRGAN is disabled. Image not upscaled.") - except Exception as e: - self.logger.info( - f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}" - ) - - if image_callback is not None: - image_callback(image, seed, upscaled=True, use_prefix=prefix) - else: - r[0] = image - - results.append([image, seed]) - - return results diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index f34b18310b..bc4a3f4176 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -55,7 +55,6 @@ def mock_services() -> InvocationServices: ), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor(), - restoration = None, # type: ignore configuration = None, # type: ignore ) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 19d7dd20b3..4741e7f58b 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -48,7 +48,6 @@ def mock_services() -> InvocationServices: ), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor(), - restoration = None, # type: ignore configuration = None, # type: ignore ) diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index df7378150d..c86be19059 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -1,6 +1,6 @@ from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation -from invokeai.app.invocations.upscale import UpscaleInvocation +from invokeai.app.invocations.upscale import RealESRGANInvocation from invokeai.app.invocations.image import * from invokeai.app.invocations.math import AddInvocation, SubtractInvocation from invokeai.app.invocations.params import ParamIntInvocation @@ -19,7 +19,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg def test_connections_are_compatible(): from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") from_field = "image" - to_node = UpscaleInvocation(id = "2") + to_node = RealESRGANInvocation(id = "2") to_field = "image" result = are_connections_compatible(from_node, from_field, to_node, to_field) @@ -29,7 +29,7 @@ def test_connections_are_compatible(): def test_connections_are_incompatible(): from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") from_field = "image" - to_node = UpscaleInvocation(id = "2") + to_node = RealESRGANInvocation(id = "2") to_field = "strength" result = are_connections_compatible(from_node, from_field, to_node, to_field) @@ -39,7 +39,7 @@ def test_connections_are_incompatible(): def test_connections_incompatible_with_invalid_fields(): from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") from_field = "invalid_field" - to_node = UpscaleInvocation(id = "2") + to_node = RealESRGANInvocation(id = "2") to_field = "image" # From field is invalid @@ -86,10 +86,10 @@ def test_graph_fails_to_update_node_if_type_changes(): g = Graph() n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) - n2 = UpscaleInvocation(id = "2") + n2 = RealESRGANInvocation(id = "2") g.add_node(n2) - nu = UpscaleInvocation(id = "1") + nu = RealESRGANInvocation(id = "1") with pytest.raises(TypeError): g.update_node("1", nu) @@ -98,7 +98,7 @@ def test_graph_allows_non_conflicting_id_change(): g = Graph() n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) - n2 = UpscaleInvocation(id = "2") + n2 = RealESRGANInvocation(id = "2") g.add_node(n2) e1 = create_edge(n.id,"image",n2.id,"image") g.add_edge(e1) @@ -128,7 +128,7 @@ def test_graph_fails_to_update_node_id_if_conflict(): def test_graph_adds_edge(): g = Graph() n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") + n2 = RealESRGANInvocation(id = "2") g.add_node(n1) g.add_node(n2) e = create_edge(n1.id,"image",n2.id,"image") @@ -139,7 +139,7 @@ def test_graph_adds_edge(): def test_graph_fails_to_add_edge_with_cycle(): g = Graph() - n1 = UpscaleInvocation(id = "1") + n1 = RealESRGANInvocation(id = "1") g.add_node(n1) e = create_edge(n1.id,"image",n1.id,"image") with pytest.raises(InvalidEdgeError): @@ -148,8 +148,8 @@ def test_graph_fails_to_add_edge_with_cycle(): def test_graph_fails_to_add_edge_with_long_cycle(): g = Graph() n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") - n3 = UpscaleInvocation(id = "3") + n2 = RealESRGANInvocation(id = "2") + n3 = RealESRGANInvocation(id = "3") g.add_node(n1) g.add_node(n2) g.add_node(n3) @@ -164,7 +164,7 @@ def test_graph_fails_to_add_edge_with_long_cycle(): def test_graph_fails_to_add_edge_with_missing_node_id(): g = Graph() n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") + n2 = RealESRGANInvocation(id = "2") g.add_node(n1) g.add_node(n2) e1 = create_edge("1","image","3","image") @@ -177,8 +177,8 @@ def test_graph_fails_to_add_edge_with_missing_node_id(): def test_graph_fails_to_add_edge_when_destination_exists(): g = Graph() n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") - n3 = UpscaleInvocation(id = "3") + n2 = RealESRGANInvocation(id = "2") + n3 = RealESRGANInvocation(id = "3") g.add_node(n1) g.add_node(n2) g.add_node(n3) @@ -194,7 +194,7 @@ def test_graph_fails_to_add_edge_when_destination_exists(): def test_graph_fails_to_add_edge_with_mismatched_types(): g = Graph() n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") + n2 = RealESRGANInvocation(id = "2") g.add_node(n1) g.add_node(n2) e1 = create_edge("1","image","2","strength") @@ -344,7 +344,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different(): def test_graph_validates(): g = Graph() n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") + n2 = RealESRGANInvocation(id = "2") g.add_node(n1) g.add_node(n2) e1 = create_edge("1","image","2","image") @@ -377,8 +377,8 @@ def test_graph_invalid_if_subgraph_invalid(): def test_graph_invalid_if_has_cycle(): g = Graph() - n1 = UpscaleInvocation(id = "1") - n2 = UpscaleInvocation(id = "2") + n1 = RealESRGANInvocation(id = "1") + n2 = RealESRGANInvocation(id = "2") g.nodes[n1.id] = n1 g.nodes[n2.id] = n2 e1 = create_edge("1","image","2","image") @@ -391,7 +391,7 @@ def test_graph_invalid_if_has_cycle(): def test_graph_invalid_with_invalid_connection(): g = Graph() n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") + n2 = RealESRGANInvocation(id = "2") g.nodes[n1.id] = n1 g.nodes[n2.id] = n2 e1 = create_edge("1","image","2","strength") @@ -503,7 +503,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node(): g.add_node(n1) - n2 = UpscaleInvocation(id = "2") + n2 = RealESRGANInvocation(id = "2") g.add_node(n2) with pytest.raises(NodeNotFoundError): @@ -512,7 +512,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node(): def test_graph_gets_networkx_graph(): g = Graph() n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") + n2 = RealESRGANInvocation(id = "2") g.add_node(n1) g.add_node(n2) e = create_edge(n1.id,"image",n2.id,"image") @@ -529,7 +529,7 @@ def test_graph_gets_networkx_graph(): def test_graph_can_serialize(): g = Graph() n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") + n2 = RealESRGANInvocation(id = "2") g.add_node(n1) g.add_node(n2) e = create_edge(n1.id,"image",n2.id,"image") @@ -541,7 +541,7 @@ def test_graph_can_serialize(): def test_graph_can_deserialize(): g = Graph() n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = UpscaleInvocation(id = "2") + n2 = RealESRGANInvocation(id = "2") g.add_node(n1) g.add_node(n2) e = create_edge(n1.id,"image",n2.id,"image")