Compare commits

..

6 Commits

Author SHA1 Message Date
psychedelicious
c3304dcbe3 feat(ui): add image batching to workflows (wip) 2024-11-13 14:23:04 -08:00
Darrell
fb19621361 Updated link to flux ip adapter model 2024-11-12 08:11:40 -05:00
psychedelicious
3f880496f7 feat(ui): clarify denoising strength badge text 2024-11-09 08:38:41 +11:00
psychedelicious
79eb8172b6 feat(ui): update warnings on upscaling tab based on model arch
When an unsupported model architecture is selected, show that warning only, without the extra warnings (i.e. no "missing tile controlnet" warning)

Update Invoke tooltip warnings accordingly

Closes #7239
Closes #7177
2024-11-09 07:34:03 +11:00
psychedelicious
5b3e1593ca fix(ui): restore missing image paste handler
Missed migrating this logic over during dnd migration.
2024-11-08 16:42:39 +11:00
psychedelicious
2d08078a7d fix(ui): fit bbox to layers math 2024-11-08 16:40:24 +11:00
35 changed files with 658 additions and 139 deletions

View File

@@ -16,6 +16,7 @@ from pydantic import (
from pydantic_core import to_jsonable_python
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.fields import ImageField
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowWithoutID,
@@ -51,11 +52,7 @@ class SessionQueueItemNotFoundError(ValueError):
# region Batch
BatchDataType = Union[
StrictStr,
float,
int,
]
BatchDataType = Union[StrictStr, float, int, ImageField]
class NodeFieldValue(BaseModel):

View File

@@ -300,7 +300,7 @@ ip_adapter_sdxl = StarterModel(
ip_adapter_flux = StarterModel(
name="Standard Reference (XLabs FLUX IP-Adapter)",
base=BaseModelType.Flux,
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/flux-ip-adapter.safetensors",
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors",
description="References images with a more generalized/looser degree of precision.",
type=ModelType.IPAdapter,
dependencies=[clip_vit_l_image_encoder],

View File

@@ -997,7 +997,7 @@
"controlNetControlMode": "Control Mode",
"copyImage": "Copy Image",
"denoisingStrength": "Denoising Strength",
"noRasterLayers": "No Raster Layers",
"disabledNoRasterContent": "Disabled (No Raster Content)",
"downloadImage": "Download Image",
"general": "General",
"guidance": "Guidance",
@@ -1999,7 +1999,9 @@
"upscaleModelDesc": "Upscale (image to image) model",
"missingUpscaleInitialImage": "Missing initial image for upscaling",
"missingUpscaleModel": "Missing upscale model",
"missingTileControlNetModel": "No valid tile ControlNet models installed"
"missingTileControlNetModel": "No valid tile ControlNet models installed",
"incompatibleBaseModel": "Unsupported main model architecture for upscaling",
"incompatibleBaseModelDesc": "Upscaling is supported for SD1.5 and SDXL architecture models only. Change the main model to enable upscaling."
},
"stylePresets": {
"active": "Active",

View File

@@ -1,10 +1,11 @@
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { isImageBatchNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
import type { Batch, BatchConfig } from 'services/api/types';
export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) => {
startAppListening({
@@ -26,6 +27,18 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
delete builtWorkflow.id;
}
const data: Batch['data'] = [];
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
const imageBatchNodes = nodes.nodes.filter(isImageBatchNode);
for (const imageBatch of imageBatchNodes) {
const edge = nodes.edges.find((e) => e.source === imageBatch.id);
if (!edge || !edge.targetHandle) {
break;
}
data.push([{ node_path: edge.target, field_name: edge.targetHandle, items: imageBatch.data.images }]);
}
const batchConfig: BatchConfig = {
batch: {
graph,
@@ -33,6 +46,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
runs: state.params.iterations,
origin: 'workflows',
destination: 'gallery',
data,
},
prepend: action.payload.prepend,
};

View File

@@ -119,11 +119,20 @@ const createSelector = (
reasons.push({ content: i18n.t('upscaling.exceedsMaxSize') });
}
}
if (!upscale.upscaleModel) {
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
}
if (!upscale.tileControlnetModel) {
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
if (model && !['sd-1', 'sdxl'].includes(model.base)) {
// When we are using an upsupported model, do not add the other warnings
reasons.push({ content: i18n.t('upscaling.incompatibleBaseModel') });
} else {
// Using a compatible model, add all warnings
if (!model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (!upscale.upscaleModel) {
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
}
if (!upscale.tileControlnetModel) {
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
}
}
} else {
if (canvasIsFiltering) {

View File

@@ -17,12 +17,15 @@ import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selectIsEnabled = createSelector(selectActiveRasterLayerEntities, (entities) => entities.length > 0);
const selectHasRasterLayersWithContent = createSelector(
selectActiveRasterLayerEntities,
(entities) => entities.length > 0
);
export const ParamDenoisingStrength = memo(() => {
const img2imgStrength = useAppSelector(selectImg2imgStrength);
const dispatch = useAppDispatch();
const isEnabled = useAppSelector(selectIsEnabled);
const hasRasterLayersWithContent = useAppSelector(selectHasRasterLayersWithContent);
const onChange = useCallback(
(v: number) => {
@@ -37,16 +40,16 @@ export const ParamDenoisingStrength = memo(() => {
const [invokeBlue300] = useToken('colors', ['invokeBlue.300']);
return (
<FormControl isDisabled={!isEnabled} p={1} justifyContent="space-between" h={8}>
<FormControl isDisabled={!hasRasterLayersWithContent} p={1} justifyContent="space-between" h={8}>
<Flex gap={3} alignItems="center">
<InformationalPopover feature="paramDenoisingStrength">
<FormLabel mr={0}>{`${t('parameters.denoisingStrength')}`}</FormLabel>
</InformationalPopover>
{isEnabled && (
{hasRasterLayersWithContent && (
<WavyLine amplitude={img2imgStrength * 10} stroke={invokeBlue300} strokeWidth={1} width={40} height={14} />
)}
</Flex>
{isEnabled ? (
{hasRasterLayersWithContent ? (
<>
<CompositeSlider
step={config.coarseStep}
@@ -70,9 +73,7 @@ export const ParamDenoisingStrength = memo(() => {
</>
) : (
<Flex alignItems="center">
<Badge opacity="0.6">
{t('common.disabled')} - {t('parameters.noRasterLayers')}
</Badge>
<Badge opacity="0.6">{t('parameters.disabledNoRasterContent')}</Badge>
</Flex>
)}
</FormControl>

View File

@@ -1,13 +1,8 @@
import {
roundDownToMultiple,
roundToMultiple,
roundToMultipleMin,
roundUpToMultiple,
} from 'common/util/roundDownToMultiple';
import { roundToMultiple, roundToMultipleMin } from 'common/util/roundDownToMultiple';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
import { fitRectToGrid, getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectBbox } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect } from 'features/controlLayers/store/types';
@@ -398,18 +393,12 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
}
// Determine the bbox size that fits within the visible rect. The bbox must be at least 64px in width and height,
// and its width and height must be multiples of 8px.
// and its width and height must be multiples of the bbox grid size.
const gridSize = this.manager.stateApi.getBboxGridSize();
// To be conservative, we will round up the x and y to the nearest grid size, and round down the width and height.
// This ensures the bbox is never _larger_ than the visible rect. If the bbox is larger than the visible, we
// will always trigger the outpainting workflow, which is not what the user wants.
const x = roundUpToMultiple(visibleRect.x, gridSize);
const y = roundUpToMultiple(visibleRect.y, gridSize);
const width = roundDownToMultiple(visibleRect.width, gridSize);
const height = roundDownToMultiple(visibleRect.height, gridSize);
const rect = fitRectToGrid(visibleRect, gridSize);
this.manager.stateApi.setGenerationBbox({ x, y, width, height });
this.manager.stateApi.setGenerationBbox(rect);
};
/**

View File

@@ -1,4 +1,6 @@
import { getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util';
import { roundUpToMultiple } from 'common/util/roundDownToMultiple';
import { fitRectToGrid, getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util';
import type { Rect } from 'features/controlLayers/store/types';
import { describe, expect, it } from 'vitest';
describe('util', () => {
@@ -44,4 +46,74 @@ describe('util', () => {
expect(union).toEqual({ x: 0, y: 0, width: 0, height: 0 });
});
});
describe('fitRectToGrid', () => {
it('should fit rect within grid without exceeding bounds', () => {
const rect: Rect = { x: 0, y: 0, width: 1047, height: 1758 };
const gridSize = 50;
const result = fitRectToGrid(rect, gridSize);
expect(result.x).toBe(roundUpToMultiple(rect.x, gridSize));
expect(result.y).toBe(roundUpToMultiple(rect.y, gridSize));
expect(result.width).toBeLessThanOrEqual(rect.width);
expect(result.height).toBeLessThanOrEqual(rect.height);
expect(result.width % gridSize).toBe(0);
expect(result.height % gridSize).toBe(0);
});
it('should handle small rect within grid bounds', () => {
const rect: Rect = { x: 20, y: 30, width: 80, height: 90 };
const gridSize = 25;
const result = fitRectToGrid(rect, gridSize);
expect(result.x).toBe(25);
expect(result.y).toBe(50);
expect(result.width % gridSize).toBe(0);
expect(result.height % gridSize).toBe(0);
expect(result.width).toBeLessThanOrEqual(rect.width);
expect(result.height).toBeLessThanOrEqual(rect.height);
});
it('should handle rect starting outside of grid alignment', () => {
const rect: Rect = { x: 13, y: 27, width: 94, height: 112 };
const gridSize = 20;
const result = fitRectToGrid(rect, gridSize);
expect(result.x).toBe(20);
expect(result.y).toBe(40);
expect(result.width % gridSize).toBe(0);
expect(result.height % gridSize).toBe(0);
expect(result.width).toBeLessThanOrEqual(rect.width);
expect(result.height).toBeLessThanOrEqual(rect.height);
});
it('should return the same rect if already aligned to grid', () => {
const rect: Rect = { x: 100, y: 100, width: 200, height: 300 };
const gridSize = 50;
const result = fitRectToGrid(rect, gridSize);
expect(result).toEqual(rect);
});
it('should handle large grid sizes relative to rect dimensions', () => {
const rect: Rect = { x: 250, y: 300, width: 400, height: 500 };
const gridSize = 100;
const result = fitRectToGrid(rect, gridSize);
expect(result.x).toBe(300);
expect(result.y).toBe(300);
expect(result.width % gridSize).toBe(0);
expect(result.height % gridSize).toBe(0);
expect(result.width).toBeLessThanOrEqual(rect.width);
expect(result.height).toBeLessThanOrEqual(rect.height);
});
it('should handle rect with zero width and height', () => {
const rect: Rect = { x: 40, y: 60, width: 100, height: 200 };
const gridSize = 20;
const result = fitRectToGrid(rect, gridSize);
expect(result).toEqual({ x: 40, y: 60, width: 100, height: 200 });
});
});
});

View File

@@ -1,5 +1,6 @@
import type { Selector, Store } from '@reduxjs/toolkit';
import { $authToken } from 'app/store/nanostores/authToken';
import { roundDownToMultiple, roundUpToMultiple } from 'common/util/roundDownToMultiple';
import type {
CanvasEntityIdentifier,
CanvasObjectState,
@@ -560,6 +561,33 @@ export const getRectIntersection = (...rects: Rect[]): Rect => {
return rect || getEmptyRect();
};
/**
* Fits a rect to the nearest multiple of the grid size, rounding down. The returned rect will be smaller than or equal
* to the input rect, and will be aligned to the grid.
*
* In other words, shrink the rect inwards on each size until it fits within the visible rect and aligns to the grid.
*
* @param rect The rect to fit
* @param gridSize The size of the grid
* @returns The fitted rect
*/
export const fitRectToGrid = (rect: Rect, gridSize: number): Rect => {
// Rounding x and y up effectively shrinks the left and top edges of the rect, and rounding width and height down
// effectively shrinks the right and bottom edges.
const x = roundUpToMultiple(rect.x, gridSize);
const y = roundUpToMultiple(rect.y, gridSize);
// Because we've just shifted the rect's x and y, we need to adjust the width and height by the same amount before
// we round those values down.
const offsetX = x - rect.x;
const offsetY = y - rect.y;
const width = roundDownToMultiple(rect.width - offsetX, gridSize);
const height = roundDownToMultiple(rect.height - offsetY, gridSize);
return { x, y, width, height };
};
/**
* Asserts that the value is never reached. Used for exhaustive checks in switch statements or conditional logic to ensure
* that all possible values are handled.

View File

@@ -20,12 +20,15 @@ const sx = {
},
} satisfies SystemStyleObject;
type Props = ImageProps & {
imageDTO: ImageDTO;
asThumbnail?: boolean;
};
/* eslint-disable-next-line @typescript-eslint/no-namespace */
export namespace DndImage {
export interface Props extends ImageProps {
imageDTO: ImageDTO;
asThumbnail?: boolean;
}
}
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: Props) => {
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: DndImage.Props) => {
const store = useAppStore();
const [isDragging, setIsDragging] = useState(false);
const [element, ref] = useState<HTMLImageElement | null>(null);

View File

@@ -0,0 +1,26 @@
import { IAINoContentFallback, IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
import { DndImage } from 'features/dnd/DndImage';
import { memo } from 'react';
import { PiExclamationMarkBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
/* eslint-disable-next-line @typescript-eslint/no-namespace */
namespace DndImageFromImageName {
export interface Props extends Omit<DndImage.Props, 'imageDTO'> {
imageName: string;
}
}
export const DndImageFromImageName = memo(({ imageName, ...rest }: DndImageFromImageName.Props) => {
const query = useGetImageDTOQuery(imageName);
if (query.isLoading) {
return <IAINoContentFallbackWithSpinner />;
}
if (!query.data) {
return <IAINoContentFallback icon={<PiExclamationMarkBold />} />;
}
return <DndImage imageDTO={query.data} {...rest} />;
});
DndImageFromImageName.displayName = 'DndImageFromImageName';

View File

@@ -11,7 +11,7 @@ import type { DndTargetState } from 'features/dnd/types';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { memo, useEffect, useRef, useState } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { type UploadImageArg, uploadImages } from 'services/api/endpoints/images';
import { useBoardName } from 'services/api/hooks/useBoardName';
@@ -71,13 +71,46 @@ export const FullscreenDropzone = memo(() => {
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const [dndState, setDndState] = useState<DndTargetState>('idle');
const uploadFilesSchema = useMemo(() => getFilesSchema(maxImageUploadCount), [maxImageUploadCount]);
const validateAndUploadFiles = useCallback(
(files: File[]) => {
const { getState } = getStore();
const parseResult = uploadFilesSchema.safeParse(files);
if (!parseResult.success) {
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
toast({
id: 'UPLOAD_FAILED',
title: t('toast.uploadFailed'),
description,
status: 'error',
});
return;
}
const autoAddBoardId = selectAutoAddBoardId(getState());
const uploadArgs: UploadImageArg[] = files.map((file) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
}));
uploadImages(uploadArgs);
},
[maxImageUploadCount, t, uploadFilesSchema]
);
useEffect(() => {
const element = ref.current;
if (!element) {
return;
}
const { getState } = getStore();
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
return combine(
dropTargetForExternal({
@@ -85,32 +118,7 @@ export const FullscreenDropzone = memo(() => {
canDrop: containsFiles,
onDrop: ({ source }) => {
const files = getFiles({ source });
const parseResult = uploadFilesSchema.safeParse(files);
if (!parseResult.success) {
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
toast({
id: 'UPLOAD_FAILED',
title: t('toast.uploadFailed'),
description,
status: 'error',
});
return;
}
const autoAddBoardId = selectAutoAddBoardId(getState());
const uploadArgs: UploadImageArg[] = files.map((file) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
}));
uploadImages(uploadArgs);
validateAndUploadFiles(files);
},
onDragEnter: () => {
setDndState('over');
@@ -131,7 +139,27 @@ export const FullscreenDropzone = memo(() => {
},
})
);
}, [maxImageUploadCount, t]);
}, [validateAndUploadFiles]);
useEffect(() => {
const controller = new AbortController();
document.addEventListener(
'paste',
(e) => {
if (!e.clipboardData?.files) {
return;
}
const files = Array.from(e.clipboardData.files);
validateAndUploadFiles(files);
},
{ signal: controller.signal }
);
return () => {
controller.abort();
};
}, [validateAndUploadFiles]);
return (
<Box ref={ref} data-dnd-state={dndState} sx={sx}>

View File

@@ -18,6 +18,7 @@ import {
setRegionalGuidanceReferenceImage,
setUpscaleInitialImage,
} from 'features/imageActions/actions';
import { batchImageInputNodeImagesAdded } from 'features/nodes/store/nodesSlice';
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
@@ -241,6 +242,47 @@ export const setNodeImageFieldImageDndTarget: DndTarget<SetNodeImageFieldImageDn
};
//#endregion
//#region Add Images to Batch Image Input Node
const _addImagesToBatchImageInputNode = buildTypeAndKey('add-images-to-batch-image-input-node');
export type AddImagesToBatchImageInputNodeDndTargetData = DndData<
typeof _addImagesToBatchImageInputNode.type,
typeof _addImagesToBatchImageInputNode.key,
{ nodeId: string }
>;
export const addImagesToBatchImageInputNodeDndTarget: DndTarget<
AddImagesToBatchImageInputNodeDndTargetData,
SingleImageDndSourceData | MultipleImageDndSourceData
> = {
..._addImagesToBatchImageInputNode,
typeGuard: buildTypeGuard(_addImagesToBatchImageInputNode.key),
getData: buildGetData(_addImagesToBatchImageInputNode.key, _addImagesToBatchImageInputNode.type),
isValid: ({ sourceData }) => {
if (singleImageDndSource.typeGuard(sourceData) || multipleImageDndSource.typeGuard(sourceData)) {
return true;
}
return false;
},
handler: ({ sourceData, targetData, dispatch }) => {
if (!singleImageDndSource.typeGuard(sourceData) && !multipleImageDndSource.typeGuard(sourceData)) {
return;
}
const { nodeId } = targetData.payload;
const imageDTOs: ImageDTO[] = [];
if (singleImageDndSource.typeGuard(sourceData)) {
imageDTOs.push(sourceData.payload.imageDTO);
} else {
imageDTOs.push(...sourceData.payload.imageDTOs);
}
const images = imageDTOs.map(({ image_name }) => ({ image_name }));
dispatch(batchImageInputNodeImagesAdded({ nodeId, images }));
},
};
//#endregion
//# Set Comparison Image
const _setComparisonImage = buildTypeAndKey('set-comparison-image');
export type SetComparisonImageDndTargetData = DndData<
@@ -430,6 +472,7 @@ export const dndTargets = [
// Single or Multiple Image
addImageToBoardDndTarget,
removeImageFromBoardDndTarget,
addImagesToBatchImageInputNodeDndTarget,
] as const;
export type AnyDndTarget = (typeof dndTargets)[number];

View File

@@ -347,6 +347,18 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
}),
[t]
);
const batchInputImageItem = useMemo<FilterableItem>(
() =>
({
type: 'batch_input',
title: 'batch_input',
description: 'batch_input',
tags: ['batch'],
classification: 'stable',
nodePack: 'invokeai',
}) as const,
[]
);
const items = useMemo<NodeCommandItemData[]>(() => {
// If we have a connection in progress, we need to filter the node choices
@@ -365,7 +377,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
}
}
for (const item of [currentImageFilterItem, notesFilterItem]) {
for (const item of [currentImageFilterItem, notesFilterItem, batchInputImageItem]) {
if (filter(item, searchTerm)) {
_items.push({
label: item.title,
@@ -403,7 +415,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
}
return _items;
}, [pendingConnection, currentImageFilterItem, searchTerm, notesFilterItem, templatesArray]);
}, [pendingConnection, templatesArray, searchTerm, currentImageFilterItem, notesFilterItem, batchInputImageItem]);
return (
<>

View File

@@ -2,6 +2,7 @@ import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { BatchImageInputNode } from 'features/nodes/components/flow/nodes/BatchInput/BatchImageInputNode';
import { useConnection } from 'features/nodes/hooks/useConnection';
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
@@ -64,6 +65,7 @@ const nodeTypes = {
invocation: InvocationNodeWrapper,
current_image: CurrentImageNode,
notes: NotesNode,
image_batch: BatchImageInputNode,
};
// TODO: can we support reactflow? if not, we could style the attribution so it matches the app

View File

@@ -0,0 +1,119 @@
import { Box, Flex, Grid, GridItem, IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import type { AddImagesToBatchImageInputNodeDndTargetData } from 'features/dnd/dnd';
import { addImagesToBatchImageInputNodeDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImageFromImageName } from 'features/dnd/DndImageFromImageName';
import NodeCollapseButton from 'features/nodes/components/flow/nodes/common/NodeCollapseButton';
import NodeTitle from 'features/nodes/components/flow/nodes/common/NodeTitle';
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
import FieldHandle from 'features/nodes/components/flow/nodes/Invocation/fields/FieldHandle';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { batchImageInputNodeReset } from 'features/nodes/store/nodesSlice';
import { imageBatchOutputFieldTemplate } from 'features/nodes/types/field';
import type { ImageBatchNodeData } from 'features/nodes/types/invocation';
import { memo, useCallback, useMemo } from 'react';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import type { NodeProps } from 'reactflow';
export const BatchImageInputNode = memo((props: NodeProps<ImageBatchNodeData>) => {
const { id: nodeId, data, selected } = props;
const { images, isOpen } = data;
const dispatch = useAppDispatch();
const onReset = useCallback(() => {
dispatch(batchImageInputNodeReset({ nodeId }));
}, [dispatch, nodeId]);
const targetData = useMemo<AddImagesToBatchImageInputNodeDndTargetData>(
() => addImagesToBatchImageInputNodeDndTarget.getData({ nodeId }),
[nodeId]
);
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
<Flex
layerStyle="nodeHeader"
borderTopRadius="base"
borderBottomRadius={isOpen ? 0 : 'base'}
alignItems="center"
justifyContent="space-between"
h={8}
>
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<NodeTitle nodeId={nodeId} title="Batch Image Input" />
<Box minW={8} />
</Flex>
{isOpen && (
<>
<Flex
position="relative"
layerStyle="nodeBody"
className="nopan"
cursor="auto"
flexDirection="column"
borderBottomRadius="base"
w="full"
h="full"
p={2}
gap={1}
minH={16}
>
<Grid className="nopan" w="full" h="full" templateColumns="repeat(3, 1fr)" gap={2}>
{images.map(({ image_name }) => (
<GridItem key={image_name}>
<DndImageFromImageName imageName={image_name} asThumbnail />
</GridItem>
))}
</Grid>
<IconButton
aria-label="reset"
icon={<PiArrowCounterClockwiseBold />}
position="absolute"
top={0}
insetInlineEnd={0}
onClick={onReset}
variant="ghost"
/>
</Flex>
</>
)}
<ImageBatchOutputField nodeId={nodeId} />
<DndDropTarget
dndTarget={addImagesToBatchImageInputNodeDndTarget}
dndTargetData={targetData}
label="Add to Batch"
/>
</NodeWrapper>
);
});
BatchImageInputNode.displayName = 'BatchImageInputNode';
const ImageBatchOutputField = memo(({ nodeId }: { nodeId: string }) => {
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
useConnectionState({ nodeId, fieldName: 'images', kind: 'outputs' });
return (
<Flex
position="absolute"
minH={8}
top="50%"
translateY="-50%"
insetInlineEnd={2}
alignItems="center"
opacity={shouldDim ? 0.5 : 1}
transitionProperty="opacity"
transitionDuration="0.1s"
justifyContent="flex-end"
>
<FieldHandle
fieldTemplate={imageBatchOutputFieldTemplate}
handleType="source"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
validationResult={validationResult}
/>
</Flex>
);
});
ImageBatchOutputField.displayName = 'ImageBatchOutputField';

View File

@@ -1,9 +1,9 @@
import { Flex, Icon, Text, Tooltip } from '@invoke-ai/ui-library';
import { compare } from 'compare-versions';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNode } from 'features/nodes/hooks/useNode';
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { isInvocationNodeData } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiInfoBold } from 'react-icons/pi';
@@ -25,32 +25,32 @@ const InvocationNodeInfoIcon = ({ nodeId }: Props) => {
export default memo(InvocationNodeInfoIcon);
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
const data = useNodeData(nodeId);
const node = useNode(nodeId);
const nodeTemplate = useNodeTemplate(nodeId);
const { t } = useTranslation();
const title = useMemo(() => {
if (data?.label && nodeTemplate?.title) {
return `${data.label} (${nodeTemplate.title})`;
if (node.data?.label && nodeTemplate?.title) {
return `${node.data.label} (${nodeTemplate.title})`;
}
if (data?.label && !nodeTemplate) {
return data.label;
if (node.data?.label && !nodeTemplate) {
return node.data.label;
}
if (!data?.label && nodeTemplate) {
if (!node.data?.label && nodeTemplate) {
return nodeTemplate.title;
}
return t('nodes.unknownNode');
}, [data, nodeTemplate, t]);
}, [node.data.label, nodeTemplate, t]);
const versionComponent = useMemo(() => {
if (!isInvocationNodeData(data) || !nodeTemplate) {
if (!isInvocationNode(node) || !nodeTemplate) {
return null;
}
if (!data.version) {
if (!node.data.version) {
return (
<Text as="span" color="error.500">
{t('nodes.versionUnknown')}
@@ -61,35 +61,35 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
if (!nodeTemplate.version) {
return (
<Text as="span" color="error.500">
{t('nodes.version')} {data.version} ({t('nodes.unknownTemplate')})
{t('nodes.version')} {node.data.version} ({t('nodes.unknownTemplate')})
</Text>
);
}
if (compare(data.version, nodeTemplate.version, '<')) {
if (compare(node.data.version, nodeTemplate.version, '<')) {
return (
<Text as="span" color="error.500">
{t('nodes.version')} {data.version} ({t('nodes.updateNode')})
{t('nodes.version')} {node.data.version} ({t('nodes.updateNode')})
</Text>
);
}
if (compare(data.version, nodeTemplate.version, '>')) {
if (compare(node.data.version, nodeTemplate.version, '>')) {
return (
<Text as="span" color="error.500">
{t('nodes.version')} {data.version} ({t('nodes.updateApp')})
{t('nodes.version')} {node.data.version} ({t('nodes.updateApp')})
</Text>
);
}
return (
<Text as="span">
{t('nodes.version')} {data.version}
{t('nodes.version')} {node.data.version}
</Text>
);
}, [data, nodeTemplate, t]);
}, [node, nodeTemplate, t]);
if (!isInvocationNodeData(data)) {
if (!isInvocationNode(node)) {
return <Text fontWeight="semibold">{t('nodes.unknownNode')}</Text>;
}
@@ -107,7 +107,7 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
{nodeTemplate?.description}
</Text>
{versionComponent}
{data?.notes && <Text>{data.notes}</Text>}
{node.data?.notes && <Text>{node.data.notes}</Text>}
</Flex>
);
});

View File

@@ -1,15 +1,15 @@
import { FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNode } from 'features/nodes/hooks/useNode';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNodeData } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const node = useNode(nodeId);
const { t } = useTranslation();
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
@@ -17,13 +17,13 @@ const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
if (!isInvocationNode(node)) {
return null;
}
return (
<FormControl orientation="vertical" h="full">
<FormLabel>{t('nodes.notes')}</FormLabel>
<Textarea value={data?.notes} onChange={handleNotesChanged} rows={10} resize="none" />
<Textarea value={node.data?.notes} onChange={handleNotesChanged} rows={10} resize="none" />
</FormControl>
);
};

View File

@@ -23,6 +23,8 @@ const InputField = ({ nodeId, fieldName }: Props) => {
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
console.log({ isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim });
const isMissingInput = useMemo(() => {
if (!fieldTemplate) {
return false;

View File

@@ -64,7 +64,7 @@ type OutputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean;
}>;
const OutputFieldWrapper = memo(({ shouldDim, children }: OutputFieldWrapperProps) => (
export const OutputFieldWrapper = memo(({ shouldDim, children }: OutputFieldWrapperProps) => (
<Flex
position="relative"
minH={8}

View File

@@ -2,6 +2,7 @@ import { useStore } from '@nanostores/react';
import { $templates } from 'features/nodes/store/nodesSlice';
import { NODE_WIDTH } from 'features/nodes/types/constants';
import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation';
import { buildBatchInputImageNode } from 'features/nodes/util/node/buildBatchInputImageNode';
import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import { buildNotesNode } from 'features/nodes/util/node/buildNotesNode';
@@ -14,7 +15,7 @@ export const useBuildNode = () => {
return useCallback(
// string here is "any invocation type"
(type: string | 'current_image' | 'notes'): AnyNode => {
(type: string | 'current_image' | 'notes' | 'batch_input'): AnyNode => {
let _x = window.innerWidth / 2;
let _y = window.innerHeight / 2;
@@ -39,6 +40,10 @@ export const useBuildNode = () => {
return buildNotesNode(position);
}
if (type === 'batch_input') {
return buildBatchInputImageNode(position);
}
// TODO: Keep track of invocation types so we do not need to cast this
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
const template = templates[type] as InvocationTemplate;

View File

@@ -12,6 +12,11 @@ import {
import { selectNodes, selectNodesSlice } from 'features/nodes/store/selectors';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import {
type FieldInputTemplate,
type FieldOutputTemplate,
imageBatchOutputFieldTemplate,
} from 'features/nodes/types/field';
import { useCallback, useMemo } from 'react';
import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import { useUpdateNodeInternals } from 'reactflow';
@@ -33,13 +38,20 @@ export const useConnection = () => {
return;
}
const template = templates[node.data.type];
if (!template) {
return;
let fieldTemplate: FieldInputTemplate | FieldOutputTemplate | undefined = undefined;
if (node.type === 'image_batch' && handleId === 'images' && handleType === 'source') {
fieldTemplate = imageBatchOutputFieldTemplate;
} else {
const template = templates[node.data.type];
if (!template) {
return;
}
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
fieldTemplate = fieldTemplates[handleId];
}
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
const fieldTemplate = fieldTemplates[handleId];
if (!fieldTemplate) {
return;
}

View File

@@ -0,0 +1,19 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNode, selectNodesSlice } from 'features/nodes/store/selectors';
import { useMemo } from 'react';
import type { Node } from 'reactflow';
export const useNode = (nodeId: string): Node => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
return selectNode(nodes, nodeId);
}),
[nodeId]
);
const node = useAppSelector(selector);
return node;
};

View File

@@ -3,6 +3,7 @@ import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig } from 'app/store/store';
import { buildUseBoolean } from 'common/hooks/useBoolean';
import { workflowLoaded } from 'features/nodes/store/actions';
import type { ImageField } from 'features/nodes/types/common';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
BoardFieldValue,
@@ -58,7 +59,8 @@ import {
zVAEModelFieldValue,
} from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { isImageBatchNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { uniqBy } from 'lodash-es';
import { atom, computed } from 'nanostores';
import type { MouseEvent } from 'react';
import type { Edge, EdgeChange, NodeChange, Viewport, XYPosition } from 'reactflow';
@@ -193,7 +195,7 @@ export const nodesSlice = createSlice({
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node) && !isNotesNode(node)) {
if (!isInvocationNode(node) && !isNotesNode(node) && !isImageBatchNode(node)) {
return;
}
@@ -382,6 +384,34 @@ export const nodesSlice = createSlice({
}
node.data.notes = value;
},
batchImageInputNodeImagesAdded: (state, action: PayloadAction<{ nodeId: string; images: ImageField[] }>) => {
const { nodeId, images } = action.payload;
const node = state.nodes.find((n) => n.id === nodeId);
if (!isImageBatchNode(node)) {
return;
}
node.data.images = uniqBy([...node.data.images, ...images], 'image_name');
},
batchImageInputNodeImagesRemoved: (state, action: PayloadAction<{ nodeId: string; images: ImageField[] }>) => {
const { nodeId, images } = action.payload;
const node = state.nodes.find((n) => n.id === nodeId);
if (!isImageBatchNode(node)) {
return;
}
const imageNamesToRemove = images.map(({ image_name }) => image_name);
node.data.images = uniqBy(
node.data.images.filter(({ image_name }) => imageNamesToRemove.includes(image_name)),
'image_name'
);
},
batchImageInputNodeReset: (state, action: PayloadAction<{ nodeId: string }>) => {
const { nodeId } = action.payload;
const node = state.nodes.find((n) => n.id === nodeId);
if (!isImageBatchNode(node)) {
return;
}
node.data.images = [];
},
nodeEditorReset: (state) => {
state.nodes = [];
state.edges = [];
@@ -443,6 +473,9 @@ export const {
notesNodeValueChanged,
undo,
redo,
batchImageInputNodeImagesAdded,
batchImageInputNodeImagesRemoved,
batchImageInputNodeReset,
} = nodesSlice.actions;
export const $cursorPos = atom<XYPosition | null>(null);

View File

@@ -5,8 +5,15 @@ import type { NodesState } from 'features/nodes/store/types';
import type { FieldInputInstance } from 'features/nodes/types/field';
import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type { Node } from 'reactflow';
import { assert } from 'tsafe';
export const selectNode = (nodesSlice: NodesState, nodeId: string): Node => {
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
assert(node !== undefined, `Node ${nodeId} not found`);
return node;
};
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);

View File

@@ -3,6 +3,7 @@ import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import { imageBatchOutputFieldTemplate } from 'features/nodes/types/field';
import type { AnyNode } from 'features/nodes/types/invocation';
import type { Connection as NullableConnection, Edge } from 'reactflow';
import type { SetNonNullable } from 'type-fest';
@@ -78,23 +79,32 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
return buildRejectResult('nodes.missingNode');
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const targetTemplate = templates[targetNode.data.type];
if (!targetTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
if (!sourceFieldTemplate) {
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
if (!targetFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
}
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
if (!targetFieldTemplate) {
if (sourceNode.type === 'image_batch') {
const isValid = validateConnectionTypes(imageBatchOutputFieldTemplate.type, targetFieldTemplate.type);
if (!isValid) {
return buildRejectResult('nodes.fieldTypesMustMatch');
} else {
return buildAcceptResult();
}
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
if (!sourceFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
}

View File

@@ -15,6 +15,10 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field
return false;
}
if (sourceType.name === 'ImageBatchField') {
return isSingle(sourceType) && targetType.name === 'ImageField';
}
if (areTypesEqual(sourceType, targetType)) {
return true;
}

View File

@@ -59,6 +59,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
EnumField: 'blue.500',
FloatField: 'orange.500',
ImageField: 'purple.500',
ImageBatchField: 'purple.500',
IntegerField: 'red.500',
IPAdapterField: 'teal.500',
IPAdapterModelField: 'teal.500',

View File

@@ -95,6 +95,10 @@ const zImageFieldType = zFieldTypeBase.extend({
name: z.literal('ImageField'),
originalType: zStatelessFieldType.optional(),
});
const zImageBatchFieldType = zFieldTypeBase.extend({
name: z.literal('ImageBatchField'),
originalType: zStatelessFieldType.optional(),
});
const zBoardFieldType = zFieldTypeBase.extend({
name: z.literal('BoardField'),
originalType: zStatelessFieldType.optional(),
@@ -175,13 +179,14 @@ const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
});
const zStatefulFieldType = z.union([
export const zStatefulFieldType = z.union([
zIntegerFieldType,
zFloatFieldType,
zStringFieldType,
zBooleanFieldType,
zEnumFieldType,
zImageFieldType,
zImageBatchFieldType,
zBoardFieldType,
zModelIdentifierFieldType,
zMainModelFieldType,
@@ -367,6 +372,19 @@ export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputI
zImageFieldInputInstance.safeParse(val).success;
export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate =>
zImageFieldInputTemplate.safeParse(val).success;
const zImageBatchFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zImageBatchFieldType,
});
export type ImageBatchOutputFieldTemplate = z.infer<typeof zImageBatchFieldOutputTemplate>;
export const imageBatchOutputFieldTemplate: ImageBatchOutputFieldTemplate = {
fieldKind: 'output',
name: 'images',
ui_hidden: false,
type: { name: 'ImageBatchField', cardinality: 'SINGLE' },
title: 'Image',
};
// #endregion
// #region BoardField
@@ -991,6 +1009,7 @@ const zStatefulFieldOutputTemplate = z.union([
zBooleanFieldOutputTemplate,
zEnumFieldOutputTemplate,
zImageFieldOutputTemplate,
zImageBatchFieldOutputTemplate,
zBoardFieldOutputTemplate,
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,

View File

@@ -1,7 +1,7 @@
import type { Edge, Node } from 'reactflow';
import { z } from 'zod';
import { zClassification, zProgressImage } from './common';
import { zClassification, zImageField, zProgressImage } from './common';
import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputTemplate } from './field';
import { zSemVer } from './semver';
@@ -49,23 +49,33 @@ const zCurrentImageNodeData = z.object({
label: z.string(),
isOpen: z.boolean(),
});
const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData]);
const zImageBatchNodeData = z.object({
id: z.string().trim().min(1),
type: z.literal('image_batch'),
label: z.string(),
isOpen: z.boolean(),
images: z.array(zImageField),
});
const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData, zImageBatchNodeData]);
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
export type ImageBatchNodeData = z.infer<typeof zImageBatchNodeData>;
type AnyNodeData = z.infer<typeof zAnyNodeData>;
export type InvocationNode = Node<InvocationNodeData, 'invocation'>;
export type NotesNode = Node<NotesNodeData, 'notes'>;
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>;
export type BatchImageInputNode = Node<ImageBatchNodeData, 'image_batch'>;
export type AnyNode = Node<AnyNodeData>;
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
Boolean(node && node.type === 'invocation');
export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes');
export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData =>
Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type
export const isImageBatchNode = (node?: AnyNode | null): node is BatchImageInputNode =>
Boolean(node && node.type === 'image_batch');
// #endregion
// #region NodeExecutionState

View File

@@ -1,9 +1,12 @@
import { logger } from 'app/logging/logger';
import type { NodesState } from 'features/nodes/store/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { omit, reduce } from 'lodash-es';
import type { AnyInvocation, Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
const log = logger('workflows');
/**
* Builds a graph from the node editor state.
*/
@@ -47,22 +50,31 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
return nodesAccumulator;
}, {});
const filteredNodeIds = filteredNodes.map(({ id }) => id);
// skip out the "dummy" edges between collapsed nodes
const filteredEdges = edges.filter((n) => n.type !== 'collapsed');
const filteredEdges = edges
.filter((edge) => edge.type !== 'collapsed')
.filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target));
// Reduce the node editor edges into invocation graph edges
const parsedEdges = filteredEdges.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
const { source, target, sourceHandle, targetHandle } = edge;
if (!sourceHandle || !targetHandle) {
log.warn({ source, target, sourceHandle, targetHandle }, 'Missing source or taget handle for edge');
return edgesAccumulator;
}
// Format the edges and add to the edges array
edgesAccumulator.push({
source: {
node_id: source,
field: sourceHandle as string,
field: sourceHandle,
},
destination: {
node_id: target,
field: targetHandle as string,
field: targetHandle,
},
});

View File

@@ -0,0 +1,22 @@
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type { BatchImageInputNode } from 'features/nodes/types/invocation';
import type { XYPosition } from 'reactflow';
import { v4 as uuidv4 } from 'uuid';
export const buildBatchInputImageNode = (position: XYPosition): BatchImageInputNode => {
const nodeId = uuidv4();
const node: BatchImageInputNode = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'image_batch',
position,
data: {
id: nodeId,
isOpen: true,
label: 'Image Batch',
type: 'image_batch',
images: [],
},
};
return node;
};

View File

@@ -34,8 +34,15 @@ export const UpscaleWarning = () => {
dispatch(tileControlnetModelChanged(validModel || null));
}, [model?.base, modelConfigs, dispatch]);
const isBaseModelCompatible = useMemo(() => {
return model && ['sd-1', 'sdxl'].includes(model.base);
}, [model]);
const modelWarnings = useMemo(() => {
const _warnings: string[] = [];
if (!isBaseModelCompatible) {
return _warnings;
}
if (!model) {
_warnings.push(t('upscaling.mainModelDesc'));
}
@@ -46,7 +53,7 @@ export const UpscaleWarning = () => {
_warnings.push(t('upscaling.upscaleModelDesc'));
}
return _warnings;
}, [model, tileControlnetModel, upscaleModel, t]);
}, [isBaseModelCompatible, model, tileControlnetModel, upscaleModel, t]);
const otherWarnings = useMemo(() => {
const _warnings: string[] = [];
@@ -58,22 +65,25 @@ export const UpscaleWarning = () => {
return _warnings;
}, [isTooLargeToUpscale, t, maxUpscaleDimension]);
const allWarnings = useMemo(() => [...modelWarnings, ...otherWarnings], [modelWarnings, otherWarnings]);
const handleGoToModelManager = useCallback(() => {
dispatch(setActiveTab('models'));
$installModelsTab.set(3);
}, [dispatch]);
if (modelWarnings.length && isModelsTabDisabled) {
if (isBaseModelCompatible && modelWarnings.length > 0 && isModelsTabDisabled) {
return null;
}
if ((!modelWarnings.length && !otherWarnings.length) || isLoading) {
if ((isBaseModelCompatible && allWarnings.length === 0) || isLoading) {
return null;
}
return (
<Flex bg="error.500" borderRadius="base" padding={4} direction="column" fontSize="sm" gap={2}>
{!!modelWarnings.length && (
{!isBaseModelCompatible && <Text>{t('upscaling.incompatibleBaseModelDesc')}</Text>}
{modelWarnings.length > 0 && (
<Text>
<Trans
i18nKey="upscaling.missingModelsWarning"
@@ -85,11 +95,13 @@ export const UpscaleWarning = () => {
/>
</Text>
)}
<UnorderedList>
{[...modelWarnings, ...otherWarnings].map((warning) => (
<ListItem key={warning}>{warning}</ListItem>
))}
</UnorderedList>
{allWarnings.length > 0 && (
<UnorderedList>
{allWarnings.map((warning) => (
<ListItem key={warning}>{warning}</ListItem>
))}
</UnorderedList>
)}
</Flex>
);
};

View File

@@ -1822,7 +1822,7 @@ export type components = {
* Items
* @description The list of items to substitute into the node/field.
*/
items?: (string | number)[];
items?: (string | number | components["schemas"]["ImageField"])[];
};
/**
* BatchEnqueuedEvent
@@ -6751,6 +6751,12 @@ export type components = {
* @default 1
*/
denoising_end?: number;
/**
* Add Noise
* @description Add noise based on denoising start.
* @default true
*/
add_noise?: boolean;
/**
* Transformer
* @description Flux model (Transformer) to load
@@ -13161,7 +13167,7 @@ export type components = {
* Value
* @description The value to substitute into the node/field.
*/
value: string | number;
value: string | number | components["schemas"]["ImageField"];
};
/**
* Noise

View File

@@ -1,6 +1,6 @@
# State dict keys and shapes for an XLabs FLUX IP-Adapter model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/XLabs-AI/flux-ip-adapter/blob/ad16be50d78a07ea83d8c4bde44ff9753235182e/flux-ip-adapter.safetensors
# https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors
xlabs_sd_shapes = {
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],