mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Compare commits
9 Commits
cursor/eva
...
psychedeli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1a4376b75 | ||
|
|
ef4d5d7377 | ||
|
|
6b0dfd8427 | ||
|
|
471c010217 | ||
|
|
b1193022f7 | ||
|
|
2152ca092c | ||
|
|
ccc62ba56d | ||
|
|
9cf82de8c5 | ||
|
|
aced349152 |
@@ -11,6 +11,7 @@ import { memo, useCallback } from 'react';
|
||||
import { ErrorBoundary } from 'react-error-boundary';
|
||||
|
||||
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
|
||||
import ThemeLocaleProvider from './ThemeLocaleProvider';
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
interface Props {
|
||||
@@ -30,12 +31,14 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
|
||||
|
||||
return (
|
||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
|
||||
<AppContent />
|
||||
{!didStudioInit && <Loading />}
|
||||
</Box>
|
||||
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
|
||||
<GlobalModalIsolator />
|
||||
<ThemeLocaleProvider>
|
||||
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
|
||||
<AppContent />
|
||||
{!didStudioInit && <Loading />}
|
||||
</Box>
|
||||
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
|
||||
<GlobalModalIsolator />
|
||||
</ThemeLocaleProvider>
|
||||
</ErrorBoundary>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -42,7 +42,6 @@ import { $socketOptions } from 'services/events/stores';
|
||||
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
|
||||
|
||||
const App = lazy(() => import('./App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||
|
||||
interface Props extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
@@ -330,9 +329,7 @@ const InvokeAIUI = ({
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App config={config} studioInitAction={studioInitAction} />
|
||||
</ThemeLocaleProvider>
|
||||
<App config={config} studioInitAction={studioInitAction} />
|
||||
</React.Suspense>
|
||||
</Provider>
|
||||
</React.StrictMode>
|
||||
|
||||
@@ -170,7 +170,6 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
case 'canvas':
|
||||
// Go to the canvas tab, open the launchpad
|
||||
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
|
||||
store.dispatch(canvasReset());
|
||||
break;
|
||||
case 'workflows':
|
||||
// Go to the workflows tab
|
||||
|
||||
@@ -27,6 +27,7 @@ const sx = {
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
flexShrink: 0,
|
||||
h: 'full',
|
||||
aspectRatio: '1/1',
|
||||
borderWidth: 2,
|
||||
borderRadius: 'base',
|
||||
@@ -39,11 +40,11 @@ const sx = {
|
||||
|
||||
type Props = {
|
||||
item: S['SessionQueueItem'];
|
||||
number: number;
|
||||
index: number;
|
||||
isSelected: boolean;
|
||||
};
|
||||
|
||||
export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) => {
|
||||
export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const ctx = useCanvasSessionContext();
|
||||
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
|
||||
@@ -69,7 +70,7 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) =
|
||||
|
||||
return (
|
||||
<Flex
|
||||
id={getQueueItemElementId(item.item_id)}
|
||||
id={getQueueItemElementId(index)}
|
||||
sx={sx}
|
||||
data-selected={isSelected}
|
||||
onClick={onClick}
|
||||
@@ -78,7 +79,7 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) =
|
||||
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
|
||||
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail />}
|
||||
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
|
||||
<QueueItemNumber number={number} position="absolute" top={0} left={1} />
|
||||
<QueueItemNumber number={index + 1} position="absolute" top={0} left={1} />
|
||||
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -1,17 +1,148 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, forwardRef } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { QueueItemPreviewMini } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewMini';
|
||||
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { memo, useEffect } from 'react';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties, RefObject } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import type { Components, ItemContent, ListRange, VirtuosoHandle, VirtuosoProps } from 'react-virtuoso';
|
||||
import { Virtuoso } from 'react-virtuoso';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
import { getQueueItemElementId } from './shared';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
const virtuosoStyles = {
|
||||
width: '100%',
|
||||
height: '72px',
|
||||
} satisfies CSSProperties;
|
||||
|
||||
type VirtuosoContext = { selectedItemId: number | null };
|
||||
|
||||
/**
|
||||
* Scroll the item at the given index into view if it is not currently visible.
|
||||
*/
|
||||
const scrollIntoView = (
|
||||
targetIndex: number,
|
||||
rootEl: HTMLDivElement,
|
||||
virtuosoHandle: VirtuosoHandle,
|
||||
range: ListRange
|
||||
) => {
|
||||
if (range.endIndex === 0) {
|
||||
// No range is rendered; no need to scroll to anything.
|
||||
return;
|
||||
}
|
||||
|
||||
const targetItem = rootEl.querySelector(`#${getQueueItemElementId(targetIndex)}`);
|
||||
|
||||
if (!targetItem) {
|
||||
if (targetIndex > range.endIndex) {
|
||||
virtuosoHandle.scrollToIndex({
|
||||
index: targetIndex,
|
||||
behavior: 'auto',
|
||||
align: 'end',
|
||||
});
|
||||
} else if (targetIndex < range.startIndex) {
|
||||
virtuosoHandle.scrollToIndex({
|
||||
index: targetIndex,
|
||||
behavior: 'auto',
|
||||
align: 'start',
|
||||
});
|
||||
} else {
|
||||
log.debug(
|
||||
`Unable to find queue item at index ${targetIndex} but it is in the rendered range ${range.startIndex}-${range.endIndex}`
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// We found the image in the DOM, but it might be in the overscan range - rendered but not in the visible viewport.
|
||||
// Check if it is in the viewport and scroll if necessary.
|
||||
|
||||
const itemRect = targetItem.getBoundingClientRect();
|
||||
const rootRect = rootEl.getBoundingClientRect();
|
||||
|
||||
if (itemRect.left < rootRect.left) {
|
||||
virtuosoHandle.scrollToIndex({
|
||||
index: targetIndex,
|
||||
behavior: 'auto',
|
||||
align: 'start',
|
||||
});
|
||||
} else if (itemRect.right > rootRect.right) {
|
||||
virtuosoHandle.scrollToIndex({
|
||||
index: targetIndex,
|
||||
behavior: 'auto',
|
||||
align: 'end',
|
||||
});
|
||||
} else {
|
||||
// Image is already in view
|
||||
}
|
||||
|
||||
return;
|
||||
};
|
||||
|
||||
const useScrollableStagingArea = (rootRef: RefObject<HTMLDivElement>) => {
|
||||
const [scroller, scrollerRef] = useState<HTMLElement | null>(null);
|
||||
const [initialize, osInstance] = useOverlayScrollbars({
|
||||
defer: true,
|
||||
events: {
|
||||
initialized(osInstance) {
|
||||
// force overflow styles
|
||||
const { viewport } = osInstance.elements();
|
||||
viewport.style.overflowX = `var(--os-viewport-overflow-x)`;
|
||||
viewport.style.overflowY = `var(--os-viewport-overflow-y)`;
|
||||
},
|
||||
},
|
||||
options: {
|
||||
scrollbars: {
|
||||
visibility: 'auto',
|
||||
autoHide: 'scroll',
|
||||
autoHideDelay: 1300,
|
||||
theme: 'os-theme-dark',
|
||||
},
|
||||
overflow: {
|
||||
y: 'hidden',
|
||||
x: 'scroll',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
const { current: root } = rootRef;
|
||||
|
||||
if (scroller && root) {
|
||||
initialize({
|
||||
target: root,
|
||||
elements: {
|
||||
viewport: scroller,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
return () => {
|
||||
osInstance()?.destroy();
|
||||
};
|
||||
}, [scroller, initialize, osInstance, rootRef]);
|
||||
|
||||
return scrollerRef;
|
||||
};
|
||||
|
||||
export const StagingAreaItemsList = memo(() => {
|
||||
const canvasManager = useCanvasManagerSafe();
|
||||
const ctx = useCanvasSessionContext();
|
||||
const virtuosoRef = useRef<VirtuosoHandle>(null);
|
||||
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const items = useStore(ctx.$items);
|
||||
const selectedItemId = useStore(ctx.$selectedItemId);
|
||||
|
||||
const context = useMemo(() => ({ selectedItemId }), [selectedItemId]);
|
||||
const scrollerRef = useScrollableStagingArea(rootRef);
|
||||
|
||||
useEffect(() => {
|
||||
if (!canvasManager) {
|
||||
return;
|
||||
@@ -20,21 +151,64 @@ export const StagingAreaItemsList = memo(() => {
|
||||
return canvasManager.stagingArea.connectToSession(ctx.$selectedItemId, ctx.$progressData, ctx.$isPending);
|
||||
}, [canvasManager, ctx.$progressData, ctx.$selectedItemId, ctx.$isPending]);
|
||||
|
||||
useEffect(() => {
|
||||
return ctx.$selectedItemIndex.listen((index) => {
|
||||
if (!virtuosoRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!rootRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (index === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
scrollIntoView(index, rootRef.current, virtuosoRef.current, rangeRef.current);
|
||||
});
|
||||
}, [ctx.$selectedItemIndex]);
|
||||
|
||||
const onRangeChanged = useCallback((range: ListRange) => {
|
||||
rangeRef.current = range;
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex position="relative" maxW="full" w="full" h="72px">
|
||||
<ScrollableContent overflowX="scroll" overflowY="hidden">
|
||||
<Flex gap={2} w="full" h="full" justifyContent="safe center">
|
||||
{items.map((item, i) => (
|
||||
<QueueItemPreviewMini
|
||||
key={`${item.item_id}-mini`}
|
||||
item={item}
|
||||
number={i + 1}
|
||||
isSelected={selectedItemId === item.item_id}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Flex>
|
||||
<Box data-overlayscrollbars-initialize="" ref={rootRef} position="relative" w="full" h="full">
|
||||
<Virtuoso<S['SessionQueueItem'], VirtuosoContext>
|
||||
ref={virtuosoRef}
|
||||
context={context}
|
||||
data={items}
|
||||
horizontalDirection
|
||||
style={virtuosoStyles}
|
||||
itemContent={itemContent}
|
||||
components={components}
|
||||
rangeChanged={onRangeChanged}
|
||||
// Virtuoso expects the ref to be of HTMLElement | null | Window, but overlayscrollbars doesn't allow Window
|
||||
scrollerRef={scrollerRef as VirtuosoProps<S['SessionQueueItem'], VirtuosoContext>['scrollerRef']}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
StagingAreaItemsList.displayName = 'StagingAreaItemsList';
|
||||
|
||||
const itemContent: ItemContent<S['SessionQueueItem'], VirtuosoContext> = (index, item, { selectedItemId }) => (
|
||||
<QueueItemPreviewMini
|
||||
key={`${item.item_id}-mini`}
|
||||
item={item}
|
||||
index={index}
|
||||
isSelected={selectedItemId === item.item_id}
|
||||
/>
|
||||
);
|
||||
|
||||
const listSx = {
|
||||
'& > * + *': {
|
||||
pl: 2,
|
||||
},
|
||||
};
|
||||
|
||||
const components: Components<S['SessionQueueItem'], VirtuosoContext> = {
|
||||
List: forwardRef(({ context: _, ...rest }, ref) => {
|
||||
return <Flex ref={ref} sx={listSx} {...rest} />;
|
||||
}),
|
||||
};
|
||||
|
||||
@@ -13,7 +13,7 @@ export const getProgressMessage = (data?: S['InvocationProgressEvent'] | null) =
|
||||
|
||||
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
|
||||
|
||||
export const getQueueItemElementId = (itemId: number) => `queue-item-status-card-${itemId}`;
|
||||
export const getQueueItemElementId = (index: number) => `queue-item-preview-${index}`;
|
||||
|
||||
export const getOutputImageName = (item: S['SessionQueueItem']) => {
|
||||
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { ButtonGroup, Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
|
||||
import { StagingAreaToolbarAcceptButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarAcceptButton';
|
||||
import { StagingAreaToolbarDiscardAllButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardAllButton';
|
||||
import { StagingAreaToolbarDiscardSelectedButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardSelectedButton';
|
||||
@@ -12,7 +11,7 @@ import { StagingAreaToolbarPrevButton } from 'features/controlLayers/components/
|
||||
import { StagingAreaToolbarSaveSelectedToGalleryButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveSelectedToGalleryButton';
|
||||
import { StagingAreaToolbarToggleShowResultsButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarToggleShowResultsButton';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { memo, useEffect } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
import { StagingAreaAutoSwitchButtons } from './StagingAreaAutoSwitchButtons';
|
||||
@@ -23,16 +22,6 @@ export const StagingAreaToolbar = memo(() => {
|
||||
|
||||
const ctx = useCanvasSessionContext();
|
||||
|
||||
useEffect(() => {
|
||||
return ctx.$selectedItemId.listen((id) => {
|
||||
if (id !== null) {
|
||||
document
|
||||
.getElementById(getQueueItemElementId(id))
|
||||
?.scrollIntoView({ block: 'nearest', inline: 'nearest', behavior: 'auto' });
|
||||
}
|
||||
});
|
||||
}, [ctx.$selectedItemId]);
|
||||
|
||||
useHotkeys('meta+left', ctx.selectFirst, { preventDefault: true });
|
||||
useHotkeys('meta+right', ctx.selectLast, { preventDefault: true });
|
||||
|
||||
|
||||
@@ -482,11 +482,6 @@ export const NewGallery = memo(() => {
|
||||
|
||||
const context = useMemo<GridContext>(() => ({ imageNames, queryArgs }), [imageNames, queryArgs]);
|
||||
|
||||
// Item content function
|
||||
const itemContent: GridItemContent<string, GridContext> = useCallback((index, imageName) => {
|
||||
return <ImageAtPosition index={index} imageName={imageName} />;
|
||||
}, []);
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center" gap={4}>
|
||||
@@ -511,7 +506,7 @@ export const NewGallery = memo(() => {
|
||||
ref={virtuosoRef}
|
||||
context={context}
|
||||
data={imageNames}
|
||||
increaseViewportBy={2048}
|
||||
increaseViewportBy={4096}
|
||||
itemContent={itemContent}
|
||||
computeItemKey={computeItemKey}
|
||||
components={components}
|
||||
@@ -528,8 +523,12 @@ export const NewGallery = memo(() => {
|
||||
NewGallery.displayName = 'NewGallery';
|
||||
|
||||
const scrollSeekConfiguration: ScrollSeekConfiguration = {
|
||||
enter: (velocity) => velocity > 4096,
|
||||
exit: (velocity) => velocity === 0,
|
||||
enter: (velocity) => {
|
||||
return Math.abs(velocity) > 2048;
|
||||
},
|
||||
exit: (velocity) => {
|
||||
return velocity === 0;
|
||||
},
|
||||
};
|
||||
|
||||
// Styles
|
||||
@@ -549,6 +548,10 @@ const ListComponent: GridComponents<GridContext>['List'] = forwardRef(({ context
|
||||
});
|
||||
ListComponent.displayName = 'ListComponent';
|
||||
|
||||
const itemContent: GridItemContent<string, GridContext> = (index, imageName) => {
|
||||
return <ImageAtPosition index={index} imageName={imageName} />;
|
||||
};
|
||||
|
||||
const ItemComponent: GridComponents<GridContext>['Item'] = forwardRef(({ context: _, ...rest }, ref) => (
|
||||
<GridItem ref={ref} aspectRatio="1/1" {...rest} />
|
||||
));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { useCallback, useEffect, useRef } from 'react';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import type { ListRange } from 'react-virtuoso';
|
||||
import { imagesApi, useGetImageDTOsByNamesMutation } from 'services/api/endpoints/images';
|
||||
import { useThrottledCallback } from 'use-debounce';
|
||||
@@ -13,33 +13,20 @@ interface UseRangeBasedImageFetchingReturn {
|
||||
onRangeChanged: (range: ListRange) => void;
|
||||
}
|
||||
|
||||
const getUncachedNames = (imageNames: string[], cachedImageNames: string[], range: ListRange): string[] => {
|
||||
if (range.startIndex === range.endIndex) {
|
||||
// If the start and end indices are the same, no range to fetch
|
||||
return [];
|
||||
}
|
||||
const getUncachedNames = (imageNames: string[], cachedImageNames: string[], ranges: ListRange[]): string[] => {
|
||||
const uncachedNamesSet = new Set<string>();
|
||||
const cachedImageNamesSet = new Set(cachedImageNames);
|
||||
|
||||
if (imageNames.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const start = Math.max(0, range.startIndex);
|
||||
const end = Math.min(imageNames.length - 1, range.endIndex);
|
||||
|
||||
if (cachedImageNames.length === 0) {
|
||||
return imageNames.slice(start, end + 1);
|
||||
}
|
||||
|
||||
const uncachedNames: string[] = [];
|
||||
|
||||
for (let i = start; i <= end; i++) {
|
||||
const imageName = imageNames[i]!;
|
||||
if (!cachedImageNames.includes(imageName)) {
|
||||
uncachedNames.push(imageName);
|
||||
for (const range of ranges) {
|
||||
for (let i = range.startIndex; i <= range.endIndex; i++) {
|
||||
const n = imageNames[i]!;
|
||||
if (n && !cachedImageNamesSet.has(n)) {
|
||||
uncachedNamesSet.add(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return uncachedNames;
|
||||
return Array.from(uncachedNamesSet);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -53,39 +40,36 @@ export const useRangeBasedImageFetching = ({
|
||||
}: UseRangeBasedImageFetchingArgs): UseRangeBasedImageFetchingReturn => {
|
||||
const store = useAppStore();
|
||||
const [getImageDTOsByNames] = useGetImageDTOsByNamesMutation();
|
||||
const lastRangeRef = useRef<ListRange | null>(null);
|
||||
const [lastRange, setLastRange] = useState<ListRange | null>(null);
|
||||
const [pendingRanges, setPendingRanges] = useState<ListRange[]>([]);
|
||||
|
||||
const fetchImages = useCallback(
|
||||
(visibleRange: ListRange) => {
|
||||
(ranges: ListRange[], imageNames: string[]) => {
|
||||
if (!enabled) {
|
||||
return;
|
||||
}
|
||||
const cachedImageNames = imagesApi.util.selectCachedArgsForQuery(store.getState(), 'getImageDTO');
|
||||
const uncachedNames = getUncachedNames(imageNames, cachedImageNames, visibleRange);
|
||||
const uncachedNames = getUncachedNames(imageNames, cachedImageNames, ranges);
|
||||
if (uncachedNames.length === 0) {
|
||||
return;
|
||||
}
|
||||
getImageDTOsByNames({ image_names: uncachedNames });
|
||||
lastRangeRef.current = visibleRange;
|
||||
setPendingRanges([]);
|
||||
},
|
||||
[enabled, getImageDTOsByNames, imageNames, store]
|
||||
[enabled, getImageDTOsByNames, store]
|
||||
);
|
||||
|
||||
const throttledFetchImages = useThrottledCallback(fetchImages, 100);
|
||||
const throttledFetchImages = useThrottledCallback(fetchImages, 500);
|
||||
|
||||
const onRangeChanged = useCallback(
|
||||
(range: ListRange) => {
|
||||
throttledFetchImages(range);
|
||||
},
|
||||
[throttledFetchImages]
|
||||
);
|
||||
const onRangeChanged = useCallback((range: ListRange) => {
|
||||
setLastRange(range);
|
||||
setPendingRanges((prev) => [...prev, range]);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!lastRangeRef.current) {
|
||||
return;
|
||||
}
|
||||
throttledFetchImages(lastRangeRef.current);
|
||||
}, [imageNames, throttledFetchImages]);
|
||||
const combinedRanges = lastRange ? [...pendingRanges, lastRange] : pendingRanges;
|
||||
throttledFetchImages(combinedRanges, imageNames);
|
||||
}, [imageNames, lastRange, pendingRanges, throttledFetchImages]);
|
||||
|
||||
return {
|
||||
onRangeChanged,
|
||||
|
||||
@@ -26,14 +26,14 @@ const optionsObject: Record<Language, string> = {
|
||||
nl: 'Nederlands',
|
||||
pl: 'Polski',
|
||||
pt: 'Português',
|
||||
pt_BR: 'Português do Brasil',
|
||||
'pt-BR': 'Português do Brasil',
|
||||
ru: 'Русский',
|
||||
sv: 'Svenska',
|
||||
tr: 'Türkçe',
|
||||
ua: 'Украї́нська',
|
||||
vi: 'Tiếng Việt',
|
||||
zh_CN: '简体中文',
|
||||
zh_Hant: '漢語',
|
||||
'zh-CN': '简体中文',
|
||||
'zh-Hant': '漢語',
|
||||
};
|
||||
|
||||
const options = map(optionsObject, (label, value) => ({ label, value }));
|
||||
|
||||
@@ -9,7 +9,7 @@ import { uniq } from 'es-toolkit/compat';
|
||||
import type { Language, SystemState } from './types';
|
||||
|
||||
const initialSystemState: SystemState = {
|
||||
_version: 1,
|
||||
_version: 2,
|
||||
shouldConfirmOnDelete: true,
|
||||
shouldAntialiasProgressImage: false,
|
||||
shouldConfirmOnNewSession: true,
|
||||
@@ -96,6 +96,10 @@ const migrateSystemState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
state.language = (state as SystemState).language.replace('_', '-');
|
||||
state._version = 2;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
|
||||
@@ -17,20 +17,20 @@ const zLanguage = z.enum([
|
||||
'nl',
|
||||
'pl',
|
||||
'pt',
|
||||
'pt_BR',
|
||||
'pt-BR',
|
||||
'ru',
|
||||
'sv',
|
||||
'tr',
|
||||
'ua',
|
||||
'vi',
|
||||
'zh_CN',
|
||||
'zh_Hant',
|
||||
'zh-CN',
|
||||
'zh-Hant',
|
||||
]);
|
||||
export type Language = z.infer<typeof zLanguage>;
|
||||
export const isLanguage = (v: unknown): v is Language => zLanguage.safeParse(v).success;
|
||||
|
||||
export interface SystemState {
|
||||
_version: 1;
|
||||
_version: 2;
|
||||
shouldConfirmOnDelete: boolean;
|
||||
shouldAntialiasProgressImage: boolean;
|
||||
shouldConfirmOnNewSession: boolean;
|
||||
|
||||
@@ -13,7 +13,7 @@ export const StagingArea = memo(() => {
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex position="absolute" flexDir="column" bottom={4} gap={2} align="center" justify="center" left={4} right={4}>
|
||||
<Flex position="absolute" flexDir="column" bottom={2} gap={2} align="center" justify="center" left={2} right={2}>
|
||||
<StagingAreaItemsList />
|
||||
<StagingAreaToolbar />
|
||||
</Flex>
|
||||
|
||||
@@ -16,6 +16,8 @@ import {
|
||||
PiTextAaBold,
|
||||
} from 'react-icons/pi';
|
||||
|
||||
import { useHackOutDvTabDraggable } from './use-hack-out-dv-tab-draggable';
|
||||
|
||||
const TAB_ICONS: Record<TabName, IconType> = {
|
||||
generate: PiTextAaBold,
|
||||
canvas: PiBoundingBoxBold,
|
||||
@@ -41,6 +43,8 @@ export const TabWithLaunchpadIcon = memo((props: IDockviewPanelHeaderProps) => {
|
||||
setFocusedRegion(props.params.focusRegion);
|
||||
}, [props.params.focusRegion]);
|
||||
|
||||
useHackOutDvTabDraggable(ref);
|
||||
|
||||
return (
|
||||
<Flex ref={ref} alignItems="center" h="full" px={4} gap={3} onPointerDown={onPointerDown}>
|
||||
<Icon as={TAB_ICONS[activeTab]} color="invokeYellow.300" boxSize={5} />
|
||||
|
||||
@@ -5,6 +5,7 @@ import type { IDockviewPanelHeaderProps } from 'dockview';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
|
||||
import type { PanelParameters } from './auto-layout-context';
|
||||
import { useHackOutDvTabDraggable } from './use-hack-out-dv-tab-draggable';
|
||||
|
||||
export const TabWithoutCloseButton = memo((props: IDockviewPanelHeaderProps<PanelParameters>) => {
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
@@ -20,6 +21,8 @@ export const TabWithoutCloseButton = memo((props: IDockviewPanelHeaderProps<Pane
|
||||
setFocusedRegion(props.params.focusRegion);
|
||||
}, [props.params.focusRegion]);
|
||||
|
||||
useHackOutDvTabDraggable(ref);
|
||||
|
||||
return (
|
||||
<Flex ref={ref} alignItems="center" h="full" onPointerDown={onPointerDown}>
|
||||
<Text userSelect="none" px={4}>
|
||||
|
||||
@@ -7,6 +7,7 @@ import { memo, useCallback, useRef } from 'react';
|
||||
import { useIsGenerationInProgress } from 'services/api/endpoints/queue';
|
||||
|
||||
import type { PanelParameters } from './auto-layout-context';
|
||||
import { useHackOutDvTabDraggable } from './use-hack-out-dv-tab-draggable';
|
||||
|
||||
export const TabWithoutCloseButtonAndWithProgressIndicator = memo(
|
||||
(props: IDockviewPanelHeaderProps<PanelParameters>) => {
|
||||
@@ -25,6 +26,8 @@ export const TabWithoutCloseButtonAndWithProgressIndicator = memo(
|
||||
setFocusedRegion(props.params.focusRegion);
|
||||
}, [props.params.focusRegion]);
|
||||
|
||||
useHackOutDvTabDraggable(ref);
|
||||
|
||||
return (
|
||||
<Flex ref={ref} position="relative" alignItems="center" h="full" onPointerDown={onPointerDown}>
|
||||
<Text userSelect="none" px={4}>
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
import type { RefObject } from 'react';
|
||||
import { useEffect } from 'react';
|
||||
|
||||
/**
|
||||
* Prevent undesired dnd behavior in Dockview tabs.
|
||||
*
|
||||
* Dockview always sets the draggable flag on its tab elements, even when dnd is disabled. This hook traverses
|
||||
* up from the provided ref to find the closest tab element and sets its `draggable` attribute to `false`.
|
||||
*/
|
||||
export const useHackOutDvTabDraggable = (ref: RefObject<HTMLElement>) => {
|
||||
useEffect(() => {
|
||||
const el = ref.current;
|
||||
if (!el) {
|
||||
return;
|
||||
}
|
||||
const parentTab = el.closest('.dv-tab');
|
||||
if (!parentTab) {
|
||||
return;
|
||||
}
|
||||
parentTab.setAttribute('draggable', 'false');
|
||||
}, [ref]);
|
||||
};
|
||||
@@ -1 +1 @@
|
||||
__version__ = "6.0.0"
|
||||
__version__ = "6.1.0rc1"
|
||||
|
||||
@@ -1,263 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Export LoRA metadata from InvokeAI database to JSON files.
|
||||
|
||||
This script exports LoRA metadata to JSON files with the following format:
|
||||
{
|
||||
"description": "",
|
||||
"sd version": "Unknown",
|
||||
"activation text": "",
|
||||
"preferred weight": 0,
|
||||
"negative text": "",
|
||||
"notes": ""
|
||||
}
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
def map_base_model_to_sd_version(base_model: BaseModelType) -> str:
|
||||
"""Map BaseModelType to SD version string."""
|
||||
mapping = {
|
||||
BaseModelType.StableDiffusion1: "SD 1.5",
|
||||
BaseModelType.StableDiffusion2: "SD 2.x",
|
||||
BaseModelType.StableDiffusionXL: "SDXL",
|
||||
BaseModelType.Flux: "FLUX",
|
||||
}
|
||||
return mapping.get(base_model, "Unknown")
|
||||
|
||||
|
||||
def parse_description(description: Optional[str]) -> Dict[str, Any]:
|
||||
"""Parse description field to extract structured data."""
|
||||
result = {
|
||||
"description": "",
|
||||
"preferred_weight": 0,
|
||||
"negative_text": "",
|
||||
"notes": ""
|
||||
}
|
||||
|
||||
if not description:
|
||||
return result
|
||||
|
||||
# Try to extract structured parts from description
|
||||
lines = description.split("\n")
|
||||
current_section = "description"
|
||||
section_content = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
# Check for section markers
|
||||
if line.startswith("Preferred weight:"):
|
||||
# Save previous section
|
||||
if current_section == "description" and section_content:
|
||||
result["description"] = "\n".join(section_content).strip()
|
||||
|
||||
# Extract weight
|
||||
weight_match = re.search(r"Preferred weight:\s*([\d.]+)", line)
|
||||
if weight_match:
|
||||
try:
|
||||
result["preferred_weight"] = float(weight_match.group(1)) # type: ignore
|
||||
except ValueError:
|
||||
pass
|
||||
current_section = "after_weight"
|
||||
section_content = []
|
||||
|
||||
elif line.startswith("Negative prompt:"):
|
||||
# Extract negative text
|
||||
negative_text = line[len("Negative prompt:"):].strip()
|
||||
result["negative_text"] = negative_text
|
||||
current_section = "after_negative"
|
||||
section_content = []
|
||||
|
||||
elif line.startswith("Notes:"):
|
||||
# Extract notes
|
||||
notes = line[len("Notes:"):].strip()
|
||||
result["notes"] = notes
|
||||
current_section = "notes"
|
||||
section_content = [notes] if notes else []
|
||||
|
||||
elif line and current_section == "notes":
|
||||
# Continue adding to notes
|
||||
section_content.append(line)
|
||||
|
||||
elif line and current_section == "description":
|
||||
# Add to description
|
||||
section_content.append(line)
|
||||
|
||||
# Save final section
|
||||
if current_section == "description" and section_content:
|
||||
result["description"] = "\n".join(section_content).strip()
|
||||
elif current_section == "notes" and section_content:
|
||||
result["notes"] = "\n".join(section_content).strip()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def export_lora_metadata(lora_model: AnyModelConfig) -> Dict[str, Any]:
|
||||
"""Export LoRA model metadata to JSON format."""
|
||||
# Parse description to extract structured data
|
||||
parsed = parse_description(lora_model.description)
|
||||
|
||||
# Build activation text from trigger phrases
|
||||
activation_text = ""
|
||||
if hasattr(lora_model, 'trigger_phrases') and lora_model.trigger_phrases:
|
||||
activation_text = ", ".join(sorted(lora_model.trigger_phrases))
|
||||
|
||||
# Build final JSON structure
|
||||
return {
|
||||
"description": parsed["description"],
|
||||
"sd version": map_base_model_to_sd_version(lora_model.base),
|
||||
"activation text": activation_text,
|
||||
"preferred weight": parsed["preferred_weight"],
|
||||
"negative text": parsed["negative_text"],
|
||||
"notes": parsed["notes"]
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export LoRA metadata from InvokeAI database to JSON files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("."),
|
||||
help="Directory to save JSON files (default: current directory)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
help="Export only the specified LoRA model by name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-key",
|
||||
type=str,
|
||||
help="Export only the specified LoRA model by key",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filename-pattern",
|
||||
type=str,
|
||||
default="{name}.json",
|
||||
help="Filename pattern for JSON files (default: {name}.json). "
|
||||
"Available placeholders: {name}, {key}, {base}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
help="Overwrite existing JSON files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretty",
|
||||
action="store_true",
|
||||
help="Pretty-print JSON output",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize configuration and services
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger("export_lora_metadata")
|
||||
|
||||
# Initialize database
|
||||
db = SqliteDatabase(db_path=config.db_path, logger=logger)
|
||||
model_record_service = ModelRecordServiceSQL(db, logger)
|
||||
|
||||
# Create output directory if needed
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get LoRA models to export
|
||||
lora_models = []
|
||||
|
||||
if args.model_key:
|
||||
try:
|
||||
model = model_record_service.get_model(args.model_key)
|
||||
if model.type != ModelType.LoRA:
|
||||
print(f"Error: Model {args.model_key} is not a LoRA model", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
lora_models = [model]
|
||||
except Exception:
|
||||
print(f"Error: Model with key {args.model_key} not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
elif args.model_name:
|
||||
models = model_record_service.search_by_attr(
|
||||
model_name=args.model_name,
|
||||
model_type=ModelType.LoRA
|
||||
)
|
||||
if not models:
|
||||
print(f"Error: No LoRA model found with name '{args.model_name}'", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
lora_models = models
|
||||
|
||||
else:
|
||||
# Export all LoRA models
|
||||
lora_models = model_record_service.search_by_attr(model_type=ModelType.LoRA)
|
||||
|
||||
if not lora_models:
|
||||
print("No LoRA models found in database", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Exporting {len(lora_models)} LoRA model(s)...")
|
||||
|
||||
# Export each model
|
||||
exported_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for lora_model in lora_models:
|
||||
# Generate filename
|
||||
filename = args.filename_pattern.format(
|
||||
name=lora_model.name,
|
||||
key=lora_model.key,
|
||||
base=lora_model.base.value
|
||||
)
|
||||
|
||||
# Sanitize filename
|
||||
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
||||
|
||||
output_path = args.output_dir / filename
|
||||
|
||||
# Check if file exists
|
||||
if output_path.exists() and not args.overwrite:
|
||||
print(f"Skipping {lora_model.name}: {output_path} already exists")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Export metadata
|
||||
metadata = export_lora_metadata(lora_model)
|
||||
|
||||
# Write JSON file
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
if args.pretty:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
json.dump(metadata, f, ensure_ascii=False)
|
||||
|
||||
print(f"Exported {lora_model.name} → {output_path}")
|
||||
exported_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error exporting {lora_model.name}: {e}", file=sys.stderr)
|
||||
|
||||
# Summary
|
||||
print(f"\nExport complete:")
|
||||
print(f" Exported: {exported_count}")
|
||||
if skipped_count > 0:
|
||||
print(f" Skipped: {skipped_count} (use --overwrite to replace)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,288 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Import LoRA metadata from JSON files into InvokeAI database.
|
||||
|
||||
This script reads JSON files with the following format:
|
||||
{
|
||||
"description": "",
|
||||
"sd version": "Unknown",
|
||||
"activation text": "",
|
||||
"preferred weight": 0,
|
||||
"negative text": "",
|
||||
"notes": ""
|
||||
}
|
||||
|
||||
And imports the metadata into existing LoRA models in the InvokeAI database.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
def map_sd_version_to_base_model(sd_version: str) -> Optional[BaseModelType]:
|
||||
"""Map SD version string to BaseModelType."""
|
||||
sd_version_lower = sd_version.lower()
|
||||
|
||||
if "xl" in sd_version_lower or "sdxl" in sd_version_lower:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif "2" in sd_version_lower:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif "1" in sd_version_lower or "1.5" in sd_version_lower:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif "flux" in sd_version_lower:
|
||||
return BaseModelType.Flux
|
||||
else:
|
||||
return None # Will not update base model if unknown
|
||||
|
||||
|
||||
def build_description(json_data: Dict[str, Any]) -> str:
|
||||
"""Build a comprehensive description from JSON data."""
|
||||
parts = []
|
||||
|
||||
if json_data.get("description"):
|
||||
parts.append(json_data["description"])
|
||||
|
||||
if json_data.get("preferred weight") and json_data["preferred weight"] != 0:
|
||||
parts.append(f"Preferred weight: {json_data['preferred weight']}")
|
||||
|
||||
if json_data.get("negative text"):
|
||||
parts.append(f"Negative prompt: {json_data['negative text']}")
|
||||
|
||||
if json_data.get("notes"):
|
||||
parts.append(f"Notes: {json_data['notes']}")
|
||||
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def process_lora_metadata(
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
lora_model: AnyModelConfig,
|
||||
json_data: Dict[str, Any],
|
||||
update_base_model: bool = False,
|
||||
) -> bool:
|
||||
"""Process and update a single LoRA model with metadata from JSON."""
|
||||
changes = ModelRecordChanges()
|
||||
|
||||
# Map activation text to trigger phrases
|
||||
if json_data.get("activation text"):
|
||||
activation_texts = [text.strip() for text in json_data["activation text"].split(",")]
|
||||
changes.trigger_phrases = set(activation_texts)
|
||||
|
||||
# Build description from multiple fields
|
||||
description = build_description(json_data)
|
||||
if description:
|
||||
changes.description = description
|
||||
|
||||
# Optionally update base model type
|
||||
if update_base_model and json_data.get("sd version"):
|
||||
base_model = map_sd_version_to_base_model(json_data["sd version"])
|
||||
if base_model:
|
||||
changes.base = base_model
|
||||
|
||||
# Only update if we have changes
|
||||
if changes.model_dump(exclude_none=True):
|
||||
try:
|
||||
model_record_service.update_model(lora_model.key, changes)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error updating model {lora_model.name}: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Import LoRA metadata from JSON files into InvokeAI database"
|
||||
)
|
||||
parser.add_argument(
|
||||
"json_file",
|
||||
type=Path,
|
||||
help="Path to JSON file containing LoRA metadata",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
help="Name of the LoRA model to update (if not specified, will try to match based on filename)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-key",
|
||||
type=str,
|
||||
help="Key of the LoRA model to update (takes precedence over --model-name)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-base-model",
|
||||
action="store_true",
|
||||
help="Update the base model type based on 'sd version' field",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Show what would be updated without making changes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch",
|
||||
action="store_true",
|
||||
help="Process multiple JSON files in batch mode (json_file should be a directory)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize configuration and services
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger("import_lora_metadata")
|
||||
|
||||
# Initialize database
|
||||
db = SqliteDatabase(db_path=config.db_path, logger=logger)
|
||||
model_record_service = ModelRecordServiceSQL(db, logger)
|
||||
|
||||
# Process single file or batch
|
||||
if args.batch:
|
||||
if not args.json_file.is_dir():
|
||||
print(f"Error: {args.json_file} is not a directory", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
json_files = list(args.json_file.glob("*.json"))
|
||||
if not json_files:
|
||||
print(f"No JSON files found in {args.json_file}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(json_files)} JSON files to process")
|
||||
|
||||
for json_file in json_files:
|
||||
process_single_file(
|
||||
json_file, model_record_service, args, logger
|
||||
)
|
||||
else:
|
||||
process_single_file(
|
||||
args.json_file, model_record_service, args, logger
|
||||
)
|
||||
|
||||
|
||||
def process_single_file(
|
||||
json_file: Path,
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
args: argparse.Namespace,
|
||||
logger: Any,
|
||||
) -> None:
|
||||
"""Process a single JSON file."""
|
||||
if not json_file.exists():
|
||||
print(f"Error: {json_file} does not exist", file=sys.stderr)
|
||||
return
|
||||
|
||||
try:
|
||||
with open(json_file, "r") as f:
|
||||
json_data = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error reading {json_file}: {e}", file=sys.stderr)
|
||||
return
|
||||
|
||||
# Find the LoRA model to update
|
||||
lora_model = None
|
||||
|
||||
if args.model_key:
|
||||
try:
|
||||
lora_model = model_record_service.get_model(args.model_key)
|
||||
if lora_model.type != ModelType.LoRA:
|
||||
print(f"Error: Model {args.model_key} is not a LoRA model", file=sys.stderr)
|
||||
return
|
||||
except Exception:
|
||||
print(f"Error: Model with key {args.model_key} not found", file=sys.stderr)
|
||||
return
|
||||
elif args.model_name:
|
||||
# Search for LoRA by name
|
||||
models = model_record_service.search_by_attr(
|
||||
model_name=args.model_name,
|
||||
model_type=ModelType.LoRA
|
||||
)
|
||||
if not models:
|
||||
print(f"Error: No LoRA model found with name '{args.model_name}'", file=sys.stderr)
|
||||
return
|
||||
elif len(models) > 1:
|
||||
print(f"Error: Multiple LoRA models found with name '{args.model_name}':", file=sys.stderr)
|
||||
for model in models:
|
||||
print(f" - {model.key}: {model.name} ({model.base})")
|
||||
print("Please specify --model-key to select one", file=sys.stderr)
|
||||
return
|
||||
lora_model = models[0]
|
||||
else:
|
||||
# Try to match based on filename
|
||||
base_name = json_file.stem
|
||||
models = model_record_service.search_by_attr(
|
||||
model_name=base_name,
|
||||
model_type=ModelType.LoRA
|
||||
)
|
||||
if not models:
|
||||
# Try partial match
|
||||
all_loras = model_record_service.search_by_attr(model_type=ModelType.LoRA)
|
||||
matches = [m for m in all_loras if base_name.lower() in m.name.lower()]
|
||||
|
||||
if not matches:
|
||||
print(f"Error: No LoRA model found matching filename '{base_name}'", file=sys.stderr)
|
||||
return
|
||||
elif len(matches) > 1:
|
||||
print(f"Error: Multiple LoRA models found matching '{base_name}':", file=sys.stderr)
|
||||
for model in matches:
|
||||
print(f" - {model.key}: {model.name} ({model.base})")
|
||||
print("Please specify --model-name or --model-key", file=sys.stderr)
|
||||
return
|
||||
lora_model = matches[0]
|
||||
elif len(models) > 1:
|
||||
print(f"Error: Multiple LoRA models found with name '{base_name}':", file=sys.stderr)
|
||||
for model in models:
|
||||
print(f" - {model.key}: {model.name} ({model.base})")
|
||||
print("Please specify --model-key to select one", file=sys.stderr)
|
||||
return
|
||||
else:
|
||||
lora_model = models[0]
|
||||
|
||||
# Display current and proposed changes
|
||||
print(f"\nProcessing: {json_file.name}")
|
||||
print(f"Target LoRA: {lora_model.name} (key: {lora_model.key})")
|
||||
print(f"Current base model: {lora_model.base}")
|
||||
|
||||
if args.dry_run:
|
||||
print("\n--- DRY RUN MODE ---")
|
||||
print("Proposed changes:")
|
||||
|
||||
if json_data.get("activation text"):
|
||||
print(f" Trigger phrases: {json_data['activation text']}")
|
||||
|
||||
description = build_description(json_data)
|
||||
if description:
|
||||
print(f" Description: {description[:100]}..." if len(description) > 100 else f" Description: {description}")
|
||||
|
||||
if args.update_base_model and json_data.get("sd version"):
|
||||
base_model = map_sd_version_to_base_model(json_data["sd version"])
|
||||
if base_model:
|
||||
print(f" Base model: {lora_model.base} → {base_model}")
|
||||
|
||||
print("--- END DRY RUN ---\n")
|
||||
else:
|
||||
# Apply the updates
|
||||
success = process_lora_metadata(
|
||||
model_record_service,
|
||||
lora_model,
|
||||
json_data,
|
||||
args.update_base_model
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✓ Successfully updated metadata")
|
||||
else:
|
||||
print("✗ No changes made")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,136 +0,0 @@
|
||||
# LoRA Metadata Import/Export Tools
|
||||
|
||||
These scripts allow you to import and export LoRA metadata between JSON files and the InvokeAI database.
|
||||
|
||||
## JSON Format
|
||||
|
||||
The JSON format used by these tools is:
|
||||
|
||||
```json
|
||||
{
|
||||
"description": "Description of the LoRA",
|
||||
"sd version": "SDXL",
|
||||
"activation text": "trigger1, trigger2",
|
||||
"preferred weight": 0.8,
|
||||
"negative text": "negative prompts to avoid",
|
||||
"notes": "Additional notes about the LoRA"
|
||||
}
|
||||
```
|
||||
|
||||
## Import Script: `import_lora_metadata.py`
|
||||
|
||||
Imports metadata from JSON files into existing LoRA models in the InvokeAI database.
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
# Import metadata for a single LoRA (matches by filename)
|
||||
python scripts/import_lora_metadata.py my_lora.json
|
||||
|
||||
# Import metadata by specifying the model name
|
||||
python scripts/import_lora_metadata.py metadata.json --model-name "My LoRA Model"
|
||||
|
||||
# Import metadata by specifying the model key
|
||||
python scripts/import_lora_metadata.py metadata.json --model-key "abc123def456"
|
||||
|
||||
# Dry run to see what would be changed
|
||||
python scripts/import_lora_metadata.py my_lora.json --dry-run
|
||||
|
||||
# Update the base model type based on "sd version" field
|
||||
python scripts/import_lora_metadata.py my_lora.json --update-base-model
|
||||
|
||||
# Batch import multiple JSON files from a directory
|
||||
python scripts/import_lora_metadata.py /path/to/json/directory --batch
|
||||
```
|
||||
|
||||
### Field Mappings
|
||||
|
||||
- `activation text` → `trigger_phrases` (comma-separated list)
|
||||
- `description` → `description`
|
||||
- `preferred weight`, `negative text`, `notes` → Combined into `description` field
|
||||
- `sd version` → `base` (when `--update-base-model` is used)
|
||||
|
||||
### SD Version Mapping
|
||||
|
||||
When using `--update-base-model`, the script maps SD versions as follows:
|
||||
- Contains "xl" or "sdxl" → StableDiffusionXL
|
||||
- Contains "2" → StableDiffusion2
|
||||
- Contains "1" or "1.5" → StableDiffusion1
|
||||
- Contains "flux" → Flux
|
||||
- Other → No update
|
||||
|
||||
## Export Script: `export_lora_metadata.py`
|
||||
|
||||
Exports LoRA metadata from the InvokeAI database to JSON files.
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
# Export all LoRA models to current directory
|
||||
python scripts/export_lora_metadata.py
|
||||
|
||||
# Export to specific directory
|
||||
python scripts/export_lora_metadata.py --output-dir /path/to/output
|
||||
|
||||
# Export specific LoRA by name
|
||||
python scripts/export_lora_metadata.py --model-name "My LoRA Model"
|
||||
|
||||
# Export specific LoRA by key
|
||||
python scripts/export_lora_metadata.py --model-key "abc123def456"
|
||||
|
||||
# Custom filename pattern
|
||||
python scripts/export_lora_metadata.py --filename-pattern "{base}_{name}.json"
|
||||
|
||||
# Pretty-print JSON output
|
||||
python scripts/export_lora_metadata.py --pretty
|
||||
|
||||
# Overwrite existing files
|
||||
python scripts/export_lora_metadata.py --overwrite
|
||||
```
|
||||
|
||||
### Filename Patterns
|
||||
|
||||
Available placeholders for `--filename-pattern`:
|
||||
- `{name}` - LoRA model name
|
||||
- `{key}` - LoRA model key/ID
|
||||
- `{base}` - Base model type (e.g., "sdxl", "sd-1", etc.)
|
||||
|
||||
## Examples
|
||||
|
||||
### Example 1: Import metadata for a newly added LoRA
|
||||
|
||||
1. Add a LoRA model to InvokeAI (e.g., `anime_style_v2.safetensors`)
|
||||
2. Create a JSON file with metadata (`anime_style_v2.json`):
|
||||
```json
|
||||
{
|
||||
"description": "Anime style LoRA trained on modern anime artwork",
|
||||
"sd version": "SDXL",
|
||||
"activation text": "anime style, modern anime",
|
||||
"preferred weight": 0.7,
|
||||
"negative text": "realistic, photorealistic",
|
||||
"notes": "Works best with anime-focused base models"
|
||||
}
|
||||
```
|
||||
3. Import the metadata:
|
||||
```bash
|
||||
python scripts/import_lora_metadata.py anime_style_v2.json
|
||||
```
|
||||
|
||||
### Example 2: Batch export and import
|
||||
|
||||
1. Export all LoRA metadata:
|
||||
```bash
|
||||
python scripts/export_lora_metadata.py --output-dir ./lora_metadata --pretty
|
||||
```
|
||||
2. Edit the JSON files as needed
|
||||
3. Import all metadata back:
|
||||
```bash
|
||||
python scripts/import_lora_metadata.py ./lora_metadata --batch
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The import script requires that LoRA models already exist in the InvokeAI database
|
||||
- When importing, the script will try to match JSON filenames to LoRA model names
|
||||
- Use `--dry-run` to preview changes before applying them
|
||||
- The scripts preserve existing data when possible (e.g., appending to descriptions rather than replacing)
|
||||
Reference in New Issue
Block a user