Compare commits

...

5 Commits

Author SHA1 Message Date
psychedelicious
c2e9bdc6c5 feat(ui): handle new progress event
Minor changes to use the new progress event. Only additional feature is if the progress has a message, it is displayed as a tooltip on the progress bar.
2024-08-04 19:02:45 +10:00
psychedelicious
4e0e9041e2 chore(ui): typegen 2024-08-04 18:51:42 +10:00
psychedelicious
5f94340e4f feat(app): merge progress events into one
- Merged `InvocationGenericProgressEvent` and `InvocationDenoiseProgressEvent` into single `InvocationProgressEvent`
- Simplified API - message is required, percentage and image are optional, no steps/total steps
- Added helper to build a `ProgressImage`
- Added field validation to `ProgressImage` width and height
- Added `ProgressImage` to `invocation_api.py`
- Updated `InvocationContext` utils
2024-08-04 18:47:45 +10:00
psychedelicious
682280683a feat(app): signal progress while processing spandrel tiles 2024-08-03 22:02:11 +10:00
psychedelicious
487815b181 feat(app): generic progress events
Some processes have steps, like denoising or a tiled spandel.

Denoising has its own step callback but we don't have any generic way to signal progress. Processes like a tiled spandrel run show indeterminate progress in the client.

This change introduces a new event to handle this: `InvocationGenericProgressEvent`

A simplified helper is added to the invocation API so nodes can easily emit progress as they do their thing.
2024-08-03 22:01:36 +10:00
22 changed files with 696 additions and 323 deletions

View File

@@ -20,8 +20,8 @@ from invokeai.app.services.events.events_common import (
DownloadStartedEvent, DownloadStartedEvent,
FastAPIEvent, FastAPIEvent,
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent, InvocationErrorEvent,
InvocationProgressEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelEventBase, ModelEventBase,
ModelInstallCancelledEvent, ModelInstallCancelledEvent,
@@ -55,7 +55,7 @@ class BulkDownloadSubscriptionEvent(BaseModel):
QUEUE_EVENTS = { QUEUE_EVENTS = {
InvocationStartedEvent, InvocationStartedEvent,
InvocationDenoiseProgressEvent, InvocationProgressEvent,
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationErrorEvent, InvocationErrorEvent,
QueueItemStatusChangedEvent, QueueItemStatusChangedEvent,

View File

@@ -1,3 +1,4 @@
import functools
from typing import Callable from typing import Callable
import numpy as np import numpy as np
@@ -61,6 +62,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tile_size: int, tile_size: int,
spandrel_model: SpandrelImageToImageModel, spandrel_model: SpandrelImageToImageModel,
is_canceled: Callable[[], bool], is_canceled: Callable[[], bool],
step_callback: Callable[[int, int], None],
) -> Image.Image: ) -> Image.Image:
# Compute the image tiles. # Compute the image tiles.
if tile_size > 0: if tile_size > 0:
@@ -103,7 +105,12 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype) image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on each tile. # Run the model on each tile.
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"): pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles")
# Update progress, starting with 0.
step_callback(0, pbar.total)
for tile, scaled_tile in pbar:
# Exit early if the invocation has been canceled. # Exit early if the invocation has been canceled.
if is_canceled(): if is_canceled():
raise CanceledException raise CanceledException
@@ -136,6 +143,8 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
:, :,
] = output_tile[top_overlap:, left_overlap:, :] ] = output_tile[top_overlap:, left_overlap:, :]
step_callback(pbar.n + 1, pbar.total)
# Convert the output tensor to a PIL image. # Convert the output tensor to a PIL image.
np_image = output_tensor.detach().numpy().astype(np.uint8) np_image = output_tensor.detach().numpy().astype(np.uint8)
pil_image = Image.fromarray(np_image) pil_image = Image.fromarray(np_image)
@@ -151,12 +160,20 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
# Load the model. # Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model) spandrel_model_info = context.models.load(self.image_to_image_model)
def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
message=f"Processing image (tile {step}/{total_steps})",
percentage=step / total_steps,
)
# Do the upscaling. # Do the upscaling.
with spandrel_model_info as spandrel_model: with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel) assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Upscale the image # Upscale the image
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled) pil_image = self.upscale_image(
image, self.tile_size, spandrel_model, context.util.is_canceled, step_callback
)
image_dto = context.images.save(image=pil_image) image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
@@ -197,12 +214,27 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
target_width = int(image.width * self.scale) target_width = int(image.width * self.scale)
target_height = int(image.height * self.scale) target_height = int(image.height * self.scale)
def step_callback(iteration: int, step: int, total_steps: int) -> None:
context.util.signal_progress(
message=self._get_progress_message(iteration, step, total_steps),
percentage=step / total_steps,
)
# Do the upscaling. # Do the upscaling.
with spandrel_model_info as spandrel_model: with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel) assert isinstance(spandrel_model, SpandrelImageToImageModel)
iteration = 1
context.util.signal_progress(self._get_progress_message(iteration))
# First pass of upscaling. Note: `pil_image` will be mutated. # First pass of upscaling. Note: `pil_image` will be mutated.
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled) pil_image = self.upscale_image(
image,
self.tile_size,
spandrel_model,
context.util.is_canceled,
functools.partial(step_callback, iteration),
)
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model # Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions # upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
@@ -211,16 +243,22 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
if is_upscale_model: if is_upscale_model:
# This is an upscale model, so we should keep upscaling until we reach the target size. # This is an upscale model, so we should keep upscaling until we reach the target size.
iterations = 1
while pil_image.width < target_width or pil_image.height < target_height: while pil_image.width < target_width or pil_image.height < target_height:
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled) iteration += 1
iterations += 1 context.util.signal_progress(self._get_progress_message(iteration))
pil_image = self.upscale_image(
pil_image,
self.tile_size,
spandrel_model,
context.util.is_canceled,
functools.partial(step_callback, iteration),
)
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x. # Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations. # Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice, # We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
# we should never reach this limit. # we should never reach this limit.
if iterations >= 5: if iteration >= 5:
context.logger.warning( context.logger.warning(
"Upscale loop reached maximum iteration count of 5, stopping upscaling early." "Upscale loop reached maximum iteration count of 5, stopping upscaling early."
) )
@@ -251,3 +289,10 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
image_dto = context.images.save(image=pil_image) image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
@classmethod
def _get_progress_message(cls, iteration: int, step: int | None = None, total_steps: int | None = None) -> str:
if step is not None and total_steps is not None:
return f"Processing image (iteration {iteration}, tile {step}/{total_steps})"
return f"Processing image (iteration {iteration})"

