diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts index 16fe987350..0d2e299cb1 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts @@ -24,12 +24,13 @@ import { selectCanvasSlice, selectEntity, } from 'features/controlLayers/store/selectors'; -import { - type CanvasEntityIdentifier, - type CanvasRenderableEntityState, - isRasterLayerEntityIdentifier, - type Rect, +import type { + CanvasEntityIdentifier, + CanvasRenderableEntityState, + LifecycleCallback, + Rect, } from 'features/controlLayers/store/types'; +import { isRasterLayerEntityIdentifier } from 'features/controlLayers/store/types'; import { toast } from 'features/toast/toast'; import Konva from 'konva'; import { atom } from 'nanostores'; @@ -40,11 +41,6 @@ import stableHash from 'stable-hash'; import { assert } from 'tsafe'; import type { Jsonifiable, JsonObject } from 'type-fest'; -// Ideally, we'd type `adapter` as `CanvasEntityAdapterBase`, but the generics make this tricky. `CanvasEntityAdapter` -// is a union of all entity adapters and is functionally identical to `CanvasEntityAdapterBase`. We'll need to do a -// type assertion below in the `onInit` method, which calls these callbacks. -type InitCallback = (adapter: CanvasEntityAdapter) => Promise; - export abstract class CanvasEntityAdapterBase< T extends CanvasRenderableEntityState, U extends string, @@ -118,7 +114,7 @@ export abstract class CanvasEntityAdapterBase< /** * Callbacks that are executed when the module is initialized. */ - private static initCallbacks = new Set(); + private static initCallbacks = new Set(); /** * Register a callback to be run when an entity adapter is initialized. @@ -165,7 +161,7 @@ export abstract class CanvasEntityAdapterBase< * return false; * }); */ - static registerInitCallback = (callback: InitCallback) => { + static registerInitCallback = (callback: LifecycleCallback) => { const wrapped = async (adapter: CanvasEntityAdapter) => { const result = await callback(adapter); if (result) { diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts index d50d51b390..189273a759 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts @@ -13,7 +13,7 @@ import { roundRect, } from 'features/controlLayers/konva/util'; import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors'; -import type { Coordinate, Rect, RectWithRotation } from 'features/controlLayers/store/types'; +import type { Coordinate, LifecycleCallback, Rect, RectWithRotation } from 'features/controlLayers/store/types'; import { toast } from 'features/toast/toast'; import Konva from 'konva'; import type { GroupConfig } from 'konva/lib/Group'; @@ -123,7 +123,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase { /** * Whether the transformer is currently calculating the rect of the parent. */ - $isPendingRectCalculation = atom(true); + $isPendingRectCalculation = atom(false); /** * A set of subscriptions that should be cleaned up when the transformer is destroyed. @@ -177,6 +177,11 @@ export class CanvasEntityTransformer extends CanvasModuleBase { */ transformMutex = new Mutex(); + /** + * Callbacks that are executed when the bbox is updated. + */ + private static bboxUpdatedCallbacks = new Set(); + konva: { transformer: Konva.Transformer; proxyRect: Konva.Rect; @@ -908,6 +913,8 @@ export class CanvasEntityTransformer extends CanvasModuleBase { this.parent.renderer.konva.objectGroup.setAttrs(groupAttrs); this.parent.bufferRenderer.konva.group.setAttrs(groupAttrs); } + + CanvasEntityTransformer.runBboxUpdatedCallbacks(this.parent); }; calculateRect = debounce(() => { @@ -1026,6 +1033,23 @@ export class CanvasEntityTransformer extends CanvasModuleBase { this.konva.outlineRect.visible(false); }; + static registerBboxUpdatedCallback = (callback: LifecycleCallback) => { + const wrapped = async (adapter: CanvasEntityAdapter) => { + const result = await callback(adapter); + if (result) { + this.bboxUpdatedCallbacks.delete(wrapped); + } + return result; + }; + this.bboxUpdatedCallbacks.add(wrapped); + }; + + private static runBboxUpdatedCallbacks = (adapter: CanvasEntityAdapter) => { + for (const callback of this.bboxUpdatedCallbacks) { + callback(adapter); + } + }; + repr = () => { return { id: this.id, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index e0f9a59846..e681ca73ea 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -1,3 +1,4 @@ +import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types'; import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers'; import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; @@ -611,3 +612,7 @@ export const isMaskEntityIdentifier = ( ): entityIdentifier is CanvasEntityIdentifier<'inpaint_mask' | 'regional_guidance'> => { return isInpaintMaskEntityIdentifier(entityIdentifier) || isRegionalGuidanceEntityIdentifier(entityIdentifier); }; + +// Ideally, we'd type `adapter` as `CanvasEntityAdapterBase`, but the generics make this tricky. `CanvasEntityAdapter` +// is a union of all entity adapters and is functionally identical to `CanvasEntityAdapterBase`. +export type LifecycleCallback = (adapter: CanvasEntityAdapter) => Promise; diff --git a/invokeai/frontend/web/src/features/imageActions/actions.ts b/invokeai/frontend/web/src/features/imageActions/actions.ts index f0460229a8..359899d494 100644 --- a/invokeai/frontend/web/src/features/imageActions/actions.ts +++ b/invokeai/frontend/web/src/features/imageActions/actions.ts @@ -1,6 +1,7 @@ import type { AppDispatch, RootState } from 'app/store/store'; import { deepClone } from 'common/util/deepClone'; import { selectDefaultIPAdapter, selectDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks'; +import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { canvasReset } from 'features/controlLayers/store/actions'; import { @@ -173,12 +174,24 @@ export const newCanvasFromImage = async (arg: { imageObject = imageDTOToImageObject(imageDTO); } + const addFitOnLayerInitCallback = (adapterId: string) => { + CanvasEntityTransformer.registerBboxUpdatedCallback((adapter) => { + // Skip the callback if the adapter is not the one we are creating + if (adapter.id !== adapterId) { + return Promise.resolve(false); + } + adapter.manager.stage.fitBboxAndLayersToStage(); + return Promise.resolve(true); + }); + }; + switch (type) { case 'raster_layer': { const overrides = { id: getPrefixedId('raster_layer'), objects: [imageObject], } satisfies Partial; + addFitOnLayerInitCallback(overrides.id); dispatch(canvasReset()); // The `bboxChangedFromCanvas` reducer does no validation! Careful! dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height })); @@ -191,6 +204,7 @@ export const newCanvasFromImage = async (arg: { objects: [imageObject], controlAdapter: deepClone(initialControlNet), } satisfies Partial; + addFitOnLayerInitCallback(overrides.id); dispatch(canvasReset()); // The `bboxChangedFromCanvas` reducer does no validation! Careful! dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height })); @@ -202,6 +216,7 @@ export const newCanvasFromImage = async (arg: { id: getPrefixedId('inpaint_mask'), objects: [imageObject], } satisfies Partial; + addFitOnLayerInitCallback(overrides.id); dispatch(canvasReset()); // The `bboxChangedFromCanvas` reducer does no validation! Careful! dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height })); @@ -213,6 +228,7 @@ export const newCanvasFromImage = async (arg: { id: getPrefixedId('regional_guidance'), objects: [imageObject], } satisfies Partial; + addFitOnLayerInitCallback(overrides.id); dispatch(canvasReset()); // The `bboxChangedFromCanvas` reducer does no validation! Careful! dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height }));