Compare commits

...

22 Commits

Author SHA1 Message Date
Eugene Brodsky
bb066f6c33 (ci) remove python 3.10 from the test matrix; comment out GPU tests for now 2025-03-28 15:03:13 -04:00
Mary Hipp
3f58c68c09 fix tag invalidation 2025-03-28 10:52:27 -04:00
Mary Hipp
e50c7e5947 restore multiple key 2025-03-28 10:52:27 -04:00
Mary Hipp
4a83700fe4 if clientSideUploading is enabled, handle bulk uploads using that flow 2025-03-28 10:52:27 -04:00
jazzhaiku
a53e1ccf08 Small improvements (#7842)
## Summary

- Extend `ModelOnDisk` with caching, type hints, default args
- Fail early if there is an error classifying a config

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-03-28 12:21:41 +11:00
jazzhaiku
1af9930951 Merge branch 'main' into small-improvements 2025-03-28 12:11:09 +11:00
psychedelicious
c6f96613fc chore(ui): typegen 2025-03-28 08:14:06 +11:00
psychedelicious
258bf736da fix(nodes): handle zero fade size (e.g. mask blur 0)
Closes #7850
2025-03-28 08:14:06 +11:00
jazzhaiku
c9dc27afbb Merge branch 'main' into small-improvements 2025-03-27 08:14:48 +11:00
Billy
efd14ec0e4 Make ruff happy 2025-03-27 08:11:39 +11:00
Billy
21ee2b6251 Merge branch 'small-improvements' of github.com:invoke-ai/InvokeAI into small-improvements 2025-03-27 08:10:38 +11:00
Billy
82dd2d508f Deprecate checkpoint as file, diffusers as directory terminology 2025-03-27 08:10:12 +11:00
jazzhaiku
5a59f6e3b8 Merge branch 'main' into small-improvements 2025-03-27 07:38:13 +11:00
Billy
60b5aef16a Log error -> warning 2025-03-27 06:56:22 +11:00
Billy
0e8b5484d5 Error handling 2025-03-26 19:31:57 +11:00
Billy
454506c83e Type hints 2025-03-26 19:12:49 +11:00
Billy
8f6ab67376 Logs 2025-03-26 16:34:32 +11:00
Billy
5afcc7778f Redundant 2025-03-26 16:32:19 +11:00
Billy
325e07d330 Error handling 2025-03-26 16:30:45 +11:00
Billy
a016bdc159 Add todo 2025-03-26 16:17:26 +11:00
Billy
a14f0b2864 Fail early on invalid config 2025-03-26 16:10:32 +11:00
Billy
721483318a Extend ModelOnDisk 2025-03-26 16:10:00 +11:00
17 changed files with 438 additions and 176 deletions

View File

@@ -61,13 +61,13 @@ jobs:
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.12'
cache: pip
cache-dependency-path: pyproject.toml
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff==0.9.9
run: pip install ruff==0.11.2
shell: bash
- name: ruff check

View File

@@ -39,26 +39,25 @@ jobs:
strategy:
matrix:
python-version:
- '3.10'
- '3.11'
platform:
- linux-cuda-11_7
- linux-rocm-5_2
# - linux-cuda-12_6
# - linux-rocm-6_2
- linux-cpu
- macos-default
- windows-cpu
include:
- platform: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- platform: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
# - platform: linux-cuda-12_6
# os: ubuntu-24.04
# github-env: $GITHUB_ENV
# - platform: linux-rocm-6_2
# os: ubuntu-24.04
# extra-index-url: 'https://download.pytorch.org/whl/rocm6.2'
# github-env: $GITHUB_ENV
- platform: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
os: ubuntu-24.04
github-env: $GITHUB_ENV
extra-index-url: 'https://download.pytorch.org/whl/cpu'
- platform: macos-default
os: macOS-14
github-env: $GITHUB_ENV

View File

@@ -96,6 +96,22 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image")
class ImageUploadEntry(BaseModel):
image_dto: ImageDTO = Body(description="The image DTO")
presigned_url: str = Body(description="The URL to get the presigned URL for the image upload")
@images_router.post("/", operation_id="create_image_upload_entry")
async def create_image_upload_entry(
width: int = Body(description="The width of the image"),
height: int = Body(description="The height of the image"),
board_id: Optional[str] = Body(default=None, description="The board to add this image to, if any"),
) -> ImageUploadEntry:
"""Uploads an image from a URL, not implemented"""
raise HTTPException(status_code=501, detail="Not implemented")
@images_router.delete("/i/{image_name}", operation_id="delete_image")
async def delete_image(
image_name: str = Path(description="The name of the image to delete"),

View File

@@ -1095,6 +1095,7 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
If the fade size is 0, the mask is returned as-is.
"""
mask: ImageField = InputField(description="The mask to expand")
@@ -1104,6 +1105,11 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_mask = context.images.get_pil(self.mask.image_name, mode="L")
if self.fade_size_px == 0:
# If the fade size is 0, just return the mask as-is.
image_dto = context.images.save(image=pil_mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
np_mask = numpy.array(pil_mask)
# Threshold the mask to create a binary mask - 0 for black, 255 for white

View File

@@ -67,6 +67,11 @@ class InvalidModelConfigException(Exception):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class FSLayout(Enum):
FILE = "file"
DIRECTORY = "directory"
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
@@ -102,29 +107,31 @@ class ModelOnDisk:
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
# TODO: Revisit checkpoint vs diffusers terminology
self.layout = FSLayout.DIRECTORY if path.is_dir() else FSLayout.FILE
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
self._state_dict_cache = {}
def hash(self):
def hash(self) -> str:
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def size(self):
if self.format_type == ModelFormat.Checkpoint:
def size(self) -> int:
if self.layout == FSLayout.FILE:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self):
if self.format_type == ModelFormat.Checkpoint:
def component_paths(self) -> set[Path]:
if self.layout == FSLayout.FILE:
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def repo_variant(self):
if self.format_type == ModelFormat.Checkpoint:
def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.layout == FSLayout.FILE:
return None
weight_files = list(self.path.glob("**/*.safetensors"))
@@ -140,14 +147,30 @@ class ModelOnDisk:
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
@staticmethod
def load_state_dict(path: Path):
def load_state_dict(self, path: Optional[Path] = None) -> Dict[str | int, Any]:
if path in self._state_dict_cache:
return self._state_dict_cache[path]
if not path:
components = list(self.component_paths())
match components:
case []:
raise ValueError("No weight files found for this model")
case [p]:
path = p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
assert isinstance(checkpoint, dict)
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
@@ -156,6 +179,7 @@ class ModelOnDisk:
raise ValueError(f"Unrecognized model extension: {path.suffix}")
state_dict = checkpoint.get("state_dict", checkpoint)
self._state_dict_cache[path] = state_dict
return state_dict
@@ -238,11 +262,13 @@ class ModelConfigBase(ABC, BaseModel):
for config_cls in sorted_by_match_speed:
try:
return config_cls.from_model_on_disk(mod, **overrides)
except InvalidModelConfigException:
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
if not config_cls.matches(mod):
continue
except Exception as e:
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
continue
else:
return config_cls.from_model_on_disk(mod, **overrides)
raise InvalidModelConfigException("No valid config found")
@@ -285,9 +311,6 @@ class ModelConfigBase(ABC, BaseModel):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
if not cls.matches(mod):
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
fields = cls.parse(mod)
cls.cast_overrides(overrides)
fields.update(overrides)
@@ -563,7 +586,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.format_type == ModelFormat.Checkpoint:
if mod.layout == FSLayout.FILE:
return False
config_path = mod.path / "config.json"

View File

@@ -1,6 +1,8 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
import { toast } from 'features/toast/toast';
@@ -8,7 +10,8 @@ import { t } from 'i18next';
import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { getCategories, getListImagesUrl } from 'services/api/util';
const log = logger('gallery');
/**
@@ -34,19 +37,56 @@ let lastUploadedToastTimeout: number | null = null;
export const addImageUploadedFulfilledListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.uploadImage.matchFulfilled,
matcher: isAnyOf(imagesApi.endpoints.uploadImage.matchFulfilled, imageUploadedClientSide),
effect: (action, { dispatch, getState }) => {
const imageDTO = action.payload;
let imageDTO: ImageDTO;
let silent;
let isFirstUploadOfBatch = true;
if (imageUploadedClientSide.match(action)) {
imageDTO = action.payload.imageDTO;
silent = action.payload.silent;
isFirstUploadOfBatch = action.payload.isFirstUploadOfBatch;
} else if (imagesApi.endpoints.uploadImage.matchFulfilled(action)) {
imageDTO = action.payload;
silent = action.meta.arg.originalArgs.silent;
isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
} else {
return;
}
if (silent || imageDTO.is_intermediate) {
// If the image is silent or intermediate, we don't want to show a toast
return;
}
if (imageUploadedClientSide.match(action)) {
const categories = getCategories(imageDTO);
const boardId = imageDTO.board_id ?? 'none';
dispatch(
imagesApi.util.invalidateTags([
{
type: 'ImageList',
id: getListImagesUrl({
board_id: boardId,
categories,
}),
},
{
type: 'Board',
id: boardId,
},
{
type: 'BoardImagesTotal',
id: boardId,
},
])
);
}
const state = getState();
log.debug({ imageDTO }, 'Image uploaded');
if (action.meta.arg.originalArgs.silent || imageDTO.is_intermediate) {
// When a "silent" upload is requested, or the image is intermediate, we can skip all post-upload actions,
// like toasts and switching the gallery view
return;
}
const boardId = imageDTO.board_id ?? 'none';
const DEFAULT_UPLOADED_TOAST = {
@@ -80,7 +120,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
*
* Default to true to not require _all_ image upload handlers to set this value
*/
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
if (isFirstUploadOfBatch) {
dispatch(boardIdSelected({ boardId }));
dispatch(galleryViewChanged('assets'));

View File

@@ -73,6 +73,7 @@ export type AppConfig = {
maxUpscaleDimension?: number;
allowPrivateBoards: boolean;
allowPrivateStylePresets: boolean;
allowClientSideUpload: boolean;
disabledTabs: TabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];
@@ -81,7 +82,6 @@ export type AppConfig = {
metadataFetchDebounce?: number;
workflowFetchDebounce?: number;
isLocal?: boolean;
maxImageUploadCount?: number;
sd: {
defaultModel?: string;
disabledControlNetModels: string[];

View File

@@ -0,0 +1,105 @@
import { useStore } from '@nanostores/react';
import { $authToken } from 'app/store/nanostores/authToken';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { useCallback } from 'react';
import { useCreateImageUploadEntryMutation } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const useClientSideUpload = () => {
const dispatch = useAppDispatch();
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const authToken = useStore($authToken);
const [createImageUploadEntry] = useCreateImageUploadEntryMutation();
const clientSideUpload = useCallback(
async (file: File, i: number): Promise<ImageDTO> => {
const image = new Image();
const objectURL = URL.createObjectURL(file);
image.src = objectURL;
let width = 0;
let height = 0;
let thumbnail: Blob | undefined;
await new Promise<void>((resolve) => {
image.onload = () => {
width = image.naturalWidth;
height = image.naturalHeight;
// Calculate thumbnail dimensions maintaining aspect ratio
let thumbWidth = width;
let thumbHeight = height;
if (width > height && width > 256) {
thumbWidth = 256;
thumbHeight = Math.round((height * 256) / width);
} else if (height > 256) {
thumbHeight = 256;
thumbWidth = Math.round((width * 256) / height);
}
const canvas = document.createElement('canvas');
canvas.width = thumbWidth;
canvas.height = thumbHeight;
const ctx = canvas.getContext('2d');
ctx?.drawImage(image, 0, 0, thumbWidth, thumbHeight);
canvas.toBlob(
(blob) => {
if (blob) {
thumbnail = blob;
// Clean up resources
URL.revokeObjectURL(objectURL);
image.src = ''; // Clear image source
image.remove(); // Remove the image element
canvas.width = 0; // Clear canvas
canvas.height = 0;
resolve();
}
},
'image/webp',
0.8
);
};
// Handle load errors
image.onerror = () => {
URL.revokeObjectURL(objectURL);
image.remove();
resolve();
};
});
const { presigned_url, image_dto } = await createImageUploadEntry({
width,
height,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
}).unwrap();
await fetch(`${presigned_url}/?type=full`, {
method: 'PUT',
body: file,
...(authToken && {
headers: {
Authorization: `Bearer ${authToken}`,
},
}),
});
await fetch(`${presigned_url}/?type=thumbnail`, {
method: 'PUT',
body: thumbnail,
...(authToken && {
headers: {
Authorization: `Bearer ${authToken}`,
},
}),
});
dispatch(imageUploadedClientSide({ imageDTO: image_dto, silent: false, isFirstUploadOfBatch: i === 0 }));
return image_dto;
},
[autoAddBoardId, authToken, createImageUploadEntry, dispatch]
);
return clientSideUpload;
};

View File

@@ -3,7 +3,7 @@ import { IconButton } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import type { FileRejection } from 'react-dropzone';
@@ -15,6 +15,7 @@ import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import type { SetOptional } from 'type-fest';
import { useClientSideUpload } from './useClientSideUpload';
type UseImageUploadButtonArgs =
| {
isDisabled?: boolean;
@@ -50,8 +51,9 @@ const log = logger('gallery');
*/
export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
const [uploadImage, request] = useUploadImageMutation();
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const clientSideUpload = useClientSideUpload();
const { t } = useTranslation();
const onDropAccepted = useCallback(
@@ -79,22 +81,27 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
onUpload(imageDTO);
}
} else {
const imageDTOs = await uploadImages(
files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: false,
isFirstUploadOfBatch: i === 0,
}))
);
let imageDTOs: ImageDTO[] = [];
if (isClientSideUploadEnabled) {
imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i)));
} else {
imageDTOs = await uploadImages(
files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: false,
isFirstUploadOfBatch: i === 0,
}))
);
}
if (onUpload) {
onUpload(imageDTOs);
}
}
},
[allowMultiple, autoAddBoardId, onUpload, uploadImage]
[allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload]
);
const onDropRejected = useCallback(
@@ -105,10 +112,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
file: rejection.file.path,
}));
log.error({ errors }, 'Invalid upload');
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
const description = t('toast.uploadFailedInvalidUploadDesc');
toast({
id: 'UPLOAD_FAILED',
@@ -120,7 +124,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
return;
}
},
[maxImageUploadCount, t]
[t]
);
const {
@@ -137,8 +141,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
onDropRejected,
disabled: isDisabled,
noDrag: true,
multiple: allowMultiple && (maxImageUploadCount === undefined || maxImageUploadCount > 1),
maxFiles: maxImageUploadCount,
multiple: allowMultiple,
});
return { getUploadButtonProps, getUploadInputProps, openUploader, request };

View File

@@ -8,12 +8,13 @@ import { useStore } from '@nanostores/react';
import { getStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { $focusedRegion } from 'common/hooks/focus';
import { useClientSideUpload } from 'common/hooks/useClientSideUpload';
import { setFileToPaste } from 'features/controlLayers/components/CanvasPasteModal';
import { DndDropOverlay } from 'features/dnd/DndDropOverlay';
import type { DndTargetState } from 'features/dnd/types';
import { $imageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
@@ -53,13 +54,6 @@ const zUploadFile = z
(file) => ({ message: `File extension .${file.name.split('.').at(-1)} is not supported` })
);
const getFilesSchema = (max?: number) => {
if (max === undefined) {
return z.array(zUploadFile);
}
return z.array(zUploadFile).max(max);
};
const sx = {
position: 'absolute',
top: 2,
@@ -74,22 +68,19 @@ const sx = {
export const FullscreenDropzone = memo(() => {
const { t } = useTranslation();
const ref = useRef<HTMLDivElement>(null);
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const [dndState, setDndState] = useState<DndTargetState>('idle');
const activeTab = useAppSelector(selectActiveTab);
const isImageViewerOpen = useStore($imageViewer);
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
const clientSideUpload = useClientSideUpload();
const validateAndUploadFiles = useCallback(
(files: File[]) => {
async (files: File[]) => {
const { getState } = getStore();
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
const parseResult = uploadFilesSchema.safeParse(files);
const parseResult = z.array(zUploadFile).safeParse(files);
if (!parseResult.success) {
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
const description = t('toast.uploadFailedInvalidUploadDesc');
toast({
id: 'UPLOAD_FAILED',
@@ -118,17 +109,23 @@ export const FullscreenDropzone = memo(() => {
const autoAddBoardId = selectAutoAddBoardId(getState());
const uploadArgs: UploadImageArg[] = files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
isFirstUploadOfBatch: i === 0,
}));
if (isClientSideUploadEnabled) {
for (const [i, file] of files.entries()) {
await clientSideUpload(file, i);
}
} else {
const uploadArgs: UploadImageArg[] = files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
isFirstUploadOfBatch: i === 0,
}));
uploadImages(uploadArgs);
uploadImages(uploadArgs);
}
},
[activeTab, isImageViewerOpen, maxImageUploadCount, t]
[activeTab, isImageViewerOpen, t, isClientSideUploadEnabled, clientSideUpload]
);
const onPaste = useCallback(

View File

@@ -1,31 +1,18 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { t } from 'i18next';
import { useMemo } from 'react';
import { PiUploadBold } from 'react-icons/pi';
export const GalleryUploadButton = () => {
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const uploadOptions = useMemo(() => ({ allowMultiple: maxImageUploadCount !== 1 }), [maxImageUploadCount]);
const uploadApi = useImageUploadButton(uploadOptions);
const uploadApi = useImageUploadButton({ allowMultiple: true });
return (
<>
<IconButton
size="sm"
alignSelf="stretch"
variant="link"
aria-label={
maxImageUploadCount === undefined || maxImageUploadCount > 1
? t('accessibility.uploadImages')
: t('accessibility.uploadImage')
}
tooltip={
maxImageUploadCount === undefined || maxImageUploadCount > 1
? t('accessibility.uploadImages')
: t('accessibility.uploadImage')
}
aria-label={t('accessibility.uploadImages')}
tooltip={t('accessibility.uploadImages')}
icon={<PiUploadBold />}
{...uploadApi.getUploadButtonProps()}
/>

View File

@@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import type { ImageDTO } from 'services/api/types';
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');
@@ -7,3 +8,9 @@ export const imageDownloaded = createAction('gallery/imageDownloaded');
export const imageCopiedToClipboard = createAction('gallery/imageCopiedToClipboard');
export const imageOpenedInNewTab = createAction('gallery/imageOpenedInNewTab');
export const imageUploadedClientSide = createAction<{
imageDTO: ImageDTO;
silent: boolean;
isFirstUploadOfBatch: boolean;
}>('gallery/imageUploadedClientSide');

View File

@@ -20,6 +20,7 @@ const initialConfigState: AppConfig = {
shouldFetchMetadataFromApi: false,
allowPrivateBoards: false,
allowPrivateStylePresets: false,
allowClientSideUpload: false,
disabledTabs: [],
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'],
@@ -218,6 +219,5 @@ export const selectWorkflowFetchDebounce = createConfigSelector((config) => conf
export const selectMetadataFetchDebounce = createConfigSelector((config) => config.metadataFetchDebounce ?? 300);
export const selectIsModelsTabDisabled = createConfigSelector((config) => config.disabledTabs.includes('models'));
export const selectMaxImageUploadCount = createConfigSelector((config) => config.maxImageUploadCount);
export const selectIsClientSideUploadEnabled = createConfigSelector((config) => config.allowClientSideUpload);
export const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);

View File

@@ -7,6 +7,8 @@ import type {
DeleteBoardResult,
GraphAndWorkflowResponse,
ImageDTO,
ImageUploadEntryRequest,
ImageUploadEntryResponse,
ListImagesArgs,
ListImagesResponse,
UploadImageArg,
@@ -287,6 +289,7 @@ export const imagesApi = api.injectEndpoints({
},
};
},
invalidatesTags: (result) => {
if (!result || result.is_intermediate) {
// Don't add it to anything
@@ -314,7 +317,13 @@ export const imagesApi = api.injectEndpoints({
];
},
}),
createImageUploadEntry: build.mutation<ImageUploadEntryResponse, ImageUploadEntryRequest>({
query: ({ width, height, board_id }) => ({
url: buildImagesUrl(),
method: 'POST',
body: { width, height, board_id },
}),
}),
deleteBoard: build.mutation<DeleteBoardResult, string>({
query: (board_id) => ({ url: buildBoardsUrl(board_id), method: 'DELETE' }),
invalidatesTags: () => [
@@ -549,6 +558,7 @@ export const {
useGetImageWorkflowQuery,
useLazyGetImageWorkflowQuery,
useUploadImageMutation,
useCreateImageUploadEntryMutation,
useClearIntermediatesMutation,
useAddImagesToBoardMutation,
useRemoveImagesFromBoardMutation,

View File

@@ -466,6 +466,30 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/images/": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* List Image Dtos
* @description Gets a list of image DTOs
*/
get: operations["list_image_dtos"];
put?: never;
/**
* Create Image Upload Entry
* @description Uploads an image from a URL, not implemented
*/
post: operations["create_image_upload_entry"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/images/i/{image_name}": {
parameters: {
query?: never;
@@ -619,26 +643,6 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/images/": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* List Image Dtos
* @description Gets a list of image DTOs
*/
get: operations["list_image_dtos"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/images/delete": {
parameters: {
query?: never;
@@ -2358,6 +2362,24 @@ export type components = {
*/
batch_ids: string[];
};
/** Body_create_image_upload_entry */
Body_create_image_upload_entry: {
/**
* Width
* @description The width of the image
*/
width: number;
/**
* Height
* @description The height of the image
*/
height: number;
/**
* Board Id
* @description The board to add this image to, if any
*/
board_id?: string | null;
};
/** Body_create_style_preset */
Body_create_style_preset: {
/**
@@ -6451,6 +6473,7 @@ export type components = {
* @description Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
* The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
* The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
* If the fade size is 0, the mask is returned as-is.
*/
ExpandMaskWithFadeInvocation: {
/**
@@ -10753,6 +10776,16 @@ export type components = {
*/
type: "i2l";
};
/** ImageUploadEntry */
ImageUploadEntry: {
/** @description The image DTO */
image_dto: components["schemas"]["ImageDTO"];
/**
* Presigned Url
* @description The URL to get the presigned URL for the image upload
*/
presigned_url: string;
};
/**
* ImageUrlsDTO
* @description The URLs for an image and its thumbnail.
@@ -23218,6 +23251,87 @@ export interface operations {
};
};
};
list_image_dtos: {
parameters: {
query?: {
/** @description The origin of images to list. */
image_origin?: components["schemas"]["ResourceOrigin"] | null;
/** @description The categories of image to include. */
categories?: components["schemas"]["ImageCategory"][] | null;
/** @description Whether to list intermediate images. */
is_intermediate?: boolean | null;
/** @description The board id to filter by. Use 'none' to find images without a board. */
board_id?: string | null;
/** @description The page offset */
offset?: number;
/** @description The number of images per page */
limit?: number;
/** @description The order of sort */
order_dir?: components["schemas"]["SQLiteDirection"];
/** @description Whether to sort by starred images first */
starred_first?: boolean;
/** @description The term to search for */
search_term?: string | null;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["OffsetPaginatedResults_ImageDTO_"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
create_image_upload_entry: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["Body_create_image_upload_entry"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["ImageUploadEntry"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
get_image_dto: {
parameters: {
query?: never;
@@ -23571,54 +23685,6 @@ export interface operations {
};
};
};
list_image_dtos: {
parameters: {
query?: {
/** @description The origin of images to list. */
image_origin?: components["schemas"]["ResourceOrigin"] | null;
/** @description The categories of image to include. */
categories?: components["schemas"]["ImageCategory"][] | null;
/** @description Whether to list intermediate images. */
is_intermediate?: boolean | null;
/** @description The board id to filter by. Use 'none' to find images without a board. */
board_id?: string | null;
/** @description The page offset */
offset?: number;
/** @description The number of images per page */
limit?: number;
/** @description The order of sort */
order_dir?: components["schemas"]["SQLiteDirection"];
/** @description Whether to sort by starred images first */
starred_first?: boolean;
/** @description The term to search for */
search_term?: string | null;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["OffsetPaginatedResults_ImageDTO_"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
delete_images_from_list: {
parameters: {
query?: never;

View File

@@ -354,3 +354,6 @@ export type UploadImageArg = {
*/
isFirstUploadOfBatch?: boolean;
};
export type ImageUploadEntryResponse = S['ImageUploadEntry'];
export type ImageUploadEntryRequest = paths['/api/v1/images/']['post']['requestBody']['content']['application/json'];

View File

@@ -71,7 +71,7 @@ def create_stripped_model(original_model_path: Path, stripped_model_path: Path)
print(f"Created clone of {original.name} at {stripped.path}")
for component_path in stripped.component_paths():
original_state_dict = ModelOnDisk.load_state_dict(component_path)
original_state_dict = stripped.load_state_dict(component_path)
stripped_state_dict = strip(original_state_dict) # type: ignore
with open(component_path, "w") as f:
json.dump(stripped_state_dict, f, indent=4)