mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
remove all references to CLI
This commit is contained in:
committed by
psychedelicious
parent
9fa8e38163
commit
d27392cc2d
@@ -1,312 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import argparse
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import Edge, GraphExecutionState, LibraryGraph
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override=None):
|
||||
default = (
|
||||
default_override
|
||||
if default_override is not None
|
||||
else field.default
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if get_origin(field.annotation) == Literal:
|
||||
allowed_values = get_args(field.annotation)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.annotation,
|
||||
default=default,
|
||||
help=field.description,
|
||||
)
|
||||
|
||||
|
||||
def add_parsers(
|
||||
subparsers,
|
||||
commands: list[type],
|
||||
command_field: str = "type",
|
||||
exclude_fields: list[str] = ["id", "type"],
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None,
|
||||
):
|
||||
"""Adds parsers for each command to the subparsers"""
|
||||
|
||||
# Create subparsers for each command
|
||||
for command in commands:
|
||||
hints = get_type_hints(command)
|
||||
cmd_name = get_args(hints[command_field])[0]
|
||||
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = command.__fields__ # type: ignore
|
||||
for name, field in fields.items():
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
|
||||
add_field_argument(command_parser, name, field)
|
||||
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers, graphs: list[LibraryGraph], add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Add arguments for inputs
|
||||
for exposed_input in graph.exposed_inputs:
|
||||
node = graph.graph.get_node(exposed_input.node_path)
|
||||
field = node.__fields__[exposed_input.field]
|
||||
default_override = getattr(node, exposed_input.field)
|
||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
|
||||
|
||||
class CliContext:
|
||||
invoker: Invoker
|
||||
session: GraphExecutionState
|
||||
parser: argparse.ArgumentParser
|
||||
defaults: dict[str, Any]
|
||||
graph_nodes: dict[str, str]
|
||||
nodes_added: list[str]
|
||||
|
||||
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
||||
self.invoker = invoker
|
||||
self.session = session
|
||||
self.parser = parser
|
||||
self.defaults = dict()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
|
||||
def get_session(self):
|
||||
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
||||
return self.session
|
||||
|
||||
def reset(self):
|
||||
self.session = self.invoker.create_execution_state()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
# Leave defaults unchanged
|
||||
|
||||
def add_node(self, node: BaseInvocation):
|
||||
self.get_session()
|
||||
self.session.graph.add_node(node)
|
||||
self.nodes_added.append(node.id)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
def add_edge(self, edge: Edge):
|
||||
self.get_session()
|
||||
self.session.add_edge(edge)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseCommand(ABC, BaseModel):
|
||||
"""A CLI command"""
|
||||
|
||||
# All commands must include a type name like this:
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return subclasses
|
||||
|
||||
@classmethod
|
||||
def get_commands(cls):
|
||||
return tuple(BaseCommand.get_all_subclasses())
|
||||
|
||||
@classmethod
|
||||
def get_commands_map(cls):
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)["type"])[0], t), BaseCommand.get_all_subclasses()))
|
||||
|
||||
@abstractmethod
|
||||
def run(self, context: CliContext) -> None:
|
||||
"""Run the command. Raise ExitCli to exit."""
|
||||
pass
|
||||
|
||||
|
||||
class ExitCommand(BaseCommand):
|
||||
"""Exits the CLI"""
|
||||
|
||||
type: Literal["exit"] = "exit"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
raise ExitCli()
|
||||
|
||||
|
||||
class HelpCommand(BaseCommand):
|
||||
"""Shows help"""
|
||||
|
||||
type: Literal["help"] = "help"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
context.parser.print_help()
|
||||
|
||||
|
||||
def get_graph_execution_history(
|
||||
graph_execution_state: GraphExecutionState,
|
||||
) -> Iterable[str]:
|
||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
|
||||
|
||||
|
||||
def get_invocation_command(invocation) -> str:
|
||||
fields = invocation.__fields__.items()
|
||||
type_hints = get_type_hints(type(invocation))
|
||||
command = [invocation.type]
|
||||
for name, field in fields:
|
||||
if name in ["id", "type"]:
|
||||
continue
|
||||
|
||||
# TODO: add links
|
||||
|
||||
# Skip image fields when serializing command
|
||||
type_hint = type_hints.get(name) or None
|
||||
if type_hint is ImageField or ImageField in get_args(type_hint):
|
||||
continue
|
||||
|
||||
field_value = getattr(invocation, name)
|
||||
field_default = field.default
|
||||
if field_value != field_default:
|
||||
if type_hint is str or str in get_args(type_hint):
|
||||
command.append(f'--{name} "{field_value}"')
|
||||
else:
|
||||
command.append(f"--{name} {field_value}")
|
||||
|
||||
return " ".join(command)
|
||||
|
||||
|
||||
class HistoryCommand(BaseCommand):
|
||||
"""Shows the invocation history"""
|
||||
|
||||
type: Literal["history"] = "history"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
count: int = Field(default=5, gt=0, description="The number of history entries to show")
|
||||
# fmt: on
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
history = list(get_graph_execution_history(context.get_session()))
|
||||
for i in range(min(self.count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = context.get_session().graph.get_node(entry_id)
|
||||
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
|
||||
|
||||
class SetDefaultCommand(BaseCommand):
|
||||
"""Sets a default value for a field"""
|
||||
|
||||
type: Literal["default"] = "default"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
field: str = Field(description="The field to set the default for")
|
||||
value: str = Field(description="The value to set the default to, or None to clear the default")
|
||||
# fmt: on
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
if self.value is None:
|
||||
if self.field in context.defaults:
|
||||
del context.defaults[self.field]
|
||||
else:
|
||||
context.defaults[self.field] = self.value
|
||||
|
||||
|
||||
class DrawGraphCommand(BaseCommand):
|
||||
"""Debugs a graph"""
|
||||
|
||||
type: Literal["draw_graph"] = "draw_graph"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
nxgraph = session.graph.nx_graph_flat()
|
||||
|
||||
# Draw the networkx graph
|
||||
plt.figure(figsize=(20, 20))
|
||||
pos = nx.spectral_layout(nxgraph)
|
||||
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
|
||||
class DrawExecutionGraphCommand(BaseCommand):
|
||||
"""Debugs an execution graph"""
|
||||
|
||||
type: Literal["draw_xgraph"] = "draw_xgraph"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
nxgraph = session.execution_graph.nx_graph_flat()
|
||||
|
||||
# Draw the networkx graph
|
||||
plt.figure(figsize=(20, 20))
|
||||
pos = nx.spectral_layout(nxgraph)
|
||||
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
|
||||
class SortedHelpFormatter(argparse.HelpFormatter):
|
||||
def _iter_indented_subactions(self, action):
|
||||
try:
|
||||
get_subactions = action._get_subactions
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self._indent()
|
||||
if isinstance(action, argparse._SubParsersAction):
|
||||
for subaction in sorted(get_subactions(), key=lambda x: x.dest):
|
||||
yield subaction
|
||||
else:
|
||||
for subaction in get_subactions():
|
||||
yield subaction
|
||||
self._dedent()
|
||||
@@ -1,171 +0,0 @@
|
||||
"""
|
||||
Readline helper functions for cli_app.py
|
||||
You may import the global singleton `completer` to get access to the
|
||||
completer object.
|
||||
"""
|
||||
import atexit
|
||||
import readline
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, get_args, get_origin, get_type_hints
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...backend import ModelManager
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .commands import BaseCommand
|
||||
|
||||
# singleton object, class variable
|
||||
completer = None
|
||||
|
||||
|
||||
class Completer(object):
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
self.commands = self.get_commands()
|
||||
self.matches = None
|
||||
self.linebuffer = None
|
||||
self.manager = model_manager
|
||||
return
|
||||
|
||||
def complete(self, text, state):
|
||||
"""
|
||||
Complete commands and switches fromm the node CLI command line.
|
||||
Switches are determined in a context-specific manner.
|
||||
"""
|
||||
|
||||
buffer = readline.get_line_buffer()
|
||||
if state == 0:
|
||||
options = None
|
||||
try:
|
||||
current_command, current_switch = self.get_current_command(buffer)
|
||||
options = self.get_command_options(current_command, current_switch)
|
||||
except IndexError:
|
||||
pass
|
||||
options = options or list(self.parse_commands().keys())
|
||||
|
||||
if not text: # first time
|
||||
self.matches = options
|
||||
else:
|
||||
self.matches = [s for s in options if s and s.startswith(text)]
|
||||
|
||||
try:
|
||||
match = self.matches[state]
|
||||
except IndexError:
|
||||
match = None
|
||||
return match
|
||||
|
||||
@classmethod
|
||||
def get_commands(self) -> List[object]:
|
||||
"""
|
||||
Return a list of all the client commands and invocations.
|
||||
"""
|
||||
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
|
||||
|
||||
def get_current_command(self, buffer: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse the readline buffer to find the most recent command and its switch.
|
||||
"""
|
||||
if len(buffer) == 0:
|
||||
return None, None
|
||||
tokens = shlex.split(buffer)
|
||||
command = None
|
||||
switch = None
|
||||
for t in tokens:
|
||||
if t[0].isalpha():
|
||||
if switch is None:
|
||||
command = t
|
||||
else:
|
||||
switch = t
|
||||
# don't try to autocomplete switches that are already complete
|
||||
if switch and buffer.endswith(" "):
|
||||
switch = None
|
||||
return command or "", switch or ""
|
||||
|
||||
def parse_commands(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Return a dict in which the keys are the command name
|
||||
and the values are the parameters the command takes.
|
||||
"""
|
||||
result = dict()
|
||||
for command in self.commands:
|
||||
hints = get_type_hints(command)
|
||||
name = get_args(hints["type"])[0]
|
||||
result.update({name: hints})
|
||||
return result
|
||||
|
||||
def get_command_options(self, command: str, switch: str) -> List[str]:
|
||||
"""
|
||||
Return all the parameters that can be passed to the command as
|
||||
command-line switches. Returns None if the command is unrecognized.
|
||||
"""
|
||||
parsed_commands = self.parse_commands()
|
||||
if command not in parsed_commands:
|
||||
return None
|
||||
|
||||
# handle switches in the format "-foo=bar"
|
||||
argument = None
|
||||
if switch and "=" in switch:
|
||||
switch, argument = switch.split("=")
|
||||
|
||||
parameter = switch.strip("-")
|
||||
if parameter in parsed_commands[command]:
|
||||
if argument is None:
|
||||
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||
else:
|
||||
return [
|
||||
f"--{parameter}={x}"
|
||||
for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||
]
|
||||
else:
|
||||
return [f"--{x}" for x in parsed_commands[command].keys()]
|
||||
|
||||
def get_parameter_options(self, parameter: str, typehint) -> List[str]:
|
||||
"""
|
||||
Given a parameter type (such as Literal), offers autocompletions.
|
||||
"""
|
||||
if get_origin(typehint) == Literal:
|
||||
return get_args(typehint)
|
||||
if parameter == "model":
|
||||
return self.manager.model_names()
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
readline.insert_text(self.linebuffer)
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
|
||||
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
global completer
|
||||
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
completer = Completer(services.model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
try:
|
||||
readline.set_auto_history(True)
|
||||
except AttributeError:
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
pass
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(" ")
|
||||
readline.parse_and_bind("tab: complete")
|
||||
readline.parse_and_bind("set print-completions-horizontally off")
|
||||
readline.parse_and_bind("set page-completions on")
|
||||
readline.parse_and_bind("set skip-completed-text on")
|
||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||
|
||||
histfile = Path(services.configuration.root_dir / ".invoke_history")
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f"{histfile}.old"
|
||||
logger.error(f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}")
|
||||
histfile.replace(Path(newname))
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
@@ -1,484 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||
# values from the command line or config file.
|
||||
|
||||
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||
import argparse
|
||||
import re
|
||||
import shlex
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional, Union, get_type_hints
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic.fields import Field
|
||||
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
||||
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
|
||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
|
||||
from .services.events import EventServiceBase
|
||||
from .services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
GraphExecutionState,
|
||||
GraphInvocation,
|
||||
LibraryGraph,
|
||||
are_connection_types_compatible,
|
||||
)
|
||||
from .services.image_file_storage import DiskImageFileStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from .services.model_manager_service import ModelManagerService
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
|
||||
|
||||
class InvalidArgs(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def add_invocation_args(command_parser):
|
||||
# Add linking capability
|
||||
command_parser.add_argument(
|
||||
"--link",
|
||||
"-l",
|
||||
action="append",
|
||||
nargs=3,
|
||||
help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
command_parser.add_argument(
|
||||
"--link_node",
|
||||
"-ln",
|
||||
action="append",
|
||||
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
|
||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
|
||||
|
||||
def exit(*args, **kwargs):
|
||||
raise InvalidArgs
|
||||
|
||||
parser.exit = exit
|
||||
subparsers = parser.add_subparsers(dest="type")
|
||||
|
||||
# Create subparsers for each invocation
|
||||
invocations = BaseInvocation.get_all_subclasses()
|
||||
add_parsers(subparsers, invocations, add_arguments=add_invocation_args)
|
||||
|
||||
# Create subparsers for each command
|
||||
commands = BaseCommand.get_all_subclasses()
|
||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||
|
||||
# Create subparsers for exposed CLI graphs
|
||||
# TODO: add a way to identify these graphs
|
||||
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
|
||||
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class NodeField:
|
||||
alias: str
|
||||
node_path: str
|
||||
field: str
|
||||
field_type: type
|
||||
|
||||
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
|
||||
self.alias = alias
|
||||
self.node_path = node_path
|
||||
self.field = field
|
||||
self.field_type = field_type
|
||||
|
||||
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str, NodeField]:
|
||||
return {k: NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
|
||||
|
||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||
return NodeField(
|
||||
alias=exposed_input.alias,
|
||||
node_path=f"{node_id}.{exposed_input.node_path}",
|
||||
field=exposed_input.field,
|
||||
field_type=get_type_hints(node_type)[exposed_input.field],
|
||||
)
|
||||
|
||||
|
||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||
node_output_type = node_type.get_output_type()
|
||||
return NodeField(
|
||||
alias=exposed_output.alias,
|
||||
node_path=f"{node_id}.{exposed_output.node_path}",
|
||||
field=exposed_output.field,
|
||||
field_type=get_type_hints(node_output_type)[exposed_output.field],
|
||||
)
|
||||
|
||||
|
||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the inputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
|
||||
|
||||
|
||||
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the outputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||
|
||||
|
||||
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliContext) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
afields = get_node_outputs(a, context)
|
||||
bfields = get_node_inputs(b, context)
|
||||
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
# Remove invalid fields
|
||||
invalid_fields = set(["type", "id"])
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
# Validate types
|
||||
matching_fields = [
|
||||
f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)
|
||||
]
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field),
|
||||
)
|
||||
for alias in matching_fields
|
||||
]
|
||||
return edges
|
||||
|
||||
|
||||
class SessionError(Exception):
|
||||
"""Raised when a session error has occurred"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def invoke_all(context: CliContext):
|
||||
"""Runs all invocations in the specified session"""
|
||||
context.invoker.invoke(context.session, invoke_all=True)
|
||||
while not context.get_session().is_complete():
|
||||
# Wait some time
|
||||
time.sleep(0.1)
|
||||
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
context.invoker.services.logger.error(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
raise SessionError()
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
logger.info(f"InvokeAI version {__version__}")
|
||||
# get the optional list of invocations to execute on the command line
|
||||
parser = config.get_parser()
|
||||
parser.add_argument("commands", nargs="*")
|
||||
invocation_commands = parser.parse_args().commands
|
||||
|
||||
# get the optional file to read commands from.
|
||||
# Simplest is to use it for STDIN
|
||||
if infile := config.from_file:
|
||||
sys.stdin = open(infile, "r")
|
||||
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
|
||||
events = EventServiceBase()
|
||||
output_folder = config.output_path
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
if config.use_memory_db:
|
||||
db_location = ":memory:"
|
||||
else:
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
board_record_storage=board_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
board_images = BoardImagesService(
|
||||
services=BoardImagesServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
board_record_storage=board_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
images = ImageService(
|
||||
services=ImageServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||
images=images,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
configuration=config,
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
set_autocompleter(services)
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser(services)
|
||||
|
||||
re_negid = re.compile("^-[0-9]+$")
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
# print(services.session_manager.list())
|
||||
|
||||
context = CliContext(invoker, session, parser)
|
||||
set_autocompleter(services)
|
||||
|
||||
command_line_args_exist = len(invocation_commands) > 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
try:
|
||||
if command_line_args_exist:
|
||||
cmd_input = invocation_commands.pop(0)
|
||||
done = len(invocation_commands) == 0
|
||||
else:
|
||||
cmd_input = input("invoke> ")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Ctrl-c exits
|
||||
break
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
# history = list(get_graph_execution_history(context.session))
|
||||
history = list(reversed(context.nodes_added))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split("|")
|
||||
start_id = len(context.nodes_added)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
if cmd is None or cmd.strip() == "":
|
||||
raise InvalidArgs("Empty command")
|
||||
|
||||
# Parse args to create invocation
|
||||
args = vars(context.parser.parse_args(shlex.split(cmd.strip())))
|
||||
|
||||
# Override defaults
|
||||
for field_name, field_default in context.defaults.items():
|
||||
if field_name in args:
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
command: CliCommand = None # type:ignore
|
||||
system_graph: Optional[LibraryGraph] = None
|
||||
if args["type"] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args["type"], system_graphs))
|
||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||
for exposed_input in system_graph.exposed_inputs:
|
||||
if exposed_input.alias in args:
|
||||
node = invocation.graph.get_node(exposed_input.node_path)
|
||||
field = exposed_input.field
|
||||
setattr(node, field, args[exposed_input.alias])
|
||||
command = CliCommand(command=invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
|
||||
if command is None:
|
||||
continue
|
||||
|
||||
# Run any CLI commands immediately
|
||||
if isinstance(command.command, BaseCommand):
|
||||
# Invoke all current nodes to preserve operation order
|
||||
invoke_all(context)
|
||||
|
||||
# Run the command
|
||||
command.command.run(context)
|
||||
continue
|
||||
|
||||
# TODO: handle linking with library graphs
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
from_id = history[0] if current_id == start_id else str(current_id - 1)
|
||||
from_node = (
|
||||
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
||||
if current_id != start_id
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(from_node, command.command, context)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
# Parse provided links
|
||||
if "link_node" in args and args["link_node"]:
|
||||
for link in args["link_node"]:
|
||||
node_id = link
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
link_node = context.session.graph.get_node(node_id)
|
||||
matching_edges = generate_matching_edges(link_node, command.command, context)
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
edges.extend(matching_edges)
|
||||
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges = [
|
||||
e
|
||||
for e in edges
|
||||
if e.destination.node_id != command.command.id or e.destination.field != link[2]
|
||||
]
|
||||
|
||||
node_id = link[0]
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
# TODO: handle missing input/output
|
||||
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
|
||||
node_input = get_node_inputs(command.command, context)[link[2]]
|
||||
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field),
|
||||
)
|
||||
)
|
||||
|
||||
new_invocations.append((command.command, edges))
|
||||
|
||||
current_id = current_id + 1
|
||||
|
||||
# Add the node to the session
|
||||
context.add_node(command.command)
|
||||
for edge in edges:
|
||||
print(edge)
|
||||
context.add_edge(edge)
|
||||
|
||||
# Execute all remaining nodes
|
||||
invoke_all(context)
|
||||
|
||||
except InvalidArgs:
|
||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except ValidationError:
|
||||
invoker.services.logger.warning('Invalid command arguments, run "<command> --help" for summary')
|
||||
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
invoker.services.logger.warning("Session error: creating a new session")
|
||||
context.reset()
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
||||
except SystemExit:
|
||||
continue
|
||||
|
||||
invoker.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_cli()
|
||||
Reference in New Issue
Block a user