refactor: gallery scroll

This commit is contained in:
psychedelicious
2025-06-24 15:51:28 +10:00
parent 049a8d8144
commit bee4cf41b4
13 changed files with 928 additions and 17 deletions

View File

@@ -1,7 +1,7 @@
import io
import json
import traceback
from typing import ClassVar, Optional
from typing import ClassVar, Literal, Optional
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
@@ -562,3 +562,63 @@ async def get_bulk_download_item(
return response
except Exception:
raise HTTPException(status_code=404)
@images_router.get("/collections/counts", operation_id="get_image_collection_counts")
async def get_image_collection_counts(
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to count."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to include intermediate images."),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find images without a board.",
),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> dict[str, int]:
"""Gets counts for starred and unstarred image collections"""
try:
counts = ApiDependencies.invoker.services.images.get_collection_counts(
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
return counts
except Exception:
raise HTTPException(status_code=500, detail="Failed to get collection counts")
@images_router.get("/collections/{collection}", operation_id="get_image_collection")
async def get_image_collection(
collection: Literal["starred", "unstarred"] = Path(..., description="The collection to retrieve from"),
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find images without a board.",
),
offset: int = Query(default=0, description="The offset within the collection"),
limit: int = Query(default=50, description="The number of images to return"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets images from a specific collection (starred or unstarred)"""
try:
image_dtos = ApiDependencies.invoker.services.images.get_collection_images(
collection=collection,
offset=offset,
limit=limit,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
return image_dtos
except Exception:
raise HTTPException(status_code=500, detail="Failed to get collection images")

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional
from typing import Literal, Optional
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
@@ -97,3 +97,31 @@ class ImageRecordStorageBase(ABC):
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
"""Gets the most recent image for a board."""
pass
@abstractmethod
def get_collection_counts(
self,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> dict[str, int]:
"""Gets counts for starred and unstarred image collections."""
pass
@abstractmethod
def get_collection_images(
self,
collection: Literal["starred", "unstarred"],
offset: int = 0,
limit: int = 10,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images from a specific collection (starred or unstarred)."""
pass

View File

@@ -1,6 +1,6 @@
import sqlite3
from datetime import datetime
from typing import Optional, Union, cast
from typing import Literal, Optional, Union, cast
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase
@@ -386,3 +386,181 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return None
return deserialize_image_record(dict(result))
def get_collection_counts(
self,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> dict[str, int]:
cursor = self._conn.cursor()
# Build the base query conditions (same as get_many)
base_query = """--sql
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Get starred count
starred_query = f"SELECT COUNT(*) {base_query} {query_conditions} AND images.starred = TRUE;"
cursor.execute(starred_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get unstarred count
unstarred_query = f"SELECT COUNT(*) {base_query} {query_conditions} AND images.starred = FALSE;"
cursor.execute(unstarred_query, query_params)
unstarred_count = cast(int, cursor.fetchone()[0])
return {
"starred_count": starred_count,
"unstarred_count": unstarred_count,
"total_count": starred_count + unstarred_count,
}
def get_collection_images(
self,
collection: Literal["starred", "unstarred"],
offset: int = 0,
limit: int = 10,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
cursor = self._conn.cursor()
# Base queries
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
# Add starred/unstarred filter
is_starred = collection == "starred"
query_conditions += """--sql
AND images.starred = ?
"""
query_params.append(is_starred)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Add ordering and pagination
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Execute images query
images_query += query_conditions + query_pagination + ";"
images_params = query_params.copy()
images_params.extend([limit, offset])
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Execute count query
count_query += query_conditions + ";"
cursor.execute(count_query, query_params)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Optional
from typing import Callable, Literal, Optional
from PIL.Image import Image as PILImageType
@@ -147,3 +147,31 @@ class ImageServiceABC(ABC):
def delete_images_on_board(self, board_id: str):
"""Deletes all images on a board."""
pass
@abstractmethod
def get_collection_counts(
self,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> dict[str, int]:
"""Gets counts for starred and unstarred image collections."""
pass
@abstractmethod
def get_collection_images(
self,
collection: Literal["starred", "unstarred"],
offset: int = 0,
limit: int = 10,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets images from a specific collection (starred or unstarred)."""
pass

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Literal, Optional
from PIL.Image import Image as PILImageType
@@ -309,3 +309,68 @@ class ImageService(ImageServiceABC):
except Exception as e:
self.__invoker.services.logger.error("Problem getting intermediates count")
raise e
def get_collection_counts(
self,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> dict[str, int]:
try:
return self.__invoker.services.image_records.get_collection_counts(
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
except Exception as e:
self.__invoker.services.logger.error("Problem getting collection counts")
raise e
def get_collection_images(
self,
collection: Literal["starred", "unstarred"],
offset: int = 0,
limit: int = 10,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self.__invoker.services.image_records.get_collection_images(
collection=collection,
offset=offset,
limit=limit,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
image_dtos = [
image_record_to_dto(
image_record=r,
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
)
for r in results.items
]
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=results.offset,
limit=results.limit,
total=results.total,
)
except Exception as e:
self.__invoker.services.logger.error("Problem getting collection images")
raise e

View File

@@ -12,11 +12,13 @@ module.exports = {
// TODO: ENABLE THIS RULE BEFORE v6.0.0
// 'i18next/no-literal-string': 'error',
// https://eslint.org/docs/latest/rules/no-console
'no-console': 'error',
'no-console': 'warn',
// https://eslint.org/docs/latest/rules/no-promise-executor-return
'no-promise-executor-return': 'error',
// https://eslint.org/docs/latest/rules/require-await
'require-await': 'error',
// TODO: ENABLE THIS RULE BEFORE v6.0.0
'react/display-name': 'off',
'no-restricted-properties': [
'error',
{

View File

@@ -39,7 +39,6 @@ import { authToastMiddleware } from 'services/api/authToastMiddleware';
import type { JsonObject } from 'type-fest';
import { STORAGE_PREFIX } from './constants';
import { getDebugLoggerMiddleware } from './middleware/debugLoggerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
@@ -177,7 +176,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
.concat(api.middleware)
.concat(dynamicMiddlewares)
.concat(authToastMiddleware)
.concat(getDebugLoggerMiddleware())
// .concat(getDebugLoggerMiddleware())
.prepend(listenerMiddleware.middleware),
enhancers: (getDefaultEnhancers) => {
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());

View File

@@ -25,9 +25,8 @@ import { useBoardName } from 'services/api/hooks/useBoardName';
import { GallerySettingsPopover } from './GallerySettingsPopover/GallerySettingsPopover';
import { GalleryUploadButton } from './GalleryUploadButton';
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
import { GalleryPagination } from './ImageGrid/GalleryPagination';
import { GallerySearch } from './ImageGrid/GallerySearch';
import { NewGallery } from './NewGallery';
const BASE_STYLES: ChakraProps['sx'] = {
fontWeight: 'semibold',
@@ -112,8 +111,9 @@ export const GalleryPanel = memo(() => {
/>
</Box>
</Collapse>
<GalleryImageGrid />
<GalleryPagination />
{/* <GalleryImageGrid />
<GalleryPagination /> */}
<NewGallery />
</Flex>
);
});

View File

@@ -0,0 +1,340 @@
import { Box, Flex, forwardRef, Grid, GridItem, Image, Skeleton, Spinner, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectImageCollectionQueryArgs } from 'features/gallery/store/gallerySelectors';
import React, { memo, useCallback, useMemo, useState } from 'react';
import { VirtuosoGrid } from 'react-virtuoso';
import { useGetImageCollectionCountsQuery, useGetImageCollectionQuery } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
// Placeholder image component for now
const ImagePlaceholder = memo(({ image }: { image: ImageDTO }) => (
<Image src={image.thumbnail_url} w="full" h="full" objectFit="contain" />
));
ImagePlaceholder.displayName = 'ImagePlaceholder';
// Loading skeleton component
const ImageSkeleton = memo(() => <Skeleton w="full" h="full" />);
ImageSkeleton.displayName = 'ImageSkeleton';
// Hook to manage image data for virtual scrolling
const useVirtualImageData = () => {
const queryArgs = useAppSelector(selectImageCollectionQueryArgs);
// Get total counts for position mapping
const { data: counts, isLoading: countsLoading } = useGetImageCollectionCountsQuery(queryArgs);
// Cache for loaded image ranges
const [loadedRanges, setLoadedRanges] = useState<Map<string, ImageDTO[]>>(new Map());
// Calculate position mappings
const positionInfo = useMemo(() => {
if (!counts) {
return null;
}
const result = {
totalCount: counts.total_count,
starredCount: counts.starred_count ?? 0,
unstarredCount: counts.unstarred_count ?? 0,
starredEnd: (counts.starred_count ?? 0) - 1,
};
return result;
}, [counts]);
// Clear cache when search parameters change
React.useEffect(() => {
setLoadedRanges(new Map());
}, [queryArgs.board_id, queryArgs.search_term, queryArgs.categories]);
// Return flag to indicate when search parameters have changed
const searchParamsChanged = useMemo(() => queryArgs, [queryArgs]);
// Function to generate cache key for a range
const getRangeKey = useCallback((collection: 'starred' | 'unstarred', offset: number, limit: number) => {
return `${collection}-${offset}-${limit}`;
}, []);
// Function to get images for a specific position range
const getImagesForRange = useCallback(
(startIndex: number, endIndex: number) => {
if (!positionInfo) {
return [];
}
const requestedImages: (ImageDTO | null)[] = new Array(endIndex - startIndex + 1).fill(null);
const rangesToLoad: Array<{
collection: 'starred' | 'unstarred';
offset: number;
limit: number;
targetStartIndex: number;
}> = [];
for (let i = startIndex; i <= endIndex; i++) {
const relativeIndex = i - startIndex;
// Handle case where there are no starred images
if (positionInfo.starredCount === 0 || i >= positionInfo.starredCount) {
// This position is in the unstarred collection
const unstarredOffset = i - positionInfo.starredCount;
const rangeKey = getRangeKey('unstarred', Math.floor(unstarredOffset / 50) * 50, 50);
const cachedRange = loadedRanges.get(rangeKey);
if (cachedRange) {
const imageIndex = unstarredOffset % 50;
if (imageIndex < cachedRange.length) {
requestedImages[relativeIndex] = cachedRange[imageIndex] ?? null;
}
} else {
// Need to load this range
const rangeOffset = Math.floor(unstarredOffset / 50) * 50;
rangesToLoad.push({
collection: 'unstarred',
offset: rangeOffset,
limit: 50,
targetStartIndex: i,
});
}
} else {
// This position is in the starred collection
const starredOffset = i;
const rangeKey = getRangeKey('starred', Math.floor(starredOffset / 50) * 50, 50);
const cachedRange = loadedRanges.get(rangeKey);
if (cachedRange) {
const imageIndex = starredOffset % 50;
if (imageIndex < cachedRange.length) {
requestedImages[relativeIndex] = cachedRange[imageIndex] ?? null;
}
} else {
// Need to load this range
const rangeOffset = Math.floor(starredOffset / 50) * 50;
rangesToLoad.push({
collection: 'starred',
offset: rangeOffset,
limit: 50,
targetStartIndex: i,
});
}
}
}
return { images: requestedImages, rangesToLoad };
},
[positionInfo, loadedRanges, getRangeKey]
);
return {
positionInfo,
countsLoading,
getImagesForRange,
setLoadedRanges,
loadedRanges,
searchParamsChanged,
};
};
// Component to handle loading image ranges
const ImageRangeLoader = memo(
({
collection,
offset,
limit,
onDataLoaded,
}: {
collection: 'starred' | 'unstarred';
offset: number;
limit: number;
onDataLoaded: (key: string, images: ImageDTO[]) => void;
}) => {
const queryArgs = useAppSelector(selectImageCollectionQueryArgs);
const { data } = useGetImageCollectionQuery({
collection,
offset,
limit,
...queryArgs,
});
// Update cache when data is loaded - use useEffect to avoid state update during render
React.useEffect(() => {
if (data?.items) {
const key = `${collection}-${offset}-${limit}`;
onDataLoaded(key, data.items);
}
}, [data, collection, offset, limit, onDataLoaded]);
return null;
}
);
ImageRangeLoader.displayName = 'ImageRangeLoader';
export const NewGallery = memo(() => {
const { positionInfo, countsLoading, getImagesForRange, setLoadedRanges, searchParamsChanged } =
useVirtualImageData();
const [activeRangeLoaders, setActiveRangeLoaders] = useState<Set<string>>(new Set());
// Force initial range loading when position info becomes available
const [hasInitiallyLoaded, setHasInitiallyLoaded] = useState(false);
// Reset hasInitiallyLoaded when search parameters change
React.useEffect(() => {
setHasInitiallyLoaded(false);
setActiveRangeLoaders(new Set());
}, [searchParamsChanged]);
// Use useEffect for initial load to avoid state updates during render
React.useEffect(() => {
if (positionInfo && !hasInitiallyLoaded) {
// Force initial load of first 100 positions to ensure we see both starred and unstarred
const initialResult = getImagesForRange(0, Math.min(99, positionInfo.totalCount - 1));
if (!Array.isArray(initialResult)) {
const { rangesToLoad } = initialResult;
rangesToLoad.forEach((rangeInfo) => {
const key = `${rangeInfo.collection}-${rangeInfo.offset}-${rangeInfo.limit}`;
if (!activeRangeLoaders.has(key)) {
setActiveRangeLoaders((prev) => new Set(prev).add(key));
}
});
}
setHasInitiallyLoaded(true);
}
}, [positionInfo, hasInitiallyLoaded, getImagesForRange, activeRangeLoaders]);
// Handle range changes from virtuoso
const handleRangeChanged = useCallback(
(range: { startIndex: number; endIndex: number }) => {
if (!positionInfo) {
return;
}
const result = getImagesForRange(range.startIndex, range.endIndex);
if (!Array.isArray(result)) {
const { rangesToLoad } = result;
// Start loading any missing ranges
rangesToLoad.forEach((rangeInfo) => {
const key = `${rangeInfo.collection}-${rangeInfo.offset}-${rangeInfo.limit}`;
if (!activeRangeLoaders.has(key)) {
setActiveRangeLoaders((prev) => new Set(prev).add(key));
}
});
}
},
[positionInfo, getImagesForRange, activeRangeLoaders]
);
// Handle when range data is loaded
const handleDataLoaded = useCallback(
(key: string, images: ImageDTO[]) => {
setLoadedRanges((prev) => new Map(prev).set(key, images));
setActiveRangeLoaders((prev) => {
const next = new Set(prev);
next.delete(key);
return next;
});
},
[setLoadedRanges]
);
const computeItemKey = useCallback(
(index: number) => {
const result = getImagesForRange(index, index);
if (Array.isArray(result)) {
return `loading-${index}`;
}
const { images } = result;
const image = images[0];
return image ? `image-${index}-${image.image_name}` : `skeleton-${index}`;
},
[getImagesForRange]
);
// Render item at specific index
const itemContent = useCallback(
(index: number) => {
if (!positionInfo) {
return <ImageSkeleton />;
}
const result = getImagesForRange(index, index);
if (Array.isArray(result)) {
return <ImageSkeleton />;
}
const { images } = result;
const image = images[0];
if (image) {
return <ImagePlaceholder image={image} />;
}
return <ImageSkeleton />;
},
[positionInfo, getImagesForRange]
);
if (countsLoading) {
return (
<Flex height="100%" alignItems="center" justifyContent="center">
<Spinner size="lg" />
<Text ml={4}>Loading gallery...</Text>
</Flex>
);
}
if (!positionInfo || positionInfo.totalCount === 0) {
return (
<Flex height="100%" alignItems="center" justifyContent="center">
<Text color="gray.500">No images found</Text>
</Flex>
);
}
return (
<Box height="100%" width="100%">
{/* Render active range loaders */}
{Array.from(activeRangeLoaders).map((key) => {
const [collection, offset, limit] = key.split('-');
return (
<ImageRangeLoader
key={key}
collection={collection as 'starred' | 'unstarred'}
offset={parseInt(offset ?? '0', 10)}
limit={parseInt(limit ?? '50', 10)}
onDataLoaded={handleDataLoaded}
/>
);
})}
{/* Virtualized grid */}
<VirtuosoGrid
totalCount={positionInfo.totalCount}
overscan={200}
rangeChanged={handleRangeChanged}
itemContent={itemContent}
style={style}
computeItemKey={computeItemKey}
components={components}
/>
</Box>
);
});
NewGallery.displayName = 'NewGallery';
const style = { height: '100%', width: '100%' };
const ListComponent = forwardRef((props, ref) => (
<Grid ref={ref} gridTemplateColumns="repeat(auto-fill, minmax(64px, 1fr))" gap={2} padding={2} {...props} />
));
const ItemComponent = forwardRef((props, ref) => <GridItem ref={ref} aspectRatio="1/1" {...props} />);
const components = {
Item: ItemComponent,
List: ListComponent,
};

View File

@@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
import type { ListBoardsArgs, ListImagesArgs } from 'services/api/types';
import type { ListBoardsArgs, ListImagesArgs, SQLiteDirection } from 'services/api/types';
export const selectFirstSelectedImage = createSelector(selectGallerySlice, (gallery) => gallery.selection.at(0));
export const selectLastSelectedImage = createSelector(selectGallerySlice, (gallery) => gallery.selection.at(-1));
@@ -38,6 +38,14 @@ export const selectListBoardsQueryArgs = createMemoizedSelector(
export const selectAutoAddBoardId = createSelector(selectGallerySlice, (gallery) => gallery.autoAddBoardId);
export const selectSelectedBoardId = createSelector(selectGallerySlice, (gallery) => gallery.selectedBoardId);
export const selectImageCollectionQueryArgs = createMemoizedSelector(selectGallerySlice, (gallery) => ({
board_id: gallery.selectedBoardId === 'none' ? undefined : gallery.selectedBoardId,
categories: gallery.galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES,
search_term: gallery.searchTerm || undefined,
order_dir: gallery.orderDir as SQLiteDirection,
is_intermediate: false,
}));
export const selectAutoAssignBoardOnClick = createSelector(
selectGallerySlice,
(gallery) => gallery.autoAssignBoardOnClick

View File

@@ -77,7 +77,12 @@ export const imagesApi = api.injectEndpoints({
}),
clearIntermediates: build.mutation<number, void>({
query: () => ({ url: buildImagesUrl('intermediates'), method: 'DELETE' }),
invalidatesTags: ['IntermediatesCount', 'InvocationCacheStatus'],
invalidatesTags: [
'IntermediatesCount',
'InvocationCacheStatus',
'ImageCollectionCounts',
{ type: 'ImageCollection', id: LIST_TAG },
],
}),
getImageDTO: build.query<ImageDTO, string>({
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}`) }),
@@ -106,7 +111,11 @@ export const imagesApi = api.injectEndpoints({
// We ignore the deleted images when getting tags to invalidate. If we did not, we will invalidate the queries
// that fetch image DTOs, metadata, and workflows. But we have just deleted those images! Invalidating the tags
// will force those queries to re-fetch, and the requests will of course 404.
return getTagsToInvalidateForBoardAffectingMutation(result.affected_boards);
return [
...getTagsToInvalidateForBoardAffectingMutation(result.affected_boards),
'ImageCollectionCounts',
{ type: 'ImageCollection', id: LIST_TAG },
];
},
}),
deleteImages: build.mutation<
@@ -125,7 +134,11 @@ export const imagesApi = api.injectEndpoints({
// We ignore the deleted images when getting tags to invalidate. If we did not, we will invalidate the queries
// that fetch image DTOs, metadata, and workflows. But we have just deleted those images! Invalidating the tags
// will force those queries to re-fetch, and the requests will of course 404.
return getTagsToInvalidateForBoardAffectingMutation(result.affected_boards);
return [
...getTagsToInvalidateForBoardAffectingMutation(result.affected_boards),
'ImageCollectionCounts',
{ type: 'ImageCollection', id: LIST_TAG },
];
},
}),
deleteUncategorizedImages: build.mutation<
@@ -140,7 +153,11 @@ export const imagesApi = api.injectEndpoints({
// We ignore the deleted images when getting tags to invalidate. If we did not, we will invalidate the queries
// that fetch image DTOs, metadata, and workflows. But we have just deleted those images! Invalidating the tags
// will force those queries to re-fetch, and the requests will of course 404.
return getTagsToInvalidateForBoardAffectingMutation(result.affected_boards);
return [
...getTagsToInvalidateForBoardAffectingMutation(result.affected_boards),
'ImageCollectionCounts',
{ type: 'ImageCollection', id: LIST_TAG },
];
},
}),
/**
@@ -184,6 +201,8 @@ export const imagesApi = api.injectEndpoints({
return [
...getTagsToInvalidateForImageMutation(result.starred_images),
...getTagsToInvalidateForBoardAffectingMutation(result.affected_boards),
'ImageCollectionCounts',
{ type: 'ImageCollection', id: LIST_TAG },
];
},
}),
@@ -206,6 +225,8 @@ export const imagesApi = api.injectEndpoints({
return [
...getTagsToInvalidateForImageMutation(result.unstarred_images),
...getTagsToInvalidateForBoardAffectingMutation(result.affected_boards),
'ImageCollectionCounts',
{ type: 'ImageCollection', id: LIST_TAG },
];
},
}),
@@ -399,6 +420,52 @@ export const imagesApi = api.injectEndpoints({
},
}),
}),
/**
* Get counts for starred and unstarred image collections
*/
getImageCollectionCounts: build.query<
paths['/api/v1/images/collections/counts']['get']['responses']['200']['content']['application/json'],
paths['/api/v1/images/collections/counts']['get']['parameters']['query']
>({
query: (queryArgs) => ({
url: buildImagesUrl('collections/counts'),
method: 'GET',
params: queryArgs,
}),
providesTags: ['ImageCollectionCounts', 'FetchOnReconnect'],
}),
/**
* Get images from a specific collection (starred or unstarred)
*/
getImageCollection: build.query<
paths['/api/v1/images/collections/{collection}']['get']['responses']['200']['content']['application/json'],
paths['/api/v1/images/collections/{collection}']['get']['parameters']['path'] &
paths['/api/v1/images/collections/{collection}']['get']['parameters']['query']
>({
query: ({ collection, ...queryArgs }) => ({
url: buildImagesUrl(`collections/${collection}`),
method: 'GET',
params: queryArgs,
}),
providesTags: (result, error, { collection, board_id, categories }) => {
const cacheKey = `${collection}-${board_id || 'all'}-${categories?.join(',') || 'all'}`;
return [{ type: 'ImageCollection', id: cacheKey }, 'FetchOnReconnect'];
},
async onQueryStarted(_, { dispatch, queryFulfilled }) {
// Populate the getImageDTO cache with these images, similar to listImages
const res = await queryFulfilled;
const imageDTOs = res.data.items;
const updates: Param0<typeof imagesApi.util.upsertQueryEntries> = [];
for (const imageDTO of imageDTOs) {
updates.push({
endpointName: 'getImageDTO',
arg: imageDTO.image_name,
value: imageDTO,
});
}
dispatch(imagesApi.util.upsertQueryEntries(updates));
},
}),
}),
});
@@ -420,6 +487,9 @@ export const {
useStarImagesMutation,
useUnstarImagesMutation,
useBulkDownloadImagesMutation,
useGetImageCollectionCountsQuery,
useGetImageCollectionQuery,
useLazyGetImageCollectionQuery,
} = imagesApi;
/**

View File

@@ -23,6 +23,8 @@ const tagTypes = [
'ImageList',
'ImageMetadata',
'ImageWorkflow',
'ImageCollectionCounts',
'ImageCollection',
'ImageMetadataFromFile',
'IntermediatesCount',
'SessionQueueItem',

View File

@@ -752,6 +752,46 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/images/collections/counts": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Get Image Collection Counts
* @description Gets counts for starred and unstarred image collections
*/
get: operations["get_image_collection_counts"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/images/collections/{collection}": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Get Image Collection
* @description Gets images from a specific collection (starred or unstarred)
*/
get: operations["get_image_collection"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/boards/": {
parameters: {
query?: never;
@@ -23675,6 +23715,97 @@ export interface operations {
};
};
};
get_image_collection_counts: {
parameters: {
query?: {
/** @description The origin of images to count. */
image_origin?: components["schemas"]["ResourceOrigin"] | null;
/** @description The categories of image to include. */
categories?: components["schemas"]["ImageCategory"][] | null;
/** @description Whether to include intermediate images. */
is_intermediate?: boolean | null;
/** @description The board id to filter by. Use 'none' to find images without a board. */
board_id?: string | null;
/** @description The term to search for */
search_term?: string | null;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": {
[key: string]: number;
};
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
get_image_collection: {
parameters: {
query?: {
/** @description The origin of images to list. */
image_origin?: components["schemas"]["ResourceOrigin"] | null;
/** @description The categories of image to include. */
categories?: components["schemas"]["ImageCategory"][] | null;
/** @description Whether to list intermediate images. */
is_intermediate?: boolean | null;
/** @description The board id to filter by. Use 'none' to find images without a board. */
board_id?: string | null;
/** @description The offset within the collection */
offset?: number;
/** @description The number of images to return */
limit?: number;
/** @description The order of sort */
order_dir?: components["schemas"]["SQLiteDirection"];
/** @description The term to search for */
search_term?: string | null;
};
header?: never;
path: {
/** @description The collection to retrieve from */
collection: "starred" | "unstarred";
};
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["OffsetPaginatedResults_ImageDTO_"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
list_boards: {
parameters: {
query?: {