mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Compare commits
6 Commits
v5.4.1rc2
...
psyche/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3304dcbe3 | ||
|
|
fb19621361 | ||
|
|
3f880496f7 | ||
|
|
79eb8172b6 | ||
|
|
5b3e1593ca | ||
|
|
2d08078a7d |
@@ -16,6 +16,7 @@ from pydantic import (
|
||||
from pydantic_core import to_jsonable_python
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WorkflowWithoutID,
|
||||
@@ -51,11 +52,7 @@ class SessionQueueItemNotFoundError(ValueError):
|
||||
|
||||
# region Batch
|
||||
|
||||
BatchDataType = Union[
|
||||
StrictStr,
|
||||
float,
|
||||
int,
|
||||
]
|
||||
BatchDataType = Union[StrictStr, float, int, ImageField]
|
||||
|
||||
|
||||
class NodeFieldValue(BaseModel):
|
||||
|
||||
@@ -300,7 +300,7 @@ ip_adapter_sdxl = StarterModel(
|
||||
ip_adapter_flux = StarterModel(
|
||||
name="Standard Reference (XLabs FLUX IP-Adapter)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/flux-ip-adapter.safetensors",
|
||||
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors",
|
||||
description="References images with a more generalized/looser degree of precision.",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[clip_vit_l_image_encoder],
|
||||
|
||||
@@ -997,7 +997,7 @@
|
||||
"controlNetControlMode": "Control Mode",
|
||||
"copyImage": "Copy Image",
|
||||
"denoisingStrength": "Denoising Strength",
|
||||
"noRasterLayers": "No Raster Layers",
|
||||
"disabledNoRasterContent": "Disabled (No Raster Content)",
|
||||
"downloadImage": "Download Image",
|
||||
"general": "General",
|
||||
"guidance": "Guidance",
|
||||
@@ -1999,7 +1999,9 @@
|
||||
"upscaleModelDesc": "Upscale (image to image) model",
|
||||
"missingUpscaleInitialImage": "Missing initial image for upscaling",
|
||||
"missingUpscaleModel": "Missing upscale model",
|
||||
"missingTileControlNetModel": "No valid tile ControlNet models installed"
|
||||
"missingTileControlNetModel": "No valid tile ControlNet models installed",
|
||||
"incompatibleBaseModel": "Unsupported main model architecture for upscaling",
|
||||
"incompatibleBaseModelDesc": "Upscaling is supported for SD1.5 and SDXL architecture models only. Change the main model to enable upscaling."
|
||||
},
|
||||
"stylePresets": {
|
||||
"active": "Active",
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isImageBatchNode } from 'features/nodes/types/invocation';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig } from 'services/api/types';
|
||||
import type { Batch, BatchConfig } from 'services/api/types';
|
||||
|
||||
export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -26,6 +27,18 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
delete builtWorkflow.id;
|
||||
}
|
||||
|
||||
const data: Batch['data'] = [];
|
||||
|
||||
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
|
||||
const imageBatchNodes = nodes.nodes.filter(isImageBatchNode);
|
||||
for (const imageBatch of imageBatchNodes) {
|
||||
const edge = nodes.edges.find((e) => e.source === imageBatch.id);
|
||||
if (!edge || !edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
data.push([{ node_path: edge.target, field_name: edge.targetHandle, items: imageBatch.data.images }]);
|
||||
}
|
||||
|
||||
const batchConfig: BatchConfig = {
|
||||
batch: {
|
||||
graph,
|
||||
@@ -33,6 +46,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
runs: state.params.iterations,
|
||||
origin: 'workflows',
|
||||
destination: 'gallery',
|
||||
data,
|
||||
},
|
||||
prepend: action.payload.prepend,
|
||||
};
|
||||
|
||||
@@ -119,11 +119,20 @@ const createSelector = (
|
||||
reasons.push({ content: i18n.t('upscaling.exceedsMaxSize') });
|
||||
}
|
||||
}
|
||||
if (!upscale.upscaleModel) {
|
||||
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
|
||||
}
|
||||
if (!upscale.tileControlnetModel) {
|
||||
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
|
||||
if (model && !['sd-1', 'sdxl'].includes(model.base)) {
|
||||
// When we are using an upsupported model, do not add the other warnings
|
||||
reasons.push({ content: i18n.t('upscaling.incompatibleBaseModel') });
|
||||
} else {
|
||||
// Using a compatible model, add all warnings
|
||||
if (!model) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
|
||||
}
|
||||
if (!upscale.upscaleModel) {
|
||||
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
|
||||
}
|
||||
if (!upscale.tileControlnetModel) {
|
||||
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (canvasIsFiltering) {
|
||||
|
||||
@@ -17,12 +17,15 @@ import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectIsEnabled = createSelector(selectActiveRasterLayerEntities, (entities) => entities.length > 0);
|
||||
const selectHasRasterLayersWithContent = createSelector(
|
||||
selectActiveRasterLayerEntities,
|
||||
(entities) => entities.length > 0
|
||||
);
|
||||
|
||||
export const ParamDenoisingStrength = memo(() => {
|
||||
const img2imgStrength = useAppSelector(selectImg2imgStrength);
|
||||
const dispatch = useAppDispatch();
|
||||
const isEnabled = useAppSelector(selectIsEnabled);
|
||||
const hasRasterLayersWithContent = useAppSelector(selectHasRasterLayersWithContent);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
@@ -37,16 +40,16 @@ export const ParamDenoisingStrength = memo(() => {
|
||||
const [invokeBlue300] = useToken('colors', ['invokeBlue.300']);
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!isEnabled} p={1} justifyContent="space-between" h={8}>
|
||||
<FormControl isDisabled={!hasRasterLayersWithContent} p={1} justifyContent="space-between" h={8}>
|
||||
<Flex gap={3} alignItems="center">
|
||||
<InformationalPopover feature="paramDenoisingStrength">
|
||||
<FormLabel mr={0}>{`${t('parameters.denoisingStrength')}`}</FormLabel>
|
||||
</InformationalPopover>
|
||||
{isEnabled && (
|
||||
{hasRasterLayersWithContent && (
|
||||
<WavyLine amplitude={img2imgStrength * 10} stroke={invokeBlue300} strokeWidth={1} width={40} height={14} />
|
||||
)}
|
||||
</Flex>
|
||||
{isEnabled ? (
|
||||
{hasRasterLayersWithContent ? (
|
||||
<>
|
||||
<CompositeSlider
|
||||
step={config.coarseStep}
|
||||
@@ -70,9 +73,7 @@ export const ParamDenoisingStrength = memo(() => {
|
||||
</>
|
||||
) : (
|
||||
<Flex alignItems="center">
|
||||
<Badge opacity="0.6">
|
||||
{t('common.disabled')} - {t('parameters.noRasterLayers')}
|
||||
</Badge>
|
||||
<Badge opacity="0.6">{t('parameters.disabledNoRasterContent')}</Badge>
|
||||
</Flex>
|
||||
)}
|
||||
</FormControl>
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
import {
|
||||
roundDownToMultiple,
|
||||
roundToMultiple,
|
||||
roundToMultipleMin,
|
||||
roundUpToMultiple,
|
||||
} from 'common/util/roundDownToMultiple';
|
||||
import { roundToMultiple, roundToMultipleMin } from 'common/util/roundDownToMultiple';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
|
||||
import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { fitRectToGrid, getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectBbox } from 'features/controlLayers/store/selectors';
|
||||
import type { Coordinate, Rect } from 'features/controlLayers/store/types';
|
||||
@@ -398,18 +393,12 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
|
||||
}
|
||||
|
||||
// Determine the bbox size that fits within the visible rect. The bbox must be at least 64px in width and height,
|
||||
// and its width and height must be multiples of 8px.
|
||||
// and its width and height must be multiples of the bbox grid size.
|
||||
const gridSize = this.manager.stateApi.getBboxGridSize();
|
||||
|
||||
// To be conservative, we will round up the x and y to the nearest grid size, and round down the width and height.
|
||||
// This ensures the bbox is never _larger_ than the visible rect. If the bbox is larger than the visible, we
|
||||
// will always trigger the outpainting workflow, which is not what the user wants.
|
||||
const x = roundUpToMultiple(visibleRect.x, gridSize);
|
||||
const y = roundUpToMultiple(visibleRect.y, gridSize);
|
||||
const width = roundDownToMultiple(visibleRect.width, gridSize);
|
||||
const height = roundDownToMultiple(visibleRect.height, gridSize);
|
||||
const rect = fitRectToGrid(visibleRect, gridSize);
|
||||
|
||||
this.manager.stateApi.setGenerationBbox({ x, y, width, height });
|
||||
this.manager.stateApi.setGenerationBbox(rect);
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util';
|
||||
import { roundUpToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { fitRectToGrid, getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util';
|
||||
import type { Rect } from 'features/controlLayers/store/types';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe('util', () => {
|
||||
@@ -44,4 +46,74 @@ describe('util', () => {
|
||||
expect(union).toEqual({ x: 0, y: 0, width: 0, height: 0 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('fitRectToGrid', () => {
|
||||
it('should fit rect within grid without exceeding bounds', () => {
|
||||
const rect: Rect = { x: 0, y: 0, width: 1047, height: 1758 };
|
||||
const gridSize = 50;
|
||||
const result = fitRectToGrid(rect, gridSize);
|
||||
|
||||
expect(result.x).toBe(roundUpToMultiple(rect.x, gridSize));
|
||||
expect(result.y).toBe(roundUpToMultiple(rect.y, gridSize));
|
||||
expect(result.width).toBeLessThanOrEqual(rect.width);
|
||||
expect(result.height).toBeLessThanOrEqual(rect.height);
|
||||
expect(result.width % gridSize).toBe(0);
|
||||
expect(result.height % gridSize).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle small rect within grid bounds', () => {
|
||||
const rect: Rect = { x: 20, y: 30, width: 80, height: 90 };
|
||||
const gridSize = 25;
|
||||
const result = fitRectToGrid(rect, gridSize);
|
||||
|
||||
expect(result.x).toBe(25);
|
||||
expect(result.y).toBe(50);
|
||||
expect(result.width % gridSize).toBe(0);
|
||||
expect(result.height % gridSize).toBe(0);
|
||||
expect(result.width).toBeLessThanOrEqual(rect.width);
|
||||
expect(result.height).toBeLessThanOrEqual(rect.height);
|
||||
});
|
||||
|
||||
it('should handle rect starting outside of grid alignment', () => {
|
||||
const rect: Rect = { x: 13, y: 27, width: 94, height: 112 };
|
||||
const gridSize = 20;
|
||||
const result = fitRectToGrid(rect, gridSize);
|
||||
|
||||
expect(result.x).toBe(20);
|
||||
expect(result.y).toBe(40);
|
||||
expect(result.width % gridSize).toBe(0);
|
||||
expect(result.height % gridSize).toBe(0);
|
||||
expect(result.width).toBeLessThanOrEqual(rect.width);
|
||||
expect(result.height).toBeLessThanOrEqual(rect.height);
|
||||
});
|
||||
|
||||
it('should return the same rect if already aligned to grid', () => {
|
||||
const rect: Rect = { x: 100, y: 100, width: 200, height: 300 };
|
||||
const gridSize = 50;
|
||||
const result = fitRectToGrid(rect, gridSize);
|
||||
|
||||
expect(result).toEqual(rect);
|
||||
});
|
||||
|
||||
it('should handle large grid sizes relative to rect dimensions', () => {
|
||||
const rect: Rect = { x: 250, y: 300, width: 400, height: 500 };
|
||||
const gridSize = 100;
|
||||
const result = fitRectToGrid(rect, gridSize);
|
||||
|
||||
expect(result.x).toBe(300);
|
||||
expect(result.y).toBe(300);
|
||||
expect(result.width % gridSize).toBe(0);
|
||||
expect(result.height % gridSize).toBe(0);
|
||||
expect(result.width).toBeLessThanOrEqual(rect.width);
|
||||
expect(result.height).toBeLessThanOrEqual(rect.height);
|
||||
});
|
||||
|
||||
it('should handle rect with zero width and height', () => {
|
||||
const rect: Rect = { x: 40, y: 60, width: 100, height: 200 };
|
||||
const gridSize = 20;
|
||||
const result = fitRectToGrid(rect, gridSize);
|
||||
|
||||
expect(result).toEqual({ x: 40, y: 60, width: 100, height: 200 });
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { Selector, Store } from '@reduxjs/toolkit';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { roundDownToMultiple, roundUpToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import type {
|
||||
CanvasEntityIdentifier,
|
||||
CanvasObjectState,
|
||||
@@ -560,6 +561,33 @@ export const getRectIntersection = (...rects: Rect[]): Rect => {
|
||||
return rect || getEmptyRect();
|
||||
};
|
||||
|
||||
/**
|
||||
* Fits a rect to the nearest multiple of the grid size, rounding down. The returned rect will be smaller than or equal
|
||||
* to the input rect, and will be aligned to the grid.
|
||||
*
|
||||
* In other words, shrink the rect inwards on each size until it fits within the visible rect and aligns to the grid.
|
||||
*
|
||||
* @param rect The rect to fit
|
||||
* @param gridSize The size of the grid
|
||||
* @returns The fitted rect
|
||||
*/
|
||||
export const fitRectToGrid = (rect: Rect, gridSize: number): Rect => {
|
||||
// Rounding x and y up effectively shrinks the left and top edges of the rect, and rounding width and height down
|
||||
// effectively shrinks the right and bottom edges.
|
||||
const x = roundUpToMultiple(rect.x, gridSize);
|
||||
const y = roundUpToMultiple(rect.y, gridSize);
|
||||
|
||||
// Because we've just shifted the rect's x and y, we need to adjust the width and height by the same amount before
|
||||
// we round those values down.
|
||||
const offsetX = x - rect.x;
|
||||
const offsetY = y - rect.y;
|
||||
|
||||
const width = roundDownToMultiple(rect.width - offsetX, gridSize);
|
||||
const height = roundDownToMultiple(rect.height - offsetY, gridSize);
|
||||
|
||||
return { x, y, width, height };
|
||||
};
|
||||
|
||||
/**
|
||||
* Asserts that the value is never reached. Used for exhaustive checks in switch statements or conditional logic to ensure
|
||||
* that all possible values are handled.
|
||||
|
||||
@@ -20,12 +20,15 @@ const sx = {
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
type Props = ImageProps & {
|
||||
imageDTO: ImageDTO;
|
||||
asThumbnail?: boolean;
|
||||
};
|
||||
/* eslint-disable-next-line @typescript-eslint/no-namespace */
|
||||
export namespace DndImage {
|
||||
export interface Props extends ImageProps {
|
||||
imageDTO: ImageDTO;
|
||||
asThumbnail?: boolean;
|
||||
}
|
||||
}
|
||||
|
||||
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: Props) => {
|
||||
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: DndImage.Props) => {
|
||||
const store = useAppStore();
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const [element, ref] = useState<HTMLImageElement | null>(null);
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import { IAINoContentFallback, IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { memo } from 'react';
|
||||
import { PiExclamationMarkBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-namespace */
|
||||
namespace DndImageFromImageName {
|
||||
export interface Props extends Omit<DndImage.Props, 'imageDTO'> {
|
||||
imageName: string;
|
||||
}
|
||||
}
|
||||
|
||||
export const DndImageFromImageName = memo(({ imageName, ...rest }: DndImageFromImageName.Props) => {
|
||||
const query = useGetImageDTOQuery(imageName);
|
||||
if (query.isLoading) {
|
||||
return <IAINoContentFallbackWithSpinner />;
|
||||
}
|
||||
if (!query.data) {
|
||||
return <IAINoContentFallback icon={<PiExclamationMarkBold />} />;
|
||||
}
|
||||
|
||||
return <DndImage imageDTO={query.data} {...rest} />;
|
||||
});
|
||||
|
||||
DndImageFromImageName.displayName = 'DndImageFromImageName';
|
||||
@@ -11,7 +11,7 @@ import type { DndTargetState } from 'features/dnd/types';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useEffect, useRef, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { type UploadImageArg, uploadImages } from 'services/api/endpoints/images';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
@@ -71,13 +71,46 @@ export const FullscreenDropzone = memo(() => {
|
||||
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
|
||||
const [dndState, setDndState] = useState<DndTargetState>('idle');
|
||||
|
||||
const uploadFilesSchema = useMemo(() => getFilesSchema(maxImageUploadCount), [maxImageUploadCount]);
|
||||
|
||||
const validateAndUploadFiles = useCallback(
|
||||
(files: File[]) => {
|
||||
const { getState } = getStore();
|
||||
const parseResult = uploadFilesSchema.safeParse(files);
|
||||
|
||||
if (!parseResult.success) {
|
||||
const description =
|
||||
maxImageUploadCount === undefined
|
||||
? t('toast.uploadFailedInvalidUploadDesc')
|
||||
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
|
||||
|
||||
toast({
|
||||
id: 'UPLOAD_FAILED',
|
||||
title: t('toast.uploadFailed'),
|
||||
description,
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
const autoAddBoardId = selectAutoAddBoardId(getState());
|
||||
|
||||
const uploadArgs: UploadImageArg[] = files.map((file) => ({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
}));
|
||||
|
||||
uploadImages(uploadArgs);
|
||||
},
|
||||
[maxImageUploadCount, t, uploadFilesSchema]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const element = ref.current;
|
||||
if (!element) {
|
||||
return;
|
||||
}
|
||||
const { getState } = getStore();
|
||||
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
|
||||
|
||||
return combine(
|
||||
dropTargetForExternal({
|
||||
@@ -85,32 +118,7 @@ export const FullscreenDropzone = memo(() => {
|
||||
canDrop: containsFiles,
|
||||
onDrop: ({ source }) => {
|
||||
const files = getFiles({ source });
|
||||
const parseResult = uploadFilesSchema.safeParse(files);
|
||||
|
||||
if (!parseResult.success) {
|
||||
const description =
|
||||
maxImageUploadCount === undefined
|
||||
? t('toast.uploadFailedInvalidUploadDesc')
|
||||
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
|
||||
|
||||
toast({
|
||||
id: 'UPLOAD_FAILED',
|
||||
title: t('toast.uploadFailed'),
|
||||
description,
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
const autoAddBoardId = selectAutoAddBoardId(getState());
|
||||
|
||||
const uploadArgs: UploadImageArg[] = files.map((file) => ({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
}));
|
||||
|
||||
uploadImages(uploadArgs);
|
||||
validateAndUploadFiles(files);
|
||||
},
|
||||
onDragEnter: () => {
|
||||
setDndState('over');
|
||||
@@ -131,7 +139,27 @@ export const FullscreenDropzone = memo(() => {
|
||||
},
|
||||
})
|
||||
);
|
||||
}, [maxImageUploadCount, t]);
|
||||
}, [validateAndUploadFiles]);
|
||||
|
||||
useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
|
||||
document.addEventListener(
|
||||
'paste',
|
||||
(e) => {
|
||||
if (!e.clipboardData?.files) {
|
||||
return;
|
||||
}
|
||||
const files = Array.from(e.clipboardData.files);
|
||||
validateAndUploadFiles(files);
|
||||
},
|
||||
{ signal: controller.signal }
|
||||
);
|
||||
|
||||
return () => {
|
||||
controller.abort();
|
||||
};
|
||||
}, [validateAndUploadFiles]);
|
||||
|
||||
return (
|
||||
<Box ref={ref} data-dnd-state={dndState} sx={sx}>
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
setRegionalGuidanceReferenceImage,
|
||||
setUpscaleInitialImage,
|
||||
} from 'features/imageActions/actions';
|
||||
import { batchImageInputNodeImagesAdded } from 'features/nodes/store/nodesSlice';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
@@ -241,6 +242,47 @@ export const setNodeImageFieldImageDndTarget: DndTarget<SetNodeImageFieldImageDn
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//#region Add Images to Batch Image Input Node
|
||||
const _addImagesToBatchImageInputNode = buildTypeAndKey('add-images-to-batch-image-input-node');
|
||||
export type AddImagesToBatchImageInputNodeDndTargetData = DndData<
|
||||
typeof _addImagesToBatchImageInputNode.type,
|
||||
typeof _addImagesToBatchImageInputNode.key,
|
||||
{ nodeId: string }
|
||||
>;
|
||||
export const addImagesToBatchImageInputNodeDndTarget: DndTarget<
|
||||
AddImagesToBatchImageInputNodeDndTargetData,
|
||||
SingleImageDndSourceData | MultipleImageDndSourceData
|
||||
> = {
|
||||
..._addImagesToBatchImageInputNode,
|
||||
typeGuard: buildTypeGuard(_addImagesToBatchImageInputNode.key),
|
||||
getData: buildGetData(_addImagesToBatchImageInputNode.key, _addImagesToBatchImageInputNode.type),
|
||||
isValid: ({ sourceData }) => {
|
||||
if (singleImageDndSource.typeGuard(sourceData) || multipleImageDndSource.typeGuard(sourceData)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
handler: ({ sourceData, targetData, dispatch }) => {
|
||||
if (!singleImageDndSource.typeGuard(sourceData) && !multipleImageDndSource.typeGuard(sourceData)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { nodeId } = targetData.payload;
|
||||
const imageDTOs: ImageDTO[] = [];
|
||||
|
||||
if (singleImageDndSource.typeGuard(sourceData)) {
|
||||
imageDTOs.push(sourceData.payload.imageDTO);
|
||||
} else {
|
||||
imageDTOs.push(...sourceData.payload.imageDTOs);
|
||||
}
|
||||
|
||||
const images = imageDTOs.map(({ image_name }) => ({ image_name }));
|
||||
|
||||
dispatch(batchImageInputNodeImagesAdded({ nodeId, images }));
|
||||
},
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//# Set Comparison Image
|
||||
const _setComparisonImage = buildTypeAndKey('set-comparison-image');
|
||||
export type SetComparisonImageDndTargetData = DndData<
|
||||
@@ -430,6 +472,7 @@ export const dndTargets = [
|
||||
// Single or Multiple Image
|
||||
addImageToBoardDndTarget,
|
||||
removeImageFromBoardDndTarget,
|
||||
addImagesToBatchImageInputNodeDndTarget,
|
||||
] as const;
|
||||
|
||||
export type AnyDndTarget = (typeof dndTargets)[number];
|
||||
|
||||
@@ -347,6 +347,18 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
const batchInputImageItem = useMemo<FilterableItem>(
|
||||
() =>
|
||||
({
|
||||
type: 'batch_input',
|
||||
title: 'batch_input',
|
||||
description: 'batch_input',
|
||||
tags: ['batch'],
|
||||
classification: 'stable',
|
||||
nodePack: 'invokeai',
|
||||
}) as const,
|
||||
[]
|
||||
);
|
||||
|
||||
const items = useMemo<NodeCommandItemData[]>(() => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
@@ -365,7 +377,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
|
||||
}
|
||||
}
|
||||
|
||||
for (const item of [currentImageFilterItem, notesFilterItem]) {
|
||||
for (const item of [currentImageFilterItem, notesFilterItem, batchInputImageItem]) {
|
||||
if (filter(item, searchTerm)) {
|
||||
_items.push({
|
||||
label: item.title,
|
||||
@@ -403,7 +415,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
|
||||
}
|
||||
|
||||
return _items;
|
||||
}, [pendingConnection, currentImageFilterItem, searchTerm, notesFilterItem, templatesArray]);
|
||||
}, [pendingConnection, templatesArray, searchTerm, currentImageFilterItem, notesFilterItem, batchInputImageItem]);
|
||||
|
||||
return (
|
||||
<>
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { BatchImageInputNode } from 'features/nodes/components/flow/nodes/BatchInput/BatchImageInputNode';
|
||||
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
||||
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
@@ -64,6 +65,7 @@ const nodeTypes = {
|
||||
invocation: InvocationNodeWrapper,
|
||||
current_image: CurrentImageNode,
|
||||
notes: NotesNode,
|
||||
image_batch: BatchImageInputNode,
|
||||
};
|
||||
|
||||
// TODO: can we support reactflow? if not, we could style the attribution so it matches the app
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
import { Box, Flex, Grid, GridItem, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import type { AddImagesToBatchImageInputNodeDndTargetData } from 'features/dnd/dnd';
|
||||
import { addImagesToBatchImageInputNodeDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { DndImageFromImageName } from 'features/dnd/DndImageFromImageName';
|
||||
import NodeCollapseButton from 'features/nodes/components/flow/nodes/common/NodeCollapseButton';
|
||||
import NodeTitle from 'features/nodes/components/flow/nodes/common/NodeTitle';
|
||||
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
|
||||
import FieldHandle from 'features/nodes/components/flow/nodes/Invocation/fields/FieldHandle';
|
||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||
import { batchImageInputNodeReset } from 'features/nodes/store/nodesSlice';
|
||||
import { imageBatchOutputFieldTemplate } from 'features/nodes/types/field';
|
||||
import type { ImageBatchNodeData } from 'features/nodes/types/invocation';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
import type { NodeProps } from 'reactflow';
|
||||
|
||||
export const BatchImageInputNode = memo((props: NodeProps<ImageBatchNodeData>) => {
|
||||
const { id: nodeId, data, selected } = props;
|
||||
const { images, isOpen } = data;
|
||||
const dispatch = useAppDispatch();
|
||||
const onReset = useCallback(() => {
|
||||
dispatch(batchImageInputNodeReset({ nodeId }));
|
||||
}, [dispatch, nodeId]);
|
||||
const targetData = useMemo<AddImagesToBatchImageInputNodeDndTargetData>(
|
||||
() => addImagesToBatchImageInputNodeDndTarget.getData({ nodeId }),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<NodeWrapper nodeId={nodeId} selected={selected}>
|
||||
<Flex
|
||||
layerStyle="nodeHeader"
|
||||
borderTopRadius="base"
|
||||
borderBottomRadius={isOpen ? 0 : 'base'}
|
||||
alignItems="center"
|
||||
justifyContent="space-between"
|
||||
h={8}
|
||||
>
|
||||
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
|
||||
<NodeTitle nodeId={nodeId} title="Batch Image Input" />
|
||||
<Box minW={8} />
|
||||
</Flex>
|
||||
{isOpen && (
|
||||
<>
|
||||
<Flex
|
||||
position="relative"
|
||||
layerStyle="nodeBody"
|
||||
className="nopan"
|
||||
cursor="auto"
|
||||
flexDirection="column"
|
||||
borderBottomRadius="base"
|
||||
w="full"
|
||||
h="full"
|
||||
p={2}
|
||||
gap={1}
|
||||
minH={16}
|
||||
>
|
||||
<Grid className="nopan" w="full" h="full" templateColumns="repeat(3, 1fr)" gap={2}>
|
||||
{images.map(({ image_name }) => (
|
||||
<GridItem key={image_name}>
|
||||
<DndImageFromImageName imageName={image_name} asThumbnail />
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
<IconButton
|
||||
aria-label="reset"
|
||||
icon={<PiArrowCounterClockwiseBold />}
|
||||
position="absolute"
|
||||
top={0}
|
||||
insetInlineEnd={0}
|
||||
onClick={onReset}
|
||||
variant="ghost"
|
||||
/>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
<ImageBatchOutputField nodeId={nodeId} />
|
||||
|
||||
<DndDropTarget
|
||||
dndTarget={addImagesToBatchImageInputNodeDndTarget}
|
||||
dndTargetData={targetData}
|
||||
label="Add to Batch"
|
||||
/>
|
||||
</NodeWrapper>
|
||||
);
|
||||
});
|
||||
|
||||
BatchImageInputNode.displayName = 'BatchImageInputNode';
|
||||
|
||||
const ImageBatchOutputField = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
|
||||
useConnectionState({ nodeId, fieldName: 'images', kind: 'outputs' });
|
||||
|
||||
return (
|
||||
<Flex
|
||||
position="absolute"
|
||||
minH={8}
|
||||
top="50%"
|
||||
translateY="-50%"
|
||||
insetInlineEnd={2}
|
||||
alignItems="center"
|
||||
opacity={shouldDim ? 0.5 : 1}
|
||||
transitionProperty="opacity"
|
||||
transitionDuration="0.1s"
|
||||
justifyContent="flex-end"
|
||||
>
|
||||
<FieldHandle
|
||||
fieldTemplate={imageBatchOutputFieldTemplate}
|
||||
handleType="source"
|
||||
isConnectionInProgress={isConnectionInProgress}
|
||||
isConnectionStartField={isConnectionStartField}
|
||||
validationResult={validationResult}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
ImageBatchOutputField.displayName = 'ImageBatchOutputField';
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Flex, Icon, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { compare } from 'compare-versions';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNode } from 'features/nodes/hooks/useNode';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiInfoBold } from 'react-icons/pi';
|
||||
@@ -25,32 +25,32 @@ const InvocationNodeInfoIcon = ({ nodeId }: Props) => {
|
||||
export default memo(InvocationNodeInfoIcon);
|
||||
|
||||
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const data = useNodeData(nodeId);
|
||||
const node = useNode(nodeId);
|
||||
const nodeTemplate = useNodeTemplate(nodeId);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const title = useMemo(() => {
|
||||
if (data?.label && nodeTemplate?.title) {
|
||||
return `${data.label} (${nodeTemplate.title})`;
|
||||
if (node.data?.label && nodeTemplate?.title) {
|
||||
return `${node.data.label} (${nodeTemplate.title})`;
|
||||
}
|
||||
|
||||
if (data?.label && !nodeTemplate) {
|
||||
return data.label;
|
||||
if (node.data?.label && !nodeTemplate) {
|
||||
return node.data.label;
|
||||
}
|
||||
|
||||
if (!data?.label && nodeTemplate) {
|
||||
if (!node.data?.label && nodeTemplate) {
|
||||
return nodeTemplate.title;
|
||||
}
|
||||
|
||||
return t('nodes.unknownNode');
|
||||
}, [data, nodeTemplate, t]);
|
||||
}, [node.data.label, nodeTemplate, t]);
|
||||
|
||||
const versionComponent = useMemo(() => {
|
||||
if (!isInvocationNodeData(data) || !nodeTemplate) {
|
||||
if (!isInvocationNode(node) || !nodeTemplate) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!data.version) {
|
||||
if (!node.data.version) {
|
||||
return (
|
||||
<Text as="span" color="error.500">
|
||||
{t('nodes.versionUnknown')}
|
||||
@@ -61,35 +61,35 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
if (!nodeTemplate.version) {
|
||||
return (
|
||||
<Text as="span" color="error.500">
|
||||
{t('nodes.version')} {data.version} ({t('nodes.unknownTemplate')})
|
||||
{t('nodes.version')} {node.data.version} ({t('nodes.unknownTemplate')})
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (compare(data.version, nodeTemplate.version, '<')) {
|
||||
if (compare(node.data.version, nodeTemplate.version, '<')) {
|
||||
return (
|
||||
<Text as="span" color="error.500">
|
||||
{t('nodes.version')} {data.version} ({t('nodes.updateNode')})
|
||||
{t('nodes.version')} {node.data.version} ({t('nodes.updateNode')})
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (compare(data.version, nodeTemplate.version, '>')) {
|
||||
if (compare(node.data.version, nodeTemplate.version, '>')) {
|
||||
return (
|
||||
<Text as="span" color="error.500">
|
||||
{t('nodes.version')} {data.version} ({t('nodes.updateApp')})
|
||||
{t('nodes.version')} {node.data.version} ({t('nodes.updateApp')})
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Text as="span">
|
||||
{t('nodes.version')} {data.version}
|
||||
{t('nodes.version')} {node.data.version}
|
||||
</Text>
|
||||
);
|
||||
}, [data, nodeTemplate, t]);
|
||||
}, [node, nodeTemplate, t]);
|
||||
|
||||
if (!isInvocationNodeData(data)) {
|
||||
if (!isInvocationNode(node)) {
|
||||
return <Text fontWeight="semibold">{t('nodes.unknownNode')}</Text>;
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
{nodeTemplate?.description}
|
||||
</Text>
|
||||
{versionComponent}
|
||||
{data?.notes && <Text>{data.notes}</Text>}
|
||||
{node.data?.notes && <Text>{node.data.notes}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import { FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNode } from 'features/nodes/hooks/useNode';
|
||||
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const data = useNodeData(nodeId);
|
||||
const node = useNode(nodeId);
|
||||
const { t } = useTranslation();
|
||||
const handleNotesChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
@@ -17,13 +17,13 @@ const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
|
||||
},
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
if (!isInvocationNodeData(data)) {
|
||||
if (!isInvocationNode(node)) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<FormControl orientation="vertical" h="full">
|
||||
<FormLabel>{t('nodes.notes')}</FormLabel>
|
||||
<Textarea value={data?.notes} onChange={handleNotesChanged} rows={10} resize="none" />
|
||||
<Textarea value={node.data?.notes} onChange={handleNotesChanged} rows={10} resize="none" />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -23,6 +23,8 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
|
||||
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
|
||||
|
||||
console.log({ isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim });
|
||||
|
||||
const isMissingInput = useMemo(() => {
|
||||
if (!fieldTemplate) {
|
||||
return false;
|
||||
|
||||
@@ -64,7 +64,7 @@ type OutputFieldWrapperProps = PropsWithChildren<{
|
||||
shouldDim: boolean;
|
||||
}>;
|
||||
|
||||
const OutputFieldWrapper = memo(({ shouldDim, children }: OutputFieldWrapperProps) => (
|
||||
export const OutputFieldWrapper = memo(({ shouldDim, children }: OutputFieldWrapperProps) => (
|
||||
<Flex
|
||||
position="relative"
|
||||
minH={8}
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useStore } from '@nanostores/react';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { NODE_WIDTH } from 'features/nodes/types/constants';
|
||||
import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { buildBatchInputImageNode } from 'features/nodes/util/node/buildBatchInputImageNode';
|
||||
import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode';
|
||||
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||
import { buildNotesNode } from 'features/nodes/util/node/buildNotesNode';
|
||||
@@ -14,7 +15,7 @@ export const useBuildNode = () => {
|
||||
|
||||
return useCallback(
|
||||
// string here is "any invocation type"
|
||||
(type: string | 'current_image' | 'notes'): AnyNode => {
|
||||
(type: string | 'current_image' | 'notes' | 'batch_input'): AnyNode => {
|
||||
let _x = window.innerWidth / 2;
|
||||
let _y = window.innerHeight / 2;
|
||||
|
||||
@@ -39,6 +40,10 @@ export const useBuildNode = () => {
|
||||
return buildNotesNode(position);
|
||||
}
|
||||
|
||||
if (type === 'batch_input') {
|
||||
return buildBatchInputImageNode(position);
|
||||
}
|
||||
|
||||
// TODO: Keep track of invocation types so we do not need to cast this
|
||||
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
|
||||
const template = templates[type] as InvocationTemplate;
|
||||
|
||||
@@ -12,6 +12,11 @@ import {
|
||||
import { selectNodes, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import {
|
||||
type FieldInputTemplate,
|
||||
type FieldOutputTemplate,
|
||||
imageBatchOutputFieldTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
|
||||
import { useUpdateNodeInternals } from 'reactflow';
|
||||
@@ -33,13 +38,20 @@ export const useConnection = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return;
|
||||
let fieldTemplate: FieldInputTemplate | FieldOutputTemplate | undefined = undefined;
|
||||
|
||||
if (node.type === 'image_batch' && handleId === 'images' && handleType === 'source') {
|
||||
fieldTemplate = imageBatchOutputFieldTemplate;
|
||||
} else {
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return;
|
||||
}
|
||||
|
||||
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
|
||||
fieldTemplate = fieldTemplates[handleId];
|
||||
}
|
||||
|
||||
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
|
||||
const fieldTemplate = fieldTemplates[handleId];
|
||||
if (!fieldTemplate) {
|
||||
return;
|
||||
}
|
||||
|
||||
19
invokeai/frontend/web/src/features/nodes/hooks/useNode.ts
Normal file
19
invokeai/frontend/web/src/features/nodes/hooks/useNode.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNode, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
import type { Node } from 'reactflow';
|
||||
|
||||
export const useNode = (nodeId: string): Node => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
return selectNode(nodes, nodeId);
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const node = useAppSelector(selector);
|
||||
|
||||
return node;
|
||||
};
|
||||
@@ -3,6 +3,7 @@ import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig } from 'app/store/store';
|
||||
import { buildUseBoolean } from 'common/hooks/useBoolean';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
BoardFieldValue,
|
||||
@@ -58,7 +59,8 @@ import {
|
||||
zVAEModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import { isImageBatchNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import { uniqBy } from 'lodash-es';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { MouseEvent } from 'react';
|
||||
import type { Edge, EdgeChange, NodeChange, Viewport, XYPosition } from 'reactflow';
|
||||
@@ -193,7 +195,7 @@ export const nodesSlice = createSlice({
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
if (!isInvocationNode(node) && !isNotesNode(node)) {
|
||||
if (!isInvocationNode(node) && !isNotesNode(node) && !isImageBatchNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -382,6 +384,34 @@ export const nodesSlice = createSlice({
|
||||
}
|
||||
node.data.notes = value;
|
||||
},
|
||||
batchImageInputNodeImagesAdded: (state, action: PayloadAction<{ nodeId: string; images: ImageField[] }>) => {
|
||||
const { nodeId, images } = action.payload;
|
||||
const node = state.nodes.find((n) => n.id === nodeId);
|
||||
if (!isImageBatchNode(node)) {
|
||||
return;
|
||||
}
|
||||
node.data.images = uniqBy([...node.data.images, ...images], 'image_name');
|
||||
},
|
||||
batchImageInputNodeImagesRemoved: (state, action: PayloadAction<{ nodeId: string; images: ImageField[] }>) => {
|
||||
const { nodeId, images } = action.payload;
|
||||
const node = state.nodes.find((n) => n.id === nodeId);
|
||||
if (!isImageBatchNode(node)) {
|
||||
return;
|
||||
}
|
||||
const imageNamesToRemove = images.map(({ image_name }) => image_name);
|
||||
node.data.images = uniqBy(
|
||||
node.data.images.filter(({ image_name }) => imageNamesToRemove.includes(image_name)),
|
||||
'image_name'
|
||||
);
|
||||
},
|
||||
batchImageInputNodeReset: (state, action: PayloadAction<{ nodeId: string }>) => {
|
||||
const { nodeId } = action.payload;
|
||||
const node = state.nodes.find((n) => n.id === nodeId);
|
||||
if (!isImageBatchNode(node)) {
|
||||
return;
|
||||
}
|
||||
node.data.images = [];
|
||||
},
|
||||
nodeEditorReset: (state) => {
|
||||
state.nodes = [];
|
||||
state.edges = [];
|
||||
@@ -443,6 +473,9 @@ export const {
|
||||
notesNodeValueChanged,
|
||||
undo,
|
||||
redo,
|
||||
batchImageInputNodeImagesAdded,
|
||||
batchImageInputNodeImagesRemoved,
|
||||
batchImageInputNodeReset,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export const $cursorPos = atom<XYPosition | null>(null);
|
||||
|
||||
@@ -5,8 +5,15 @@ import type { NodesState } from 'features/nodes/store/types';
|
||||
import type { FieldInputInstance } from 'features/nodes/types/field';
|
||||
import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { Node } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const selectNode = (nodesSlice: NodesState, nodeId: string): Node => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
assert(node !== undefined, `Node ${nodeId} not found`);
|
||||
return node;
|
||||
};
|
||||
|
||||
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);
|
||||
|
||||
@@ -3,6 +3,7 @@ import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import { imageBatchOutputFieldTemplate } from 'features/nodes/types/field';
|
||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||
import type { Connection as NullableConnection, Edge } from 'reactflow';
|
||||
import type { SetNonNullable } from 'type-fest';
|
||||
@@ -78,23 +79,32 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
|
||||
return buildRejectResult('nodes.missingNode');
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (!sourceTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
if (!targetTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
|
||||
if (!sourceFieldTemplate) {
|
||||
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
|
||||
if (!targetFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
}
|
||||
|
||||
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
|
||||
if (!targetFieldTemplate) {
|
||||
if (sourceNode.type === 'image_batch') {
|
||||
const isValid = validateConnectionTypes(imageBatchOutputFieldTemplate.type, targetFieldTemplate.type);
|
||||
if (!isValid) {
|
||||
return buildRejectResult('nodes.fieldTypesMustMatch');
|
||||
} else {
|
||||
return buildAcceptResult();
|
||||
}
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (!sourceTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
|
||||
if (!sourceFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,10 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field
|
||||
return false;
|
||||
}
|
||||
|
||||
if (sourceType.name === 'ImageBatchField') {
|
||||
return isSingle(sourceType) && targetType.name === 'ImageField';
|
||||
}
|
||||
|
||||
if (areTypesEqual(sourceType, targetType)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -59,6 +59,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
EnumField: 'blue.500',
|
||||
FloatField: 'orange.500',
|
||||
ImageField: 'purple.500',
|
||||
ImageBatchField: 'purple.500',
|
||||
IntegerField: 'red.500',
|
||||
IPAdapterField: 'teal.500',
|
||||
IPAdapterModelField: 'teal.500',
|
||||
|
||||
@@ -95,6 +95,10 @@ const zImageFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ImageField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zImageBatchFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ImageBatchField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zBoardFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BoardField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -175,13 +179,14 @@ const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zStatefulFieldType = z.union([
|
||||
export const zStatefulFieldType = z.union([
|
||||
zIntegerFieldType,
|
||||
zFloatFieldType,
|
||||
zStringFieldType,
|
||||
zBooleanFieldType,
|
||||
zEnumFieldType,
|
||||
zImageFieldType,
|
||||
zImageBatchFieldType,
|
||||
zBoardFieldType,
|
||||
zModelIdentifierFieldType,
|
||||
zMainModelFieldType,
|
||||
@@ -367,6 +372,19 @@ export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputI
|
||||
zImageFieldInputInstance.safeParse(val).success;
|
||||
export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate =>
|
||||
zImageFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
const zImageBatchFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zImageBatchFieldType,
|
||||
});
|
||||
export type ImageBatchOutputFieldTemplate = z.infer<typeof zImageBatchFieldOutputTemplate>;
|
||||
|
||||
export const imageBatchOutputFieldTemplate: ImageBatchOutputFieldTemplate = {
|
||||
fieldKind: 'output',
|
||||
name: 'images',
|
||||
ui_hidden: false,
|
||||
type: { name: 'ImageBatchField', cardinality: 'SINGLE' },
|
||||
title: 'Image',
|
||||
};
|
||||
// #endregion
|
||||
|
||||
// #region BoardField
|
||||
@@ -991,6 +1009,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zBooleanFieldOutputTemplate,
|
||||
zEnumFieldOutputTemplate,
|
||||
zImageFieldOutputTemplate,
|
||||
zImageBatchFieldOutputTemplate,
|
||||
zBoardFieldOutputTemplate,
|
||||
zModelIdentifierFieldOutputTemplate,
|
||||
zMainModelFieldOutputTemplate,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { Edge, Node } from 'reactflow';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { zClassification, zProgressImage } from './common';
|
||||
import { zClassification, zImageField, zProgressImage } from './common';
|
||||
import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputTemplate } from './field';
|
||||
import { zSemVer } from './semver';
|
||||
|
||||
@@ -49,23 +49,33 @@ const zCurrentImageNodeData = z.object({
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
});
|
||||
const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData]);
|
||||
const zImageBatchNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('image_batch'),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
images: z.array(zImageField),
|
||||
});
|
||||
const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData, zImageBatchNodeData]);
|
||||
|
||||
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
|
||||
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
|
||||
type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
|
||||
export type ImageBatchNodeData = z.infer<typeof zImageBatchNodeData>;
|
||||
|
||||
type AnyNodeData = z.infer<typeof zAnyNodeData>;
|
||||
|
||||
export type InvocationNode = Node<InvocationNodeData, 'invocation'>;
|
||||
export type NotesNode = Node<NotesNodeData, 'notes'>;
|
||||
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>;
|
||||
export type BatchImageInputNode = Node<ImageBatchNodeData, 'image_batch'>;
|
||||
export type AnyNode = Node<AnyNodeData>;
|
||||
|
||||
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
|
||||
Boolean(node && node.type === 'invocation');
|
||||
export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes');
|
||||
export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData =>
|
||||
Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type
|
||||
export const isImageBatchNode = (node?: AnyNode | null): node is BatchImageInputNode =>
|
||||
Boolean(node && node.type === 'image_batch');
|
||||
// #endregion
|
||||
|
||||
// #region NodeExecutionState
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { omit, reduce } from 'lodash-es';
|
||||
import type { AnyInvocation, Graph } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const log = logger('workflows');
|
||||
|
||||
/**
|
||||
* Builds a graph from the node editor state.
|
||||
*/
|
||||
@@ -47,22 +50,31 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
||||
return nodesAccumulator;
|
||||
}, {});
|
||||
|
||||
const filteredNodeIds = filteredNodes.map(({ id }) => id);
|
||||
|
||||
// skip out the "dummy" edges between collapsed nodes
|
||||
const filteredEdges = edges.filter((n) => n.type !== 'collapsed');
|
||||
const filteredEdges = edges
|
||||
.filter((edge) => edge.type !== 'collapsed')
|
||||
.filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target));
|
||||
|
||||
// Reduce the node editor edges into invocation graph edges
|
||||
const parsedEdges = filteredEdges.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
|
||||
const { source, target, sourceHandle, targetHandle } = edge;
|
||||
|
||||
if (!sourceHandle || !targetHandle) {
|
||||
log.warn({ source, target, sourceHandle, targetHandle }, 'Missing source or taget handle for edge');
|
||||
return edgesAccumulator;
|
||||
}
|
||||
|
||||
// Format the edges and add to the edges array
|
||||
edgesAccumulator.push({
|
||||
source: {
|
||||
node_id: source,
|
||||
field: sourceHandle as string,
|
||||
field: sourceHandle,
|
||||
},
|
||||
destination: {
|
||||
node_id: target,
|
||||
field: targetHandle as string,
|
||||
field: targetHandle,
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type { BatchImageInputNode } from 'features/nodes/types/invocation';
|
||||
import type { XYPosition } from 'reactflow';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
export const buildBatchInputImageNode = (position: XYPosition): BatchImageInputNode => {
|
||||
const nodeId = uuidv4();
|
||||
const node: BatchImageInputNode = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
type: 'image_batch',
|
||||
position,
|
||||
data: {
|
||||
id: nodeId,
|
||||
isOpen: true,
|
||||
label: 'Image Batch',
|
||||
type: 'image_batch',
|
||||
images: [],
|
||||
},
|
||||
};
|
||||
return node;
|
||||
};
|
||||
@@ -34,8 +34,15 @@ export const UpscaleWarning = () => {
|
||||
dispatch(tileControlnetModelChanged(validModel || null));
|
||||
}, [model?.base, modelConfigs, dispatch]);
|
||||
|
||||
const isBaseModelCompatible = useMemo(() => {
|
||||
return model && ['sd-1', 'sdxl'].includes(model.base);
|
||||
}, [model]);
|
||||
|
||||
const modelWarnings = useMemo(() => {
|
||||
const _warnings: string[] = [];
|
||||
if (!isBaseModelCompatible) {
|
||||
return _warnings;
|
||||
}
|
||||
if (!model) {
|
||||
_warnings.push(t('upscaling.mainModelDesc'));
|
||||
}
|
||||
@@ -46,7 +53,7 @@ export const UpscaleWarning = () => {
|
||||
_warnings.push(t('upscaling.upscaleModelDesc'));
|
||||
}
|
||||
return _warnings;
|
||||
}, [model, tileControlnetModel, upscaleModel, t]);
|
||||
}, [isBaseModelCompatible, model, tileControlnetModel, upscaleModel, t]);
|
||||
|
||||
const otherWarnings = useMemo(() => {
|
||||
const _warnings: string[] = [];
|
||||
@@ -58,22 +65,25 @@ export const UpscaleWarning = () => {
|
||||
return _warnings;
|
||||
}, [isTooLargeToUpscale, t, maxUpscaleDimension]);
|
||||
|
||||
const allWarnings = useMemo(() => [...modelWarnings, ...otherWarnings], [modelWarnings, otherWarnings]);
|
||||
|
||||
const handleGoToModelManager = useCallback(() => {
|
||||
dispatch(setActiveTab('models'));
|
||||
$installModelsTab.set(3);
|
||||
}, [dispatch]);
|
||||
|
||||
if (modelWarnings.length && isModelsTabDisabled) {
|
||||
if (isBaseModelCompatible && modelWarnings.length > 0 && isModelsTabDisabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if ((!modelWarnings.length && !otherWarnings.length) || isLoading) {
|
||||
if ((isBaseModelCompatible && allWarnings.length === 0) || isLoading) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex bg="error.500" borderRadius="base" padding={4} direction="column" fontSize="sm" gap={2}>
|
||||
{!!modelWarnings.length && (
|
||||
{!isBaseModelCompatible && <Text>{t('upscaling.incompatibleBaseModelDesc')}</Text>}
|
||||
{modelWarnings.length > 0 && (
|
||||
<Text>
|
||||
<Trans
|
||||
i18nKey="upscaling.missingModelsWarning"
|
||||
@@ -85,11 +95,13 @@ export const UpscaleWarning = () => {
|
||||
/>
|
||||
</Text>
|
||||
)}
|
||||
<UnorderedList>
|
||||
{[...modelWarnings, ...otherWarnings].map((warning) => (
|
||||
<ListItem key={warning}>{warning}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
{allWarnings.length > 0 && (
|
||||
<UnorderedList>
|
||||
{allWarnings.map((warning) => (
|
||||
<ListItem key={warning}>{warning}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1822,7 +1822,7 @@ export type components = {
|
||||
* Items
|
||||
* @description The list of items to substitute into the node/field.
|
||||
*/
|
||||
items?: (string | number)[];
|
||||
items?: (string | number | components["schemas"]["ImageField"])[];
|
||||
};
|
||||
/**
|
||||
* BatchEnqueuedEvent
|
||||
@@ -6751,6 +6751,12 @@ export type components = {
|
||||
* @default 1
|
||||
*/
|
||||
denoising_end?: number;
|
||||
/**
|
||||
* Add Noise
|
||||
* @description Add noise based on denoising start.
|
||||
* @default true
|
||||
*/
|
||||
add_noise?: boolean;
|
||||
/**
|
||||
* Transformer
|
||||
* @description Flux model (Transformer) to load
|
||||
@@ -13161,7 +13167,7 @@ export type components = {
|
||||
* Value
|
||||
* @description The value to substitute into the node/field.
|
||||
*/
|
||||
value: string | number;
|
||||
value: string | number | components["schemas"]["ImageField"];
|
||||
};
|
||||
/**
|
||||
* Noise
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# State dict keys and shapes for an XLabs FLUX IP-Adapter model. Intended to be used for unit tests.
|
||||
# These keys were extracted from:
|
||||
# https://huggingface.co/XLabs-AI/flux-ip-adapter/blob/ad16be50d78a07ea83d8c4bde44ff9753235182e/flux-ip-adapter.safetensors
|
||||
# https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors
|
||||
xlabs_sd_shapes = {
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
|
||||
Reference in New Issue
Block a user