feat(ui): add button to ref image to recall size & optimize for model

This is useful for FLUX Kontext, where you typically want the generation
size to at least roughly match the first ref image size.
This commit is contained in:
psychedelicious
2025-08-05 07:47:20 +10:00
parent 111408c046
commit 61ff9ee3a7
4 changed files with 52 additions and 1 deletions

View File

@@ -1,15 +1,20 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import { bboxSizeOptimized, bboxSizeRecalled } from 'features/controlLayers/store/canvasSlice';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { sizeOptimized, sizeRecalled } from 'features/controlLayers/store/paramsSlice';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import type { setGlobalReferenceImageDndTarget, setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
@@ -29,7 +34,10 @@ export const RefImageImage = memo(
dndTargetData,
}: Props<T>) => {
const { t } = useTranslation();
const store = useAppStore();
const isConnected = useStore($isConnected);
const tab = useAppSelector(selectActiveTab);
const isStaging = useCanvasIsStaging();
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
@@ -48,6 +56,20 @@ export const RefImageImage = memo(
[onChangeImage]
);
const recallSizeAndOptimize = useCallback(() => {
if (!imageDTO || (tab === 'canvas' && isStaging)) {
return;
}
const { width, height } = imageDTO;
if (tab === 'canvas') {
store.dispatch(bboxSizeRecalled({ width, height }));
store.dispatch(bboxSizeOptimized());
} else if (tab === 'generate') {
store.dispatch(sizeRecalled({ width, height }));
store.dispatch(sizeOptimized());
}
}, [imageDTO, isStaging, store, tab]);
return (
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
{!imageDTO && (
@@ -69,6 +91,14 @@ export const RefImageImage = memo(
tooltip={t('common.reset')}
/>
</Flex>
<Flex position="absolute" flexDir="column" bottom={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={recallSizeAndOptimize}
icon={<PiRulerBold size={16} />}
tooltip={t('parameters.useSize')}
isDisabled={!imageDTO || (tab === 'canvas' && isStaging)}
/>
</Flex>
</>
)}
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />

View File

@@ -1091,6 +1091,15 @@ const slice = createSlice({
syncScaledSize(state);
},
bboxSizeRecalled: (state, action: PayloadAction<{ width: number; height: number }>) => {
const { width, height } = action.payload;
const gridSize = getGridSize(state.bbox.modelBase);
state.bbox.rect.width = Math.max(roundDownToMultiple(width, gridSize), 64);
state.bbox.rect.height = Math.max(roundDownToMultiple(height, gridSize), 64);
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.id = 'Free';
state.bbox.aspectRatio.isLocked = true;
},
bboxAspectRatioLockToggled: (state) => {
state.bbox.aspectRatio.isLocked = !state.bbox.aspectRatio.isLocked;
syncScaledSize(state);
@@ -1627,6 +1636,7 @@ export const {
bboxScaledWidthChanged,
bboxScaledHeightChanged,
bboxScaleMethodChanged,
bboxSizeRecalled,
bboxWidthChanged,
bboxHeightChanged,
bboxAspectRatioLockToggled,

View File

@@ -241,6 +241,15 @@ const slice = createSlice({
},
//#region Dimensions
sizeRecalled: (state, action: PayloadAction<{ width: number; height: number }>) => {
const { width, height } = action.payload;
const gridSize = getGridSize(state.model?.base);
state.dimensions.rect.width = Math.max(roundDownToMultiple(width, gridSize), 64);
state.dimensions.rect.height = Math.max(roundDownToMultiple(height, gridSize), 64);
state.dimensions.aspectRatio.value = state.dimensions.rect.width / state.dimensions.rect.height;
state.dimensions.aspectRatio.id = 'Free';
state.dimensions.aspectRatio.isLocked = true;
},
widthChanged: (state, action: PayloadAction<{ width: number; updateAspectRatio?: boolean; clamp?: boolean }>) => {
const { width, updateAspectRatio, clamp } = action.payload;
const gridSize = getGridSize(state.model?.base);
@@ -429,6 +438,7 @@ export const {
modelChanged,
// Dimensions
sizeRecalled,
widthChanged,
heightChanged,
aspectRatioLockToggled,

View File

@@ -27,6 +27,7 @@ export const DndImageIcon = memo((props: Props) => {
return (
<IconButton
onClick={onClick}
tooltip={tooltip}
aria-label={tooltip}
icon={icon}
variant="link"