Compare commits

..

9 Commits

Author SHA1 Message Date
Ryan Dick
ffa89126d1 WIP experimentation with replacing einops with native torch ops for faster execution. More investigation needed. 2024-10-22 14:38:39 +00:00
psychedelicious
d1bb4c2c70 fix(nodes): FluxDenoiseInvocation.controlnet_vae missing default=None 2024-10-22 10:54:15 +11:00
Mary Hipp
bbd89d54b4 add it to list 2024-10-19 14:08:49 +11:00
Mary Hipp
ee61006a49 add starter model 2024-10-19 14:08:49 +11:00
psychedelicious
0b43f5fd64 docs(ui): improve docstrings for LoggingOverrides 2024-10-19 08:04:20 +11:00
psychedelicious
6c61266990 refactor(ui): logging config handling
Introduce two-stage logging configuration and overrides for enabled status, log level and log namespaces.

The first stage in `<InvokeAIUI />`, before we set up redux (and therefore before we have access to the user's configured logging setup). In this stage, we use the overrides or default values.

The second stage is in `<App />`, after we set up redux, via `useSyncLoggingConfig`. In this stage, we use the overrides or the user's configured logging setup. This hook also handles pushing changes made by the user into localstorage.

Other changes:
- Extract logging config to util function
- Remove the `useEffect` from `SettingsModal` that was changing the logging settings
- Remove extraneous log effects from `useLogger`
- Export new `LoggingOverrides` type
2024-10-19 08:04:20 +11:00
Maximilian Maag
2d5afe8094 fix(installer): Print maximize suggestion when Python is found, not when it's missing 2024-10-18 16:35:51 -04:00
Maximilian Maag
2430137d19 fix(installer): Avoid misleading error message when searching for python binary
which prints a message to stderr when it doesn't find anything. In this case,
not finding anything is expected so the error is misleading.
2024-10-18 16:35:51 -04:00
psychedelicious
5440c03767 fix(app): directory traversal when deleting images 2024-10-18 14:27:41 +11:00
14 changed files with 236 additions and 78 deletions

View File

@@ -12,7 +12,7 @@ MINIMUM_PYTHON_VERSION=3.10.0
MAXIMUM_PYTHON_VERSION=3.11.100
PYTHON=""
for candidate in python3.11 python3.10 python3 python ; do
if ppath=`which $candidate`; then
if ppath=`which $candidate 2>/dev/null`; then
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
# we check that this found executable can actually run
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
@@ -30,10 +30,11 @@ done
if [ -z "$PYTHON" ]; then
echo "A suitable Python interpreter could not be found"
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
echo "For the best user experience we suggest enlarging or maximizing this window now."
read -p "Press any key to exit"
exit -1
fi
echo "For the best user experience we suggest enlarging or maximizing this window now."
exec $PYTHON ./lib/main.py ${@}
read -p "Press any key to exit"

View File

@@ -96,6 +96,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
default=None, input=Input.Connection, description="ControlNet models."
)
controlnet_vae: VAEField | None = InputField(
default=None,
description=FieldDescriptions.vae,
input=Input.Connection,
)

View File