View File

@@ -15,8 +15,8 @@ from invokeai.app.services.events.events_common import (
DownloadStartedEvent, DownloadStartedEvent,
EventBase, EventBase,
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent, InvocationErrorEvent,
InvocationProgressEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelInstallCancelledEvent, ModelInstallCancelledEvent,
ModelInstallCompleteEvent, ModelInstallCompleteEvent,
@@ -30,13 +30,12 @@ from invokeai.app.services.events.events_common import (
QueueClearedEvent, QueueClearedEvent,
QueueItemStatusChangedEvent, QueueItemStatusChangedEvent,
) )
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.app.services.session_processor.session_processor_common import ProgressImage
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus, BatchStatus,
EnqueueBatchResult, EnqueueBatchResult,
@@ -58,15 +57,16 @@ class EventServiceBase:
"""Emitted when an invocation is started""" """Emitted when an invocation is started"""
self.dispatch(InvocationStartedEvent.build(queue_item, invocation)) self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
def emit_invocation_denoise_progress( def emit_invocation_progress(
self, self,
queue_item: "SessionQueueItem", queue_item: "SessionQueueItem",
invocation: "BaseInvocation", invocation: "BaseInvocation",
intermediate_state: PipelineIntermediateState, message: str,
progress_image: "ProgressImage", percentage: float | None = None,
image: ProgressImage | None = None,
) -> None: ) -> None:
"""Emitted at each step during denoising of an invocation.""" """Emitted at each step during an invocation"""
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image)) self.dispatch(InvocationProgressEvent.build(queue_item, invocation, message, percentage, image))
def emit_invocation_complete( def emit_invocation_complete(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput" self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"

View File

@@ -1,4 +1,3 @@
from math import floor
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
@@ -16,7 +15,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob from invokeai.app.services.download.download_base import DownloadJob
@@ -121,28 +119,28 @@ class InvocationStartedEvent(InvocationEventBase):
@payload_schema.register @payload_schema.register
class InvocationDenoiseProgressEvent(InvocationEventBase): class InvocationProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress""" """Event model for invocation_progress"""
__event_name__ = "invocation_denoise_progress" __event_name__ = "invocation_progress"
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing") message: str = Field(description="A message to display")
step: int = Field(description="The current step of the invocation") percentage: float | None = Field(
total_steps: int = Field(description="The total number of steps in the invocation") default=None, ge=0, le=1, description="The percentage of the progress (omit to indicate indeterminate progress)"
order: int = Field(description="The order of the invocation in the session") )
percentage: float = Field(description="The percentage of completion of the invocation") image: ProgressImage | None = Field(
default=None, description="An image representing the current state of the progress"
)
@classmethod @classmethod
def build( def build(
cls, cls,
queue_item: SessionQueueItem, queue_item: SessionQueueItem,
invocation: AnyInvocation, invocation: AnyInvocation,
intermediate_state: PipelineIntermediateState, message: str,
progress_image: ProgressImage, percentage: float | None = None,
) -> "InvocationDenoiseProgressEvent": image: ProgressImage | None = None,
step = intermediate_state.step ) -> "InvocationProgressEvent":
total_steps = intermediate_state.total_steps
order = intermediate_state.order
return cls( return cls(
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
item_id=queue_item.item_id, item_id=queue_item.item_id,
@@ -150,23 +148,11 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
session_id=queue_item.session_id, session_id=queue_item.session_id,
invocation=invocation, invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
progress_image=progress_image, percentage=percentage,
step=step, image=image,
total_steps=total_steps, message=message,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
) )
@staticmethod
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
"""Calculate the percentage of completion of denoising."""
if total_steps == 0:
return 0.0
if scheduler_order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)
@payload_schema.register @payload_schema.register
class InvocationCompleteEvent(InvocationEventBase): class InvocationCompleteEvent(InvocationEventBase):

View File

@@ -1,5 +1,8 @@
from PIL.Image import Image as PILImageType
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.backend.util.util import image_to_dataURL
class SessionProcessorStatus(BaseModel): class SessionProcessorStatus(BaseModel):
is_started: bool = Field(description="Whether the session processor is started") is_started: bool = Field(description="Whether the session processor is started")
@@ -15,6 +18,16 @@ class CanceledException(Exception):
class ProgressImage(BaseModel): class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing""" """The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels") width: int = Field(ge=1, description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels") height: int = Field(ge=1, description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL") dataURL: str = Field(description="The image data as a b64 data URL")
@classmethod
def build(cls, image: PILImageType, size: tuple[int, int] | None = None) -> "ProgressImage":
"""Build a ProgressImage from a PIL image"""
return cls(
width=size[0] if size else image.width,
height=size[1] if size else image.height,
dataURL=image_to_dataURL(image, image_format="JPEG"),
)

View File

@@ -14,6 +14,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModel, AnyModel,
@@ -550,13 +551,64 @@ class UtilInterface(InvocationContextInterface):
""" """
stable_diffusion_step_callback( stable_diffusion_step_callback(
context_data=self._data, signal_progress=self.signal_progress,
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
base_model=base_model, base_model=base_model,
events=self._services.events,
is_canceled=self.is_canceled, is_canceled=self.is_canceled,
) )
def signal_progress(
self, message: str, percentage: float | None = None, image: ProgressImage | None = None
) -> None:
"""Signals the progress of some long-running invocation. The progress is displayed in the UI.
If you have an image to display, use `ProgressImage.build` to create the object.
If your progress image should be displayed at a different size, provide a tuple of `(width, height)` when
building the progress image.
For example, SD denoising progress images are 1/8 the size of the original image. In this case, the progress
image should be built like this to ensure it displays at the correct size:
```py
progress_image = ProgressImage.build(image, (width * 8, height * 8))
```
If your progress image is very large, consider downscaling it to reduce the payload size.
Example:
```py
total_steps = 10
for i in range(total_steps):
# Do some iterative progressing
image = do_iterative_processing(image)
# Calculate the percentage
step = i + 1
percentage = step / total_steps
# Create a short, friendly message
message = f"Processing (step {step}/{total_steps})"
# Build the progress image
progress_image = ProgressImage.build(image)
# Send progress to the UI
context.util.signal_progress(message, percentage, progress_image)
```
Args:
message: A message describing the current status.
percentage: The current percentage completion for the process. Omit for indeterminate progress.
image: An optional progress image to display.
"""
self._services.events.emit_invocation_progress(
queue_item=self._data.queue_item,
invocation=self._data.invocation,
message=message,
percentage=percentage,
image=image,
)
class InvocationContext: class InvocationContext:
"""Provides access to various services and data for the current invocation. """Provides access to various services and data for the current invocation.

View File

@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Callable, Optional from math import floor
from typing import Callable, Optional
import torch import torch
from PIL import Image from PIL import Image
@@ -6,11 +7,6 @@ from PIL import Image
from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage
from invokeai.backend.model_manager.config import BaseModelType from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.util.util import image_to_dataURL
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContextData
# fast latents preview matrix for sdxl # fast latents preview matrix for sdxl
# generated by @StAlKeR7779 # generated by @StAlKeR7779
@@ -56,11 +52,25 @@ def sample_to_lowres_estimated_image(
return Image.fromarray(latents_ubyte.numpy()) return Image.fromarray(latents_ubyte.numpy())
def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
"""Calculate the percentage of completion of denoising."""
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
if total_steps == 0:
return 0.0
if order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)
def stable_diffusion_step_callback( def stable_diffusion_step_callback(
context_data: "InvocationContextData", signal_progress: Callable[[str, float | None, ProgressImage | None], None],
intermediate_state: PipelineIntermediateState, intermediate_state: PipelineIntermediateState,
base_model: BaseModelType, base_model: BaseModelType,
events: "EventServiceBase",
is_canceled: Callable[[], bool], is_canceled: Callable[[], bool],
) -> None: ) -> None:
if is_canceled(): if is_canceled():
@@ -86,11 +96,10 @@ def stable_diffusion_step_callback(
width *= 8 width *= 8
height *= 8 height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG") percentage = calc_percentage(intermediate_state)
events.emit_invocation_denoise_progress( signal_progress(
context_data.queue_item, "Denoising",
context_data.invocation, percentage,
intermediate_state, ProgressImage.build(image=image, size=(width, height)),
ProgressImage(dataURL=dataURL, width=width, height=height),
) )

View File

@@ -3,7 +3,7 @@ import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions'; import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { appInfoApi } from 'services/api/endpoints/appInfo'; import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types'; import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions'; import { socketInvocationProgress } from 'services/events/actions';
export const actionSanitizer = <A extends UnknownAction>(action: A): A => { export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
if (isAnyGraphBuilt(action)) { if (isAnyGraphBuilt(action)) {
@@ -24,10 +24,10 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
}; };
} }
if (socketGeneratorProgress.match(action)) { if (socketInvocationProgress.match(action)) {
const sanitized = deepClone(action); const sanitized = deepClone(action);
if (sanitized.payload.data.progress_image) { if (sanitized.payload.data.image) {
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>'; sanitized.payload.data.image.dataURL = '<Progress image omitted>';
} }
return sanitized; return sanitized;
} }

View File

@@ -39,9 +39,9 @@ import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddlewa
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings'; import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected'; import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected'; import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete'; import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError'; import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationProgress';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted'; import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall'; import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad'; import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
@@ -102,7 +102,7 @@ addStagingAreaImageSavedListener(startAppListening);
addCommitStagingAreaImageListener(startAppListening); addCommitStagingAreaImageListener(startAppListening);
// Socket.IO // Socket.IO
addGeneratorProgressEventListener(startAppListening); addInvocationProgressEventListener(startAppListening);
addInvocationCompleteEventListener(startAppListening); addInvocationCompleteEventListener(startAppListening);
addInvocationErrorEventListener(startAppListening); addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening); addInvocationStartedEventListener(startAppListening);

View File

@@ -4,21 +4,21 @@ import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions'; import { socketInvocationProgress } from 'services/events/actions';
const log = logger('socketio'); const log = logger('socketio');
export const addGeneratorProgressEventListener = (startAppListening: AppStartListening) => { export const addInvocationProgressEventListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: socketGeneratorProgress, actionCreator: socketInvocationProgress,
effect: (action) => { effect: (action) => {
log.trace(parseify(action.payload), `Generator progress`); log.trace(parseify(action.payload), `Generator progress`);
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data; const { invocation_source_id, percentage, image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) { if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS; nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps; nes.progress = percentage;
nes.progressImage = progress_image ?? null; nes.progressImage = image ?? null;
upsertExecutionState(nes.nodeId, nes); upsertExecutionState(nes.nodeId, nes);
} }
}, },

View File

@@ -10,8 +10,7 @@ const progressImageSelector = createMemoizedSelector([selectSystemSlice, selectC
const { batchIds } = canvas; const { batchIds } = canvas;
return { return {
progressImage: progressImage: denoiseProgress && batchIds.includes(denoiseProgress.batch_id) ? denoiseProgress.image : undefined,
denoiseProgress && batchIds.includes(denoiseProgress.batch_id) ? denoiseProgress.progress_image : undefined,
boundingBox: canvas.layerState.stagingArea.boundingBox, boundingBox: canvas.layerState.stagingArea.boundingBox,
}; };
}); });

View File

@@ -40,7 +40,7 @@ const selectShouldDisableToolbarButtons = createSelector(
selectGallerySlice, selectGallerySlice,
selectLastSelectedImage, selectLastSelectedImage,
(system, gallery, lastSelectedImage) => { (system, gallery, lastSelectedImage) => {
const hasProgressImage = Boolean(system.denoiseProgress?.progress_image); const hasProgressImage = Boolean(system.denoiseProgress?.image);
return hasProgressImage || !lastSelectedImage; return hasProgressImage || !lastSelectedImage;
} }
); );

View File

@@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
const CurrentImagePreview = () => { const CurrentImagePreview = () => {
const progress_image = useAppSelector((s) => s.system.denoiseProgress?.progress_image); const image = useAppSelector((s) => s.system.denoiseProgress?.image);
const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage); const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage);
const sx = useMemo<SystemStyleObject>( const sx = useMemo<SystemStyleObject>(
@@ -14,15 +14,15 @@ const CurrentImagePreview = () => {
[shouldAntialiasProgressImage] [shouldAntialiasProgressImage]
); );
if (!progress_image) { if (!image) {
return null; return null;
} }
return ( return (
<Image <Image
src={progress_image.dataURL} src={image.dataURL}
width={progress_image.width} width={image.width}
height={progress_image.height} height={image.height}
draggable={false} draggable={false}
data-testid="progress-image" data-testid="progress-image"
objectFit="contain" objectFit="contain"

View File

@@ -20,7 +20,7 @@ const selector = createMemoizedSelector(selectSystemSlice, selectGallerySlice, (
return { return {
imageDTO, imageDTO,
progressImage: system.denoiseProgress?.progress_image, progressImage: system.denoiseProgress?.image,
}; };
}); });

View File

@@ -1,4 +1,4 @@
import { Progress } from '@invoke-ai/ui-library'; import { Progress, Tooltip } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { selectSystemSlice } from 'features/system/store/systemSlice'; import { selectSystemSlice } from 'features/system/store/systemSlice';
@@ -15,18 +15,21 @@ const ProgressBar = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { data: queueStatus } = useGetQueueStatusQuery(); const { data: queueStatus } = useGetQueueStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useAppSelector((s) => s.system.isConnected);
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress)); const message = useAppSelector((s) => s.system.denoiseProgress?.message);
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress?.percentage !== undefined));
const value = useAppSelector(selectProgressValue); const value = useAppSelector(selectProgressValue);
return ( return (
<Progress <Tooltip label={message} placement="end">
value={value} <Progress
aria-label={t('accessibility.invokeProgressBar')} value={value}
isIndeterminate={isConnected && Boolean(queueStatus?.queue.in_progress) && !hasSteps} aria-label={t('accessibility.invokeProgressBar')}
h={2} isIndeterminate={isConnected && Boolean(queueStatus?.queue.in_progress) && !hasSteps}
w="full" h={2}
colorScheme="invokeBlue" w="full"
/> colorScheme="invokeBlue"
/>
</Tooltip>
); );
}; };

View File

@@ -5,8 +5,8 @@ import type { LogLevelName } from 'roarr';
import { import {
socketConnected, socketConnected,
socketDisconnected, socketDisconnected,
socketGeneratorProgress,
socketInvocationComplete, socketInvocationComplete,
socketInvocationProgress,
socketInvocationStarted, socketInvocationStarted,
socketModelLoadComplete, socketModelLoadComplete,
socketModelLoadStarted, socketModelLoadStarted,
@@ -95,8 +95,8 @@ export const systemSlice = createSlice({
/** /**
* Generator Progress * Generator Progress
*/ */
builder.addCase(socketGeneratorProgress, (state, action) => { builder.addCase(socketInvocationProgress, (state, action) => {
const { step, total_steps, progress_image, session_id, batch_id, percentage } = action.payload.data; const { image, session_id, batch_id, percentage, message } = action.payload.data;
if (state.cancellations.includes(session_id)) { if (state.cancellations.includes(session_id)) {
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a // Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
@@ -105,10 +105,9 @@ export const systemSlice = createSlice({
} }
state.denoiseProgress = { state.denoiseProgress = {
step, message,
total_steps,
percentage, percentage,
progress_image, image,
session_id, session_id,
batch_id, batch_id,
}; };

View File

@@ -1,18 +1,9 @@
import type { LogLevel } from 'app/logging/logger'; import type { LogLevel } from 'app/logging/logger';
import type { ProgressImage } from 'services/events/types'; import type { InvocationProgressEvent } from 'services/events/types';
import { z } from 'zod'; import { z } from 'zod';
type SystemStatus = 'CONNECTED' | 'DISCONNECTED' | 'PROCESSING' | 'ERROR' | 'LOADING_MODEL'; type SystemStatus = 'CONNECTED' | 'DISCONNECTED' | 'PROCESSING' | 'ERROR' | 'LOADING_MODEL';
type DenoiseProgress = {
session_id: string;
batch_id: string;
progress_image: ProgressImage | null | undefined;
step: number;
total_steps: number;
percentage: number;
};
const zLanguage = z.enum([ const zLanguage = z.enum([
'ar', 'ar',
'az', 'az',
@@ -45,7 +36,7 @@ export interface SystemState {
isConnected: boolean; isConnected: boolean;
shouldConfirmOnDelete: boolean; shouldConfirmOnDelete: boolean;
enableImageDebugging: boolean; enableImageDebugging: boolean;
denoiseProgress: DenoiseProgress | null; denoiseProgress: Pick<InvocationProgressEvent, 'session_id' | 'batch_id' | 'image' | 'percentage' | 'message'> | null;
consoleLogLevel: LogLevel; consoleLogLevel: LogLevel;
shouldLogToConsole: boolean; shouldLogToConsole: boolean;
shouldAntialiasProgressImage: boolean; shouldAntialiasProgressImage: boolean;

File diff suppressed because one or more lines are too long

View File

@@ -9,8 +9,8 @@ import type {
DownloadProgressEvent, DownloadProgressEvent,
DownloadStartedEvent, DownloadStartedEvent,
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent, InvocationErrorEvent,
InvocationProgressEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelInstallCancelledEvent, ModelInstallCancelledEvent,
ModelInstallCompleteEvent, ModelInstallCompleteEvent,
@@ -32,9 +32,7 @@ export const socketDisconnected = createSocketAction('Disconnected');
export const socketInvocationStarted = createSocketAction<InvocationStartedEvent>('InvocationStartedEvent'); export const socketInvocationStarted = createSocketAction<InvocationStartedEvent>('InvocationStartedEvent');
export const socketInvocationComplete = createSocketAction<InvocationCompleteEvent>('InvocationCompleteEvent'); export const socketInvocationComplete = createSocketAction<InvocationCompleteEvent>('InvocationCompleteEvent');
export const socketInvocationError = createSocketAction<InvocationErrorEvent>('InvocationErrorEvent'); export const socketInvocationError = createSocketAction<InvocationErrorEvent>('InvocationErrorEvent');
export const socketGeneratorProgress = createSocketAction<InvocationDenoiseProgressEvent>( export const socketInvocationProgress = createSocketAction<InvocationProgressEvent>('InvocationProgressEvent');
'InvocationDenoiseProgressEvent'
);
export const socketModelLoadStarted = createSocketAction<ModelLoadStartedEvent>('ModelLoadStartedEvent'); export const socketModelLoadStarted = createSocketAction<ModelLoadStartedEvent>('ModelLoadStartedEvent');
export const socketModelLoadComplete = createSocketAction<ModelLoadCompleteEvent>('ModelLoadCompleteEvent'); export const socketModelLoadComplete = createSocketAction<ModelLoadCompleteEvent>('ModelLoadCompleteEvent');
export const socketDownloadStarted = createSocketAction<DownloadStartedEvent>('DownloadStartedEvent'); export const socketDownloadStarted = createSocketAction<DownloadStartedEvent>('DownloadStartedEvent');

View File

@@ -14,9 +14,9 @@ import {
socketDownloadError, socketDownloadError,
socketDownloadProgress, socketDownloadProgress,
socketDownloadStarted, socketDownloadStarted,
socketGeneratorProgress,
socketInvocationComplete, socketInvocationComplete,
socketInvocationError, socketInvocationError,
socketInvocationProgress,
socketInvocationStarted, socketInvocationStarted,
socketModelInstallCancelled, socketModelInstallCancelled,
socketModelInstallComplete, socketModelInstallComplete,
@@ -65,8 +65,8 @@ export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) =>
socket.on('invocation_started', (data) => { socket.on('invocation_started', (data) => {
dispatch(socketInvocationStarted({ data })); dispatch(socketInvocationStarted({ data }));
}); });
socket.on('invocation_denoise_progress', (data) => { socket.on('invocation_progress', (data) => {
dispatch(socketGeneratorProgress({ data })); dispatch(socketInvocationProgress({ data }));
}); });
socket.on('invocation_error', (data) => { socket.on('invocation_error', (data) => {
dispatch(socketInvocationError({ data })); dispatch(socketInvocationError({ data }));

View File

@@ -4,10 +4,9 @@ export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent']; export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
export type InvocationStartedEvent = S['InvocationStartedEvent']; export type InvocationStartedEvent = S['InvocationStartedEvent'];
export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent']; export type InvocationProgressEvent = S['InvocationProgressEvent'];
export type InvocationCompleteEvent = S['InvocationCompleteEvent']; export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
export type InvocationErrorEvent = S['InvocationErrorEvent']; export type InvocationErrorEvent = S['InvocationErrorEvent'];
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
export type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent']; export type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent'];
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent']; export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
@@ -39,7 +38,7 @@ type ClientEmitSubscribeBulkDownload = {
type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload; type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload;
export type ServerToClientEvents = { export type ServerToClientEvents = {
invocation_denoise_progress: (payload: InvocationDenoiseProgressEvent) => void; invocation_progress: (payload: InvocationProgressEvent) => void;
invocation_complete: (payload: InvocationCompleteEvent) => void; invocation_complete: (payload: InvocationCompleteEvent) => void;
invocation_error: (payload: InvocationErrorEvent) => void; invocation_error: (payload: InvocationErrorEvent) => void;
invocation_started: (payload: InvocationStartedEvent) => void; invocation_started: (payload: InvocationStartedEvent) => void;

View File

@@ -66,6 +66,7 @@ from invokeai.app.invocations.scheduler import SchedulerOutput
from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
@@ -176,4 +177,5 @@ __all__ = [
# invokeai.app.util.misc # invokeai.app.util.misc
"SEED_MAX", "SEED_MAX",
"get_random_seed", "get_random_seed",
"ProgressImage",
] ]