Merge branch 'main' into model-classification-api

This commit is contained in:
Billy
2025-03-13 13:03:51 +11:00
83 changed files with 1965 additions and 1051 deletions

View File

@@ -60,7 +60,7 @@
"@fontsource-variable/inter": "^5.1.0",
"@invoke-ai/ui-library": "^0.0.46",
"@nanostores/react": "^0.7.3",
"@reduxjs/toolkit": "2.6.0",
"@reduxjs/toolkit": "2.6.1",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.4.2",
"async-mutex": "^0.5.0",

View File

@@ -30,8 +30,8 @@ dependencies:
specifier: ^0.7.3
version: 0.7.3(nanostores@0.11.3)(react@18.3.1)
'@reduxjs/toolkit':
specifier: 2.6.0
version: 2.6.0(react-redux@9.1.2)(react@18.3.1)
specifier: 2.6.1
version: 2.6.1(react-redux@9.1.2)(react@18.3.1)
'@roarr/browser-log-writer':
specifier: ^1.3.0
version: 1.3.0
@@ -2311,8 +2311,8 @@ packages:
- supports-color
dev: true
/@reduxjs/toolkit@2.6.0(react-redux@9.1.2)(react@18.3.1):
resolution: {integrity: sha512-mWJCYpewLRyTuuzRSEC/IwIBBkYg2dKtQas8mty5MaV2iXzcmicS3gW554FDeOvLnY3x13NIk8MB1e8wHO7rqQ==}
/@reduxjs/toolkit@2.6.1(react-redux@9.1.2)(react@18.3.1):
resolution: {integrity: sha512-SSlIqZNYhqm/oMkXbtofwZSt9lrncblzo6YcZ9zoX+zLngRBrCOjK4lNLdkNucJF58RHOWrD9txT3bT3piH7Zw==}
peerDependencies:
react: ^16.9.0 || ^17.0.0 || ^18 || ^19
react-redux: ^7.2.1 || ^8.1.3 || ^9.0.0

View File

@@ -113,7 +113,8 @@
"end": "Ende",
"layout": "Layout",
"board": "Ordner",
"combinatorial": "Kombinatorisch"
"combinatorial": "Kombinatorisch",
"saveChanges": "Änderungen speichern"
},
"gallery": {
"galleryImageSize": "Bildgröße",
@@ -761,7 +762,16 @@
"workflowDeleted": "Arbeitsablauf gelöscht",
"errorCopied": "Fehler kopiert",
"layerCopiedToClipboard": "Ebene in die Zwischenablage kopiert",
"sentToCanvas": "An Leinwand gesendet"
"sentToCanvas": "An Leinwand gesendet",
"problemDeletingWorkflow": "Problem beim Löschen des Arbeitsablaufs",
"uploadFailedInvalidUploadDesc_withCount_one": "Es darf maximal 1 PNG- oder JPEG-Bild sein.",
"uploadFailedInvalidUploadDesc_withCount_other": "Es dürfen maximal {{count}} PNG- oder JPEG-Bilder sein.",
"problemRetrievingWorkflow": "Problem beim Abrufen des Arbeitsablaufs",
"uploadFailedInvalidUploadDesc": "Müssen PNG- oder JPEG-Bilder sein.",
"pasteSuccess": "Eingefügt in {{destination}}",
"pasteFailed": "Einfügen fehlgeschlagen",
"unableToCopy": "Kopieren nicht möglich",
"unableToCopyDesc_theseSteps": "diese Schritte"
},
"accessibility": {
"uploadImage": "Bild hochladen",
@@ -1314,7 +1324,8 @@
"nodeName": "Knotenname",
"description": "Beschreibung",
"loadWorkflowDesc": "Arbeitsablauf laden?",
"loadWorkflowDesc2": "Ihr aktueller Arbeitsablauf enthält nicht gespeicherte Änderungen."
"loadWorkflowDesc2": "Ihr aktueller Arbeitsablauf enthält nicht gespeicherte Änderungen.",
"loadingTemplates": "Lade {{name}}"
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",

View File

@@ -1692,10 +1692,11 @@
"filterByTags": "Filter by Tags",
"yourWorkflows": "Your Workflows",
"recentlyOpened": "Recently Opened",
"noRecentWorkflows": "No Recent Workflows",
"private": "Private",
"shared": "Shared",
"browseWorkflows": "Browse Workflows",
"resetTags": "Reset Tags",
"deselectAll": "Deselect All",
"opened": "Opened",
"openWorkflow": "Open Workflow",
"updated": "Updated",
@@ -1726,6 +1727,7 @@
"loadWorkflow": "$t(common.load) Workflow",
"autoLayout": "Auto Layout",
"edit": "Edit",
"view": "View",
"download": "Download",
"copyShareLink": "Copy Share Link",
"copyShareLinkForWorkflow": "Copy Share Link for Workflow",
@@ -1741,6 +1743,8 @@
"row": "Row",
"column": "Column",
"container": "Container",
"containerRowLayout": "Container (row layout)",
"containerColumnLayout": "Container (column layout)",
"heading": "Heading",
"text": "Text",
"divider": "Divider",

View File

@@ -1771,7 +1771,6 @@
"projectWorkflows": "Workflows du projet",
"copyShareLink": "Copier le lien de partage",
"chooseWorkflowFromLibrary": "Choisir le Workflow dans la Bibliothèque",
"uploadAndSaveWorkflow": "Importer dans la bibliothèque",
"edit": "Modifer",
"deleteWorkflow2": "Êtes-vous sûr de vouloir supprimer ce Workflow? Cette action ne peut pas être annulé.",
"download": "Télécharger",

View File

@@ -109,7 +109,8 @@
"board": "Bacheca",
"layout": "Schema",
"row": "Riga",
"column": "Colonna"
"column": "Colonna",
"saveChanges": "Salva modifiche"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -784,7 +785,7 @@
"serverError": "Errore del Server",
"connected": "Connesso al server",
"canceled": "Elaborazione annullata",
"uploadFailedInvalidUploadDesc": "Devono essere immagini PNG o JPEG.",
"uploadFailedInvalidUploadDesc": "Devono essere immagini PNG, JPEG o WEBP.",
"parameterSet": "Parametro richiamato",
"parameterNotSet": "Parametro non richiamato",
"problemCopyingImage": "Impossibile copiare l'immagine",
@@ -835,9 +836,9 @@
"linkCopied": "Collegamento copiato",
"addedToUncategorized": "Aggiunto alle risorse della bacheca $t(boards.uncategorized)",
"imagesWillBeAddedTo": "Le immagini caricate verranno aggiunte alle risorse della bacheca {{boardName}}.",
"uploadFailedInvalidUploadDesc_withCount_one": "Devi caricare al massimo 1 immagine PNG o JPEG.",
"uploadFailedInvalidUploadDesc_withCount_many": "Devi caricare al massimo {{count}} immagini PNG o JPEG.",
"uploadFailedInvalidUploadDesc_withCount_other": "Devi caricare al massimo {{count}} immagini PNG o JPEG.",
"uploadFailedInvalidUploadDesc_withCount_one": "Devi caricare al massimo 1 immagine PNG, JPEG o WEBP.",
"uploadFailedInvalidUploadDesc_withCount_many": "Devi caricare al massimo {{count}} immagini PNG, JPEG o WEBP.",
"uploadFailedInvalidUploadDesc_withCount_other": "Devi caricare al massimo {{count}} immagini PNG, JPEG o WEBP.",
"outOfMemoryErrorDescLocal": "Segui la nostra <LinkComponent>guida per bassa VRAM</LinkComponent> per ridurre gli OOM.",
"pasteFailed": "Incolla non riuscita",
"pasteSuccess": "Incollato su {{destination}}",
@@ -1704,7 +1705,7 @@
"saveWorkflow": "Salva flusso di lavoro",
"openWorkflow": "Apri flusso di lavoro",
"clearWorkflowSearchFilter": "Cancella il filtro di ricerca del flusso di lavoro",
"workflowLibrary": "Libreria",
"workflowLibrary": "Libreria flussi di lavoro",
"workflowSaved": "Flusso di lavoro salvato",
"unnamedWorkflow": "Flusso di lavoro senza nome",
"savingWorkflow": "Salvataggio del flusso di lavoro...",
@@ -1734,7 +1735,6 @@
"userWorkflows": "Flussi di lavoro utente",
"projectWorkflows": "Flussi di lavoro del progetto",
"defaultWorkflows": "Flussi di lavoro predefiniti",
"uploadAndSaveWorkflow": "Carica nella libreria",
"chooseWorkflowFromLibrary": "Scegli il flusso di lavoro dalla libreria",
"deleteWorkflow2": "Vuoi davvero eliminare questo flusso di lavoro? Questa operazione non può essere annullata.",
"edit": "Modifica",
@@ -1772,7 +1772,19 @@
"container": "Contenitore",
"text": "Testo",
"numberInput": "Ingresso numerico"
}
},
"loadMore": "Carica altro",
"searchPlaceholder": "Cerca per nome, descrizione o etichetta",
"filterByTags": "Filtra per etichetta",
"shared": "Condiviso",
"browseWorkflows": "Sfoglia i flussi di lavoro",
"resetTags": "Reimposta le etichette",
"allLoaded": "Tutti i flussi di lavoro caricati",
"saveChanges": "Salva modifiche",
"yourWorkflows": "I tuoi flussi di lavoro",
"recentlyOpened": "Aperto di recente",
"workflowThumbnail": "Miniatura del flusso di lavoro",
"private": "Privato"
},
"accordions": {
"compositing": {
@@ -2330,8 +2342,8 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"Editor del flusso di lavoro: nuovo generatore di moduli trascina-e-rilascia per una creazione più facile del flusso di lavoro.",
"Altri miglioramenti: messa in coda dei lotti più rapida, migliore ampliamento, selettore colore migliorato e nodi metadati."
"Gestione della memoria: nuova impostazione per gli utenti con GPU Nvidia per ridurre l'utilizzo della VRAM.",
"Prestazioni: continui miglioramenti alle prestazioni e alla reattività complessive dell'applicazione."
]
},
"system": {

View File

@@ -1566,7 +1566,6 @@
"defaultWorkflows": "Стандартные рабочие процессы",
"deleteWorkflow2": "Вы уверены, что хотите удалить этот рабочий процесс? Это нельзя отменить.",
"chooseWorkflowFromLibrary": "Выбрать рабочий процесс из библиотеки",
"uploadAndSaveWorkflow": "Загрузить в библиотеку",
"edit": "Редактировать",
"download": "Скачать",
"copyShareLink": "Скопировать ссылку на общий доступ",

View File

@@ -235,7 +235,8 @@
"column": "Cột",
"layout": "Bố Cục",
"row": "Hàng",
"board": "Bảng"
"board": "Bảng",
"saveChanges": "Lưu Thay Đổi"
},
"prompt": {
"addPromptTrigger": "Thêm Prompt Trigger",
@@ -766,7 +767,9 @@
"urlUnauthorizedErrorMessage2": "Tìm hiểu thêm.",
"urlForbidden": "Bạn không có quyền truy cập vào model này",
"urlForbiddenErrorMessage": "Bạn có thể cần yêu cầu quyền truy cập từ trang web đang cung cấp model.",
"urlUnauthorizedErrorMessage": "Bạn có thể cần thiếp lập một token API để dùng được model này."
"urlUnauthorizedErrorMessage": "Bạn có thể cần thiếp lập một token API để dùng được model này.",
"fluxRedux": "FLUX Redux",
"sigLip": "SigLIP"
},
"metadata": {
"guidance": "Hướng Dẫn",
@@ -979,7 +982,7 @@
"unknownInput": "Đầu Vào Không Rõ: {{name}}",
"validateConnections": "Xác Thực Kết Nối Và Đồ Thị",
"workflowNotes": "Ghi Chú",
"workflowTags": "Thẻ Tên",
"workflowTags": "Nhãn",
"editMode": "Chỉnh sửa trong Trình Biên Tập Workflow",
"edit": "Chỉnh Sửa",
"executionStateInProgress": "Đang Xử Lý",
@@ -2021,7 +2024,7 @@
},
"mergingLayers": "Đang gộp layer",
"controlLayerEmptyState": "<UploadButton>Tải lên ảnh</UploadButton>, kéo thả ảnh từ <GalleryButton>thư viện</GalleryButton> vào layer này, hoặc vẽ trên canvas để bắt đầu.",
"referenceImageEmptyState": "<UploadButton>Tải lên ảnh</UploadButton> hoặc kéo thả ảnh từ <GalleryButton>thư viện</GalleryButton> vào layer này để bắt đầu.",
"referenceImageEmptyState": "<UploadButton>Tải lên hình ảnh</UploadButton>, kéo ảnh từ <GalleryButton>thư viện ảnh</GalleryButton> vào layer này, hoặc <PullBboxButton>kéo hộp giới hạn vào layer này</PullBboxButton> để bắt đầu.",
"useImage": "Dùng Hình Ảnh",
"resetCanvasLayers": "Khởi Động Lại Layer Canvas",
"asRasterLayer": "Như $t(controlLayers.rasterLayer)",
@@ -2137,7 +2140,7 @@
"toast": {
"imageUploadFailed": "Tải Lên Ảnh Thất Bại",
"layerCopiedToClipboard": "Sao Chép Layer Vào Clipboard",
"uploadFailedInvalidUploadDesc_withCount_other": "Tối đa là {{count}} ảnh PNG hoặc JPEG.",
"uploadFailedInvalidUploadDesc_withCount_other": "Tối đa là {{count}} ảnh PNG, JPEG hoặc WEBP.",
"imageCopied": "Ảnh Đã Được Sao Chép",
"sentToUpscale": "Chuyển Vào Upscale",
"unableToLoadImage": "Không Thể Tải Hình Ảnh",
@@ -2149,7 +2152,7 @@
"unableToLoadImageMetadata": "Không Thể Tải Metadata Của Ảnh",
"workflowLoaded": "Workflow Đã Tải",
"uploadFailed": "Tải Lên Thất Bại",
"uploadFailedInvalidUploadDesc": "Phải là ảnh PNG hoặc JPEG.",
"uploadFailedInvalidUploadDesc": "Phải là ảnh PNG, JPEG hoặc WEBP.",
"serverError": "Lỗi Server",
"addedToBoard": "Thêm vào tài nguyên của bảng {{name}}",
"sessionRef": "Phiên: {{sessionId}}",
@@ -2252,11 +2255,10 @@
"convertGraph": "Chuyển Đổi Đồ Thị",
"saveWorkflowToProject": "Lưu Workflow Vào Dự Án",
"workflowName": "Tên Workflow",
"workflowLibrary": "Thư Viện",
"workflowLibrary": "Thư Viện Workflow",
"opened": "Ngày Mở",
"deleteWorkflow": "Xoá Workflow",
"workflowEditorMenu": "Menu Biên Tập Workflow",
"uploadAndSaveWorkflow": "Tải Lên Thư Viện",
"openLibrary": "Mở Thư Viện",
"builder": {
"resetAllNodeFields": "Tải Lại Các Vùng Node",
@@ -2287,7 +2289,19 @@
"heading": "Đầu Dòng",
"text": "Văn Bản",
"divider": "Gạch Chia"
}
},
"yourWorkflows": "Workflow Của Bạn",
"browseWorkflows": "Khám Phá Workflow",
"workflowThumbnail": "Ảnh Minh Họa Workflow",
"saveChanges": "Lưu Thay Đổi",
"allLoaded": "Đã Tải Tất Cả Workflow",
"shared": "Nhóm",
"searchPlaceholder": "Tìm theo tên, mô tả, hoặc nhãn",
"filterByTags": "Lọc Theo Nhãn",
"recentlyOpened": "Mở Gần Đây",
"private": "Cá Nhân",
"resetTags": "Khởi Động Lại Nhãn",
"loadMore": "Tải Thêm"
},
"upscaling": {
"missingUpscaleInitialImage": "Thiếu ảnh dùng để upscale",
@@ -2322,8 +2336,8 @@
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
"items": [
"Trình Biên Tập Workflow: trình tạo vùng nhập dưới dạng kéo thả nhằm tạo dựng workflow dễ dàng hơn.",
"Các nâng cấp khác: Xếp hàng tạo sinh theo nhóm nhanh hơn, upscale tốt hơn, trình chọn màu được cải thiện, và node chứa metadata."
"Trình Quản Lý Bộ Nhớ: Thiết lập mới cho người dùng với GPU Nvidia để giảm lượng VRAM sử dng.",
"Hiệu suất: Các cải thiện tiếp theo nhm gói gọn hiệu suất và khả năng phản hồi của ứng dụng."
]
},
"upsell": {

View File

@@ -1629,7 +1629,6 @@
"projectWorkflows": "项目工作流程",
"copyShareLink": "复制分享链接",
"chooseWorkflowFromLibrary": "从库中选择工作流程",
"uploadAndSaveWorkflow": "上传到库",
"deleteWorkflow2": "您确定要删除此工作流程吗?此操作无法撤销。"
},
"accordions": {

View File

@@ -16,10 +16,16 @@ import { $openAPISchemaUrl } from 'app/store/nanostores/openAPISchemaUrl';
import { $projectId, $projectName, $projectUrl } from 'app/store/nanostores/projectId';
import { $queueId, DEFAULT_QUEUE_ID } from 'app/store/nanostores/queueId';
import { $store } from 'app/store/nanostores/store';
import { $workflowCategories } from 'app/store/nanostores/workflowCategories';
import { createStore } from 'app/store/store';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import type { WorkflowTagCategory } from 'features/nodes/store/workflowLibrarySlice';
import {
$workflowLibraryCategoriesOptions,
$workflowLibraryTagCategoriesOptions,
DEFAULT_WORKFLOW_LIBRARY_CATEGORIES,
DEFAULT_WORKFLOW_LIBRARY_TAG_CATEGORIES,
} from 'features/nodes/store/workflowLibrarySlice';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
@@ -48,6 +54,7 @@ interface Props extends PropsWithChildren {
isDebugging?: boolean;
logo?: ReactNode;
workflowCategories?: WorkflowCategory[];
workflowTagCategories?: WorkflowTagCategory[];
loggingOverrides?: LoggingOverrides;
}
@@ -68,6 +75,7 @@ const InvokeAIUI = ({
isDebugging = false,
logo,
workflowCategories,
workflowTagCategories,
loggingOverrides,
}: Props) => {
useLayoutEffect(() => {
@@ -195,14 +203,24 @@ const InvokeAIUI = ({
useEffect(() => {
if (workflowCategories) {
$workflowCategories.set(workflowCategories);
$workflowLibraryCategoriesOptions.set(workflowCategories);
}
return () => {
$workflowCategories.set([]);
$workflowLibraryCategoriesOptions.set(DEFAULT_WORKFLOW_LIBRARY_CATEGORIES);
};
}, [workflowCategories]);
useEffect(() => {
if (workflowTagCategories) {
$workflowLibraryTagCategoriesOptions.set(workflowTagCategories);
}
return () => {
$workflowLibraryTagCategoriesOptions.set(DEFAULT_WORKFLOW_LIBRARY_TAG_CATEGORIES);
};
}, [workflowTagCategories]);
useEffect(() => {
if (socketOptions) {
$socketOptions.set(socketOptions);

View File

@@ -15,7 +15,7 @@ import { $isWorkflowLibraryModalOpen } from 'features/nodes/store/workflowLibrar
import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { activeTabCanvasRightPanelChanged, setActiveTab } from 'features/ui/store/uiSlice';
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { atom } from 'nanostores';
import { useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
@@ -57,7 +57,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
const { t } = useTranslation();
const didParseOpenAPISchema = useStore($hasTemplates);
const store = useAppStore();
const { getAndLoadWorkflow } = useGetAndLoadLibraryWorkflow();
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const handleSendToCanvas = useCallback(
async (imageName: string) => {
@@ -113,10 +113,15 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
const handleLoadWorkflow = useCallback(
async (workflowId: string) => {
// This shows a toast
await getAndLoadWorkflow(workflowId);
store.dispatch(setActiveTab('workflows'));
await loadWorkflowWithDialog({
type: 'library',
data: workflowId,
onSuccess: () => {
store.dispatch(setActiveTab('workflows'));
},
});
},
[getAndLoadWorkflow, store]
[loadWorkflowWithDialog, store]
);
const handleSelectStylePreset = useCallback(

View File

@@ -31,6 +31,7 @@ import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isControlLayerModelConfig,
isFluxReduxModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
@@ -77,6 +78,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
handleT5EncoderModels(models, state, dispatch, log);
handleCLIPEmbedModels(models, state, dispatch, log);
handleFLUXVAEModels(models, state, dispatch, log);
handleFLUXReduxModels(models, state, dispatch, log);
},
});
};
@@ -209,6 +211,10 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log)
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
const ipaModels = models.filter(isIPAdapterModelConfig);
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
const selectedIPAdapterModel = entity.ipAdapter.model;
// `null` is a valid IP adapter model - no need to do anything.
if (!selectedIPAdapterModel) {
@@ -224,6 +230,10 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
if (ipAdapter.type !== 'ip_adapter') {
return;
}
const selectedIPAdapterModel = ipAdapter.model;
// `null` is a valid IP adapter model - no need to do anything.
if (!selectedIPAdapterModel) {
@@ -241,6 +251,49 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
});
};
const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
const fluxReduxModels = models.filter(isFluxReduxModelConfig);
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
if (entity.ipAdapter.type !== 'flux_redux') {
return;
}
const selectedFLUXReduxModel = entity.ipAdapter.model;
// `null` is a valid FLUX Redux model - no need to do anything.
if (!selectedFLUXReduxModel) {
return;
}
const isModelAvailable = fluxReduxModels.some((m) => m.key === selectedFLUXReduxModel.key);
if (isModelAvailable) {
return;
}
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
});
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
if (ipAdapter.type !== 'flux_redux') {
return;
}
const selectedFLUXReduxModel = ipAdapter.model;
// `null` is a valid FLUX Redux model - no need to do anything.
if (!selectedFLUXReduxModel) {
return;
}
const isModelAvailable = fluxReduxModels.some((m) => m.key === selectedFLUXReduxModel.key);
if (isModelAvailable) {
return;
}
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
dispatch(
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
);
});
});
};
const handlePostProcessingModel: ModelHandler = (models, state, dispatch, log) => {
const selectedPostProcessingModel = state.upscale.postProcessingModel;
const allSpandrelModels = models.filter(isSpandrelImageToImageModelConfig);

View File

@@ -1,4 +0,0 @@
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import { atom } from 'nanostores';
export const $workflowCategories = atom<WorkflowCategory[]>(['user', 'default']);

View File

@@ -19,6 +19,7 @@ import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/galle
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
import { workflowLibraryPersistConfig, workflowLibrarySlice } from 'features/nodes/store/workflowLibrarySlice';
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
@@ -68,6 +69,7 @@ const allReducers = {
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
[canvasStagingAreaSlice.name]: canvasStagingAreaSlice.reducer,
[lorasSlice.name]: lorasSlice.reducer,
[workflowLibrarySlice.name]: workflowLibrarySlice.reducer,
};
const rootReducer = combineReducers(allReducers);
@@ -113,6 +115,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
[lorasPersistConfig.name]: lorasPersistConfig,
[workflowLibraryPersistConfig.name]: workflowLibraryPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {

View File

@@ -9,7 +9,7 @@ import {
useAddRegionalGuidance,
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@@ -22,7 +22,6 @@ export const CanvasAddEntityButtons = memo(() => {
const addControlLayer = useAddControlLayer();
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
return (
@@ -75,7 +74,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalReferenceImage}
isDisabled={isFLUX || isSD3}
isDisabled={isSD3}
>
{t('controlLayers.regionalReferenceImage')}
</Button>

View File

@@ -9,7 +9,7 @@ import {
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@@ -23,7 +23,6 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
return (
@@ -52,7 +51,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isSD3}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isSD3}>
{t('controlLayers.regionalReferenceImage')}
</MenuItem>
</MenuGroup>

View File

@@ -0,0 +1,61 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import type { CLIPVisionModelV2 } from 'features/controlLayers/store/types';
import { isCLIPVisionModelV2 } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { assert } from 'tsafe';
// at this time, ViT-L is the only supported clip model for FLUX IP adapter
const FLUX_CLIP_VISION = 'ViT-L';
const CLIP_VISION_OPTIONS = [
{ label: 'ViT-H', value: 'ViT-H' },
{ label: 'ViT-G', value: 'ViT-G' },
{ label: FLUX_CLIP_VISION, value: FLUX_CLIP_VISION },
];
type Props = {
model: CLIPVisionModelV2;
onChange: (clipVisionModel: CLIPVisionModelV2) => void;
};
export const CLIPVisionModel = memo(({ model, onChange }: Props) => {
const { t } = useTranslation();
const _onChangeCLIPVisionModel = useCallback<ComboboxOnChange>(
(v) => {
assert(isCLIPVisionModelV2(v?.value));
onChange(v.value);
},
[onChange]
);
const isFLUX = useAppSelector(selectIsFLUX);
const clipVisionOptions = useMemo(() => {
return CLIP_VISION_OPTIONS.map((option) => ({
...option,
isDisabled: isFLUX && option.value !== FLUX_CLIP_VISION,
}));
}, [isFLUX]);
const clipVisionModelValue = useMemo(() => {
return CLIP_VISION_OPTIONS.find((o) => o.value === model);
}, [model]);
return (
<FormControl width="max-content" minWidth={28}>
<Combobox
options={clipVisionOptions}
placeholder={t('common.placeholderSelectAModel')}
value={clipVisionModelValue}
onChange={_onChangeCLIPVisionModel}
/>
</FormControl>
);
});
CLIPVisionModel.displayName = 'CLIPVisionModel';

View File

@@ -1,40 +1,36 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase, selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import type { CLIPVisionModelV2 } from 'features/controlLayers/store/types';
import { isCLIPVisionModelV2 } from 'features/controlLayers/store/types';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
// at this time, ViT-L is the only supported clip model for FLUX IP adapter
const FLUX_CLIP_VISION = 'ViT-L';
const CLIP_VISION_OPTIONS = [
{ label: 'ViT-H', value: 'ViT-H' },
{ label: 'ViT-G', value: 'ViT-G' },
{ label: FLUX_CLIP_VISION, value: FLUX_CLIP_VISION },
];
import { useIPAdapterOrFLUXReduxModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
type Props = {
isRegionalGuidance: boolean;
modelKey: string | null;
onChangeModel: (modelConfig: IPAdapterModelConfig) => void;
clipVisionModel: CLIPVisionModelV2;
onChangeCLIPVisionModel: (clipVisionModel: CLIPVisionModelV2) => void;
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => void;
};
export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel, onChangeCLIPVisionModel }: Props) => {
export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useIPAdapterModels();
const filter = useCallback(
(config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
if (isRegionalGuidance && config.base === 'flux' && config.type === 'ip_adapter') {
return false;
}
return true;
},
[isRegionalGuidance]
);
const [modelConfigs, { isLoading }] = useIPAdapterOrFLUXReduxModels(filter);
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | null) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null) => {
if (!modelConfig) {
return;
}
@@ -43,21 +39,11 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
[onChangeModel]
);
const _onChangeCLIPVisionModel = useCallback<ComboboxOnChange>(
(v) => {
assert(isCLIPVisionModelV2(v?.value));
onChangeCLIPVisionModel(v.value);
},
[onChangeCLIPVisionModel]
);
const isFLUX = useAppSelector(selectIsFLUX);
const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const isCompatible = currentBaseModel === model.base;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
const hasSameBase = currentBaseModel === model.base;
return !hasMainModel || !hasSameBase;
},
[currentBaseModel]
);
@@ -70,41 +56,18 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
isLoading,
});
const clipVisionOptions = useMemo(() => {
return CLIP_VISION_OPTIONS.map((option) => ({
...option,
isDisabled: isFLUX && option.value !== FLUX_CLIP_VISION,
}));
}, [isFLUX]);
const clipVisionModelValue = useMemo(() => {
return CLIP_VISION_OPTIONS.find((o) => o.value === clipVisionModel);
}, [clipVisionModel]);
return (
<Flex gap={2}>
<Tooltip label={selectedModel?.description}>
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
{selectedModel?.format === 'checkpoint' && (
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} width="max-content" minWidth={28}>
<Combobox
options={clipVisionOptions}
placeholder={t('common.placeholderSelectAModel')}
value={clipVisionModelValue}
onChange={_onChangeCLIPVisionModel}
/>
</FormControl>
)}
</Flex>
<Tooltip label={selectedModel?.description}>
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
);
});

