mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-19 15:27:58 -05:00
Compare commits
60 Commits
lstein/enh
...
dev/ci/upd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f608d3743 | ||
|
|
df024dd982 | ||
|
|
45da85765c | ||
|
|
8618e41b32 | ||
|
|
4687f94141 | ||
|
|
440912dcff | ||
|
|
8b87a26e7e | ||
|
|
44ae93df3e | ||
|
|
2b213da967 | ||
|
|
e91e1eb9aa | ||
|
|
b24129fb3e | ||
|
|
350b1421bb | ||
|
|
f01c79a94f | ||
|
|
463f6352ce | ||
|
|
a80fe05e23 | ||
|
|
58d7833c5c | ||
|
|
5012f61599 | ||
|
|
85c33823c3 | ||
|
|
c83a112669 | ||
|
|
e04ada1319 | ||
|
|
d866dcb3d2 | ||
|
|
81ec476f3a | ||
|
|
1e6adf0a06 | ||
|
|
7d221e2518 | ||
|
|
56d3cbead0 | ||
|
|
5e8c97f1ba | ||
|
|
4687ad4ed6 | ||
|
|
994b247f8e | ||
|
|
0419f50ab0 | ||
|
|
f9f40adcdc | ||
|
|
3264d30b44 | ||
|
|
4d885653e9 | ||
|
|
475b6bef53 | ||
|
|
d39de0ad38 | ||
|
|
d14a7d756e | ||
|
|
b050c1bb8f | ||
|
|
276dfc591b | ||
|
|
b49d76ebee | ||
|
|
0bc2edc044 | ||
|
|
16488e7db8 | ||
|
|
974841926d | ||
|
|
8db20e0d95 | ||
|
|
f0e07bff5a | ||
|
|
3ec06a1fc3 | ||
|
|
6b79e2b407 | ||
|
|
0f95f7cea3 | ||
|
|
0b0068ab86 | ||
|
|
d753cff91a | ||
|
|
89f1909e4b | ||
|
|
37916a22ad | ||
|
|
8cb2fa8600 | ||
|
|
8f460b92f1 | ||
|
|
d99a08a441 | ||
|
|
b164330e3c | ||
|
|
0b0e6fe448 | ||
|
|
c132dbdefa | ||
|
|
f3081e7013 | ||
|
|
f904f14f9e | ||
|
|
8917a6d99b | ||
|
|
5a4765046e |
15
.github/workflows/mkdocs-material.yml
vendored
15
.github/workflows/mkdocs-material.yml
vendored
@@ -2,8 +2,7 @@ name: mkdocs-material
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'development'
|
||||
- 'refs/heads/v2.3'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -12,6 +11,10 @@ jobs:
|
||||
mkdocs-material:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
|
||||
REPO_NAME: '${{ github.repository }}'
|
||||
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
||||
steps:
|
||||
- name: checkout sources
|
||||
uses: actions/checkout@v3
|
||||
@@ -22,11 +25,15 @@ jobs:
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install requirements
|
||||
env:
|
||||
PIP_USE_PEP517: 1
|
||||
run: |
|
||||
python -m \
|
||||
pip install -r docs/requirements-mkdocs.txt
|
||||
pip install ".[docs]"
|
||||
|
||||
- name: confirm buildability
|
||||
run: |
|
||||
@@ -36,7 +43,7 @@ jobs:
|
||||
--verbose
|
||||
|
||||
- name: deploy to gh-pages
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
if: ${{ github.ref == 'refs/heads/v2.3' }}
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs gh-deploy \
|
||||
|
||||
@@ -89,7 +89,7 @@ experimental versions later.
|
||||
sudo apt update
|
||||
sudo apt install -y software-properties-common
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt install python3.10 python3-pip python3.10-venv
|
||||
sudo apt install -y python3.10 python3-pip python3.10-venv
|
||||
sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3
|
||||
```
|
||||
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from argparse import Namespace
|
||||
|
||||
from invokeai.app.services.metadata import PngMetadataService, MetadataServiceBase
|
||||
import invokeai.backend.util.logging as logger
|
||||
from typing import types
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ...backend.globals import Globals, copy_conf_to_globals
|
||||
from ..services.config import InvokeAIWebConfig
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
@@ -20,6 +17,7 @@ from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.metadata import PngMetadataService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@@ -45,11 +43,16 @@ class ApiDependencies:
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(config, event_handler_id: int):
|
||||
copy_conf_to_globals(config)
|
||||
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
|
||||
Globals.try_patchmatch = config.patchmatch
|
||||
Globals.always_use_cpu = config.always_use_cpu
|
||||
Globals.internet_available = config.internet_available and check_internet()
|
||||
Globals.disable_xformers = not config.xformers
|
||||
Globals.ckpt_convert = config.ckpt_convert
|
||||
|
||||
# TODO: Use a logger
|
||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||
# TO DO: Use the config to select the logger rather than use the default
|
||||
# invokeai logging module
|
||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
@@ -67,8 +70,9 @@ class ApiDependencies:
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config),
|
||||
model_manager=get_model_manager(config,logger),
|
||||
events=events,
|
||||
logger=logger,
|
||||
latents=latents,
|
||||
images=images,
|
||||
metadata=metadata,
|
||||
@@ -80,8 +84,7 @@ class ApiDependencies:
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
configuration=config,
|
||||
restoration=RestorationServices(config,logger),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
@@ -108,19 +108,20 @@ async def update_model(
|
||||
async def delete_model(model_name: str) -> None:
|
||||
"""Delete Model"""
|
||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_exists = model_name in model_names
|
||||
|
||||
# check if model exists
|
||||
print(f">> Checking for model {model_name}...")
|
||||
logger.info(f"Checking for model {model_name}...")
|
||||
|
||||
if model_exists:
|
||||
print(f">> Deleting Model: {model_name}")
|
||||
logger.info(f"Deleting Model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||
print(f">> Model Deleted: {model_name}")
|
||||
logger.info(f"Model Deleted: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
print(f">> Model not found")
|
||||
logger.error(f"Model not found")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import asyncio
|
||||
from inspect import signature
|
||||
|
||||
import uvicorn
|
||||
import invokeai.backend.util.logging as logger
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
@@ -12,11 +13,11 @@ from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.schema import schema
|
||||
|
||||
from ..backend import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions, models
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.config import InvokeAIWebConfig
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
@@ -32,15 +33,30 @@ app.add_middleware(
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
origins = []
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
web_config = {}
|
||||
config = {}
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
|
||||
ApiDependencies.initialize(
|
||||
config=web_config, event_handler_id=event_handler_id
|
||||
config=config, event_handler_id=event_handler_id, logger=logger
|
||||
)
|
||||
|
||||
|
||||
@@ -130,21 +146,12 @@ def overridden_redoc():
|
||||
|
||||
|
||||
def invoke_api():
|
||||
# parse command-line settings, environment and the init file
|
||||
# (this is a module global)
|
||||
global web_config
|
||||
web_config = InvokeAIWebConfig()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=web_config.allow_origins,
|
||||
allow_credentials=web_config.allow_credentials,
|
||||
allow_methods=web_config.allow_methods,
|
||||
allow_headers=web_config.allow_headers,
|
||||
)
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(app=app, host=web_config.host, port=web_config.port, loop=loop)
|
||||
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
@@ -1,52 +1,92 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_type_hints
|
||||
from pydantic import Field
|
||||
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..services.config import InvokeAISettings
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def add_parsers(
|
||||
subparsers,
|
||||
commands: list[type],
|
||||
command_field: str = "type",
|
||||
exclude_fields: list[str] = ["id", "type"],
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
):
|
||||
"""Adds parsers for each command to the subparsers"""
|
||||
|
||||
# Create subparsers for each command
|
||||
for command in commands:
|
||||
name = command.cmd_name()
|
||||
command_parser = subparsers.add_parser(name, help=command.__doc__)
|
||||
hints = get_type_hints(command)
|
||||
cmd_name = get_args(hints[command_field])[0]
|
||||
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
command.add_parser_arguments(command_parser)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = command.__fields__ # type: ignore
|
||||
for name, field in fields.items():
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
|
||||
add_field_argument(command_parser, name, field)
|
||||
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
if add_arguments is not None:
|
||||
graph.add_parser_arguments(command_parser)
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Add arguments for inputs
|
||||
for exposed_input in graph.exposed_inputs:
|
||||
node = graph.graph.get_node(exposed_input.node_path)
|
||||
field = node.__fields__[exposed_input.field]
|
||||
default_override = getattr(node, exposed_input.field)
|
||||
graph.add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
|
||||
|
||||
class CliContext:
|
||||
invoker: Invoker
|
||||
@@ -91,7 +131,7 @@ class ExitCli(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseCommand(ABC, InvokeAISettings):
|
||||
class BaseCommand(ABC, BaseModel):
|
||||
"""A CLI command"""
|
||||
|
||||
# All commands must include a type name like this:
|
||||
@@ -190,7 +230,7 @@ class HistoryCommand(BaseCommand):
|
||||
for i in range(min(self.count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = context.get_session().graph.get_node(entry_id)
|
||||
print(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
|
||||
|
||||
class SetDefaultCommand(BaseCommand):
|
||||
|
||||
@@ -10,10 +10,10 @@ import shlex
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||
|
||||
from ...backend import ModelManager
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...backend import ModelManager, Globals
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .commands import BaseCommand
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
# singleton object, class variable
|
||||
completer = None
|
||||
@@ -131,13 +131,13 @@ class Completer(object):
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
global completer
|
||||
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
completer = Completer(services.model_manager)
|
||||
completer = Completer(model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
@@ -153,7 +153,7 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
readline.parse_and_bind("set skip-completed-text on")
|
||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||
|
||||
histfile = Path(services.configuration.root_dir / ".invoke_history")
|
||||
histfile = Path(Globals.root, ".invoke_history")
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
@@ -161,8 +161,8 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f"{histfile}.old"
|
||||
print(
|
||||
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
logger.error(
|
||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
)
|
||||
histfile.replace(Path(newname))
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
||||
@@ -13,20 +13,20 @@ from typing import (
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.metadata import PngMetadataService
|
||||
|
||||
from .services.default_graphs import create_system_graphs
|
||||
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
@@ -34,8 +34,7 @@ from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.globals import copy_conf_to_globals # temporary workaround for code depending on Globals
|
||||
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
@@ -182,7 +181,7 @@ def invoke_all(context: CliContext):
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
print(
|
||||
context.invoker.services.logger.error(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
@@ -190,20 +189,24 @@ def invoke_all(context: CliContext):
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
config = InvokeAIAppConfig()
|
||||
copy_conf_to_globals(config) # temporary workaround
|
||||
model_manager = get_model_manager(config)
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
model_manager = get_model_manager(config,logger=logger)
|
||||
|
||||
# This initializes the autocompleter and returns it.
|
||||
# Currently nothing is done with the returned Completer
|
||||
# object, but the object can be used to change autocompletion
|
||||
# behavior on the fly, if desired.
|
||||
set_autocompleter(model_manager)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
output_folder = config.output_path
|
||||
metadata = PngMetadataService()
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||
)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
@@ -221,8 +224,8 @@ def invoke_cli():
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
configuration=config,
|
||||
restoration=RestorationServices(config,logger=logger),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
@@ -239,8 +242,6 @@ def invoke_cli():
|
||||
|
||||
context = CliContext(invoker, session, parser)
|
||||
|
||||
set_autocompleter(services)
|
||||
|
||||
while True:
|
||||
try:
|
||||
cmd_input = input("invoke> ")
|
||||
@@ -284,17 +285,8 @@ def invoke_cli():
|
||||
command = CliCommand(command = invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
if "id" in args:
|
||||
args["id"] = args["id"] or current_id
|
||||
|
||||
# remove extraneous fields from initialization
|
||||
exclude = ['link','link_node']
|
||||
command_args = dict()
|
||||
for key,value in args.items():
|
||||
if key not in exclude:
|
||||
command_args[key]=value
|
||||
|
||||
command = CliCommand(command=command_args)
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
|
||||
if command is None:
|
||||
continue
|
||||
@@ -373,12 +365,12 @@ def invoke_cli():
|
||||
invoke_all(context)
|
||||
|
||||
except InvalidArgs:
|
||||
print('Invalid command, use "help" to list commands')
|
||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
print("Session error: creating a new session")
|
||||
invoker.services.logger.warning("Session error: creating a new session")
|
||||
context.reset()
|
||||
|
||||
except ExitCli:
|
||||
|
||||
@@ -4,10 +4,10 @@ from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
|
||||
|
||||
from pydantic import BaseModel, BaseSettings, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.config import InvokeAISettings
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
services: InvocationServices
|
||||
@@ -36,7 +36,7 @@ class BaseInvocationOutput(BaseModel):
|
||||
return tuple(subclasses)
|
||||
|
||||
|
||||
class BaseInvocation(ABC, InvokeAISettings):
|
||||
class BaseInvocation(ABC, BaseModel):
|
||||
"""A node to process inputs and produce outputs.
|
||||
May use dependency injection in __init__ to receive providers.
|
||||
"""
|
||||
@@ -101,8 +101,8 @@ class CustomisedSchemaExtra(TypedDict):
|
||||
ui: UIConfig
|
||||
|
||||
|
||||
class InvocationConfig(BaseSettings.Config):
|
||||
"""Customizes pydantic's BaseSettings.Config class for use by Invocations.
|
||||
class InvocationConfig(BaseModel.Config):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
|
||||
|
||||
245
invokeai/app/invocations/compel.py
Normal file
245
invokeai/app/invocations/compel.py
Normal file
@@ -0,0 +1,245 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
Blend,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
)
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
|
||||
class CompelOutput(BaseInvocationOutput):
|
||||
"""Compel parser output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["compel_output"] = "compel_output"
|
||||
|
||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||
#fmt: on
|
||||
|
||||
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["compel"] = "compel"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
model: str = Field(default="", description="Model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
|
||||
# TODO: load without model
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
pipeline = model["model"]
|
||||
tokenizer = pipeline.tokenizer
|
||||
text_encoder = pipeline.text_encoder
|
||||
|
||||
# TODO: global? input?
|
||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||
#use_full_precision = False
|
||||
|
||||
# TODO: redo TI when separate model loding implemented
|
||||
#textual_inversion_manager = TextualInversionManager(
|
||||
# tokenizer=tokenizer,
|
||||
# text_encoder=text_encoder,
|
||||
# full_precision=use_full_precision,
|
||||
#)
|
||||
|
||||
def load_huggingface_concepts(concepts: list[str]):
|
||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||
self.prompt,
|
||||
lambda concepts: load_huggingface_concepts(concepts),
|
||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||
)
|
||||
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_str
|
||||
)
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
)
|
||||
|
||||
# TODO: support legacy blend?
|
||||
|
||||
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
|
||||
|
||||
if getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
||||
# TODO: long prompt support
|
||||
#if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.set(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
||||
for c in blend.prompts
|
||||
]
|
||||
)
|
||||
else:
|
||||
return len(
|
||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
||||
)
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||
) -> [str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError(
|
||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||
)
|
||||
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (
|
||||
" ".join([f.text for f in x.original])
|
||||
if type(x) is CrossAttentionControlSubstitute
|
||||
else str(x)
|
||||
)
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
for i, c in enumerate(blend.prompts):
|
||||
log_tokenization_for_prompt_object(
|
||||
c,
|
||||
tokenizer,
|
||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
||||
)
|
||||
elif type(p) is FlattenedPrompt:
|
||||
flattened_prompt: FlattenedPrompt = p
|
||||
if flattened_prompt.wants_cross_attention_control:
|
||||
original_fragments = []
|
||||
edited_fragments = []
|
||||
for f in flattened_prompt.children:
|
||||
if type(f) is CrossAttentionControlSubstitute:
|
||||
original_fragments += f.original
|
||||
edited_fragments += f.edited
|
||||
else:
|
||||
original_fragments.append(f)
|
||||
edited_fragments.append(f)
|
||||
|
||||
original_text = " ".join([x.text for x in original_fragments])
|
||||
log_tokenization_for_text(
|
||||
original_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap originals)",
|
||||
)
|
||||
edited_text = " ".join([x.text for x in edited_fragments])
|
||||
log_tokenization_for_text(
|
||||
edited_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap replacements)",
|
||||
)
|
||||
else:
|
||||
text = " ".join([x.text for x in flattened_prompt.children])
|
||||
log_tokenization_for_text(
|
||||
text, tokenizer, display_label=display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
"""
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace("</w>", " ")
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if truncate_if_too_long and i >= tokenizer.model_max_length:
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
else:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
print(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
print(f"{discarded}\x1b[0m")
|
||||
@@ -46,8 +46,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
@@ -150,6 +150,9 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
)
|
||||
mask = None
|
||||
|
||||
if self.fit:
|
||||
image = image.resize((self.width, self.height))
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
@@ -247,8 +250,8 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
|
||||
@@ -13,13 +13,13 @@ from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput, build_image_output
|
||||
from .compel import ConditioningField
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
@@ -113,8 +113,8 @@ class NoiseInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
|
||||
# Schema customisation
|
||||
@@ -138,14 +138,14 @@ class NoiseInvocation(BaseInvocation):
|
||||
|
||||
# Text to image
|
||||
class TextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from a prompt."""
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l"] = "t2l"
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
@@ -203,8 +203,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
return model
|
||||
|
||||
|
||||
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
||||
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
@@ -231,7 +233,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
|
||||
@@ -3,12 +3,11 @@ from invokeai.backend.model_management.model_manager import ModelManager
|
||||
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
logger = model_manager.logger
|
||||
if model_manager.valid_model(model_name):
|
||||
model = model_manager.get_model(model_name)
|
||||
else:
|
||||
model = model_manager.get_model()
|
||||
print(
|
||||
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
|
||||
)
|
||||
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
|
||||
|
||||
return model
|
||||
|
||||
@@ -1,379 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein)
|
||||
|
||||
'''Invokeai configuration system.
|
||||
|
||||
Arguments and fields are taken from the pydantic definition of the
|
||||
model. Defaults can be set by creating a yaml configuration file that
|
||||
has top-level keys corresponding to an invocation name, a command, or
|
||||
"globals" for global values such as `xformers_enabled`. Currently
|
||||
graphs cannot be configured this way, but their constituents can be.
|
||||
|
||||
[file: invokeai.yaml]
|
||||
|
||||
globals:
|
||||
nsfw_checker: False
|
||||
max_loaded_models: 5
|
||||
|
||||
txt2img:
|
||||
steps: 20
|
||||
scheduler: k_heun
|
||||
width: 768
|
||||
|
||||
img2img:
|
||||
width: 1024
|
||||
height: 1024
|
||||
|
||||
The default name of the configuration file is `invokeai.yaml`, located
|
||||
in INVOKEAI_ROOT. You can use any OmegaConf dictionary by passing it
|
||||
to the config object at initialization time:
|
||||
|
||||
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
||||
conf = InvokeAIAppConfig(conf=omegaconf)
|
||||
|
||||
By default, InvokeAIAppConfig will parse the contents of argv at
|
||||
initialization time. You may pass a list of strings in the optional
|
||||
`argv` argument to use instead of the system argv:
|
||||
|
||||
conf = InvokeAIAppConfig(arg=['--xformers_enabled'])
|
||||
|
||||
It is also possible to set a value at initialization time. This value
|
||||
has highest priority.
|
||||
|
||||
conf = InvokeAIAppConfig(xformers_enabled=True)
|
||||
|
||||
Any setting can be overwritten by setting an environment variable of
|
||||
form: "INVOKEAI_<command>_<value>", as in:
|
||||
|
||||
export INVOKEAI_txt2img_steps=30
|
||||
|
||||
Order of precedence (from highest):
|
||||
1) initialization options
|
||||
2) command line options
|
||||
3) environment variable options
|
||||
4) config file options
|
||||
5) pydantic defaults
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.invocations.generate import TextToImageInvocation
|
||||
|
||||
# get global configuration and print its nsfw_checker value
|
||||
conf = InvokeAIAppConfig()
|
||||
print(conf.nsfw_checker)
|
||||
|
||||
# get the text2image invocation and print its step value
|
||||
text2image = TextToImageInvocation()
|
||||
print(text2image.steps)
|
||||
|
||||
Computed properties:
|
||||
|
||||
The InvokeAIAppConfig object has a series of properties that
|
||||
resolve paths relative to the runtime root directory. They each return
|
||||
a Path object:
|
||||
|
||||
root_path - path to InvokeAI root
|
||||
output_path - path to default outputs directory
|
||||
model_conf_path - path to models.yaml
|
||||
conf - alias for the above
|
||||
embedding_path - path to the embeddings directory
|
||||
lora_path - path to the LoRA directory
|
||||
|
||||
|
||||
'''
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path('invokeai.yaml')
|
||||
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
'''
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
'''
|
||||
initconf : ClassVar[DictConfig] = None
|
||||
argparse_groups : ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list=sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt, _ = parser.parse_known_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
setattr(self, name, getattr(opt,name))
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else 'INVOKEAI_'
|
||||
if 'type' in get_type_hints(cls):
|
||||
default_settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
||||
else:
|
||||
default_settings_stanza = 'globals'
|
||||
initconf = cls.initconf.get(default_settings_stanza) if cls.initconf and default_settings_stanza in cls.initconf else None
|
||||
|
||||
fields = cls.__fields__
|
||||
cls.argparse_groups = {}
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
env_name = env_prefix+f'{cls.cmd_name()}_{name}'
|
||||
if initconf and name in initconf:
|
||||
field.default = initconf.get(name)
|
||||
if env_name in os.environ:
|
||||
field.default = os.environ[env_name]
|
||||
cls.add_field_argument(parser, name, field)
|
||||
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str='type')->str:
|
||||
hints = get_type_hints(self)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
return 'globals'
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls)->ArgumentParser:
|
||||
parser = ArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
)
|
||||
cls.add_parser_arguments(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self)->List[str]:
|
||||
return ['type','initconf']
|
||||
|
||||
class Config:
|
||||
env_file_encoding = 'utf-8'
|
||||
arbitrary_types_allowed = True
|
||||
env_prefix = 'INVOKEAI_'
|
||||
case_sensitive = True
|
||||
@classmethod
|
||||
def customise_sources(
|
||||
cls,
|
||||
init_settings,
|
||||
env_settings,
|
||||
file_secret_settings,
|
||||
):
|
||||
return (
|
||||
init_settings,
|
||||
cls._omegaconf_settings_source,
|
||||
env_settings,
|
||||
file_secret_settings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _omegaconf_settings_source(cls, settings: BaseSettings) -> dict[str, Any]:
|
||||
if initconf := InvokeAISettings.initconf:
|
||||
return initconf.get(settings.cmd_name(),{})
|
||||
else:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
else:
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
help=field.field_info.description,
|
||||
)
|
||||
def _find_root()->Path:
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||
elif (
|
||||
os.environ.get("VIRTUAL_ENV")
|
||||
and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists()
|
||||
or
|
||||
Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists()
|
||||
)
|
||||
):
|
||||
root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
'''
|
||||
Application-wide settings.
|
||||
'''
|
||||
#fmt: off
|
||||
type: Literal["globals"] = "globals"
|
||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||
infile : Path = Field(default=None, description='Path to a file of prompt commands to bulk generate from', category='Paths')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
|
||||
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
||||
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
|
||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
||||
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
||||
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
#fmt: on
|
||||
|
||||
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
|
||||
'''
|
||||
Initialize InvokeAIAppconfig.
|
||||
:param conf: alternate Omegaconf dictionary object
|
||||
:param argv: aternate sys.argv list
|
||||
:param **kwargs: attributes to initialize with
|
||||
'''
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set the runtime root directory. We parse command-line switches here
|
||||
# in order to pick up the --root_dir option.
|
||||
self.parse_args(argv)
|
||||
if not conf:
|
||||
try:
|
||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
except:
|
||||
pass
|
||||
InvokeAISettings.initconf = conf
|
||||
|
||||
# parse args again in order to pick up settings in configuration file
|
||||
self.parse_args(argv)
|
||||
|
||||
# restore initialization values
|
||||
hints = get_type_hints(self)
|
||||
for k in kwargs:
|
||||
setattr(self,k,parse_obj_as(hints[k],kwargs[k]))
|
||||
|
||||
@property
|
||||
def root_path(self)->Path:
|
||||
'''
|
||||
Path to the runtime root directory
|
||||
'''
|
||||
if self.root:
|
||||
return self.root.expanduser()
|
||||
else:
|
||||
return self.find_root()
|
||||
|
||||
@property
|
||||
def root_dir(self)->Path:
|
||||
'''
|
||||
Alias for above.
|
||||
'''
|
||||
return self.root_path
|
||||
|
||||
def _resolve(self,partial_path:Path)->Path:
|
||||
return (self.root_path / partial_path).resolve()
|
||||
|
||||
@property
|
||||
def output_path(self)->Path:
|
||||
'''
|
||||
Path to defaults outputs directory.
|
||||
'''
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def model_conf_path(self)->Path:
|
||||
'''
|
||||
Path to models configuration file.
|
||||
'''
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def conf(self)->Path:
|
||||
'''
|
||||
Path to models configuration file (alias for model_conf_path).
|
||||
'''
|
||||
return self.model_conf_path
|
||||
|
||||
@property
|
||||
def embedding_path(self)->Path:
|
||||
'''
|
||||
Path to the textual inversion embeddings directory.
|
||||
'''
|
||||
return self._resolve(self.embedding_dir) if self.embedding_dir else None
|
||||
|
||||
@property
|
||||
def lora_path(self)->Path:
|
||||
'''
|
||||
Path to the LoRA models directory.
|
||||
'''
|
||||
return self._resolve(self.lora_dir) if self.lora_dir else None
|
||||
|
||||
@property
|
||||
def autoconvert_path(self)->Path:
|
||||
'''
|
||||
Path to the directory containing models to be imported automatically at startup.
|
||||
'''
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
@property
|
||||
def gfpgan_model_path(self)->Path:
|
||||
'''
|
||||
Path to the GFPGAN model.
|
||||
'''
|
||||
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
|
||||
|
||||
@staticmethod
|
||||
def find_root()->Path:
|
||||
'''
|
||||
Choose the runtime root directory when not specified on command line or
|
||||
init file.
|
||||
'''
|
||||
return _find_root()
|
||||
|
||||
class InvokeAIWebConfig(InvokeAIAppConfig):
|
||||
'''
|
||||
Web-specific settings
|
||||
'''
|
||||
#fmt: off
|
||||
type : Literal["web"] = "web"
|
||||
allow_origins : List = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing')
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Cross-Origin Resource Sharing')
|
||||
allow_methods : List = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing')
|
||||
allow_headers : List = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing')
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||
#fmt: on
|
||||
@@ -1,4 +1,5 @@
|
||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
||||
from ..invocations.compel import CompelInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
from .item_storage import ItemStorageABC
|
||||
@@ -16,26 +17,32 @@ def create_text_to_image() -> LibraryGraph:
|
||||
nodes={
|
||||
'width': ParamIntInvocation(id='width', a=512),
|
||||
'height': ParamIntInvocation(id='height', a=512),
|
||||
'seed': ParamIntInvocation(id='seed', a=-1),
|
||||
'3': NoiseInvocation(id='3'),
|
||||
'4': TextToLatentsInvocation(id='4'),
|
||||
'5': LatentsToImageInvocation(id='5')
|
||||
'4': CompelInvocation(id='4'),
|
||||
'5': CompelInvocation(id='5'),
|
||||
'6': TextToLatentsInvocation(id='6'),
|
||||
'7': LatentsToImageInvocation(id='7'),
|
||||
},
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
|
||||
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
||||
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
||||
]
|
||||
),
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'),
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
||||
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height')
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
||||
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
||||
],
|
||||
exposed_outputs=[
|
||||
ExposedNodeOutput(node_path='5', field='image', alias='image')
|
||||
ExposedNodeOutput(node_path='7', field='image', alias='image')
|
||||
])
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Dict, TypedDict, Union
|
||||
from typing import Any
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import copy
|
||||
import itertools
|
||||
import uuid
|
||||
from types import NoneType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
@@ -13,10 +14,9 @@ from typing import (
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
NoneType = type(None)
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator, Extra
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
@@ -25,7 +25,6 @@ from ..invocations.baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
)
|
||||
from .config import InvokeAISettings
|
||||
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
@@ -212,10 +211,9 @@ class CollectInvocation(BaseInvocation):
|
||||
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||
|
||||
class Graph(InvokeAISettings):
|
||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||
type: Literal["graph"] = "graph"
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
@@ -751,7 +749,7 @@ class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
||||
|
||||
|
||||
# TODO: Store a reference to the graph instead of the actual graph?
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
@@ -807,7 +805,7 @@ class GraphExecutionState(BaseModel):
|
||||
]
|
||||
}
|
||||
|
||||
def next(self) -> Union[BaseInvocation, None]:
|
||||
def next(self) -> BaseInvocation | None:
|
||||
"""Gets the next node ready to execute."""
|
||||
|
||||
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
||||
@@ -1156,7 +1154,7 @@ class ExposedNodeOutput(BaseModel):
|
||||
field: str = Field(description="The field name of the output")
|
||||
alias: str = Field(description="The alias of the output")
|
||||
|
||||
class LibraryGraph(InvokeAISettings):
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
@@ -1164,9 +1162,6 @@ class LibraryGraph(InvokeAISettings):
|
||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
||||
|
||||
class Config:
|
||||
extra='allow'
|
||||
|
||||
@validator('exposed_inputs', 'exposed_outputs')
|
||||
def validate_exposed_aliases(cls, v):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class InvocationQueueABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: Union[InvocationQueueItem, None]) -> None:
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -57,7 +57,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
||||
|
||||
return item
|
||||
|
||||
def put(self, item: Union[InvocationQueueItem, None]) -> None:
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
self.__queue.put(item)
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from typing import types
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.backend import ModelManager
|
||||
|
||||
@@ -8,7 +10,6 @@ from .image_storage import ImageStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
from .config import InvokeAISettings
|
||||
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
@@ -20,7 +21,6 @@ class InvocationServices:
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
configuration: InvokeAISettings
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
@@ -31,6 +31,7 @@ class InvocationServices:
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
logger: types.ModuleType,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
@@ -39,10 +40,10 @@ class InvocationServices:
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
configuration: InvokeAISettings=None,
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.metadata = metadata
|
||||
@@ -51,4 +52,3 @@ class InvocationServices:
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
self.configuration = configuration
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from abc import ABC
|
||||
from threading import Event, Thread
|
||||
from typing import Union
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .graph import Graph, GraphExecutionState
|
||||
@@ -22,7 +21,7 @@ class Invoker:
|
||||
|
||||
def invoke(
|
||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> Union[str, None]:
|
||||
) -> str | None:
|
||||
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
|
||||
|
||||
# Get the next invocation
|
||||
@@ -45,12 +44,12 @@ class Invoker:
|
||||
|
||||
return invocation.id
|
||||
|
||||
def create_execution_state(self, graph: Union[Graph, None] = None) -> GraphExecutionState:
|
||||
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState:
|
||||
"""Creates a new execution state for the given graph"""
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
"""Cancels the given execution state"""
|
||||
self.services.queue.cancel(graph_execution_state_id)
|
||||
@@ -72,18 +71,12 @@ class Invoker:
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||
# First stop all services
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.services.queue.put(None)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Union
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
@@ -56,7 +56,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
|
||||
def __get_cache(self, name: str) -> Union[torch.Tensor,None]:
|
||||
def __get_cache(self, name: str) -> torch.Tensor|None:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
@@ -90,4 +90,4 @@ class DiskLatentsStorage(LatentsStorageBase):
|
||||
|
||||
def get_path(self, name: str) -> str:
|
||||
return os.path.join(self.__output_folder, name)
|
||||
|
||||
|
||||
@@ -5,22 +5,24 @@ from argparse import Namespace
|
||||
from invokeai.backend import Args
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import types
|
||||
|
||||
import invokeai.version
|
||||
from .config import InvokeAISettings
|
||||
from ...backend import ModelManager
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# TODO: Replace with an abstract class base ModelManagerBase
|
||||
def get_model_manager(config:InvokeAISettings) -> ModelManager:
|
||||
model_config = config.model_conf_path
|
||||
if not model_config.exists():
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {model_config} could not be found.")
|
||||
)
|
||||
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
||||
if not config.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
|
||||
)
|
||||
|
||||
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
print(f'>> InvokeAI runtime directory is "{config.root_dir}"')
|
||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
@@ -30,7 +32,20 @@ def get_model_manager(config:InvokeAISettings) -> ModelManager:
|
||||
import diffusers
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
embedding_path = config.embedding_path
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(config.conf):
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
||||
|
||||
if config.embeddings:
|
||||
if not os.path.isabs(config.embedding_path):
|
||||
embedding_path = os.path.normpath(
|
||||
os.path.join(Globals.root, config.embedding_path)
|
||||
)
|
||||
else:
|
||||
embedding_path = config.embedding_path
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
# migrate legacy models
|
||||
ModelManager.migrate_models()
|
||||
@@ -43,35 +58,38 @@ def get_model_manager(config:InvokeAISettings) -> ModelManager:
|
||||
else choose_precision(device)
|
||||
|
||||
model_manager = ModelManager(
|
||||
OmegaConf.load(model_config),
|
||||
OmegaConf.load(config.conf),
|
||||
precision=precision,
|
||||
device_type=device,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
embedding_path = embedding_path,
|
||||
embedding_path = Path(embedding_path),
|
||||
logger = logger,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(config, e)
|
||||
report_model_error(config, e, logger)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f"{e}. Aborting.")
|
||||
logger.error(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if config.autoconvert_path:
|
||||
model_manager.heuristic_import(
|
||||
config.autoconvert_path,
|
||||
if path := config.autoconvert:
|
||||
model_manager.autoconvert_weights(
|
||||
conf_path=config.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
logger.info('Model manager initialized')
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception):
|
||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
print(
|
||||
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
logger.error(
|
||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
print(
|
||||
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
logger.warning(
|
||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
response = input(
|
||||
@@ -80,13 +98,12 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
if response.startswith(("n", "N")):
|
||||
return
|
||||
|
||||
print("invokeai-configure is launching....\n")
|
||||
logger.info("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
previous_config = sys.argv
|
||||
sys.argv = ["invokeai-configure"]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config.to_dict())
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import traceback
|
||||
from threading import Event, Thread
|
||||
from threading import Event, Thread, BoundedSemaphore
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
@@ -10,8 +10,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
__invoker: Invoker
|
||||
__threadLimit: BoundedSemaphore
|
||||
|
||||
def start(self, invoker) -> None:
|
||||
# if we do want multithreading at some point, we could make this configurable
|
||||
self.__threadLimit = BoundedSemaphore(1)
|
||||
self.__invoker = invoker
|
||||
self.__stop_event = Event()
|
||||
self.__invoker_thread = Thread(
|
||||
@@ -20,7 +23,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
kwargs=dict(stop_event=self.__stop_event),
|
||||
)
|
||||
self.__invoker_thread.daemon = (
|
||||
True # TODO: probably better to just not use threads?
|
||||
True # TODO: make async and do not use threads
|
||||
)
|
||||
self.__invoker_thread.start()
|
||||
|
||||
@@ -29,6 +32,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
while not stop_event.is_set():
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
if not queue_item: # Probably stopping
|
||||
@@ -110,7 +114,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
@@ -127,4 +131,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
... # Log something?
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
finally:
|
||||
self.__threadLimit.release()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
from typing import types
|
||||
from ...backend.restoration import Restoration
|
||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
|
||||
@@ -10,7 +11,7 @@ from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
class RestorationServices:
|
||||
'''Face restoration and upscaling'''
|
||||
|
||||
def __init__(self,args):
|
||||
def __init__(self,args,logger:types.ModuleType):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if args.restore or args.esrgan:
|
||||
@@ -20,20 +21,22 @@ class RestorationServices:
|
||||
args.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
print(">> Face restoration disabled")
|
||||
logger.info("Face restoration disabled")
|
||||
if args.esrgan:
|
||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||
else:
|
||||
print(">> Upscaling disabled")
|
||||
logger.info("Upscaling disabled")
|
||||
else:
|
||||
print(">> Face restoration and upscaling disabled")
|
||||
logger.info("Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
self.device = torch.device(choose_torch_device())
|
||||
self.gfpgan = gfpgan
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.logger = logger
|
||||
self.logger.info('Face restoration initialized')
|
||||
|
||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||
# esrgan upscaling
|
||||
@@ -58,15 +61,15 @@ class RestorationServices:
|
||||
if self.gfpgan is not None or self.codeformer is not None:
|
||||
if facetool == "gfpgan":
|
||||
if self.gfpgan is None:
|
||||
print(
|
||||
">> GFPGAN not found. Face restoration is disabled."
|
||||
self.logger.info(
|
||||
"GFPGAN not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
if facetool == "codeformer":
|
||||
if self.codeformer is None:
|
||||
print(
|
||||
">> CodeFormer not found. Face restoration is disabled."
|
||||
self.logger.info(
|
||||
"CodeFormer not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
cf_device = (
|
||||
@@ -80,7 +83,7 @@ class RestorationServices:
|
||||
fidelity=codeformer_fidelity,
|
||||
)
|
||||
else:
|
||||
print(">> Face Restoration is disabled.")
|
||||
self.logger.info("Face Restoration is disabled.")
|
||||
if upscale is not None:
|
||||
if self.esrgan is not None:
|
||||
if len(upscale) < 2:
|
||||
@@ -93,10 +96,10 @@ class RestorationServices:
|
||||
denoise_str=upscale_denoise_str,
|
||||
)
|
||||
else:
|
||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||
except Exception as e:
|
||||
print(
|
||||
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
self.logger.info(
|
||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
)
|
||||
|
||||
if image_callback is not None:
|
||||
|
||||
@@ -96,6 +96,7 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import invokeai.version
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.image_util import retrieve_metadata
|
||||
|
||||
from .globals import Globals
|
||||
@@ -189,7 +190,7 @@ class Args(object):
|
||||
print(f"{APP_NAME} {APP_VERSION}")
|
||||
sys.exit(0)
|
||||
|
||||
print("* Initializing, be patient...")
|
||||
logger.info("Initializing, be patient...")
|
||||
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
|
||||
Globals.try_patchmatch = switches.patchmatch
|
||||
|
||||
@@ -197,14 +198,13 @@ class Args(object):
|
||||
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
|
||||
legacyinit = os.path.expanduser("~/.invokeai")
|
||||
if os.path.exists(initfile):
|
||||
print(
|
||||
f">> Initialization file {initfile} found. Loading...",
|
||||
file=sys.stderr,
|
||||
logger.info(
|
||||
f"Initialization file {initfile} found. Loading...",
|
||||
)
|
||||
sysargs.insert(0, f"@{initfile}")
|
||||
elif os.path.exists(legacyinit):
|
||||
print(
|
||||
f">> WARNING: Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
|
||||
logger.warning(
|
||||
f"Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
|
||||
)
|
||||
sysargs.insert(0, f"@{legacyinit}")
|
||||
Globals.log_tokenization = self._arg_parser.parse_args(
|
||||
@@ -214,7 +214,7 @@ class Args(object):
|
||||
self._arg_switches = self._arg_parser.parse_args(sysargs)
|
||||
return self._arg_switches
|
||||
except Exception as e:
|
||||
print(f"An exception has occurred: {e}")
|
||||
logger.error(f"An exception has occurred: {e}")
|
||||
return None
|
||||
|
||||
def parse_cmd(self, cmd_string):
|
||||
@@ -1154,7 +1154,7 @@ class Args(object):
|
||||
|
||||
|
||||
def format_metadata(**kwargs):
|
||||
print("format_metadata() is deprecated. Please use metadata_dumps()")
|
||||
logger.warning("format_metadata() is deprecated. Please use metadata_dumps()")
|
||||
return metadata_dumps(kwargs)
|
||||
|
||||
|
||||
@@ -1326,7 +1326,7 @@ def metadata_loads(metadata) -> list:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
print(">> could not read metadata", file=sys.stderr)
|
||||
logger.error("Could not read metadata")
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return results
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .args import metadata_from_png
|
||||
from .generator import infill_methods
|
||||
from .globals import Globals, global_cache_dir
|
||||
@@ -195,12 +196,12 @@ class Generate:
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
# it wasn't actually doing anything. This logic could be reinstated.
|
||||
self.device = torch.device(choose_torch_device())
|
||||
print(f">> Using device_type {self.device.type}")
|
||||
logger.info(f"Using device_type {self.device.type}")
|
||||
if full_precision:
|
||||
if self.precision != "auto":
|
||||
raise ValueError("Remove --full_precision / -F if using --precision")
|
||||
print("Please remove deprecated --full_precision / -F")
|
||||
print("If auto config does not work you can use --precision=float32")
|
||||
logger.warning("Please remove deprecated --full_precision / -F")
|
||||
logger.warning("If auto config does not work you can use --precision=float32")
|
||||
self.precision = "float32"
|
||||
if self.precision == "auto":
|
||||
self.precision = choose_precision(self.device)
|
||||
@@ -208,13 +209,13 @@ class Generate:
|
||||
|
||||
if is_xformers_available():
|
||||
if torch.cuda.is_available() and not Globals.disable_xformers:
|
||||
print(">> xformers memory-efficient attention is available and enabled")
|
||||
logger.info("xformers memory-efficient attention is available and enabled")
|
||||
else:
|
||||
print(
|
||||
">> xformers memory-efficient attention is available but disabled"
|
||||
logger.info(
|
||||
"xformers memory-efficient attention is available but disabled"
|
||||
)
|
||||
else:
|
||||
print(">> xformers not installed")
|
||||
logger.info("xformers not installed")
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_manager = ModelManager(
|
||||
@@ -229,8 +230,8 @@ class Generate:
|
||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||
model = model or fallback
|
||||
if not self.model_manager.valid_model(model):
|
||||
print(
|
||||
f'** "{model}" is not a known model name; falling back to {fallback}.'
|
||||
logger.warning(
|
||||
f'"{model}" is not a known model name; falling back to {fallback}.'
|
||||
)
|
||||
model = None
|
||||
self.model_name = model or fallback
|
||||
@@ -246,10 +247,10 @@ class Generate:
|
||||
|
||||
# load safety checker if requested
|
||||
if safety_checker:
|
||||
print(">> Initializing NSFW checker")
|
||||
logger.info("Initializing NSFW checker")
|
||||
self.safety_checker = SafetyChecker(self.device)
|
||||
else:
|
||||
print(">> NSFW checker is disabled")
|
||||
logger.info("NSFW checker is disabled")
|
||||
|
||||
def prompt2png(self, prompt, outdir, **kwargs):
|
||||
"""
|
||||
@@ -567,7 +568,7 @@ class Generate:
|
||||
self.clear_cuda_cache()
|
||||
|
||||
if catch_interrupts:
|
||||
print("**Interrupted** Partial results will be returned.")
|
||||
logger.warning("Interrupted** Partial results will be returned.")
|
||||
else:
|
||||
raise KeyboardInterrupt
|
||||
except RuntimeError:
|
||||
@@ -575,11 +576,11 @@ class Generate:
|
||||
self.clear_cuda_cache()
|
||||
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(">> Could not generate image.")
|
||||
logger.info("Could not generate image.")
|
||||
|
||||
toc = time.time()
|
||||
print("\n>> Usage stats:")
|
||||
print(f">> {len(results)} image(s) generated in", "%4.2fs" % (toc - tic))
|
||||
logger.info("Usage stats:")
|
||||
logger.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic))
|
||||
self.print_cuda_stats()
|
||||
return results
|
||||
|
||||
@@ -609,16 +610,16 @@ class Generate:
|
||||
def print_cuda_stats(self):
|
||||
if self._has_cuda():
|
||||
self.gather_cuda_stats()
|
||||
print(
|
||||
">> Max VRAM used for this generation:",
|
||||
"%4.2fG." % (self.max_memory_allocated / 1e9),
|
||||
"Current VRAM utilization:",
|
||||
"%4.2fG" % (self.memory_allocated / 1e9),
|
||||
logger.info(
|
||||
"Max VRAM used for this generation: "+
|
||||
"%4.2fG. " % (self.max_memory_allocated / 1e9)+
|
||||
"Current VRAM utilization: "+
|
||||
"%4.2fG" % (self.memory_allocated / 1e9)
|
||||
)
|
||||
|
||||
print(
|
||||
">> Max VRAM used since script start: ",
|
||||
"%4.2fG" % (self.session_peakmem / 1e9),
|
||||
logger.info(
|
||||
"Max VRAM used since script start: " +
|
||||
"%4.2fG" % (self.session_peakmem / 1e9)
|
||||
)
|
||||
|
||||
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
|
||||
@@ -647,7 +648,7 @@ class Generate:
|
||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
|
||||
prompt = opt.prompt or args.prompt or ""
|
||||
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
|
||||
logger.info(f'using seed {seed} and prompt "{prompt}" for {image_path}')
|
||||
|
||||
# try to reuse the same filename prefix as the original file.
|
||||
# we take everything up to the first period
|
||||
@@ -696,8 +697,8 @@ class Generate:
|
||||
try:
|
||||
extend_instructions[direction] = int(pixels)
|
||||
except ValueError:
|
||||
print(
|
||||
'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
||||
logger.warning(
|
||||
'invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
||||
)
|
||||
|
||||
opt.seed = seed
|
||||
@@ -720,8 +721,8 @@ class Generate:
|
||||
# fetch the metadata from the image
|
||||
generator = self.select_generator(embiggen=True)
|
||||
opt.strength = opt.embiggen_strength or 0.40
|
||||
print(
|
||||
f">> Setting img2img strength to {opt.strength} for happy embiggening"
|
||||
logger.info(
|
||||
f"Setting img2img strength to {opt.strength} for happy embiggening"
|
||||
)
|
||||
generator.generate(
|
||||
prompt,
|
||||
@@ -748,12 +749,12 @@ class Generate:
|
||||
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
|
||||
|
||||
elif tool is None:
|
||||
print(
|
||||
"* please provide at least one postprocessing option, such as -G or -U"
|
||||
logger.warning(
|
||||
"please provide at least one postprocessing option, such as -G or -U"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
print(f"* postprocessing tool {tool} is not yet supported")
|
||||
logger.warning(f"postprocessing tool {tool} is not yet supported")
|
||||
return None
|
||||
|
||||
def select_generator(
|
||||
@@ -797,8 +798,8 @@ class Generate:
|
||||
image = self._load_img(img)
|
||||
|
||||
if image.width < self.width and image.height < self.height:
|
||||
print(
|
||||
f">> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
||||
logger.warning(
|
||||
f"img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
||||
)
|
||||
|
||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||
@@ -809,8 +810,8 @@ class Generate:
|
||||
if (image.width * image.height) > (
|
||||
self.width * self.height
|
||||
) and self.size_matters:
|
||||
print(
|
||||
">> This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
||||
logger.info(
|
||||
"This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
||||
)
|
||||
self.size_matters = False
|
||||
|
||||
@@ -891,11 +892,11 @@ class Generate:
|
||||
try:
|
||||
model_data = cache.get_model(model_name)
|
||||
except Exception as e:
|
||||
print(f"** model {model_name} could not be loaded: {str(e)}")
|
||||
logger.warning(f"model {model_name} could not be loaded: {str(e)}")
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
if previous_model_name is None:
|
||||
raise e
|
||||
print("** trying to reload previous model")
|
||||
logger.warning("trying to reload previous model")
|
||||
model_data = cache.get_model(previous_model_name) # load previous
|
||||
if model_data is None:
|
||||
raise e
|
||||
@@ -962,15 +963,15 @@ class Generate:
|
||||
if self.gfpgan is not None or self.codeformer is not None:
|
||||
if facetool == "gfpgan":
|
||||
if self.gfpgan is None:
|
||||
print(
|
||||
">> GFPGAN not found. Face restoration is disabled."
|
||||
logger.info(
|
||||
"GFPGAN not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
if facetool == "codeformer":
|
||||
if self.codeformer is None:
|
||||
print(
|
||||
">> CodeFormer not found. Face restoration is disabled."
|
||||
logger.info(
|
||||
"CodeFormer not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
cf_device = (
|
||||
@@ -984,7 +985,7 @@ class Generate:
|
||||
fidelity=codeformer_fidelity,
|
||||
)
|
||||
else:
|
||||
print(">> Face Restoration is disabled.")
|
||||
logger.info("Face Restoration is disabled.")
|
||||
if upscale is not None:
|
||||
if self.esrgan is not None:
|
||||
if len(upscale) < 2:
|
||||
@@ -997,10 +998,10 @@ class Generate:
|
||||
denoise_str=upscale_denoise_str,
|
||||
)
|
||||
else:
|
||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||
logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||
except Exception as e:
|
||||
print(
|
||||
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
logger.info(
|
||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
)
|
||||
|
||||
if image_callback is not None:
|
||||
@@ -1066,17 +1067,17 @@ class Generate:
|
||||
if self.sampler_name in scheduler_map:
|
||||
sampler_class = scheduler_map[self.sampler_name]
|
||||
msg = (
|
||||
f">> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
||||
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
||||
)
|
||||
self.sampler = sampler_class.from_config(self.model.scheduler.config)
|
||||
else:
|
||||
msg = (
|
||||
f">> Unsupported Sampler: {self.sampler_name} "
|
||||
f" Unsupported Sampler: {self.sampler_name} "+
|
||||
f"Defaulting to {default}"
|
||||
)
|
||||
self.sampler = default
|
||||
|
||||
print(msg)
|
||||
logger.info(msg)
|
||||
|
||||
if not hasattr(self.sampler, "uses_inpainting_model"):
|
||||
# FIXME: terrible kludge!
|
||||
@@ -1085,17 +1086,17 @@ class Generate:
|
||||
def _load_img(self, img) -> Image:
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
print(f">> using provided input image of size {image.width}x{image.height}")
|
||||
logger.info(f"using provided input image of size {image.width}x{image.height}")
|
||||
elif isinstance(img, str):
|
||||
assert os.path.exists(img), f">> {img}: File not found"
|
||||
assert os.path.exists(img), f"{img}: File not found"
|
||||
|
||||
image = Image.open(img)
|
||||
print(
|
||||
f">> loaded input image of size {image.width}x{image.height} from {img}"
|
||||
logger.info(
|
||||
f"loaded input image of size {image.width}x{image.height} from {img}"
|
||||
)
|
||||
else:
|
||||
image = Image.open(img)
|
||||
print(f">> loaded input image of size {image.width}x{image.height}")
|
||||
logger.info(f"loaded input image of size {image.width}x{image.height}")
|
||||
image = ImageOps.exif_transpose(image)
|
||||
return image
|
||||
|
||||
@@ -1183,14 +1184,14 @@ class Generate:
|
||||
|
||||
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
|
||||
if not mask:
|
||||
print(
|
||||
">> Initial image has transparent areas. Will inpaint in these regions."
|
||||
logger.info(
|
||||
"Initial image has transparent areas. Will inpaint in these regions."
|
||||
)
|
||||
if (not force_outpaint) and self._check_for_erasure(image):
|
||||
print(
|
||||
">> WARNING: Colors underneath the transparent region seem to have been erased.\n",
|
||||
">> Inpainting will be suboptimal. Please preserve the colors when making\n",
|
||||
">> a transparency mask, or provide mask explicitly using --init_mask (-M).",
|
||||
if (not force_outpaint) and self._check_for_erasure(image):
|
||||
logger.info(
|
||||
"Colors underneath the transparent region seem to have been erased.\n" +
|
||||
"Inpainting will be suboptimal. Please preserve the colors when making\n" +
|
||||
"a transparency mask, or provide mask explicitly using --init_mask (-M)."
|
||||
)
|
||||
|
||||
def _squeeze_image(self, image):
|
||||
@@ -1201,11 +1202,11 @@ class Generate:
|
||||
|
||||
def _fit_image(self, image, max_dimensions):
|
||||
w, h = max_dimensions
|
||||
print(f">> image will be resized to fit inside a box {w}x{h} in size.")
|
||||
logger.info(f"image will be resized to fit inside a box {w}x{h} in size.")
|
||||
# note that InitImageResizer does the multiple of 64 truncation internally
|
||||
image = InitImageResizer(image).resize(width=w, height=h)
|
||||
print(
|
||||
f">> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
||||
logger.info(
|
||||
f"after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
||||
)
|
||||
return image
|
||||
|
||||
@@ -1216,8 +1217,8 @@ class Generate:
|
||||
) # resize to integer multiple of 64
|
||||
if h != height or w != width:
|
||||
if log:
|
||||
print(
|
||||
f">> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
||||
logger.info(
|
||||
f"Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
||||
)
|
||||
height = h
|
||||
width = w
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import Callable, List, Iterator, Optional, Type
|
||||
from dataclasses import dataclass, field
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..image_util import configure_model_padding
|
||||
from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
@@ -372,7 +373,7 @@ class Generator:
|
||||
try:
|
||||
x_T = self.get_noise(width, height)
|
||||
except:
|
||||
print("** An error occurred while getting initial noise **")
|
||||
logger.error("An error occurred while getting initial noise")
|
||||
print(traceback.format_exc())
|
||||
|
||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||
@@ -607,7 +608,7 @@ class Generator:
|
||||
image = self.sample_to_image(sample)
|
||||
dirname = os.path.dirname(filepath) or "."
|
||||
if not os.path.exists(dirname):
|
||||
print(f"** creating directory {dirname}")
|
||||
logger.info(f"creating directory {dirname}")
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath, "PNG")
|
||||
|
||||
|
||||
@@ -8,10 +8,11 @@ import torch
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from .base import Generator
|
||||
from .img2img import Img2Img
|
||||
|
||||
|
||||
class Embiggen(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
@@ -72,22 +73,22 @@ class Embiggen(Generator):
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
elif embiggen[0] < 0:
|
||||
embiggen[0] = 1.0
|
||||
print(
|
||||
">> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||
logger.warning(
|
||||
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||
)
|
||||
if len(embiggen) < 2:
|
||||
embiggen.append(0.75)
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||
embiggen[1] = 0.75
|
||||
print(
|
||||
">> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||
logger.warning(
|
||||
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||
)
|
||||
if len(embiggen) < 3:
|
||||
embiggen.append(0.25)
|
||||
elif embiggen[2] < 0:
|
||||
embiggen[2] = 0.25
|
||||
print(
|
||||
">> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||
logger.warning(
|
||||
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||
)
|
||||
|
||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||
@@ -97,8 +98,8 @@ class Embiggen(Generator):
|
||||
embiggen_tiles.sort()
|
||||
|
||||
if strength >= 0.5:
|
||||
print(
|
||||
f"* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||
logger.warning(
|
||||
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||
)
|
||||
|
||||
# Prep img2img generator, since we wrap over it
|
||||
@@ -121,8 +122,8 @@ class Embiggen(Generator):
|
||||
from ..restoration.realesrgan import ESRGAN
|
||||
|
||||
esrgan = ESRGAN()
|
||||
print(
|
||||
f">> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||
logger.info(
|
||||
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||
)
|
||||
if embiggen[0] > 2:
|
||||
initsuperimage = esrgan.process(
|
||||
@@ -312,10 +313,10 @@ class Embiggen(Generator):
|
||||
def make_image():
|
||||
# Make main tiles -------------------------------------------------
|
||||
if embiggen_tiles:
|
||||
print(f">> Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||
else:
|
||||
print(
|
||||
f">> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||
logger.info(
|
||||
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||
)
|
||||
|
||||
emb_tile_store = []
|
||||
@@ -361,11 +362,11 @@ class Embiggen(Generator):
|
||||
# newinitimage.save(newinitimagepath)
|
||||
|
||||
if embiggen_tiles:
|
||||
print(
|
||||
logger.debug(
|
||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
||||
)
|
||||
else:
|
||||
print(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||
|
||||
# create a torch tensor from an Image
|
||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||
@@ -547,8 +548,8 @@ class Embiggen(Generator):
|
||||
# Layer tile onto final image
|
||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||
else:
|
||||
print(
|
||||
"Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||
logger.error(
|
||||
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||
)
|
||||
|
||||
# after internal loops and patching up return Embiggen image
|
||||
|
||||
@@ -14,6 +14,8 @@ from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeli
|
||||
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
||||
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
@@ -77,8 +79,8 @@ class Txt2Img2Img(Generator):
|
||||
# the message below is accurate.
|
||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||
print(
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
logger.info(
|
||||
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
|
||||
@@ -16,7 +16,6 @@ import os.path as osp
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from pydantic import BaseSettings
|
||||
|
||||
Globals = Namespace()
|
||||
|
||||
@@ -121,15 +120,3 @@ def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
|
||||
return Path(home, subdir)
|
||||
else:
|
||||
return Path(Globals.root, "models", subdir)
|
||||
|
||||
def copy_conf_to_globals(conf: Union[dict,BaseSettings]):
|
||||
'''
|
||||
Given a dict or dict-like object, copy its keys and
|
||||
values into the Globals Namespace. This is a transitional
|
||||
workaround until we remove Globals entirely.
|
||||
'''
|
||||
if isinstance(conf,BaseSettings):
|
||||
conf = conf.dict()
|
||||
for key in conf.keys():
|
||||
if key is not None:
|
||||
setattr(Globals,key,conf[key])
|
||||
|
||||
@@ -5,10 +5,9 @@ wraps the actual patchmatch object. It respects the global
|
||||
be suppressed or deferred
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
Thin class wrapper around the patchmatch function.
|
||||
@@ -28,12 +27,12 @@ class PatchMatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
print(">> Patchmatch initialized")
|
||||
logger.info("Patchmatch initialized")
|
||||
else:
|
||||
print(">> Patchmatch not loaded (nonfatal)")
|
||||
logger.info("Patchmatch not loaded (nonfatal)")
|
||||
self.patch_match = pm
|
||||
else:
|
||||
print(">> Patchmatch loading disabled")
|
||||
logger.info("Patchmatch loading disabled")
|
||||
self.tried_load = True
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -30,9 +30,9 @@ work fine.
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from torchvision import transforms
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import global_cache_dir
|
||||
|
||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||
@@ -83,7 +83,7 @@ class Txt2Mask(object):
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", refined=False):
|
||||
print(">> Initializing clipseg model for text to mask inference")
|
||||
logger.info("Initializing clipseg model for text to mask inference")
|
||||
|
||||
# BUG: we are not doing anything with the device option at this time
|
||||
self.device = device
|
||||
@@ -101,18 +101,6 @@ class Txt2Mask(object):
|
||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||
pixels indicate where the object is inferred to be.
|
||||
"""
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
transforms.Resize(
|
||||
(CLIPSEG_SIZE, CLIPSEG_SIZE)
|
||||
), # must be multiple of 64...
|
||||
]
|
||||
)
|
||||
|
||||
if type(image) is str:
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import Union
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
||||
|
||||
from .model_manager import ModelManager, SDLegacyType
|
||||
@@ -372,9 +373,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
unet_key = "model.diffusion_model."
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
logger.debug(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
if extract_ema:
|
||||
print(" | Extracting EMA weights (usually better for inference)")
|
||||
logger.debug("Extracting EMA weights (usually better for inference)")
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
@@ -392,8 +393,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
key
|
||||
)
|
||||
else:
|
||||
print(
|
||||
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||
logger.debug(
|
||||
"Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
@@ -1115,7 +1116,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
print(" | global_step key not found in model")
|
||||
logger.debug("global_step key not found in model")
|
||||
global_step = None
|
||||
|
||||
# sometimes there is a state_dict key and sometimes not
|
||||
@@ -1229,15 +1230,15 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
# If a replacement VAE path was specified, we'll incorporate that into
|
||||
# the checkpoint model and then convert it
|
||||
if vae_path:
|
||||
print(f" | Converting VAE {vae_path}")
|
||||
logger.debug(f"Converting VAE {vae_path}")
|
||||
replace_checkpoint_vae(checkpoint,vae_path)
|
||||
# otherwise we use the original VAE, provided that
|
||||
# an externally loaded diffusers VAE was not passed
|
||||
elif not vae:
|
||||
print(" | Using checkpoint model's original VAE")
|
||||
logger.debug("Using checkpoint model's original VAE")
|
||||
|
||||
if vae:
|
||||
print(" | Using replacement diffusers VAE")
|
||||
logger.debug("Using replacement diffusers VAE")
|
||||
else: # convert the original or replacement VAE
|
||||
vae_config = create_vae_diffusers_config(
|
||||
original_config, image_size=image_size
|
||||
|
||||
@@ -18,12 +18,13 @@ import warnings
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union, Callable
|
||||
from typing import Any, Optional, Union, Callable, types
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import transformers
|
||||
import invokeai.backend.util.logging as logger
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
@@ -75,6 +76,8 @@ class ModelManager(object):
|
||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||
"""
|
||||
|
||||
logger: types.ModuleType = logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf | Path,
|
||||
@@ -83,6 +86,7 @@ class ModelManager(object):
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
embedding_path: Path = None,
|
||||
logger: types.ModuleType = logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file or
|
||||
@@ -104,6 +108,7 @@ class ModelManager(object):
|
||||
self.current_model = None
|
||||
self.sequential_offload = sequential_offload
|
||||
self.embedding_path = embedding_path
|
||||
self.logger = logger
|
||||
|
||||
def valid_model(self, model_name: str) -> bool:
|
||||
"""
|
||||
@@ -132,8 +137,8 @@ class ModelManager(object):
|
||||
)
|
||||
|
||||
if not self.valid_model(model_name):
|
||||
print(
|
||||
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
self.logger.error(
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
return self.current_model
|
||||
|
||||
@@ -144,7 +149,7 @@ class ModelManager(object):
|
||||
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]["model"]
|
||||
print(f">> Retrieving model {model_name} from system RAM cache")
|
||||
self.logger.info(f"Retrieving model {model_name} from system RAM cache")
|
||||
requested_model.ready()
|
||||
width = self.models[model_name]["width"]
|
||||
height = self.models[model_name]["height"]
|
||||
@@ -379,7 +384,7 @@ class ModelManager(object):
|
||||
"""
|
||||
omega = self.config
|
||||
if model_name not in omega:
|
||||
print(f"** Unknown model {model_name}")
|
||||
self.logger.error(f"Unknown model {model_name}")
|
||||
return
|
||||
# save these for use in deletion later
|
||||
conf = omega[model_name]
|
||||
@@ -392,13 +397,13 @@ class ModelManager(object):
|
||||
self.stack.remove(model_name)
|
||||
if delete_files:
|
||||
if weights:
|
||||
print(f"** Deleting file {weights}")
|
||||
self.logger.info(f"Deleting file {weights}")
|
||||
Path(weights).unlink(missing_ok=True)
|
||||
elif path:
|
||||
print(f"** Deleting directory {path}")
|
||||
self.logger.info(f"Deleting directory {path}")
|
||||
rmtree(path, ignore_errors=True)
|
||||
elif repo_id:
|
||||
print(f"** Deleting the cached model directory for {repo_id}")
|
||||
self.logger.info(f"Deleting the cached model directory for {repo_id}")
|
||||
self._delete_model_from_cache(repo_id)
|
||||
|
||||
def add_model(
|
||||
@@ -439,7 +444,7 @@ class ModelManager(object):
|
||||
def _load_model(self, model_name: str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if model_name not in self.config:
|
||||
print(
|
||||
self.logger.error(
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
return
|
||||
@@ -457,7 +462,7 @@ class ModelManager(object):
|
||||
model_format = mconfig.get("format", "ckpt")
|
||||
if model_format == "ckpt":
|
||||
weights = mconfig.weights
|
||||
print(f">> Loading {model_name} from {weights}")
|
||||
self.logger.info(f"Loading {model_name} from {weights}")
|
||||
model, width, height, model_hash = self._load_ckpt_model(
|
||||
model_name, mconfig
|
||||
)
|
||||
@@ -473,13 +478,15 @@ class ModelManager(object):
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(">> Model loaded in", "%4.2fs" % (toc - tic))
|
||||
self.logger.info("Model loaded in " + "%4.2fs" % (toc - tic))
|
||||
if self._has_cuda():
|
||||
print(
|
||||
">> Max VRAM used to load the model:",
|
||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
"\n>> Current VRAM usage:"
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||
self.logger.info(
|
||||
"Max VRAM used to load the model: "+
|
||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
|
||||
)
|
||||
self.logger.info(
|
||||
"Current VRAM usage: "+
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
||||
)
|
||||
return model, width, height, model_hash
|
||||
|
||||
@@ -487,11 +494,11 @@ class ModelManager(object):
|
||||
name_or_path = self.model_name_or_path(mconfig)
|
||||
using_fp16 = self.precision == "float16"
|
||||
|
||||
print(f">> Loading diffusers model from {name_or_path}")
|
||||
self.logger.info(f"Loading diffusers model from {name_or_path}")
|
||||
if using_fp16:
|
||||
print(" | Using faster float16 precision")
|
||||
self.logger.debug("Using faster float16 precision")
|
||||
else:
|
||||
print(" | Using more accurate float32 precision")
|
||||
self.logger.debug("Using more accurate float32 precision")
|
||||
|
||||
# TODO: scan weights maybe?
|
||||
pipeline_args: dict[str, Any] = dict(
|
||||
@@ -523,8 +530,8 @@ class ModelManager(object):
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f"** An unexpected error occurred while downloading the model: {e})"
|
||||
self.logger.error(
|
||||
f"An unexpected error occurred while downloading the model: {e})"
|
||||
)
|
||||
if pipeline:
|
||||
break
|
||||
@@ -542,7 +549,7 @@ class ModelManager(object):
|
||||
# square images???
|
||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||
height = width
|
||||
print(f" | Default image dimensions = {width} x {height}")
|
||||
self.logger.debug(f"Default image dimensions = {width} x {height}")
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
@@ -559,14 +566,14 @@ class ModelManager(object):
|
||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||
|
||||
# Convert to diffusers and return a diffusers pipeline
|
||||
print(f">> Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
try:
|
||||
if self.list_models()[self.current_model]["status"] == "active":
|
||||
self.offload_model(self.current_model)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
vae_path = None
|
||||
@@ -624,7 +631,7 @@ class ModelManager(object):
|
||||
if model_name not in self.models:
|
||||
return
|
||||
|
||||
print(f">> Offloading {model_name} to CPU")
|
||||
self.logger.info(f"Offloading {model_name} to CPU")
|
||||
model = self.models[model_name]["model"]
|
||||
model.offload_all()
|
||||
self.current_model = None
|
||||
@@ -640,30 +647,26 @@ class ModelManager(object):
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
print(f" | Scanning Model: {model_name}")
|
||||
self.logger.debug(f"Scanning Model: {model_name}")
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files == 1:
|
||||
print(f"\n### Issues Found In Model: {scan_result.issues_count}")
|
||||
print(
|
||||
"### WARNING: The model you are trying to load seems to be infected."
|
||||
)
|
||||
print("### For your safety, InvokeAI will not load this model.")
|
||||
print("### Please use checkpoints from trusted sources.")
|
||||
print("### Exiting InvokeAI")
|
||||
self.logger.critical(f"Issues Found In Model: {scan_result.issues_count}")
|
||||
self.logger.critical("The model you are trying to load seems to be infected.")
|
||||
self.logger.critical("For your safety, InvokeAI will not load this model.")
|
||||
self.logger.critical("Please use checkpoints from trusted sources.")
|
||||
self.logger.critical("Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print(
|
||||
"\n### WARNING: InvokeAI was unable to scan the model you are using."
|
||||
)
|
||||
self.logger.warning("InvokeAI was unable to scan the model you are using.")
|
||||
model_safe_check_fail = ask_user(
|
||||
"Do you want to to continue loading the model?", ["y", "n"]
|
||||
)
|
||||
if model_safe_check_fail.lower() != "y":
|
||||
print("### Exiting InvokeAI")
|
||||
self.logger.critical("Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print(" | Model scanned ok")
|
||||
self.logger.debug("Model scanned ok")
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
@@ -780,26 +783,24 @@ class ModelManager(object):
|
||||
model_path: Path = None
|
||||
thing = path_url_or_repo # to save typing
|
||||
|
||||
print(f">> Probing {thing} for import")
|
||||
self.logger.info(f"Probing {thing} for import")
|
||||
|
||||
if thing.startswith(("http:", "https:", "ftp:")):
|
||||
print(f" | {thing} appears to be a URL")
|
||||
self.logger.info(f"{thing} appears to be a URL")
|
||||
model_path = self._resolve_path(
|
||||
thing, "models/ldm/stable-diffusion-v1"
|
||||
) # _resolve_path does a download if needed
|
||||
|
||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||
print(
|
||||
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
||||
)
|
||||
self.logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
|
||||
return
|
||||
else:
|
||||
print(f" | {thing} appears to be a checkpoint file on disk")
|
||||
self.logger.debug(f"{thing} appears to be a checkpoint file on disk")
|
||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||
|
||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||
print(f" | {thing} appears to be a diffusers file on disk")
|
||||
self.logger.debug(f"{thing} appears to be a diffusers file on disk")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
@@ -810,34 +811,30 @@ class ModelManager(object):
|
||||
|
||||
elif Path(thing).is_dir():
|
||||
if (Path(thing) / "model_index.json").exists():
|
||||
print(f" | {thing} appears to be a diffusers model.")
|
||||
self.logger.debug(f"{thing} appears to be a diffusers model.")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" |{thing} appears to be a directory. Will scan for models to import"
|
||||
)
|
||||
self.logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
|
||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
||||
Path(thing).rglob("*.safetensors")
|
||||
):
|
||||
if model_name := self.heuristic_import(
|
||||
str(m), commit_to_conf=commit_to_conf
|
||||
):
|
||||
print(f" >> {model_name} successfully imported")
|
||||
self.logger.info(f"{model_name} successfully imported")
|
||||
return model_name
|
||||
|
||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
||||
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
||||
self.logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
||||
return model_name
|
||||
else:
|
||||
print(
|
||||
f"** {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id"
|
||||
)
|
||||
self.logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
||||
|
||||
# Model_path is set in the event of a legacy checkpoint file.
|
||||
# If not set, we're all done
|
||||
@@ -845,7 +842,7 @@ class ModelManager(object):
|
||||
return
|
||||
|
||||
if model_path.stem in self.config: # already imported
|
||||
print(" | Already imported. Skipping")
|
||||
self.logger.debug("Already imported. Skipping")
|
||||
return model_path.stem
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
@@ -861,39 +858,39 @@ class ModelManager(object):
|
||||
# look for a like-named .yaml file in same directory
|
||||
if model_path.with_suffix(".yaml").exists():
|
||||
model_config_file = model_path.with_suffix(".yaml")
|
||||
print(f" | Using config file {model_config_file.name}")
|
||||
self.logger.debug(f"Using config file {model_config_file.name}")
|
||||
|
||||
else:
|
||||
model_type = self.probe_model_type(checkpoint)
|
||||
if model_type == SDLegacyType.V1:
|
||||
print(" | SD-v1 model detected")
|
||||
self.logger.debug("SD-v1 model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
self.logger.debug("SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root,
|
||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_v:
|
||||
print(" | SD-v2-v model detected")
|
||||
self.logger.debug("SD-v2-v model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
print(" | SD-v2-e model detected")
|
||||
self.logger.debug("SD-v2-e model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2:
|
||||
print(
|
||||
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||
self.logger.warning(
|
||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||
)
|
||||
return
|
||||
else:
|
||||
print(
|
||||
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||
self.logger.warning(
|
||||
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||
)
|
||||
return
|
||||
|
||||
@@ -909,7 +906,7 @@ class ModelManager(object):
|
||||
for suffix in ["pt", "ckpt", "safetensors"]:
|
||||
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
||||
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
||||
print(f" | Using VAE file {vae_path.name}")
|
||||
self.logger.debug(f"Using VAE file {vae_path.name}")
|
||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||
|
||||
diffuser_path = Path(
|
||||
@@ -955,14 +952,14 @@ class ModelManager(object):
|
||||
from . import convert_ckpt_to_diffusers
|
||||
|
||||
if diffusers_path.exists():
|
||||
print(
|
||||
f"ERROR: The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
||||
self.logger.error(
|
||||
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
||||
)
|
||||
return
|
||||
|
||||
model_name = model_name or diffusers_path.name
|
||||
model_description = model_description or f"Converted version of {model_name}"
|
||||
print(f" | Converting {model_name} to diffusers (30-60s)")
|
||||
self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
|
||||
try:
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
@@ -979,10 +976,10 @@ class ModelManager(object):
|
||||
vae_path=vae_path,
|
||||
scan_needed=scan_needed,
|
||||
)
|
||||
print(
|
||||
f" | Success. Converted model is now located at {str(diffusers_path)}"
|
||||
self.logger.debug(
|
||||
f"Success. Converted model is now located at {str(diffusers_path)}"
|
||||
)
|
||||
print(f" | Writing new config file entry for {model_name}")
|
||||
self.logger.debug(f"Writing new config file entry for {model_name}")
|
||||
new_config = dict(
|
||||
path=str(diffusers_path),
|
||||
description=model_description,
|
||||
@@ -993,17 +990,17 @@ class ModelManager(object):
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
print(" | Conversion succeeded")
|
||||
self.logger.debug("Conversion succeeded")
|
||||
except Exception as e:
|
||||
print(f"** Conversion failed: {str(e)}")
|
||||
print(
|
||||
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||
self.logger.warning(f"Conversion failed: {str(e)}")
|
||||
self.logger.warning(
|
||||
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||
)
|
||||
|
||||
return model_name
|
||||
|
||||
def search_models(self, search_folder):
|
||||
print(f">> Finding Models In: {search_folder}")
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
||||
|
||||
@@ -1027,8 +1024,8 @@ class ModelManager(object):
|
||||
num_loaded_models = len(self.models)
|
||||
if num_loaded_models >= self.max_loaded_models:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
print(
|
||||
f">> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
||||
self.logger.info(
|
||||
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
||||
)
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
@@ -1036,8 +1033,8 @@ class ModelManager(object):
|
||||
|
||||
def print_vram_usage(self) -> None:
|
||||
if self._has_cuda:
|
||||
print(
|
||||
">> Current VRAM usage: ",
|
||||
self.logger.info(
|
||||
"Current VRAM usage:"+
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
@@ -1126,10 +1123,10 @@ class ModelManager(object):
|
||||
dest = hub / model.stem
|
||||
if dest.exists() and not source.exists():
|
||||
continue
|
||||
print(f"** {source} => {dest}")
|
||||
cls.logger.info(f"{source} => {dest}")
|
||||
if source.exists():
|
||||
if dest.is_symlink():
|
||||
print(f"** Found symlink at {dest.name}. Not migrating.")
|
||||
logger.warning(f"Found symlink at {dest.name}. Not migrating.")
|
||||
elif dest.exists():
|
||||
if source.is_dir():
|
||||
rmtree(source)
|
||||
@@ -1146,7 +1143,7 @@ class ModelManager(object):
|
||||
]
|
||||
for d in empty:
|
||||
os.rmdir(d)
|
||||
print("** Migration is done. Continuing...")
|
||||
cls.logger.info("Migration is done. Continuing...")
|
||||
|
||||
def _resolve_path(
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
@@ -1189,15 +1186,15 @@ class ModelManager(object):
|
||||
|
||||
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
||||
if self.embedding_path is not None:
|
||||
print(f">> Loading embeddings from {self.embedding_path}")
|
||||
self.logger.info(f"Loading embeddings from {self.embedding_path}")
|
||||
for root, _, files in os.walk(self.embedding_path):
|
||||
for name in files:
|
||||
ti_path = os.path.join(root, name)
|
||||
model.textual_inversion_manager.load_textual_inversion(
|
||||
ti_path, defer_injecting_tokens=True
|
||||
)
|
||||
print(
|
||||
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||
self.logger.info(
|
||||
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||
)
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
@@ -1219,7 +1216,7 @@ class ModelManager(object):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
print(" | Calculating sha256 hash of model files")
|
||||
self.logger.debug("Calculating sha256 hash of model files")
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
count = 0
|
||||
@@ -1231,7 +1228,7 @@ class ModelManager(object):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
@@ -1249,13 +1246,13 @@ class ModelManager(object):
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
print(" | Calculating sha256 hash of weights file")
|
||||
self.logger.debug("Calculating sha256 hash of weights file")
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
sha.update(data)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f">> sha256 = {hash}", "(%4.2fs)" % (toc - tic))
|
||||
self.logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
|
||||
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
@@ -1276,12 +1273,12 @@ class ModelManager(object):
|
||||
local_files_only=not Globals.internet_available,
|
||||
)
|
||||
|
||||
print(f" | Loading diffusers VAE from {name_or_path}")
|
||||
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
|
||||
if using_fp16:
|
||||
vae_args.update(torch_dtype=torch.float16)
|
||||
fp_args_list = [{"revision": "fp16"}, {}]
|
||||
else:
|
||||
print(" | Using more accurate float32 precision")
|
||||
self.logger.debug("Using more accurate float32 precision")
|
||||
fp_args_list = [{}]
|
||||
|
||||
vae = None
|
||||
@@ -1305,12 +1302,12 @@ class ModelManager(object):
|
||||
break
|
||||
|
||||
if not vae and deferred_error:
|
||||
print(f"** Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||
self.logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||
|
||||
return vae
|
||||
|
||||
@staticmethod
|
||||
def _delete_model_from_cache(repo_id):
|
||||
@classmethod
|
||||
def _delete_model_from_cache(cls,repo_id):
|
||||
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
||||
|
||||
# I'm sure there is a way to do this with comprehensions
|
||||
@@ -1321,8 +1318,8 @@ class ModelManager(object):
|
||||
for revision in repo.revisions:
|
||||
hashes_to_delete.add(revision.commit_hash)
|
||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||
print(
|
||||
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||
cls.logger.warning(
|
||||
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||
)
|
||||
strategy.execute()
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from compel.prompt_parser import (
|
||||
PromptParser,
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||
@@ -162,8 +163,8 @@ def log_tokenization(
|
||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||
tokenizer,
|
||||
):
|
||||
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||
|
||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||
log_tokenization_for_prompt_object(
|
||||
@@ -237,12 +238,12 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
print(f"{tokenized}\x1b[0m")
|
||||
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
logger.debug(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
print(f"{discarded}\x1b[0m")
|
||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
logger.debug(f"{discarded}\x1b[0m")
|
||||
|
||||
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
||||
@@ -295,8 +296,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||
logger.warning(
|
||||
"Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||
)
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Restoration:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
@@ -8,17 +10,17 @@ class Restoration:
|
||||
# Load GFPGAN
|
||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
||||
if gfpgan.gfpgan_model_exists:
|
||||
print(">> GFPGAN Initialized")
|
||||
logger.info("GFPGAN Initialized")
|
||||
else:
|
||||
print(">> GFPGAN Disabled")
|
||||
logger.info("GFPGAN Disabled")
|
||||
gfpgan = None
|
||||
|
||||
# Load CodeFormer
|
||||
codeformer = self.load_codeformer()
|
||||
if codeformer.codeformer_model_exists:
|
||||
print(">> CodeFormer Initialized")
|
||||
logger.info("CodeFormer Initialized")
|
||||
else:
|
||||
print(">> CodeFormer Disabled")
|
||||
logger.info("CodeFormer Disabled")
|
||||
codeformer = None
|
||||
|
||||
return gfpgan, codeformer
|
||||
@@ -39,5 +41,5 @@ class Restoration:
|
||||
from .realesrgan import ESRGAN
|
||||
|
||||
esrgan = ESRGAN(esrgan_bg_tile)
|
||||
print(">> ESRGAN Initialized")
|
||||
logger.info("ESRGAN Initialized")
|
||||
return esrgan
|
||||
|
||||
@@ -5,6 +5,7 @@ import warnings
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..globals import Globals
|
||||
|
||||
pretrained_model_url = (
|
||||
@@ -23,12 +24,12 @@ class CodeFormerRestoration:
|
||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.codeformer_model_exists:
|
||||
print("## NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||
sys.path.append(os.path.abspath(codeformer_dir))
|
||||
|
||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||
if seed is not None:
|
||||
print(f">> CodeFormer - Restoring Faces for image seed:{seed}")
|
||||
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
@@ -97,7 +98,7 @@ class CodeFormerRestoration:
|
||||
del output
|
||||
torch.cuda.empty_cache()
|
||||
except RuntimeError as error:
|
||||
print(f"\tFailed inference for CodeFormer: {error}.")
|
||||
logger.error(f"Failed inference for CodeFormer: {error}.")
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype("uint8")
|
||||
|
||||
@@ -6,9 +6,9 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class GFPGAN:
|
||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
||||
if not os.path.isabs(gfpgan_model_path):
|
||||
@@ -19,7 +19,7 @@ class GFPGAN:
|
||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.gfpgan_model_exists:
|
||||
print("## NOT FOUND: GFPGAN model not found at " + self.model_path)
|
||||
logger.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
|
||||
return None
|
||||
|
||||
def model_exists(self):
|
||||
@@ -27,7 +27,7 @@ class GFPGAN:
|
||||
|
||||
def process(self, image, strength: float, seed: str = None):
|
||||
if seed is not None:
|
||||
print(f">> GFPGAN - Restoring Faces for image seed:{seed}")
|
||||
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
@@ -47,14 +47,14 @@ class GFPGAN:
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
print(">> Error loading GFPGAN:", file=sys.stderr)
|
||||
logger.error("Error loading GFPGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
os.chdir(cwd)
|
||||
|
||||
if self.gfpgan is None:
|
||||
print(f">> WARNING: GFPGAN not initialized.")
|
||||
print(
|
||||
f">> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
||||
logger.warning("WARNING: GFPGAN not initialized.")
|
||||
logger.warning(
|
||||
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
||||
)
|
||||
|
||||
image = image.convert("RGB")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Outcrop(object):
|
||||
def __init__(
|
||||
@@ -82,7 +82,7 @@ class Outcrop(object):
|
||||
pixels = extents[direction]
|
||||
# round pixels up to the nearest 64
|
||||
pixels = math.ceil(pixels / 64) * 64
|
||||
print(f">> extending image {direction}ward by {pixels} pixels")
|
||||
logger.info(f"extending image {direction}ward by {pixels} pixels")
|
||||
image = self._rotate(image, direction)
|
||||
image = self._extend(image, pixels)
|
||||
image = self._rotate(image, direction, reverse=True)
|
||||
|
||||
@@ -6,18 +6,13 @@ import torch
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageType
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class ESRGAN:
|
||||
def __init__(self, bg_tile_size=400) -> None:
|
||||
self.bg_tile_size = bg_tile_size
|
||||
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
def load_esrgan_bg_upsampler(self, denoise_str):
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
@@ -74,16 +69,16 @@ class ESRGAN:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
print(">> Error loading Real-ESRGAN:", file=sys.stderr)
|
||||
logger.error("Error loading Real-ESRGAN:")
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if upsampler_scale == 0:
|
||||
print(">> Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
||||
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
||||
return image
|
||||
|
||||
if seed is not None:
|
||||
print(
|
||||
f">> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
||||
logger.info(
|
||||
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
||||
)
|
||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
||||
image = image.convert("RGB")
|
||||
|
||||
@@ -14,6 +14,7 @@ from PIL import Image, ImageFilter
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.assets.web as web_assets
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .globals import global_cache_dir
|
||||
from .util import CPU_DEVICE
|
||||
|
||||
@@ -40,8 +41,8 @@ class SafetyChecker(object):
|
||||
cache_dir=safety_model_path,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"** An error was encountered while installing the safety checker:"
|
||||
logger.error(
|
||||
"An error was encountered while installing the safety checker:"
|
||||
)
|
||||
print(traceback.format_exc())
|
||||
|
||||
@@ -65,8 +66,8 @@ class SafetyChecker(object):
|
||||
)
|
||||
self.safety_checker.to(CPU_DEVICE) # offload
|
||||
if has_nsfw_concept[0]:
|
||||
print(
|
||||
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||
logger.warning(
|
||||
"An image with potential non-safe content has been detected. A blurred image will be returned."
|
||||
)
|
||||
return self.blur(image)
|
||||
else:
|
||||
|
||||
@@ -17,6 +17,7 @@ from huggingface_hub import (
|
||||
hf_hub_url,
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
@@ -66,11 +67,11 @@ class HuggingFaceConceptsLibrary(object):
|
||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
except Exception as e:
|
||||
print(
|
||||
f" ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||
logger.warning(
|
||||
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||
)
|
||||
print(
|
||||
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
logger.warning(
|
||||
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
)
|
||||
return self.concept_list
|
||||
else:
|
||||
@@ -83,7 +84,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
be downloaded.
|
||||
"""
|
||||
if not concept_name in self.list_concepts():
|
||||
print(
|
||||
logger.warning(
|
||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||
)
|
||||
return None
|
||||
@@ -221,7 +222,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
if chunk == 0:
|
||||
bytes += total
|
||||
|
||||
print(f">> Downloading {repo_id}...", end="")
|
||||
logger.info(f"Downloading {repo_id}...", end="")
|
||||
try:
|
||||
for file in (
|
||||
"README.md",
|
||||
@@ -235,22 +236,22 @@ class HuggingFaceConceptsLibrary(object):
|
||||
)
|
||||
except ul_error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
print(
|
||||
logger.warning(
|
||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
logger.warning(
|
||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
except ul_error.URLError as e:
|
||||
print(
|
||||
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
logger.error(
|
||||
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
print("...{:.2f}Kb".format(bytes / 1024))
|
||||
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
||||
return succeeded
|
||||
|
||||
def _concept_id(self, concept_name: str) -> str:
|
||||
|
||||
@@ -13,9 +13,9 @@ from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from torch import nn
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...util import torch_dtype
|
||||
|
||||
|
||||
class CrossAttentionType(enum.Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
@@ -421,7 +421,7 @@ def get_cross_attention_modules(
|
||||
expected_count = 16
|
||||
if cross_attention_modules_in_model_count != expected_count:
|
||||
# non-fatal error but .swap() won't work.
|
||||
print(
|
||||
logger.error(
|
||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
||||
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from .cross_attention_control import (
|
||||
@@ -466,10 +467,14 @@ class InvokeAIDiffuserComponent:
|
||||
outside = torch.count_nonzero(
|
||||
(latents < -current_threshold) | (latents > current_threshold)
|
||||
)
|
||||
print(
|
||||
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||
logger.info(
|
||||
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
|
||||
)
|
||||
logger.debug(
|
||||
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
|
||||
)
|
||||
logger.debug(
|
||||
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||
)
|
||||
|
||||
if maxval < current_threshold and minval > -current_threshold:
|
||||
@@ -496,9 +501,11 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
if self.debug_thresholding:
|
||||
print(
|
||||
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
|
||||
logger.debug(
|
||||
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
|
||||
)
|
||||
logger.debug(
|
||||
f"{num_altered / latents.numel() * 100:.2f}% values altered"
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
@@ -10,7 +10,7 @@ from torchvision.utils import make_grid
|
||||
|
||||
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ def mkdirs(paths):
|
||||
def mkdir_and_rename(path):
|
||||
if os.path.exists(path):
|
||||
new_name = path + "_archived_" + get_timestamp()
|
||||
print("Path already exists. Rename it to [{:s}]".format(new_name))
|
||||
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
|
||||
os.replace(path, new_name)
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
@dataclass
|
||||
@@ -59,12 +60,12 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||
): # in case a token with literal angle brackets encountered
|
||||
print(f">> Loaded local embedding for trigger {concept_name}")
|
||||
logger.info(f"Loaded local embedding for trigger {concept_name}")
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
print(f">> Loaded remote embedding for trigger {concept_name}")
|
||||
logger.info(f"Loaded remote embedding for trigger {concept_name}")
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||
|
||||
@@ -85,8 +86,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||
for embedding_info in embedding_list:
|
||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||
print(
|
||||
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||
logger.warning(
|
||||
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -105,8 +106,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else f"<{ckpt_path.stem}>"
|
||||
)
|
||||
print(
|
||||
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||
logger.info(
|
||||
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||
)
|
||||
trigger_str = replacement_trigger_str
|
||||
|
||||
@@ -120,8 +121,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||
|
||||
except ValueError as e:
|
||||
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
print(f" | The error was {str(e)}")
|
||||
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
logger.debug(f"The error was {str(e)}")
|
||||
|
||||
def _add_textual_inversion(
|
||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||
@@ -133,8 +134,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
:return: The token id for the added embedding, either existing or newly-added.
|
||||
"""
|
||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||
print(
|
||||
f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||
logger.warning(
|
||||
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||
)
|
||||
return
|
||||
if not self.full_precision:
|
||||
@@ -155,11 +156,11 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
|
||||
except ValueError as e:
|
||||
if str(e).startswith("Warning"):
|
||||
print(f">> {str(e)}")
|
||||
logger.warning(f"{str(e)}")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
print(
|
||||
f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||
logger.error(
|
||||
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -219,16 +220,16 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
for ti in self.textual_inversions:
|
||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||
if ti.embedding_vector_length > 1:
|
||||
print(
|
||||
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||
logger.info(
|
||||
f"Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||
)
|
||||
try:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
except ValueError as e:
|
||||
print(
|
||||
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||
logger.debug(
|
||||
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||
)
|
||||
print(f" | The error was {str(e)}")
|
||||
logger.debug(f"The error was {str(e)}")
|
||||
continue
|
||||
injected_token_ids.append(ti.trigger_token_id)
|
||||
injected_token_ids.extend(ti.pad_token_ids)
|
||||
@@ -306,16 +307,16 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
if suffix in [".pt",".ckpt",".bin"]:
|
||||
scan_result = scan_file_path(embedding_file)
|
||||
if scan_result.infected_files > 0:
|
||||
print(
|
||||
f" ** Security Issues Found in Model: {scan_result.issues_count}"
|
||||
logger.critical(
|
||||
f"Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
print(" ** For your safety, InvokeAI will not load this embed.")
|
||||
logger.critical("For your safety, InvokeAI will not load this embed.")
|
||||
return list()
|
||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||
else:
|
||||
ckpt = safetensors.torch.load_file(embedding_file)
|
||||
except Exception as e:
|
||||
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||
return list()
|
||||
|
||||
# try to figure out what kind of embedding file it is and parse accordingly
|
||||
@@ -334,7 +335,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
|
||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v1 embedding file: {basename}')
|
||||
logger.debug(f'Loading v1 embedding file: {basename}')
|
||||
|
||||
embeddings = list()
|
||||
token_counter = -1
|
||||
@@ -342,7 +343,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
if token_counter < 0:
|
||||
trigger = embedding_ckpt["name"]
|
||||
elif token_counter == 0:
|
||||
trigger = f'<basename>'
|
||||
trigger = '<basename>'
|
||||
else:
|
||||
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||
token_counter += 1
|
||||
@@ -365,7 +366,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
This handles embedding .pt file variant #2.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v2 embedding file: {basename}')
|
||||
logger.debug(f'Loading v2 embedding file: {basename}')
|
||||
embeddings = list()
|
||||
|
||||
if isinstance(
|
||||
@@ -384,7 +385,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
else:
|
||||
print(f" ** {basename}: Unrecognized embedding format")
|
||||
logger.warning(f"{basename}: Unrecognized embedding format")
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -393,7 +394,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v3 embedding file: {basename}')
|
||||
logger.debug(f'Loading v3 embedding file: {basename}')
|
||||
embedding = embedding_ckpt['emb_params']
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = f'<{basename}>',
|
||||
@@ -411,11 +412,11 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
basename = Path(filepath).stem
|
||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||
|
||||
print(f' | Loading v4 embedding file: {short_path}')
|
||||
logger.debug(f'Loading v4 embedding file: {short_path}')
|
||||
|
||||
embeddings = list()
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
print(f" ** Invalid embeddings file: {short_path}")
|
||||
logger.warning(f"Invalid embeddings file: {short_path}")
|
||||
else:
|
||||
for token,embedding in embedding_ckpt.items():
|
||||
embedding_info = EmbeddingInfo(
|
||||
|
||||
109
invokeai/backend/util/logging.py
Normal file
109
invokeai/backend/util/logging.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
|
||||
|
||||
"""invokeai.util.logging
|
||||
|
||||
Logging class for InvokeAI that produces console messages that follow
|
||||
the conventions established in InvokeAI 1.X through 2.X.
|
||||
|
||||
|
||||
One way to use it:
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.getLogger(__name__)
|
||||
logger.critical('this is critical')
|
||||
logger.error('this is an error')
|
||||
logger.warning('this is a warning')
|
||||
logger.info('this is info')
|
||||
logger.debug('this is debugging')
|
||||
|
||||
Console messages:
|
||||
### this is critical
|
||||
*** this is an error ***
|
||||
** this is a warning
|
||||
>> this is info
|
||||
| this is debugging
|
||||
|
||||
Another way:
|
||||
import invokeai.backend.util.logging as ialog
|
||||
ialogger.debug('this is a debugging message')
|
||||
"""
|
||||
import logging
|
||||
|
||||
# module level functions
|
||||
def debug(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
|
||||
|
||||
def info(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().info(msg, *args, **kwargs)
|
||||
|
||||
def warning(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().warning(msg, *args, **kwargs)
|
||||
|
||||
def error(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
|
||||
|
||||
def critical(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
|
||||
|
||||
def log(level, msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().log(level, msg, *args, **kwargs)
|
||||
|
||||
def disable(level=logging.CRITICAL):
|
||||
InvokeAILogger.getLogger().disable(level)
|
||||
|
||||
def basicConfig(**kwargs):
|
||||
InvokeAILogger.getLogger().basicConfig(**kwargs)
|
||||
|
||||
def getLogger(name: str=None)->logging.Logger:
|
||||
return InvokeAILogger.getLogger(name)
|
||||
|
||||
class InvokeAILogFormatter(logging.Formatter):
|
||||
'''
|
||||
Repurposed from:
|
||||
https://stackoverflow.com/questions/14844970/modifying-logging-message-format-based-on-message-logging-level-in-python3
|
||||
'''
|
||||
crit_fmt = "### %(msg)s"
|
||||
err_fmt = "*** %(msg)s"
|
||||
warn_fmt = "** %(msg)s"
|
||||
info_fmt = ">> %(msg)s"
|
||||
dbg_fmt = " | %(msg)s"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(fmt="%(levelno)d: %(msg)s", datefmt=None, style='%')
|
||||
|
||||
def format(self, record):
|
||||
# Remember the format used when the logging module
|
||||
# was installed (in the event that this formatter is
|
||||
# used with the vanilla logging module.
|
||||
format_orig = self._style._fmt
|
||||
if record.levelno == logging.DEBUG:
|
||||
self._style._fmt = InvokeAILogFormatter.dbg_fmt
|
||||
if record.levelno == logging.INFO:
|
||||
self._style._fmt = InvokeAILogFormatter.info_fmt
|
||||
if record.levelno == logging.WARNING:
|
||||
self._style._fmt = InvokeAILogFormatter.warn_fmt
|
||||
if record.levelno == logging.ERROR:
|
||||
self._style._fmt = InvokeAILogFormatter.err_fmt
|
||||
if record.levelno == logging.CRITICAL:
|
||||
self._style._fmt = InvokeAILogFormatter.crit_fmt
|
||||
|
||||
# parent class does the work
|
||||
result = super().format(record)
|
||||
self._style._fmt = format_orig
|
||||
return result
|
||||
|
||||
class InvokeAILogger(object):
|
||||
loggers = dict()
|
||||
|
||||
@classmethod
|
||||
def getLogger(self, name:str='invokeai')->logging.Logger:
|
||||
if name not in self.loggers:
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
ch = logging.StreamHandler()
|
||||
fmt = InvokeAILogFormatter()
|
||||
ch.setFormatter(fmt)
|
||||
logger.addHandler(ch)
|
||||
self.loggers[name] = logger
|
||||
return self.loggers[name]
|
||||
@@ -18,6 +18,7 @@ import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .devices import torch_dtype
|
||||
|
||||
|
||||
@@ -38,7 +39,7 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
logger.warning("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
@@ -80,8 +81,8 @@ def mean_flat(tensor):
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(
|
||||
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
||||
logger.debug(
|
||||
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
||||
)
|
||||
return total_params
|
||||
|
||||
@@ -132,8 +133,8 @@ def parallel_data_prefetch(
|
||||
raise ValueError("list expected but function got ndarray.")
|
||||
elif isinstance(data, abc.Iterable):
|
||||
if isinstance(data, dict):
|
||||
print(
|
||||
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||
logger.warning(
|
||||
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||
)
|
||||
data = list(data.values())
|
||||
if target_data_type == "ndarray":
|
||||
@@ -175,7 +176,7 @@ def parallel_data_prefetch(
|
||||
processes += [p]
|
||||
|
||||
# start processes
|
||||
print("Start prefetching...")
|
||||
logger.info("Start prefetching...")
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
@@ -194,7 +195,7 @@ def parallel_data_prefetch(
|
||||
gather_res[res[0]] = res[1]
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e)
|
||||
logger.error("Exception: ", e)
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
|
||||
@@ -202,7 +203,7 @@ def parallel_data_prefetch(
|
||||
finally:
|
||||
for p in processes:
|
||||
p.join()
|
||||
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||
logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||
|
||||
if target_data_type == "ndarray":
|
||||
if not isinstance(gather_res[0], np.ndarray):
|
||||
@@ -318,23 +319,23 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
resp = requests.get(url, headers=header, stream=True) # new request with range
|
||||
|
||||
if exist_size > content_length:
|
||||
print("* corrupt existing file found. re-downloading")
|
||||
logger.warning("corrupt existing file found. re-downloading")
|
||||
os.remove(dest)
|
||||
exist_size = 0
|
||||
|
||||
if resp.status_code == 416 or exist_size == content_length:
|
||||
print(f"* {dest}: complete file found. Skipping.")
|
||||
logger.warning(f"{dest}: complete file found. Skipping.")
|
||||
return dest
|
||||
elif resp.status_code == 206 or exist_size > 0:
|
||||
print(f"* {dest}: partial file found. Resuming...")
|
||||
logger.warning(f"{dest}: partial file found. Resuming...")
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {dest}: {resp.reason}")
|
||||
logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
|
||||
else:
|
||||
print(f"* {dest}: Downloading...")
|
||||
logger.error(f"{dest}: Downloading...")
|
||||
|
||||
try:
|
||||
if content_length < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
|
||||
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(dest, open_mode) as file, tqdm(
|
||||
@@ -349,7 +350,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while downloading {dest}: {str(e)}")
|
||||
logger.error(f"An error occurred while downloading {dest}: {str(e)}")
|
||||
return None
|
||||
|
||||
return dest
|
||||
|
||||
@@ -19,6 +19,7 @@ from PIL import Image
|
||||
from PIL.Image import Image as ImageType
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
import invokeai.frontend.web.dist as frontend
|
||||
|
||||
from .. import Generate
|
||||
@@ -77,7 +78,6 @@ class InvokeAIWebServer:
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
# Socket IO
|
||||
logger = True if args.web_verbose else False
|
||||
engineio_logger = True if args.web_verbose else False
|
||||
max_http_buffer_size = 10000000
|
||||
|
||||
@@ -213,7 +213,7 @@ class InvokeAIWebServer:
|
||||
self.load_socketio_listeners(self.socketio)
|
||||
|
||||
if args.gui:
|
||||
print(">> Launching Invoke AI GUI")
|
||||
logger.info("Launching Invoke AI GUI")
|
||||
try:
|
||||
from flaskwebgui import FlaskUI
|
||||
|
||||
@@ -231,17 +231,17 @@ class InvokeAIWebServer:
|
||||
sys.exit(0)
|
||||
else:
|
||||
useSSL = args.certfile or args.keyfile
|
||||
print(">> Started Invoke AI Web Server")
|
||||
logger.info("Started Invoke AI Web Server")
|
||||
if self.host == "0.0.0.0":
|
||||
print(
|
||||
logger.info(
|
||||
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
||||
logger.info(
|
||||
"Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
||||
)
|
||||
print(
|
||||
f">> Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
|
||||
logger.info(
|
||||
f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
|
||||
)
|
||||
if not useSSL:
|
||||
self.socketio.run(app=self.app, host=self.host, port=self.port)
|
||||
@@ -273,7 +273,7 @@ class InvokeAIWebServer:
|
||||
# path for thumbnail images
|
||||
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
|
||||
# txt log
|
||||
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
|
||||
self.log_path = os.path.join(self.result_path, "invoke_logger.txt")
|
||||
# make all output paths
|
||||
[
|
||||
os.makedirs(path, exist_ok=True)
|
||||
@@ -290,7 +290,7 @@ class InvokeAIWebServer:
|
||||
def load_socketio_listeners(self, socketio):
|
||||
@socketio.on("requestSystemConfig")
|
||||
def handle_request_capabilities():
|
||||
print(">> System config requested")
|
||||
logger.info("System config requested")
|
||||
config = self.get_system_config()
|
||||
config["model_list"] = self.generate.model_manager.list_models()
|
||||
config["infill_methods"] = infill_methods()
|
||||
@@ -330,7 +330,7 @@ class InvokeAIWebServer:
|
||||
if model_name in current_model_list:
|
||||
update = True
|
||||
|
||||
print(f">> Adding New Model: {model_name}")
|
||||
logger.info(f"Adding New Model: {model_name}")
|
||||
|
||||
self.generate.model_manager.add_model(
|
||||
model_name=model_name,
|
||||
@@ -348,14 +348,14 @@ class InvokeAIWebServer:
|
||||
"update": update,
|
||||
},
|
||||
)
|
||||
print(f">> New Model Added: {model_name}")
|
||||
logger.info(f"New Model Added: {model_name}")
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
|
||||
@socketio.on("deleteModel")
|
||||
def handle_delete_model(model_name: str):
|
||||
try:
|
||||
print(f">> Deleting Model: {model_name}")
|
||||
logger.info(f"Deleting Model: {model_name}")
|
||||
self.generate.model_manager.del_model(model_name)
|
||||
self.generate.model_manager.commit(opt.conf)
|
||||
updated_model_list = self.generate.model_manager.list_models()
|
||||
@@ -366,14 +366,14 @@ class InvokeAIWebServer:
|
||||
"model_list": updated_model_list,
|
||||
},
|
||||
)
|
||||
print(f">> Model Deleted: {model_name}")
|
||||
logger.info(f"Model Deleted: {model_name}")
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
|
||||
@socketio.on("requestModelChange")
|
||||
def handle_set_model(model_name: str):
|
||||
try:
|
||||
print(f">> Model change requested: {model_name}")
|
||||
logger.info(f"Model change requested: {model_name}")
|
||||
model = self.generate.set_model(model_name)
|
||||
model_list = self.generate.model_manager.list_models()
|
||||
if model is None:
|
||||
@@ -454,7 +454,7 @@ class InvokeAIWebServer:
|
||||
"update": True,
|
||||
},
|
||||
)
|
||||
print(f">> Model Converted: {model_name}")
|
||||
logger.info(f"Model Converted: {model_name}")
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
|
||||
@@ -490,7 +490,7 @@ class InvokeAIWebServer:
|
||||
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||
"vae", None
|
||||
):
|
||||
print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||
logger.info(f"Using configured VAE assigned to {models_to_merge[0]}")
|
||||
merged_model_config.update(vae=vae)
|
||||
|
||||
self.generate.model_manager.import_diffuser_model(
|
||||
@@ -507,8 +507,8 @@ class InvokeAIWebServer:
|
||||
"update": True,
|
||||
},
|
||||
)
|
||||
print(f">> Models Merged: {models_to_merge}")
|
||||
print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
logger.info(f"Models Merged: {models_to_merge}")
|
||||
logger.info(f"New Model Added: {model_merge_info['merged_model_name']}")
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
|
||||
@@ -698,7 +698,7 @@ class InvokeAIWebServer:
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f">> Unable to load {path}")
|
||||
logger.info(f"Unable to load {path}")
|
||||
socketio.emit(
|
||||
"error", {"message": f"Unable to load {path}: {str(e)}"}
|
||||
)
|
||||
@@ -735,9 +735,9 @@ class InvokeAIWebServer:
|
||||
printable_parameters["init_mask"][:64] + "..."
|
||||
)
|
||||
|
||||
print(f"\n>> Image Generation Parameters:\n\n{printable_parameters}\n")
|
||||
print(f">> ESRGAN Parameters: {esrgan_parameters}")
|
||||
print(f">> Facetool Parameters: {facetool_parameters}")
|
||||
logger.info(f"Image Generation Parameters:\n\n{printable_parameters}\n")
|
||||
logger.info(f"ESRGAN Parameters: {esrgan_parameters}")
|
||||
logger.info(f"Facetool Parameters: {facetool_parameters}")
|
||||
|
||||
self.generate_images(
|
||||
generation_parameters,
|
||||
@@ -750,8 +750,8 @@ class InvokeAIWebServer:
|
||||
@socketio.on("runPostprocessing")
|
||||
def handle_run_postprocessing(original_image, postprocessing_parameters):
|
||||
try:
|
||||
print(
|
||||
f'>> Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
||||
logger.info(
|
||||
f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
||||
)
|
||||
|
||||
progress = Progress()
|
||||
@@ -861,14 +861,14 @@ class InvokeAIWebServer:
|
||||
|
||||
@socketio.on("cancel")
|
||||
def handle_cancel():
|
||||
print(">> Cancel processing requested")
|
||||
logger.info("Cancel processing requested")
|
||||
self.canceled.set()
|
||||
|
||||
# TODO: I think this needs a safety mechanism.
|
||||
@socketio.on("deleteImage")
|
||||
def handle_delete_image(url, thumbnail, uuid, category):
|
||||
try:
|
||||
print(f'>> Delete requested "{url}"')
|
||||
logger.info(f'Delete requested "{url}"')
|
||||
from send2trash import send2trash
|
||||
|
||||
path = self.get_image_path_from_url(url)
|
||||
@@ -1263,7 +1263,7 @@ class InvokeAIWebServer:
|
||||
image, os.path.basename(path), self.thumbnail_image_path
|
||||
)
|
||||
|
||||
print(f'\n\n>> Image generated: "{path}"\n')
|
||||
logger.info(f'Image generated: "{path}"\n')
|
||||
self.write_log_message(f'[Generated] "{path}": {command}')
|
||||
|
||||
if progress.total_iterations > progress.current_iteration:
|
||||
@@ -1329,7 +1329,7 @@ class InvokeAIWebServer:
|
||||
except Exception as e:
|
||||
# Clear the CUDA cache on an exception
|
||||
self.empty_cuda_cache()
|
||||
print(e)
|
||||
logger.error(e)
|
||||
self.handle_exceptions(e)
|
||||
|
||||
def empty_cuda_cache(self):
|
||||
|
||||
@@ -16,6 +16,7 @@ if sys.platform == "darwin":
|
||||
import pyparsing # type: ignore
|
||||
|
||||
import invokeai.version as invokeai
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...backend import Generate, ModelManager
|
||||
from ...backend.args import Args, dream_cmd_from_png, metadata_dumps, metadata_from_png
|
||||
@@ -69,7 +70,7 @@ def main():
|
||||
# run any post-install patches needed
|
||||
run_patches()
|
||||
|
||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
||||
|
||||
if not args.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
@@ -78,8 +79,8 @@ def main():
|
||||
opt, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||
)
|
||||
|
||||
print(f">> {invokeai.__app_name__}, version {invokeai.__version__}")
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
logger.info(f"{invokeai.__app_name__}, version {invokeai.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# loading here to avoid long delays on startup
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
@@ -121,7 +122,7 @@ def main():
|
||||
else:
|
||||
raise FileNotFoundError(f"{opt.infile} not found.")
|
||||
except (FileNotFoundError, IOError) as e:
|
||||
print(f"{e}. Aborting.")
|
||||
logger.critical('Aborted',exc_info=True)
|
||||
sys.exit(-1)
|
||||
|
||||
# creating a Generate object:
|
||||
@@ -142,12 +143,12 @@ def main():
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(opt, e)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f"{e}. Aborting.")
|
||||
except (IOError, KeyError):
|
||||
logger.critical("Aborted",exc_info=True)
|
||||
sys.exit(-1)
|
||||
|
||||
if opt.seamless:
|
||||
print(">> changed to seamless tiling mode")
|
||||
logger.info("Changed to seamless tiling mode")
|
||||
|
||||
# preload the model
|
||||
try:
|
||||
@@ -180,9 +181,7 @@ def main():
|
||||
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
|
||||
)
|
||||
except Exception:
|
||||
print(">> An error occurred:")
|
||||
traceback.print_exc()
|
||||
|
||||
logger.error("An error occurred",exc_info=True)
|
||||
|
||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||
def main_loop(gen, opt):
|
||||
@@ -248,7 +247,7 @@ def main_loop(gen, opt):
|
||||
if not opt.prompt:
|
||||
oldargs = metadata_from_png(opt.init_img)
|
||||
opt.prompt = oldargs.prompt
|
||||
print(f'>> Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
|
||||
logger.info(f'Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
|
||||
except (OSError, AttributeError, KeyError):
|
||||
pass
|
||||
|
||||
@@ -265,9 +264,9 @@ def main_loop(gen, opt):
|
||||
if opt.init_img is not None and re.match("^-\\d+$", opt.init_img):
|
||||
try:
|
||||
opt.init_img = last_results[int(opt.init_img)][0]
|
||||
print(f">> Reusing previous image {opt.init_img}")
|
||||
logger.info(f"Reusing previous image {opt.init_img}")
|
||||
except IndexError:
|
||||
print(f">> No previous initial image at position {opt.init_img} found")
|
||||
logger.info(f"No previous initial image at position {opt.init_img} found")
|
||||
opt.init_img = None
|
||||
continue
|
||||
|
||||
@@ -288,9 +287,9 @@ def main_loop(gen, opt):
|
||||
if opt.seed is not None and opt.seed < 0 and operation != "postprocess":
|
||||
try:
|
||||
opt.seed = last_results[opt.seed][1]
|
||||
print(f">> Reusing previous seed {opt.seed}")
|
||||
logger.info(f"Reusing previous seed {opt.seed}")
|
||||
except IndexError:
|
||||
print(f">> No previous seed at position {opt.seed} found")
|
||||
logger.info(f"No previous seed at position {opt.seed} found")
|
||||
opt.seed = None
|
||||
continue
|
||||
|
||||
@@ -309,7 +308,7 @@ def main_loop(gen, opt):
|
||||
subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))]
|
||||
current_outdir = os.path.join(opt.outdir, subdir)
|
||||
|
||||
print('Writing files to directory: "' + current_outdir + '"')
|
||||
logger.info('Writing files to directory: "' + current_outdir + '"')
|
||||
|
||||
# make sure the output directory exists
|
||||
if not os.path.exists(current_outdir):
|
||||
@@ -438,15 +437,14 @@ def main_loop(gen, opt):
|
||||
catch_interrupts=catch_ctrl_c,
|
||||
**vars(opt),
|
||||
)
|
||||
except (PromptParser.ParsingException, pyparsing.ParseException) as e:
|
||||
print("** An error occurred while processing your prompt **")
|
||||
print(f"** {str(e)} **")
|
||||
except (PromptParser.ParsingException, pyparsing.ParseException):
|
||||
logger.error("An error occurred while processing your prompt",exc_info=True)
|
||||
elif operation == "postprocess":
|
||||
print(f">> fixing {opt.prompt}")
|
||||
logger.info(f"fixing {opt.prompt}")
|
||||
opt.last_operation = do_postprocess(gen, opt, image_writer)
|
||||
|
||||
elif operation == "mask":
|
||||
print(f">> generating masks from {opt.prompt}")
|
||||
logger.info(f"generating masks from {opt.prompt}")
|
||||
do_textmask(gen, opt, image_writer)
|
||||
|
||||
if opt.grid and len(grid_images) > 0:
|
||||
@@ -469,12 +467,12 @@ def main_loop(gen, opt):
|
||||
)
|
||||
results = [[path, formatted_dream_prompt]]
|
||||
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
except AssertionError:
|
||||
logger.error(e)
|
||||
continue
|
||||
|
||||
except OSError as e:
|
||||
print(e)
|
||||
logger.error(e)
|
||||
continue
|
||||
|
||||
print("Outputs:")
|
||||
@@ -513,7 +511,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
gen.set_model(model_name)
|
||||
add_embedding_terms(gen, completer)
|
||||
except KeyError as e:
|
||||
print(str(e))
|
||||
logger.error(e)
|
||||
except Exception as e:
|
||||
report_model_error(opt, e)
|
||||
completer.add_history(command)
|
||||
@@ -527,8 +525,8 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
elif command.startswith("!import"):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print(
|
||||
"** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
|
||||
logger.warning(
|
||||
"please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
@@ -541,7 +539,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
elif command.startswith(("!convert", "!optimize")):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print("** please provide the path to a .ckpt or .safetensors model")
|
||||
logger.warning("please provide the path to a .ckpt or .safetensors model")
|
||||
else:
|
||||
try:
|
||||
convert_model(path[1], gen, opt, completer)
|
||||
@@ -553,7 +551,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
elif command.startswith("!edit"):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print("** please provide the name of a model")
|
||||
logger.warning("please provide the name of a model")
|
||||
else:
|
||||
edit_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
@@ -562,7 +560,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
elif command.startswith("!del"):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print("** please provide the name of a model")
|
||||
logger.warning("please provide the name of a model")
|
||||
else:
|
||||
del_config(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
@@ -642,8 +640,8 @@ def import_model(model_path: str, gen, opt, completer):
|
||||
try:
|
||||
default_name = url_attachment_name(model_path)
|
||||
default_name = Path(default_name).stem
|
||||
except Exception as e:
|
||||
print(f"** URL: {str(e)}")
|
||||
except Exception:
|
||||
logger.warning(f"A problem occurred while assigning the name of the downloaded model",exc_info=True)
|
||||
model_name, model_desc = _get_model_name_and_desc(
|
||||
gen.model_manager,
|
||||
completer,
|
||||
@@ -664,11 +662,11 @@ def import_model(model_path: str, gen, opt, completer):
|
||||
model_config_file=config_file,
|
||||
)
|
||||
if not imported_name:
|
||||
print("** Aborting import.")
|
||||
logger.error("Aborting import.")
|
||||
return
|
||||
|
||||
if not _verify_load(imported_name, gen):
|
||||
print("** model failed to load. Discarding configuration entry")
|
||||
logger.error("model failed to load. Discarding configuration entry")
|
||||
gen.model_manager.del_model(imported_name)
|
||||
return
|
||||
if click.confirm("Make this the default model?", default=False):
|
||||
@@ -676,7 +674,7 @@ def import_model(model_path: str, gen, opt, completer):
|
||||
|
||||
gen.model_manager.commit(opt.conf)
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
print(f">> {imported_name} successfully installed")
|
||||
logger.info(f"{imported_name} successfully installed")
|
||||
|
||||
def _pick_configuration_file(completer)->Path:
|
||||
print(
|
||||
@@ -720,21 +718,21 @@ Please select the type of this model:
|
||||
return choice
|
||||
|
||||
def _verify_load(model_name: str, gen) -> bool:
|
||||
print(">> Verifying that new model loads...")
|
||||
logger.info("Verifying that new model loads...")
|
||||
current_model = gen.model_name
|
||||
try:
|
||||
if not gen.set_model(model_name):
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"** model failed to load: {str(e)}")
|
||||
print(
|
||||
logger.warning(f"model failed to load: {str(e)}")
|
||||
logger.warning(
|
||||
"** note that importing 2.X checkpoints is not supported. Please use !convert_model instead."
|
||||
)
|
||||
return False
|
||||
if click.confirm("Keep model loaded?", default=True):
|
||||
gen.set_model(model_name)
|
||||
else:
|
||||
print(">> Restoring previous model")
|
||||
logger.info("Restoring previous model")
|
||||
gen.set_model(current_model)
|
||||
return True
|
||||
|
||||
@@ -757,7 +755,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
ckpt_path = None
|
||||
original_config_file = None
|
||||
if model_name_or_path == gen.model_name:
|
||||
print("** Can't convert the active model. !switch to another model first. **")
|
||||
logger.warning("Can't convert the active model. !switch to another model first. **")
|
||||
return
|
||||
elif model_info := manager.model_info(model_name_or_path):
|
||||
if "weights" in model_info:
|
||||
@@ -767,7 +765,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
model_description = model_info["description"]
|
||||
vae_path = model_info.get("vae")
|
||||
else:
|
||||
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
||||
logger.warning(f"{model_name_or_path} is not a legacy .ckpt weights file")
|
||||
return
|
||||
model_name = manager.convert_and_import(
|
||||
ckpt_path,
|
||||
@@ -788,16 +786,16 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
manager.commit(opt.conf)
|
||||
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
|
||||
ckpt_path.unlink(missing_ok=True)
|
||||
print(f"{ckpt_path} deleted")
|
||||
logger.warning(f"{ckpt_path} deleted")
|
||||
|
||||
|
||||
def del_config(model_name: str, gen, opt, completer):
|
||||
current_model = gen.model_name
|
||||
if model_name == current_model:
|
||||
print("** Can't delete active model. !switch to another model first. **")
|
||||
logger.warning("Can't delete active model. !switch to another model first. **")
|
||||
return
|
||||
if model_name not in gen.model_manager.config:
|
||||
print(f"** Unknown model {model_name}")
|
||||
logger.warning(f"Unknown model {model_name}")
|
||||
return
|
||||
|
||||
if not click.confirm(
|
||||
@@ -810,17 +808,17 @@ def del_config(model_name: str, gen, opt, completer):
|
||||
)
|
||||
gen.model_manager.del_model(model_name, delete_files=delete_completely)
|
||||
gen.model_manager.commit(opt.conf)
|
||||
print(f"** {model_name} deleted")
|
||||
logger.warning(f"{model_name} deleted")
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
|
||||
|
||||
def edit_model(model_name: str, gen, opt, completer):
|
||||
manager = gen.model_manager
|
||||
if not (info := manager.model_info(model_name)):
|
||||
print(f"** Unknown model {model_name}")
|
||||
logger.warning(f"** Unknown model {model_name}")
|
||||
return
|
||||
|
||||
print(f"\n>> Editing model {model_name} from configuration file {opt.conf}")
|
||||
print()
|
||||
logger.info(f"Editing model {model_name} from configuration file {opt.conf}")
|
||||
new_name = _get_model_name(manager.list_models(), completer, model_name)
|
||||
|
||||
for attribute in info.keys():
|
||||
@@ -858,7 +856,7 @@ def edit_model(model_name: str, gen, opt, completer):
|
||||
manager.set_default_model(new_name)
|
||||
manager.commit(opt.conf)
|
||||
completer.update_models(manager.list_models())
|
||||
print(">> Model successfully updated")
|
||||
logger.info("Model successfully updated")
|
||||
|
||||
|
||||
def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
||||
@@ -869,11 +867,11 @@ def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
||||
if len(model_name) == 0:
|
||||
model_name = default_name
|
||||
if not re.match("^[\w._+:/-]+$", model_name):
|
||||
print(
|
||||
'** model name must contain only words, digits and the characters "._+:/-" **'
|
||||
logger.warning(
|
||||
'model name must contain only words, digits and the characters "._+:/-" **'
|
||||
)
|
||||
elif model_name != default_name and model_name in existing_names:
|
||||
print(f"** the name {model_name} is already in use. Pick another.")
|
||||
logger.warning(f"the name {model_name} is already in use. Pick another.")
|
||||
else:
|
||||
done = True
|
||||
return model_name
|
||||
@@ -940,11 +938,10 @@ def do_postprocess(gen, opt, callback):
|
||||
opt=opt,
|
||||
)
|
||||
except OSError:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(f"** {file_path}: file could not be read")
|
||||
logger.error(f"{file_path}: file could not be read",exc_info=True)
|
||||
return
|
||||
except (KeyError, AttributeError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
logger.error(f"an error occurred while applying the {tool} postprocessor",exc_info=True)
|
||||
return
|
||||
return opt.last_operation
|
||||
|
||||
@@ -999,13 +996,13 @@ def prepare_image_metadata(
|
||||
try:
|
||||
filename = opt.fnformat.format(**wildcards)
|
||||
except KeyError as e:
|
||||
print(
|
||||
f"** The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
|
||||
logger.error(
|
||||
f"The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
|
||||
)
|
||||
filename = f"{prefix}.{seed}.png"
|
||||
except IndexError:
|
||||
print(
|
||||
"** The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
|
||||
logger.error(
|
||||
"The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
|
||||
)
|
||||
filename = f"{prefix}.{seed}.png"
|
||||
|
||||
@@ -1094,14 +1091,14 @@ def split_variations(variations_string) -> list:
|
||||
for part in variations_string.split(","):
|
||||
seed_and_weight = part.split(":")
|
||||
if len(seed_and_weight) != 2:
|
||||
print(f'** Could not parse with_variation part "{part}"')
|
||||
logger.warning(f'Could not parse with_variation part "{part}"')
|
||||
broken = True
|
||||
break
|
||||
try:
|
||||
seed = int(seed_and_weight[0])
|
||||
weight = float(seed_and_weight[1])
|
||||
except ValueError:
|
||||
print(f'** Could not parse with_variation part "{part}"')
|
||||
logger.warning(f'Could not parse with_variation part "{part}"')
|
||||
broken = True
|
||||
break
|
||||
parts.append([seed, weight])
|
||||
@@ -1125,23 +1122,23 @@ def load_face_restoration(opt):
|
||||
opt.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
print(">> Face restoration disabled")
|
||||
logger.info("Face restoration disabled")
|
||||
if opt.esrgan:
|
||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||
else:
|
||||
print(">> Upscaling disabled")
|
||||
logger.info("Upscaling disabled")
|
||||
else:
|
||||
print(">> Face restoration and upscaling disabled")
|
||||
logger.info("Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
return gfpgan, codeformer, esrgan
|
||||
|
||||
|
||||
def make_step_callback(gen, opt, prefix):
|
||||
destination = os.path.join(opt.outdir, "intermediates", prefix)
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
print(f">> Intermediate images will be written into {destination}")
|
||||
logger.info(f"Intermediate images will be written into {destination}")
|
||||
|
||||
def callback(state: PipelineIntermediateState):
|
||||
latents = state.latents
|
||||
@@ -1183,21 +1180,20 @@ def retrieve_dream_command(opt, command, completer):
|
||||
try:
|
||||
cmd = dream_cmd_from_png(path)
|
||||
except OSError:
|
||||
print(f"## {tokens[0]}: file could not be read")
|
||||
logger.error(f"{tokens[0]}: file could not be read")
|
||||
except (KeyError, AttributeError, IndexError):
|
||||
print(f"## {tokens[0]}: file has no metadata")
|
||||
logger.error(f"{tokens[0]}: file has no metadata")
|
||||
except:
|
||||
print(f"## {tokens[0]}: file could not be processed")
|
||||
logger.error(f"{tokens[0]}: file could not be processed")
|
||||
if len(cmd) > 0:
|
||||
completer.set_line(cmd)
|
||||
|
||||
|
||||
def write_commands(opt, file_path: str, outfilepath: str):
|
||||
dir, basename = os.path.split(file_path)
|
||||
try:
|
||||
paths = sorted(list(Path(dir).glob(basename)))
|
||||
except ValueError:
|
||||
print(f'## "{basename}": unacceptable pattern')
|
||||
logger.error(f'"{basename}": unacceptable pattern')
|
||||
return
|
||||
|
||||
commands = []
|
||||
@@ -1206,9 +1202,9 @@ def write_commands(opt, file_path: str, outfilepath: str):
|
||||
try:
|
||||
cmd = dream_cmd_from_png(path)
|
||||
except (KeyError, AttributeError, IndexError):
|
||||
print(f"## {path}: file has no metadata")
|
||||
logger.error(f"{path}: file has no metadata")
|
||||
except:
|
||||
print(f"## {path}: file could not be processed")
|
||||
logger.error(f"{path}: file could not be processed")
|
||||
if cmd:
|
||||
commands.append(f"# {path}")
|
||||
commands.append(cmd)
|
||||
@@ -1218,18 +1214,18 @@ def write_commands(opt, file_path: str, outfilepath: str):
|
||||
outfilepath = os.path.join(opt.outdir, basename)
|
||||
with open(outfilepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(commands))
|
||||
print(f">> File {outfilepath} with commands created")
|
||||
logger.info(f"File {outfilepath} with commands created")
|
||||
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception):
|
||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
print(
|
||||
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
logger.warning(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
logger.warning(
|
||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
print(
|
||||
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
logger.warning(
|
||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
if not click.confirm(
|
||||
@@ -1238,7 +1234,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
):
|
||||
return
|
||||
|
||||
print("invokeai-configure is launching....\n")
|
||||
logger.info("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
@@ -1255,7 +1251,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
from ..install import invokeai_configure
|
||||
|
||||
invokeai_configure()
|
||||
print("** InvokeAI will now restart")
|
||||
logger.warning("InvokeAI will now restart")
|
||||
sys.argv = previous_args
|
||||
main() # would rather do a os.exec(), but doesn't exist?
|
||||
sys.exit(0)
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals, global_config_dir
|
||||
|
||||
from ...backend.config.model_install_backend import (
|
||||
@@ -455,8 +456,8 @@ def main():
|
||||
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
||||
|
||||
if not global_config_dir().exists():
|
||||
print(
|
||||
">> Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||
logger.info(
|
||||
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||
)
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
|
||||
@@ -466,18 +467,18 @@ def main():
|
||||
try:
|
||||
select_and_download_models(opt)
|
||||
except AssertionError as e:
|
||||
print(str(e))
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
logger.info("Goodbye! Come back soon.")
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** Insufficient vertical space for the interface. Please make your window taller and try again"
|
||||
logger.error(
|
||||
"Insufficient vertical space for the interface. Please make your window taller and try again"
|
||||
)
|
||||
elif str(e).startswith("addwstr"):
|
||||
print(
|
||||
"** Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||
logger.error(
|
||||
"Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ from ...backend.globals import (
|
||||
global_models_dir,
|
||||
global_set_root,
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...backend.model_management import ModelManager
|
||||
from ...frontend.install.widgets import FloatTitleSlider
|
||||
|
||||
@@ -113,7 +115,7 @@ def merge_diffusion_models_and_commit(
|
||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
if vae := model_manager.config[models[0]].get("vae", None):
|
||||
print(f">> Using configured VAE assigned to {models[0]}")
|
||||
logger.info(f"Using configured VAE assigned to {models[0]}")
|
||||
import_args.update(vae=vae)
|
||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||
model_manager.commit(config_file)
|
||||
@@ -391,10 +393,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
for name in self.model_manager.model_names()
|
||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
||||
]
|
||||
print(model_names)
|
||||
return sorted(model_names)
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -414,7 +414,7 @@ def run_gui(args: Namespace):
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merge_diffusion_models_and_commit(**args)
|
||||
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
@@ -425,8 +425,8 @@ def run_cli(args: Namespace):
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.models)
|
||||
print(
|
||||
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
logger.info(
|
||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||
@@ -435,7 +435,7 @@ def run_cli(args: Namespace):
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merge_diffusion_models_and_commit(**vars(args))
|
||||
print(f'>> Models merged into new model: "{args.merged_model_name}".')
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
def main():
|
||||
@@ -455,17 +455,16 @@ def main():
|
||||
run_cli(args)
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** You need to have at least two diffusers models defined in models.yaml in order to merge"
|
||||
logger.error(
|
||||
"You need to have at least two diffusers models defined in models.yaml in order to merge"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"** Not enough room for the user interface. Try making this window larger."
|
||||
logger.error(
|
||||
"Not enough room for the user interface. Try making this window larger."
|
||||
)
|
||||
sys.exit(-1)
|
||||
except Exception:
|
||||
print(">> An error occurred:")
|
||||
traceback.print_exc()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
@@ -20,6 +20,7 @@ import npyscreen
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals, global_set_root
|
||||
|
||||
from ...backend.training import do_textual_inversion_training, parse_args
|
||||
@@ -368,14 +369,14 @@ def copy_to_embeddings_folder(args: dict):
|
||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
print(f">> Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||
shutil.copy(source, destination)
|
||||
if (
|
||||
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
|
||||
).startswith(("y", "Y")):
|
||||
shutil.rmtree(Path(args["output_dir"]))
|
||||
else:
|
||||
print(f'>> Keeping {args["output_dir"]}')
|
||||
logger.info(f'Keeping {args["output_dir"]}')
|
||||
|
||||
|
||||
def save_args(args: dict):
|
||||
@@ -422,10 +423,10 @@ def do_front_end(args: Namespace):
|
||||
do_textual_inversion_training(**args)
|
||||
copy_to_embeddings_folder(args)
|
||||
except Exception as e:
|
||||
print("** An exception occurred during training. The exception was:")
|
||||
print(str(e))
|
||||
print("** DETAILS:")
|
||||
print(traceback.format_exc())
|
||||
logger.error("An exception occurred during training. The exception was:")
|
||||
logger.error(str(e))
|
||||
logger.error("DETAILS:")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
def main():
|
||||
@@ -437,21 +438,21 @@ def main():
|
||||
else:
|
||||
do_textual_inversion_training(**vars(args))
|
||||
except AssertionError as e:
|
||||
print(str(e))
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** You need to have at least one diffusers models defined in models.yaml in order to train"
|
||||
logger.error(
|
||||
"You need to have at least one diffusers models defined in models.yaml in order to train"
|
||||
)
|
||||
elif str(e).startswith("addwstr"):
|
||||
print(
|
||||
"** Not enough window space for the interface. Please make your window larger and try again."
|
||||
logger.error(
|
||||
"Not enough window space for the interface. Please make your window larger and try again."
|
||||
)
|
||||
else:
|
||||
print(f"** An error has occurred: {str(e)}")
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
|
||||
@@ -76,6 +76,8 @@
|
||||
"i18next-http-backend": "^2.2.0",
|
||||
"konva": "^9.0.1",
|
||||
"lodash-es": "^4.17.21",
|
||||
"overlayscrollbars": "^2.1.1",
|
||||
"overlayscrollbars-react": "^0.5.0",
|
||||
"patch-package": "^7.0.0",
|
||||
"re-resizable": "^6.9.9",
|
||||
"react": "^18.2.0",
|
||||
@@ -91,6 +93,7 @@
|
||||
"react-rnd": "^10.4.1",
|
||||
"react-transition-group": "^4.4.5",
|
||||
"react-use": "^17.4.0",
|
||||
"react-virtuoso": "^4.3.5",
|
||||
"react-zoom-pan-pinch": "^3.0.7",
|
||||
"reactflow": "^11.7.0",
|
||||
"redux-deep-persist": "^1.0.7",
|
||||
|
||||
@@ -18,6 +18,8 @@ import '@fontsource/inter/600.css';
|
||||
import '@fontsource/inter/700.css';
|
||||
import '@fontsource/inter/800.css';
|
||||
import '@fontsource/inter/900.css';
|
||||
import 'overlayscrollbars/overlayscrollbars.css';
|
||||
import 'theme/css/overlayscrollbars.css';
|
||||
|
||||
type ThemeLocaleProviderProps = {
|
||||
children: ReactNode;
|
||||
|
||||
@@ -73,10 +73,12 @@ const rootPersistConfig = getPersistConfig({
|
||||
...modelsDenylist,
|
||||
...nodesDenylist,
|
||||
...postprocessingDenylist,
|
||||
...resultsDenylist,
|
||||
// ...resultsDenylist,
|
||||
'results',
|
||||
...systemDenylist,
|
||||
...uiDenylist,
|
||||
...uploadsDenylist,
|
||||
// ...uploadsDenylist,
|
||||
'uploads',
|
||||
'hotkeys',
|
||||
'config',
|
||||
],
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import { ImageField, LatentsField } from 'services/api';
|
||||
import { ImageField, LatentsField, ConditioningField } from 'services/api';
|
||||
|
||||
const OBJECT_TYPESTRING = '[object Object]';
|
||||
const STRING_TYPESTRING = '[object String]';
|
||||
@@ -74,8 +74,38 @@ const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
|
||||
};
|
||||
};
|
||||
|
||||
const parseConditioningField = (
|
||||
conditioningField: unknown
|
||||
): ConditioningField | undefined => {
|
||||
// Must be an object
|
||||
if (!isObject(conditioningField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// A ConditioningField must have a `conditioning_name`
|
||||
if (!('conditioning_name' in conditioningField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// A ConditioningField's `conditioning_name` must be a string
|
||||
if (typeof conditioningField.conditioning_name !== 'string') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a valid ConditioningField
|
||||
return {
|
||||
conditioning_name: conditioningField.conditioning_name,
|
||||
};
|
||||
};
|
||||
|
||||
type NodeMetadata = {
|
||||
[key: string]: string | number | boolean | ImageField | LatentsField;
|
||||
[key: string]:
|
||||
| string
|
||||
| number
|
||||
| boolean
|
||||
| ImageField
|
||||
| LatentsField
|
||||
| ConditioningField;
|
||||
};
|
||||
|
||||
type InvokeAIMetadata = {
|
||||
@@ -101,7 +131,7 @@ export const parseNodeMetadata = (
|
||||
return;
|
||||
}
|
||||
|
||||
// the only valid object types are ImageField and LatentsField
|
||||
// the only valid object types are ImageField, LatentsField and ConditioningField
|
||||
if (isObject(nodeItem)) {
|
||||
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
||||
const imageField = parseImageField(nodeItem);
|
||||
@@ -118,6 +148,14 @@ export const parseNodeMetadata = (
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('conditioning_name' in nodeItem) {
|
||||
const conditioningField = parseConditioningField(nodeItem);
|
||||
if (conditioningField) {
|
||||
parsed[nodeKey] = conditioningField;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise we accept any string, number or boolean
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
Image,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
Skeleton,
|
||||
useDisclosure,
|
||||
useTheme,
|
||||
useToast,
|
||||
@@ -12,7 +13,7 @@ import {
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { DragEvent, memo, useCallback, useState } from 'react';
|
||||
import { FaCheck, FaExpand, FaShare, FaTrash } from 'react-icons/fa';
|
||||
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
|
||||
import DeleteImageModal from './DeleteImageModal';
|
||||
import { ContextMenu } from 'chakra-ui-contextmenu';
|
||||
import * as InvokeAI from 'app/types/invokeai';
|
||||
@@ -268,58 +269,48 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
userSelect="none"
|
||||
draggable={true}
|
||||
onDragStart={handleDragStart}
|
||||
onClick={handleSelectImage}
|
||||
ref={ref}
|
||||
sx={{
|
||||
padding: 2,
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
transition: 'transform 0.2s ease-out',
|
||||
_hover: {
|
||||
cursor: 'pointer',
|
||||
|
||||
zIndex: 2,
|
||||
},
|
||||
_before: {
|
||||
content: '""',
|
||||
display: 'block',
|
||||
paddingBottom: '100%',
|
||||
},
|
||||
aspectRatio: '1/1',
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
loading="lazy"
|
||||
objectFit={
|
||||
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
|
||||
}
|
||||
rounded="md"
|
||||
src={getUrl(thumbnail || url)}
|
||||
loading="lazy"
|
||||
fallback={<FaImage />}
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
maxWidth: '100%',
|
||||
maxHeight: '100%',
|
||||
top: '50%',
|
||||
transform: 'translate(-50%,-50%)',
|
||||
...(direction === 'rtl'
|
||||
? { insetInlineEnd: '50%' }
|
||||
: { insetInlineStart: '50%' }),
|
||||
}}
|
||||
/>
|
||||
<Flex
|
||||
onClick={handleSelectImage}
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: '0',
|
||||
insetInlineStart: '0',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
}}
|
||||
>
|
||||
{isSelected && (
|
||||
{isSelected && (
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: '0',
|
||||
insetInlineStart: '0',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
pointerEvents: 'none',
|
||||
}}
|
||||
>
|
||||
<Icon
|
||||
filter={'drop-shadow(0px 0px 1rem black)'}
|
||||
as={FaCheck}
|
||||
sx={{
|
||||
width: '50%',
|
||||
@@ -327,9 +318,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
fill: 'ok.500',
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
{isHovered && galleryImageMinimumWidth >= 64 && (
|
||||
</Flex>
|
||||
)}
|
||||
{isHovered && galleryImageMinimumWidth >= 100 && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
import { ButtonGroup, Flex, Grid, Icon, Image, Text } from '@chakra-ui/react';
|
||||
// import { requestImages } from 'app/socketio/actions';
|
||||
import {
|
||||
Box,
|
||||
ButtonGroup,
|
||||
Flex,
|
||||
FlexProps,
|
||||
Grid,
|
||||
Icon,
|
||||
Text,
|
||||
forwardRef,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAICheckbox from 'common/components/IAICheckbox';
|
||||
@@ -15,28 +23,33 @@ import {
|
||||
setShouldUseSingleGalleryColumn,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { togglePinGalleryPanel } from 'features/ui/store/uiSlice';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
|
||||
import { ChangeEvent, useEffect, useRef, useState } from 'react';
|
||||
import {
|
||||
ChangeEvent,
|
||||
PropsWithChildren,
|
||||
memo,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs';
|
||||
import { FaImage, FaUser, FaWrench } from 'react-icons/fa';
|
||||
import { MdPhotoLibrary } from 'react-icons/md';
|
||||
import HoverableImage from './HoverableImage';
|
||||
|
||||
import Scrollable from 'features/ui/components/common/Scrollable';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import {
|
||||
resultsAdapter,
|
||||
selectResultsAll,
|
||||
selectResultsTotal,
|
||||
} from '../store/resultsSlice';
|
||||
import { resultsAdapter } from '../store/resultsSlice';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
} from 'services/thunks/gallery';
|
||||
import { selectUploadsAll, uploadsAdapter } from '../store/uploadsSlice';
|
||||
import { uploadsAdapter } from '../store/uploadsSlice';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
|
||||
|
||||
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
||||
|
||||
@@ -68,16 +81,28 @@ const ImageGalleryContent = () => {
|
||||
const { t } = useTranslation();
|
||||
const resizeObserverRef = useRef<HTMLDivElement>(null);
|
||||
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
|
||||
const rootRef = useRef(null);
|
||||
const [scroller, setScroller] = useState<HTMLElement | null>(null);
|
||||
const [initialize, osInstance] = useOverlayScrollbars({
|
||||
defer: true,
|
||||
options: {
|
||||
scrollbars: {
|
||||
visibility: 'auto',
|
||||
autoHide: 'leave',
|
||||
autoHideDelay: 1300,
|
||||
theme: 'os-theme-dark',
|
||||
},
|
||||
overflow: { x: 'hidden' },
|
||||
},
|
||||
});
|
||||
|
||||
const {
|
||||
// images,
|
||||
currentCategory,
|
||||
shouldPinGallery,
|
||||
galleryImageMinimumWidth,
|
||||
galleryGridTemplateColumns,
|
||||
galleryImageObjectFit,
|
||||
shouldAutoSwitchToNewImages,
|
||||
// areMoreImagesAvailable,
|
||||
shouldUseSingleGalleryColumn,
|
||||
selectedImage,
|
||||
} = useAppSelector(imageGallerySelector);
|
||||
@@ -85,9 +110,6 @@ const ImageGalleryContent = () => {
|
||||
const { images, areMoreImagesAvailable, isLoading } =
|
||||
useAppSelector(gallerySelector);
|
||||
|
||||
// const handleClickLoadMore = () => {
|
||||
// dispatch(requestImages(currentCategory));
|
||||
// };
|
||||
const handleClickLoadMore = () => {
|
||||
if (currentCategory === 'results') {
|
||||
dispatch(receivedResultImagesPage());
|
||||
@@ -129,6 +151,25 @@ const ImageGalleryContent = () => {
|
||||
return () => resizeObserver.disconnect(); // clean up
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const { current: root } = rootRef;
|
||||
if (scroller && root) {
|
||||
initialize({
|
||||
target: root,
|
||||
elements: {
|
||||
viewport: scroller,
|
||||
},
|
||||
});
|
||||
}
|
||||
return () => osInstance()?.destroy();
|
||||
}, [scroller, initialize, osInstance]);
|
||||
|
||||
const setScrollerRef = useCallback((ref: HTMLElement | Window | null) => {
|
||||
if (ref instanceof HTMLElement) {
|
||||
setScroller(ref);
|
||||
}
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
||||
<Flex
|
||||
@@ -241,65 +282,119 @@ const ImageGalleryContent = () => {
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Scrollable>
|
||||
<Flex direction="column" gap={2} h="full">
|
||||
{images.length || areMoreImagesAvailable ? (
|
||||
<>
|
||||
<Grid
|
||||
gap={2}
|
||||
style={{ gridTemplateColumns: galleryGridTemplateColumns }}
|
||||
>
|
||||
{images.map((image) => {
|
||||
const { name } = image;
|
||||
const isSelected = selectedImage?.name === name;
|
||||
return (
|
||||
<HoverableImage
|
||||
key={`${name}-${image.thumbnail}`}
|
||||
image={image}
|
||||
isSelected={isSelected}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</Grid>
|
||||
<IAIButton
|
||||
onClick={handleClickLoadMore}
|
||||
isDisabled={!areMoreImagesAvailable}
|
||||
isLoading={isLoading}
|
||||
flexShrink={0}
|
||||
>
|
||||
{areMoreImagesAvailable
|
||||
? t('gallery.loadMore')
|
||||
: t('gallery.allImagesLoaded')}
|
||||
</IAIButton>
|
||||
</>
|
||||
) : (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
gap: 2,
|
||||
padding: 8,
|
||||
h: '100%',
|
||||
w: '100%',
|
||||
color: 'base.500',
|
||||
}}
|
||||
<Flex direction="column" gap={2} h="full">
|
||||
{images.length || areMoreImagesAvailable ? (
|
||||
<>
|
||||
<Box ref={rootRef} data-overlayscrollbars="" h="100%">
|
||||
{shouldUseSingleGalleryColumn ? (
|
||||
<Virtuoso
|
||||
style={{ height: '100%' }}
|
||||
data={images}
|
||||
scrollerRef={(ref) => setScrollerRef(ref)}
|
||||
itemContent={(index, image) => {
|
||||
const { name } = image;
|
||||
const isSelected = selectedImage?.name === name;
|
||||
|
||||
return (
|
||||
<Flex sx={{ pb: 2 }}>
|
||||
<HoverableImage
|
||||
key={`${name}-${image.thumbnail}`}
|
||||
image={image}
|
||||
isSelected={isSelected}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<VirtuosoGrid
|
||||
style={{ height: '100%' }}
|
||||
data={images}
|
||||
components={{
|
||||
Item: ItemContainer,
|
||||
List: ListContainer,
|
||||
}}
|
||||
scrollerRef={setScroller}
|
||||
itemContent={(index, image) => {
|
||||
const { name } = image;
|
||||
const isSelected = selectedImage?.name === name;
|
||||
|
||||
return (
|
||||
<HoverableImage
|
||||
key={`${name}-${image.thumbnail}`}
|
||||
image={image}
|
||||
isSelected={isSelected}
|
||||
/>
|
||||
);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
<IAIButton
|
||||
onClick={handleClickLoadMore}
|
||||
isDisabled={!areMoreImagesAvailable}
|
||||
isLoading={isLoading}
|
||||
flexShrink={0}
|
||||
>
|
||||
<Icon
|
||||
as={MdPhotoLibrary}
|
||||
sx={{
|
||||
w: 16,
|
||||
h: 16,
|
||||
}}
|
||||
/>
|
||||
<Text textAlign="center">{t('gallery.noImagesInGallery')}</Text>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Scrollable>
|
||||
{areMoreImagesAvailable
|
||||
? t('gallery.loadMore')
|
||||
: t('gallery.allImagesLoaded')}
|
||||
</IAIButton>
|
||||
</>
|
||||
) : (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
gap: 2,
|
||||
padding: 8,
|
||||
h: '100%',
|
||||
w: '100%',
|
||||
color: 'base.500',
|
||||
}}
|
||||
>
|
||||
<Icon
|
||||
as={MdPhotoLibrary}
|
||||
sx={{
|
||||
w: 16,
|
||||
h: 16,
|
||||
}}
|
||||
/>
|
||||
<Text textAlign="center">{t('gallery.noImagesInGallery')}</Text>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
ImageGalleryContent.displayName = 'ImageGalleryContent';
|
||||
export default ImageGalleryContent;
|
||||
type ItemContainerProps = PropsWithChildren & FlexProps;
|
||||
const ItemContainer = forwardRef((props: ItemContainerProps, ref) => (
|
||||
<Box className="item-container" ref={ref}>
|
||||
{props.children}
|
||||
</Box>
|
||||
));
|
||||
|
||||
type ListContainerProps = PropsWithChildren & FlexProps;
|
||||
const ListContainer = forwardRef((props: ListContainerProps, ref) => {
|
||||
const galleryImageMinimumWidth = useAppSelector(
|
||||
(state: RootState) => state.gallery.galleryImageMinimumWidth
|
||||
);
|
||||
|
||||
return (
|
||||
<Grid
|
||||
{...props}
|
||||
className="list-container"
|
||||
ref={ref}
|
||||
sx={{
|
||||
gap: 2,
|
||||
gridTemplateColumns: `repeat(auto-fit, minmax(${galleryImageMinimumWidth}px, 1fr));`,
|
||||
}}
|
||||
>
|
||||
{props.children}
|
||||
</Grid>
|
||||
);
|
||||
});
|
||||
|
||||
export default memo(ImageGalleryContent);
|
||||
|
||||
@@ -40,6 +40,8 @@ export const gallerySlice = createSlice({
|
||||
action: PayloadAction<SelectedImage | undefined>
|
||||
) => {
|
||||
state.selectedImage = action.payload;
|
||||
// TODO: if the user selects an image, disable the auto switch?
|
||||
// state.shouldAutoSwitchToNewImages = false;
|
||||
},
|
||||
setGalleryImageMinimumWidth: (state, action: PayloadAction<number>) => {
|
||||
state.galleryImageMinimumWidth = action.payload;
|
||||
|
||||
@@ -6,9 +6,11 @@ import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
|
||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
||||
|
||||
type InputFieldComponentProps = {
|
||||
nodeId: string;
|
||||
@@ -84,6 +86,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'conditioning' && template.type === 'conditioning') {
|
||||
return (
|
||||
<ConditioningInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'model' && template.type === 'model') {
|
||||
return (
|
||||
<ModelInputFieldComponent
|
||||
@@ -104,6 +116,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'item' && template.type === 'item') {
|
||||
return (
|
||||
<ItemInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return <Box p={2}>Unknown field type: {type}</Box>;
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
import {
|
||||
ConditioningInputFieldTemplate,
|
||||
ConditioningInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const ConditioningInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
ConditioningInputFieldValue,
|
||||
ConditioningInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(ConditioningInputFieldComponent);
|
||||
@@ -0,0 +1,17 @@
|
||||
import {
|
||||
ItemInputFieldTemplate,
|
||||
ItemInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import { FaAddressCard, FaList } from 'react-icons/fa';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const ItemInputFieldComponent = (
|
||||
props: FieldComponentProps<ItemInputFieldValue, ItemInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
return <FaAddressCard />;
|
||||
};
|
||||
|
||||
export default memo(ItemInputFieldComponent);
|
||||
@@ -11,8 +11,10 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
enum: 'enum',
|
||||
ImageField: 'image',
|
||||
LatentsField: 'latents',
|
||||
ConditioningField: 'conditioning',
|
||||
model: 'model',
|
||||
array: 'array',
|
||||
item: 'item',
|
||||
};
|
||||
|
||||
const COLOR_TOKEN_VALUE = 500;
|
||||
@@ -63,6 +65,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
title: 'Latents',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
},
|
||||
conditioning: {
|
||||
color: 'cyan',
|
||||
colorCssVar: getColorTokenCssVariable('cyan'),
|
||||
title: 'Conditioning',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
},
|
||||
model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
@@ -75,4 +83,10 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
title: 'Array',
|
||||
description: 'TODO: Array type description.',
|
||||
},
|
||||
item: {
|
||||
color: 'gray',
|
||||
colorCssVar: getColorTokenCssVariable('gray'),
|
||||
title: 'Collection Item',
|
||||
description: 'TODO: Collection Item type description.',
|
||||
},
|
||||
};
|
||||
|
||||
@@ -56,8 +56,10 @@ export type FieldType =
|
||||
| 'enum'
|
||||
| 'image'
|
||||
| 'latents'
|
||||
| 'conditioning'
|
||||
| 'model'
|
||||
| 'array';
|
||||
| 'array'
|
||||
| 'item';
|
||||
|
||||
/**
|
||||
* An input field is persisted across reloads as part of the user's local state.
|
||||
@@ -74,9 +76,11 @@ export type InputFieldValue =
|
||||
| BooleanInputFieldValue
|
||||
| ImageInputFieldValue
|
||||
| LatentsInputFieldValue
|
||||
| ConditioningInputFieldValue
|
||||
| EnumInputFieldValue
|
||||
| ModelInputFieldValue
|
||||
| ArrayInputFieldValue;
|
||||
| ArrayInputFieldValue
|
||||
| ItemInputFieldValue;
|
||||
|
||||
/**
|
||||
* An input field template is generated on each page load from the OpenAPI schema.
|
||||
@@ -91,9 +95,11 @@ export type InputFieldTemplate =
|
||||
| BooleanInputFieldTemplate
|
||||
| ImageInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| ConditioningInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| ModelInputFieldTemplate
|
||||
| ArrayInputFieldTemplate;
|
||||
| ArrayInputFieldTemplate
|
||||
| ItemInputFieldTemplate;
|
||||
|
||||
/**
|
||||
* An output field is persisted across as part of the user's local state.
|
||||
@@ -162,6 +168,11 @@ export type LatentsInputFieldValue = FieldValueBase & {
|
||||
value?: undefined;
|
||||
};
|
||||
|
||||
export type ConditioningInputFieldValue = FieldValueBase & {
|
||||
type: 'conditioning';
|
||||
value?: undefined;
|
||||
};
|
||||
|
||||
export type ImageInputFieldValue = FieldValueBase & {
|
||||
type: 'image';
|
||||
value?: Pick<ImageField, 'image_name' | 'image_type'>;
|
||||
@@ -177,6 +188,11 @@ export type ArrayInputFieldValue = FieldValueBase & {
|
||||
value?: (string | number)[];
|
||||
};
|
||||
|
||||
export type ItemInputFieldValue = FieldValueBase & {
|
||||
type: 'item';
|
||||
value?: undefined;
|
||||
};
|
||||
|
||||
export type InputFieldTemplateBase = {
|
||||
name: string;
|
||||
title: string;
|
||||
@@ -229,6 +245,11 @@ export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'latents';
|
||||
};
|
||||
|
||||
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'conditioning';
|
||||
};
|
||||
|
||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string | number;
|
||||
type: 'enum';
|
||||
@@ -242,10 +263,15 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
};
|
||||
|
||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: (string | number)[];
|
||||
default: [];
|
||||
type: 'array';
|
||||
};
|
||||
|
||||
export type ItemInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'item';
|
||||
};
|
||||
|
||||
/**
|
||||
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
|
||||
*/
|
||||
|
||||
@@ -9,12 +9,15 @@ import {
|
||||
ImageInputFieldTemplate,
|
||||
IntegerInputFieldTemplate,
|
||||
LatentsInputFieldTemplate,
|
||||
ConditioningInputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
ModelInputFieldTemplate,
|
||||
InputFieldTemplateBase,
|
||||
OutputFieldTemplate,
|
||||
TypeHints,
|
||||
FieldType,
|
||||
ArrayInputFieldTemplate,
|
||||
ItemInputFieldTemplate,
|
||||
} from '../types/types';
|
||||
|
||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||
@@ -196,6 +199,21 @@ const buildLatentsInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildConditioningInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ConditioningInputFieldTemplate => {
|
||||
const template: ConditioningInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'conditioning',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'connection',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildEnumInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -214,6 +232,36 @@ const buildEnumInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildArrayInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ArrayInputFieldTemplate => {
|
||||
const template: ArrayInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'array',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: [],
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildItemInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ItemInputFieldTemplate => {
|
||||
const template: ItemInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'item',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
export const getFieldType = (
|
||||
schemaObject: OpenAPIV3.SchemaObject,
|
||||
name: string,
|
||||
@@ -266,6 +314,9 @@ export const buildInputFieldTemplate = (
|
||||
if (['latents'].includes(fieldType)) {
|
||||
return buildLatentsInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['conditioning'].includes(fieldType)) {
|
||||
return buildConditioningInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['model'].includes(fieldType)) {
|
||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
@@ -284,6 +335,12 @@ export const buildInputFieldTemplate = (
|
||||
if (['boolean'].includes(fieldType)) {
|
||||
return buildBooleanInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['array'].includes(fieldType)) {
|
||||
return buildArrayInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['item'].includes(fieldType)) {
|
||||
return buildItemInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -48,6 +48,10 @@ export const buildInputFieldValue = (
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'conditioning') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ export const buildIterateNode = (): IterateInvocation => {
|
||||
return {
|
||||
id: nodeId,
|
||||
type: 'iterate',
|
||||
collection: [],
|
||||
index: 0,
|
||||
// collection: [],
|
||||
// index: 0,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -13,7 +13,7 @@ import {
|
||||
buildOutputFieldTemplates,
|
||||
} from './fieldTemplateBuilders';
|
||||
|
||||
const invocationDenylist = ['Graph', 'Collect', 'LoadImage'];
|
||||
const invocationDenylist = ['Graph', 'LoadImage'];
|
||||
|
||||
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
||||
// filter out non-invocation schemas, plus some tricky invocations for now
|
||||
@@ -32,49 +32,62 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
||||
if (isInvocationSchemaObject(schema)) {
|
||||
const type = schema.properties.type.default;
|
||||
|
||||
const title =
|
||||
schema.ui?.title ??
|
||||
schema.title
|
||||
.replace('Invocation', '')
|
||||
.split(/(?=[A-Z])/) // split PascalCase into array
|
||||
.join(' ');
|
||||
const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
|
||||
|
||||
const typeHints = schema.ui?.type_hints;
|
||||
|
||||
const inputs = reduce(
|
||||
schema.properties,
|
||||
(inputsAccumulator, property, propertyName) => {
|
||||
if (
|
||||
// `type` and `id` are not valid inputs/outputs
|
||||
!['type', 'id'].includes(propertyName) &&
|
||||
isSchemaObject(property)
|
||||
) {
|
||||
let field: InputFieldTemplate | undefined;
|
||||
if (propertyName === 'collection') {
|
||||
field = {
|
||||
default: property.default ?? [],
|
||||
name: 'collection',
|
||||
title: property.title ?? '',
|
||||
description: property.description ?? '',
|
||||
type: 'array',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'connection',
|
||||
};
|
||||
} else {
|
||||
field = buildInputFieldTemplate(
|
||||
property,
|
||||
propertyName,
|
||||
typeHints
|
||||
);
|
||||
const inputs: Record<string, InputFieldTemplate> = {};
|
||||
|
||||
if (type === 'collect') {
|
||||
const itemProperty = schema.properties[
|
||||
'item'
|
||||
] as InvocationSchemaObject;
|
||||
// Handle the special Collect node
|
||||
inputs.item = {
|
||||
type: 'item',
|
||||
name: 'item',
|
||||
description: itemProperty.description ?? '',
|
||||
title: 'Collection Item',
|
||||
inputKind: 'connection',
|
||||
inputRequirement: 'always',
|
||||
default: undefined,
|
||||
};
|
||||
} else if (type === 'iterate') {
|
||||
const itemProperty = schema.properties[
|
||||
'collection'
|
||||
] as InvocationSchemaObject;
|
||||
|
||||
inputs.collection = {
|
||||
type: 'array',
|
||||
name: 'collection',
|
||||
title: itemProperty.title ?? '',
|
||||
default: [],
|
||||
description: itemProperty.description ?? '',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'connection',
|
||||
};
|
||||
} else {
|
||||
// All other nodes
|
||||
reduce(
|
||||
schema.properties,
|
||||
(inputsAccumulator, property, propertyName) => {
|
||||
if (
|
||||
// `type` and `id` are not valid inputs/outputs
|
||||
!['type', 'id'].includes(propertyName) &&
|
||||
isSchemaObject(property)
|
||||
) {
|
||||
const field: InputFieldTemplate | undefined =
|
||||
buildInputFieldTemplate(property, propertyName, typeHints);
|
||||
|
||||
if (field) {
|
||||
inputsAccumulator[propertyName] = field;
|
||||
}
|
||||
}
|
||||
if (field) {
|
||||
inputsAccumulator[propertyName] = field;
|
||||
}
|
||||
}
|
||||
return inputsAccumulator;
|
||||
},
|
||||
{} as Record<string, InputFieldTemplate>
|
||||
);
|
||||
return inputsAccumulator;
|
||||
},
|
||||
inputs
|
||||
);
|
||||
}
|
||||
|
||||
const rawOutput = (schema as InvocationSchemaObject).output;
|
||||
|
||||
|
||||
@@ -13,17 +13,35 @@ const selector = createSelector(
|
||||
(generation, hotkeys, config) => {
|
||||
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
|
||||
config.sd.height;
|
||||
const { height } = generation;
|
||||
const { height, shouldFitToWidthHeight, isImageToImageEnabled } =
|
||||
generation;
|
||||
|
||||
const step = hotkeys.shift ? fineStep : coarseStep;
|
||||
|
||||
return { height, initial, min, sliderMax, inputMax, step };
|
||||
return {
|
||||
height,
|
||||
initial,
|
||||
min,
|
||||
sliderMax,
|
||||
inputMax,
|
||||
step,
|
||||
shouldFitToWidthHeight,
|
||||
isImageToImageEnabled,
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
const HeightSlider = () => {
|
||||
const { height, initial, min, sliderMax, inputMax, step } =
|
||||
useAppSelector(selector);
|
||||
const {
|
||||
height,
|
||||
initial,
|
||||
min,
|
||||
sliderMax,
|
||||
inputMax,
|
||||
step,
|
||||
shouldFitToWidthHeight,
|
||||
isImageToImageEnabled,
|
||||
} = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -40,6 +58,7 @@ const HeightSlider = () => {
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
isDisabled={!shouldFitToWidthHeight && isImageToImageEnabled}
|
||||
label={t('parameters.height')}
|
||||
value={height}
|
||||
min={min}
|
||||
|
||||
@@ -13,17 +13,34 @@ const selector = createSelector(
|
||||
(generation, hotkeys, config) => {
|
||||
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
|
||||
config.sd.width;
|
||||
const { width } = generation;
|
||||
const { width, shouldFitToWidthHeight, isImageToImageEnabled } = generation;
|
||||
|
||||
const step = hotkeys.shift ? fineStep : coarseStep;
|
||||
|
||||
return { width, initial, min, sliderMax, inputMax, step };
|
||||
return {
|
||||
width,
|
||||
initial,
|
||||
min,
|
||||
sliderMax,
|
||||
inputMax,
|
||||
step,
|
||||
shouldFitToWidthHeight,
|
||||
isImageToImageEnabled,
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
const WidthSlider = () => {
|
||||
const { width, initial, min, sliderMax, inputMax, step } =
|
||||
useAppSelector(selector);
|
||||
const {
|
||||
width,
|
||||
initial,
|
||||
min,
|
||||
sliderMax,
|
||||
inputMax,
|
||||
step,
|
||||
shouldFitToWidthHeight,
|
||||
isImageToImageEnabled,
|
||||
} = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -40,6 +57,7 @@ const WidthSlider = () => {
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
isDisabled={!shouldFitToWidthHeight && isImageToImageEnabled}
|
||||
label={t('parameters.width')}
|
||||
value={width}
|
||||
min={min}
|
||||
|
||||
@@ -89,13 +89,6 @@ const ProgressImagePreview = () => {
|
||||
onResizeStop={(e, direction, ref, delta, position) => {
|
||||
const newRect: Partial<Rect> = {};
|
||||
|
||||
console.log(
|
||||
ref.style.width,
|
||||
ref.style.height,
|
||||
position.x,
|
||||
position.y
|
||||
);
|
||||
|
||||
if (ref.style.width) {
|
||||
newRect.width = ref.style.width;
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ const initialSystemState: SystemState = {
|
||||
subscribedNodeIds: [],
|
||||
wereModelsReceived: false,
|
||||
wasSchemaParsed: false,
|
||||
consoleLogLevel: 'error',
|
||||
consoleLogLevel: 'debug',
|
||||
shouldLogToConsole: true,
|
||||
statusTranslationKey: 'common.statusDisconnected',
|
||||
canceledSession: '',
|
||||
|
||||
48
invokeai/frontend/web/src/theme/css/overlayscrollbars.css
Normal file
48
invokeai/frontend/web/src/theme/css/overlayscrollbars.css
Normal file
@@ -0,0 +1,48 @@
|
||||
.os-scrollbar {
|
||||
/* The size of the scrollbar */
|
||||
/* --os-size: 0; */
|
||||
/* The axis-perpedicular padding of the scrollbar (horizontal: padding-y, vertical: padding-x) */
|
||||
/* --os-padding-perpendicular: 0; */
|
||||
/* The axis padding of the scrollbar (horizontal: padding-x, vertical: padding-y) */
|
||||
/* --os-padding-axis: 0; */
|
||||
/* The border radius of the scrollbar track */
|
||||
/* --os-track-border-radius: 0; */
|
||||
/* The background of the scrollbar track */
|
||||
--os-track-bg: rgba(0, 0, 0, 0.3);
|
||||
/* The :hover background of the scrollbar track */
|
||||
--os-track-bg-hover: rgba(0, 0, 0, 0.3);
|
||||
/* The :active background of the scrollbar track */
|
||||
--os-track-bg-active: rgba(0, 0, 0, 0.3);
|
||||
/* The border of the scrollbar track */
|
||||
/* --os-track-border: none; */
|
||||
/* The :hover background of the scrollbar track */
|
||||
/* --os-track-border-hover: none; */
|
||||
/* The :active background of the scrollbar track */
|
||||
/* --os-track-border-active: none; */
|
||||
/* The border radius of the scrollbar handle */
|
||||
/* --os-handle-border-radius: 0; */
|
||||
/* The background of the scrollbar handle */
|
||||
--os-handle-bg: var(--invokeai-colors-accent-500);
|
||||
/* The :hover background of the scrollbar handle */
|
||||
--os-handle-bg-hover: var(--invokeai-colors-accent-450);
|
||||
/* The :active background of the scrollbar handle */
|
||||
--os-handle-bg-active: var(--invokeai-colors-accent-400);
|
||||
/* The border of the scrollbar handle */
|
||||
/* --os-handle-border: none; */
|
||||
/* The :hover border of the scrollbar handle */
|
||||
/* --os-handle-border-hover: none; */
|
||||
/* The :active border of the scrollbar handle */
|
||||
/* --os-handle-border-active: none; */
|
||||
/* The min size of the scrollbar handle */
|
||||
--os-handle-min-size: 50px;
|
||||
/* The max size of the scrollbar handle */
|
||||
/* --os-handle-max-size: none; */
|
||||
/* The axis-perpedicular size of the scrollbar handle (horizontal: height, vertical: width) */
|
||||
/* --os-handle-perpendicular-size: 100%; */
|
||||
/* The :hover axis-perpedicular size of the scrollbar handle (horizontal: height, vertical: width) */
|
||||
/* --os-handle-perpendicular-size-hover: 100%; */
|
||||
/* The :active axis-perpedicular size of the scrollbar handle (horizontal: height, vertical: width) */
|
||||
/* --os-handle-perpendicular-size-active: 100%; */
|
||||
/* Increases the interactive area of the scrollbar handle. */
|
||||
/* --os-handle-interactive-area-offset: 0; */
|
||||
}
|
||||
@@ -5100,6 +5100,16 @@ os-tmpdir@~1.0.2:
|
||||
resolved "https://registry.yarnpkg.com/os-tmpdir/-/os-tmpdir-1.0.2.tgz#bbe67406c79aa85c5cfec766fe5734555dfa1274"
|
||||
integrity sha512-D2FR03Vir7FIu45XBY20mTb+/ZSWB00sjU9jdQXt83gDrI4Ztz5Fs7/yy74g2N5SVQY4xY1qDr4rNddwYRVX0g==
|
||||
|
||||
overlayscrollbars-react@^0.5.0:
|
||||
version "0.5.0"
|
||||
resolved "https://registry.yarnpkg.com/overlayscrollbars-react/-/overlayscrollbars-react-0.5.0.tgz#0272bdc6304c7228a58d30e5b678e97fd5c5d8dd"
|
||||
integrity sha512-uCNTnkfWW74veoiEv3kSwoLelKt4e8gTNv65D771X3il0x5g5Yo0fUbro7SpQzR9yNgi23cvB2mQHTTdQH96pA==
|
||||
|
||||
overlayscrollbars@^2.1.1:
|
||||
version "2.1.1"
|
||||
resolved "https://registry.yarnpkg.com/overlayscrollbars/-/overlayscrollbars-2.1.1.tgz#a7414fe9c96cf140dbe4975bbe9312861750388d"
|
||||
integrity sha512-xvs2g8Tcq9+CZDpLEUchN3YUzjJhnTWw9kwqT/qcC53FIkOyP9mqnRMot5sW16tcsPT1KaMyzF0AMXw/7E4a8g==
|
||||
|
||||
p-cancelable@^1.0.0:
|
||||
version "1.1.0"
|
||||
resolved "https://registry.yarnpkg.com/p-cancelable/-/p-cancelable-1.1.0.tgz#d078d15a3af409220c886f1d9a0ca2e441ab26cc"
|
||||
@@ -5612,6 +5622,11 @@ react-use@^17.4.0:
|
||||
ts-easing "^0.2.0"
|
||||
tslib "^2.1.0"
|
||||
|
||||
react-virtuoso@^4.3.5:
|
||||
version "4.3.5"
|
||||
resolved "https://registry.yarnpkg.com/react-virtuoso/-/react-virtuoso-4.3.5.tgz#1e882d435b2d3d8abf7c4b85235199cbfadd935d"
|
||||
integrity sha512-MdWzmM9d8Gt5YGPIgGzRoqnYygTsriWlZrq+SqxphJTiiHs9cffnjf2Beo3SA3wRYzQJD8FI2HXtN5ACWzPFbQ==
|
||||
|
||||
react-zoom-pan-pinch@^3.0.7:
|
||||
version "3.0.7"
|
||||
resolved "https://registry.yarnpkg.com/react-zoom-pan-pinch/-/react-zoom-pan-pinch-3.0.7.tgz#def52f6886bc11e1b160dedf4250aae95470b94d"
|
||||
|
||||
@@ -25,6 +25,7 @@ def mock_services():
|
||||
return InvocationServices(
|
||||
model_manager = None, # type: ignore
|
||||
events = None, # type: ignore
|
||||
logger = None, # type: ignore
|
||||
images = None, # type: ignore
|
||||
latents = None, # type: ignore
|
||||
metadata = None, # type: ignore
|
||||
|
||||
@@ -8,6 +8,7 @@ from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_graph():
|
||||
g = Graph()
|
||||
@@ -22,6 +23,7 @@ def mock_services() -> InvocationServices:
|
||||
return InvocationServices(
|
||||
model_manager = None, # type: ignore
|
||||
events = TestEventService(),
|
||||
logger = None, # type: ignore
|
||||
images = None, # type: ignore
|
||||
latents = None, # type: ignore
|
||||
metadata = None, # type: ignore
|
||||
|
||||
@@ -316,7 +316,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
|
||||
|
||||
def test_graph_iterator_invalid_if_input_not_list():
|
||||
g = Graph()
|
||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
||||
n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi")
|
||||
n2 = IterateInvocation(id = "2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
@@ -463,16 +463,16 @@ def test_graph_subgraph_t2i():
|
||||
|
||||
n4 = ShowImageInvocation(id = "4")
|
||||
g.add_node(n4)
|
||||
g.add_edge(create_edge("1.5","image","4","image"))
|
||||
g.add_edge(create_edge("1.7","image","4","image"))
|
||||
|
||||
# Validate
|
||||
dg = g.nx_graph_flat()
|
||||
assert set(dg.nodes) == set(['1.width', '1.height', '1.3', '1.4', '1.5', '2', '3', '4'])
|
||||
assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '1.7', '2', '3', '4'])
|
||||
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
|
||||
expected_edges.extend([
|
||||
('2','1.width'),
|
||||
('3','1.height'),
|
||||
('1.5','4')
|
||||
('1.7','4')
|
||||
])
|
||||
print(expected_edges)
|
||||
print(list(dg.edges))
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
|
||||
os.environ['INVOKEAI_ROOT']='/tmp'
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, InvokeAISettings
|
||||
from invokeai.app.invocations.generate import TextToImageInvocation
|
||||
|
||||
init1 = OmegaConf.create(
|
||||
'''
|
||||
globals:
|
||||
nsfw_checker: False
|
||||
max_loaded_models: 5
|
||||
|
||||
history:
|
||||
count: 100
|
||||
|
||||
txt2img:
|
||||
steps: 18
|
||||
scheduler: k_heun
|
||||
width: 768
|
||||
|
||||
img2img:
|
||||
width: 1024
|
||||
height: 1024
|
||||
'''
|
||||
)
|
||||
|
||||
init2 = OmegaConf.create(
|
||||
'''
|
||||
globals:
|
||||
nsfw_checker: True
|
||||
max_loaded_models: 2
|
||||
|
||||
history:
|
||||
count: 10
|
||||
'''
|
||||
)
|
||||
|
||||
def test_use_init():
|
||||
# note that we explicitly set omegaconf dict and argv here
|
||||
# so that the values aren't read from ~invokeai/invokeai.yaml and
|
||||
# sys.argv respectively.
|
||||
conf1 = InvokeAIAppConfig(init1,[])
|
||||
assert conf1
|
||||
assert conf1.max_loaded_models==5
|
||||
assert not conf1.nsfw_checker
|
||||
|
||||
conf2 = InvokeAIAppConfig(init2,[])
|
||||
assert conf2
|
||||
assert conf2.nsfw_checker
|
||||
assert conf2.max_loaded_models==2
|
||||
assert not hasattr(conf2,'invalid_attribute')
|
||||
|
||||
|
||||
def test_argv_override():
|
||||
conf = InvokeAIAppConfig(init1,['--nsfw_checker','--max_loaded=10'])
|
||||
assert conf.nsfw_checker
|
||||
assert conf.max_loaded_models==10
|
||||
assert conf.outdir==Path('outputs') # this is the default
|
||||
|
||||
def test_env_override():
|
||||
# argv overrides
|
||||
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
|
||||
assert conf.nsfw_checker==False
|
||||
|
||||
os.environ['INVOKEAI_globals_nsfw_checker'] = 'True'
|
||||
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
|
||||
assert conf.nsfw_checker==True
|
||||
|
||||
conf = InvokeAIAppConfig(conf=init1,argv=['--no-nsfw_checker','--max_loaded=10'])
|
||||
assert conf.nsfw_checker==False
|
||||
|
||||
conf = InvokeAIAppConfig(conf=init1,argv=[],max_loaded_models=20)
|
||||
assert conf.max_loaded_models==20
|
||||
|
||||
# have to comment this one out because of a race condition in setting same
|
||||
# environment variable in the CI test environment
|
||||
# assert conf.root==Path('/tmp')
|
||||
|
||||
def test_invocation():
|
||||
InvokeAISettings.initconf=init1
|
||||
invocation = TextToImageInvocation(id='foobar')
|
||||
assert invocation.steps==18
|
||||
assert invocation.scheduler=='k_heun'
|
||||
assert invocation.height==512 # default
|
||||
|
||||
invocation = TextToImageInvocation(id='foobar2',steps=30)
|
||||
assert invocation.steps==30
|
||||
|
||||
def test_type_coercion():
|
||||
conf = InvokeAIAppConfig(argv=['--root=/tmp/foobar'])
|
||||
assert conf.root==Path('/tmp/foobar')
|
||||
assert isinstance(conf.root,Path)
|
||||
conf = InvokeAIAppConfig(argv=['--root=/tmp/foobar'],root='/tmp/different')
|
||||
assert conf.root==Path('/tmp/different')
|
||||
assert isinstance(conf.root,Path)
|
||||
Reference in New Issue
Block a user