Compare commits

..

9 Commits

Author SHA1 Message Date
psychedelicious
c1a4376b75 chore: bump version to v6.1.0rc1 2025-07-11 08:20:02 +10:00
psychedelicious
ef4d5d7377 feat(ui): virtualized list for staging area
Make the staging area a virtualized list so it doesn't choke when there
are a large number (i.e. more than a few hundred) of queue items.
2025-07-11 07:50:57 +10:00
Mary Hipp Rogers
6b0dfd8427 dont reset canvas if studio is loaded with canvas destination (#8252)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2025-07-10 09:36:41 -04:00
psychedelicious
471c010217 fix(ui): invalid language crashes app
- Apparently locales must use hyphens instead of underscores. This must
have been a fairly recent change that we didn't catch. It caused i18n to
throw for Brasilian Portuguese and both Simplified and Traditional
Mandarin. Change the locales to use the right strings.
- Move the theme + locale provider inside of the error boundary. This
allows errors with locals to be caught by the error boundary instead of
hard-crashing the app. The error screen is unstyled if this happens but
at least it has the reset button.
- Add a migration for the system slice to fix existing users' language
selections. For example, if the user had an incorrect language setting
of `zh_CN`, it will be changed to the correct `zh-CN`.
2025-07-10 14:27:36 +10:00
psychedelicious
b1193022f7 fix(ui): sometimes images added to gallery show as placeholder only
The range-based fetching logic had a subtle bug - it didn't keep track
of what the _current_ visible range is - only the ranges that the user
last scrolled to.

When an image was added to the gallery, the logic saw that the images
had changed, but thought it had already loaded everything it needed to,
so it didn't load the new image.

The updated logic tracks the current visible range separately from the
accumulated scroll ranges to address this issue.
2025-07-10 14:27:36 +10:00
psychedelicious
2152ca092c fix(ui): workaround for dockview bug that lets you drag tabs in certain ways 2025-07-10 14:27:36 +10:00
psychedelicious
ccc62ba56d perf(ui): revised range-based fetching strategy
When the user scrolls in the gallery, we are alerted of the new range of
visible images. Then we fetch those specific images.

Previously, each change of range triggered a throttled function to fetch
that range. The throttle timeout was 100ms.

Now, each change of range appends that range to a list of ranges and
triggers the throttled fetch. The timeout is increased to 500ms, but to
compensate, each fetch handles all ranges that had been accumulated
since the last fetch.

The result is far fewer network requests, but each of them gets more
images.
2025-07-10 14:27:36 +10:00
psychedelicious
9cf82de8c5 fix(ui): check for absolute value of scroll velocity to handle scrolling up 2025-07-10 14:27:36 +10:00
psychedelicious
aced349152 perf(ui): increase viewport in gallery
This allows us to prefetch more images and reduce how often placeholders
are shown as we fetch more images in the gallery.
2025-07-10 14:27:36 +10:00
21 changed files with 290 additions and 791 deletions

View File

@@ -11,6 +11,7 @@ import { memo, useCallback } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import ThemeLocaleProvider from './ThemeLocaleProvider';
const DEFAULT_CONFIG = {};
interface Props {
@@ -30,12 +31,14 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{!didStudioInit && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
<ThemeLocaleProvider>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{!didStudioInit && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
</ThemeLocaleProvider>
</ErrorBoundary>
);
};

View File

@@ -42,7 +42,6 @@ import { $socketOptions } from 'services/events/stores';
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
interface Props extends PropsWithChildren {
apiUrl?: string;
@@ -330,9 +329,7 @@ const InvokeAIUI = ({
<React.StrictMode>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<App config={config} studioInitAction={studioInitAction} />
</ThemeLocaleProvider>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
</React.StrictMode>

View File

@@ -170,7 +170,6 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
case 'canvas':
// Go to the canvas tab, open the launchpad
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
store.dispatch(canvasReset());
break;
case 'workflows':
// Go to the workflows tab

View File

@@ -27,6 +27,7 @@ const sx = {
alignItems: 'center',
justifyContent: 'center',
flexShrink: 0,
h: 'full',
aspectRatio: '1/1',
borderWidth: 2,
borderRadius: 'base',
@@ -39,11 +40,11 @@ const sx = {
type Props = {
item: S['SessionQueueItem'];
number: number;
index: number;
isSelected: boolean;
};
export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) => {
export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) => {
const dispatch = useAppDispatch();
const ctx = useCanvasSessionContext();
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
@@ -69,7 +70,7 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) =
return (
<Flex
id={getQueueItemElementId(item.item_id)}
id={getQueueItemElementId(index)}
sx={sx}
data-selected={isSelected}
onClick={onClick}
@@ -78,7 +79,7 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) =
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail />}
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
<QueueItemNumber number={number} position="absolute" top={0} left={1} />
<QueueItemNumber number={index + 1} position="absolute" top={0} left={1} />
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />
</Flex>
);

View File

@@ -1,17 +1,148 @@
import { Flex } from '@invoke-ai/ui-library';
import { Box, Flex, forwardRef } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { logger } from 'app/logging/logger';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { QueueItemPreviewMini } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewMini';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { memo, useEffect } from 'react';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import type { CSSProperties, RefObject } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import type { Components, ItemContent, ListRange, VirtuosoHandle, VirtuosoProps } from 'react-virtuoso';
import { Virtuoso } from 'react-virtuoso';
import type { S } from 'services/api/types';
import { getQueueItemElementId } from './shared';
const log = logger('system');
const virtuosoStyles = {
width: '100%',
height: '72px',
} satisfies CSSProperties;
type VirtuosoContext = { selectedItemId: number | null };
/**
* Scroll the item at the given index into view if it is not currently visible.
*/
const scrollIntoView = (
targetIndex: number,
rootEl: HTMLDivElement,
virtuosoHandle: VirtuosoHandle,
range: ListRange
) => {
if (range.endIndex === 0) {
// No range is rendered; no need to scroll to anything.
return;
}
const targetItem = rootEl.querySelector(`#${getQueueItemElementId(targetIndex)}`);
if (!targetItem) {
if (targetIndex > range.endIndex) {
virtuosoHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'end',
});
} else if (targetIndex < range.startIndex) {
virtuosoHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'start',
});
} else {
log.debug(
`Unable to find queue item at index ${targetIndex} but it is in the rendered range ${range.startIndex}-${range.endIndex}`
);
}
return;
}
// We found the image in the DOM, but it might be in the overscan range - rendered but not in the visible viewport.
// Check if it is in the viewport and scroll if necessary.
const itemRect = targetItem.getBoundingClientRect();
const rootRect = rootEl.getBoundingClientRect();
if (itemRect.left < rootRect.left) {
virtuosoHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'start',
});
} else if (itemRect.right > rootRect.right) {
virtuosoHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'end',
});
} else {
// Image is already in view
}
return;
};
const useScrollableStagingArea = (rootRef: RefObject<HTMLDivElement>) => {
const [scroller, scrollerRef] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars({
defer: true,
events: {
initialized(osInstance) {
// force overflow styles
const { viewport } = osInstance.elements();
viewport.style.overflowX = `var(--os-viewport-overflow-x)`;
viewport.style.overflowY = `var(--os-viewport-overflow-y)`;
},
},
options: {
scrollbars: {
visibility: 'auto',
autoHide: 'scroll',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
overflow: {
y: 'hidden',
x: 'scroll',
},
},
});
useEffect(() => {
const { current: root } = rootRef;
if (scroller && root) {
initialize({
target: root,
elements: {
viewport: scroller,
},
});
}
return () => {
osInstance()?.destroy();
};
}, [scroller, initialize, osInstance, rootRef]);
return scrollerRef;
};
export const StagingAreaItemsList = memo(() => {
const canvasManager = useCanvasManagerSafe();
const ctx = useCanvasSessionContext();
const virtuosoRef = useRef<VirtuosoHandle>(null);
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
const rootRef = useRef<HTMLDivElement>(null);
const items = useStore(ctx.$items);
const selectedItemId = useStore(ctx.$selectedItemId);
const context = useMemo(() => ({ selectedItemId }), [selectedItemId]);
const scrollerRef = useScrollableStagingArea(rootRef);
useEffect(() => {
if (!canvasManager) {
return;
@@ -20,21 +151,64 @@ export const StagingAreaItemsList = memo(() => {
return canvasManager.stagingArea.connectToSession(ctx.$selectedItemId, ctx.$progressData, ctx.$isPending);
}, [canvasManager, ctx.$progressData, ctx.$selectedItemId, ctx.$isPending]);
useEffect(() => {
return ctx.$selectedItemIndex.listen((index) => {
if (!virtuosoRef.current) {
return;
}
if (!rootRef.current) {
return;
}
if (index === null) {
return;
}
scrollIntoView(index, rootRef.current, virtuosoRef.current, rangeRef.current);
});
}, [ctx.$selectedItemIndex]);
const onRangeChanged = useCallback((range: ListRange) => {
rangeRef.current = range;
}, []);
return (
<Flex position="relative" maxW="full" w="full" h="72px">
<ScrollableContent overflowX="scroll" overflowY="hidden">
<Flex gap={2} w="full" h="full" justifyContent="safe center">
{items.map((item, i) => (
<QueueItemPreviewMini
key={`${item.item_id}-mini`}
item={item}
number={i + 1}
isSelected={selectedItemId === item.item_id}
/>
))}
</Flex>
</ScrollableContent>
</Flex>
<Box data-overlayscrollbars-initialize="" ref={rootRef} position="relative" w="full" h="full">
<Virtuoso<S['SessionQueueItem'], VirtuosoContext>
ref={virtuosoRef}
context={context}
data={items}
horizontalDirection
style={virtuosoStyles}
itemContent={itemContent}
components={components}
rangeChanged={onRangeChanged}
// Virtuoso expects the ref to be of HTMLElement | null | Window, but overlayscrollbars doesn't allow Window
scrollerRef={scrollerRef as VirtuosoProps<S['SessionQueueItem'], VirtuosoContext>['scrollerRef']}
/>
</Box>
);
});
StagingAreaItemsList.displayName = 'StagingAreaItemsList';
const itemContent: ItemContent<S['SessionQueueItem'], VirtuosoContext> = (index, item, { selectedItemId }) => (
<QueueItemPreviewMini
key={`${item.item_id}-mini`}
item={item}
index={index}
isSelected={selectedItemId === item.item_id}
/>
);
const listSx = {
'& > * + *': {
pl: 2,
},
};
const components: Components<S['SessionQueueItem'], VirtuosoContext> = {
List: forwardRef(({ context: _, ...rest }, ref) => {
return <Flex ref={ref} sx={listSx} {...rest} />;
}),
};

View File

@@ -13,7 +13,7 @@ export const getProgressMessage = (data?: S['InvocationProgressEvent'] | null) =
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
export const getQueueItemElementId = (itemId: number) => `queue-item-status-card-${itemId}`;
export const getQueueItemElementId = (index: number) => `queue-item-preview-${index}`;
export const getOutputImageName = (item: S['SessionQueueItem']) => {
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>

View File

@@ -1,7 +1,6 @@
import { ButtonGroup, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
import { StagingAreaToolbarAcceptButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarAcceptButton';
import { StagingAreaToolbarDiscardAllButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardAllButton';
import { StagingAreaToolbarDiscardSelectedButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardSelectedButton';
@@ -12,7 +11,7 @@ import { StagingAreaToolbarPrevButton } from 'features/controlLayers/components/
import { StagingAreaToolbarSaveSelectedToGalleryButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveSelectedToGalleryButton';
import { StagingAreaToolbarToggleShowResultsButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarToggleShowResultsButton';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { memo, useEffect } from 'react';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { StagingAreaAutoSwitchButtons } from './StagingAreaAutoSwitchButtons';
@@ -23,16 +22,6 @@ export const StagingAreaToolbar = memo(() => {
const ctx = useCanvasSessionContext();
useEffect(() => {
return ctx.$selectedItemId.listen((id) => {
if (id !== null) {
document
.getElementById(getQueueItemElementId(id))
?.scrollIntoView({ block: 'nearest', inline: 'nearest', behavior: 'auto' });
}
});
}, [ctx.$selectedItemId]);
useHotkeys('meta+left', ctx.selectFirst, { preventDefault: true });
useHotkeys('meta+right', ctx.selectLast, { preventDefault: true });

View File

@@ -482,11 +482,6 @@ export const NewGallery = memo(() => {
const context = useMemo<GridContext>(() => ({ imageNames, queryArgs }), [imageNames, queryArgs]);
// Item content function
const itemContent: GridItemContent<string, GridContext> = useCallback((index, imageName) => {
return <ImageAtPosition index={index} imageName={imageName} />;
}, []);
if (isLoading) {
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center" gap={4}>
@@ -511,7 +506,7 @@ export const NewGallery = memo(() => {
ref={virtuosoRef}
context={context}
data={imageNames}
increaseViewportBy={2048}
increaseViewportBy={4096}
itemContent={itemContent}
computeItemKey={computeItemKey}
components={components}
@@ -528,8 +523,12 @@ export const NewGallery = memo(() => {
NewGallery.displayName = 'NewGallery';
const scrollSeekConfiguration: ScrollSeekConfiguration = {
enter: (velocity) => velocity > 4096,
exit: (velocity) => velocity === 0,
enter: (velocity) => {
return Math.abs(velocity) > 2048;
},
exit: (velocity) => {
return velocity === 0;
},
};
// Styles
@@ -549,6 +548,10 @@ const ListComponent: GridComponents<GridContext>['List'] = forwardRef(({ context
});
ListComponent.displayName = 'ListComponent';
const itemContent: GridItemContent<string, GridContext> = (index, imageName) => {
return <ImageAtPosition index={index} imageName={imageName} />;
};
const ItemComponent: GridComponents<GridContext>['Item'] = forwardRef(({ context: _, ...rest }, ref) => (
<GridItem ref={ref} aspectRatio="1/1" {...rest} />
));

View File

@@ -1,5 +1,5 @@
import { useAppStore } from 'app/store/storeHooks';
import { useCallback, useEffect, useRef } from 'react';
import { useCallback, useEffect, useState } from 'react';
import type { ListRange } from 'react-virtuoso';
import { imagesApi, useGetImageDTOsByNamesMutation } from 'services/api/endpoints/images';
import { useThrottledCallback } from 'use-debounce';
@@ -13,33 +13,20 @@ interface UseRangeBasedImageFetchingReturn {
onRangeChanged: (range: ListRange) => void;
}
const getUncachedNames = (imageNames: string[], cachedImageNames: string[], range: ListRange): string[] => {
if (range.startIndex === range.endIndex) {
// If the start and end indices are the same, no range to fetch
return [];
}
const getUncachedNames = (imageNames: string[], cachedImageNames: string[], ranges: ListRange[]): string[] => {
const uncachedNamesSet = new Set<string>();
const cachedImageNamesSet = new Set(cachedImageNames);
if (imageNames.length === 0) {
return [];
}
const start = Math.max(0, range.startIndex);
const end = Math.min(imageNames.length - 1, range.endIndex);
if (cachedImageNames.length === 0) {
return imageNames.slice(start, end + 1);
}
const uncachedNames: string[] = [];
for (let i = start; i <= end; i++) {
const imageName = imageNames[i]!;
if (!cachedImageNames.includes(imageName)) {
uncachedNames.push(imageName);
for (const range of ranges) {
for (let i = range.startIndex; i <= range.endIndex; i++) {
const n = imageNames[i]!;
if (n && !cachedImageNamesSet.has(n)) {
uncachedNamesSet.add(n);
}
}
}
return uncachedNames;
return Array.from(uncachedNamesSet);
};
/**
@@ -53,39 +40,36 @@ export const useRangeBasedImageFetching = ({
}: UseRangeBasedImageFetchingArgs): UseRangeBasedImageFetchingReturn => {
const store = useAppStore();
const [getImageDTOsByNames] = useGetImageDTOsByNamesMutation();
const lastRangeRef = useRef<ListRange | null>(null);
const [lastRange, setLastRange] = useState<ListRange | null>(null);
const [pendingRanges, setPendingRanges] = useState<ListRange[]>([]);
const fetchImages = useCallback(
(visibleRange: ListRange) => {
(ranges: ListRange[], imageNames: string[]) => {
if (!enabled) {
return;
}
const cachedImageNames = imagesApi.util.selectCachedArgsForQuery(store.getState(), 'getImageDTO');
const uncachedNames = getUncachedNames(imageNames, cachedImageNames, visibleRange);
const uncachedNames = getUncachedNames(imageNames, cachedImageNames, ranges);
if (uncachedNames.length === 0) {
return;
}
getImageDTOsByNames({ image_names: uncachedNames });
lastRangeRef.current = visibleRange;
setPendingRanges([]);
},
[enabled, getImageDTOsByNames, imageNames, store]
[enabled, getImageDTOsByNames, store]
);
const throttledFetchImages = useThrottledCallback(fetchImages, 100);
const throttledFetchImages = useThrottledCallback(fetchImages, 500);
const onRangeChanged = useCallback(
(range: ListRange) => {
throttledFetchImages(range);
},
[throttledFetchImages]
);
const onRangeChanged = useCallback((range: ListRange) => {
setLastRange(range);
setPendingRanges((prev) => [...prev, range]);
}, []);
useEffect(() => {
if (!lastRangeRef.current) {
return;
}
throttledFetchImages(lastRangeRef.current);
}, [imageNames, throttledFetchImages]);
const combinedRanges = lastRange ? [...pendingRanges, lastRange] : pendingRanges;
throttledFetchImages(combinedRanges, imageNames);
}, [imageNames, lastRange, pendingRanges, throttledFetchImages]);
return {
onRangeChanged,

View File

@@ -26,14 +26,14 @@ const optionsObject: Record<Language, string> = {
nl: 'Nederlands',
pl: 'Polski',
pt: 'Português',
pt_BR: 'Português do Brasil',
'pt-BR': 'Português do Brasil',
ru: 'Русский',
sv: 'Svenska',
tr: 'Türkçe',
ua: 'Украї́нська',
vi: 'Tiếng Việt',
zh_CN: '简体中文',
zh_Hant: '漢語',
'zh-CN': '简体中文',
'zh-Hant': '漢語',
};
const options = map(optionsObject, (label, value) => ({ label, value }));

View File

@@ -9,7 +9,7 @@ import { uniq } from 'es-toolkit/compat';
import type { Language, SystemState } from './types';
const initialSystemState: SystemState = {
_version: 1,
_version: 2,
shouldConfirmOnDelete: true,
shouldAntialiasProgressImage: false,
shouldConfirmOnNewSession: true,
@@ -96,6 +96,10 @@ const migrateSystemState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state.language = (state as SystemState).language.replace('_', '-');
state._version = 2;
}
return state;
};

View File

@@ -17,20 +17,20 @@ const zLanguage = z.enum([
'nl',
'pl',
'pt',
'pt_BR',
'pt-BR',
'ru',
'sv',
'tr',
'ua',
'vi',
'zh_CN',
'zh_Hant',
'zh-CN',
'zh-Hant',
]);
export type Language = z.infer<typeof zLanguage>;
export const isLanguage = (v: unknown): v is Language => zLanguage.safeParse(v).success;
export interface SystemState {
_version: 1;
_version: 2;
shouldConfirmOnDelete: boolean;
shouldAntialiasProgressImage: boolean;
shouldConfirmOnNewSession: boolean;

View File

@@ -13,7 +13,7 @@ export const StagingArea = memo(() => {
}
return (
<Flex position="absolute" flexDir="column" bottom={4} gap={2} align="center" justify="center" left={4} right={4}>
<Flex position="absolute" flexDir="column" bottom={2} gap={2} align="center" justify="center" left={2} right={2}>
<StagingAreaItemsList />
<StagingAreaToolbar />
</Flex>

View File

@@ -16,6 +16,8 @@ import {
PiTextAaBold,
} from 'react-icons/pi';
import { useHackOutDvTabDraggable } from './use-hack-out-dv-tab-draggable';
const TAB_ICONS: Record<TabName, IconType> = {
generate: PiTextAaBold,
canvas: PiBoundingBoxBold,
@@ -41,6 +43,8 @@ export const TabWithLaunchpadIcon = memo((props: IDockviewPanelHeaderProps) => {
setFocusedRegion(props.params.focusRegion);
}, [props.params.focusRegion]);
useHackOutDvTabDraggable(ref);
return (
<Flex ref={ref} alignItems="center" h="full" px={4} gap={3} onPointerDown={onPointerDown}>
<Icon as={TAB_ICONS[activeTab]} color="invokeYellow.300" boxSize={5} />

View File

@@ -5,6 +5,7 @@ import type { IDockviewPanelHeaderProps } from 'dockview';
import { memo, useCallback, useRef } from 'react';
import type { PanelParameters } from './auto-layout-context';
import { useHackOutDvTabDraggable } from './use-hack-out-dv-tab-draggable';
export const TabWithoutCloseButton = memo((props: IDockviewPanelHeaderProps<PanelParameters>) => {
const ref = useRef<HTMLDivElement>(null);
@@ -20,6 +21,8 @@ export const TabWithoutCloseButton = memo((props: IDockviewPanelHeaderProps<Pane
setFocusedRegion(props.params.focusRegion);
}, [props.params.focusRegion]);
useHackOutDvTabDraggable(ref);
return (
<Flex ref={ref} alignItems="center" h="full" onPointerDown={onPointerDown}>
<Text userSelect="none" px={4}>

View File

@@ -7,6 +7,7 @@ import { memo, useCallback, useRef } from 'react';
import { useIsGenerationInProgress } from 'services/api/endpoints/queue';
import type { PanelParameters } from './auto-layout-context';
import { useHackOutDvTabDraggable } from './use-hack-out-dv-tab-draggable';
export const TabWithoutCloseButtonAndWithProgressIndicator = memo(
(props: IDockviewPanelHeaderProps<PanelParameters>) => {
@@ -25,6 +26,8 @@ export const TabWithoutCloseButtonAndWithProgressIndicator = memo(
setFocusedRegion(props.params.focusRegion);
}, [props.params.focusRegion]);
useHackOutDvTabDraggable(ref);
return (
<Flex ref={ref} position="relative" alignItems="center" h="full" onPointerDown={onPointerDown}>
<Text userSelect="none" px={4}>

View File

@@ -0,0 +1,22 @@
import type { RefObject } from 'react';
import { useEffect } from 'react';
/**
* Prevent undesired dnd behavior in Dockview tabs.
*
* Dockview always sets the draggable flag on its tab elements, even when dnd is disabled. This hook traverses
* up from the provided ref to find the closest tab element and sets its `draggable` attribute to `false`.
*/
export const useHackOutDvTabDraggable = (ref: RefObject<HTMLElement>) => {
useEffect(() => {
const el = ref.current;
if (!el) {
return;
}
const parentTab = el.closest('.dv-tab');
if (!parentTab) {
return;
}
parentTab.setAttribute('draggable', 'false');
}, [ref]);
};

View File

@@ -1 +1 @@
__version__ = "6.0.0"
__version__ = "6.1.0rc1"

View File

@@ -1,263 +0,0 @@
#!/usr/bin/env python3
"""
Export LoRA metadata from InvokeAI database to JSON files.
This script exports LoRA metadata to JSON files with the following format:
{
"description": "",
"sd version": "Unknown",
"activation text": "",
"preferred weight": 0,
"negative text": "",
"notes": ""
}
"""
import argparse
import json
import re
import sys
from pathlib import Path
from typing import Any, Dict, Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
from invokeai.backend.util.logging import InvokeAILogger
def map_base_model_to_sd_version(base_model: BaseModelType) -> str:
"""Map BaseModelType to SD version string."""
mapping = {
BaseModelType.StableDiffusion1: "SD 1.5",
BaseModelType.StableDiffusion2: "SD 2.x",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.Flux: "FLUX",
}
return mapping.get(base_model, "Unknown")
def parse_description(description: Optional[str]) -> Dict[str, Any]:
"""Parse description field to extract structured data."""
result = {
"description": "",
"preferred_weight": 0,
"negative_text": "",
"notes": ""
}
if not description:
return result
# Try to extract structured parts from description
lines = description.split("\n")
current_section = "description"
section_content = []
for line in lines:
line = line.strip()
# Check for section markers
if line.startswith("Preferred weight:"):
# Save previous section
if current_section == "description" and section_content:
result["description"] = "\n".join(section_content).strip()
# Extract weight
weight_match = re.search(r"Preferred weight:\s*([\d.]+)", line)
if weight_match:
try:
result["preferred_weight"] = float(weight_match.group(1)) # type: ignore
except ValueError:
pass
current_section = "after_weight"
section_content = []
elif line.startswith("Negative prompt:"):
# Extract negative text
negative_text = line[len("Negative prompt:"):].strip()
result["negative_text"] = negative_text
current_section = "after_negative"
section_content = []
elif line.startswith("Notes:"):
# Extract notes
notes = line[len("Notes:"):].strip()
result["notes"] = notes
current_section = "notes"
section_content = [notes] if notes else []
elif line and current_section == "notes":
# Continue adding to notes
section_content.append(line)
elif line and current_section == "description":
# Add to description
section_content.append(line)
# Save final section
if current_section == "description" and section_content:
result["description"] = "\n".join(section_content).strip()
elif current_section == "notes" and section_content:
result["notes"] = "\n".join(section_content).strip()
return result
def export_lora_metadata(lora_model: AnyModelConfig) -> Dict[str, Any]:
"""Export LoRA model metadata to JSON format."""
# Parse description to extract structured data
parsed = parse_description(lora_model.description)
# Build activation text from trigger phrases
activation_text = ""
if hasattr(lora_model, 'trigger_phrases') and lora_model.trigger_phrases:
activation_text = ", ".join(sorted(lora_model.trigger_phrases))
# Build final JSON structure
return {
"description": parsed["description"],
"sd version": map_base_model_to_sd_version(lora_model.base),
"activation text": activation_text,
"preferred weight": parsed["preferred_weight"],
"negative text": parsed["negative_text"],
"notes": parsed["notes"]
}
def main():
parser = argparse.ArgumentParser(
description="Export LoRA metadata from InvokeAI database to JSON files"
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("."),
help="Directory to save JSON files (default: current directory)",
)
parser.add_argument(
"--model-name",
type=str,
help="Export only the specified LoRA model by name",
)
parser.add_argument(
"--model-key",
type=str,
help="Export only the specified LoRA model by key",
)
parser.add_argument(
"--filename-pattern",
type=str,
default="{name}.json",
help="Filename pattern for JSON files (default: {name}.json). "
"Available placeholders: {name}, {key}, {base}",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite existing JSON files",
)
parser.add_argument(
"--pretty",
action="store_true",
help="Pretty-print JSON output",
)
args = parser.parse_args()
# Initialize configuration and services
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger("export_lora_metadata")
# Initialize database
db = SqliteDatabase(db_path=config.db_path, logger=logger)
model_record_service = ModelRecordServiceSQL(db, logger)
# Create output directory if needed
args.output_dir.mkdir(parents=True, exist_ok=True)
# Get LoRA models to export
lora_models = []
if args.model_key:
try:
model = model_record_service.get_model(args.model_key)
if model.type != ModelType.LoRA:
print(f"Error: Model {args.model_key} is not a LoRA model", file=sys.stderr)
sys.exit(1)
lora_models = [model]
except Exception:
print(f"Error: Model with key {args.model_key} not found", file=sys.stderr)
sys.exit(1)
elif args.model_name:
models = model_record_service.search_by_attr(
model_name=args.model_name,
model_type=ModelType.LoRA
)
if not models:
print(f"Error: No LoRA model found with name '{args.model_name}'", file=sys.stderr)
sys.exit(1)
lora_models = models
else:
# Export all LoRA models
lora_models = model_record_service.search_by_attr(model_type=ModelType.LoRA)
if not lora_models:
print("No LoRA models found in database", file=sys.stderr)
sys.exit(1)
print(f"Exporting {len(lora_models)} LoRA model(s)...")
# Export each model
exported_count = 0
skipped_count = 0
for lora_model in lora_models:
# Generate filename
filename = args.filename_pattern.format(
name=lora_model.name,
key=lora_model.key,
base=lora_model.base.value
)
# Sanitize filename
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
output_path = args.output_dir / filename
# Check if file exists
if output_path.exists() and not args.overwrite:
print(f"Skipping {lora_model.name}: {output_path} already exists")
skipped_count += 1
continue
# Export metadata
metadata = export_lora_metadata(lora_model)
# Write JSON file
try:
with open(output_path, "w", encoding="utf-8") as f:
if args.pretty:
json.dump(metadata, f, indent=2, ensure_ascii=False)
else:
json.dump(metadata, f, ensure_ascii=False)
print(f"Exported {lora_model.name}{output_path}")
exported_count += 1
except Exception as e:
print(f"Error exporting {lora_model.name}: {e}", file=sys.stderr)
# Summary
print(f"\nExport complete:")
print(f" Exported: {exported_count}")
if skipped_count > 0:
print(f" Skipped: {skipped_count} (use --overwrite to replace)")
if __name__ == "__main__":
main()

View File

@@ -1,288 +0,0 @@
#!/usr/bin/env python3
"""
Import LoRA metadata from JSON files into InvokeAI database.
This script reads JSON files with the following format:
{
"description": "",
"sd version": "Unknown",
"activation text": "",
"preferred weight": 0,
"negative text": "",
"notes": ""
}
And imports the metadata into existing LoRA models in the InvokeAI database.
"""
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
from invokeai.backend.util.logging import InvokeAILogger
def map_sd_version_to_base_model(sd_version: str) -> Optional[BaseModelType]:
"""Map SD version string to BaseModelType."""
sd_version_lower = sd_version.lower()
if "xl" in sd_version_lower or "sdxl" in sd_version_lower:
return BaseModelType.StableDiffusionXL
elif "2" in sd_version_lower:
return BaseModelType.StableDiffusion2
elif "1" in sd_version_lower or "1.5" in sd_version_lower:
return BaseModelType.StableDiffusion1
elif "flux" in sd_version_lower:
return BaseModelType.Flux
else:
return None # Will not update base model if unknown
def build_description(json_data: Dict[str, Any]) -> str:
"""Build a comprehensive description from JSON data."""
parts = []
if json_data.get("description"):
parts.append(json_data["description"])
if json_data.get("preferred weight") and json_data["preferred weight"] != 0:
parts.append(f"Preferred weight: {json_data['preferred weight']}")
if json_data.get("negative text"):
parts.append(f"Negative prompt: {json_data['negative text']}")
if json_data.get("notes"):
parts.append(f"Notes: {json_data['notes']}")
return "\n\n".join(parts) if parts else ""
def process_lora_metadata(
model_record_service: ModelRecordServiceBase,
lora_model: AnyModelConfig,
json_data: Dict[str, Any],
update_base_model: bool = False,
) -> bool:
"""Process and update a single LoRA model with metadata from JSON."""
changes = ModelRecordChanges()
# Map activation text to trigger phrases
if json_data.get("activation text"):
activation_texts = [text.strip() for text in json_data["activation text"].split(",")]
changes.trigger_phrases = set(activation_texts)
# Build description from multiple fields
description = build_description(json_data)
if description:
changes.description = description
# Optionally update base model type
if update_base_model and json_data.get("sd version"):
base_model = map_sd_version_to_base_model(json_data["sd version"])
if base_model:
changes.base = base_model
# Only update if we have changes
if changes.model_dump(exclude_none=True):
try:
model_record_service.update_model(lora_model.key, changes)
return True
except Exception as e:
print(f"Error updating model {lora_model.name}: {e}")
return False
return False
def main():
parser = argparse.ArgumentParser(
description="Import LoRA metadata from JSON files into InvokeAI database"
)
parser.add_argument(
"json_file",
type=Path,
help="Path to JSON file containing LoRA metadata",
)
parser.add_argument(
"--model-name",
type=str,
help="Name of the LoRA model to update (if not specified, will try to match based on filename)",
)
parser.add_argument(
"--model-key",
type=str,
help="Key of the LoRA model to update (takes precedence over --model-name)",
)
parser.add_argument(
"--update-base-model",
action="store_true",
help="Update the base model type based on 'sd version' field",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be updated without making changes",
)
parser.add_argument(
"--batch",
action="store_true",
help="Process multiple JSON files in batch mode (json_file should be a directory)",
)
args = parser.parse_args()
# Initialize configuration and services
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger("import_lora_metadata")
# Initialize database
db = SqliteDatabase(db_path=config.db_path, logger=logger)
model_record_service = ModelRecordServiceSQL(db, logger)
# Process single file or batch
if args.batch:
if not args.json_file.is_dir():
print(f"Error: {args.json_file} is not a directory", file=sys.stderr)
sys.exit(1)
json_files = list(args.json_file.glob("*.json"))
if not json_files:
print(f"No JSON files found in {args.json_file}", file=sys.stderr)
sys.exit(1)
print(f"Found {len(json_files)} JSON files to process")
for json_file in json_files:
process_single_file(
json_file, model_record_service, args, logger
)
else:
process_single_file(
args.json_file, model_record_service, args, logger
)
def process_single_file(
json_file: Path,
model_record_service: ModelRecordServiceBase,
args: argparse.Namespace,
logger: Any,
) -> None:
"""Process a single JSON file."""
if not json_file.exists():
print(f"Error: {json_file} does not exist", file=sys.stderr)
return
try:
with open(json_file, "r") as f:
json_data = json.load(f)
except json.JSONDecodeError as e:
print(f"Error reading {json_file}: {e}", file=sys.stderr)
return
# Find the LoRA model to update
lora_model = None
if args.model_key:
try:
lora_model = model_record_service.get_model(args.model_key)
if lora_model.type != ModelType.LoRA:
print(f"Error: Model {args.model_key} is not a LoRA model", file=sys.stderr)
return
except Exception:
print(f"Error: Model with key {args.model_key} not found", file=sys.stderr)
return
elif args.model_name:
# Search for LoRA by name
models = model_record_service.search_by_attr(
model_name=args.model_name,
model_type=ModelType.LoRA
)
if not models:
print(f"Error: No LoRA model found with name '{args.model_name}'", file=sys.stderr)
return
elif len(models) > 1:
print(f"Error: Multiple LoRA models found with name '{args.model_name}':", file=sys.stderr)
for model in models:
print(f" - {model.key}: {model.name} ({model.base})")
print("Please specify --model-key to select one", file=sys.stderr)
return
lora_model = models[0]
else:
# Try to match based on filename
base_name = json_file.stem
models = model_record_service.search_by_attr(
model_name=base_name,
model_type=ModelType.LoRA
)
if not models:
# Try partial match
all_loras = model_record_service.search_by_attr(model_type=ModelType.LoRA)
matches = [m for m in all_loras if base_name.lower() in m.name.lower()]
if not matches:
print(f"Error: No LoRA model found matching filename '{base_name}'", file=sys.stderr)
return
elif len(matches) > 1:
print(f"Error: Multiple LoRA models found matching '{base_name}':", file=sys.stderr)
for model in matches:
print(f" - {model.key}: {model.name} ({model.base})")
print("Please specify --model-name or --model-key", file=sys.stderr)
return
lora_model = matches[0]
elif len(models) > 1:
print(f"Error: Multiple LoRA models found with name '{base_name}':", file=sys.stderr)
for model in models:
print(f" - {model.key}: {model.name} ({model.base})")
print("Please specify --model-key to select one", file=sys.stderr)
return
else:
lora_model = models[0]
# Display current and proposed changes
print(f"\nProcessing: {json_file.name}")
print(f"Target LoRA: {lora_model.name} (key: {lora_model.key})")
print(f"Current base model: {lora_model.base}")
if args.dry_run:
print("\n--- DRY RUN MODE ---")
print("Proposed changes:")
if json_data.get("activation text"):
print(f" Trigger phrases: {json_data['activation text']}")
description = build_description(json_data)
if description:
print(f" Description: {description[:100]}..." if len(description) > 100 else f" Description: {description}")
if args.update_base_model and json_data.get("sd version"):
base_model = map_sd_version_to_base_model(json_data["sd version"])
if base_model:
print(f" Base model: {lora_model.base}{base_model}")
print("--- END DRY RUN ---\n")
else:
# Apply the updates
success = process_lora_metadata(
model_record_service,
lora_model,
json_data,
args.update_base_model
)
if success:
print("✓ Successfully updated metadata")
else:
print("✗ No changes made")
if __name__ == "__main__":
main()

View File

@@ -1,136 +0,0 @@
# LoRA Metadata Import/Export Tools
These scripts allow you to import and export LoRA metadata between JSON files and the InvokeAI database.
## JSON Format
The JSON format used by these tools is:
```json
{
"description": "Description of the LoRA",
"sd version": "SDXL",
"activation text": "trigger1, trigger2",
"preferred weight": 0.8,
"negative text": "negative prompts to avoid",
"notes": "Additional notes about the LoRA"
}
```
## Import Script: `import_lora_metadata.py`
Imports metadata from JSON files into existing LoRA models in the InvokeAI database.
### Usage
```bash
# Import metadata for a single LoRA (matches by filename)
python scripts/import_lora_metadata.py my_lora.json
# Import metadata by specifying the model name
python scripts/import_lora_metadata.py metadata.json --model-name "My LoRA Model"
# Import metadata by specifying the model key
python scripts/import_lora_metadata.py metadata.json --model-key "abc123def456"
# Dry run to see what would be changed
python scripts/import_lora_metadata.py my_lora.json --dry-run
# Update the base model type based on "sd version" field
python scripts/import_lora_metadata.py my_lora.json --update-base-model
# Batch import multiple JSON files from a directory
python scripts/import_lora_metadata.py /path/to/json/directory --batch
```
### Field Mappings
- `activation text``trigger_phrases` (comma-separated list)
- `description``description`
- `preferred weight`, `negative text`, `notes` → Combined into `description` field
- `sd version``base` (when `--update-base-model` is used)
### SD Version Mapping
When using `--update-base-model`, the script maps SD versions as follows:
- Contains "xl" or "sdxl" → StableDiffusionXL
- Contains "2" → StableDiffusion2
- Contains "1" or "1.5" → StableDiffusion1
- Contains "flux" → Flux
- Other → No update
## Export Script: `export_lora_metadata.py`
Exports LoRA metadata from the InvokeAI database to JSON files.
### Usage
```bash
# Export all LoRA models to current directory
python scripts/export_lora_metadata.py
# Export to specific directory
python scripts/export_lora_metadata.py --output-dir /path/to/output
# Export specific LoRA by name
python scripts/export_lora_metadata.py --model-name "My LoRA Model"
# Export specific LoRA by key
python scripts/export_lora_metadata.py --model-key "abc123def456"
# Custom filename pattern
python scripts/export_lora_metadata.py --filename-pattern "{base}_{name}.json"
# Pretty-print JSON output
python scripts/export_lora_metadata.py --pretty
# Overwrite existing files
python scripts/export_lora_metadata.py --overwrite
```
### Filename Patterns
Available placeholders for `--filename-pattern`:
- `{name}` - LoRA model name
- `{key}` - LoRA model key/ID
- `{base}` - Base model type (e.g., "sdxl", "sd-1", etc.)
## Examples
### Example 1: Import metadata for a newly added LoRA
1. Add a LoRA model to InvokeAI (e.g., `anime_style_v2.safetensors`)
2. Create a JSON file with metadata (`anime_style_v2.json`):
```json
{
"description": "Anime style LoRA trained on modern anime artwork",
"sd version": "SDXL",
"activation text": "anime style, modern anime",
"preferred weight": 0.7,
"negative text": "realistic, photorealistic",
"notes": "Works best with anime-focused base models"
}
```
3. Import the metadata:
```bash
python scripts/import_lora_metadata.py anime_style_v2.json
```
### Example 2: Batch export and import
1. Export all LoRA metadata:
```bash
python scripts/export_lora_metadata.py --output-dir ./lora_metadata --pretty
```
2. Edit the JSON files as needed
3. Import all metadata back:
```bash
python scripts/import_lora_metadata.py ./lora_metadata --batch
```
## Notes
- The import script requires that LoRA models already exist in the InvokeAI database
- When importing, the script will try to match JSON filenames to LoRA model names
- Use `--dry-run` to preview changes before applying them
- The scripts preserve existing data when possible (e.g., appending to descriptions rather than replacing)