@@ -110,15 +110,26 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e:
raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
path = self.__output_folder / image_name
base_folder = self.__thumbnails_folder if thumbnail else self.__output_folder
filename = get_thumbnail_name(image_name) if thumbnail else image_name
if thumbnail:
thumbnail_name = get_thumbnail_name(image_name)
path = self.__thumbnails_folder / thumbnail_name
# Strip any path information from the filename
basename = Path(filename).name
return path
if basename != filename:
raise ValueError("Invalid image name, potential directory traversal detected")
image_path = base_folder / basename
# Ensure the image path is within the base folder to prevent directory traversal
resolved_base = base_folder.resolve()
resolved_image_path = image_path.resolve()
if not resolved_image_path.is_relative_to(resolved_base):
raise ValueError("Image path outside outputs folder, potential directory traversal detected")
return resolved_image_path
def validate_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for an image or thumbnail."""

View File

@@ -9,8 +9,12 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
# Replaced original einops.rearrange(...) call with torch.reshape(...) for slightly faster performance.
# Original call: x = rearrange(x, "B H L D -> B L (H D)")
# x = x.permute(0, 2, 1, 3) # BHLD -> BLHD
# x = x.reshape(x.shape[0], x.shape[1], -1) # BLHD -> BL(HD)
x = rearrange(x, "B H L D -> B L (H D)")
return x
@@ -23,6 +27,9 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
# Replaced original einops.rearrange(...) call with torch.view(...) for slightly faster performance.
# Original call: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
# out = out.view(*out.shape[:-1], 2, 2)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()

View File

@@ -4,7 +4,6 @@ import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from invokeai.backend.flux.math import attention, rope
@@ -94,13 +93,14 @@ class SelfAttention(nn.Module):
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
# Unused code for reference:
# def forward(self, x: Tensor, pe: Tensor) -> Tensor:
# qkv = self.qkv(x)
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
# q, k = self.norm(q, k, v)
# x = attention(q, k, v, pe=pe)
# x = self.proj(x)
# return x
@dataclass
@@ -163,14 +163,22 @@ class DoubleStreamBlock(nn.Module):
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
# img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(
2, 0, 3, 1, 4
)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
# txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(
2, 0, 3, 1, 4
)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
@@ -229,7 +237,8 @@ class SingleStreamBlock(nn.Module):
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention

View File

@@ -186,6 +186,16 @@ dreamshaper_sdxl = StarterModel(
type=ModelType.Main,
dependencies=[sdxl_fp16_vae_fix],
)
archvis_sdxl = StarterModel(
name="Architecture (RealVisXL5)",
base=BaseModelType.StableDiffusionXL,
source="SG161222/RealVisXL_V5.0",
description="A photorealistic model, with architecture among its many use cases",
type=ModelType.Main,
dependencies=[sdxl_fp16_vae_fix],
)
sdxl_refiner = StarterModel(
name="SDXL Refiner",
base=BaseModelType.StableDiffusionXLRefiner,
@@ -545,6 +555,7 @@ STARTER_MODELS: list[StarterModel] = [
deliberate_inpainting_sd1,
juggernaut_sdxl,
dreamshaper_sdxl,
archvis_sdxl,
sdxl_refiner,
sdxl_fp16_vae_fix,
flux_vae,

View File

@@ -4,6 +4,7 @@ import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { useStudioInitAction } from 'app/hooks/useStudioInitAction';
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
import { useLogger } from 'app/logging/useLogger';
import { useSyncLoggingConfig } from 'app/logging/useSyncLoggingConfig';
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { PartialAppConfig } from 'app/types/invokeai';
@@ -59,6 +60,7 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
useGlobalModifiersInit();
useGlobalHotkeys();
useGetOpenAPISchemaQuery();
useSyncLoggingConfig();
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();

View File

@@ -2,6 +2,8 @@ import 'i18n';
import type { Middleware } from '@reduxjs/toolkit';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import type { LoggingOverrides } from 'app/logging/logger';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $customNavComponent } from 'app/store/nanostores/customNavComponent';
@@ -20,7 +22,7 @@ import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useMemo } from 'react';
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
import { Provider } from 'react-redux';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
import { $socketOptions } from 'services/events/stores';
@@ -46,6 +48,7 @@ interface Props extends PropsWithChildren {
isDebugging?: boolean;
logo?: ReactNode;
workflowCategories?: WorkflowCategory[];
loggingOverrides?: LoggingOverrides;
}
const InvokeAIUI = ({
@@ -65,7 +68,26 @@ const InvokeAIUI = ({
isDebugging = false,
logo,
workflowCategories,
loggingOverrides,
}: Props) => {
useLayoutEffect(() => {
/*
* We need to configure logging before anything else happens - useLayoutEffect ensures we set this at the first
* possible opportunity.
*
* Once redux initializes, we will check the user's settings and update the logging config accordingly. See
* `useSyncLoggingConfig`.
*/
$loggingOverrides.set(loggingOverrides);
// Until we get the user's settings, we will use the overrides OR default values.
configureLogging(
loggingOverrides?.logIsEnabled ?? true,
loggingOverrides?.logLevel ?? 'debug',
loggingOverrides?.logNamespaces ?? '*'
);
}, [loggingOverrides]);
useEffect(() => {
// configure API client token
if (token) {

View File

@@ -9,11 +9,10 @@ const serializeMessage: MessageSerializer = (message) => {
};
ROARR.serializeMessage = serializeMessage;
ROARR.write = createLogWriter();
export const BASE_CONTEXT = {};
const BASE_CONTEXT = {};
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
export const zLogNamespace = z.enum([
'canvas',
@@ -35,8 +34,22 @@ export const zLogLevel = z.enum(['trace', 'debug', 'info', 'warn', 'error', 'fat
export type LogLevel = z.infer<typeof zLogLevel>;
export const isLogLevel = (v: unknown): v is LogLevel => zLogLevel.safeParse(v).success;
/**
* Override logging settings.
* @property logIsEnabled Override the enabled log state. Omit to use the user's settings.
* @property logNamespaces Override the enabled log namespaces. Use `"*"` for all namespaces. Omit to use the user's settings.
* @property logLevel Override the log level. Omit to use the user's settings.
*/
export type LoggingOverrides = {
logIsEnabled?: boolean;
logNamespaces?: LogNamespace[] | '*';
logLevel?: LogLevel;
};
export const $loggingOverrides = atom<LoggingOverrides | undefined>();
// Translate human-readable log levels to numbers, used for log filtering
export const LOG_LEVEL_MAP: Record<LogLevel, number> = {
const LOG_LEVEL_MAP: Record<LogLevel, number> = {
trace: 10,
debug: 20,
info: 30,
@@ -44,3 +57,40 @@ export const LOG_LEVEL_MAP: Record<LogLevel, number> = {
error: 50,
fatal: 60,
};
/**
* Configure logging, pushing settings to local storage.
*
* @param logIsEnabled Whether logging is enabled
* @param logLevel The log level
* @param logNamespaces A list of log namespaces to enable, or '*' to enable all
*/
export const configureLogging = (
logIsEnabled: boolean = true,
logLevel: LogLevel = 'warn',
logNamespaces: LogNamespace[] | '*'
): void => {
if (!logIsEnabled) {
// Disable console log output
localStorage.setItem('ROARR_LOG', 'false');
} else {
// Enable console log output
localStorage.setItem('ROARR_LOG', 'true');
// Use a filter to show only logs of the given level
let filter = `context.logLevel:>=${LOG_LEVEL_MAP[logLevel]}`;
const namespaces = logNamespaces === '*' ? zLogNamespace.options : logNamespaces;
if (namespaces.length > 0) {
filter += ` AND (${namespaces.map((ns) => `context.namespace:${ns}`).join(' OR ')})`;
} else {
// This effectively hides all logs because we use namespaces for all logs
filter += ' AND context.namespace:undefined';
}
localStorage.setItem('ROARR_FILTER', filter);
}
ROARR.write = createLogWriter();
};

View File

@@ -1,53 +1,9 @@
import { createLogWriter } from '@roarr/browser-log-writer';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectSystemLogIsEnabled,
selectSystemLogLevel,
selectSystemLogNamespaces,
} from 'features/system/store/systemSlice';
import { useEffect, useMemo } from 'react';
import { ROARR, Roarr } from 'roarr';
import { useMemo } from 'react';
import type { LogNamespace } from './logger';
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP, logger } from './logger';
import { logger } from './logger';
export const useLogger = (namespace: LogNamespace) => {
const logLevel = useAppSelector(selectSystemLogLevel);
const logNamespaces = useAppSelector(selectSystemLogNamespaces);
const logIsEnabled = useAppSelector(selectSystemLogIsEnabled);
// The provided Roarr browser log writer uses localStorage to config logging to console
useEffect(() => {
if (logIsEnabled) {
// Enable console log output
localStorage.setItem('ROARR_LOG', 'true');
// Use a filter to show only logs of the given level
let filter = `context.logLevel:>=${LOG_LEVEL_MAP[logLevel]}`;
if (logNamespaces.length > 0) {
filter += ` AND (${logNamespaces.map((ns) => `context.namespace:${ns}`).join(' OR ')})`;
} else {
filter += ' AND context.namespace:undefined';
}
localStorage.setItem('ROARR_FILTER', filter);
} else {
// Disable console log output
localStorage.setItem('ROARR_LOG', 'false');
}
ROARR.write = createLogWriter();
}, [logLevel, logIsEnabled, logNamespaces]);
// Update the module-scoped logger context as needed
useEffect(() => {
// TODO: type this properly
//eslint-disable-next-line @typescript-eslint/no-explicit-any
const newContext: Record<string, any> = {
...BASE_CONTEXT,
};
$logger.set(Roarr.child(newContext));
}, []);
const log = useMemo(() => logger(namespace), [namespace]);
return log;

View File

@@ -0,0 +1,43 @@
import { useStore } from '@nanostores/react';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import {
selectSystemLogIsEnabled,
selectSystemLogLevel,
selectSystemLogNamespaces,
} from 'features/system/store/systemSlice';
import { useLayoutEffect } from 'react';
/**
* This hook synchronizes the logging configuration stored in Redux with the logging system, which uses localstorage.
*
* The sync is one-way: from Redux to localstorage. This means that changes made in the UI will be reflected in the
* logging system, but changes made directly to localstorage will not be reflected in the UI.
*
* See {@link configureLogging}
*/
export const useSyncLoggingConfig = () => {
useAssertSingleton('useSyncLoggingConfig');
const loggingOverrides = useStore($loggingOverrides);
const logLevel = useAppSelector(selectSystemLogLevel);
const logNamespaces = useAppSelector(selectSystemLogNamespaces);
const logIsEnabled = useAppSelector(selectSystemLogIsEnabled);
useLayoutEffect(() => {
configureLogging(
loggingOverrides?.logIsEnabled ?? logIsEnabled,
loggingOverrides?.logLevel ?? logLevel,
loggingOverrides?.logNamespaces ?? logNamespaces
);
}, [
logIsEnabled,
logLevel,
logNamespaces,
loggingOverrides?.logIsEnabled,
loggingOverrides?.logLevel,
loggingOverrides?.logNamespaces,
]);
};

View File

@@ -27,7 +27,6 @@ import { SettingsDeveloperLogNamespaces } from 'features/system/components/Setti
import { useClearIntermediates } from 'features/system/components/SettingsModal/useClearIntermediates';
import { StickyScrollable } from 'features/system/components/StickyScrollable';
import {
logIsEnabledChanged,
selectSystemShouldAntialiasProgressImage,
selectSystemShouldConfirmOnDelete,
selectSystemShouldConfirmOnNewSession,
@@ -76,12 +75,6 @@ const SettingsModal = ({ config = defaultConfig, children }: SettingsModalProps)
const dispatch = useAppDispatch();
const { t } = useTranslation();
useEffect(() => {
if (!config?.shouldShowDeveloperSettings) {
dispatch(logIsEnabledChanged(false));
}
}, [dispatch, config?.shouldShowDeveloperSettings]);
const { isNSFWCheckerAvailable, isWatermarkerAvailable } = useGetAppConfigQuery(undefined, {
selectFromResult: ({ data }) => ({
isNSFWCheckerAvailable: data?.nsfw_methods.includes('nsfw_checker') ?? false,

View File

@@ -1,5 +1,6 @@
export { default as InvokeAIUI } from './app/components/InvokeAIUI';
export type { StudioInitAction } from './app/hooks/useStudioInitAction';
export type { LoggingOverrides } from './app/logging/logger';
export type { PartialAppConfig } from './app/types/invokeai';
export { default as ParamMainModelSelect } from './features/parameters/components/MainModel/ParamMainModelSelect';
export { default as HotkeysModal } from './features/system/components/HotkeysModal/HotkeysModal';

View File

@@ -0,0 +1,51 @@
import platform
from pathlib import Path
import pytest
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
@pytest.fixture
def image_names() -> list[str]:
# Determine the platform and return a path that matches its format
if platform.system() == "Windows":
return [
# Relative paths
"folder\\evil.txt",
"folder\\..\\evil.txt",
# Absolute paths
"\\folder\\evil.txt",
"C:\\folder\\..\\evil.txt",
]
else:
return [
# Relative paths
"folder/evil.txt",
"folder/../evil.txt",
# Absolute paths
"/folder/evil.txt",
"/folder/../evil.txt",
]
def test_directory_traversal_protection(tmp_path: Path, image_names: list[str]):
"""Test that the image file storage prevents directory traversal attacks.
There are two safeguards in the `DiskImageFileStorage.get_path` method:
1. Check if the image name contains any directory traversal characters
2. Check if the resulting path is relative to the base folder
This test checks the first safeguard. I'd like to check the second but I cannot figure out a test case that would
pass the first check but fail the second check.
"""
image_files_disk = DiskImageFileStorage(tmp_path)
for name in image_names:
with pytest.raises(ValueError, match="Invalid image name, potential directory traversal detected"):
image_files_disk.get_path(name)
def test_image_paths_relative_to_storage_dir(tmp_path: Path):
image_files_disk = DiskImageFileStorage(tmp_path)
path = image_files_disk.get_path("foo.png")
assert path.is_relative_to(tmp_path)