feat(ui): clean up adapter API

This commit is contained in:
psychedelicious
2024-09-05 16:36:04 +10:00
parent e176e48fa3
commit b189937bc9
16 changed files with 162 additions and 195 deletions

View File

@@ -1,12 +1,12 @@
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapterBase';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntityAdapter/types';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold, PiArrowsOutBold, PiCheckBold, PiXBold } from 'react-icons/pi';
const TransformBox = memo(({ adapter }: { adapter: CanvasEntityAdapterBase }) => {
const TransformBox = memo(({ adapter }: { adapter: CanvasEntityAdapter }) => {
const { t } = useTranslation();
const isProcessing = useStore(adapter.transformer.$isProcessing);

View File

@@ -1,9 +1,9 @@
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapterRegionalGuidance';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRegionalGuidance';
import type { PropsWithChildren } from 'react';
import { createContext, memo, useContext, useMemo, useSyncExternalStore } from 'react';
import { assert } from 'tsafe';

View File

@@ -1,22 +1,26 @@
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapterRegionalGuidance';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRegionalGuidance';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
import { assert } from 'tsafe';
export const useEntityAdapter = (
entityIdentifier: CanvasEntityIdentifier
): CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer | CanvasEntityAdapterInpaintMask | CanvasEntityAdapterRegionalGuidance => {
):
| CanvasEntityAdapterRasterLayer
| CanvasEntityAdapterControlLayer
| CanvasEntityAdapterInpaintMask
| CanvasEntityAdapterRegionalGuidance => {
const canvasManager = useCanvasManager();
const adapter = useMemo(() => {
const entity = canvasManager.stateApi.getEntity(entityIdentifier);
assert(entity, 'Entity adapter not found');
return entity.adapter;
}, [canvasManager.stateApi, entityIdentifier]);
const adapter = canvasManager.getAdapter(entityIdentifier);
assert(adapter, 'Entity adapter not found');
return adapter;
}, [canvasManager, entityIdentifier]);
return adapter;
};

View File

@@ -1,7 +1,7 @@
import type { SerializableObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntityObjectRenderer';
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntityTransformer';
import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntityObjectRenderer';
import type { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntityTransformer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { selectEntity } from 'features/controlLayers/store/selectors';
@@ -36,12 +36,12 @@ export abstract class CanvasEntityAdapterBase<
/**
* The transformer for this entity adapter.
*/
transformer: CanvasEntityTransformer;
abstract transformer: CanvasEntityTransformer;
/**
* The renderer for this entity adapter.
*/
renderer: CanvasEntityObjectRenderer;
abstract renderer: CanvasEntityObjectRenderer;
/**
* The entity's state.
@@ -79,9 +79,6 @@ export abstract class CanvasEntityAdapterBase<
const initialState = this.getSnapshot();
assert(initialState !== undefined, 'Missing entity state on creation');
this.state = initialState;
this.renderer = new CanvasEntityObjectRenderer(this);
this.transformer = new CanvasEntityTransformer(this);
}
/**

View File

@@ -1,5 +1,7 @@
import type { SerializableObject } from 'common/types';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapterBase';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterBase';
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntityObjectRenderer';
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntityTransformer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasControlLayerState, CanvasEntityIdentifier, Rect } from 'features/controlLayers/store/types';
import type { GroupConfig } from 'konva/lib/Group';
@@ -8,8 +10,13 @@ import { omit } from 'lodash-es';
export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase<CanvasControlLayerState> {
static TYPE = 'control_layer_adapter';
transformer: CanvasEntityTransformer;
renderer: CanvasEntityObjectRenderer;
constructor(entityIdentifier: CanvasEntityIdentifier<'control_layer'>, manager: CanvasManager) {
super(entityIdentifier, manager, CanvasEntityAdapterControlLayer.TYPE);
this.transformer = new CanvasEntityTransformer(this);
this.renderer = new CanvasEntityObjectRenderer(this);
this.subscriptions.add(this.manager.stateApi.store.subscribe(this.sync));
this.sync(true);
}

View File

@@ -1,5 +1,7 @@
import type { SerializableObject } from 'common/types';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapterBase';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterBase';
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntityObjectRenderer';
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntityTransformer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasEntityIdentifier, CanvasInpaintMaskState, Rect } from 'features/controlLayers/store/types';
import type { GroupConfig } from 'konva/lib/Group';
@@ -8,8 +10,15 @@ import { omit } from 'lodash-es';
export class CanvasEntityAdapterInpaintMask extends CanvasEntityAdapterBase<CanvasInpaintMaskState> {
static TYPE = 'inpaint_mask_adapter';
transformer: CanvasEntityTransformer;
renderer: CanvasEntityObjectRenderer;
constructor(entityIdentifier: CanvasEntityIdentifier<'inpaint_mask'>, manager: CanvasManager) {
super(entityIdentifier, manager, CanvasEntityAdapterInpaintMask.TYPE);
this.transformer = new CanvasEntityTransformer(this);
this.renderer = new CanvasEntityObjectRenderer(this);
this.subscriptions.add(this.manager.stateApi.store.subscribe(this.sync));
this.sync(true);
}

View File

@@ -1,5 +1,7 @@
import type { SerializableObject } from 'common/types';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapterBase';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterBase';
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntityObjectRenderer';
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntityTransformer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasEntityIdentifier, CanvasRasterLayerState, Rect } from 'features/controlLayers/store/types';
import type { GroupConfig } from 'konva/lib/Group';
@@ -8,8 +10,13 @@ import { omit } from 'lodash-es';
export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<CanvasRasterLayerState> {
static TYPE = 'raster_layer_adapter';
transformer: CanvasEntityTransformer;
renderer: CanvasEntityObjectRenderer;
constructor(entityIdentifier: CanvasEntityIdentifier<'raster_layer'>, manager: CanvasManager) {
super(entityIdentifier, manager, CanvasEntityAdapterRasterLayer.TYPE);
this.transformer = new CanvasEntityTransformer(this);
this.renderer = new CanvasEntityObjectRenderer(this);
this.subscriptions.add(this.manager.stateApi.store.subscribe(this.sync));
this.sync(true);
}

View File

@@ -1,5 +1,7 @@
import type { SerializableObject } from 'common/types';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapterBase';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterBase';
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntityObjectRenderer';
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntityTransformer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasEntityIdentifier, CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types';
import type { GroupConfig } from 'konva/lib/Group';
@@ -8,8 +10,15 @@ import { omit } from 'lodash-es';
export class CanvasEntityAdapterRegionalGuidance extends CanvasEntityAdapterBase<CanvasRegionalGuidanceState> {
static TYPE = 'regional_guidance_adapter';
transformer: CanvasEntityTransformer;
renderer: CanvasEntityObjectRenderer;
constructor(entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>, manager: CanvasManager) {
super(entityIdentifier, manager, CanvasEntityAdapterRegionalGuidance.TYPE);
this.transformer = new CanvasEntityTransformer(this);
this.renderer = new CanvasEntityObjectRenderer(this);
this.subscriptions.add(this.manager.stateApi.store.subscribe(this.sync));
this.sync(true);
}

View File

@@ -0,0 +1,10 @@
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRegionalGuidance';
export type CanvasEntityAdapter =
| CanvasEntityAdapterRasterLayer
| CanvasEntityAdapterControlLayer
| CanvasEntityAdapterInpaintMask
| CanvasEntityAdapterRegionalGuidance;

View File

@@ -1,5 +1,5 @@
import { rgbColorToString } from 'common/util/colorCodeTransformers';
import type { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapterBase';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntityAdapter/types';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasObjectBrushLine } from 'features/controlLayers/konva/CanvasObjectBrushLine';
@@ -58,7 +58,7 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
readonly type = 'object_renderer';
readonly id: string;
readonly path: string[];
readonly parent: CanvasEntityAdapterBase;
readonly parent: CanvasEntityAdapter;
readonly manager: CanvasManager;
readonly log: Logger;
@@ -129,7 +129,7 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
*/
$canvasCache = atom<{ canvas: HTMLCanvasElement; rect: Rect } | null>(null);
constructor(parent: CanvasEntityAdapterBase) {
constructor(parent: CanvasEntityAdapter) {
super();
this.id = getPrefixedId(this.type);
this.parent = parent;

View File

@@ -37,7 +37,6 @@ export class CanvasEntityRendererModule extends CanvasModuleBase {
this.manager.stateApi.$settingsState.set(this.manager.stateApi.getSettings());
this.manager.stateApi.$selectedEntityIdentifier.set(state.selectedEntityIdentifier);
this.manager.stateApi.$selectedEntity.set(this.manager.stateApi.getSelectedEntity());
this.manager.stateApi.$currentFill.set(this.manager.stateApi.getCurrentColor());
if (prevState === state) {

View File

@@ -1,4 +1,4 @@
import type { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapterBase';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntityAdapter/types';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { canvasToImageData, getEmptyRect, getPrefixedId } from 'features/controlLayers/konva/util';
@@ -79,7 +79,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
readonly type = 'entity_transformer';
readonly id: string;
readonly path: string[];
readonly parent: CanvasEntityAdapterBase;
readonly parent: CanvasEntityAdapter;
readonly manager: CanvasManager;
readonly log: Logger;

View File

@@ -1,10 +1,10 @@
import type { SerializableObject } from 'common/types';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRegionalGuidance';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapterRegionalGuidance';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasEntityIdentifier, CanvasImageState, FilterConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS, imageDTOToImageObject } from 'features/controlLayers/store/types';
@@ -48,16 +48,16 @@ export class CanvasFilterModule extends CanvasModuleBase {
initialize = (entityIdentifier: CanvasEntityIdentifier) => {
this.log.trace('Initializing filter');
const entity = this.manager.stateApi.getEntity(entityIdentifier);
if (!entity) {
const adapter = this.manager.getAdapter(entityIdentifier);
if (!adapter) {
this.log.warn({ entityIdentifier }, 'Unable to find entity');
return;
}
if (entity.type !== 'raster_layer' && entity.type !== 'control_layer') {
if (adapter.entityIdentifier.type !== 'raster_layer' && adapter.entityIdentifier.type !== 'control_layer') {
this.log.warn({ entityIdentifier }, 'Unsupported entity type');
return;
}
this.$adapter.set(entity.adapter);
this.$adapter.set(adapter);
this.manager.tool.$tool.set('view');
};

View File

@@ -6,10 +6,11 @@ import { SyncableMap } from 'common/util/SyncableMap/SyncableMap';
import { CanvasBboxModule } from 'features/controlLayers/konva/CanvasBboxModule';
import { CanvasCacheModule } from 'features/controlLayers/konva/CanvasCacheModule';
import { CanvasCompositorModule } from 'features/controlLayers/konva/CanvasCompositorModule';
import { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapterControlLayer';
import { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapterInpaintMask';
import { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapterRasterLayer';
import { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapterRegionalGuidance';
import { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterControlLayer';
import { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterInpaintMask';
import { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRasterLayer';
import { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRegionalGuidance';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntityAdapter/types';
import { CanvasEntityRendererModule } from 'features/controlLayers/konva/CanvasEntityRendererModule';
import { CanvasFilterModule } from 'features/controlLayers/konva/CanvasFilterModule';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
@@ -132,6 +133,21 @@ export class CanvasManager extends CanvasModuleBase {
});
}
getAdapter = (entityIdentifier: CanvasEntityIdentifier): CanvasEntityAdapter | null => {
switch (entityIdentifier.type) {
case 'raster_layer':
return this.adapters.rasterLayers.get(entityIdentifier.id) ?? null;
case 'control_layer':
return this.adapters.controlLayers.get(entityIdentifier.id) ?? null;
case 'regional_guidance':
return this.adapters.regionMasks.get(entityIdentifier.id) ?? null;
case 'inpaint_mask':
return this.adapters.inpaintMasks.get(entityIdentifier.id) ?? null;
default:
return null;
}
};
deleteAdapter = (entityIdentifier: CanvasEntityIdentifier): boolean => {
switch (entityIdentifier.type) {
case 'raster_layer':
@@ -147,12 +163,7 @@ export class CanvasManager extends CanvasModuleBase {
}
};
getAllAdapters = (): (
| CanvasEntityAdapterRasterLayer
| CanvasEntityAdapterControlLayer
| CanvasEntityAdapterRegionalGuidance
| CanvasEntityAdapterInpaintMask
)[] => {
getAllAdapters = (): CanvasEntityAdapter[] => {
return [
...this.adapters.rasterLayers.values(),
...this.adapters.controlLayers.values(),
@@ -196,7 +207,6 @@ export class CanvasManager extends CanvasModuleBase {
this.stateApi.$settingsState.set(this.stateApi.getSettings());
this.stateApi.$selectedEntityIdentifier.set(this.stateApi.getCanvasState().selectedEntityIdentifier);
this.stateApi.$currentFill.set(this.stateApi.getCurrentColor());
this.stateApi.$selectedEntity.set(this.stateApi.getSelectedEntity());
this.stage.initialize();
$canvasManager.set(this);

View File

@@ -1,12 +1,7 @@
import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { AppStore } from 'app/store/store';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapterBase';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapterInpaintMask';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapterRegionalGuidance';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasSettingsState } from 'features/controlLayers/store/canvasSettingsSlice';
import {
@@ -25,12 +20,8 @@ import {
} from 'features/controlLayers/store/canvasSlice';
import { selectAllRenderableEntities, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
CanvasEntityType,
CanvasInpaintMaskState,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
EntityBrushLineAddedPayload,
EntityEraserLineAddedPayload,
EntityIdentifierPayload,
@@ -49,31 +40,7 @@ import type { BatchConfig } from 'services/api/types';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
import { assert } from 'tsafe';
type EntityStateAndAdapter =
| {
id: string;
type: CanvasRasterLayerState['type'];
state: CanvasRasterLayerState;
adapter: CanvasEntityAdapterRasterLayer;
}
| {
id: string;
type: CanvasControlLayerState['type'];
state: CanvasControlLayerState;
adapter: CanvasEntityAdapterControlLayer;
}
| {
id: string;
type: CanvasInpaintMaskState['type'];
state: CanvasInpaintMaskState;
adapter: CanvasEntityAdapterInpaintMask;
}
| {
id: string;
type: CanvasRegionalGuidanceState['type'];
state: CanvasRegionalGuidanceState;
adapter: CanvasEntityAdapterRegionalGuidance;
};
import type { CanvasEntityAdapter } from './CanvasEntityAdapter/types';
export class CanvasStateApiModule extends CanvasModuleBase {
readonly type = 'state_api';
@@ -265,56 +232,10 @@ export class CanvasStateApiModule extends CanvasModuleBase {
}
};
/**
* Gets an entity by its identifier. The entity's state is retrieved from the redux store, and its adapter is
* retrieved from the canvas manager.
*
* Both state and adapter must exist for the entity to be returned.
*/
getEntity<T extends CanvasEntityIdentifier>({
id,
type,
}: T): Extract<EntityStateAndAdapter, { type: T['type'] }> | null {
const state = this.getCanvasState();
let entityState: EntityStateAndAdapter['state'] | undefined = undefined;
let entityAdapter: EntityStateAndAdapter['adapter'] | undefined = undefined;
switch (type) {
case 'raster_layer':
entityState = state.rasterLayers.entities.find((i) => i.id === id);
entityAdapter = this.manager.adapters.rasterLayers.get(id);
break;
case 'control_layer':
entityState = state.controlLayers.entities.find((i) => i.id === id);
entityAdapter = this.manager.adapters.controlLayers.get(id);
break;
case 'regional_guidance':
entityState = state.regions.entities.find((i) => i.id === id);
entityAdapter = this.manager.adapters.regionMasks.get(id);
break;
case 'inpaint_mask':
entityState = state.inpaintMasks.entities.find((i) => i.id === id);
entityAdapter = this.manager.adapters.inpaintMasks.get(id);
break;
}
if (entityState && entityAdapter) {
return {
id: entityState.id,
type: entityState.type,
state: entityState,
adapter: entityAdapter,
} as Extract<EntityStateAndAdapter, { type: T['type'] }>; // TODO(psyche): make TS happy w/o this cast
}
return null;
}
/**
* Gets the number of entities that are currently rendered on the canvas.
*/
getRenderedEntityCount = () => {
getRenderedEntityCount = (): number => {
const renderableEntities = selectAllRenderableEntities(this.getCanvasState());
let count = 0;
for (const entity of renderableEntities) {
@@ -326,13 +247,12 @@ export class CanvasStateApiModule extends CanvasModuleBase {
};
/**
* Gets the currently selected entity, if any. The entity's state is retrieved from the redux store, and its adapter
* is retrieved from the canvas manager.
* Gets the currently selected entity's adapter
*/
getSelectedEntity = () => {
getSelectedEntityAdapter = (): CanvasEntityAdapter | null => {
const state = this.getCanvasState();
if (state.selectedEntityIdentifier) {
return this.getEntity(state.selectedEntityIdentifier);
return this.manager.getAdapter(state.selectedEntityIdentifier);
}
return null;
};
@@ -347,9 +267,9 @@ export class CanvasStateApiModule extends CanvasModuleBase {
* so the color for lines and rects doesn't matter - it is never seen. The only requirement is that it is opaque. For
* consistency with conventional black and white mask images, we use black as the color for these entities.
*/
getCurrentColor = () => {
getCurrentColor = (): RgbaColor => {
let color: RgbaColor = this.getSettings().color;
const selectedEntity = this.getSelectedEntity();
const selectedEntity = this.getSelectedEntityAdapter();
if (selectedEntity) {
// These two entity types use a compositing rect for opacity. Their fill is always a solid color.
if (selectedEntity.state.type === 'regional_guidance' || selectedEntity.state.type === 'inpaint_mask') {
@@ -368,7 +288,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
* color.
*/
getBrushPreviewColor = (): RgbaColor => {
const selectedEntity = this.getSelectedEntity();
const selectedEntity = this.getSelectedEntityAdapter();
if (selectedEntity?.state.type === 'regional_guidance' || selectedEntity?.state.type === 'inpaint_mask') {
// TODO(psyche): If we move the brush preview's Konva nodes to the selected entity renderer, we can draw them
// under the entity's compositing rect, so they would use selected entity's selected color and texture. As a
@@ -383,7 +303,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
/**
* The entity adapter being transformed, if any.
*/
$transformingAdapter = atom<CanvasEntityAdapterBase | null>(null);
$transformingAdapter = atom<CanvasEntityAdapter | null>(null);
/**
* Whether an entity is currently being transformed. Derived from `$transformingAdapter`.
@@ -400,11 +320,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
*/
$currentFill: WritableAtom<RgbaColor> = atom();
/**
* The currently selected entity, if any. Includes the entity latest state and its adapter.
*/
$selectedEntity: WritableAtom<EntityStateAndAdapter | null> = atom();
/**
* The currently selected entity's identifier, if an entity is selected.
*/

View File

@@ -146,7 +146,7 @@ export class CanvasToolModule extends CanvasModuleBase {
stage.setCursor('not-allowed');
} else if (this.manager.$isBusy.get()) {
stage.setCursor('not-allowed');
} else if (!this.manager.stateApi.getSelectedEntity()?.adapter.getIsInteractable()) {
} else if (!this.manager.stateApi.getSelectedEntityAdapter()?.getIsInteractable()) {
stage.setCursor('not-allowed');
} else if (tool === 'colorPicker' || tool === 'brush' || tool === 'eraser') {
stage.setCursor('none');
@@ -162,7 +162,7 @@ export class CanvasToolModule extends CanvasModuleBase {
render = () => {
const stage = this.manager.stage;
const renderedEntityCount = this.manager.stateApi.getRenderedEntityCount();
const selectedEntity = this.manager.stateApi.getSelectedEntity();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
const cursorPos = this.$cursorPos.get();
const tool = this.$tool.get();
@@ -282,7 +282,7 @@ export class CanvasToolModule extends CanvasModuleBase {
return false;
} else if (this.manager.filter.$isFiltering.get()) {
return false;
} else if (!this.manager.stateApi.getSelectedEntity()?.adapter.getIsInteractable()) {
} else if (!this.manager.stateApi.getSelectedEntityAdapter()?.getIsInteractable()) {
return false;
} else {
return true;
@@ -299,21 +299,21 @@ export class CanvasToolModule extends CanvasModuleBase {
const isMouseDown = this.$isMouseDown.get();
const settings = this.manager.stateApi.getSettings();
const tool = this.$tool.get();
const selectedEntity = this.manager.stateApi.getSelectedEntity();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!cursorPos || !isMouseDown || !selectedEntity?.state.isEnabled || selectedEntity.state.isLocked) {
return;
}
if (selectedEntity.adapter.renderer.bufferState?.type !== 'rect') {
selectedEntity.adapter.renderer.commitBuffer();
if (selectedEntity.renderer.bufferState?.type !== 'rect') {
selectedEntity.renderer.commitBuffer();
return;
}
if (tool === 'brush') {
const normalizedPoint = offsetCoord(cursorPos, selectedEntity.state.position);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
await selectedEntity.adapter.renderer.setBuffer({
await selectedEntity.renderer.setBuffer({
id: getPrefixedId('brush_line'),
type: 'brush_line',
points: [alignedPoint.x, alignedPoint.y],
@@ -327,10 +327,10 @@ export class CanvasToolModule extends CanvasModuleBase {
if (tool === 'eraser') {
const normalizedPoint = offsetCoord(cursorPos, selectedEntity.state.position);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
if (selectedEntity.adapter.renderer.bufferState) {
selectedEntity.adapter.renderer.commitBuffer();
if (selectedEntity.renderer.bufferState) {
selectedEntity.renderer.commitBuffer();
}
await selectedEntity.adapter.renderer.setBuffer({
await selectedEntity.renderer.setBuffer({
id: getPrefixedId('eraser_line'),
type: 'eraser_line',
points: [alignedPoint.x, alignedPoint.y],
@@ -365,7 +365,7 @@ export class CanvasToolModule extends CanvasModuleBase {
}
const isMouseDown = this.$isMouseDown.get();
const selectedEntity = this.manager.stateApi.getSelectedEntity();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!cursorPos || !isMouseDown || !selectedEntity?.state.isEnabled || selectedEntity?.state.isLocked) {
return;
@@ -378,11 +378,11 @@ export class CanvasToolModule extends CanvasModuleBase {
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
if (e.evt.shiftKey && lastLinePoint) {
// Create a straight line from the last line point
if (selectedEntity.adapter.renderer.bufferState) {
selectedEntity.adapter.renderer.commitBuffer();
if (selectedEntity.renderer.bufferState) {
selectedEntity.renderer.commitBuffer();
}
await selectedEntity.adapter.renderer.setBuffer({
await selectedEntity.renderer.setBuffer({
id: getPrefixedId('brush_line'),
type: 'brush_line',
points: [
@@ -397,10 +397,10 @@ export class CanvasToolModule extends CanvasModuleBase {
clip: this.getClip(selectedEntity.state),
});
} else {
if (selectedEntity.adapter.renderer.bufferState) {
selectedEntity.adapter.renderer.commitBuffer();
if (selectedEntity.renderer.bufferState) {
selectedEntity.renderer.commitBuffer();
}
await selectedEntity.adapter.renderer.setBuffer({
await selectedEntity.renderer.setBuffer({
id: getPrefixedId('brush_line'),
type: 'brush_line',
points: [alignedPoint.x, alignedPoint.y],
@@ -416,10 +416,10 @@ export class CanvasToolModule extends CanvasModuleBase {
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
if (e.evt.shiftKey && lastLinePoint) {
// Create a straight line from the last line point
if (selectedEntity.adapter.renderer.bufferState) {
selectedEntity.adapter.renderer.commitBuffer();
if (selectedEntity.renderer.bufferState) {
selectedEntity.renderer.commitBuffer();
}
await selectedEntity.adapter.renderer.setBuffer({
await selectedEntity.renderer.setBuffer({
id: getPrefixedId('eraser_line'),
type: 'eraser_line',
points: [
@@ -433,10 +433,10 @@ export class CanvasToolModule extends CanvasModuleBase {
clip: this.getClip(selectedEntity.state),
});
} else {
if (selectedEntity.adapter.renderer.bufferState) {
selectedEntity.adapter.renderer.commitBuffer();
if (selectedEntity.renderer.bufferState) {
selectedEntity.renderer.commitBuffer();
}
await selectedEntity.adapter.renderer.setBuffer({
await selectedEntity.renderer.setBuffer({
id: getPrefixedId('eraser_line'),
type: 'eraser_line',
points: [alignedPoint.x, alignedPoint.y],
@@ -447,10 +447,10 @@ export class CanvasToolModule extends CanvasModuleBase {
}
if (tool === 'rect') {
if (selectedEntity.adapter.renderer.bufferState) {
selectedEntity.adapter.renderer.commitBuffer();
if (selectedEntity.renderer.bufferState) {
selectedEntity.renderer.commitBuffer();
}
await selectedEntity.adapter.renderer.setBuffer({
await selectedEntity.renderer.setBuffer({
id: getPrefixedId('rect'),
type: 'rect',
rect: { x: Math.round(normalizedPoint.x), y: Math.round(normalizedPoint.y), width: 0, height: 0 },
@@ -473,7 +473,7 @@ export class CanvasToolModule extends CanvasModuleBase {
if (!cursorPos) {
return;
}
const selectedEntity = this.manager.stateApi.getSelectedEntity();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
const isDrawable = selectedEntity?.state.isEnabled && !selectedEntity.state.isLocked;
if (!isDrawable) {
return;
@@ -481,26 +481,26 @@ export class CanvasToolModule extends CanvasModuleBase {
const tool = this.$tool.get();
if (tool === 'brush') {
if (selectedEntity.adapter.renderer.bufferState?.type === 'brush_line') {
selectedEntity.adapter.renderer.commitBuffer();
if (selectedEntity.renderer.bufferState?.type === 'brush_line') {
selectedEntity.renderer.commitBuffer();
} else {
selectedEntity.adapter.renderer.clearBuffer();
selectedEntity.renderer.clearBuffer();
}
}
if (tool === 'eraser') {
if (selectedEntity.adapter.renderer.bufferState?.type === 'eraser_line') {
selectedEntity.adapter.renderer.commitBuffer();
if (selectedEntity.renderer.bufferState?.type === 'eraser_line') {
selectedEntity.renderer.commitBuffer();
} else {
selectedEntity.adapter.renderer.clearBuffer();
selectedEntity.renderer.clearBuffer();
}
}
if (tool === 'rect') {
if (selectedEntity.adapter.renderer.bufferState?.type === 'rect') {
selectedEntity.adapter.renderer.commitBuffer();
if (selectedEntity.renderer.bufferState?.type === 'rect') {
selectedEntity.renderer.commitBuffer();
} else {
selectedEntity.adapter.renderer.clearBuffer();
selectedEntity.renderer.clearBuffer();
}
}
} finally {
@@ -526,14 +526,14 @@ export class CanvasToolModule extends CanvasModuleBase {
}
const isMouseDown = this.$isMouseDown.get();
const selectedEntity = this.manager.stateApi.getSelectedEntity();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
const isDrawable = selectedEntity?.state.isEnabled && !selectedEntity.state.isLocked && cursorPos && isMouseDown;
if (!isDrawable) {
return;
}
const bufferState = selectedEntity.adapter.renderer.bufferState;
const bufferState = selectedEntity.renderer.bufferState;
if (!bufferState) {
return;
@@ -557,7 +557,7 @@ export class CanvasToolModule extends CanvasModuleBase {
}
bufferState.points.push(alignedPoint.x, alignedPoint.y);
await selectedEntity.adapter.renderer.setBuffer(bufferState);
await selectedEntity.renderer.setBuffer(bufferState);
} else if (tool === 'eraser' && bufferState.type === 'eraser_line') {
const lastPoint = getLastPointOfLine(bufferState.points);
const minDistance = settings.eraserWidth * this.config.BRUSH_SPACING_TARGET_SCALE;
@@ -574,15 +574,15 @@ export class CanvasToolModule extends CanvasModuleBase {
}
bufferState.points.push(alignedPoint.x, alignedPoint.y);
await selectedEntity.adapter.renderer.setBuffer(bufferState);
await selectedEntity.renderer.setBuffer(bufferState);
} else if (tool === 'rect' && bufferState.type === 'rect') {
const normalizedPoint = offsetCoord(cursorPos, selectedEntity.state.position);
const alignedPoint = floorCoord(normalizedPoint);
bufferState.rect.width = Math.round(alignedPoint.x - bufferState.rect.x);
bufferState.rect.height = Math.round(alignedPoint.y - bufferState.rect.y);
await selectedEntity.adapter.renderer.setBuffer(bufferState);
await selectedEntity.renderer.setBuffer(bufferState);
} else {
selectedEntity?.adapter.renderer.clearBuffer();
selectedEntity?.renderer.clearBuffer();
}
} finally {
this.render();
@@ -595,10 +595,10 @@ export class CanvasToolModule extends CanvasModuleBase {
}
this.$cursorPos.set(null);
const selectedEntity = this.manager.stateApi.getSelectedEntity();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (selectedEntity && selectedEntity.adapter.renderer.bufferState?.type !== 'rect') {
selectedEntity.adapter.renderer.commitBuffer();
if (selectedEntity && selectedEntity.renderer.bufferState?.type !== 'rect') {
selectedEntity.renderer.commitBuffer();
}
this.render();
@@ -636,10 +636,10 @@ export class CanvasToolModule extends CanvasModuleBase {
onWindowPointerUp = () => {
this.$isMouseDown.set(false);
const selectedEntity = this.manager.stateApi.getSelectedEntity();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (selectedEntity && selectedEntity.adapter.renderer.hasBuffer()) {
selectedEntity.adapter.renderer.commitBuffer();
if (selectedEntity && selectedEntity.renderer.hasBuffer()) {
selectedEntity.renderer.commitBuffer();
}
};
@@ -654,9 +654,9 @@ export class CanvasToolModule extends CanvasModuleBase {
if (e.key === 'Escape') {
// Cancel shape drawing on escape
const selectedEntity = this.manager.stateApi.getSelectedEntity();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (selectedEntity) {
selectedEntity.adapter.renderer.clearBuffer();
selectedEntity.renderer.clearBuffer();
}
return;
}