View File

@@ -1,9 +1,10 @@
import { Box, Flex, IconButton } from '@invoke-ai/ui-library';
import { Flex, IconButton } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -25,7 +26,7 @@ import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold } from 'react-icons/pi';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
import { IPAdapterModel } from './IPAdapterModel';
@@ -65,7 +66,7 @@ const IPAdapterSettingsContent = memo(() => {
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
},
[dispatch, entityIdentifier]
@@ -98,14 +99,14 @@ const IPAdapterSettingsContent = memo(() => {
<CanvasEntitySettingsWrapper>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
<IPAdapterModel
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
clipVisionModel={ipAdapter.clipVisionModel}
onChangeCLIPVisionModel={onChangeCLIPVisionModel}
/>
</Box>
<IPAdapterModel
isRegionalGuidance={false}
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
/>
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}
<IconButton
onClick={pullBboxIntoIPAdapter}
isDisabled={isBusy}
@@ -116,12 +117,14 @@ const IPAdapterSettingsContent = memo(() => {
/>
</Flex>
<Flex gap={2} w="full" alignItems="center">
<Flex flexDir="column" gap={2} w="full">
{!isFLUX && <IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />}
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
{ipAdapter.type === 'ip_adapter' && (
<Flex flexDir="column" gap={2} w="full">
{!isFLUX && <IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />}
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
)}
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
<IPAdapterImagePreview
image={ipAdapter.image}
onChangeImage={onChangeImage}

View File

@@ -1,8 +1,9 @@
import { Box, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
import { Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
@@ -26,7 +27,7 @@ import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiXBold } from 'react-icons/pi';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
type Props = {
@@ -73,7 +74,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
dispatch(rgIPAdapterModelChanged({ entityIdentifier, referenceImageId, modelConfig }));
},
[dispatch, entityIdentifier, referenceImageId]
@@ -125,14 +126,14 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
</Flex>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
<IPAdapterModel
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
clipVisionModel={ipAdapter.clipVisionModel}
onChangeCLIPVisionModel={onChangeCLIPVisionModel}
/>
</Box>
<IPAdapterModel
isRegionalGuidance={true}
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
/>
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}
<IconButton
onClick={pullBboxIntoIPAdapter}
isDisabled={isBusy}
@@ -143,12 +144,14 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
/>
</Flex>
<Flex gap={2} w="full">
<Flex flexDir="column" gap={2} w="full">
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
{ipAdapter.type === 'ip_adapter' && (
<Flex flexDir="column" gap={2} w="full">
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
)}
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
<IPAdapterImagePreview
image={ipAdapter.image}
onChangeImage={onChangeImage}

View File

@@ -38,6 +38,7 @@ import type { UndoableOptions } from 'redux-undo';
import type {
ControlLoRAModelConfig,
ControlNetModelConfig,
FLUXReduxModelConfig,
ImageDTO,
IPAdapterModelConfig,
T2IAdapterModelConfig,
@@ -76,6 +77,7 @@ import {
imageDTOToImageWithDims,
initialControlLoRA,
initialControlNet,
initialFLUXRedux,
initialIPAdapter,
initialT2IAdapter,
} from './util';
@@ -619,11 +621,16 @@ export const canvasSlice = createSlice({
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.method = method;
},
referenceImageIPAdapterModelChanged: (
state,
action: PayloadAction<EntityIdentifierPayload<{ modelConfig: IPAdapterModelConfig | null }, 'reference_image'>>
action: PayloadAction<
EntityIdentifierPayload<{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null }, 'reference_image'>
>
) => {
const { entityIdentifier, modelConfig } = action.payload;
const entity = selectEntity(state, entityIdentifier);
@@ -631,12 +638,39 @@ export const canvasSlice = createSlice({
return;
}
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (entity.ipAdapter.model?.base === 'flux') {
entity.ipAdapter.clipVisionModel = 'ViT-L';
} else if (entity.ipAdapter.clipVisionModel === 'ViT-L') {
// Fall back to ViT-H (ViT-G would also work)
entity.ipAdapter.clipVisionModel = 'ViT-H';
if (!entity.ipAdapter.model) {
return;
}
if (entity.ipAdapter.type === 'ip_adapter' && entity.ipAdapter.model.type === 'flux_redux') {
// Switching from ip_adapter to flux_redux
entity.ipAdapter = {
...initialFLUXRedux,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.type === 'flux_redux' && entity.ipAdapter.model.type === 'ip_adapter') {
// Switching from flux_redux to ip_adapter
entity.ipAdapter = {
...initialIPAdapter,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.type === 'ip_adapter') {
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (entity.ipAdapter.model?.base === 'flux') {
entity.ipAdapter.clipVisionModel = 'ViT-L';
} else if (entity.ipAdapter.clipVisionModel === 'ViT-L') {
// Fall back to ViT-H (ViT-G would also work)
entity.ipAdapter.clipVisionModel = 'ViT-H';
}
}
},
referenceImageIPAdapterCLIPVisionModelChanged: (
@@ -648,6 +682,9 @@ export const canvasSlice = createSlice({
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.clipVisionModel = clipVisionModel;
},
referenceImageIPAdapterWeightChanged: (
@@ -659,6 +696,9 @@ export const canvasSlice = createSlice({
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.weight = weight;
},
referenceImageIPAdapterBeginEndStepPctChanged: (
@@ -670,6 +710,9 @@ export const canvasSlice = createSlice({
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.beginEndStepPct = beginEndStepPct;
},
//#region Regional Guidance
@@ -843,6 +886,10 @@ export const canvasSlice = createSlice({
if (!referenceImage) {
return;
}
if (referenceImage.ipAdapter.type !== 'ip_adapter') {
return;
}
referenceImage.ipAdapter.weight = weight;
},
rgIPAdapterBeginEndStepPctChanged: (
@@ -856,6 +903,10 @@ export const canvasSlice = createSlice({
if (!referenceImage) {
return;
}
if (referenceImage.ipAdapter.type !== 'ip_adapter') {
return;
}
referenceImage.ipAdapter.beginEndStepPct = beginEndStepPct;
},
rgIPAdapterMethodChanged: (
@@ -869,6 +920,10 @@ export const canvasSlice = createSlice({
if (!referenceImage) {
return;
}
if (referenceImage.ipAdapter.type !== 'ip_adapter') {
return;
}
referenceImage.ipAdapter.method = method;
},
rgIPAdapterModelChanged: (
@@ -877,7 +932,7 @@ export const canvasSlice = createSlice({
EntityIdentifierPayload<
{
referenceImageId: string;
modelConfig: IPAdapterModelConfig | null;
modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null;
},
'regional_guidance'
>
@@ -889,12 +944,39 @@ export const canvasSlice = createSlice({
return;
}
referenceImage.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (referenceImage.ipAdapter.model?.base === 'flux') {
referenceImage.ipAdapter.clipVisionModel = 'ViT-L';
} else if (referenceImage.ipAdapter.clipVisionModel === 'ViT-L') {
// Fall back to ViT-H (ViT-G would also work)
referenceImage.ipAdapter.clipVisionModel = 'ViT-H';
if (!referenceImage.ipAdapter.model) {
return;
}
if (referenceImage.ipAdapter.type === 'ip_adapter' && referenceImage.ipAdapter.model.type === 'flux_redux') {
// Switching from ip_adapter to flux_redux
referenceImage.ipAdapter = {
...initialFLUXRedux,
image: referenceImage.ipAdapter.image,
model: referenceImage.ipAdapter.model,
};
return;
}
if (referenceImage.ipAdapter.type === 'flux_redux' && referenceImage.ipAdapter.model.type === 'ip_adapter') {
// Switching from flux_redux to ip_adapter
referenceImage.ipAdapter = {
...initialIPAdapter,
image: referenceImage.ipAdapter.image,
model: referenceImage.ipAdapter.model,
};
return;
}
if (referenceImage.ipAdapter.type === 'ip_adapter') {
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (referenceImage.ipAdapter.model?.base === 'flux') {
referenceImage.ipAdapter.clipVisionModel = 'ViT-L';
} else if (referenceImage.ipAdapter.clipVisionModel === 'ViT-L') {
// Fall back to ViT-H (ViT-G would also work)
referenceImage.ipAdapter.clipVisionModel = 'ViT-H';
}
}
},
rgIPAdapterCLIPVisionModelChanged: (
@@ -908,6 +990,10 @@ export const canvasSlice = createSlice({
if (!referenceImage) {
return;
}
if (referenceImage.ipAdapter.type !== 'ip_adapter') {
return;
}
referenceImage.ipAdapter.clipVisionModel = clipVisionModel;
},
//#region Inpaint mask

View File

@@ -233,6 +233,13 @@ const zIPAdapterConfig = z.object({
});
export type IPAdapterConfig = z.infer<typeof zIPAdapterConfig>;
const zFLUXReduxConfig = z.object({
type: z.literal('flux_redux'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
});
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
const zCanvasEntityBase = z.object({
id: zId,
name: zName,
@@ -242,10 +249,16 @@ const zCanvasEntityBase = z.object({
const zCanvasReferenceImageState = zCanvasEntityBase.extend({
type: z.literal('reference_image'),
ipAdapter: zIPAdapterConfig,
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
});
export type CanvasReferenceImageState = z.infer<typeof zCanvasReferenceImageState>;
export const isIPAdapterConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is IPAdapterConfig =>
config.type === 'ip_adapter';
export const isFLUXReduxConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is FLUXReduxConfig =>
config.type === 'flux_redux';
const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
export type FillStyle = z.infer<typeof zFillStyle>;
export const isFillStyle = (v: unknown): v is FillStyle => zFillStyle.safeParse(v).success;
@@ -253,7 +266,7 @@ const zFill = z.object({ style: zFillStyle, color: zRgbColor });
const zRegionalGuidanceReferenceImageState = z.object({
id: zId,
ipAdapter: zIPAdapterConfig,
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
});
export type RegionalGuidanceReferenceImageState = z.infer<typeof zRegionalGuidanceReferenceImageState>;

View File

@@ -9,6 +9,7 @@ import type {
CanvasRegionalGuidanceState,
ControlLoRAConfig,
ControlNetConfig,
FLUXReduxConfig,
ImageWithDims,
IPAdapterConfig,
RgbColor,
@@ -70,6 +71,11 @@ export const initialIPAdapter: IPAdapterConfig = {
clipVisionModel: 'ViT-H',
weight: 1,
};
export const initialFLUXRedux: FLUXReduxConfig = {
type: 'flux_redux',
image: null,
model: null,
};
export const initialT2IAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
model: null,

View File

@@ -44,33 +44,33 @@ export const getRegionalGuidanceWarnings = (
if (model.base === 'sd-3' || model.base === 'sd-2') {
// Unsupported model architecture
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
} else if (model.base === 'flux') {
return warnings;
}
if (model.base === 'flux') {
// Some features are not supported for flux models
if (entity.negativePrompt !== null) {
warnings.push(WARNINGS.RG_NEGATIVE_PROMPT_NOT_SUPPORTED);
}
if (entity.referenceImages.length > 0) {
warnings.push(WARNINGS.RG_REFERENCE_IMAGES_NOT_SUPPORTED);
}
if (entity.autoNegative) {
warnings.push(WARNINGS.RG_AUTO_NEGATIVE_NOT_SUPPORTED);
}
} else {
entity.referenceImages.forEach(({ ipAdapter }) => {
if (!ipAdapter.model) {
// No model selected
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
} else if (ipAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
}
if (!ipAdapter.image) {
// No image selected
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
}
});
}
entity.referenceImages.forEach(({ ipAdapter }) => {
if (!ipAdapter.model) {
// No model selected
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
} else if (ipAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
}
if (!ipAdapter.image) {
// No image selected
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
}
});
}
return warnings;
@@ -82,22 +82,27 @@ export const getGlobalReferenceImageWarnings = (
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
if (!entity.ipAdapter.model) {
// No model selected
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
} else if (model) {
if (model) {
if (model.base === 'sd-3' || model.base === 'sd-2') {
// Unsupported model architecture
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
} else if (entity.ipAdapter.model.base !== model.base) {
return warnings;
}
const { ipAdapter } = entity;
if (!ipAdapter.model) {
// No model selected
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
} else if (ipAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
}
}
if (!entity.ipAdapter.image) {
// No image selected
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
if (!entity.ipAdapter.image) {
// No image selected
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
}
}
return warnings;

View File

@@ -1,9 +1,8 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { SpinnerIcon } from 'features/gallery/components/ImageContextMenu/SpinnerIcon';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFlowArrowBold } from 'react-icons/pi';
@@ -11,19 +10,15 @@ import { PiFlowArrowBold } from 'react-icons/pi';
export const ImageMenuItemLoadWorkflow = memo(() => {
const { t } = useTranslation();
const imageDTO = useImageDTOContext();
const [getAndLoadEmbeddedWorkflow, { isLoading }] = useGetAndLoadEmbeddedWorkflow();
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const hasTemplates = useStore($hasTemplates);
const onClick = useCallback(() => {
getAndLoadEmbeddedWorkflow(imageDTO.image_name);
}, [getAndLoadEmbeddedWorkflow, imageDTO.image_name]);
loadWorkflowWithDialog({ type: 'image', data: imageDTO.image_name });
}, [loadWorkflowWithDialog, imageDTO.image_name]);
return (
<MenuItem
icon={isLoading ? <SpinnerIcon /> : <PiFlowArrowBold />}
onClickCapture={onClick}
isDisabled={!imageDTO.has_workflow || !hasTemplates}
>
<MenuItem icon={<PiFlowArrowBold />} onClickCapture={onClick} isDisabled={!imageDTO.has_workflow || !hasTemplates}>
{t('nodes.loadWorkflow')}
</MenuItem>
);

View File

@@ -1,7 +0,0 @@
import { Flex, Spinner } from '@invoke-ai/ui-library';
export const SpinnerIcon = () => (
<Flex w="14px" alignItems="center" justifyContent="center">
<Spinner size="xs" />
</Flex>
);

View File

@@ -17,7 +17,7 @@ import {
} from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
@@ -147,14 +147,15 @@ export const useImageActions = (imageDTO: ImageDTO) => {
});
}, [metadata, imageDTO]);
const [getAndLoadEmbeddedWorkflow] = useGetAndLoadEmbeddedWorkflow();
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const loadWorkflow = useCallback(() => {
const loadWorkflowFromImage = useCallback(() => {
if (!imageDTO.has_workflow || !hasTemplates) {
return;
}
getAndLoadEmbeddedWorkflow(imageDTO.image_name);
}, [getAndLoadEmbeddedWorkflow, hasTemplates, imageDTO.has_workflow, imageDTO.image_name]);
loadWorkflowWithDialog({ type: 'image', data: imageDTO.image_name });
}, [hasTemplates, imageDTO.has_workflow, imageDTO.image_name, loadWorkflowWithDialog]);
const recallSize = useCallback(() => {
if (isStaging) {
@@ -180,7 +181,7 @@ export const useImageActions = (imageDTO: ImageDTO) => {
recallSeed,
recallPrompts,
createAsPreset,
loadWorkflow,
loadWorkflow: loadWorkflowFromImage,
hasWorkflow: imageDTO.has_workflow,
recallSize,
upscale,

View File

@@ -6,7 +6,7 @@ import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxReduxModels } from 'services/api/hooks/modelsByType';
import type { FluxReduxModelConfig } from 'services/api/types';
import type { FLUXReduxModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -19,7 +19,7 @@ const FluxReduxModelFieldInputComponent = (
const [modelConfigs, { isLoading }] = useFluxReduxModels();
const _onChange = useCallback(
(value: FluxReduxModelConfig | null) => {
(value: FLUXReduxModelConfig | null) => {
if (!value) {
return;
}

View File

@@ -1,4 +1,4 @@
import { Text } from '@invoke-ai/ui-library';
import { Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { linkifyOptions, linkifySx } from 'common/components/linkify';
import { selectWorkflowDescription } from 'features/nodes/store/workflowSlice';
@@ -13,9 +13,11 @@ export const ActiveWorkflowDescription = memo(() => {
}
return (
<Text color="base.300" fontStyle="italic" pb={2} sx={linkifySx}>
<Linkify options={linkifyOptions}>{description}</Linkify>
</Text>
<Tooltip label={description}>
<Text color="base.300" fontStyle="italic" sx={linkifySx} noOfLines={1}>
<Linkify options={linkifyOptions}>{description}</Linkify>
</Text>
</Tooltip>
);
});

View File

@@ -9,7 +9,7 @@ import { useZoomToNode } from 'features/nodes/hooks/useZoomToNode';
import { formElementRemoved } from 'features/nodes/store/workflowSlice';
import type { FormElement, NodeFieldElement } from 'features/nodes/types/workflow';
import { isContainerElement, isNodeFieldElement } from 'features/nodes/types/workflow';
import { startCase } from 'lodash-es';
import { camelCase } from 'lodash-es';
import type { RefObject } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -103,15 +103,16 @@ const RemoveElementButton = memo(({ element }: { element: FormElement }) => {
RemoveElementButton.displayName = 'RemoveElementButton';
const Label = memo(({ element }: { element: FormElement }) => {
const { t } = useTranslation();
const label = useMemo(() => {
if (isContainerElement(element) && element.data.layout === 'column') {
return `Container (column layout)`;
return t('workflows.builder.containerColumnLayout');
}
if (isContainerElement(element) && element.data.layout === 'row') {
return `Container (row layout)`;
return t('workflows.builder.containerRowLayout');
}
return startCase(element.type);
}, [element]);
return t(`workflows.builder.${camelCase(element.type)}`);
}, [element, t]);
return (
<Text fontWeight="semibold" noOfLines={1} wordBreak="break-all" userSelect="none">

View File

@@ -8,6 +8,7 @@ import { useTranslation } from 'react-i18next';
const headingSx: SystemStyleObject = {
fontWeight: 'bold',
fontSize: '2xl',
whiteSpace: 'pre-wrap',
'&[data-is-empty="true"]': {
opacity: 0.3,
},

View File

@@ -7,6 +7,7 @@ import { useTranslation } from 'react-i18next';
const textSx: SystemStyleObject = {
fontSize: 'md',
whiteSpace: 'pre-wrap',
overflowWrap: 'anywhere',
'&[data-is-empty="true"]': {
opacity: 0.3,

View File

@@ -20,7 +20,8 @@ export const DeleteWorkflow = ({ workflowId }: { workflowId: string }) => {
<Tooltip label={t('workflows.delete')} closeOnScroll>
<IconButton
size="sm"
variant="ghost"
variant="link"
alignSelf="stretch"
aria-label={t('workflows.delete')}
onClick={handleClickDelete}
colorScheme="error"

View File

@@ -21,7 +21,8 @@ export const DownloadWorkflow = ({ workflowId }: { workflowId: string }) => {
<Tooltip label={t('workflows.download')} closeOnScroll>
<IconButton
size="sm"
variant="ghost"
variant="link"
alignSelf="stretch"
aria-label={t('workflows.download')}
onClick={handleClickDownload}
icon={<PiDownloadSimpleBold />}

View File

@@ -1,27 +1,37 @@
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
import { useLoadWorkflow } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { useAppDispatch } from 'app/store/storeHooks';
import { workflowModeChanged } from 'features/nodes/store/workflowSlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import type { MouseEvent } from 'react';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPencilBold } from 'react-icons/pi';
export const EditWorkflow = ({ workflowId }: { workflowId: string }) => {
const loadWorkflow = useLoadWorkflow();
const dispatch = useAppDispatch();
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const { t } = useTranslation();
const handleClickEdit = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
loadWorkflow.loadWithDialog(workflowId, 'edit');
loadWorkflowWithDialog({
type: 'library',
data: workflowId,
onSuccess: () => {
dispatch(workflowModeChanged('edit'));
},
});
},
[loadWorkflow, workflowId]
[dispatch, loadWorkflowWithDialog, workflowId]
);
return (
<Tooltip label={t('workflows.edit')} closeOnScroll>
<IconButton
size="sm"
variant="ghost"
variant="link"
alignSelf="stretch"
aria-label={t('workflows.edit')}
onClick={handleClickEdit}
icon={<PiPencilBold />}

View File

@@ -1,32 +0,0 @@
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
import { useLoadWorkflow } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import type { MouseEvent } from 'react';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFloppyDiskBold } from 'react-icons/pi';
// needs to clone and save workflow to account without taking over editor
export const SaveWorkflow = ({ workflowId }: { workflowId: string }) => {
const loadWorkflow = useLoadWorkflow();
const { t } = useTranslation();
const handleClickSave = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
loadWorkflow.loadWithDialog(workflowId, 'view');
},
[loadWorkflow, workflowId]
);
return (
<Tooltip label={t('workflows.edit')} closeOnScroll>
<IconButton
size="sm"
variant="ghost"
aria-label={t('workflows.edit')}
onClick={handleClickSave}
icon={<PiFloppyDiskBold />}
/>
</Tooltip>
);
};

View File

@@ -22,7 +22,8 @@ export const ShareWorkflowButton = memo(({ workflow }: { workflow: WorkflowRecor
<Tooltip label={t('workflows.copyShareLink')} closeOnScroll>
<IconButton
size="sm"
variant="ghost"
variant="link"
alignSelf="stretch"
aria-label={t('workflows.copyShareLink')}
onClick={handleClickShare}
icon={<PiShareFatBold />}

View File

@@ -1,28 +1,38 @@
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
import { useLoadWorkflow } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { useAppDispatch } from 'app/store/storeHooks';
import { workflowModeChanged } from 'features/nodes/store/workflowSlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import type { MouseEvent } from 'react';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiEyeBold } from 'react-icons/pi';
export const ViewWorkflow = ({ workflowId }: { workflowId: string }) => {
const loadWorkflow = useLoadWorkflow();
const dispatch = useAppDispatch();
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const { t } = useTranslation();
const handleClickLoad = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
loadWorkflow.loadWithDialog(workflowId, 'view');
loadWorkflowWithDialog({
type: 'library',
data: workflowId,
onSuccess: () => {
dispatch(workflowModeChanged('view'));
},
});
},
[loadWorkflow, workflowId]
[dispatch, loadWorkflowWithDialog, workflowId]
);
return (
<Tooltip label={t('workflows.edit')} closeOnScroll>
<Tooltip label={t('workflows.view')} closeOnScroll>
<IconButton
size="sm"
variant="ghost"
aria-label={t('workflows.edit')}
variant="link"
alignSelf="stretch"
aria-label={t('workflows.view')}
onClick={handleClickLoad}
icon={<PiEyeBold />}
/>

View File

@@ -8,16 +8,27 @@ import {
ModalHeader,
ModalOverlay,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { useWorkflowLibraryModal } from 'features/nodes/store/workflowLibraryModal';
import {
$workflowLibraryCategoriesOptions,
selectWorkflowLibraryView,
workflowLibraryViewChanged,
} from 'features/nodes/store/workflowLibrarySlice';
import { memo, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetCountsByCategoryQuery } from 'services/api/endpoints/workflows';
import { WorkflowLibrarySideNav } from './WorkflowLibrarySideNav';
import { WorkflowLibraryTopNav } from './WorkflowLibraryTopNav';
import { WorkflowList } from './WorkflowList';
export const WorkflowLibraryModal = () => {
export const WorkflowLibraryModal = memo(() => {
const { t } = useTranslation();
const workflowLibraryModal = useWorkflowLibraryModal();
const didSync = useSyncInitialWorkflowLibraryCategories();
return (
<Modal isOpen={workflowLibraryModal.isOpen} onClose={workflowLibraryModal.close} isCentered>
<ModalOverlay />
@@ -30,16 +41,102 @@ export const WorkflowLibraryModal = () => {
<ModalHeader>{t('workflows.workflowLibrary')}</ModalHeader>
<ModalCloseButton />
<ModalBody pb={6}>
<Flex gap={4} h="100%">
<WorkflowLibrarySideNav />
<Divider orientation="vertical" />
<Flex flexDir="column" flex={1} gap={4}>
<WorkflowLibraryTopNav />
<WorkflowList />
{didSync && (
<Flex gap={4} h="100%">
<WorkflowLibrarySideNav />
<Divider orientation="vertical" />
<Flex flexDir="column" flex={1} gap={4}>
<WorkflowLibraryTopNav />
<WorkflowList />
</Flex>
</Flex>
</Flex>
)}
{!didSync && <IAINoContentFallback label={t('workflows.loading')} icon={null} />}
</ModalBody>
</ModalContent>
</Modal>
);
});
WorkflowLibraryModal.displayName = 'WorkflowLibraryModal';
/**
* On first app load, if the user's selected view has no workflows, switches to the next available view.
*/
const useSyncInitialWorkflowLibraryCategories = () => {
const dispatch = useAppDispatch();
const view = useAppSelector(selectWorkflowLibraryView);
const categoryOptions = useStore($workflowLibraryCategoriesOptions);
const [didSync, setDidSync] = useState(false);
const recentWorkflowsCountQueryArg = useMemo(
() =>
({
categories: ['user', 'project', 'default'],
has_been_opened: true,
}) satisfies Parameters<typeof useGetCountsByCategoryQuery>[0],
[]
);
const yourWorkflowsCountQueryArg = useMemo(
() =>
({
categories: ['user', 'project'],
}) satisfies Parameters<typeof useGetCountsByCategoryQuery>[0],
[]
);
const queryOptions = useMemo(
() =>
({
selectFromResult: ({ data, isLoading }) => {
if (!data) {
return { count: 0, isLoading: true };
}
return {
count: Object.values(data).reduce((acc, count) => acc + count, 0),
isLoading,
};
},
}) satisfies Parameters<typeof useGetCountsByCategoryQuery>[1],
[]
);
const { count: recentWorkflowsCount, isLoading: isLoadingRecentWorkflowsCount } = useGetCountsByCategoryQuery(
recentWorkflowsCountQueryArg,
queryOptions
);
const { count: yourWorkflowsCount, isLoading: isLoadingYourWorkflowsCount } = useGetCountsByCategoryQuery(
yourWorkflowsCountQueryArg,
queryOptions
);
useEffect(() => {
if (didSync || isLoadingRecentWorkflowsCount || isLoadingYourWorkflowsCount) {
return;
}
// If the user's selected view has no workflows, switch to the next available view
if (recentWorkflowsCount === 0 && view === 'recent') {
if (yourWorkflowsCount > 0) {
dispatch(workflowLibraryViewChanged('yours'));
} else {
dispatch(workflowLibraryViewChanged('defaults'));
}
} else if (yourWorkflowsCount === 0 && (view === 'yours' || view === 'shared' || view === 'private')) {
if (recentWorkflowsCount > 0) {
dispatch(workflowLibraryViewChanged('recent'));
} else {
dispatch(workflowLibraryViewChanged('defaults'));
}
}
setDidSync(true);
}, [
categoryOptions,
didSync,
dispatch,
isLoadingRecentWorkflowsCount,
isLoadingYourWorkflowsCount,
recentWorkflowsCount,
view,
yourWorkflowsCount,
]);
return didSync;
};

View File

@@ -1,145 +1,55 @@
import type { ButtonProps, CheckboxProps } from '@invoke-ai/ui-library';
import { Button, Checkbox, Collapse, Flex, Icon, Spacer, Text } from '@invoke-ai/ui-library';
import { Button, Checkbox, Collapse, Flex, Spacer, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $workflowCategories } from 'app/store/nanostores/workflowCategories';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { WORKFLOW_TAGS, type WorkflowTag } from 'features/nodes/store/types';
import type { WorkflowLibraryView, WorkflowTagCategory } from 'features/nodes/store/workflowLibrarySlice';
import {
$workflowLibraryCategoriesOptions,
$workflowLibraryTagCategoriesOptions,
$workflowLibraryTagOptions,
selectWorkflowLibrarySelectedTags,
selectWorkflowSelectedCategories,
workflowSelectedCategoriesChanged,
workflowSelectedTagsRese,
workflowSelectedTagToggled,
} from 'features/nodes/store/workflowSlice';
import { useLoadWorkflow } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
selectWorkflowLibraryView,
workflowLibraryTagsReset,
workflowLibraryTagToggled,
workflowLibraryViewChanged,
} from 'features/nodes/store/workflowLibrarySlice';
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
import { UploadWorkflowButton } from 'features/workflowLibrary/components/UploadWorkflowButton';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiUsersBold } from 'react-icons/pi';
import { useDispatch } from 'react-redux';
import { useGetCountsQuery, useListWorkflowsQuery } from 'services/api/endpoints/workflows';
import type { S } from 'services/api/types';
import { useGetCountsByTagQuery } from 'services/api/endpoints/workflows';
export const WorkflowLibrarySideNav = () => {
const { t } = useTranslation();
const dispatch = useDispatch();
const categories = useAppSelector(selectWorkflowSelectedCategories);
const categoryOptions = useStore($workflowCategories);
const selectedTags = useAppSelector(selectWorkflowLibrarySelectedTags);
const selectYourWorkflows = useCallback(() => {
dispatch(workflowSelectedCategoriesChanged(categoryOptions.includes('project') ? ['user', 'project'] : ['user']));
}, [categoryOptions, dispatch]);
const selectPrivateWorkflows = useCallback(() => {
dispatch(workflowSelectedCategoriesChanged(['user']));
}, [dispatch]);
const selectSharedWorkflows = useCallback(() => {
dispatch(workflowSelectedCategoriesChanged(['project']));
}, [dispatch]);
const selectDefaultWorkflows = useCallback(() => {
dispatch(workflowSelectedCategoriesChanged(['default']));
}, [dispatch]);
const resetTags = useCallback(() => {
dispatch(workflowSelectedTagsRese());
}, [dispatch]);
const isYourWorkflowsSelected = useMemo(() => {
if (categoryOptions.includes('project')) {
return categories.includes('user') && categories.includes('project');
} else {
return categories.includes('user');
}
}, [categoryOptions, categories]);
const isPrivateWorkflowsExclusivelySelected = useMemo(() => {
return categories.length === 1 && categories.includes('user');
}, [categories]);
const isSharedWorkflowsExclusivelySelected = useMemo(() => {
return categories.length === 1 && categories.includes('project');
}, [categories]);
const isDefaultWorkflowsExclusivelySelected = useMemo(() => {
return categories.length === 1 && categories.includes('default');
}, [categories]);
const categoryOptions = useStore($workflowLibraryCategoriesOptions);
const view = useAppSelector(selectWorkflowLibraryView);
return (
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={1}>
<Flex flexDir="column" w="full" pb={2}>
<Text px={3} py={2} fontSize="md" fontWeight="semibold">
{t('workflows.recentlyOpened')}
</Text>
<Flex flexDir="column" gap={2} pl={4}>
<RecentWorkflows />
</Flex>
<WorkflowLibraryViewButton view="recent">{t('workflows.recentlyOpened')}</WorkflowLibraryViewButton>
</Flex>
<Flex flexDir="column" w="full" pb={2}>
<CategoryButton isSelected={isYourWorkflowsSelected} onClick={selectYourWorkflows}>
{t('workflows.yourWorkflows')}
</CategoryButton>
<WorkflowLibraryViewButton view="yours">{t('workflows.yourWorkflows')}</WorkflowLibraryViewButton>
{categoryOptions.includes('project') && (
<Collapse
in={
isYourWorkflowsSelected || isPrivateWorkflowsExclusivelySelected || isSharedWorkflowsExclusivelySelected
}
>
<Collapse in={view === 'yours' || view === 'shared' || view === 'private'}>
<Flex flexDir="column" gap={2} pl={4} pt={2}>
<CategoryButton
size="sm"
onClick={selectPrivateWorkflows}
isSelected={isPrivateWorkflowsExclusivelySelected}
>
<WorkflowLibraryViewButton size="sm" view="private">
{t('workflows.private')}
</CategoryButton>
<CategoryButton
size="sm"
rightIcon={<PiUsersBold />}
onClick={selectSharedWorkflows}
isSelected={isSharedWorkflowsExclusivelySelected}
>
</WorkflowLibraryViewButton>
<WorkflowLibraryViewButton size="sm" rightIcon={<PiUsersBold />} view="shared">
{t('workflows.shared')}
<Spacer />
</CategoryButton>
</WorkflowLibraryViewButton>
</Flex>
</Collapse>
)}
</Flex>
<Flex h="full" minH={0} overflow="hidden" flexDir="column">
<CategoryButton isSelected={isDefaultWorkflowsExclusivelySelected} onClick={selectDefaultWorkflows}>
{t('workflows.browseWorkflows')}
</CategoryButton>
<Collapse in={isDefaultWorkflowsExclusivelySelected}>
<Flex flexDir="column" gap={2} pl={4} py={2} overflow="hidden" h="100%" minH={0}>
<Button
isDisabled={!isDefaultWorkflowsExclusivelySelected || selectedTags.length === 0}
onClick={resetTags}
size="sm"
variant="link"
fontWeight="bold"
justifyContent="flex-start"
flexGrow={0}
flexShrink={0}
leftIcon={<PiArrowCounterClockwiseBold />}
h={8}
>
{t('workflows.resetTags')}
</Button>
<Flex flexDir="column" gap={2} overflow="auto">
{WORKFLOW_TAGS.map((tagCategory) => (
<TagCategory
key={tagCategory.category}
tagCategory={tagCategory}
isDisabled={!isDefaultWorkflowsExclusivelySelected}
/>
))}
</Flex>
</Flex>
</Collapse>
<WorkflowLibraryViewButton view="defaults">{t('workflows.browseWorkflows')}</WorkflowLibraryViewButton>
<DefaultsViewCheckboxesCollapsible />
</Flex>
<Spacer />
<NewWorkflowButton />
@@ -148,61 +58,107 @@ export const WorkflowLibrarySideNav = () => {
);
};
const recentWorkflowsQueryArg = {
page: 0,
per_page: 5,
order_by: 'opened_at',
direction: 'DESC',
} satisfies Parameters<typeof useListWorkflowsQuery>[0];
const RecentWorkflows = memo(() => {
const DefaultsViewCheckboxesCollapsible = memo(() => {
const { t } = useTranslation();
const { data, isLoading } = useListWorkflowsQuery(recentWorkflowsQueryArg);
const dispatch = useDispatch();
const tags = useAppSelector(selectWorkflowLibrarySelectedTags);
const tagCategoryOptions = useStore($workflowLibraryTagCategoriesOptions);
const view = useAppSelector(selectWorkflowLibraryView);
if (isLoading) {
return <Text variant="subtext">{t('common.loading')}</Text>;
}
if (!data) {
return <Text variant="subtext">{t('workflows.noRecentWorkflows')}</Text>;
}
const resetTags = useCallback(() => {
dispatch(workflowLibraryTagsReset());
}, [dispatch]);
return (
<>
{data.items.map((workflow) => {
return <RecentWorkflowButton key={workflow.workflow_id} workflow={workflow} />;
})}
</>
<Collapse in={view === 'defaults'}>
<Flex flexDir="column" gap={2} pl={4} py={2} overflow="hidden" h="100%" minH={0}>
<Button
isDisabled={tags.length === 0}
onClick={resetTags}
size="sm"
variant="link"
fontWeight="bold"
justifyContent="flex-start"
flexGrow={0}
flexShrink={0}
leftIcon={<PiArrowCounterClockwiseBold />}
h={8}
>
{t('workflows.deselectAll')}
</Button>
<Flex flexDir="column" gap={2} overflow="auto">
{tagCategoryOptions.map((tagCategory) => (
<TagCategory key={tagCategory.categoryTKey} tagCategory={tagCategory} />
))}
</Flex>
</Flex>
</Collapse>
);
});
RecentWorkflows.displayName = 'RecentWorkflows';
DefaultsViewCheckboxesCollapsible.displayName = 'DefaultsViewCheckboxes';
const RecentWorkflowButton = memo(({ workflow }: { workflow: S['WorkflowRecordListItemWithThumbnailDTO'] }) => {
const loadWorkflow = useLoadWorkflow();
const load = useCallback(() => {
loadWorkflow.loadWithDialog(workflow.workflow_id, 'view');
}, [loadWorkflow, workflow.workflow_id]);
return (
<Flex
role="button"
key={workflow.workflow_id}
gap={2}
alignItems="center"
_hover={{ textDecoration: 'underline' }}
color="base.300"
onClick={load}
>
<Text as="span" noOfLines={1} w="full" fontWeight="semibold">
{workflow.name}
</Text>
{workflow.category === 'project' && <Icon as={PiUsersBold} boxSize="12px" />}
</Flex>
const useCountForIndividualTag = (tag: string) => {
const allTags = useStore($workflowLibraryTagOptions);
const queryArg = useMemo(
() =>
({
tags: allTags,
categories: ['default'],
}) satisfies Parameters<typeof useGetCountsByTagQuery>[0],
[allTags]
);
const queryOptions = useMemo(
() =>
({
selectFromResult: ({ data }) => ({
count: data?.[tag] ?? 0,
}),
}) satisfies Parameters<typeof useGetCountsByTagQuery>[1],
[tag]
);
});
RecentWorkflowButton.displayName = 'RecentWorkflowButton';
const CategoryButton = memo(({ isSelected, ...rest }: ButtonProps & { isSelected: boolean }) => {
const { count } = useGetCountsByTagQuery(queryArg, queryOptions);
return count;
};
const useCountForTagCategory = (tagCategory: WorkflowTagCategory) => {
const allTags = useStore($workflowLibraryTagOptions);
const queryArg = useMemo(
() =>
({
tags: allTags,
categories: ['default'], // We only allow filtering by tag for default workflows
}) satisfies Parameters<typeof useGetCountsByTagQuery>[0],
[allTags]
);
const queryOptions = useMemo(
() =>
({
selectFromResult: ({ data }) => {
if (!data) {
return { count: 0 };
}
return {
count: tagCategory.tags.reduce((acc, tag) => acc + (data[tag] ?? 0), 0),
};
},
}) satisfies Parameters<typeof useGetCountsByTagQuery>[1],
[tagCategory]
);
const { count } = useGetCountsByTagQuery(queryArg, queryOptions);
return count;
};
const WorkflowLibraryViewButton = memo(({ view, ...rest }: ButtonProps & { view: WorkflowLibraryView }) => {
const dispatch = useDispatch();
const selectedView = useAppSelector(selectWorkflowLibraryView);
const onClick = useCallback(() => {
dispatch(workflowLibraryViewChanged(view));
}, [dispatch, view]);
return (
<Button
variant="ghost"
@@ -210,61 +166,54 @@ const CategoryButton = memo(({ isSelected, ...rest }: ButtonProps & { isSelected
size="md"
flexShrink={0}
w="full"
onClick={onClick}
{...rest}
bg={isSelected ? 'base.700' : undefined}
color={isSelected ? 'base.50' : undefined}
bg={selectedView === view ? 'base.700' : undefined}
color={selectedView === view ? 'base.50' : undefined}
/>
);
});
CategoryButton.displayName = 'NavButton';
WorkflowLibraryViewButton.displayName = 'NavButton';
const TagCategory = memo(
({ tagCategory, isDisabled }: { tagCategory: (typeof WORKFLOW_TAGS)[number]; isDisabled: boolean }) => {
const { count } = useGetCountsQuery(
{ tags: [...tagCategory.tags], categories: ['default'] },
{ selectFromResult: ({ data }) => ({ count: data ?? 0 }) }
);
if (count === 0) {
return null;
}
return (
<Flex flexDir="column" gap={2}>
<Text fontWeight="semibold" color="base.300" opacity={isDisabled ? 0.5 : 1} flexShrink={0}>
{tagCategory.category}
</Text>
<Flex flexDir="column" gap={2} pl={4}>
{tagCategory.tags.map((tag) => (
<TagCheckbox key={tag} tag={tag} isDisabled={isDisabled} />
))}
</Flex>
</Flex>
);
}
);
TagCategory.displayName = 'TagCategory';
const TagCheckbox = memo(({ tag, ...rest }: CheckboxProps & { tag: WorkflowTag }) => {
const dispatch = useAppDispatch();
const selectedTags = useAppSelector(selectWorkflowLibrarySelectedTags);
const isSelected = selectedTags.includes(tag);
const onChange = useCallback(() => {
dispatch(workflowSelectedTagToggled(tag));
}, [dispatch, tag]);
const { count } = useGetCountsQuery(
{ tags: [tag], categories: ['default'] },
{ selectFromResult: ({ data }) => ({ count: data ?? 0 }) }
);
const TagCategory = memo(({ tagCategory }: { tagCategory: WorkflowTagCategory }) => {
const { t } = useTranslation();
const count = useCountForTagCategory(tagCategory);
if (count === 0) {
return null;
}
return (
<Checkbox isChecked={isSelected} onChange={onChange} {...rest} flexShrink={0}>
<Flex flexDir="column" gap={2}>
<Text fontWeight="semibold" color="base.300" flexShrink={0}>
{t(tagCategory.categoryTKey)}
</Text>
<Flex flexDir="column" gap={2} pl={4}>
{tagCategory.tags.map((tag) => (
<TagCheckbox key={tag} tag={tag} />
))}
</Flex>
</Flex>
);
});
TagCategory.displayName = 'TagCategory';
const TagCheckbox = memo(({ tag, ...rest }: CheckboxProps & { tag: string }) => {
const dispatch = useAppDispatch();
const selectedTags = useAppSelector(selectWorkflowLibrarySelectedTags);
const isChecked = selectedTags.includes(tag);
const count = useCountForIndividualTag(tag);
const onChange = useCallback(() => {
dispatch(workflowLibraryTagToggled(tag));
}, [dispatch, tag]);
if (count === 0) {
return null;
}
return (
<Checkbox isChecked={isChecked} onChange={onChange} {...rest} flexShrink={0}>
<Text>{`${tag} (${count})`}</Text>
</Checkbox>
);

View File

@@ -2,31 +2,59 @@ import { Button, Flex, Grid, GridItem, Spacer, Spinner } from '@invoke-ai/ui-lib
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import type { WorkflowLibraryView } from 'features/nodes/store/workflowLibrarySlice';
import {
selectWorkflowLibraryDirection,
selectWorkflowLibraryHasSearchTerm,
selectWorkflowLibraryOrderBy,
selectWorkflowLibrarySearchTerm,
selectWorkflowLibrarySelectedTags,
selectWorkflowOrderBy,
selectWorkflowOrderDirection,
selectWorkflowSearchTerm,
selectWorkflowSelectedCategories,
} from 'features/nodes/store/workflowSlice';
selectWorkflowLibraryView,
} from 'features/nodes/store/workflowLibrarySlice';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import { memo, useCallback, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import type { useListWorkflowsQuery } from 'services/api/endpoints/workflows';
import { useListWorkflowsInfiniteInfiniteQuery } from 'services/api/endpoints/workflows';
import type { S } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { useDebounce } from 'use-debounce';
import { WorkflowListItem } from './WorkflowListItem';
const PER_PAGE = 30;
const getCategories = (view: WorkflowLibraryView): WorkflowCategory[] => {
switch (view) {
case 'defaults':
return ['default'];
case 'recent':
return ['user', 'project', 'default'];
case 'yours':
return ['user', 'project'];
case 'private':
return ['user'];
case 'shared':
return ['project'];
default:
assert<Equals<typeof view, never>>(false);
}
};
const getHasBeenOpened = (view: WorkflowLibraryView): boolean | undefined => {
if (view === 'recent') {
return true;
}
return undefined;
};
const useInfiniteQueryAry = () => {
const categories = useAppSelector(selectWorkflowSelectedCategories);
const orderBy = useAppSelector(selectWorkflowOrderBy);
const direction = useAppSelector(selectWorkflowOrderDirection);
const query = useAppSelector(selectWorkflowSearchTerm);
const tags = useAppSelector(selectWorkflowLibrarySelectedTags);
const [debouncedQuery] = useDebounce(query, 500);
const orderBy = useAppSelector(selectWorkflowLibraryOrderBy);
const direction = useAppSelector(selectWorkflowLibraryDirection);
const searchTerm = useAppSelector(selectWorkflowLibrarySearchTerm);
const selectedTags = useAppSelector(selectWorkflowLibrarySelectedTags);
const view = useAppSelector(selectWorkflowLibraryView);
const [debouncedSearchTerm] = useDebounce(searchTerm, 500);
const queryArg = useMemo(() => {
return {
@@ -34,11 +62,12 @@ const useInfiniteQueryAry = () => {
per_page: PER_PAGE,
order_by: orderBy ?? 'opened_at',
direction,
categories,
query: debouncedQuery,
tags: categories.length === 1 && categories.includes('default') ? tags : [],
} satisfies Parameters<typeof useListWorkflowsQuery>[0];
}, [orderBy, direction, categories, debouncedQuery, tags]);
categories: getCategories(view),
query: debouncedSearchTerm,
tags: view === 'defaults' ? selectedTags : [],
has_been_opened: getHasBeenOpened(view),
} satisfies Parameters<typeof useListWorkflowsInfiniteInfiniteQuery>[0];
}, [orderBy, direction, view, debouncedSearchTerm, selectedTags]);
return queryArg;
};
@@ -52,9 +81,7 @@ const queryOptions = {
},
} satisfies Parameters<typeof useListWorkflowsInfiniteInfiniteQuery>[1];
export const WorkflowList = () => {
const searchTerm = useAppSelector(selectWorkflowSearchTerm);
const { t } = useTranslation();
export const WorkflowList = memo(() => {
const queryArg = useInfiniteQueryAry();
const { items, isFetching, isLoading, fetchNextPage, hasNextPage } = useListWorkflowsInfiniteInfiniteQuery(
queryArg,
@@ -70,14 +97,7 @@ export const WorkflowList = () => {
}
if (items.length === 0) {
return (
<IAINoContentFallback
fontSize="sm"
py={4}
label={searchTerm ? t('nodes.noMatchingWorkflows') : t('nodes.noWorkflows')}
icon={null}
/>
);
return <NoItems />;
}
return (
@@ -88,8 +108,23 @@ export const WorkflowList = () => {
isFetching={isFetching}
/>
);
};
});
WorkflowList.displayName = 'WorkflowList';
const NoItems = memo(() => {
const { t } = useTranslation();
const hasSearchTerm = useAppSelector(selectWorkflowLibraryHasSearchTerm);
return (
<IAINoContentFallback
fontSize="sm"
py={4}
label={hasSearchTerm ? t('nodes.noMatchingWorkflows') : t('nodes.noWorkflows')}
icon={null}
/>
);
});
NoItems.displayName = 'NoItems';
const WorkflowListContent = memo(
({
items,
@@ -153,7 +188,7 @@ const WorkflowListContent = memo(
<Flex flexDir="column" gap={4} flex={1} minH={0}>
<Grid
ref={ref}
templateColumns="repeat(auto-fill, minmax(340px, 1fr))"
templateColumns="repeat(auto-fill, minmax(360px, 1fr))"
gridAutoFlow="dense"
gap={4}
overflow="scroll"

View File

@@ -1,23 +1,22 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Badge, Flex, Icon, Image, Spacer, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ShareWorkflowButton } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/ShareWorkflow';
import { selectWorkflowId } from 'features/nodes/store/workflowSlice';
import { useLoadWorkflow } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { selectWorkflowId, workflowModeChanged } from 'features/nodes/store/workflowSlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import InvokeLogo from 'public/assets/images/invoke-symbol-wht-lrg.svg';
import { useCallback, useMemo } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImageBold, PiUsersBold } from 'react-icons/pi';
import { PiImage, PiUsersBold } from 'react-icons/pi';
import type { WorkflowRecordListItemWithThumbnailDTO } from 'services/api/types';
import { DeleteWorkflow } from './WorkflowLibraryListItemActions/DeleteWorkflow';
import { DownloadWorkflow } from './WorkflowLibraryListItemActions/DownloadWorkflow';
import { EditWorkflow } from './WorkflowLibraryListItemActions/EditWorkflow';
import { SaveWorkflow } from './WorkflowLibraryListItemActions/SaveWorkflow';
import { ViewWorkflow } from './WorkflowLibraryListItemActions/ViewWorkflow';
const IMAGE_THUMBNAIL_SIZE = '80px';
const FALLBACK_ICON_SIZE = '24px';
const IMAGE_THUMBNAIL_SIZE = '108px';
const FALLBACK_ICON_SIZE = '32px';
const WORKFLOW_ACTION_BUTTONS_CN = 'workflow-action-buttons';
@@ -30,97 +29,97 @@ const sx: SystemStyleObject = {
},
};
export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListItemWithThumbnailDTO }) => {
export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordListItemWithThumbnailDTO }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const workflowId = useAppSelector(selectWorkflowId);
const loadWorkflow = useLoadWorkflow();
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const isActive = useMemo(() => {
return workflowId === workflow.workflow_id;
}, [workflowId, workflow.workflow_id]);
const handleClickLoad = useCallback(() => {
loadWorkflow.loadWithDialog(workflow.workflow_id, 'view');
}, [loadWorkflow, workflow.workflow_id]);
loadWorkflowWithDialog({
type: 'library',
data: workflow.workflow_id,
onSuccess: () => {
dispatch(workflowModeChanged('view'));
},
});
}, [dispatch, loadWorkflowWithDialog, workflow.workflow_id]);
return (
<Flex
position="relative"
role="button"
gap={4}
onClick={handleClickLoad}
cursor="pointer"
bg="base.750"
p={2}
ps={3}
borderRadius="base"
w="full"
alignItems="stretch"
sx={sx}
gap={2}
>
<Image
src={workflow.thumbnail_url ?? undefined}
fallbackStrategy="beforeLoadOrError"
fallback={
<Flex
height={IMAGE_THUMBNAIL_SIZE}
minWidth={IMAGE_THUMBNAIL_SIZE}
bg="base.650"
borderRadius="base"
alignItems="center"
justifyContent="center"
>
<Icon color="base.500" as={PiImageBold} boxSize={FALLBACK_ICON_SIZE} />
</Flex>
}
objectFit="cover"
objectPosition="50% 50%"
height={IMAGE_THUMBNAIL_SIZE}
width={IMAGE_THUMBNAIL_SIZE}
minHeight={IMAGE_THUMBNAIL_SIZE}
minWidth={IMAGE_THUMBNAIL_SIZE}
borderRadius="base"
/>
<Flex flexDir="column" gap={1} justifyContent="flex-start">
<Flex gap={4} alignItems="center">
<Text noOfLines={2}>{workflow.name}</Text>
{isActive && (
<Badge color="invokeBlue.400" borderColor="invokeBlue.700" borderWidth={1} bg="transparent" flexShrink={0}>
{t('workflows.opened')}
</Badge>
)}
</Flex>
<Text variant="subtext" fontSize="xs" noOfLines={2}>
{workflow.description}
</Text>
<Flex p={2} pr={0}>
<Image
src={workflow.thumbnail_url ?? undefined}
fallbackStrategy="beforeLoadOrError"
fallback={workflow.category === 'default' ? <DefaultThumbnailFallback /> : <UserThumbnailFallback />}
objectFit="cover"
objectPosition="50% 50%"
height={IMAGE_THUMBNAIL_SIZE}
width={IMAGE_THUMBNAIL_SIZE}
minHeight={IMAGE_THUMBNAIL_SIZE}
minWidth={IMAGE_THUMBNAIL_SIZE}
borderRadius="base"
/>
</Flex>
<Spacer />
<Flex flexDir="column" gap={1} justifyContent="space-between" position="relative">
<Flex gap={1} justifyContent="flex-end" w="full" p={2}>
{workflow.category === 'project' && <Icon as={PiUsersBold} color="base.200" />}
{workflow.category === 'default' && (
<Image src={InvokeLogo} alt="invoke-logo" w="14px" h="14px" minW="14px" minH="14px" userSelect="none" />
)}
<Flex flexDir="column" gap={1} justifyContent="space-between" w="full">
<Flex flexDir="column" gap={1} alignItems="flex-start" pt={2} pe={2} w="full">
<Flex gap={2} alignItems="flex-start" justifyContent="space-between" w="full">
<Text noOfLines={2}>{workflow.name}</Text>
<Flex gap={2} alignItems="center">
{isActive && (
<Badge
color="invokeBlue.400"
borderColor="invokeBlue.700"
borderWidth={1}
bg="transparent"
flexShrink={0}
variant="subtle"
>
{t('workflows.opened')}
</Badge>
)}
{workflow.category === 'project' && <Icon as={PiUsersBold} color="base.200" />}
{workflow.category === 'default' && (
<Image
src={InvokeLogo}
alt="invoke-logo"
w="14px"
h="14px"
minW="14px"
minH="14px"
userSelect="none"
opacity={0.5}
/>
)}
</Flex>
</Flex>
<Text variant="subtext" fontSize="xs" noOfLines={3}>
{workflow.description}
</Text>
</Flex>
<Flex
alignItems="center"
gap={1}
display="none"
className={WORKFLOW_ACTION_BUTTONS_CN}
position="absolute"
right={0}
bottom={0}
>
{workflow.category === 'default' && (
<>
{/* need to consider what is useful here and which icons show that. idea is to "try it out"/"view" or "clone for your own changes" */}
<ViewWorkflow workflowId={workflow.workflow_id} />
<SaveWorkflow workflowId={workflow.workflow_id} />
</>
<Flex className={WORKFLOW_ACTION_BUTTONS_CN} alignItems="center" display="none" h={8}>
{workflow.opened_at && (
<Text variant="subtext" fontSize="xs" noOfLines={1} justifySelf="flex-end" pb={0.5}>
{t('workflows.opened')} {new Date(workflow.opened_at).toLocaleString()}
</Text>
)}
<Spacer />
{workflow.category === 'default' && <ViewWorkflow workflowId={workflow.workflow_id} />}
{workflow.category !== 'default' && (
<>
<EditWorkflow workflowId={workflow.workflow_id} />
@@ -133,4 +132,39 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
</Flex>
</Flex>
);
};
});
WorkflowListItem.displayName = 'WorkflowListItem';
const UserThumbnailFallback = memo(() => {
return (
<Flex
height={IMAGE_THUMBNAIL_SIZE}
minWidth={IMAGE_THUMBNAIL_SIZE}
bg="base.600"
borderRadius="base"
alignItems="center"
justifyContent="center"
opacity={0.3}
>
<Icon as={PiImage} boxSize={FALLBACK_ICON_SIZE} />
</Flex>
);
});
UserThumbnailFallback.displayName = 'UserThumbnailFallback';
const DefaultThumbnailFallback = memo(() => {
return (
<Flex
height={IMAGE_THUMBNAIL_SIZE}
minWidth={IMAGE_THUMBNAIL_SIZE}
bg="base.600"
borderRadius="base"
alignItems="center"
justifyContent="center"
opacity={0.3}
>
<Image src={InvokeLogo} alt="invoke-logo" userSelect="none" boxSize={FALLBACK_ICON_SIZE} p={1} />
</Flex>
);
});
DefaultThumbnailFallback.displayName = 'DefaultThumbnailFallback';

View File

@@ -1,6 +1,9 @@
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectWorkflowSearchTerm, workflowSearchTermChanged } from 'features/nodes/store/workflowSlice';
import {
selectWorkflowLibrarySearchTerm,
workflowLibrarySearchTermChanged,
} from 'features/nodes/store/workflowLibrarySlice';
import type { ChangeEvent, KeyboardEvent, RefObject } from 'react';
import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
@@ -8,18 +11,18 @@ import { PiXBold } from 'react-icons/pi';
export const WorkflowSearch = memo(({ searchInputRef }: { searchInputRef: RefObject<HTMLInputElement> }) => {
const dispatch = useAppDispatch();
const searchTerm = useAppSelector(selectWorkflowSearchTerm);
const searchTerm = useAppSelector(selectWorkflowLibrarySearchTerm);
const { t } = useTranslation();
const handleWorkflowSearch = useCallback(
(newSearchTerm: string) => {
dispatch(workflowSearchTermChanged(newSearchTerm));
dispatch(workflowLibrarySearchTermChanged(newSearchTerm));
},
[dispatch]
);
const clearWorkflowSearch = useCallback(() => {
dispatch(workflowSearchTermChanged(''));
dispatch(workflowLibrarySearchTermChanged(''));
}, [dispatch]);
const handleKeydown = useCallback(

View File

@@ -1,11 +1,11 @@
import { Flex, FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectWorkflowOrderBy,
selectWorkflowOrderDirection,
workflowOrderByChanged,
workflowOrderDirectionChanged,
} from 'features/nodes/store/workflowSlice';
selectWorkflowLibraryDirection,
selectWorkflowLibraryOrderBy,
workflowLibraryDirectionChanged,
workflowLibraryOrderByChanged,
} from 'features/nodes/store/workflowLibrarySlice';
import type { ChangeEvent } from 'react';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -22,8 +22,8 @@ const isDirection = (v: unknown): v is Direction => zDirection.safeParse(v).succ
export const WorkflowSortControl = () => {
const { t } = useTranslation();
const orderBy = useAppSelector(selectWorkflowOrderBy);
const direction = useAppSelector(selectWorkflowOrderDirection);
const orderBy = useAppSelector(selectWorkflowLibraryOrderBy);
const direction = useAppSelector(selectWorkflowLibraryDirection);
const ORDER_BY_LABELS = useMemo(
() => ({
@@ -50,7 +50,7 @@ export const WorkflowSortControl = () => {
if (!isOrderBy(e.target.value)) {
return;
}
dispatch(workflowOrderByChanged(e.target.value));
dispatch(workflowLibraryOrderByChanged(e.target.value));
},
[dispatch]
);
@@ -60,7 +60,7 @@ export const WorkflowSortControl = () => {
if (!isDirection(e.target.value)) {
return;
}
dispatch(workflowOrderDirectionChanged(e.target.value));
dispatch(workflowLibraryDirectionChanged(e.target.value));
},
[dispatch]
);

View File

@@ -1,8 +1,7 @@
import type { HandleType } from '@xyflow/react';
import type { FieldInputTemplate, FieldOutputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
import type { AnyEdge, AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
import type { WorkflowCategory, WorkflowV3 } from 'features/nodes/types/workflow';
import type { SQLiteDirection, WorkflowRecordOrderBy } from 'services/api/types';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
export type Templates = Record<string, InvocationTemplate>;
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
@@ -22,22 +21,9 @@ export type NodesState = {
export type WorkflowMode = 'edit' | 'view';
export const WORKFLOW_TAGS = [
{ category: 'Industry', tags: ['Architecture', 'Fashion', 'Game Dev', 'Food'] },
{ category: 'Common Tasks', tags: ['Upscaling', 'Text to Image', 'Image to Image'] },
{ category: 'Model Architecture', tags: ['SD1.5', 'SDXL', 'Bria', 'FLUX'] },
{ category: 'Tech Showcase', tags: ['Control', 'Reference Image'] },
] as const;
export type WorkflowTag = (typeof WORKFLOW_TAGS)[number]['tags'][number];
export type WorkflowsState = Omit<WorkflowV3, 'nodes' | 'edges'> & {
_version: 1;
isTouched: boolean;
mode: WorkflowMode;
selectedTags: WorkflowTag[];
selectedCategories: WorkflowCategory[];
searchTerm: string;
orderBy?: WorkflowRecordOrderBy;
orderDirection: SQLiteDirection;
formFieldInitialValues: Record<string, StatefulFieldValue>;
};

View File

@@ -0,0 +1,103 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import { atom, computed } from 'nanostores';
import type { SQLiteDirection, WorkflowRecordOrderBy } from 'services/api/types';
export type WorkflowLibraryView = 'recent' | 'yours' | 'private' | 'shared' | 'defaults';
type WorkflowLibraryState = {
view: WorkflowLibraryView;
orderBy: WorkflowRecordOrderBy;
direction: SQLiteDirection;
searchTerm: string;
selectedTags: string[];
};
const initialWorkflowLibraryState: WorkflowLibraryState = {
searchTerm: '',
orderBy: 'opened_at',
direction: 'DESC',
selectedTags: [],
view: 'defaults',
};
export const workflowLibrarySlice = createSlice({
name: 'workflowLibrary',
initialState: initialWorkflowLibraryState,
reducers: {
workflowLibrarySearchTermChanged: (state, action: PayloadAction<string>) => {
state.searchTerm = action.payload;
},
workflowLibraryOrderByChanged: (state, action: PayloadAction<WorkflowRecordOrderBy>) => {
state.orderBy = action.payload;
},
workflowLibraryDirectionChanged: (state, action: PayloadAction<SQLiteDirection>) => {
state.direction = action.payload;
},
workflowLibraryViewChanged: (state, action: PayloadAction<WorkflowLibraryState['view']>) => {
state.view = action.payload;
state.searchTerm = '';
},
workflowLibraryTagToggled: (state, action: PayloadAction<string>) => {
const tag = action.payload;
const tags = state.selectedTags;
if (tags.includes(tag)) {
state.selectedTags = tags.filter((t) => t !== tag);
} else {
state.selectedTags = [...tags, tag];
}
},
workflowLibraryTagsReset: (state) => {
state.selectedTags = [];
},
},
});
export const {
workflowLibrarySearchTermChanged,
workflowLibraryOrderByChanged,
workflowLibraryDirectionChanged,
workflowLibraryTagToggled,
workflowLibraryTagsReset,
workflowLibraryViewChanged,
} = workflowLibrarySlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateWorkflowLibraryState = (state: any): any => state;
export const workflowLibraryPersistConfig: PersistConfig<WorkflowLibraryState> = {
name: workflowLibrarySlice.name,
initialState: initialWorkflowLibraryState,
migrate: migrateWorkflowLibraryState,
persistDenylist: [],
};
const selectWorkflowLibrarySlice = (state: RootState) => state.workflowLibrary;
const createWorkflowLibrarySelector = <T>(selector: Selector<WorkflowLibraryState, T>) =>
createSelector(selectWorkflowLibrarySlice, selector);
export const selectWorkflowLibrarySearchTerm = createWorkflowLibrarySelector(({ searchTerm }) => searchTerm);
export const selectWorkflowLibraryHasSearchTerm = createWorkflowLibrarySelector(({ searchTerm }) => !!searchTerm);
export const selectWorkflowLibraryOrderBy = createWorkflowLibrarySelector(({ orderBy }) => orderBy);
export const selectWorkflowLibraryDirection = createWorkflowLibrarySelector(({ direction }) => direction);
export const selectWorkflowLibrarySelectedTags = createWorkflowLibrarySelector(({ selectedTags }) => selectedTags);
export const selectWorkflowLibraryView = createWorkflowLibrarySelector(({ view }) => view);
export const DEFAULT_WORKFLOW_LIBRARY_CATEGORIES = ['user', 'default'] satisfies WorkflowCategory[];
export const $workflowLibraryCategoriesOptions = atom<WorkflowCategory[]>(DEFAULT_WORKFLOW_LIBRARY_CATEGORIES);
export type WorkflowTagCategory = { categoryTKey: string; tags: string[] };
export const DEFAULT_WORKFLOW_LIBRARY_TAG_CATEGORIES: WorkflowTagCategory[] = [
{ categoryTKey: 'Industry', tags: ['Architecture', 'Fashion', 'Game Dev', 'Food'] },
{ categoryTKey: 'Common Tasks', tags: ['Upscaling', 'Text to Image', 'Image to Image'] },
{ categoryTKey: 'Model Architecture', tags: ['SD1.5', 'SDXL', 'Bria', 'FLUX'] },
{ categoryTKey: 'Tech Showcase', tags: ['Control', 'Reference Image'] },
];
export const $workflowLibraryTagCategoriesOptions = atom<WorkflowTagCategory[]>(
DEFAULT_WORKFLOW_LIBRARY_TAG_CATEGORIES
);
export const $workflowLibraryTagOptions = computed($workflowLibraryTagCategoriesOptions, (tagCategories) =>
tagCategories.flatMap(({ tags }) => tags)
);

View File

@@ -11,12 +11,7 @@ import {
} from 'features/nodes/components/sidePanel/builder/form-manipulation';
import { workflowLoaded } from 'features/nodes/store/actions';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
import type {
NodesState,
WorkflowMode,
WorkflowsState as WorkflowState,
WorkflowTag,
} from 'features/nodes/store/types';
import type { NodesState, WorkflowMode, WorkflowsState as WorkflowState } from 'features/nodes/store/types';
import type { FieldIdentifier, StatefulFieldValue } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type {
@@ -40,7 +35,6 @@ import {
} from 'features/nodes/types/workflow';
import { isEqual } from 'lodash-es';
import { useMemo } from 'react';
import type { SQLiteDirection, WorkflowRecordOrderBy } from 'services/api/types';
import { selectNodesSlice } from './selectors';
@@ -83,11 +77,6 @@ const initialWorkflowState: WorkflowState = {
isTouched: false,
mode: 'view',
formFieldInitialValues: {},
searchTerm: '',
orderBy: 'opened_at', // initial value is decided in component
orderDirection: 'DESC',
selectedTags: [],
selectedCategories: ['user'],
...getBlankWorkflow(),
};
@@ -98,19 +87,6 @@ export const workflowSlice = createSlice({
workflowModeChanged: (state, action: PayloadAction<WorkflowMode>) => {
state.mode = action.payload;
},
workflowSearchTermChanged: (state, action: PayloadAction<string>) => {
state.searchTerm = action.payload;
},
workflowOrderByChanged: (state, action: PayloadAction<WorkflowRecordOrderBy>) => {
state.orderBy = action.payload;
},
workflowOrderDirectionChanged: (state, action: PayloadAction<SQLiteDirection>) => {
state.orderDirection = action.payload;
},
workflowSelectedCategoriesChanged: (state, action: PayloadAction<WorkflowCategory[]>) => {
state.selectedCategories = action.payload;
state.searchTerm = '';
},
workflowNameChanged: (state, action: PayloadAction<string>) => {
state.name = action.payload;
state.isTouched = true;
@@ -150,24 +126,13 @@ export const workflowSlice = createSlice({
workflowSaved: (state) => {
state.isTouched = false;
},
workflowSelectedTagToggled: (state, action: PayloadAction<WorkflowTag>) => {
const tag = action.payload;
const tags = state.selectedTags;
if (tags.includes(tag)) {
state.selectedTags = tags.filter((t) => t !== tag);
} else {
state.selectedTags = [...tags, tag];
}
},
workflowSelectedTagsRese: (state) => {
state.selectedTags = [];
},
formReset: (state) => {
const rootElement = buildContainer('column', []);
state.form = {
elements: { [rootElement.id]: rootElement },
rootElementId: rootElement.id,
};
state.isTouched = true;
},
formElementAdded: (
state,
@@ -184,29 +149,36 @@ export const workflowSlice = createSlice({
if (isNodeFieldElement(element)) {
state.formFieldInitialValues[element.id] = initialValue;
}
state.isTouched = true;
},
formElementRemoved: (state, action: PayloadAction<{ id: string }>) => {
const { form } = state;
const { id } = action.payload;
removeElement({ form, id });
delete state.formFieldInitialValues[id];
state.isTouched = true;
},
formElementReparented: (state, action: PayloadAction<{ id: string; newParentId: string; index: number }>) => {
const { form } = state;
const { id, newParentId, index } = action.payload;
reparentElement({ form, id, newParentId, index });
state.isTouched = true;
},
formElementHeadingDataChanged: (state, action: FormElementDataChangedAction<HeadingElement>) => {
formElementDataChangedReducer(state, action, isHeadingElement);
state.isTouched = true;
},
formElementTextDataChanged: (state, action: FormElementDataChangedAction<TextElement>) => {
formElementDataChangedReducer(state, action, isTextElement);
state.isTouched = true;
},
formElementNodeFieldDataChanged: (state, action: FormElementDataChangedAction<NodeFieldElement>) => {
formElementDataChangedReducer(state, action, isNodeFieldElement);
state.isTouched = true;
},
formElementContainerDataChanged: (state, action: FormElementDataChangedAction<ContainerElement>) => {
formElementDataChangedReducer(state, action, isContainerElement);
state.isTouched = true;
},
formFieldInitialValuesChanged: (
state,
@@ -214,6 +186,7 @@ export const workflowSlice = createSlice({
) => {
const { formFieldInitialValues } = action.payload;
state.formFieldInitialValues = formFieldInitialValues;
state.isTouched = true;
},
},
extraReducers: (builder) => {
@@ -314,12 +287,6 @@ export const {
workflowContactChanged,
workflowIDChanged,
workflowSaved,
workflowSearchTermChanged,
workflowOrderByChanged,
workflowOrderDirectionChanged,
workflowSelectedCategoriesChanged,
workflowSelectedTagToggled,
workflowSelectedTagsRese,
formReset,
formElementAdded,
formElementRemoved,
@@ -382,12 +349,7 @@ export const selectWorkflowName = createWorkflowSelector((workflow) => workflow.
export const selectWorkflowId = createWorkflowSelector((workflow) => workflow.id);
export const selectWorkflowMode = createWorkflowSelector((workflow) => workflow.mode);
export const selectWorkflowIsTouched = createWorkflowSelector((workflow) => workflow.isTouched);
export const selectWorkflowSearchTerm = createWorkflowSelector((workflow) => workflow.searchTerm);
export const selectWorkflowOrderBy = createWorkflowSelector((workflow) => workflow.orderBy);
export const selectWorkflowOrderDirection = createWorkflowSelector((workflow) => workflow.orderDirection);
export const selectWorkflowSelectedCategories = createWorkflowSelector((workflow) => workflow.selectedCategories);
export const selectWorkflowDescription = createWorkflowSelector((workflow) => workflow.description);
export const selectWorkflowLibrarySelectedTags = createWorkflowSelector((workflow) => workflow.selectedTags);
export const selectWorkflowForm = createWorkflowSelector((workflow) => workflow.form);
export const selectCleanEditor = createSelector([selectNodesSlice, selectWorkflowSlice], (nodes, workflow) => {

View File

@@ -0,0 +1,55 @@
import type { CanvasReferenceImageState, FLUXReduxConfig } from 'features/controlLayers/store/types';
import { isFLUXReduxConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
type AddFLUXReduxResult = {
addedFLUXReduxes: number;
};
type AddFLUXReduxArg = {
entities: CanvasReferenceImageState[];
g: Graph;
collector: Invocation<'collect'>;
model: ParameterModel;
};
export const addFLUXReduxes = ({ entities, g, collector, model }: AddFLUXReduxArg): AddFLUXReduxResult => {
const validFLUXReduxes = entities
.filter((entity) => entity.isEnabled)
.filter((entity) => isFLUXReduxConfig(entity.ipAdapter))
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
const result: AddFLUXReduxResult = {
addedFLUXReduxes: 0,
};
for (const { id, ipAdapter } of validFLUXReduxes) {
assert(isFLUXReduxConfig(ipAdapter), 'This should have been filtered out');
result.addedFLUXReduxes++;
addFLUXRedux(id, ipAdapter, g, collector);
}
return result;
};
const addFLUXRedux = (id: string, ipAdapter: FLUXReduxConfig, g: Graph, collector: Invocation<'collect'>) => {
const { model: fluxReduxModel, image } = ipAdapter;
assert(image, 'FLUX Redux image is required');
assert(fluxReduxModel, 'FLUX Redux model is required');
const node = g.addNode({
id: `flux_redux_${id}`,
type: 'flux_redux',
redux_model: fluxReduxModel,
image: {
image_name: image.image_name,
},
});
g.addEdge(node, 'redux_cond', collector, 'item');
};

View File

@@ -1,4 +1,8 @@
import type { CanvasReferenceImageState } from 'features/controlLayers/store/types';
import {
type CanvasReferenceImageState,
type IPAdapterConfig,
isIPAdapterConfig,
} from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
@@ -19,23 +23,24 @@ type AddIPAdaptersArg = {
export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => {
const validIPAdapters = entities
.filter((entity) => entity.isEnabled)
.filter((entity) => isIPAdapterConfig(entity.ipAdapter))
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
const result: AddIPAdaptersResult = {
addedIPAdapters: 0,
};
for (const ipa of validIPAdapters) {
for (const { id, ipAdapter } of validIPAdapters) {
assert(isIPAdapterConfig(ipAdapter), 'This should have been filtered out');
result.addedIPAdapters++;
addIPAdapter(ipa, g, collector);
addIPAdapter(id, ipAdapter, g, collector);
}
return result;
};
const addIPAdapter = (entity: CanvasReferenceImageState, g: Graph, collector: Invocation<'collect'>) => {
const { id, ipAdapter } = entity;
const addIPAdapter = (id: string, ipAdapter: IPAdapterConfig, g: Graph, collector: Invocation<'collect'>) => {
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(image, 'IP Adapter image is required');
assert(model, 'IP Adapter model is required');

View File

@@ -18,6 +18,7 @@ type AddedRegionResult = {
addedNegativePrompt: boolean;
addedAutoNegativePositivePrompt: boolean;
addedIPAdapters: number;
addedFLUXReduxes: number;
};
type AddRegionsArg = {
@@ -31,6 +32,7 @@ type AddRegionsArg = {
posCondCollect: Invocation<'collect'>;
negCondCollect: Invocation<'collect'> | null;
ipAdapterCollect: Invocation<'collect'>;
fluxReduxCollect: Invocation<'collect'> | null;
};
/**
@@ -45,6 +47,7 @@ type AddRegionsArg = {
* @param posCondCollect The positive conditioning collector
* @param negCondCollect The negative conditioning collector
* @param ipAdapterCollect The IP adapter collector
* @param fluxReduxConnect The IP adapter collector
* @returns A promise that resolves to the regions that were successfully added to the graph
*/
@@ -59,6 +62,7 @@ export const addRegions = async ({
posCondCollect,
negCondCollect,
ipAdapterCollect,
fluxReduxCollect,
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
const isSDXL = model.base === 'sdxl';
const isFLUX = model.base === 'flux';
@@ -75,6 +79,7 @@ export const addRegions = async ({
addedNegativePrompt: false,
addedAutoNegativePositivePrompt: false,
addedIPAdapters: 0,
addedFLUXReduxes: 0,
};
const getImageDTOResult = await withResultAsync(() => {
@@ -269,30 +274,52 @@ export const addRegions = async ({
}
for (const { id, ipAdapter } of region.referenceImages) {
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
if (ipAdapter.type === 'ip_adapter') {
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
result.addedIPAdapters++;
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required');
result.addedIPAdapters++;
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required');
const ipAdapterNode = g.addNode({
id: `ip_adapter_${id}`,
type: 'ip_adapter',
weight,
method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.image_name,
},
});
const ipAdapterNode = g.addNode({
id: `ip_adapter_${id}`,
type: 'ip_adapter',
weight,
method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.image_name,
},
});
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', ipAdapterNode, 'mask');
g.addEdge(ipAdapterNode, 'ip_adapter', ipAdapterCollect, 'item');
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', ipAdapterNode, 'mask');
g.addEdge(ipAdapterNode, 'ip_adapter', ipAdapterCollect, 'item');
} else if (ipAdapter.type === 'flux_redux') {
assert(isFLUX, 'Regional FLUX Redux requires FLUX.');
assert(fluxReduxCollect !== null, 'FLUX Redux collector is required.');
result.addedFLUXReduxes++;
const { model: fluxReduxModel, image } = ipAdapter;
assert(fluxReduxModel, 'FLUX Redux model is required');
assert(image, 'FLUX Redux image is required');
const fluxReduxNode = g.addNode({
id: `flux_redux_${id}`,
type: 'flux_redux',
redux_model: fluxReduxModel,
image: {
image_name: image.image_name,
},
});
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', fluxReduxNode, 'mask');
g.addEdge(fluxReduxNode, 'redux_cond', fluxReduxCollect, 'item');
}
}
results.push(result);

View File

@@ -7,6 +7,7 @@ import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs';
import { addFLUXReduxes } from 'features/nodes/util/graph/generation/addFLUXRedux';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
@@ -233,6 +234,17 @@ export const buildFLUXGraph = async (
model: modelConfig,
});
const fluxReduxCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const fluxReduxResult = addFLUXReduxes({
entities: canvas.referenceImages.entities,
g,
collector: fluxReduxCollect,
model: modelConfig,
});
const regionsResult = await addRegions({
manager,
regions: canvas.regionalGuidance.entities,
@@ -244,6 +256,7 @@ export const buildFLUXGraph = async (
posCondCollect,
negCondCollect: null,
ipAdapterCollect,
fluxReduxCollect,
});
const totalIPAdaptersAdded =
@@ -254,6 +267,16 @@ export const buildFLUXGraph = async (
g.deleteNode(ipAdapterCollect.id);
}
const totalReduxesAdded =
fluxReduxResult.addedFLUXReduxes + regionsResult.reduce((acc, r) => acc + r.addedFLUXReduxes, 0);
if (totalReduxesAdded > 0) {
g.addEdge(fluxReduxCollect, 'collection', denoise, 'redux_conditioning');
} else {
g.deleteNode(fluxReduxCollect.id);
}
// TODO: Add FLUX Reduxes to denoise node like we do for ipa
if (state.system.shouldUseNSFWChecker) {
canvasOutput = addNSFWChecker(g, canvasOutput);
}

View File

@@ -281,6 +281,7 @@ export const buildSD1Graph = async (
posCondCollect,
negCondCollect,
ipAdapterCollect,
fluxReduxCollect: null,
});
const totalIPAdaptersAdded =

View File

@@ -286,6 +286,7 @@ export const buildSDXLGraph = async (
posCondCollect,
negCondCollect,
ipAdapterCollect,
fluxReduxCollect: null,
});
const totalIPAdaptersAdded =

View File

@@ -1,75 +1,148 @@
import { ConfirmationAlertDialog, Flex, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { useWorkflowLibraryModal } from 'features/nodes/store/workflowLibraryModal';
import { selectWorkflowIsTouched, workflowModeChanged } from 'features/nodes/store/workflowSlice';
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
import { selectWorkflowIsTouched } from 'features/nodes/store/workflowSlice';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { useLoadWorkflowFromFile } from 'features/workflowLibrary/hooks/useLoadWorkflowFromFile';
import { useLoadWorkflowFromImage } from 'features/workflowLibrary/hooks/useLoadWorkflowFromImage';
import { useLoadWorkflowFromLibrary } from 'features/workflowLibrary/hooks/useLoadWorkflowFromLibrary';
import { useLoadWorkflowFromObject } from 'features/workflowLibrary/hooks/useLoadWorkflowFromObject';
import { atom } from 'nanostores';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const $workflowToLoad = atom<{ workflowId: string; mode: 'view' | 'edit'; isOpen: boolean } | null>(null);
const cleanup = () => $workflowToLoad.set(null);
type Callbacks = {
onSuccess?: (workflow: WorkflowV3) => void;
onError?: () => void;
onCompleted?: () => void;
};
export const useLoadWorkflow = () => {
const dispatch = useAppDispatch();
type LoadLibraryWorkflowData = Callbacks & {
type: 'library';
data: string;
};
type LoadWorkflowFromObjectData = Callbacks & {
type: 'object';
data: unknown;
};
type LoadWorkflowFromFileData = Callbacks & {
type: 'file';
data: File;
};
type LoadWorkflowFromImageData = Callbacks & {
type: 'image';
data: string;
};
type DialogStateExtra = {
isOpen: boolean;
};
const $dialogState = atom<
| (LoadLibraryWorkflowData & DialogStateExtra)
| (LoadWorkflowFromObjectData & DialogStateExtra)
| (LoadWorkflowFromFileData & DialogStateExtra)
| (LoadWorkflowFromImageData & DialogStateExtra)
| null
>(null);
const cleanup = () => $dialogState.set(null);
const useLoadImmediate = () => {
const workflowLibraryModal = useWorkflowLibraryModal();
const { getAndLoadWorkflow } = useGetAndLoadLibraryWorkflow();
const isTouched = useAppSelector(selectWorkflowIsTouched);
const loadWorkflowFromLibrary = useLoadWorkflowFromLibrary();
const loadWorkflowFromFile = useLoadWorkflowFromFile();
const loadWorkflowFromImage = useLoadWorkflowFromImage();
const loadWorkflowFromObject = useLoadWorkflowFromObject();
const loadImmediate = useCallback(async () => {
const workflow = $workflowToLoad.get();
if (!workflow) {
const dialogState = $dialogState.get();
if (!dialogState) {
return;
}
const { workflowId, mode } = workflow;
await getAndLoadWorkflow(workflowId);
dispatch(workflowModeChanged(mode));
const { type, data, onSuccess, onError, onCompleted } = dialogState;
const options = {
onSuccess,
onError,
onCompleted,
};
if (type === 'object') {
await loadWorkflowFromObject(data, options);
} else if (type === 'file') {
await loadWorkflowFromFile(data, options);
} else if (type === 'library') {
await loadWorkflowFromLibrary(data, options);
} else if (type === 'image') {
await loadWorkflowFromImage(data, options);
}
cleanup();
workflowLibraryModal.close();
}, [dispatch, getAndLoadWorkflow, workflowLibraryModal]);
}, [
loadWorkflowFromFile,
loadWorkflowFromImage,
loadWorkflowFromLibrary,
loadWorkflowFromObject,
workflowLibraryModal,
]);
const loadWithDialog = useCallback(
(workflowId: string, mode: 'view' | 'edit') => {
return loadImmediate;
};
/**
* Handles loading workflows from various sources. If there are unsaved changes, the user will be prompted to confirm
* before loading the workflow.
*/
export const useLoadWorkflowWithDialog = () => {
const isTouched = useAppSelector(selectWorkflowIsTouched);
const loadImmediate = useLoadImmediate();
const loadWorkflowWithDialog = useCallback(
/**
* Loads a workflow from various sources. If there are unsaved changes, the user will be prompted to confirm before
* loading the workflow. The workflow will be loaded immediately if there are no unsaved changes. On success, error
* or completion, the corresponding callback will be called.
*
* @param data - The data to load the workflow from.
* @param data.type - The type of data to load the workflow from.
* @param data.data - The data to load the workflow from. The type of this data depends on the `type` field.
* @param data.onSuccess - A callback to call when the workflow is successfully loaded.
* @param data.onError - A callback to call when an error occurs while loading the workflow.
* @param data.onCompleted - A callback to call when the loading process is completed (both success and error).
*/
(
data: LoadLibraryWorkflowData | LoadWorkflowFromObjectData | LoadWorkflowFromFileData | LoadWorkflowFromImageData
) => {
if (!isTouched) {
$workflowToLoad.set({
workflowId,
mode,
isOpen: false,
});
$dialogState.set({ ...data, isOpen: false });
loadImmediate();
} else {
$workflowToLoad.set({
workflowId,
mode,
isOpen: true,
});
$dialogState.set({ ...data, isOpen: true });
}
},
[loadImmediate, isTouched]
);
return {
loadImmediate,
loadWithDialog,
} as const;
return loadWorkflowWithDialog;
};
export const LoadWorkflowConfirmationAlertDialog = memo(() => {
useAssertSingleton('LoadWorkflowConfirmationAlertDialog');
const { t } = useTranslation();
const workflow = useStore($workflowToLoad);
const loadWorkflow = useLoadWorkflow();
const workflow = useStore($dialogState);
const loadImmediate = useLoadImmediate();
return (
<ConfirmationAlertDialog
isOpen={!!workflow?.isOpen}
onClose={cleanup}
title={t('nodes.loadWorkflow')}
acceptCallback={loadWorkflow.loadImmediate}
acceptCallback={loadImmediate}
useInert={false}
acceptButtonText={t('common.load')}
>
<Flex flexDir="column" gap={2}>
<Text>{t('nodes.loadWorkflowDesc')}</Text>

View File

@@ -15,7 +15,7 @@ import {
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
import { useLoadWorkflow } from 'features/workflowLibrary/hooks/useLoadWorkflow';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { atom } from 'nanostores';
import type { ChangeEvent } from 'react';
import { useCallback, useState } from 'react';
@@ -37,16 +37,17 @@ export const useLoadWorkflowFromGraphModal = () => {
export const LoadWorkflowFromGraphModal = () => {
const { t } = useTranslation();
const _loadWorkflow = useLoadWorkflow();
const { isOpen, onClose } = useLoadWorkflowFromGraphModal();
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const [graphRaw, setGraphRaw] = useState<string>('');
const [workflowRaw, setWorkflowRaw] = useState<string>('');
const [unvalidatedWorkflow, setUnvalidatedWorkflow] = useState<unknown>();
const [unvalidatedWorkflowAsString, setUnvalidatedWorkflowAsString] = useState<string>('');
const [shouldAutoLayout, setShouldAutoLayout] = useState(true);
const onChangeGraphRaw = useCallback((e: ChangeEvent<HTMLTextAreaElement>) => {
setGraphRaw(e.target.value);
}, []);
const onChangeWorkflowRaw = useCallback((e: ChangeEvent<HTMLTextAreaElement>) => {
setWorkflowRaw(e.target.value);
setUnvalidatedWorkflow(e.target.value);
}, []);
const onChangeShouldAutoLayout = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setShouldAutoLayout(e.target.checked);
@@ -54,12 +55,13 @@ export const LoadWorkflowFromGraphModal = () => {
const parse = useCallback(() => {
const graph = JSON.parse(graphRaw);
const workflow = graphToWorkflow(graph, shouldAutoLayout);
setWorkflowRaw(JSON.stringify(workflow, null, 2));
setUnvalidatedWorkflow(workflow);
setUnvalidatedWorkflowAsString(JSON.stringify(workflow, null, 2));
}, [graphRaw, shouldAutoLayout]);
const loadWorkflow = useCallback(() => {
_loadWorkflow({ workflow: workflowRaw, graph: null });
const loadWorkflow = useCallback(async () => {
await loadWorkflowWithDialog({ type: 'object', data: unvalidatedWorkflow });
onClose();
}, [_loadWorkflow, onClose, workflowRaw]);
}, [loadWorkflowWithDialog, unvalidatedWorkflow, onClose]);
return (
<Modal isOpen={isOpen} onClose={onClose} isCentered useInert={false}>
<ModalOverlay />
@@ -95,7 +97,7 @@ export const LoadWorkflowFromGraphModal = () => {
<FormLabel>{t('nodes.workflow')}</FormLabel>
<Textarea
h="full"
value={workflowRaw}
value={unvalidatedWorkflowAsString}
fontFamily="monospace"
whiteSpace="pre-wrap"
overflowWrap="normal"

View File

@@ -20,7 +20,7 @@ export const NewWorkflowButton = memo(() => {
);
return (
<Button onClick={onClickNewWorkflow} variant="ghost" leftIcon={<PiFilePlusBold />}>
<Button onClick={onClickNewWorkflow} variant="ghost" leftIcon={<PiFilePlusBold />} justifyContent="flex-start">
{t('nodes.newWorkflow')}
</Button>
);

View File

@@ -59,6 +59,7 @@ export const NewWorkflowConfirmationAlertDialog = memo(() => {
title={t('nodes.newWorkflow')}
acceptCallback={newWorkflow.createImmediate}
useInert={false}
acceptButtonText={t('common.load')}
>
<Flex flexDir="column" gap={2}>
<Text>{t('nodes.newWorkflowDesc')}</Text>

View File

@@ -12,9 +12,9 @@ import {
Input,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $workflowCategories } from 'app/store/nanostores/workflowCategories';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { deepClone } from 'common/util/deepClone';
import { $workflowLibraryCategoriesOptions } from 'features/nodes/store/workflowLibrarySlice';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { isDraftWorkflow, useCreateLibraryWorkflow } from 'features/workflowLibrary/hooks/useCreateNewWorkflow';
import { t } from 'i18next';
@@ -83,7 +83,7 @@ export const SaveWorkflowAsDialog = () => {
};
const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV3; cancelRef: RefObject<HTMLButtonElement> }) => {
const workflowCategories = useStore($workflowCategories);
const workflowCategories = useStore($workflowLibraryCategoriesOptions);
const [name, setName] = useState(() => {
if (workflow) {
return getInitialName(workflow);

View File

@@ -1,33 +1,25 @@
import { Button } from '@invoke-ai/ui-library';
import { useWorkflowLibraryModal } from 'features/nodes/store/workflowLibraryModal';
import { saveWorkflowAs } from 'features/workflowLibrary/components/SaveWorkflowAsDialog';
import { useLoadWorkflowFromFile } from 'features/workflowLibrary/hooks/useLoadWorkflowFromFile';
import { memo, useCallback, useRef } from 'react';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { memo, useCallback } from 'react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { PiUploadSimpleBold } from 'react-icons/pi';
export const UploadWorkflowButton = memo(() => {
const { t } = useTranslation();
const resetRef = useRef<() => void>(null);
const workflowLibraryModal = useWorkflowLibraryModal();
const loadWorkflowFromFile = useLoadWorkflowFromFile({
resetRef,
onSuccess: (workflow) => {
workflowLibraryModal.close();
saveWorkflowAs(workflow);
},
});
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const onDropAccepted = useCallback(
(files: File[]) => {
if (!files[0]) {
([file]: File[]) => {
if (!file) {
return;
}
loadWorkflowFromFile(files[0]);
loadWorkflowWithDialog({
type: 'file',
data: file,
});
},
[loadWorkflowFromFile]
[loadWorkflowWithDialog]
);
const { getInputProps, getRootProps } = useDropzone({
@@ -36,9 +28,16 @@ export const UploadWorkflowButton = memo(() => {
noDrag: true,
multiple: false,
});
return (
<>
<Button leftIcon={<PiUploadSimpleBold />} {...getRootProps()} pointerEvents="auto" variant="ghost">
<Button
leftIcon={<PiUploadSimpleBold />}
{...getRootProps()}
pointerEvents="auto"
variant="ghost"
justifyContent="flex-start"
>
{t('workflows.uploadWorkflow')}
</Button>

View File

@@ -1,23 +1,25 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useLoadWorkflowFromFile } from 'features/workflowLibrary/hooks/useLoadWorkflowFromFile';
import { memo, useCallback, useRef } from 'react';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { memo, useCallback } from 'react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { PiUploadSimpleBold } from 'react-icons/pi';
const UploadWorkflowMenuItem = () => {
const { t } = useTranslation();
const resetRef = useRef<() => void>(null);
const loadWorkflowFromFile = useLoadWorkflowFromFile({ resetRef });
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const onDropAccepted = useCallback(
(files: File[]) => {
if (!files[0]) {
([file]: File[]) => {
if (!file) {
return;
}
loadWorkflowFromFile(files[0]);
loadWorkflowWithDialog({
type: 'file',
data: file,
});
},
[loadWorkflowFromFile]
[loadWorkflowWithDialog]
);
const { getRootProps, getInputProps } = useDropzone({
@@ -26,6 +28,7 @@ const UploadWorkflowMenuItem = () => {
noDrag: true,
multiple: false,
});
return (
<MenuItem as="button" icon={<PiUploadSimpleBold />} {...getRootProps()}>
{t('workflows.uploadWorkflow')}

View File

@@ -13,7 +13,7 @@ import { useGetFormFieldInitialValues } from 'features/workflowLibrary/hooks/use
import { newWorkflowSaved } from 'features/workflowLibrary/store/actions';
import { useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { useCreateWorkflowMutation, workflowsApi } from 'services/api/endpoints/workflows';
import { useCreateWorkflowMutation, useUpdateOpenedAtMutation, workflowsApi } from 'services/api/endpoints/workflows';
import type { SetFieldType } from 'type-fest';
/**
@@ -44,6 +44,7 @@ export const useCreateLibraryWorkflow = (): CreateLibraryWorkflowReturn => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [createWorkflow, { isLoading, isError }] = useCreateWorkflowMutation();
const [updateOpenedAt] = useUpdateOpenedAtMutation();
const getFormFieldInitialValues = useGetFormFieldInitialValues();
const toast = useToast();
@@ -66,11 +67,11 @@ export const useCreateLibraryWorkflow = (): CreateLibraryWorkflowReturn => {
dispatch(workflowIDChanged(id));
dispatch(workflowNameChanged(name));
dispatch(workflowCategoryChanged(category));
dispatch(workflowSaved());
dispatch(newWorkflowSaved({ category }));
// When a workflow is saved, the form field initial values are updated to the current form field values
dispatch(formFieldInitialValuesChanged({ formFieldInitialValues: getFormFieldInitialValues() }));
updateOpenedAt({ workflow_id: id });
dispatch(workflowSaved());
onSuccess && onSuccess();
toast.update(toastRef.current, {
title: t('workflows.workflowSaved'),
@@ -92,7 +93,7 @@ export const useCreateLibraryWorkflow = (): CreateLibraryWorkflowReturn => {
}
}
},
[toast, t, createWorkflow, dispatch, getFormFieldInitialValues]
[toast, t, createWorkflow, dispatch, getFormFieldInitialValues, updateOpenedAt]
);
return {
createNewWorkflow,

View File

@@ -1,44 +0,0 @@
import { toast } from 'features/toast/toast';
import { useLoadWorkflow } from 'features/workflowLibrary/hooks/useLoadWorkflow';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyGetImageWorkflowQuery } from 'services/api/endpoints/images';
type UseGetAndLoadEmbeddedWorkflowOptions = {
onSuccess?: () => void;
onError?: () => void;
};
export const useGetAndLoadEmbeddedWorkflow = (options?: UseGetAndLoadEmbeddedWorkflowOptions) => {
const { t } = useTranslation();
const [_getAndLoadEmbeddedWorkflow, result] = useLazyGetImageWorkflowQuery();
const loadWorkflow = useLoadWorkflow();
const getAndLoadEmbeddedWorkflow = useCallback(
async (imageName: string) => {
try {
const { data } = await _getAndLoadEmbeddedWorkflow(imageName);
if (data) {
loadWorkflow(data);
// No toast - the listener for this action does that after the workflow is loaded
options?.onSuccess && options?.onSuccess();
} else {
toast({
id: 'PROBLEM_RETRIEVING_WORKFLOW',
title: t('toast.problemRetrievingWorkflow'),
status: 'error',
});
}
} catch {
toast({
id: 'PROBLEM_RETRIEVING_WORKFLOW',
title: t('toast.problemRetrievingWorkflow'),
status: 'error',
});
options?.onError && options?.onError();
}
},
[_getAndLoadEmbeddedWorkflow, loadWorkflow, options, t]
);
return [getAndLoadEmbeddedWorkflow, result] as const;
};

View File

@@ -1,47 +0,0 @@
import { useToast } from '@invoke-ai/ui-library';
import { useLoadWorkflow } from 'features/workflowLibrary/hooks/useLoadWorkflow';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyGetWorkflowQuery, useUpdateOpenedAtMutation, workflowsApi } from 'services/api/endpoints/workflows';
type UseGetAndLoadLibraryWorkflowOptions = {
onSuccess?: () => void;
onError?: () => void;
};
type UseGetAndLoadLibraryWorkflowReturn = {
getAndLoadWorkflow: (workflow_id: string) => Promise<void>;
getAndLoadWorkflowResult: ReturnType<typeof useLazyGetWorkflowQuery>[1];
};
type UseGetAndLoadLibraryWorkflow = (arg?: UseGetAndLoadLibraryWorkflowOptions) => UseGetAndLoadLibraryWorkflowReturn;
export const useGetAndLoadLibraryWorkflow: UseGetAndLoadLibraryWorkflow = (arg) => {
const toast = useToast();
const { t } = useTranslation();
const loadWorkflow = useLoadWorkflow();
const [getWorkflow, getAndLoadWorkflowResult] = useLazyGetWorkflowQuery();
const [updateOpenedAt] = useUpdateOpenedAtMutation();
const getAndLoadWorkflow = useCallback(
async (workflow_id: string) => {
try {
const { workflow } = await getWorkflow(workflow_id).unwrap();
// This action expects a stringified workflow, instead of updating the routes and services we will just stringify it here
await loadWorkflow({ workflow: JSON.stringify(workflow), graph: null });
updateOpenedAt({ workflow_id });
// No toast - the listener for this action does that after the workflow is loaded
arg?.onSuccess && arg.onSuccess();
} catch {
toast({
id: `AUTH_ERROR_TOAST_${workflowsApi.endpoints.getWorkflow.name}`,
title: t('toast.problemRetrievingWorkflow'),
status: 'error',
});
arg?.onError && arg.onError();
}
},
[getWorkflow, loadWorkflow, updateOpenedAt, arg, toast, t]
);
return { getAndLoadWorkflow, getAndLoadWorkflowResult };
};

View File

@@ -1,46 +1,57 @@
import { useAppDispatch } from 'app/store/storeHooks';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { useLoadWorkflow } from 'features/workflowLibrary/hooks/useLoadWorkflow';
import { useValidateAndLoadWorkflow } from 'features/workflowLibrary/hooks/useValidateAndLoadWorkflow';
import { workflowLoadedFromFile } from 'features/workflowLibrary/store/actions';
import type { RefObject } from 'react';
import { useCallback } from 'react';
import { assert } from 'tsafe';
type useLoadWorkflowFromFileOptions = {
resetRef: RefObject<() => void>;
onSuccess?: (workflow: WorkflowV3) => void;
};
type UseLoadWorkflowFromFile = (options: useLoadWorkflowFromFileOptions) => (file: File | null) => void;
export const useLoadWorkflowFromFile: UseLoadWorkflowFromFile = ({ resetRef, onSuccess }) => {
/**
* Loads a workflow from a file.
*
* You probably should instead use `useLoadWorkflowWithDialog`, which opens a dialog to prevent loss of unsaved changes
* and handles the loading process.
*/
export const useLoadWorkflowFromFile = () => {
const dispatch = useAppDispatch();
const loadWorkflow = useLoadWorkflow();
const validatedAndLoadWorkflow = useValidateAndLoadWorkflow();
const loadWorkflowFromFile = useCallback(
(file: File | null) => {
if (!file) {
return;
}
const reader = new FileReader();
reader.onload = async () => {
const rawJSON = reader.result;
(
file: File,
options: {
onSuccess?: (workflow: WorkflowV3) => void;
onError?: () => void;
onCompleted?: () => void;
} = {}
) => {
return new Promise<WorkflowV3 | void>((resolve, reject) => {
const reader = new FileReader();
reader.onload = async () => {
const rawJSON = reader.result;
const { onSuccess, onError, onCompleted } = options;
try {
const unvalidatedWorkflow = JSON.parse(rawJSON as string);
const validatedWorkflow = await validatedAndLoadWorkflow(unvalidatedWorkflow);
try {
const workflow = await loadWorkflow({ workflow: String(rawJSON), graph: null });
assert(workflow !== null);
dispatch(workflowLoadedFromFile());
onSuccess && onSuccess(workflow);
} catch (e) {
reader.abort();
}
};
if (!validatedWorkflow) {
reader.abort();
onError?.();
return;
}
dispatch(workflowLoadedFromFile());
onSuccess?.(validatedWorkflow);
resolve(validatedWorkflow);
} catch {
// This is catching the error from the parsing the JSON file
onError?.();
reject();
} finally {
onCompleted?.();
}
};
reader.readAsText(file);
// Reset the file picker internal state so that the same file can be loaded again
resetRef.current?.();
reader.readAsText(file);
});
},
[resetRef, loadWorkflow, dispatch, onSuccess]
[validatedAndLoadWorkflow, dispatch]
);
return loadWorkflowFromFile;

View File

@@ -0,0 +1,70 @@
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
import { toast } from 'features/toast/toast';
import { useValidateAndLoadWorkflow } from 'features/workflowLibrary/hooks/useValidateAndLoadWorkflow';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyGetImageWorkflowQuery } from 'services/api/endpoints/images';
import type { NonNullableGraph } from 'services/api/types';
import { assert } from 'tsafe';
/**
* Loads a workflow from an image.
*
* You probably should instead use `useLoadWorkflowWithDialog`, which opens a dialog to prevent loss of unsaved changes
* and handles the loading process.
*/
export const useLoadWorkflowFromImage = () => {
const { t } = useTranslation();
const [getWorkflowAndGraphFromImage] = useLazyGetImageWorkflowQuery();
const validateAndLoadWorkflow = useValidateAndLoadWorkflow();
const loadWorkflowFromImage = useCallback(
async (
imageName: string,
options: {
onSuccess?: (workflow: WorkflowV3) => void;
onError?: () => void;
onCompleted?: () => void;
} = {}
) => {
const { onSuccess, onError, onCompleted } = options;
try {
const { workflow, graph } = await getWorkflowAndGraphFromImage(imageName).unwrap();
// Images may have a workflow and/or a graph. We can load either into the workflow editor, but we prefer the
// workflow.
const unvalidatedWorkflow = workflow
? JSON.parse(workflow)
: graph
? graphToWorkflow(JSON.parse(graph) as NonNullableGraph, true)
: null;
assert(unvalidatedWorkflow !== null, 'No workflow or graph provided');
const validatedWorkflow = await validateAndLoadWorkflow(unvalidatedWorkflow);
if (!validatedWorkflow) {
onError?.();
return;
}
onSuccess?.(validatedWorkflow);
} catch {
// This is catching:
// - the error from the getWorkflowAndGraphFromImage query
// - the error from parsing the workflow or graph
toast({
id: 'PROBLEM_RETRIEVING_WORKFLOW',
title: t('toast.problemRetrievingWorkflow'),
status: 'error',
});
onError?.();
} finally {
onCompleted?.();
}
},
[getWorkflowAndGraphFromImage, validateAndLoadWorkflow, t]
);
return loadWorkflowFromImage;
};

View File

@@ -0,0 +1,57 @@
import { useToast } from '@invoke-ai/ui-library';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { useValidateAndLoadWorkflow } from 'features/workflowLibrary/hooks/useValidateAndLoadWorkflow';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyGetWorkflowQuery, useUpdateOpenedAtMutation, workflowsApi } from 'services/api/endpoints/workflows';
/**
* Loads a workflow from the library.
*
* You probably should instead use `useLoadWorkflowWithDialog`, which opens a dialog to prevent loss of unsaved changes
* and handles the loading process.
*/
export const useLoadWorkflowFromLibrary = () => {
const toast = useToast();
const { t } = useTranslation();
const validateAndLoadWorkflow = useValidateAndLoadWorkflow();
const [getWorkflow] = useLazyGetWorkflowQuery();
const [updateOpenedAt] = useUpdateOpenedAtMutation();
const loadWorkflowFromLibrary = useCallback(
async (
workflowId: string,
options: {
onSuccess?: (workflow: WorkflowV3) => void;
onError?: () => void;
onCompleted?: () => void;
} = {}
) => {
const { onSuccess, onError, onCompleted } = options;
try {
const res = await getWorkflow(workflowId).unwrap();
const validatedWorkflow = await validateAndLoadWorkflow(res.workflow);
if (!validatedWorkflow) {
onError?.();
return;
}
updateOpenedAt({ workflow_id: workflowId });
onSuccess?.(validatedWorkflow);
} catch {
// This is catching the error from the getWorkflow query
toast({
id: `AUTH_ERROR_TOAST_${workflowsApi.endpoints.getWorkflow.name}`,
title: t('toast.problemRetrievingWorkflow'),
status: 'error',
});
onError?.();
} finally {
onCompleted?.();
}
},
[getWorkflow, validateAndLoadWorkflow, updateOpenedAt, toast, t]
);
return loadWorkflowFromLibrary;
};

View File

@@ -0,0 +1,39 @@
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { useValidateAndLoadWorkflow } from 'features/workflowLibrary/hooks/useValidateAndLoadWorkflow';
import { useCallback } from 'react';
/**
* Loads a workflow from an object.
*
* You probably should instead use `useLoadWorkflowWithDialog`, which opens a dialog to prevent loss of unsaved changes
* and handles the loading process.
*/
export const useLoadWorkflowFromObject = () => {
const validateAndLoadWorkflow = useValidateAndLoadWorkflow();
const loadWorkflowFromObject = useCallback(
async (
unvalidatedWorkflow: unknown,
options: {
onSuccess?: (workflow: WorkflowV3) => void;
onError?: () => void;
onCompleted?: () => void;
} = {}
) => {
const { onSuccess, onError, onCompleted } = options;
try {
const validatedWorkflow = await validateAndLoadWorkflow(unvalidatedWorkflow);
if (!validatedWorkflow) {
onError?.();
return;
}
onSuccess?.(validatedWorkflow);
} finally {
onCompleted?.();
}
},
[validateAndLoadWorkflow]
);
return loadWorkflowFromObject;
};

View File

@@ -4,51 +4,55 @@ import { $nodeExecutionStates } from 'features/nodes/hooks/useNodeExecutionState
import { workflowLoaded } from 'features/nodes/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { $needsFit } from 'features/nodes/store/reactFlowInstance';
import type { Templates } from 'features/nodes/store/types';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { serializeError } from 'serialize-error';
import { checkBoardAccess, checkImageAccess, checkModelAccess } from 'services/api/hooks/accessChecks';
import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types';
import { z } from 'zod';
import { fromZodError } from 'zod-validation-error';
const log = logger('workflows');
const getWorkflowFromStringifiedWorkflowOrGraph = async (data: GraphAndWorkflowResponse, templates: Templates) => {
if (data.workflow) {
// Prefer to load the workflow if it's available - it has more information
const parsed = JSON.parse(data.workflow);
return await validateWorkflow({
workflow: parsed,
templates,
checkImageAccess,
checkBoardAccess,
checkModelAccess,
});
} else if (data.graph) {
// Else we fall back on the graph, using the graphToWorkflow function to convert and do layout
const parsed = JSON.parse(data.graph);
const workflow = graphToWorkflow(parsed as NonNullableGraph, true);
return await validateWorkflow({ workflow, templates, checkImageAccess, checkBoardAccess, checkModelAccess });
} else {
throw new Error('No workflow or graph provided');
}
};
export const useLoadWorkflow = () => {
/**
* This hook manages the lower-level workflow validation and loading process.
*
* You probably should instead use `useLoadWorkflowWithDialog`, which opens a dialog to prevent loss of unsaved changes
* and handles the loading process.
*
* Internally, `useLoadWorkflowWithDialog` uses these hooks...
*
* - `useLoadWorkflowFromFile`
* - `useLoadWorkflowFromImage`
* - `useLoadWorkflowFromLibrary`
* - `useLoadWorkflowFromObject`
*
* ...each of which internally uses hook.
*/
export const useValidateAndLoadWorkflow = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const loadWorkflow = useCallback(
async (data: GraphAndWorkflowResponse): Promise<WorkflowV3 | null> => {
const validateAndLoadWorkflow = useCallback(
/**
* Validate and load a workflow into the editor.
*
* The unvalidated workflow should be a JS object. Do not pass a raw JSON string.
*
* This function catches all errors. It toasts and logs on success and error.
*/
async (unvalidatedWorkflow: unknown): Promise<WorkflowV3 | null> => {
try {
const templates = $templates.get();
const { workflow, warnings } = await getWorkflowFromStringifiedWorkflowOrGraph(data, templates);
const { workflow, warnings } = await validateWorkflow({
workflow: unvalidatedWorkflow,
templates,
checkImageAccess,
checkBoardAccess,
checkModelAccess,
});
$nodeExecutionStates.set({});
dispatch(workflowLoaded(workflow));
@@ -119,5 +123,5 @@ export const useLoadWorkflow = () => {
[dispatch, t]
);
return loadWorkflow;
return validateAndLoadWorkflow;
};

View File

@@ -1,6 +1,7 @@
import queryString from 'query-string';
import type { paths } from 'services/api/schema';
import type { ApiTagDescription } from '..';
import { api, buildV1Url, LIST_TAG } from '..';
/**
@@ -19,15 +20,6 @@ export const workflowsApi = api.injectEndpoints({
>({
query: (workflow_id) => buildWorkflowsUrl(`i/${workflow_id}`),
providesTags: (result, error, workflow_id) => [{ type: 'Workflow', id: workflow_id }, 'FetchOnReconnect'],
onQueryStarted: async (arg, api) => {
const { dispatch, queryFulfilled } = api;
try {
await queryFulfilled;
dispatch(workflowsApi.util.invalidateTags([{ type: 'WorkflowsRecent', id: LIST_TAG }]));
} catch {
// no-op
}
},
}),
deleteWorkflow: build.mutation<void, string>({
query: (workflow_id) => ({
@@ -35,9 +27,11 @@ export const workflowsApi = api.injectEndpoints({
method: 'DELETE',
}),
invalidatesTags: (result, error, workflow_id) => [
// Because this may change the order of the list, we need to invalidate the whole list
{ type: 'Workflow', id: LIST_TAG },
{ type: 'Workflow', id: workflow_id },
{ type: 'WorkflowsRecent', id: LIST_TAG },
'WorkflowTagCounts',
'WorkflowCategoryCounts',
],
}),
createWorkflow: build.mutation<
@@ -50,8 +44,10 @@ export const workflowsApi = api.injectEndpoints({
body: { workflow },
}),
invalidatesTags: [
// Because this may change the order of the list, we need to invalidate the whole list
{ type: 'Workflow', id: LIST_TAG },
{ type: 'WorkflowsRecent', id: LIST_TAG },
'WorkflowTagCounts',
'WorkflowCategoryCounts',
],
}),
updateWorkflow: build.mutation<
@@ -64,27 +60,28 @@ export const workflowsApi = api.injectEndpoints({
body: { workflow },
}),
invalidatesTags: (response, error, workflow) => [
{ type: 'WorkflowsRecent', id: LIST_TAG },
{ type: 'Workflow', id: LIST_TAG },
{ type: 'Workflow', id: workflow.id },
'WorkflowTagCounts',
'WorkflowCategoryCounts',
],
}),
listWorkflows: build.query<
paths['/api/v1/workflows/']['get']['responses']['200']['content']['application/json'],
NonNullable<paths['/api/v1/workflows/']['get']['parameters']['query']>
getCountsByTag: build.query<
paths['/api/v1/workflows/counts_by_tag']['get']['responses']['200']['content']['application/json'],
NonNullable<paths['/api/v1/workflows/counts_by_tag']['get']['parameters']['query']>
>({
query: (params) => ({
url: `${buildWorkflowsUrl()}?${queryString.stringify(params, { arrayFormat: 'none' })}`,
url: `${buildWorkflowsUrl('counts_by_tag')}?${queryString.stringify(params, { arrayFormat: 'none' })}`,
}),
providesTags: ['FetchOnReconnect', { type: 'Workflow', id: LIST_TAG }],
providesTags: ['WorkflowTagCounts'],
}),
getCounts: build.query<
paths['/api/v1/workflows/counts']['get']['responses']['200']['content']['application/json'],
NonNullable<paths['/api/v1/workflows/counts']['get']['parameters']['query']>
getCountsByCategory: build.query<
paths['/api/v1/workflows/counts_by_category']['get']['responses']['200']['content']['application/json'],
NonNullable<paths['/api/v1/workflows/counts_by_category']['get']['parameters']['query']>
>({
query: (params) => ({
url: `${buildWorkflowsUrl('counts')}?${queryString.stringify(params, { arrayFormat: 'none' })}`,
url: `${buildWorkflowsUrl('counts_by_category')}?${queryString.stringify(params, { arrayFormat: 'none' })}`,
}),
providesTags: ['WorkflowCategoryCounts'],
}),
listWorkflowsInfinite: build.infiniteQuery<
paths['/api/v1/workflows/']['get']['responses']['200']['content']['application/json'],
@@ -108,13 +105,29 @@ export const workflowsApi = api.injectEndpoints({
return firstPageParam > -1 ? firstPageParam - 1 : undefined;
},
},
providesTags: (result) => {
const tags: ApiTagDescription[] = ['FetchOnReconnect', { type: 'Workflow', id: LIST_TAG }];
if (result) {
tags.push(
...result.pages
.map(({ items }) => items)
.flat()
.map((workflow) => ({ type: 'Workflow', id: workflow.workflow_id }) as const)
);
}
return tags;
},
}),
updateOpenedAt: build.mutation<void, { workflow_id: string }>({
query: ({ workflow_id }) => ({
url: buildWorkflowsUrl(`i/${workflow_id}/opened_at`),
method: 'PUT',
}),
invalidatesTags: (result, error, { workflow_id }) => [{ type: 'Workflow', id: workflow_id }],
invalidatesTags: (result, error, { workflow_id }) => [
{ type: 'Workflow', id: workflow_id },
// Because this may change the order of the list, we need to invalidate the whole list
{ type: 'Workflow', id: LIST_TAG },
],
}),
setWorkflowThumbnail: build.mutation<void, { workflow_id: string; image: File }>({
query: ({ workflow_id, image }) => {
@@ -126,33 +139,27 @@ export const workflowsApi = api.injectEndpoints({
body: formData,
};
},
invalidatesTags: (result, error, { workflow_id }) => [
{ type: 'Workflow', id: workflow_id },
{ type: 'WorkflowsRecent', id: LIST_TAG },
],
invalidatesTags: (result, error, { workflow_id }) => [{ type: 'Workflow', id: workflow_id }],
}),
deleteWorkflowThumbnail: build.mutation<void, string>({
query: (workflow_id) => ({
url: buildWorkflowsUrl(`i/${workflow_id}/thumbnail`),
method: 'DELETE',
}),
invalidatesTags: (result, error, workflow_id) => [
{ type: 'Workflow', id: workflow_id },
{ type: 'WorkflowsRecent', id: LIST_TAG },
],
invalidatesTags: (result, error, workflow_id) => [{ type: 'Workflow', id: workflow_id }],
}),
}),
});
export const {
useUpdateOpenedAtMutation,
useGetCountsQuery,
useGetCountsByTagQuery,
useGetCountsByCategoryQuery,
useLazyGetWorkflowQuery,
useGetWorkflowQuery,
useCreateWorkflowMutation,
useDeleteWorkflowMutation,
useUpdateWorkflowMutation,
useListWorkflowsQuery,
useListWorkflowsInfiniteInfiniteQuery,
useSetWorkflowThumbnailMutation,
useDeleteWorkflowThumbnailMutation,

View File

@@ -39,7 +39,7 @@ const buildModelsHook =
typeGuard: (config: AnyModelConfig, excludeSubmodels?: boolean) => config is T,
excludeSubmodels?: boolean
) =>
() => {
(filter: (config: T) => boolean = () => true) => {
const result = useGetModelConfigsQuery(undefined);
const modelConfigs = useMemo(() => {
if (!result.data) {
@@ -48,8 +48,9 @@ const buildModelsHook =
return modelConfigsAdapterSelectors
.selectAll(result.data)
.filter((config) => typeGuard(config, excludeSubmodels));
}, [result]);
.filter((config) => typeGuard(config, excludeSubmodels))
.filter(filter);
}, [filter, result.data]);
return [modelConfigs, result] as const;
};
@@ -78,6 +79,9 @@ export const useFluxVAEModels = (args?: ModelHookArgs) =>
export const useCLIPVisionModels = buildModelsHook(isCLIPVisionModelConfig);
export const useSigLipModels = buildModelsHook(isSigLipModelConfig);
export const useFluxReduxModels = buildModelsHook(isFluxReduxModelConfig);
export const useIPAdapterOrFLUXReduxModels = buildModelsHook(
(config) => isIPAdapterModelConfig(config) || isFluxReduxModelConfig(config)
);
// const buildModelsSelector =
// <T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T): Selector<RootState, T[]> =>

View File

@@ -44,7 +44,8 @@ const tagTypes = [
'LoRAModel',
'SDXLRefinerModel',
'Workflow',
'WorkflowsRecent',
'WorkflowTagCounts',
'WorkflowCategoryCounts',
'StylePreset',
'Schema',
'QueueCountsByDestination',

View File

@@ -1438,7 +1438,7 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/workflows/counts": {
"/api/v1/workflows/counts_by_tag": {
parameters: {
query?: never;
header?: never;
@@ -1446,10 +1446,30 @@ export type paths = {
cookie?: never;
};
/**
* Get Counts
* @description Gets a the count of workflows that include the specified tags and categories
* Get Counts By Tag
* @description Counts workflows by tag
*/
get: operations["get_counts"];
get: operations["get_counts_by_tag"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/workflows/counts_by_category": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Counts By Category
* @description Counts workflows by category
*/
get: operations["counts_by_category"];
put?: never;
post?: never;
delete?: never;
@@ -7970,12 +7990,6 @@ export type components = {
* @default null
*/
redux_model?: components["schemas"]["ModelIdentifierField"];
/**
* SigLIP Model
* @description The SigLIP model to use.
* @default null
*/
siglip_model?: components["schemas"]["ModelIdentifierField"];
/**
* type
* @default flux_redux
@@ -21100,7 +21114,7 @@ export type components = {
* Opened At
* @description The opened timestamp of the workflow.
*/
opened_at: string;
opened_at: string | null;
/** @description The workflow. */
workflow: components["schemas"]["Workflow"];
};
@@ -21130,7 +21144,7 @@ export type components = {
* Opened At
* @description The opened timestamp of the workflow.
*/
opened_at: string;
opened_at: string | null;
/**
* Description
* @description The description of the workflow.
@@ -21181,7 +21195,7 @@ export type components = {
* Opened At
* @description The opened timestamp of the workflow.
*/
opened_at: string;
opened_at: string | null;
/** @description The workflow. */
workflow: components["schemas"]["Workflow"];
/**
@@ -24243,6 +24257,8 @@ export interface operations {
tags?: string[] | null;
/** @description The text to query by (matches name and description) */
query?: string | null;
/** @description Whether to include/exclude recent workflows */
has_been_opened?: boolean | null;
};
header?: never;
path?: never;
@@ -24417,13 +24433,15 @@ export interface operations {
};
};
};
get_counts: {
get_counts_by_tag: {
parameters: {
query?: {
/** @description The tags to include */
tags?: string[] | null;
query: {
/** @description The tags to get counts for */
tags: string[];
/** @description The categories to include */
categories?: components["schemas"]["WorkflowCategory"][] | null;
/** @description Whether to include/exclude recent workflows */
has_been_opened?: boolean | null;
};
header?: never;
path?: never;
@@ -24437,7 +24455,45 @@ export interface operations {
[name: string]: unknown;
};
content: {
"application/json": number;
"application/json": {
[key: string]: number;
};
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
counts_by_category: {
parameters: {
query: {
/** @description The categories to include */
categories: components["schemas"]["WorkflowCategory"][];
/** @description Whether to include/exclude recent workflows */
has_been_opened?: boolean | null;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": {
[key: string]: number;
};
};
};
/** @description Validation Error */

View File

@@ -63,7 +63,7 @@ type DiffusersModelConfig = S['MainDiffusersConfig'];
export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
export type SigLipModelConfig = S['SigLIPConfig'];
export type FluxReduxModelConfig = S['FluxReduxConfig'];
export type FLUXReduxModelConfig = S['FluxReduxConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| ControlLoRAModelConfig
@@ -80,7 +80,7 @@ export type AnyModelConfig =
| MainModelConfig
| CLIPVisionDiffusersConfig
| SigLipModelConfig
| FluxReduxModelConfig;
| FLUXReduxModelConfig;
/**
* Checks if a list of submodels contains any that match a given variant or type
@@ -217,7 +217,7 @@ export const isSigLipModelConfig = (config: AnyModelConfig): config is SigLipMod
return config.type === 'siglip';
};
export const isFluxReduxModelConfig = (config: AnyModelConfig): config is FluxReduxModelConfig => {
export const isFluxReduxModelConfig = (config: AnyModelConfig): config is FLUXReduxModelConfig => {
return config.type === 'flux_redux';
};