feat(ui): split out ref images into own slice (WIP)

This commit is contained in:
psychedelicious
2025-06-12 17:18:06 +10:00
parent a5e5cbd7c3
commit aa93e95a94
62 changed files with 871 additions and 699 deletions

View File

@@ -1,4 +1,5 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/state';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
@@ -20,9 +21,10 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
const nodes = selectNodesSlice(state);
const canvas = selectCanvasSlice(state);
const upscale = selectUpscaleSlice(state);
const refImages = selectRefImagesSlice(state);
deleted_images.forEach((image_name) => {
const imageUsage = getImageUsage(nodes, canvas, upscale, image_name);
const imageUsage = getImageUsage(nodes, canvas, upscale, refImages, image_name);
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
dispatch(nodeEditorReset());

View File

@@ -1,11 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import {
controlLayerModelChanged,
referenceImageIPAdapterModelChanged,
rgIPAdapterModelChanged,
} from 'features/controlLayers/store/canvasSlice';
import { controlLayerModelChanged, rgIPAdapterModelChanged } from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import {
clipEmbedModelSelected,
@@ -15,6 +11,7 @@ import {
t5EncoderModelSelected,
vaeSelected,
} from 'features/controlLayers/store/paramsSlice';
import { referenceImageIPAdapterModelChanged, selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { modelSelected } from 'features/parameters/store/actions';
@@ -210,7 +207,7 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log)
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
const ipaModels = models.filter(isIPAdapterModelConfig);
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
selectRefImagesSlice(state).entities.forEach((entity) => {
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
@@ -225,7 +222,7 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
return;
}
log.debug({ selectedIPAdapterModel }, 'Selected IP adapter model is not available, clearing');
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
dispatch(referenceImageIPAdapterModelChanged({ id: entity.id, modelConfig: null }));
});
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
@@ -254,7 +251,7 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
const fluxReduxModels = models.filter(isFluxReduxModelConfig);
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
selectRefImagesSlice(state).entities.forEach((entity) => {
if (entity.ipAdapter.type !== 'flux_redux') {
return;
}
@@ -268,7 +265,7 @@ const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
return;
}
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
dispatch(referenceImageIPAdapterModelChanged({ id: entity.id, modelConfig: null }));
});
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {

View File

@@ -13,6 +13,7 @@ import {
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
import { refImagesPersistConfig, refImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
@@ -66,6 +67,7 @@ const allReducers = {
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
[lorasSlice.name]: lorasSlice.reducer,
[workflowLibrarySlice.name]: workflowLibrarySlice.reducer,
[refImagesSlice.name]: refImagesSlice.reducer,
};
const rootReducer = combineReducers(allReducers);
@@ -111,6 +113,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
[lorasPersistConfig.name]: lorasPersistConfig,
[workflowLibraryPersistConfig.name]: workflowLibraryPersistConfig,
[refImagesSlice.name]: refImagesPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {

View File

@@ -0,0 +1,159 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/**
* Adapted from https://github.com/chakra-ui/chakra-ui/blob/v2/packages/hooks/src/use-outside-click.ts
*
* The main change here is to support filtering of outside clicks via a `filter` function.
*
* This lets us work around issues with portals and components like popovers, which typically close on an outside click.
*
* For example, consider a popover that has a custom drop-down component inside it, which uses a portal to render
* the drop-down options. The original outside click handler would close the popover when clicking on the drop-down options,
* because the click is outside the popover - but we expect the popover to stay open in this case.
*
* A filter function like this can fix that:
*
* ```ts
* const filter = (el: HTMLElement) => el.className.includes('chakra-portal') || el.id.includes('react-select')
* ```
*
* This ignores clicks on react-select-based drop-downs and Chakra UI portals and is used as the default filter.
*/
import { useCallback, useEffect, useRef } from 'react';
export function useCallbackRef<T extends (...args: any[]) => any>(
callback: T | undefined,
deps: React.DependencyList = []
) {
const callbackRef = useRef(callback);
useEffect(() => {
callbackRef.current = callback;
});
// eslint-disable-next-line react-hooks/exhaustive-deps
return useCallback(((...args) => callbackRef.current?.(...args)) as T, deps);
}
export interface UseOutsideClickProps {
/**
* Whether the hook is enabled
*/
enabled?: boolean;
/**
* The reference to a DOM element.
*/
ref: React.RefObject<HTMLElement | null>;
/**
* Function invoked when a click is triggered outside the referenced element.
*/
handler?: (e: Event) => void;
/**
* A function that filters the elements that should be considered as outside clicks.
*
* If omitted, a default filter function that ignores clicks in Chakra UI portals and react-select components is used.
*/
filter?: (el: HTMLElement) => boolean;
}
const DEFAULT_FILTER = (el: HTMLElement) => el.className.includes('chakra-portal') || el.id.includes('react-select');
/**
* Example, used in components like Dialogs and Popovers, so they can close
* when a user clicks outside them.
*/
export function useFilterableOutsideClick(props: UseOutsideClickProps) {
const { ref, handler, enabled = true, filter = DEFAULT_FILTER } = props;
const savedHandler = useCallbackRef(handler);
const stateRef = useRef({
isPointerDown: false,
ignoreEmulatedMouseEvents: false,
});
const state = stateRef.current;
useEffect(() => {
if (!enabled) {
return;
}
const onPointerDown: any = (e: PointerEvent) => {
if (isValidEvent(e, ref, filter)) {
state.isPointerDown = true;
}
};
const onMouseUp: any = (event: MouseEvent) => {
if (state.ignoreEmulatedMouseEvents) {
state.ignoreEmulatedMouseEvents = false;
return;
}
if (state.isPointerDown && handler && isValidEvent(event, ref)) {
state.isPointerDown = false;
savedHandler(event);
}
};
const onTouchEnd = (event: TouchEvent) => {
state.ignoreEmulatedMouseEvents = true;
if (handler && state.isPointerDown && isValidEvent(event, ref)) {
state.isPointerDown = false;
savedHandler(event);
}
};
const doc = getOwnerDocument(ref.current);
doc.addEventListener('mousedown', onPointerDown, true);
doc.addEventListener('mouseup', onMouseUp, true);
doc.addEventListener('touchstart', onPointerDown, true);
doc.addEventListener('touchend', onTouchEnd, true);
return () => {
doc.removeEventListener('mousedown', onPointerDown, true);
doc.removeEventListener('mouseup', onMouseUp, true);
doc.removeEventListener('touchstart', onPointerDown, true);
doc.removeEventListener('touchend', onTouchEnd, true);
};
}, [handler, ref, savedHandler, state, enabled, filter]);
}
function isValidEvent(
event: Event,
ref: React.RefObject<HTMLElement | null>,
filter?: (el: HTMLElement) => boolean
): boolean {
const target = (event.composedPath?.()[0] ?? event.target) as HTMLElement;
if (target) {
const doc = getOwnerDocument(target);
if (!doc.contains(target)) {
return false;
}
}
if (ref.current?.contains(target)) {
return false;
}
if (filter) {
// Check if the click is inside an element matching the filter.
// This is used for portal-awareness or other general exclusion cases.
let currentElement: HTMLElement | null = target;
// Traverse up the DOM tree from the target element.
while (currentElement && currentElement !== document.body) {
if (filter(currentElement)) {
return false;
}
currentElement = currentElement.parentElement;
}
}
// If the click is not inside the ref and not inside a portal, it's a valid outside click.
return true;
}
function getOwnerDocument(node?: Element | null): Document {
return node?.ownerDocument ?? document;
}

View File

@@ -2,7 +2,6 @@ import { Button, Flex, Heading } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import {
useAddControlLayer,
useAddGlobalReferenceImage,
useAddInpaintMask,
useAddRasterLayer,
useAddRegionalGuidance,
@@ -19,9 +18,7 @@ export const CanvasAddEntityButtons = memo(() => {
const addRegionalGuidance = useAddRegionalGuidance();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
@@ -29,21 +26,6 @@ export const CanvasAddEntityButtons = memo(() => {
return (
<Flex w="full" h="full" justifyContent="center" gap={4}>
<Flex position="relative" flexDir="column" gap={4} top="20%">
<Flex flexDir="column" justifyContent="flex-start" gap={2}>
<Heading size="xs">{t('controlLayers.global')}</Heading>
<InformationalPopover feature="globalReferenceImage">
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addGlobalReferenceImage}
isDisabled={!isReferenceImageEnabled}
>
{t('controlLayers.globalReferenceImage')}
</Button>
</InformationalPopover>
</Flex>
<Flex flexDir="column" gap={2}>
<Heading size="xs">{t('controlLayers.regional')}</Heading>
<InformationalPopover feature="inpainting">

View File

@@ -12,9 +12,6 @@ const addControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.
const addRegionalGuidanceReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'regional_guidance_with_reference_image',
});
const addGlobalReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'reference_image',
});
export const CanvasDropArea = memo(() => {
const { t } = useTranslation();
@@ -57,14 +54,6 @@ export const CanvasDropArea = memo(() => {
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addGlobalReferenceImageFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
isDisabled={isBusy}
/>
</GridItem>
</Grid>
</>
);

View File

@@ -14,7 +14,6 @@ import { useEntityTypeInformationalPopover } from 'features/controlLayers/hooks/
import { useEntityTypeTitle } from 'features/controlLayers/hooks/useEntityTypeTitle';
import { entitiesReordered } from 'features/controlLayers/store/canvasSlice';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { isRenderableEntityType } from 'features/controlLayers/store/types';
import { singleCanvasEntityDndSource } from 'features/dnd/dnd';
import { triggerPostMoveFlash } from 'features/dnd/util';
import type { PropsWithChildren } from 'react';
@@ -165,8 +164,8 @@ export const CanvasEntityGroupList = memo(({ isSelected, type, children, entityI
<Spacer />
</Flex>
{isRenderableEntityType(type) && <CanvasEntityMergeVisibleButton type={type} />}
{isRenderableEntityType(type) && <CanvasEntityTypeIsHiddenToggle type={type} />}
<CanvasEntityMergeVisibleButton type={type} />
<CanvasEntityTypeIsHiddenToggle type={type} />
<CanvasEntityAddOfTypeButton type={type} />
</Flex>
<Collapse in={collapse.isTrue} style={fixTooltipCloseOnScrollStyles}>

View File

@@ -2,7 +2,6 @@ import { Flex } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { ControlLayerEntityList } from 'features/controlLayers/components/ControlLayer/ControlLayerEntityList';
import { InpaintMaskList } from 'features/controlLayers/components/InpaintMask/InpaintMaskList';
import { IPAdapterList } from 'features/controlLayers/components/IPAdapter/IPAdapterList';
import { RasterLayerEntityList } from 'features/controlLayers/components/RasterLayer/RasterLayerEntityList';
import { RegionalGuidanceEntityList } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceEntityList';
import { memo } from 'react';
@@ -11,7 +10,6 @@ export const CanvasEntityList = memo(() => {
return (
<ScrollableContent>
<Flex flexDir="column" gap={2} data-testid="control-layers-layer-list" w="full" h="full">
<IPAdapterList />
<InpaintMaskList />
<RegionalGuidanceEntityList />
<ControlLayerEntityList />

View File

@@ -1,7 +1,6 @@
import { IconButton, Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
import {
useAddControlLayer,
useAddGlobalReferenceImage,
useAddInpaintMask,
useAddRasterLayer,
useAddRegionalGuidance,
@@ -16,13 +15,11 @@ import { PiPlusBold } from 'react-icons/pi';
export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusy();
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addInpaintMask = useAddInpaintMask();
const addRegionalGuidance = useAddRegionalGuidance();
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
@@ -41,11 +38,6 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
isDisabled={isBusy}
/>
<MenuList>
<MenuGroup title={t('controlLayers.global')}>
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage} isDisabled={!isReferenceImageEnabled}>
{t('controlLayers.globalReferenceImage')}
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.regional')}>
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask} isDisabled={!isInpaintLayerEnabled}>
{t('controlLayers.inpaintMask')}

View File

@@ -22,7 +22,6 @@ import {
selectEntity,
selectSelectedEntityIdentifier,
} from 'features/controlLayers/store/selectors';
import { isRenderableEntity } from 'features/controlLayers/store/types';
import { clamp, round } from 'lodash-es';
import type { KeyboardEvent } from 'react';
import { memo, useCallback, useEffect, useState } from 'react';
@@ -70,9 +69,6 @@ const selectOpacity = createSelector(selectCanvasSlice, (canvas) => {
if (!selectedEntity) {
return 1; // fallback to 100% opacity
}
if (!isRenderableEntity(selectedEntity)) {
return 1; // fallback to 100% opacity
}
// Opacity is a float from 0-1, but we want to display it as a percentage
return selectedEntity.opacity;
});
@@ -134,11 +130,7 @@ export const EntityListSelectedEntityActionBarOpacity = memo(() => {
return (
<Popover>
<FormControl
w="min-content"
gap={2}
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'reference_image'}
>
<FormControl w="min-content" gap={2} isDisabled={selectedEntityIdentifier === null}>
<FormLabel m={0}>{t('controlLayers.opacity')}</FormLabel>
<PopoverAnchor>
<NumberInput
@@ -167,7 +159,7 @@ export const EntityListSelectedEntityActionBarOpacity = memo(() => {
position="absolute"
insetInlineEnd={0}
h="full"
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'reference_image'}
isDisabled={selectedEntityIdentifier === null}
/>
</PopoverTrigger>
</NumberInput>
@@ -185,7 +177,7 @@ export const EntityListSelectedEntityActionBarOpacity = memo(() => {
marks={marks}
formatValue={formatSliderValue}
alwaysShowMarks
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'reference_image'}
isDisabled={selectedEntityIdentifier === null}
/>
</PopoverBody>
</PopoverContent>

View File

@@ -4,31 +4,25 @@ import { CanvasEntityHeader } from 'features/controlLayers/components/common/Can
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
import { IPAdapterSettings } from 'features/controlLayers/components/IPAdapter/IPAdapterSettings';
import { CanvasEntityStateGate } from 'features/controlLayers/contexts/CanvasEntityStateGate';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { memo, useMemo } from 'react';
import { RefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { memo } from 'react';
type Props = {
id: string;
};
export const IPAdapter = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'reference_image' }), [id]);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityStateGate entityIdentifier={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader ps={4} py={5}>
<CanvasEntityEditableTitle />
<Spacer />
<CanvasEntityHeaderCommonActions />
</CanvasEntityHeader>
<IPAdapterSettings />
</CanvasEntityContainer>
</CanvasEntityStateGate>
</EntityIdentifierContext.Provider>
<RefImageIdContext.Provider value={id}>
<CanvasEntityContainer>
<CanvasEntityHeader ps={4} py={5}>
<CanvasEntityEditableTitle />
<Spacer />
<CanvasEntityHeaderCommonActions />
</CanvasEntityHeader>
<IPAdapterSettings />
</CanvasEntityContainer>
</RefImageIdContext.Provider>
);
});

View File

@@ -1,37 +1,37 @@
/* eslint-disable i18next/no-literal-string */
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
import { IPAdapter } from 'features/controlLayers/components/IPAdapter/IPAdapter';
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { RefImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterPreview';
import { RefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { selectRefImageEntityIds } from 'features/controlLayers/store/refImagesSlice';
import { memo } from 'react';
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.referenceImages.entities.map(getEntityIdentifier).toReversed();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
return selectedEntityIdentifier?.type === 'reference_image';
});
const sx: SystemStyleObject = {
opacity: 0.3,
_hover: {
opacity: 1,
},
transitionProperty: 'opacity',
transitionDuration: '0.2s',
};
export const IPAdapterList = memo(() => {
const isSelected = useAppSelector(selectIsSelected);
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
export const RefImageList = memo((props: FlexProps) => {
const ids = useAppSelector(selectRefImageEntityIds);
if (entityIdentifiers.length === 0) {
if (ids.length === 0) {
return null;
}
if (entityIdentifiers.length > 0) {
return (
<CanvasEntityGroupList type="reference_image" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
{entityIdentifiers.map((entityIdentifiers) => (
<IPAdapter key={entityIdentifiers.id} id={entityIdentifiers.id} />
))}
</CanvasEntityGroupList>
);
}
return (
<Flex gap={2} {...props}>
{ids.map((id) => (
<RefImageIdContext.Provider key={id} value={id}>
<RefImagePreview />
</RefImageIdContext.Provider>
))}
</Flex>
);
});
IPAdapterList.displayName = 'IPAdapterList';
RefImageList.displayName = 'RefImageList';

View File

@@ -1,5 +1,5 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { usePullBboxIntoGlobalReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { memo } from 'react';
@@ -8,8 +8,8 @@ import { PiBoundingBoxBold } from 'react-icons/pi';
export const IPAdapterMenuItemPullBbox = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext('reference_image');
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
const id = useRefImageIdContext();
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(id);
const isBusy = useCanvasIsBusy();
return (

View File

@@ -0,0 +1,75 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
Flex,
Image,
Popover,
PopoverAnchor,
PopoverArrow,
PopoverBody,
PopoverContent,
Portal,
} from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { useDisclosure } from 'common/hooks/useBoolean';
import { useFilterableOutsideClick } from 'common/hooks/useFilterableOutsideClick';
import { IPAdapterSettings } from 'features/controlLayers/components/IPAdapter/IPAdapterSettings';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { selectRefImageEntityOrThrow, selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { memo, useMemo, useRef } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
const sx: SystemStyleObject = {
opacity: 0.5,
_hover: {
opacity: 1,
},
"&[data-is-open='true']": {
opacity: 1,
pointerEvents: 'none',
},
transitionProperty: 'opacity',
transitionDuration: '0.2s',
};
export const RefImagePreview = memo(() => {
const id = useRefImageIdContext();
const ref = useRef<HTMLDivElement>(null);
const disclosure = useDisclosure(false);
const selectEntity = useMemo(
() =>
createSelector(selectRefImagesSlice, (refImages) =>
selectRefImageEntityOrThrow(refImages, id, 'RefImagePreview')
),
[id]
);
const entity = useAppSelector(selectEntity);
useFilterableOutsideClick({ ref, handler: disclosure.close });
return (
<Popover isLazy lazyBehavior="unmount" isOpen={disclosure.isOpen} closeOnBlur={false}>
<PopoverAnchor>
<Flex role="button" w={16} h={16} sx={sx} onClick={disclosure.open} data-is-open={disclosure.isOpen}>
<Thumbnail image={entity.ipAdapter.image} />
</Flex>
</PopoverAnchor>
<Portal>
<PopoverContent ref={ref}>
<PopoverArrow />
<PopoverBody>
<IPAdapterSettings />
</PopoverBody>
</PopoverContent>
</Portal>
</Popover>
);
});
RefImagePreview.displayName = 'RefImagePreview';
const Thumbnail = memo(({ image }: { image: ImageWithDims | null }) => {
const { data: imageDTO } = useGetImageDTOQuery(image?.image_name ?? skipToken);
return <Image src={imageDTO?.thumbnail_url} objectFit="contain" maxW="full" maxH="full" />;
});
Thumbnail.displayName = 'Thumbnail';

View File

@@ -1,17 +1,15 @@
import { Flex, IconButton } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { CLIPVisionModel } from 'features/controlLayers/components/common/CLIPVisionModel';
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/common/FLUXReduxImageInfluence';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
import { GlobalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/GlobalReferenceImageModel';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoGlobalReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import {
referenceImageIPAdapterBeginEndStepPctChanged,
referenceImageIPAdapterCLIPVisionModelChanged,
@@ -20,11 +18,11 @@ import {
referenceImageIPAdapterMethodChanged,
referenceImageIPAdapterModelChanged,
referenceImageIPAdapterWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntity, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
selectRefImageEntity,
selectRefImageEntityOrThrow,
selectRefImagesSlice,
} from 'features/controlLayers/store/refImagesSlice';
import type {
CanvasEntityIdentifier,
CLIPVisionModelV2,
FLUXReduxImageInfluence as FLUXReduxImageInfluenceType,
IPMethodV2,
@@ -33,141 +31,138 @@ import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold } from 'react-icons/pi';
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
const buildSelectIPAdapter = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) =>
const buildSelectIPAdapter = (id: string) =>
createSelector(
selectCanvasSlice,
(canvas) => selectEntityOrThrow(canvas, entityIdentifier, 'IPAdapterSettings').ipAdapter
selectRefImagesSlice,
(refImages) => selectRefImageEntityOrThrow(refImages, id, 'IPAdapterSettings').ipAdapter
);
const IPAdapterSettingsContent = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('reference_image');
const selectIPAdapter = useMemo(() => buildSelectIPAdapter(entityIdentifier), [entityIdentifier]);
const id = useRefImageIdContext();
const selectIPAdapter = useMemo(() => buildSelectIPAdapter(id), [id]);
const ipAdapter = useAppSelector(selectIPAdapter);
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
dispatch(referenceImageIPAdapterBeginEndStepPctChanged({ entityIdentifier, beginEndStepPct }));
dispatch(referenceImageIPAdapterBeginEndStepPctChanged({ id, beginEndStepPct }));
},
[dispatch, entityIdentifier]
[dispatch, id]
);
const onChangeWeight = useCallback(
(weight: number) => {
dispatch(referenceImageIPAdapterWeightChanged({ entityIdentifier, weight }));
dispatch(referenceImageIPAdapterWeightChanged({ id, weight }));
},
[dispatch, entityIdentifier]
[dispatch, id]
);
const onChangeIPMethod = useCallback(
(method: IPMethodV2) => {
dispatch(referenceImageIPAdapterMethodChanged({ entityIdentifier, method }));
dispatch(referenceImageIPAdapterMethodChanged({ id, method }));
},
[dispatch, entityIdentifier]
[dispatch, id]
);
const onChangeFLUXReduxImageInfluence = useCallback(
(imageInfluence: FLUXReduxImageInfluenceType) => {
dispatch(referenceImageIPAdapterFLUXReduxImageInfluenceChanged({ entityIdentifier, imageInfluence }));
dispatch(referenceImageIPAdapterFLUXReduxImageInfluenceChanged({ id, imageInfluence }));
},
[dispatch, entityIdentifier]
[dispatch, id]
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => {
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
dispatch(referenceImageIPAdapterModelChanged({ id, modelConfig }));
},
[dispatch, entityIdentifier]
[dispatch, id]
);
const onChangeCLIPVisionModel = useCallback(
(clipVisionModel: CLIPVisionModelV2) => {
dispatch(referenceImageIPAdapterCLIPVisionModelChanged({ entityIdentifier, clipVisionModel }));
dispatch(referenceImageIPAdapterCLIPVisionModelChanged({ id, clipVisionModel }));
},
[dispatch, entityIdentifier]
[dispatch, id]
);
const onChangeImage = useCallback(
(imageDTO: ImageDTO | null) => {
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier, imageDTO }));
dispatch(referenceImageIPAdapterImageChanged({ id, imageDTO }));
},
[dispatch, entityIdentifier]
[dispatch, id]
);
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
() => setGlobalReferenceImageDndTarget.getData({ entityIdentifier }, ipAdapter.image?.image_name),
[entityIdentifier, ipAdapter.image?.image_name]
() => setGlobalReferenceImageDndTarget.getData({ id }, ipAdapter.image?.image_name),
[id, ipAdapter.image?.image_name]
);
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
const isBusy = useCanvasIsBusy();
// const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(id);
// const isBusy = useCanvasIsBusy();
const isFLUX = useAppSelector(selectIsFLUX);
return (
<CanvasEntitySettingsWrapper>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<GlobalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}
<IconButton
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<GlobalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}
{/* <IconButton
onClick={pullBboxIntoIPAdapter}
isDisabled={isBusy}
variant="ghost"
aria-label={t('controlLayers.pullBboxIntoReferenceImage')}
tooltip={t('controlLayers.pullBboxIntoReferenceImage')}
icon={<PiBoundingBoxBold />}
/>
</Flex>
<Flex gap={2} w="full">
{ipAdapter.type === 'ip_adapter' && (
<Flex flexDir="column" gap={2} w="full">
{!isFLUX && <IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />}
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
)}
{ipAdapter.type === 'flux_redux' && (
<Flex flexDir="column" gap={2} w="full" alignItems="flex-start">
<FLUXReduxImageInfluence
imageInfluence={ipAdapter.imageInfluence ?? 'lowest'}
onChange={onChangeFLUXReduxImageInfluence}
/>
</Flex>
)}
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
<IPAdapterImagePreview
image={ipAdapter.image}
onChangeImage={onChangeImage}
dndTarget={setGlobalReferenceImageDndTarget}
dndTargetData={dndTargetData}
/> */}
</Flex>
<Flex gap={2} w="full">
{ipAdapter.type === 'ip_adapter' && (
<Flex flexDir="column" gap={2} w="full">
{!isFLUX && <IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />}
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
)}
{ipAdapter.type === 'flux_redux' && (
<Flex flexDir="column" gap={2} w="full" alignItems="flex-start">
<FLUXReduxImageInfluence
imageInfluence={ipAdapter.imageInfluence ?? 'lowest'}
onChange={onChangeFLUXReduxImageInfluence}
/>
</Flex>
)}
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
<IPAdapterImagePreview
image={ipAdapter.image}
onChangeImage={onChangeImage}
dndTarget={setGlobalReferenceImageDndTarget}
dndTargetData={dndTargetData}
/>
</Flex>
</Flex>
</CanvasEntitySettingsWrapper>
</Flex>
);
});
IPAdapterSettingsContent.displayName = 'IPAdapterSettingsContent';
const buildSelectIPAdapterHasImage = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) =>
createSelector(selectCanvasSlice, (canvas) => {
const referenceImage = selectEntity(canvas, entityIdentifier);
const buildSelectIPAdapterHasImage = (id: string) =>
createSelector(selectRefImagesSlice, (refImages) => {
const referenceImage = selectRefImageEntity(refImages, id);
return !!referenceImage && referenceImage.ipAdapter.image !== null;
});
export const IPAdapterSettings = memo(() => {
const entityIdentifier = useEntityIdentifierContext('reference_image');
const id = useRefImageIdContext();
const selectIPAdapterHasImage = useMemo(() => buildSelectIPAdapterHasImage(entityIdentifier), [entityIdentifier]);
const selectIPAdapterHasImage = useMemo(() => buildSelectIPAdapterHasImage(id), [id]);
const hasImage = useAppSelector(selectIPAdapterHasImage);
if (!hasImage) {

View File

@@ -1,7 +1,7 @@
import { Button, Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { usePullBboxIntoGlobalReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
@@ -15,24 +15,24 @@ import type { ImageDTO } from 'services/api/types';
export const IPAdapterSettingsEmptyState = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext('reference_image');
const id = useRefImageIdContext();
const dispatch = useAppDispatch();
const isBusy = useCanvasIsBusy();
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
setGlobalReferenceImage({ imageDTO, entityIdentifier, dispatch });
setGlobalReferenceImage({ imageDTO, id, dispatch });
},
[dispatch, entityIdentifier]
[dispatch, id]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(id);
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
() => setGlobalReferenceImageDndTarget.getData({ entityIdentifier }),
[entityIdentifier]
() => setGlobalReferenceImageDndTarget.getData({ id }),
[id]
);
const components = useMemo(

View File

@@ -2,13 +2,13 @@ import { Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { CLIPVisionModel } from 'features/controlLayers/components/common/CLIPVisionModel';
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/common/FLUXReduxImageInfluence';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { RegionalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/RegionalReferenceImageModel';
import { RegionalGuidanceIPAdapterSettingsEmptyState } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState';
import { RegionalReferenceImageModel } from 'features/controlLayers/components/RegionalGuidance/RegionalReferenceImageModel';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoRegionalGuidanceReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';

View File

@@ -2,7 +2,6 @@ import { IconButton } from '@invoke-ai/ui-library';
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
import {
useAddControlLayer,
useAddGlobalReferenceImage,
useAddInpaintMask,
useAddRasterLayer,
useAddRegionalGuidance,
@@ -23,7 +22,6 @@ export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
const addRegionalGuidance = useAddRegionalGuidance();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const onClick = useCallback(() => {
switch (type) {
@@ -39,11 +37,8 @@ export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
case 'control_layer':
addControlLayer();
break;
case 'reference_image':
addGlobalReferenceImage();
break;
}
}, [addControlLayer, addGlobalReferenceImage, addInpaintMask, addRasterLayer, addRegionalGuidance, type]);
}, [addControlLayer, addInpaintMask, addRasterLayer, addRegionalGuidance, type]);
const label = useMemo(() => {
switch (type) {
@@ -55,8 +50,6 @@ export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
return t('controlLayers.addRasterLayer');
case 'control_layer':
return t('controlLayers.addControlLayer');
case 'reference_image':
return t('controlLayers.addGlobalReferenceImage');
}
}, [type, t]);

View File

@@ -4,17 +4,14 @@ import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/com
import { CanvasEntityHeaderWarnings } from 'features/controlLayers/components/common/CanvasEntityHeaderWarnings';
import { CanvasEntityIsBookmarkedForQuickSwitchToggle } from 'features/controlLayers/components/common/CanvasEntityIsBookmarkedForQuickSwitchToggle';
import { CanvasEntityIsLockedToggle } from 'features/controlLayers/components/common/CanvasEntityIsLockedToggle';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { memo } from 'react';
export const CanvasEntityHeaderCommonActions = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
return (
<Flex alignSelf="stretch">
<CanvasEntityHeaderWarnings />
<CanvasEntityIsBookmarkedForQuickSwitchToggle />
{entityIdentifier.type !== 'reference_image' && <CanvasEntityIsLockedToggle />}
<CanvasEntityIsLockedToggle />
<CanvasEntityEnabledToggle />
<CanvasEntityDeleteButton />
</Flex>

View File

@@ -39,11 +39,6 @@ const getIndexAndCount = (
index: canvas.inpaintMasks.entities.findIndex((entity) => entity.id === id),
count: canvas.inpaintMasks.entities.length,
};
} else if (type === 'reference_image') {
return {
index: canvas.referenceImages.entities.findIndex((entity) => entity.id === id),
count: canvas.referenceImages.entities.length,
};
} else {
return {
index: -1,

View File

@@ -3,7 +3,7 @@ import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerP
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useEntityIdentifierBelowThisOne } from 'features/controlLayers/hooks/useNextRenderableEntityIdentifier';
import type { CanvasRenderableEntityType } from 'features/controlLayers/store/types';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiStackSimpleBold } from 'react-icons/pi';
@@ -12,7 +12,7 @@ export const CanvasEntityMenuItemsMergeDown = memo(() => {
const { t } = useTranslation();
const canvasManager = useCanvasManager();
const isBusy = useCanvasIsBusy();
const entityIdentifier = useEntityIdentifierContext<CanvasRenderableEntityType>();
const entityIdentifier = useEntityIdentifierContext<CanvasEntityType>();
const entityIdentifierBelowThisOne = useEntityIdentifierBelowThisOne(entityIdentifier);
const mergeDown = useCallback(() => {
if (entityIdentifierBelowThisOne === null) {

View File

@@ -2,13 +2,13 @@ import { IconButton } from '@invoke-ai/ui-library';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useVisibleEntityCountByType } from 'features/controlLayers/hooks/useVisibleEntityCountByType';
import type { CanvasRenderableEntityType } from 'features/controlLayers/store/types';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiStackBold } from 'react-icons/pi';
type Props = {
type: CanvasRenderableEntityType;
type: CanvasEntityType;
};
export const CanvasEntityMergeVisibleButton = memo(({ type }: Props) => {

View File

@@ -5,7 +5,7 @@ import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konv
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRegionalGuidance';
import type { CanvasEntityAdapterFromType } from 'features/controlLayers/konva/CanvasEntity/types';
import type { CanvasEntityIdentifier, CanvasRenderableEntityType } from 'features/controlLayers/store/types';
import type { CanvasEntityIdentifier, CanvasEntityType } from 'features/controlLayers/store/types';
import type { PropsWithChildren } from 'react';
import { createContext, memo, useContext, useMemo, useSyncExternalStore } from 'react';
import { assert } from 'tsafe';
@@ -96,15 +96,15 @@ export const RegionalGuidanceAdapterGate = memo(({ children }: PropsWithChildren
return <EntityAdapterContext.Provider value={adapter}>{children}</EntityAdapterContext.Provider>;
});
export const useEntityAdapterContext = <T extends CanvasRenderableEntityType | undefined = CanvasRenderableEntityType>(
export const useEntityAdapterContext = <T extends CanvasEntityType | undefined = CanvasEntityType>(
type?: T
): CanvasEntityAdapterFromType<T extends undefined ? CanvasRenderableEntityType : T> => {
): CanvasEntityAdapterFromType<T extends undefined ? CanvasEntityType : T> => {
const adapter = useContext(EntityAdapterContext);
assert(adapter, 'useEntityIdentifier must be used within a EntityIdentifierProvider');
if (type) {
assert(adapter.entityIdentifier.type === type, 'useEntityIdentifier must be used with the correct type');
}
return adapter as CanvasEntityAdapterFromType<T extends undefined ? CanvasRenderableEntityType : T>;
return adapter as CanvasEntityAdapterFromType<T extends undefined ? CanvasEntityType : T>;
};
RegionalGuidanceAdapterGate.displayName = 'RegionalGuidanceAdapterGate';

View File

@@ -0,0 +1,10 @@
import { createContext, useContext } from 'react';
import { assert } from 'tsafe';
export const RefImageIdContext = createContext<string | null>(null);
export const useRefImageIdContext = (): string => {
const id = useContext(RefImageIdContext);
assert(id, 'useRefImageIdContext must be used within a RefImageIdContext.Provider');
return id;
};

View File

@@ -9,13 +9,13 @@ import {
inpaintMaskDenoiseLimitAdded,
inpaintMaskNoiseAdded,
rasterLayerAdded,
referenceImageAdded,
rgAdded,
rgIPAdapterAdded,
rgNegativePromptChanged,
rgPositivePromptChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectBase, selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import { referenceImageAdded } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,

View File

@@ -9,8 +9,6 @@ import {
controlLayerAdded,
entityRasterized,
rasterLayerAdded,
referenceImageAdded,
referenceImageIPAdapterImageChanged,
rgAdded,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasSlice';
@@ -20,6 +18,7 @@ import {
selectPositivePrompt,
selectSeed,
} from 'features/controlLayers/store/paramsSlice';
import { referenceImageAdded, referenceImageIPAdapterImageChanged } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasMetadata } from 'features/controlLayers/store/selectors';
import type {
CanvasControlLayerState,
@@ -306,13 +305,13 @@ export const usePullBboxIntoLayer = (entityIdentifier: CanvasEntityIdentifier<'c
return func;
};
export const usePullBboxIntoGlobalReferenceImage = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) => {
export const usePullBboxIntoGlobalReferenceImage = (id: string) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const arg = useMemo<UseSaveCanvasArg>(() => {
const onSave = (imageDTO: ImageDTO, _: Rect) => {
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier, imageDTO }));
dispatch(referenceImageIPAdapterImageChanged({ id, imageDTO }));
};
return {
@@ -322,7 +321,7 @@ export const usePullBboxIntoGlobalReferenceImage = (entityIdentifier: CanvasEnti
toastOk: t('controlLayers.pullBboxIntoReferenceImageOk'),
toastError: t('controlLayers.pullBboxIntoReferenceImageError'),
};
}, [dispatch, entityIdentifier, t]);
}, [dispatch, id, t]);
const func = useSaveCanvas(arg);
return func;

View File

@@ -32,8 +32,6 @@ export const useEntityTypeName = (type: CanvasEntityIdentifier['type']) => {
return t('controlLayers.controlLayer');
case 'raster_layer':
return t('controlLayers.rasterLayer');
case 'reference_image':
return t('controlLayers.globalReferenceImage');
case 'regional_guidance':
return t('controlLayers.regionalGuidance');
default:

View File

@@ -17,8 +17,6 @@ export const useEntityTypeCount = (type: CanvasEntityIdentifier['type']): number
return canvas.inpaintMasks.entities.length;
case 'regional_guidance':
return canvas.regionalGuidance.entities.length;
case 'reference_image':
return canvas.referenceImages.entities.length;
default:
return 0;
}

View File

@@ -13,9 +13,6 @@ export const useEntityTypeInformationalPopover = (type: CanvasEntityIdentifier['
return 'rasterLayer';
case 'regional_guidance':
return 'regionalGuidanceAndReferenceImage';
case 'reference_image':
return 'globalReferenceImage';
default:
return undefined;
}

View File

@@ -17,7 +17,6 @@ export const useEntityTypeIsHidden = (type: CanvasEntityIdentifier['type']): boo
return canvas.inpaintMasks.isHidden;
case 'regional_guidance':
return canvas.regionalGuidance.isHidden;
case 'reference_image':
default:
return false;
}

View File

@@ -15,10 +15,6 @@ export const useEntityTypeString = (type: CanvasEntityIdentifier['type'], plural
return plural ? t('controlLayers.inpaintMask_withCount_other') : t('controlLayers.inpaintMask');
case 'regional_guidance':
return plural ? t('controlLayers.regionalGuidance_withCount_other') : t('controlLayers.regionalGuidance');
case 'reference_image':
return plural
? t('controlLayers.globalReferenceImage_withCount_other')
: t('controlLayers.globalReferenceImage');
default:
return '';
}

View File

@@ -21,8 +21,6 @@ export const useEntityTypeTitle = (type: CanvasEntityIdentifier['type']): string
return t('controlLayers.inpaintMasks_withCount', { count, context });
case 'regional_guidance':
return t('controlLayers.regionalGuidance_withCount', { count, context });
case 'reference_image':
return t('controlLayers.globalReferenceImages_withCount', { count, context });
default:
return '';
}

View File

@@ -7,7 +7,6 @@ import {
selectIsImagen4,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import { selectActiveReferenceImageEntities } from 'features/controlLayers/store/selectors';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
import type { Equals } from 'tsafe';
@@ -20,15 +19,9 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isImagen4 = useAppSelector(selectIsImagen4);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isFluxKontext = useAppSelector(selectIsFluxKontext);
const activeReferenceImageEntities = useAppSelector(selectActiveReferenceImageEntities);
const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'reference_image':
if (isFluxKontext) {
return activeReferenceImageEntities.length === 0;
}
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4;
case 'regional_guidance':
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
case 'control_layer':
@@ -40,7 +33,7 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
default:
assert<Equals<typeof entityType, never>>(false);
}
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isFluxKontext, isChatGPT4o, activeReferenceImageEntities]);
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isFluxKontext, isChatGPT4o]);
return isEntityTypeEnabled;
};

View File

@@ -1,11 +1,11 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasSlice, selectEntityIdentifierBelowThisOne } from 'features/controlLayers/store/selectors';
import type { CanvasRenderableEntityIdentifier } from 'features/controlLayers/store/types';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
export const useEntityIdentifierBelowThisOne = <T extends CanvasRenderableEntityIdentifier>(
export const useEntityIdentifierBelowThisOne = <T extends CanvasEntityIdentifier>(
entityIdentifier: T
): T | null => {
const selector = useMemo(

View File

@@ -4,7 +4,6 @@ import {
selectActiveControlLayerEntities,
selectActiveInpaintMaskEntities,
selectActiveRasterLayerEntities,
selectActiveReferenceImageEntities,
selectActiveRegionalGuidanceEntities,
} from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
@@ -22,8 +21,6 @@ export const useVisibleEntityCountByType = (type: CanvasEntityIdentifier['type']
return createSelector(selectActiveInpaintMaskEntities, (entities) => entities.length);
case 'regional_guidance':
return createSelector(selectActiveRegionalGuidanceEntities, (entities) => entities.length);
case 'reference_image':
return createSelector(selectActiveReferenceImageEntities, (entities) => entities.length);
default:
assert(false, 'Invalid entity type');
}

View File

@@ -21,9 +21,9 @@ import {
selectActiveRegionalGuidanceEntities,
} from 'features/controlLayers/store/selectors';
import type {
CanvasRenderableEntityIdentifier,
CanvasRenderableEntityState,
CanvasRenderableEntityType,
CanvasEntityIdentifier,
CanvasEntityState,
CanvasEntityType,
GenerationMode,
Rect,
} from 'features/controlLayers/store/types';
@@ -91,7 +91,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
* @param type The optional entity type
* @returns The rect
*/
getVisibleRectOfType = (type?: CanvasRenderableEntityType): Rect => {
getVisibleRectOfType = (type?: CanvasEntityType): Rect => {
const rects = [];
for (const adapter of this.manager.getAllAdapters()) {
@@ -139,8 +139,8 @@ export class CanvasCompositorModule extends CanvasModuleBase {
* @param type The entity type
* @returns The adapters for the given entity type that are eligible to be included in a composite
*/
getVisibleAdaptersOfType = <T extends CanvasRenderableEntityType>(type: T): CanvasEntityAdapterFromType<T>[] => {
let entities: CanvasRenderableEntityState[];
getVisibleAdaptersOfType = <T extends CanvasEntityType>(type: T): CanvasEntityAdapterFromType<T>[] => {
let entities: CanvasEntityState[];
switch (type) {
case 'raster_layer':
@@ -327,7 +327,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
* @param deleteMergedEntities Whether to delete the merged entities after creating the new merged entity
* @returns A promise that resolves to the image DTO, or null if the merge failed
*/
mergeByEntityIdentifiers = async <T extends CanvasRenderableEntityIdentifier>(
mergeByEntityIdentifiers = async <T extends CanvasEntityIdentifier>(
entityIdentifiers: T[],
deleteMergedEntities: boolean
): Promise<ImageDTO | null> => {
@@ -402,8 +402,8 @@ export class CanvasCompositorModule extends CanvasModuleBase {
* @param type The type of entity to merge
* @returns A promise that resolves to the image DTO, or null if the merge failed
*/
mergeVisibleOfType = (type: CanvasRenderableEntityType): Promise<ImageDTO | null> => {
let entities: CanvasRenderableEntityState[];
mergeVisibleOfType = (type: CanvasEntityType): Promise<ImageDTO | null> => {
let entities: CanvasEntityState[];
switch (type) {
case 'raster_layer':

View File

@@ -26,7 +26,7 @@ import {
} from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,
CanvasRenderableEntityState,
CanvasEntityState,
LifecycleCallback,
Rect,
} from 'features/controlLayers/store/types';
@@ -42,7 +42,7 @@ import { assert } from 'tsafe';
import type { Jsonifiable, JsonObject } from 'type-fest';
export abstract class CanvasEntityAdapterBase<
T extends CanvasRenderableEntityState,
T extends CanvasEntityState,
U extends string,
> extends CanvasModuleBase {
readonly type: U;

View File

@@ -9,7 +9,7 @@ import { addCoords, getKonvaNodeDebugAttrs, getPrefixedId } from 'features/contr
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
import type { FilterConfig } from 'features/controlLayers/store/filters';
import { getFilterForModel, IMAGE_FILTERS } from 'features/controlLayers/store/filters';
import type { CanvasImageState, CanvasRenderableEntityType } from 'features/controlLayers/store/types';
import type { CanvasImageState, CanvasEntityType } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { toast } from 'features/toast/toast';
import Konva from 'konva';
@@ -373,7 +373,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
* Saves the filtered image as a new entity of the given type.
* @param type The type of entity to save the filtered image as.
*/
saveAs = (type: CanvasRenderableEntityType) => {
saveAs = (type: CanvasEntityType) => {
const imageState = this.$imageState.get();
if (!imageState) {
this.log.warn('No image state to apply filter to');

View File

@@ -2,7 +2,7 @@ import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/kon
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRegionalGuidance';
import type { CanvasRenderableEntityType } from 'features/controlLayers/store/types';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
export type CanvasEntityAdapter =
| CanvasEntityAdapterRasterLayer
@@ -10,7 +10,7 @@ export type CanvasEntityAdapter =
| CanvasEntityAdapterInpaintMask
| CanvasEntityAdapterRegionalGuidance;
export type CanvasEntityAdapterFromType<T extends CanvasRenderableEntityType> = Extract<
export type CanvasEntityAdapterFromType<T extends CanvasEntityType> = Extract<
CanvasEntityAdapter,
{ state: { type: T } }
>;

View File

@@ -16,11 +16,7 @@ import { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/Canvas
import { CanvasWorkerModule } from 'features/controlLayers/konva/CanvasWorkerModule.js';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import type {
CanvasEntityIdentifier,
CanvasRenderableEntityIdentifier,
CanvasRenderableEntityType,
} from 'features/controlLayers/store/types';
import type { CanvasEntityIdentifier, CanvasEntityType } from 'features/controlLayers/store/types';
import {
isControlLayerEntityIdentifier,
isInpaintMaskEntityIdentifier,
@@ -135,7 +131,7 @@ export class CanvasManager extends CanvasModuleBase {
this.konva.previewLayer.add(this.tool.konva.group);
}
getAdapter = <T extends CanvasRenderableEntityType = CanvasRenderableEntityType>(
getAdapter = <T extends CanvasEntityType = CanvasEntityType>(
entityIdentifier: CanvasEntityIdentifier<T>
): CanvasEntityAdapterFromType<T> | null => {
let adapter: CanvasEntityAdapter | undefined;
@@ -163,7 +159,7 @@ export class CanvasManager extends CanvasModuleBase {
return adapter as CanvasEntityAdapterFromType<T>;
};
deleteAdapter = (entityIdentifier: CanvasRenderableEntityIdentifier): boolean => {
deleteAdapter = (entityIdentifier: CanvasEntityIdentifier): boolean => {
switch (entityIdentifier.type) {
case 'raster_layer':
return this.adapters.rasterLayers.delete(entityIdentifier.id);
@@ -178,7 +174,7 @@ export class CanvasManager extends CanvasModuleBase {
}
};
getAdapters = (entityIdentifiers: CanvasRenderableEntityIdentifier[]): CanvasEntityAdapter[] => {
getAdapters = (entityIdentifiers: CanvasEntityIdentifier[]): CanvasEntityAdapter[] => {
const adapters: CanvasEntityAdapter[] = [];
for (const entityIdentifier of entityIdentifiers) {
const adapter = this.getAdapter(entityIdentifier);
@@ -199,7 +195,7 @@ export class CanvasManager extends CanvasModuleBase {
];
};
createAdapter = (entityIdentifier: CanvasRenderableEntityIdentifier): CanvasEntityAdapter => {
createAdapter = (entityIdentifier: CanvasEntityIdentifier): CanvasEntityAdapter => {
if (isRasterLayerEntityIdentifier(entityIdentifier)) {
const adapter = new CanvasEntityAdapterRasterLayer(entityIdentifier, this);
this.adapters.rasterLayers.set(adapter.id, adapter);

View File

@@ -16,7 +16,7 @@ import {
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
import type {
CanvasImageState,
CanvasRenderableEntityType,
CanvasEntityType,
Coordinate,
RgbaColor,
SAMPointLabel,
@@ -703,7 +703,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
* Saves the segmented image as a new entity of the given type.
* @param type The type of entity to save the segmented image as.
*/
saveAs = (type: CanvasRenderableEntityType) => {
saveAs = (type: CanvasEntityType) => {
const imageState = this.$imageState.get();
if (!imageState) {
this.log.error('No image state to save as');

View File

@@ -48,7 +48,7 @@ import type {
Rect,
RgbaColor,
} from 'features/controlLayers/store/types';
import { isRenderableEntityIdentifier, RGBA_BLACK } from 'features/controlLayers/store/types';
import { RGBA_BLACK } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
@@ -576,9 +576,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
if (!state.selectedEntityIdentifier) {
return null;
}
if (!isRenderableEntityIdentifier(state.selectedEntityIdentifier)) {
return null;
}
return this.manager.getAdapter(state.selectedEntityIdentifier);
};

View File

@@ -22,7 +22,6 @@ import type {
Coordinate,
Tool,
} from 'features/controlLayers/store/types';
import { isRenderableEntityType } from 'features/controlLayers/store/types';
import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import { atom } from 'nanostores';
@@ -180,7 +179,7 @@ export class CanvasToolModule extends CanvasModuleBase {
this.tools.bbox.syncCursorStyle();
} else if (tool === 'colorPicker') {
this.tools.colorPicker.syncCursorStyle();
} else if (selectedEntityAdapter && isRenderableEntityType(selectedEntityAdapter.entityIdentifier.type)) {
} else if (selectedEntityAdapter) {
if (selectedEntityAdapter.$isDisabled.get()) {
stage.setCursor('not-allowed');
} else if (selectedEntityAdapter.$isEntityTypeHidden.get()) {

View File

@@ -36,10 +36,9 @@ import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants'
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect } from 'konva/lib/types';
import { isEqual, merge } from 'lodash-es';
import { merge } from 'lodash-es';
import type { UndoableOptions } from 'redux-undo';
import type {
ApiModelConfig,
ControlLoRAModelConfig,
ControlNetModelConfig,
FLUXReduxModelConfig,
@@ -47,7 +46,6 @@ import type {
IPAdapterModelConfig,
T2IAdapterModelConfig,
} from 'services/api/types';
import { assert } from 'tsafe';
import type {
AspectRatioID,
@@ -55,7 +53,6 @@ import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
CanvasState,
CLIPVisionModelV2,
@@ -77,20 +74,16 @@ import {
isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID,
isImagenAspectRatioID,
isRenderableEntity,
} from './types';
import {
converters,
getControlLayerState,
getInpaintMaskState,
getRasterLayerState,
getReferenceImageState,
getRegionalGuidanceState,
imageDTOToImageWithDims,
initialChatGPT4oReferenceImage,
initialControlLoRA,
initialControlNet,
initialFluxKontextReferenceImage,
initialFLUXRedux,
initialIPAdapter,
initialT2IAdapter,
@@ -560,204 +553,6 @@ export const canvasSlice = createSlice({
}
layer.withTransparencyEffect = !layer.withTransparencyEffect;
},
//#region Global Reference Images
referenceImageAdded: {
reducer: (
state,
action: PayloadAction<{
id: string;
overrides?: Partial<CanvasReferenceImageState>;
isSelected?: boolean;
isBookmarked?: boolean;
}>
) => {
const { id, overrides, isSelected, isBookmarked } = action.payload;
const entityState = getReferenceImageState(id, overrides);
state.referenceImages.entities.push(entityState);
const entityIdentifier = getEntityIdentifier(entityState);
if (isSelected) {
state.selectedEntityIdentifier = entityIdentifier;
}
if (isBookmarked) {
state.bookmarkedEntityIdentifier = entityIdentifier;
}
},
prepare: (payload?: {
overrides?: Partial<CanvasReferenceImageState>;
isSelected?: boolean;
isBookmarked?: boolean;
}) => ({
payload: { ...payload, id: getPrefixedId('reference_image') },
}),
},
referenceImageRecalled: (state, action: PayloadAction<{ data: CanvasReferenceImageState }>) => {
const { data } = action.payload;
state.referenceImages.entities.push(data);
state.selectedEntityIdentifier = { type: 'reference_image', id: data.id };
},
referenceImageIPAdapterImageChanged: (
state,
action: PayloadAction<EntityIdentifierPayload<{ imageDTO: ImageDTO | null }, 'reference_image'>>
) => {
const { entityIdentifier, imageDTO } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
}
entity.ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
},
referenceImageIPAdapterMethodChanged: (
state,
action: PayloadAction<EntityIdentifierPayload<{ method: IPMethodV2 }, 'reference_image'>>
) => {
const { entityIdentifier, method } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.method = method;
},
referenceImageIPAdapterFLUXReduxImageInfluenceChanged: (
state,
action: PayloadAction<EntityIdentifierPayload<{ imageInfluence: FLUXReduxImageInfluence }, 'reference_image'>>
) => {
const { entityIdentifier, imageInfluence } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'flux_redux') {
return;
}
entity.ipAdapter.imageInfluence = imageInfluence;
},
referenceImageIPAdapterModelChanged: (
state,
action: PayloadAction<
EntityIdentifierPayload<
{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null },
'reference_image'
>
>
) => {
const { entityIdentifier, modelConfig } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
}
const oldModel = entity.ipAdapter.model;
// First set the new model
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
if (!entity.ipAdapter.model) {
return;
}
if (isEqual(oldModel, entity.ipAdapter.model)) {
// Nothing changed, so we don't need to do anything
return;
}
// The type of ref image depends on the model. When the user switches the model, we rebuild the ref image.
// When we switch the model, we keep the image the same, but change the other parameters.
if (entity.ipAdapter.model.base === 'chatgpt-4o') {
// Switching to chatgpt-4o ref image
entity.ipAdapter = {
...initialChatGPT4oReferenceImage,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.model.base === 'flux-kontext') {
// Switching to flux-kontext
entity.ipAdapter = {
...initialFluxKontextReferenceImage,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.model.type === 'flux_redux') {
// Switching to flux_redux
entity.ipAdapter = {
...initialFLUXRedux,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.model.type === 'ip_adapter') {
// Switching to ip_adapter
entity.ipAdapter = {
...initialIPAdapter,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (entity.ipAdapter.model?.base === 'flux') {
entity.ipAdapter.clipVisionModel = 'ViT-L';
} else if (entity.ipAdapter.clipVisionModel === 'ViT-L') {
// Fall back to ViT-H (ViT-G would also work)
entity.ipAdapter.clipVisionModel = 'ViT-H';
}
return;
}
},
referenceImageIPAdapterCLIPVisionModelChanged: (
state,
action: PayloadAction<EntityIdentifierPayload<{ clipVisionModel: CLIPVisionModelV2 }, 'reference_image'>>
) => {
const { entityIdentifier, clipVisionModel } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.clipVisionModel = clipVisionModel;
},
referenceImageIPAdapterWeightChanged: (
state,
action: PayloadAction<EntityIdentifierPayload<{ weight: number }, 'reference_image'>>
) => {
const { entityIdentifier, weight } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.weight = weight;
},
referenceImageIPAdapterBeginEndStepPctChanged: (
state,
action: PayloadAction<EntityIdentifierPayload<{ beginEndStepPct: [number, number] }, 'reference_image'>>
) => {
const { entityIdentifier, beginEndStepPct } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.beginEndStepPct = beginEndStepPct;
},
//#region Regional Guidance
rgAdded: {
reducer: (
@@ -1466,13 +1261,10 @@ export const canvasSlice = createSlice({
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
} else if (isRenderableEntity(entity)) {
entity.isEnabled = true;
entity.objects = [];
entity.position = { x: 0, y: 0 };
} else {
assert(false, 'Not implemented');
}
entity.isEnabled = true;
entity.objects = [];
entity.position = { x: 0, y: 0 };
},
entityDuplicated: (state, action: PayloadAction<EntityIdentifierPayload>) => {
const { entityIdentifier } = action.payload;
@@ -1501,10 +1293,6 @@ export const canvasSlice = createSlice({
}
state.regionalGuidance.entities.push(newEntity);
break;
case 'reference_image':
newEntity.id = getPrefixedId('reference_image');
state.referenceImages.entities.push(newEntity);
break;
case 'inpaint_mask':
newEntity.id = getPrefixedId('inpaint_mask');
state.inpaintMasks.entities.push(newEntity);
@@ -1558,9 +1346,7 @@ export const canvasSlice = createSlice({
return;
}
if (isRenderableEntity(entity)) {
entity.position = position;
}
entity.position = position;
},
entityMovedBy: (state, action: PayloadAction<EntityMovedByPayload>) => {
const { entityIdentifier, offset } = action.payload;
@@ -1569,10 +1355,6 @@ export const canvasSlice = createSlice({
return;
}
if (!isRenderableEntity(entity)) {
return;
}
entity.position.x += offset.x;
entity.position.y += offset.y;
},
@@ -1583,11 +1365,9 @@ export const canvasSlice = createSlice({
return;
}
if (isRenderableEntity(entity)) {
if (replaceObjects) {
entity.objects = [imageObject];
entity.position = position;
}
if (replaceObjects) {
entity.objects = [imageObject];
entity.position = position;
}
if (isSelected) {
@@ -1601,10 +1381,6 @@ export const canvasSlice = createSlice({
return;
}
if (!isRenderableEntity(entity)) {
assert(false, `Cannot add a brush line to a non-drawable entity of type ${entity.type}`);
}
// TODO(psyche): If we add the object without splatting, the renderer will see it as the same object and not
// re-render it (reference equality check). I don't like this behaviour.
entity.objects.push({
@@ -1620,10 +1396,6 @@ export const canvasSlice = createSlice({
return;
}
if (!isRenderableEntity(entity)) {
assert(false, `Cannot add a eraser line to a non-drawable entity of type ${entity.type}`);
}
// TODO(psyche): If we add the object without splatting, the renderer will see it as the same object and not
// re-render it (reference equality check). I don't like this behaviour.
entity.objects.push({
@@ -1639,10 +1411,6 @@ export const canvasSlice = createSlice({
return;
}
if (!isRenderableEntity(entity)) {
assert(false, `Cannot add a rect to a non-drawable entity of type ${entity.type}`);
}
// TODO(psyche): If we add the object without splatting, the renderer will see it as the same object and not
// re-render it (reference equality check). I don't like this behaviour.
entity.objects.push({ ...rect });
@@ -1673,9 +1441,6 @@ export const canvasSlice = createSlice({
(rg) => rg.id !== entityIdentifier.id
);
break;
case 'reference_image':
state.referenceImages.entities = state.referenceImages.entities.filter((rg) => rg.id !== entityIdentifier.id);
break;
case 'inpaint_mask':
state.inpaintMasks.entities = state.inpaintMasks.entities.filter((rg) => rg.id !== entityIdentifier.id);
break;
@@ -1747,12 +1512,6 @@ export const canvasSlice = createSlice({
entityIdentifiers as CanvasEntityIdentifier<'regional_guidance'>[]
);
break;
case 'reference_image':
state.referenceImages.entities = reorderEntities(
state.referenceImages.entities,
entityIdentifiers as CanvasEntityIdentifier<'reference_image'>[]
);
break;
}
},
entityOpacityChanged: (state, action: PayloadAction<EntityIdentifierPayload<{ opacity: number }>>) => {
@@ -1761,9 +1520,6 @@ export const canvasSlice = createSlice({
if (!entity) {
return;
}
if (entity.type === 'reference_image') {
return;
}
entity.opacity = opacity;
},
allEntitiesOfTypeIsHiddenToggled: (state, action: PayloadAction<{ type: CanvasEntityIdentifier['type'] }>) => {
@@ -1782,9 +1538,6 @@ export const canvasSlice = createSlice({
case 'regional_guidance':
state.regionalGuidance.isHidden = !state.regionalGuidance.isHidden;
break;
case 'reference_image':
// no-op
break;
}
},
allEntitiesDeleted: (state) => {
@@ -1794,14 +1547,12 @@ export const canvasSlice = createSlice({
state.controlLayers = initialState.controlLayers;
state.inpaintMasks = initialState.inpaintMasks;
state.regionalGuidance = initialState.regionalGuidance;
state.referenceImages = initialState.referenceImages;
},
canvasMetadataRecalled: (state, action: PayloadAction<CanvasMetadata>) => {
const { controlLayers, inpaintMasks, rasterLayers, referenceImages, regionalGuidance } = action.payload;
const { controlLayers, inpaintMasks, rasterLayers, regionalGuidance } = action.payload;
state.controlLayers.entities = controlLayers;
state.inpaintMasks.entities = inpaintMasks;
state.rasterLayers.entities = rasterLayers;
state.referenceImages.entities = referenceImages;
state.regionalGuidance.entities = regionalGuidance;
return state;
},
@@ -1928,16 +1679,6 @@ export const {
controlLayerWeightChanged,
controlLayerBeginEndStepPctChanged,
controlLayerWithTransparencyEffectToggled,
// IP Adapters
referenceImageAdded,
// referenceImageRecalled,
referenceImageIPAdapterImageChanged,
referenceImageIPAdapterMethodChanged,
referenceImageIPAdapterModelChanged,
referenceImageIPAdapterCLIPVisionModelChanged,
referenceImageIPAdapterWeightChanged,
referenceImageIPAdapterBeginEndStepPctChanged,
referenceImageIPAdapterFLUXReduxImageInfluenceChanged,
// Regions
rgAdded,
// rgRecalled,

View File

@@ -0,0 +1,323 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasMetadataRecalled } from 'features/controlLayers/store/canvasSlice';
import type { FLUXReduxImageInfluence, RefImagesState } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { isEqual } from 'lodash-es';
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import type { PartialDeep } from 'type-fest';
import type { CanvasReferenceImageState, CLIPVisionModelV2, IPMethodV2 } from './types';
import { getInitialRefImagesState } from './types';
import {
getReferenceImageState,
imageDTOToImageWithDims,
initialChatGPT4oReferenceImage,
initialFLUXRedux,
initialIPAdapter,
} from './util';
type PayloadWithId<T = void> = T extends void
? { id: string }
: {
id: string;
} & T;
export const refImagesSlice = createSlice({
name: 'refImages',
initialState: getInitialRefImagesState(),
reducers: {
referenceImageAdded: {
reducer: (
state,
action: PayloadAction<{
id: string;
overrides?: PartialDeep<CanvasReferenceImageState>;
isSelected?: boolean;
}>
) => {
const { id, overrides, isSelected } = action.payload;
const entityState = getReferenceImageState(id, overrides);
state.entities.push(entityState);
if (isSelected) {
state.selectedId = entityState.id;
}
},
prepare: (payload?: { overrides?: PartialDeep<CanvasReferenceImageState>; isSelected?: boolean }) => ({
payload: { ...payload, id: getPrefixedId('reference_image') },
}),
},
referenceImageRecalled: (state, action: PayloadAction<{ data: CanvasReferenceImageState }>) => {
const { data } = action.payload;
state.entities.push(data);
state.selectedId = data.id;
},
referenceImageIPAdapterImageChanged: (
state,
action: PayloadAction<PayloadWithId<{ imageDTO: ImageDTO | null }>>
) => {
const { id, imageDTO } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
entity.ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
},
referenceImageIPAdapterMethodChanged: (state, action: PayloadAction<PayloadWithId<{ method: IPMethodV2 }>>) => {
const { id, method } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.method = method;
},
referenceImageIPAdapterFLUXReduxImageInfluenceChanged: (
state,
action: PayloadAction<PayloadWithId<{ imageInfluence: FLUXReduxImageInfluence }>>
) => {
const { id, imageInfluence } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'flux_redux') {
return;
}
entity.ipAdapter.imageInfluence = imageInfluence;
},
referenceImageIPAdapterModelChanged: (
state,
action: PayloadAction<
PayloadWithId<{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null }>
>
) => {
const { id, modelConfig } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
const oldModel = entity.ipAdapter.model;
// First set the new model
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
if (!entity.ipAdapter.model) {
return;
}
if (isEqual(oldModel, entity.ipAdapter.model)) {
// Nothing changed, so we don't need to do anything
return;
}
// The type of ref image depends on the model. When the user switches the model, we rebuild the ref image.
// When we switch the model, we keep the image the same, but change the other parameters.
if (entity.ipAdapter.model.base === 'chatgpt-4o') {
// Switching to chatgpt-4o ref image
entity.ipAdapter = {
...initialChatGPT4oReferenceImage,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.model.type === 'flux_redux') {
// Switching to flux_redux
entity.ipAdapter = {
...initialFLUXRedux,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.model.type === 'ip_adapter') {
// Switching to ip_adapter
entity.ipAdapter = {
...initialIPAdapter,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (entity.ipAdapter.model?.base === 'flux') {
entity.ipAdapter.clipVisionModel = 'ViT-L';
} else if (entity.ipAdapter.clipVisionModel === 'ViT-L') {
// Fall back to ViT-H (ViT-G would also work)
entity.ipAdapter.clipVisionModel = 'ViT-H';
}
return;
}
},
referenceImageIPAdapterCLIPVisionModelChanged: (
state,
action: PayloadAction<PayloadWithId<{ clipVisionModel: CLIPVisionModelV2 }>>
) => {
const { id, clipVisionModel } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.clipVisionModel = clipVisionModel;
},
referenceImageIPAdapterWeightChanged: (state, action: PayloadAction<PayloadWithId<{ weight: number }>>) => {
const { id, weight } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.weight = weight;
},
referenceImageIPAdapterBeginEndStepPctChanged: (
state,
action: PayloadAction<PayloadWithId<{ beginEndStepPct: [number, number] }>>
) => {
const { id, beginEndStepPct } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.beginEndStepPct = beginEndStepPct;
},
//#region Shared entity
entitySelected: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
// Cannot select a non-existent entity
return;
}
state.selectedId = id;
},
entityNameChanged: (state, action: PayloadAction<PayloadWithId<{ name: string | null }>>) => {
const { id, name } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
entity.name = name;
},
entityDuplicated: (state, action: PayloadAction<PayloadWithId>) => {
const { id } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
const newEntity = deepClone(entity);
if (newEntity.name) {
newEntity.name = `${newEntity.name} (Copy)`;
}
newEntity.id = getPrefixedId('reference_image');
state.entities.push(newEntity);
state.selectedId = newEntity.id;
},
entityIsEnabledToggled: (state, action: PayloadAction<PayloadWithId>) => {
const { id } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
entity.isEnabled = !entity.isEnabled;
},
entityIsLockedToggled: (state, action: PayloadAction<PayloadWithId>) => {
const { id } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
entity.isLocked = !entity.isLocked;
},
entityDeleted: (state, action: PayloadAction<PayloadWithId>) => {
const { id } = action.payload;
let selectedId: string | null = null;
const entities = state.entities;
const index = entities.findIndex((entity) => entity.id === id);
const nextIndex = entities.length > 1 ? (index + 1) % entities.length : -1;
if (nextIndex !== -1) {
const nextEntity = entities[nextIndex];
if (nextEntity) {
selectedId = nextEntity.id;
}
}
state.entities = state.entities.filter((rg) => rg.id !== id);
state.selectedId = selectedId;
},
refImagesReset: () => getInitialRefImagesState(),
},
extraReducers(builder) {
builder.addCase(canvasMetadataRecalled, (state, action) => {
const { referenceImages } = action.payload;
state.entities = referenceImages;
});
},
});
export const {
referenceImageAdded,
// referenceImageRecalled,
referenceImageIPAdapterImageChanged,
referenceImageIPAdapterMethodChanged,
referenceImageIPAdapterModelChanged,
referenceImageIPAdapterCLIPVisionModelChanged,
referenceImageIPAdapterWeightChanged,
referenceImageIPAdapterBeginEndStepPctChanged,
referenceImageIPAdapterFLUXReduxImageInfluenceChanged,
} = refImagesSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const refImagesPersistConfig: PersistConfig<RefImagesState> = {
name: refImagesSlice.name,
initialState: getInitialRefImagesState(),
migrate,
persistDenylist: [],
};
export const selectRefImagesSlice = (state: RootState) => state.refImages;
export const selectReferenceImageEntities = createSelector(selectRefImagesSlice, (state) => state.entities);
export const selectActiveReferenceImageEntities = createSelector(selectReferenceImageEntities, (entities) =>
entities.filter((e) => e.isEnabled)
);
export const selectRefImageEntityIds = createMemoizedSelector(selectReferenceImageEntities, (entities) =>
entities.map((e) => e.id)
);
export const selectRefImageEntity = (state: RefImagesState, id: string) =>
state.entities.find((entity) => entity.id === id) ?? null;
export function selectRefImageEntityOrThrow(
state: RefImagesState,
id: string,
caller: string
): CanvasReferenceImageState {
const entity = selectRefImageEntity(state, id);
assert(entity, `Entity with id ${id} not found in ${caller}`);
return entity;
}

View File

@@ -2,17 +2,16 @@ import type { Selector } from '@reduxjs/toolkit';
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
CanvasEntityState,
CanvasEntityType,
CanvasInpaintMaskState,
CanvasMetadata,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
CanvasRenderableEntityIdentifier,
CanvasRenderableEntityState,
CanvasRenderableEntityType,
CanvasState,
} from 'features/controlLayers/store/types';
import { getGridSize, getOptimalDimension } from 'features/parameters/util/optimalDimension';
@@ -40,14 +39,13 @@ export const createCanvasSelector = <T>(selector: Selector<CanvasState, T>) =>
const selectEntityCountAll = createCanvasSelector((canvas) => {
return (
canvas.regionalGuidance.entities.length +
canvas.referenceImages.entities.length +
canvas.rasterLayers.entities.length +
canvas.controlLayers.entities.length +
canvas.inpaintMasks.entities.length
);
});
const isVisibleEntity = (entity: CanvasRenderableEntityState) => entity.isEnabled && entity.objects.length > 0;
const isVisibleEntity = (entity: CanvasEntityState) => entity.isEnabled && entity.objects.length > 0;
export const selectRasterLayerEntities = createCanvasSelector((canvas) => canvas.rasterLayers.entities);
export const selectActiveRasterLayerEntities = createSelector(selectRasterLayerEntities, (entities) =>
@@ -69,11 +67,6 @@ export const selectActiveRegionalGuidanceEntities = createSelector(selectRegiona
entities.filter(isVisibleEntity)
);
export const selectReferenceImageEntities = createCanvasSelector((canvas) => canvas.referenceImages.entities);
export const selectActiveReferenceImageEntities = createSelector(selectReferenceImageEntities, (entities) =>
entities.filter((e) => e.isEnabled)
);
/**
* Selects the total _active_ canvas entity count:
* - Regions
@@ -89,20 +82,17 @@ export const selectEntityCountActive = createSelector(
selectActiveControlLayerEntities,
selectActiveInpaintMaskEntities,
selectActiveRegionalGuidanceEntities,
selectActiveReferenceImageEntities,
(
activeRasterLayerEntities,
activeControlLayerEntities,
activeInpaintMaskEntities,
activeRegionalGuidanceEntities,
activeIPAdapterEntities
activeRegionalGuidanceEntities
) => {
return (
activeRasterLayerEntities.length +
activeControlLayerEntities.length +
activeInpaintMaskEntities.length +
activeRegionalGuidanceEntities.length +
activeIPAdapterEntities.length
activeRegionalGuidanceEntities.length
);
}
);
@@ -153,9 +143,6 @@ export function selectEntity<T extends CanvasEntityIdentifier>(
case 'regional_guidance':
entity = state.regionalGuidance.entities.find((entity) => entity.id === id);
break;
case 'reference_image':
entity = state.referenceImages.entities.find((entity) => entity.id === id);
break;
}
// This cast is safe, but TS seems to be unable to infer the type
@@ -165,13 +152,13 @@ export function selectEntity<T extends CanvasEntityIdentifier>(
/**
* Selects the entity identifier for the entity that is below the given entity in terms of draw order.
*/
export function selectEntityIdentifierBelowThisOne<T extends CanvasRenderableEntityIdentifier>(
export function selectEntityIdentifierBelowThisOne<T extends CanvasEntityIdentifier>(
state: CanvasState,
entityIdentifier: T
): Extract<CanvasEntityState, T> | undefined {
const { id, type } = entityIdentifier;
let entities: CanvasRenderableEntityState[];
let entities: CanvasEntityState[];
switch (type) {
case 'raster_layer': {
@@ -244,9 +231,6 @@ export function selectAllEntitiesOfType<T extends CanvasEntityState['type']>(
case 'regional_guidance':
entities = state.regionalGuidance.entities;
break;
case 'reference_image':
entities = state.referenceImages.entities;
break;
}
// This cast is safe, but TS seems to be unable to infer the type
@@ -259,7 +243,6 @@ export function selectAllEntitiesOfType<T extends CanvasEntityState['type']>(
export function selectAllEntities(state: CanvasState): CanvasEntityState[] {
// These are in the same order as they are displayed in the list!
return [
...state.referenceImages.entities.toReversed(),
...state.inpaintMasks.entities.toReversed(),
...state.regionalGuidance.entities.toReversed(),
...state.controlLayers.entities.toReversed(),
@@ -340,7 +323,7 @@ const selectRegionalGuidanceIsHidden = createCanvasSelector((canvas) => canvas.r
/**
* Returns the hidden selector for the given entity type.
*/
export const getSelectIsTypeHidden = (type: CanvasRenderableEntityType) => {
export const getSelectIsTypeHidden = (type: CanvasEntityType) => {
switch (type) {
case 'raster_layer':
return selectRasterLayersIsHidden;
@@ -379,9 +362,6 @@ export const buildSelectHasObjects = (entityIdentifier: CanvasEntityIdentifier)
if (!entity) {
return false;
}
if (entity.type === 'reference_image') {
return entity.ipAdapter.image !== null;
}
return entity.objects.length > 0;
});
};
@@ -397,9 +377,10 @@ export const selectBboxModelBase = createSelector(selectBbox, (bbox) => bbox.mod
export const selectCanvasMetadata = createSelector(
selectCanvasSlice,
(canvas): { canvas_v2_metadata: CanvasMetadata } => {
selectReferenceImageEntities,
(canvas, refImageEntities): { canvas_v2_metadata: CanvasMetadata } => {
const canvas_v2_metadata: CanvasMetadata = {
referenceImages: selectAllEntitiesOfType(canvas, 'reference_image'),
referenceImages: refImageEntities,
controlLayers: selectAllEntitiesOfType(canvas, 'control_layer'),
inpaintMasks: selectAllEntitiesOfType(canvas, 'inpaint_mask'),
rasterLayers: selectAllEntitiesOfType(canvas, 'raster_layer'),

View File

@@ -408,25 +408,14 @@ const zCanvasEntityState = z.discriminatedUnion('type', [
zCanvasControlLayerState,
zCanvasRegionalGuidanceState,
zCanvasInpaintMaskState,
zCanvasReferenceImageState,
]);
export type CanvasEntityState = z.infer<typeof zCanvasEntityState>;
const zCanvasRenderableEntityState = z.discriminatedUnion('type', [
zCanvasRasterLayerState,
zCanvasControlLayerState,
zCanvasRegionalGuidanceState,
zCanvasInpaintMaskState,
]);
export type CanvasRenderableEntityState = z.infer<typeof zCanvasRenderableEntityState>;
export type CanvasRenderableEntityType = CanvasRenderableEntityState['type'];
const zCanvasEntityType = z.union([
zCanvasRasterLayerState.shape.type,
zCanvasControlLayerState.shape.type,
zCanvasRegionalGuidanceState.shape.type,
zCanvasInpaintMaskState.shape.type,
zCanvasReferenceImageState.shape.type,
]);
export type CanvasEntityType = z.infer<typeof zCanvasEntityType>;
@@ -435,7 +424,7 @@ export const zCanvasEntityIdentifer = z.object({
type: zCanvasEntityType,
});
export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T };
export type CanvasRenderableEntityIdentifier = CanvasEntityIdentifier<CanvasRenderableEntityType>;
export type LoRA = {
id: string;
isEnabled: boolean;
@@ -570,9 +559,6 @@ const zRegionalGuidance = z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRegionalGuidanceState),
});
const zReferenceImages = z.object({
entities: z.array(zCanvasReferenceImageState),
});
const zCanvasState = z.object({
_version: z.literal(3).default(3),
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
@@ -581,7 +567,6 @@ const zCanvasState = z.object({
rasterLayers: zRasterLayers.default({ isHidden: false, entities: [] }),
controlLayers: zControlLayers.default({ isHidden: false, entities: [] }),
regionalGuidance: zRegionalGuidance.default({ isHidden: false, entities: [] }),
referenceImages: zReferenceImages.default({ entities: [] }),
bbox: zBboxState.default({
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
@@ -592,6 +577,14 @@ const zCanvasState = z.object({
});
export type CanvasState = z.infer<typeof zCanvasState>;
const zRefImagesState = z.object({
selectedId: zId.nullable().default(null),
entities: z.array(zCanvasReferenceImageState).default(() => []),
});
export type RefImagesState = z.infer<typeof zRefImagesState>;
const INITIAL_REF_IMAGES_STATE = zRefImagesState.parse({});
export const getInitialRefImagesState = () => deepClone(INITIAL_REF_IMAGES_STATE);
/**
* Gets a fresh canvas initial state with no references in memory to existing objects.
*/
@@ -657,17 +650,6 @@ export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';
export type CanvasEntityStateFromType<T extends CanvasEntityType> = Extract<CanvasEntityState, { type: T }>;
export function isRenderableEntityType(
entityType: CanvasEntityState['type']
): entityType is CanvasRenderableEntityState['type'] {
return (
entityType === 'raster_layer' ||
entityType === 'control_layer' ||
entityType === 'regional_guidance' ||
entityType === 'inpaint_mask'
);
}
export function isRasterLayerEntityIdentifier(
entityIdentifier: CanvasEntityIdentifier
): entityIdentifier is CanvasEntityIdentifier<'raster_layer'> {
@@ -725,16 +707,6 @@ export function isSaveableEntityIdentifier(
return isRasterLayerEntityIdentifier(entityIdentifier) || isControlLayerEntityIdentifier(entityIdentifier);
}
export function isRenderableEntity(entity: CanvasEntityState): entity is CanvasRenderableEntityState {
return isRenderableEntityType(entity.type);
}
export function isRenderableEntityIdentifier(
entityIdentifier: CanvasEntityIdentifier
): entityIdentifier is CanvasRenderableEntityIdentifier {
return isRenderableEntityType(entityIdentifier.type);
}
export const getEntityIdentifier = <T extends CanvasEntityType>(
entity: Extract<CanvasEntityState, { type: T }>
): CanvasEntityIdentifier<T> => {

View File

@@ -21,6 +21,7 @@ import type {
import { merge } from 'lodash-es';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import type { PartialDeep } from 'type-fest';
export const imageDTOToImageObject = (imageDTO: ImageDTO, overrides?: Partial<CanvasImageState>): CanvasImageState => {
const { width, height, image_name } = imageDTO;
@@ -127,7 +128,7 @@ export const initialControlLoRA: ControlLoRAConfig = {
export const getReferenceImageState = (
id: string,
overrides?: Partial<CanvasReferenceImageState>
overrides?: PartialDeep<CanvasReferenceImageState>
): CanvasReferenceImageState => {
const entityState: CanvasReferenceImageState = {
id,
@@ -143,7 +144,7 @@ export const getReferenceImageState = (
export const getRegionalGuidanceState = (
id: string,
overrides?: Partial<CanvasRegionalGuidanceState>
overrides?: PartialDeep<CanvasRegionalGuidanceState>
): CanvasRegionalGuidanceState => {
const entityState: CanvasRegionalGuidanceState = {
id,
@@ -169,7 +170,7 @@ export const getRegionalGuidanceState = (
export const getControlLayerState = (
id: string,
overrides?: Partial<CanvasControlLayerState>
overrides?: PartialDeep<CanvasControlLayerState>
): CanvasControlLayerState => {
const entityState: CanvasControlLayerState = {
id,
@@ -189,7 +190,7 @@ export const getControlLayerState = (
export const getRasterLayerState = (
id: string,
overrides?: Partial<CanvasRasterLayerState>
overrides?: PartialDeep<CanvasRasterLayerState>
): CanvasRasterLayerState => {
const entityState: CanvasRasterLayerState = {
id,
@@ -207,7 +208,7 @@ export const getRasterLayerState = (
export const getInpaintMaskState = (
id: string,
overrides?: Partial<CanvasInpaintMaskState>
overrides?: PartialDeep<CanvasInpaintMaskState>
): CanvasInpaintMaskState => {
const entityState: CanvasInpaintMaskState = {
id,
@@ -232,7 +233,7 @@ export const getInpaintMaskState = (
const convertRasterLayerToControlLayer = (
newId: string,
rasterLayerState: CanvasRasterLayerState,
overrides?: Partial<CanvasControlLayerState>
overrides?: PartialDeep<CanvasControlLayerState>
): CanvasControlLayerState => {
const { name, objects, position } = rasterLayerState;
const controlLayerState = getControlLayerState(newId, {
@@ -247,7 +248,7 @@ const convertRasterLayerToControlLayer = (
const convertRasterLayerToInpaintMask = (
newId: string,
rasterLayerState: CanvasRasterLayerState,
overrides?: Partial<CanvasInpaintMaskState>
overrides?: PartialDeep<CanvasInpaintMaskState>
): CanvasInpaintMaskState => {
const { name, objects, position } = rasterLayerState;
const inpaintMaskState = getInpaintMaskState(newId, {
@@ -262,7 +263,7 @@ const convertRasterLayerToInpaintMask = (
const convertRasterLayerToRegionalGuidance = (
newId: string,
rasterLayerState: CanvasRasterLayerState,
overrides?: Partial<CanvasRegionalGuidanceState>
overrides?: PartialDeep<CanvasRegionalGuidanceState>
): CanvasRegionalGuidanceState => {
const { name, objects, position } = rasterLayerState;
const regionalGuidanceState = getRegionalGuidanceState(newId, {
@@ -277,7 +278,7 @@ const convertRasterLayerToRegionalGuidance = (
const convertControlLayerToRasterLayer = (
newId: string,
controlLayerState: CanvasControlLayerState,
overrides?: Partial<CanvasRasterLayerState>
overrides?: PartialDeep<CanvasRasterLayerState>
): CanvasRasterLayerState => {
const { name, objects, position } = controlLayerState;
const rasterLayerState = getRasterLayerState(newId, {
@@ -292,7 +293,7 @@ const convertControlLayerToRasterLayer = (
const convertControlLayerToInpaintMask = (
newId: string,
rasterLayerState: CanvasControlLayerState,
overrides?: Partial<CanvasInpaintMaskState>
overrides?: PartialDeep<CanvasInpaintMaskState>
): CanvasInpaintMaskState => {
const { name, objects, position } = rasterLayerState;
const inpaintMaskState = getInpaintMaskState(newId, {
@@ -307,7 +308,7 @@ const convertControlLayerToInpaintMask = (
const convertControlLayerToRegionalGuidance = (
newId: string,
rasterLayerState: CanvasControlLayerState,
overrides?: Partial<CanvasRegionalGuidanceState>
overrides?: PartialDeep<CanvasRegionalGuidanceState>
): CanvasRegionalGuidanceState => {
const { name, objects, position } = rasterLayerState;
const regionalGuidanceState = getRegionalGuidanceState(newId, {
@@ -322,7 +323,7 @@ const convertControlLayerToRegionalGuidance = (
const convertInpaintMaskToRegionalGuidance = (
newId: string,
inpaintMaskState: CanvasInpaintMaskState,
overrides?: Partial<CanvasRegionalGuidanceState>
overrides?: PartialDeep<CanvasRegionalGuidanceState>
): CanvasRegionalGuidanceState => {
const { name, objects, position } = inpaintMaskState;
const regionalGuidanceState = getRegionalGuidanceState(newId, {
@@ -337,7 +338,7 @@ const convertInpaintMaskToRegionalGuidance = (
const convertRegionalGuidanceToInpaintMask = (
newId: string,
regionalGuidanceState: CanvasRegionalGuidanceState,
overrides?: Partial<CanvasInpaintMaskState>
overrides?: PartialDeep<CanvasInpaintMaskState>
): CanvasInpaintMaskState => {
const { name, objects, position } = regionalGuidanceState;
const inpaintMaskState = getInpaintMaskState(newId, {

View File

@@ -1,9 +1,14 @@
import { useStore } from '@nanostores/react';
import { getStore, useAppStore } from 'app/store/nanostores/store';
import type { AppDispatch, AppGetState, RootState } from 'app/store/store';
import { entityDeleted, referenceImageIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
import { entityDeleted } from 'features/controlLayers/store/canvasSlice';
import {
referenceImageIPAdapterImageChanged,
selectReferenceImageEntities,
selectRefImagesSlice,
} from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { type CanvasState, getEntityIdentifier } from 'features/controlLayers/store/types';
import type { CanvasState, RefImagesState } from 'features/controlLayers/store/types';
import type { ImageUsage } from 'features/deleteImageModal/store/types';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
@@ -145,8 +150,9 @@ const getImageUsageFromImageDTOs = (imageDTOs: ImageDTO[], state: RootState): Im
const nodes = selectNodesSlice(state);
const canvas = selectCanvasSlice(state);
const upscale = selectUpscaleSlice(state);
const refImages = selectRefImagesSlice(state);
return imageDTOs.map(({ image_name }) => getImageUsage(nodes, canvas, upscale, image_name));
return imageDTOs.map(({ image_name }) => getImageUsage(nodes, canvas, upscale, refImages, image_name));
};
const getImageUsageSummary = (imageUsage: ImageUsage[]): ImageUsage => ({
@@ -221,9 +227,9 @@ const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, image
};
const deleteReferenceImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
selectReferenceImageEntities(state).forEach((entity) => {
if (entity.ipAdapter.image?.image_name === imageDTO.image_name) {
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
dispatch(referenceImageIPAdapterImageChanged({ id: entity.id, imageDTO: null }));
}
});
};
@@ -243,7 +249,13 @@ const deleteRasterLayerImages = (state: RootState, dispatch: AppDispatch, imageD
});
};
export const getImageUsage = (nodes: NodesState, canvas: CanvasState, upscale: UpscaleState, image_name: string) => {
export const getImageUsage = (
nodes: NodesState,
canvas: CanvasState,
upscale: UpscaleState,
refImages: RefImagesState,
image_name: string
) => {
const isNodesImage = nodes.nodes.filter(isInvocationNode).some((node) =>
some(node.data.inputs, (input) => {
if (isImageFieldInputInstance(input)) {
@@ -264,9 +276,7 @@ export const getImageUsage = (nodes: NodesState, canvas: CanvasState, upscale: U
const isUpscaleImage = upscale.upscaleInitialImage?.image_name === image_name;
const isReferenceImage = canvas.referenceImages.entities.some(
({ ipAdapter }) => ipAdapter.image?.image_name === image_name
);
const isReferenceImage = refImages.entities.some(({ ipAdapter }) => ipAdapter.image?.image_name === image_name);
const isRasterLayerImage = canvas.rasterLayers.entities.some(({ objects }) =>
objects.some((obj) => obj.type === 'image' && 'image_name' in obj.image && obj.image.image_name === image_name)

View File

@@ -1,11 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
CanvasEntityIdentifier,
CanvasEntityType,
CanvasRenderableEntityIdentifier,
} from 'features/controlLayers/store/types';
import type { CanvasEntityIdentifier, CanvasEntityType } from 'features/controlLayers/store/types';
import { selectComparisonImages } from 'features/gallery/components/ImageViewer/common';
import type { BoardId } from 'features/gallery/store/types';
import {
@@ -133,7 +129,7 @@ const _setGlobalReferenceImage = buildTypeAndKey('set-global-reference-image');
export type SetGlobalReferenceImageDndTargetData = DndData<
typeof _setGlobalReferenceImage.type,
typeof _setGlobalReferenceImage.key,
{ entityIdentifier: CanvasEntityIdentifier<'reference_image'> }
{ id: string }
>;
export const setGlobalReferenceImageDndTarget: DndTarget<
SetGlobalReferenceImageDndTargetData,
@@ -150,8 +146,8 @@ export const setGlobalReferenceImageDndTarget: DndTarget<
},
handler: ({ sourceData, targetData, dispatch }) => {
const { imageDTO } = sourceData.payload;
const { entityIdentifier } = targetData.payload;
setGlobalReferenceImage({ entityIdentifier, imageDTO, dispatch });
const { id } = targetData.payload;
setGlobalReferenceImage({ id, imageDTO, dispatch });
},
};
//#endregion
@@ -352,7 +348,7 @@ type NewCanvasFromImageDndTargetData = DndData<
typeof _newCanvas.type,
typeof _newCanvas.key,
{
type: CanvasEntityType | 'regional_guidance_with_reference_image';
type: CanvasEntityType | 'regional_guidance_with_reference_image' | 'reference_image';
withResize?: boolean;
withInpaintMask?: boolean;
}
@@ -379,7 +375,7 @@ const _replaceCanvasEntityObjectsWithImage = buildTypeAndKey('replace-canvas-ent
export type ReplaceCanvasEntityObjectsWithImageDndTargetData = DndData<
typeof _replaceCanvasEntityObjectsWithImage.type,
typeof _replaceCanvasEntityObjectsWithImage.key,
{ entityIdentifier: CanvasRenderableEntityIdentifier }
{ entityIdentifier: CanvasEntityIdentifier }
>;
export const replaceCanvasEntityObjectsWithImageDndTarget: DndTarget<
ReplaceCanvasEntityObjectsWithImageDndTargetData,

View File

@@ -15,6 +15,7 @@ import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import ImageUsageMessage from 'features/deleteImageModal/components/ImageUsageMessage';
import { getImageUsage } from 'features/deleteImageModal/store/state';
@@ -54,23 +55,26 @@ const DeleteBoardModal = () => {
const selectImageUsageSummary = useMemo(
() =>
createMemoizedSelector([selectNodesSlice, selectCanvasSlice, selectUpscaleSlice], (nodes, canvas, upscale) => {
const allImageUsage = (boardImageNames ?? []).map((imageName) =>
getImageUsage(nodes, canvas, upscale, imageName)
);
createMemoizedSelector(
[selectNodesSlice, selectCanvasSlice, selectUpscaleSlice, selectRefImagesSlice],
(nodes, canvas, upscale, refImages) => {
const allImageUsage = (boardImageNames ?? []).map((imageName) =>
getImageUsage(nodes, canvas, upscale, refImages, imageName)
);
const imageUsageSummary: ImageUsage = {
isUpscaleImage: some(allImageUsage, (i) => i.isUpscaleImage),
isRasterLayerImage: some(allImageUsage, (i) => i.isRasterLayerImage),
isInpaintMaskImage: some(allImageUsage, (i) => i.isInpaintMaskImage),
isRegionalGuidanceImage: some(allImageUsage, (i) => i.isRegionalGuidanceImage),
isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
isControlLayerImage: some(allImageUsage, (i) => i.isControlLayerImage),
isReferenceImage: some(allImageUsage, (i) => i.isReferenceImage),
};
const imageUsageSummary: ImageUsage = {
isUpscaleImage: some(allImageUsage, (i) => i.isUpscaleImage),
isRasterLayerImage: some(allImageUsage, (i) => i.isRasterLayerImage),
isInpaintMaskImage: some(allImageUsage, (i) => i.isInpaintMaskImage),
isRegionalGuidanceImage: some(allImageUsage, (i) => i.isRegionalGuidanceImage),
isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
isControlLayerImage: some(allImageUsage, (i) => i.isControlLayerImage),
isReferenceImage: some(allImageUsage, (i) => i.isReferenceImage),
};
return imageUsageSummary;
}),
return imageUsageSummary;
}
),
[boardImageNames]
);

View File

@@ -3,6 +3,8 @@ import { useAppStore } from 'app/store/nanostores/store';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
import { useCanvasIsBusySafe } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { referenceImageAdded } from 'features/controlLayers/store/refImagesSlice';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { sentImageToCanvas } from 'features/gallery/store/actions';
@@ -74,8 +76,8 @@ export const ImageMenuItemNewLayerFromImageSubMenu = memo(() => {
}, [imageDTO, imageViewer, store, t]);
const onClickNewGlobalReferenceImageFromImage = useCallback(() => {
const { dispatch, getState } = store;
createNewCanvasEntityFromImage({ imageDTO, type: 'reference_image', dispatch, getState });
const { dispatch } = store;
dispatch(referenceImageAdded({ overrides: { ipAdapter: { image: imageDTOToImageWithDims(imageDTO) } } }));
dispatch(sentImageToCanvas());
dispatch(setActiveTab('canvas'));
imageViewer.close();

View File

@@ -10,23 +10,21 @@ import {
entityRasterized,
inpaintMaskAdded,
rasterLayerAdded,
referenceImageAdded,
referenceImageIPAdapterImageChanged,
rgAdded,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasSlice';
import { canvasSessionTypeChanged } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { referenceImageAdded, referenceImageIPAdapterImageChanged } from 'features/controlLayers/store/refImagesSlice';
import { selectBboxModelBase, selectBboxRect } from 'features/controlLayers/store/selectors';
import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
CanvasEntityState,
CanvasEntityType,
CanvasImageState,
CanvasInpaintMaskState,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
CanvasRenderableEntityIdentifier,
CanvasRenderableEntityState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
@@ -41,13 +39,9 @@ import type { ImageDTO } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
export const setGlobalReferenceImage = (arg: {
imageDTO: ImageDTO;
entityIdentifier: CanvasEntityIdentifier<'reference_image'>;
dispatch: AppDispatch;
}) => {
const { imageDTO, entityIdentifier, dispatch } = arg;
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier, imageDTO }));
export const setGlobalReferenceImage = (arg: { imageDTO: ImageDTO; id: string; dispatch: AppDispatch }) => {
const { imageDTO, id, dispatch } = arg;
dispatch(referenceImageIPAdapterImageChanged({ id, imageDTO }));
};
export const setRegionalGuidanceReferenceImage = (arg: {
@@ -84,7 +78,7 @@ export const createNewCanvasEntityFromImage = (arg: {
type: CanvasEntityType | 'regional_guidance_with_reference_image';
dispatch: AppDispatch;
getState: () => RootState;
overrides?: Partial<Pick<CanvasRenderableEntityState, 'isEnabled' | 'isLocked' | 'name' | 'position'>>;
overrides?: Partial<Pick<CanvasEntityState, 'isEnabled' | 'isLocked' | 'name' | 'position'>>;
}) => {
const { type, imageDTO, dispatch, getState, overrides: _overrides } = arg;
const state = getState();
@@ -117,12 +111,6 @@ export const createNewCanvasEntityFromImage = (arg: {
dispatch(rgAdded({ overrides, isSelected: true }));
break;
}
case 'reference_image': {
const ipAdapter = deepClone(selectDefaultRefImageConfig(getState()));
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));
break;
}
case 'regional_guidance_with_reference_image': {
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
@@ -146,7 +134,7 @@ export const createNewCanvasEntityFromImage = (arg: {
*/
export const newCanvasFromImage = async (arg: {
imageDTO: ImageDTO;
type: CanvasEntityType | 'regional_guidance_with_reference_image';
type: CanvasEntityType | 'regional_guidance_with_reference_image' | 'reference_image';
withResize?: boolean;
withInpaintMask?: boolean;
dispatch: AppDispatch;
@@ -283,7 +271,7 @@ export const newCanvasFromImage = async (arg: {
export const replaceCanvasEntityObjectsWithImage = (arg: {
imageDTO: ImageDTO;
entityIdentifier: CanvasRenderableEntityIdentifier;
entityIdentifier: CanvasEntityIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {

View File

@@ -3,6 +3,7 @@ import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isChatGPT4oAspectRatioID, isChatGPT4oReferenceImageConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
@@ -32,6 +33,7 @@ export const buildChatGPT4oGraph = async (
const model = selectMainModelConfig(state);
const canvas = selectCanvasSlice(state);
const refImages = selectRefImagesSlice(state);
const { bbox } = canvas;
const { positivePrompt } = selectPresetModifiedPrompts(state);
@@ -41,7 +43,7 @@ export const buildChatGPT4oGraph = async (
assert(isChatGPT4oAspectRatioID(bbox.aspectRatio.id), 'ChatGPT 4o does not support this aspect ratio');
const validRefImages = canvas.referenceImages.entities
const validRefImages = refImages.entities
.filter((entity) => entity.isEnabled)
.filter((entity) => isChatGPT4oReferenceImageConfig(entity.ipAdapter))
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0)

View File

@@ -3,6 +3,7 @@ import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { addFLUXFill } from 'features/nodes/util/graph/generation/addFLUXFill';
import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs';
@@ -42,6 +43,7 @@ export const buildFLUXGraph = async (state: RootState, manager?: CanvasManager |
const params = selectParamsSlice(state);
const canvas = selectCanvasSlice(state);
const refImages = selectRefImagesSlice(state);
const { bbox } = canvas;
@@ -271,7 +273,7 @@ export const buildFLUXGraph = async (state: RootState, manager?: CanvasManager |
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters({
entities: canvas.referenceImages.entities,
entities: refImages.entities,
g,
collector: ipAdapterCollect,
model,
@@ -284,7 +286,7 @@ export const buildFLUXGraph = async (state: RootState, manager?: CanvasManager |
id: getPrefixedId('ip_adapter_collector'),
});
const fluxReduxResult = addFLUXReduxes({
entities: canvas.referenceImages.entities,
entities: refImages.entities,
g,
collector: fluxReduxCollect,
model,

View File

@@ -3,6 +3,7 @@ import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { addControlNets, addT2IAdapters } from 'features/nodes/util/graph/generation/addControlAdapters';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
@@ -37,6 +38,7 @@ export const buildSD1Graph = async (state: RootState, manager?: CanvasManager |
const params = selectParamsSlice(state);
const canvas = selectCanvasSlice(state);
const refImages = selectRefImagesSlice(state);
const { bbox } = canvas;
const model = selectMainModelConfig(state);
@@ -265,7 +267,7 @@ export const buildSD1Graph = async (state: RootState, manager?: CanvasManager |
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters({
entities: canvas.referenceImages.entities,
entities: refImages.entities,
g,
collector: ipAdapterCollect,
model,

View File

@@ -3,6 +3,7 @@ import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { addControlNets, addT2IAdapters } from 'features/nodes/util/graph/generation/addControlAdapters';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
@@ -41,6 +42,7 @@ export const buildSDXLGraph = async (state: RootState, manager?: CanvasManager |
const params = selectParamsSlice(state);
const canvas = selectCanvasSlice(state);
const refImages = selectRefImagesSlice(state);
const { bbox } = canvas;
@@ -272,7 +274,7 @@ export const buildSDXLGraph = async (state: RootState, manager?: CanvasManager |
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters({
entities: canvas.referenceImages.entities,
entities: refImages.entities,
g,
collector: ipAdapterCollect,
model,

View File

@@ -1,6 +1,7 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { usePersistedTextAreaSize } from 'common/hooks/usePersistedTextareaSize';
import { RefImageList } from 'features/controlLayers/components/IPAdapter/IPAdapterList';
import { positivePromptChanged, selectBase, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel';
@@ -107,6 +108,7 @@ export const ParamPositivePrompt = memo(() => {
label={`${t('parameters.positivePromptPlaceholder')} (${t('stylePresets.preview')})`}
/>
)}
<RefImageList position="absolute" bottom={2} left={2} />
</Box>
</PromptPopover>
);

View File

@@ -9,8 +9,9 @@ import type { AppConfig } from 'app/types/invokeai';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasState, ParamsState } from 'features/controlLayers/store/types';
import type { CanvasState, ParamsState, RefImagesState } from 'features/controlLayers/store/types';
import {
getControlLayerWarnings,
getGlobalReferenceImageWarnings,
@@ -75,6 +76,7 @@ const debouncedUpdateReasons = debounce(
isConnected: boolean,
canvas: CanvasState,
params: ParamsState,
refImages: RefImagesState,
dynamicPrompts: DynamicPromptsState,
canvasIsFiltering: boolean,
canvasIsTransforming: boolean,
@@ -97,6 +99,7 @@ const debouncedUpdateReasons = debounce(
model,
canvas,
params,
refImages,
dynamicPrompts,
canvasIsFiltering,
canvasIsTransforming,
@@ -138,6 +141,7 @@ export const useReadinessWatcher = () => {
const tab = useAppSelector(selectActiveTab);
const canvas = useAppSelector(selectCanvasSlice);
const params = useAppSelector(selectParamsSlice);
const refImages = useAppSelector(selectRefImagesSlice);
const dynamicPrompts = useAppSelector(selectDynamicPromptsSlice);
const nodes = useAppSelector(selectNodesSlice);
const workflowSettings = useAppSelector(selectWorkflowSettingsSlice);
@@ -159,6 +163,7 @@ export const useReadinessWatcher = () => {
isConnected,
canvas,
params,
refImages,
dynamicPrompts,
canvasIsFiltering,
canvasIsTransforming,
@@ -177,6 +182,7 @@ export const useReadinessWatcher = () => {
}, [
store,
canvas,
refImages,
canvasIsCompositing,
canvasIsFiltering,
canvasIsRasterizing,
@@ -334,6 +340,7 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
model: MainModelConfig | null | undefined;
canvas: CanvasState;
params: ParamsState;
refImages: RefImagesState;
dynamicPrompts: DynamicPromptsState;
canvasIsFiltering: boolean;
canvasIsTransforming: boolean;
@@ -347,6 +354,7 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
model,
canvas,
params,
refImages,
dynamicPrompts,
canvasIsFiltering,
canvasIsTransforming,
@@ -514,24 +522,21 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
}
});
const enabledGlobalReferenceLayers = canvas.referenceImages.entities.filter(
(referenceImage) => referenceImage.isEnabled
);
// Flux Kontext only supports 1x Reference Image at a time.
const referenceImageCount = enabledGlobalReferenceLayers.length;
const referenceImageCount = refImages.entities.filter((entity) => entity.isEnabled).length;
if (model?.base === 'flux-kontext' && referenceImageCount > 1) {
reasons.push({ content: i18n.t('parameters.invoke.fluxKontextMultipleReferenceImages') });
}
canvas.referenceImages.entities
refImages.entities
.filter((entity) => entity.isEnabled)
.forEach((entity, i) => {
const layerLiteral = i18n.t('controlLayers.layer_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems = getGlobalReferenceImageWarnings(entity, model);
if (problems.length) {