diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index cb989cb15e..54c50bcf76 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -64,379 +64,338 @@ class InvocationContextData: """The workflow associated with this queue item, if any.""" -class BoardsInterface: - def __init__(self, services: InvocationServices) -> None: - def create(board_name: str) -> BoardDTO: - """ - Creates a board. - - :param board_name: The name of the board to create. - """ - return services.boards.create(board_name) - - def get_dto(board_id: str) -> BoardDTO: - """ - Gets a board DTO. - - :param board_id: The ID of the board to get. - """ - return services.boards.get_dto(board_id) - - def get_all() -> list[BoardDTO]: - """ - Gets all boards. - """ - return services.boards.get_all() - - def add_image_to_board(board_id: str, image_name: str) -> None: - """ - Adds an image to a board. - - :param board_id: The ID of the board to add the image to. - :param image_name: The name of the image to add to the board. - """ - services.board_images.add_image_to_board(board_id, image_name) - - def get_all_image_names_for_board(board_id: str) -> list[str]: - """ - Gets all image names for a board. - - :param board_id: The ID of the board to get the image names for. - """ - return services.board_images.get_all_board_image_names_for_board(board_id) - - self.create = create - self.get_dto = get_dto - self.get_all = get_all - self.add_image_to_board = add_image_to_board - self.get_all_image_names_for_board = get_all_image_names_for_board - - -class LoggerInterface: - def __init__(self, services: InvocationServices) -> None: - def debug(message: str) -> None: - """ - Logs a debug message. - - :param message: The message to log. - """ - services.logger.debug(message) - - def info(message: str) -> None: - """ - Logs an info message. - - :param message: The message to log. - """ - services.logger.info(message) - - def warning(message: str) -> None: - """ - Logs a warning message. - - :param message: The message to log. - """ - services.logger.warning(message) - - def error(message: str) -> None: - """ - Logs an error message. - - :param message: The message to log. - """ - services.logger.error(message) - - self.debug = debug - self.info = info - self.warning = warning - self.error = error - - -class ImagesInterface: +class InvocationContextInterface: def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - def save( - image: Image, - board_id: Optional[str] = None, - image_category: ImageCategory = ImageCategory.GENERAL, - metadata: Optional[MetadataField] = None, - ) -> ImageDTO: - """ - Saves an image, returning its DTO. - - If the current queue item has a workflow or metadata, it is automatically saved with the image. - - :param image: The image to save, as a PIL image. - :param board_id: The board ID to add the image to, if it should be added. - :param image_category: The category of the image. Only the GENERAL category is added \ - to the gallery. - :param metadata: The metadata to save with the image, if it should have any. If the \ - invocation inherits from `WithMetadata`, that metadata will be used automatically. \ - **Use this only if you want to override or provide metadata manually!** - """ - - # If the invocation inherits metadata, use that. Else, use the metadata passed in. - metadata_ = ( - context_data.invocation.metadata if isinstance(context_data.invocation, WithMetadata) else metadata - ) - - return services.images.create( - image=image, - is_intermediate=context_data.invocation.is_intermediate, - image_category=image_category, - board_id=board_id, - metadata=metadata_, - image_origin=ResourceOrigin.INTERNAL, - workflow=context_data.workflow, - session_id=context_data.session_id, - node_id=context_data.invocation.id, - ) - - def get_pil(image_name: str) -> Image: - """ - Gets an image as a PIL Image object. - - :param image_name: The name of the image to get. - """ - return services.images.get_pil_image(image_name) - - def get_metadata(image_name: str) -> Optional[MetadataField]: - """ - Gets an image's metadata, if it has any. - - :param image_name: The name of the image to get the metadata for. - """ - return services.images.get_metadata(image_name) - - def get_dto(image_name: str) -> ImageDTO: - """ - Gets an image as an ImageDTO object. - - :param image_name: The name of the image to get. - """ - return services.images.get_dto(image_name) - - def update( - image_name: str, - board_id: Optional[str] = None, - is_intermediate: Optional[bool] = False, - ) -> ImageDTO: - """ - Updates an image, returning its updated DTO. - - It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. - - If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to - get the updated image. - - :param image_name: The name of the image to update. - :param board_id: The board ID to add the image to, if it should be added. - :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. - """ - if is_intermediate is not None: - services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) - if board_id is None: - services.board_images.remove_image_from_board(image_name) - else: - services.board_images.add_image_to_board(image_name, board_id) - return services.images.get_dto(image_name) - - self.save = save - self.get_pil = get_pil - self.get_metadata = get_metadata - self.get_dto = get_dto - self.update = update + self._services = services + self._context_data = context_data -class LatentsInterface: - def __init__( +class BoardsInterface(InvocationContextInterface): + def create(self, board_name: str) -> BoardDTO: + """ + Creates a board. + + :param board_name: The name of the board to create. + """ + return self._services.boards.create(board_name) + + def get_dto(self, board_id: str) -> BoardDTO: + """ + Gets a board DTO. + + :param board_id: The ID of the board to get. + """ + return self._services.boards.get_dto(board_id) + + def get_all(self) -> list[BoardDTO]: + """ + Gets all boards. + """ + return self._services.boards.get_all() + + def add_image_to_board(self, board_id: str, image_name: str) -> None: + """ + Adds an image to a board. + + :param board_id: The ID of the board to add the image to. + :param image_name: The name of the image to add to the board. + """ + return self._services.board_images.add_image_to_board(board_id, image_name) + + def get_all_image_names_for_board(self, board_id: str) -> list[str]: + """ + Gets all image names for a board. + + :param board_id: The ID of the board to get the image names for. + """ + return self._services.board_images.get_all_board_image_names_for_board(board_id) + + +class LoggerInterface(InvocationContextInterface): + def debug(self, message: str) -> None: + """ + Logs a debug message. + + :param message: The message to log. + """ + self._services.logger.debug(message) + + def info(self, message: str) -> None: + """ + Logs an info message. + + :param message: The message to log. + """ + self._services.logger.info(message) + + def warning(self, message: str) -> None: + """ + Logs a warning message. + + :param message: The message to log. + """ + self._services.logger.warning(message) + + def error(self, message: str) -> None: + """ + Logs an error message. + + :param message: The message to log. + """ + self._services.logger.error(message) + + +class ImagesInterface(InvocationContextInterface): + def save( self, - services: InvocationServices, - context_data: InvocationContextData, - ) -> None: - def save(tensor: Tensor) -> str: - """ - Saves a latents tensor, returning its name. + image: Image, + board_id: Optional[str] = None, + image_category: ImageCategory = ImageCategory.GENERAL, + metadata: Optional[MetadataField] = None, + ) -> ImageDTO: + """ + Saves an image, returning its DTO. - :param tensor: The latents tensor to save. - """ + If the current queue item has a workflow or metadata, it is automatically saved with the image. - # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. - # "mask", "noise", "masked_latents", etc. - # - # Retaining that capability in this wrapper would require either many different methods - # to save latents, or extra args for this method. Instead of complicating the API, we - # will use the same naming scheme for all latents. - # - # This has a very minor impact as we don't use them after a session completes. + :param image: The image to save, as a PIL image. + :param board_id: The board ID to add the image to, if it should be added. + :param image_category: The category of the image. Only the GENERAL category is added \ + to the gallery. + :param metadata: The metadata to save with the image, if it should have any. If the \ + invocation inherits from `WithMetadata`, that metadata will be used automatically. \ + **Use this only if you want to override or provide metadata manually!** + """ - # Previously, invocations chose the name for their latents. This is a bit risky, so we - # will generate a name for them instead. We use a uuid to ensure the name is unique. - # - # Because the name of the latents file will includes the session and invocation IDs, - # we don't need to worry about collisions. A truncated UUIDv4 is fine. + # If the invocation inherits metadata, use that. Else, use the metadata passed in. + metadata_ = ( + self._context_data.invocation.metadata + if isinstance(self._context_data.invocation, WithMetadata) + else metadata + ) - name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" - services.latents.save( - name=name, - data=tensor, - ) - return name + return self._services.images.create( + image=image, + is_intermediate=self._context_data.invocation.is_intermediate, + image_category=image_category, + board_id=board_id, + metadata=metadata_, + image_origin=ResourceOrigin.INTERNAL, + workflow=self._context_data.workflow, + session_id=self._context_data.session_id, + node_id=self._context_data.invocation.id, + ) - def get(latents_name: str) -> Tensor: - """ - Gets a latents tensor by name. + def get_pil(self, image_name: str) -> Image: + """ + Gets an image as a PIL Image object. - :param latents_name: The name of the latents tensor to get. - """ - return services.latents.get(latents_name) + :param image_name: The name of the image to get. + """ + return self._services.images.get_pil_image(image_name) - self.save = save - self.get = get + def get_metadata(self, image_name: str) -> Optional[MetadataField]: + """ + Gets an image's metadata, if it has any. + :param image_name: The name of the image to get the metadata for. + """ + return self._services.images.get_metadata(image_name) -class ConditioningInterface: - def __init__( + def get_dto(self, image_name: str) -> ImageDTO: + """ + Gets an image as an ImageDTO object. + + :param image_name: The name of the image to get. + """ + return self._services.images.get_dto(image_name) + + def update( self, - services: InvocationServices, - context_data: InvocationContextData, - ) -> None: - # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage - # service, but it is typed to work with Tensors only. We have to fudge the types here. + image_name: str, + board_id: Optional[str] = None, + is_intermediate: Optional[bool] = False, + ) -> ImageDTO: + """ + Updates an image, returning its updated DTO. - def save(conditioning_data: ConditioningFieldData) -> str: - """ - Saves a conditioning data object, returning its name. + It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. - :param conditioning_data: The conditioning data to save. - """ + If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to + get the updated image. - # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. - # - # See comment for `LatentsInterface.save` for more info about this method (it's very - # similar). - - name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" - services.latents.save( - name=name, - data=conditioning_data, # type: ignore [arg-type] - ) - return name - - def get(conditioning_name: str) -> ConditioningFieldData: - """ - Gets conditioning data by name. - - :param conditioning_name: The name of the conditioning data to get. - """ - - return services.latents.get(conditioning_name) # type: ignore [return-value] - - self.save = save - self.get = get + :param image_name: The name of the image to update. + :param board_id: The board ID to add the image to, if it should be added. + :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. + """ + if is_intermediate is not None: + self._services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) + if board_id is None: + self._services.board_images.remove_image_from_board(image_name) + else: + self._services.board_images.add_image_to_board(image_name, board_id) + return self._services.images.get_dto(image_name) -class ModelsInterface: - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - def exists(model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: - """ - Checks if a model exists. +class LatentsInterface(InvocationContextInterface): + def save(self, tensor: Tensor) -> str: + """ + Saves a latents tensor, returning its name. - :param model_name: The name of the model to check. - :param base_model: The base model of the model to check. - :param model_type: The type of the model to check. - """ - return services.model_manager.model_exists(model_name, base_model, model_type) + :param tensor: The latents tensor to save. + """ - def load( - model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None - ) -> ModelInfo: - """ - Loads a model, returning its `ModelInfo` object. + # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. + # "mask", "noise", "masked_latents", etc. + # + # Retaining that capability in this wrapper would require either many different methods + # to save latents, or extra args for this method. Instead of complicating the API, we + # will use the same naming scheme for all latents. + # + # This has a very minor impact as we don't use them after a session completes. - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. - :param submodel: The submodel of the model to get. - """ + # Previously, invocations chose the name for their latents. This is a bit risky, so we + # will generate a name for them instead. We use a uuid to ensure the name is unique. + # + # Because the name of the latents file will includes the session and invocation IDs, + # we don't need to worry about collisions. A truncated UUIDv4 is fine. - # During this call, the model manager emits events with model loading status. The model - # manager itself has access to the events services, but does not have access to the - # required metadata for the events. - # - # For example, it needs access to the node's ID so that the events can be associated - # with the execution of a specific node. - # - # While this is available within the node, it's tedious to need to pass it in on every - # call. We can avoid that by wrapping the method here. + name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" + self._services.latents.save( + name=name, + data=tensor, + ) + return name - return services.model_manager.get_model( - model_name, base_model, model_type, submodel, context_data=context_data - ) + def get(self, latents_name: str) -> Tensor: + """ + Gets a latents tensor by name. - def get_info(model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Gets a model's info, an dict-like object. - - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. - """ - return services.model_manager.model_info(model_name, base_model, model_type) - - self.exists = exists - self.load = load - self.get_info = get_info + :param latents_name: The name of the latents tensor to get. + """ + return self._services.latents.get(latents_name) -class ConfigInterface: - def __init__(self, services: InvocationServices) -> None: - def get() -> InvokeAIAppConfig: - """ - Gets the app's config. The config is read-only; attempts to mutate it will raise an error. - """ +class ConditioningInterface(InvocationContextInterface): + # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage + # service, but it is typed to work with Tensors only. We have to fudge the types here. + def save(self, conditioning_data: ConditioningFieldData) -> str: + """ + Saves a conditioning data object, returning its name. - # The config can be changed at runtime. - # - # We don't want nodes doing this, so we make a frozen copy. + :param conditioning_context_data: The conditioning data to save. + """ - config = services.configuration.get_config() - # TODO(psyche): If config cannot be changed at runtime, should we cache this? - frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) - return frozen_config + # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. + # + # See comment for `LatentsInterface.save` for more info about this method (it's very + # similar). - self.get = get + name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning" + self._services.latents.save( + name=name, + data=conditioning_data, # type: ignore [arg-type] + ) + return name + + def get(self, conditioning_name: str) -> ConditioningFieldData: + """ + Gets conditioning data by name. + + :param conditioning_name: The name of the conditioning data to get. + """ + + return self._services.latents.get(conditioning_name) # type: ignore [return-value] -class UtilInterface: - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - def sd_step_callback( - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - """ - The step callback emits a progress event with the current step, the total number of - steps, a preview image, and some other internal metadata. +class ModelsInterface(InvocationContextInterface): + def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: + """ + Checks if a model exists. - This should be called after each denoising step. + :param model_name: The name of the model to check. + :param base_model: The base model of the model to check. + :param model_type: The type of the model to check. + """ + return self._services.model_manager.model_exists(model_name, base_model, model_type) - :param intermediate_state: The intermediate state of the diffusion pipeline. - :param base_model: The base model for the current denoising step. - """ + def load( + self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None + ) -> ModelInfo: + """ + Loads a model, returning its `ModelInfo` object. - # The step callback needs access to the events and the invocation queue services, but this - # represents a dangerous level of access. - # - # We wrap the step callback so that nodes do not have direct access to these services. + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + :param submodel: The submodel of the model to get. + """ - stable_diffusion_step_callback( - context_data=context_data, - intermediate_state=intermediate_state, - base_model=base_model, - invocation_queue=services.queue, - events=services.events, - ) + # During this call, the model manager emits events with model loading status. The model + # manager itself has access to the events services, but does not have access to the + # required metadata for the events. + # + # For example, it needs access to the node's ID so that the events can be associated + # with the execution of a specific node. + # + # While this is available within the node, it's tedious to need to pass it in on every + # call. We can avoid that by wrapping the method here. - self.sd_step_callback = sd_step_callback + return self._services.model_manager.get_model( + model_name, base_model, model_type, submodel, context_data=self._context_data + ) + + def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + """ + Gets a model's info, an dict-like object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + """ + return self._services.model_manager.model_info(model_name, base_model, model_type) + + +class ConfigInterface(InvocationContextInterface): + def get(self) -> InvokeAIAppConfig: + """ + Gets the app's config. The config is read-only; attempts to mutate it will raise an error. + """ + + # The config can be changed at runtime. + # + # We don't want nodes doing this, so we make a frozen copy. + + config = self._services.configuration.get_config() + # TODO(psyche): If config cannot be changed at runtime, should we cache this? + frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) + return frozen_config + + +class UtilInterface(InvocationContextInterface): + def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: + """ + The step callback emits a progress event with the current step, the total number of + steps, a preview image, and some other internal metadata. + + This should be called after each denoising step. + + :param intermediate_state: The intermediate state of the diffusion pipeline. + :param base_model: The base model for the current denoising step. + """ + + # The step callback needs access to the events and the invocation queue services, but this + # represents a dangerous level of access. + # + # We wrap the step callback so that nodes do not have direct access to these services. + + stable_diffusion_step_callback( + context_data=self._context_data, + intermediate_state=intermediate_state, + base_model=base_model, + invocation_queue=self._services.queue, + events=self._services.events, + ) deprecation_version = "3.7.0" @@ -600,14 +559,14 @@ def build_invocation_context( :param invocation_context_data: The invocation context data. """ - logger = LoggerInterface(services=services) + logger = LoggerInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data) latents = LatentsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data) - config = ConfigInterface(services=services) + config = ConfigInterface(services=services, context_data=context_data) util = UtilInterface(services=services, context_data=context_data) conditioning = ConditioningInterface(services=services, context_data=context_data) - boards = BoardsInterface(services=services) + boards = BoardsInterface(services=services, context_data=context_data) ctx = InvocationContext( images=images,