Compare commits

...

44 Commits

Author SHA1 Message Date
Alexander Eichhorn
7938d840b2 Merge branch 'external-models' into alibabacloud/dashscope
# Conflicts:
#	invokeai/backend/model_manager/starter_models.py
2026-04-14 23:08:06 +02:00
Alexander Eichhorn
450ba7b7e1 Merge branch 'main' into external-models 2026-04-14 20:56:08 +02:00
Alexander Eichhorn
c743106f66 Merge remote-tracking branch 'upstream/main' into external-models 2026-04-14 03:43:39 +02:00
Alexander Eichhorn
cd888654d5 Merge branch 'main' into external-models 2026-04-14 02:09:56 +02:00
Alexander Eichhorn
ec4b87b949 add missing parameter 2026-04-14 01:39:04 +02:00
Alexander Eichhorn
8f00759af0 Chore pnpm fix 2026-04-14 01:07:35 +02:00
Alexander Eichhorn
5c09c823a9 Merge remote-tracking branch 'upstream/main' into external-models 2026-04-14 00:58:09 +02:00
Alexander Eichhorn
ec90b2fbe9 Merge remote-tracking branch 'upstream/main' into external-models 2026-04-12 04:29:17 +02:00
Alexander Eichhorn
17157d7c60 Merge remote-tracking branch 'upstream/main' into external-models 2026-04-12 04:28:47 +02:00
Alexander Eichhorn
853c3ef915 Merge remote-tracking branch 'upstream/main' into external-models 2026-04-07 23:54:26 +02:00
Alexander Eichhorn
3e9e052d5d feat: full canvas workflow integration for external models
- Update buildExternalGraph test to include dimensions in mock params
2026-04-06 23:32:10 +02:00
Alexander Eichhorn
089e2db402 Chore typegen Linux seperator 2026-04-06 23:21:45 +02:00
Alexander Eichhorn
4cbd60b4a5 Merge remote-tracking branch 'upstream/main' into external-models 2026-04-06 23:20:43 +02:00
Alexander Eichhorn
c2016bcfb7 feat: full canvas workflow integration for external models
- Add missing aspect ratios (4:5, 5:4, 8:1, 4:1, 1:4, 1:8) to type
  system for external model support
- Sync canvas bbox when external model resolution preset is selected
- Use params preset dimensions in buildExternalGraph to prevent
  "unsupported aspect ratio" errors
- Lock all bbox controls (resize handles, aspect ratio select,
  width/height sliders, swap/optimal buttons) for external models
  with fixed dimension presets
- Disable denoise strength slider for external models (not applicable)
- Sync bbox aspect ratio changes back to paramsSlice for external models
- Initialize bbox dimensions when switching to an external model
2026-04-06 23:13:10 +02:00
Alexander Eichhorn
813a5e2c2e Chore typegen 2026-03-28 14:59:51 +01:00
Alexander Eichhorn
18315db7f0 Chore Ruff check & format 2026-03-28 14:50:57 +01:00
Alexander Eichhorn
edde0b4737 Merge branch 'main' into external-models 2026-03-28 14:47:39 +01:00
Alexander Eichhorn
27fc650f4f Merge branch 'main' into external-models 2026-03-23 20:23:40 +01:00
Alexander Eichhorn
a1eef791a1 feat: add Alibaba Cloud DashScope external image generation provider
Add AlibabaCloudProvider supporting Qwen Image and Wan model families
via the DashScope API. Includes sync (multimodal-generation) and async
(image-generation with task polling) request modes, five starter models
(Qwen Image 2.0 Pro, 2.0, Max, Wan 2.6 T2I, Qwen Image Edit Max),
config fields for API key and base URL, and frontend registration.
2026-03-20 10:01:49 +01:00
Alexander Eichhorn
d8d0ebc356 Remove unused external model fields and add provider-specific parameters
- Remove negative_prompt, steps, guidance, reference_image_weights,
  reference_image_modes from external model nodes (unused by any provider)
- Remove supports_negative_prompt, supports_steps, supports_guidance
  from ExternalModelCapabilities
- Add provider_options dict to ExternalGenerationRequest for
  provider-specific parameters
- Add OpenAI-specific fields: quality, background, input_fidelity
- Add Gemini-specific fields: temperature, thinking_level
- Add new OpenAI starter models: GPT Image 1.5, GPT Image 1 Mini,
  DALL-E 3, DALL-E 2
- Fix OpenAI provider to use output_format (GPT Image) vs
  response_format (DALL-E) and send model ID in requests
- Add fixed aspect ratio sizes for OpenAI models (bucketing)
- Add ExternalProviderRateLimitError with retry logic for 429 responses
- Add provider-specific UI components in ExternalSettingsAccordion
- Simplify ParamSteps/ParamGuidance by removing dead external overrides
- Update all backend and frontend tests
2026-03-20 08:17:16 +01:00
Alexander Eichhorn
8375f95ea9 feat: add resolution presets and imageConfig support for Gemini 3 models
Add combined resolution preset selector for external models that maps
aspect ratio + image size to fixed dimensions. Gemini 3 Pro and 3.1 Flash
now send imageConfig (aspectRatio + imageSize) via generationConfig instead
of text-based aspect ratio hints used by Gemini 2.5 Flash.

Backend: ExternalResolutionPreset model, resolution_presets capability field,
image_size on ExternalGenerationRequest, and Gemini provider imageConfig logic.

Frontend: ExternalSettingsAccordion with combo resolution select, dimension
slider disabling for fixed-size models, and panel schema constraint wiring
for Steps/Guidance/Seed controls.
2026-03-19 04:36:09 +01:00
Alexander Eichhorn
9e4d0bb191 fix: resolve TypeScript errors and move external provider config to api_keys.yaml
Add 'external', 'external_image_generator', and 'external_api' to Zod
enum schemas (zBaseModelType, zModelType, zModelFormat) to match the
generated OpenAPI types. Remove redundant union workarounds from
component prop types and Record definitions.

Fix type errors in ModelEdit (react-hook-form Control invariance),
parsing.tsx (model identifier narrowing), buildExternalGraph (edge
typing), and ModelSettings import/export buttons.

Move external_gemini_base_url and external_openai_base_url into
api_keys.yaml alongside the API keys so all external provider config
lives in one dedicated file, separate from invokeai.yaml.
2026-03-18 17:03:15 +01:00
CypherNaught-0x
20a400cee8 feat: update gemini image model limits 2026-03-17 14:49:01 +01:00
CypherNaught-0x
40f02aa6c4 feat: add gemini 3.1 flash image preview starter model 2026-03-17 14:43:09 +01:00
CypherNaught-0x
c3a482e80a docs: sync app config docstring order 2026-03-17 14:39:43 +01:00
CypherNaught-0x
257994f552 feat(ui): drive external panels from panel schema 2026-03-17 13:56:07 +01:00
CypherNaught-0x
bafce41856 feat: expose external panel schemas in model configs 2026-03-17 13:56:02 +01:00
CypherNaught-0x
757bd3d002 feat(ui): add provider-specific external generation nodes 2026-03-17 13:36:42 +01:00
CypherNaught-0x
519575e871 fix: sync configured external starter models on startup 2026-03-17 13:33:13 +01:00
Alexander Eichhorn
f39456e6f0 Merge branch 'main' into external-models 2026-03-12 03:49:54 +01:00
Alexander Eichhorn
689725c6e4 Merge branch 'main' into external-models 2026-03-07 03:11:21 +01:00
CypherNaught-0x
10729f40f2 chore: fix linter errors 2026-02-27 16:36:07 +01:00
CypherNaught-0x
362054120e docs: updated external model docs 2026-02-27 11:13:33 +01:00
CypherNaught-0x
b91a156a3d review: save api keys to a seperate file 2026-02-27 11:13:33 +01:00
CypherNaught-0x
c6b0d45c5f chore: fix linter warning 2026-02-27 11:12:23 +01:00
CypherNaught-0x
dc665e08ac review: added optional seed control for external models 2026-02-27 11:12:23 +01:00
CypherNaught-0x
0dd72837d3 review: implemented review comments 2026-02-27 11:12:23 +01:00
CypherNaught-0x
d5a6283f23 review: model descriptions 2026-02-27 11:12:22 +01:00
CypherNaught-0x
6fe1a6f1ac feat: show external mode name during install 2026-02-27 11:12:22 +01:00
CypherNaught-0x
5d34eab6f0 review: enable auto-install/remove fro external models 2026-02-27 11:12:22 +01:00
CypherNaught-0x
1b43769b95 chore: hide Reidentify button for external models 2026-02-27 11:12:22 +01:00
CypherNaught-0x
a9d3b4e17c fix: sorting lint error 2026-02-27 11:12:22 +01:00
CypherNaught-0x
74ecc461b9 feat: support reference images for external models 2026-02-27 11:12:21 +01:00
CypherNaught-0x
19650f6ada feat: initial external model support 2026-02-27 11:12:21 +01:00
131 changed files with 7610 additions and 394 deletions

View File

@@ -0,0 +1,129 @@
# External Provider Integration
This guide covers:
1. Adding a new **external model** (most common; existing provider).
2. Adding a brand-new **external provider** (adapter + config + UI wiring).
## 1) Add a New External Model (Existing Provider)
For provider-backed models (for example, OpenAI or Gemini), the source of truth is
`invokeai/backend/model_manager/starter_models.py`.
### Required model fields
Define a `StarterModel` with:
- `base=BaseModelType.External`
- `type=ModelType.ExternalImageGenerator`
- `format=ModelFormat.ExternalApi`
- `source="external://<provider_id>/<provider_model_id>"`
- `name`, `description`
- `capabilities=ExternalModelCapabilities(...)`
- optional `default_settings=ExternalApiModelDefaultSettings(...)`
Example:
```python
new_external_model = StarterModel(
name="Provider Model Name",
base=BaseModelType.External,
source="external://openai/my-model-id",
description=(
"Provider model (external API). "
"Requires a configured OpenAI API key and may incur provider usage costs."
),
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img", "img2img", "inpaint"],
supports_negative_prompt=False,
supports_seed=False,
supports_guidance=False,
supports_steps=False,
supports_reference_images=True,
max_images_per_request=4,
),
default_settings=ExternalApiModelDefaultSettings(
width=1024,
height=1024,
num_images=1,
),
)
```
Then append it to `STARTER_MODELS`.
### Required description text
External starter model descriptions must clearly state:
- an API key is required
- usage may incur provider-side costs
### Capabilities must be accurate
These flags directly control UI visibility and request payload fields:
- `supports_negative_prompt`
- `supports_seed`
- `supports_guidance`
- `supports_steps`
- `supports_reference_images`
`supports_steps` is especially important: if `False`, steps are hidden for that model and `steps` is sent as `null`.
### Source string stability
Starter overrides are matched by `source` (`external://provider/model-id`). Keep this stable:
- runtime capability/default overrides depend on it
- installation detection in starter-model APIs depends on it
`STARTER_MODELS` enforces unique `source` values with an assertion.
### Install behavior notes
- External starter models are managed in **External Providers** setup (not the regular Starter Models tab).
- External starter models auto-install when a provider is configured.
- Removing a provider API key removes installed external models for that provider.
## 2) Credentials and Config
External provider API keys are stored separately from `invokeai.yaml`:
- default file: `~/invokeai/api_keys.yaml`
- resolved path: `<INVOKEAI_ROOT>/api_keys.yaml`
Non-secret provider settings (for example base URL overrides) stay in `invokeai.yaml`.
Environment variables are still supported, e.g.:
- `INVOKEAI_EXTERNAL_GEMINI_API_KEY`
- `INVOKEAI_EXTERNAL_OPENAI_API_KEY`
## 3) Add a New Provider (Only If Needed)
If your model uses a provider that is not already integrated:
1. Add config fields in `invokeai/app/services/config/config_default.py`
`external_<provider>_api_key` and optional `external_<provider>_base_url`.
2. Add provider field mapping in `invokeai/app/api/routers/app_info.py`
(`EXTERNAL_PROVIDER_FIELDS`).
3. Implement provider adapter in `invokeai/app/services/external_generation/providers/`
by subclassing `ExternalProvider`.
4. Register the provider in `invokeai/app/api/dependencies.py` when building
`ExternalGenerationService`.
5. Add starter model entries using `source="external://<provider>/<model-id>"`.
6. Optional UI ordering tweak:
`invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm.tsx`
(`PROVIDER_SORT_ORDER`).
## 4) Optional Manual Installation
You can also install external models directly via:
`POST /api/v2/models/install?source=external://<provider_id>/<provider_model_id>`
If omitted, `path`, `source`, and `hash` are auto-populated for external model configs.
Set capabilities conservatively; the external generation service enforces capability checks at runtime.

View File

@@ -8,6 +8,10 @@ We welcome contributions, whether features, bug fixes, code cleanup, testing, co
If youd like to help with development, please see our [development guide](contribution_guides/development.md).
## External Providers
If you are adding external image generation providers or configs, see our [external provider integration guide](EXTERNAL_PROVIDERS.md).
**New Contributors:** If youre unfamiliar with contributing to open source projects, take a look at our [new contributor guide](contribution_guides/newContributorChecklist.md).
## Nodes

View File

@@ -16,6 +16,9 @@ from invokeai.app.services.client_state_persistence.client_state_persistence_sql
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_default import DownloadQueueService
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
from invokeai.app.services.external_generation.external_generation_default import ExternalGenerationService
from invokeai.app.services.external_generation.providers import AlibabaCloudProvider, GeminiProvider, OpenAIProvider
from invokeai.app.services.external_generation.startup import sync_configured_external_starter_models
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
from invokeai.app.services.images.images_default import ImageService
@@ -149,13 +152,23 @@ class ApiDependencies:
),
)
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_record_service = ModelRecordServiceSQL(db=db, logger=logger)
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db, logger=logger),
model_record_service=model_record_service,
download_queue=download_queue_service,
events=events,
)
external_generation = ExternalGenerationService(
providers={
AlibabaCloudProvider.provider_id: AlibabaCloudProvider(app_config=configuration, logger=logger),
GeminiProvider.provider_id: GeminiProvider(app_config=configuration, logger=logger),
OpenAIProvider.provider_id: OpenAIProvider(app_config=configuration, logger=logger),
},
logger=logger,
record_store=model_record_service,
)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_relationships = ModelRelationshipsService()
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
names = SimpleNameService()
@@ -188,6 +201,7 @@ class ApiDependencies:
model_relationships=model_relationships,
model_relationship_records=model_relationship_records,
download_queue=download_queue_service,
external_generation=external_generation,
names=names,
performance_statistics=performance_statistics,
session_processor=session_processor,
@@ -204,6 +218,16 @@ class ApiDependencies:
)
ApiDependencies.invoker = Invoker(services)
configured_external_providers = {
provider_id
for provider_id, status in external_generation.get_provider_statuses().items()
if status.configured
}
sync_configured_external_starter_models(
configured_provider_ids=configured_external_providers,
model_manager=model_manager,
logger=logger,
)
db.clean()
@staticmethod

View File

@@ -1,21 +1,30 @@
import locale
from enum import Enum
from importlib.metadata import distributions
from pathlib import Path as FilePath
from threading import Lock
import torch
from fastapi import Body
import yaml
from fastapi import Body, HTTPException, Path
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.auth_dependencies import AdminUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.config.config_default import (
EXTERNAL_PROVIDER_CONFIG_FIELDS,
DefaultInvokeAIAppConfig,
InvokeAIAppConfig,
get_config,
load_and_migrate_config,
load_external_api_keys,
)
from invokeai.app.services.external_generation.external_generation_common import ExternalProviderStatus
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
from invokeai.backend.util.logging import logging
from invokeai.version import __version__
@@ -47,7 +56,7 @@ async def get_version() -> AppVersion:
async def get_app_deps() -> dict[str, str]:
deps: dict[str, str] = {dist.metadata["Name"]: dist.version for dist in distributions()}
try:
cuda = torch.version.cuda or "N/A"
cuda = getattr(getattr(torch, "version", None), "cuda", None) or "N/A" # pyright: ignore[reportAttributeAccessIssue]
except Exception:
cuda = "N/A"
@@ -70,6 +79,31 @@ class InvokeAIAppConfigWithSetFields(BaseModel):
config: InvokeAIAppConfig = Field(description="The InvokeAI App Config")
class ExternalProviderStatusModel(BaseModel):
provider_id: str = Field(description="The external provider identifier")
configured: bool = Field(description="Whether credentials are configured for the provider")
message: str | None = Field(default=None, description="Optional provider status detail")
class ExternalProviderConfigUpdate(BaseModel):
api_key: str | None = Field(default=None, description="API key for the external provider")
base_url: str | None = Field(default=None, description="Optional base URL override for the provider")
class ExternalProviderConfigModel(BaseModel):
provider_id: str = Field(description="The external provider identifier")
api_key_configured: bool = Field(description="Whether an API key is configured")
base_url: str | None = Field(default=None, description="Optional base URL override")
EXTERNAL_PROVIDER_FIELDS: dict[str, tuple[str, str]] = {
"alibabacloud": ("external_alibabacloud_api_key", "external_alibabacloud_base_url"),
"gemini": ("external_gemini_api_key", "external_gemini_base_url"),
"openai": ("external_openai_api_key", "external_openai_base_url"),
}
_EXTERNAL_PROVIDER_CONFIG_LOCK = Lock()
class UpdateAppGenerationSettingsRequest(BaseModel):
"""Writable generation-related app settings."""
@@ -112,6 +146,166 @@ async def update_runtime_config(
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
@app_router.get(
"/external_providers/status",
operation_id="get_external_provider_statuses",
status_code=200,
response_model=list[ExternalProviderStatusModel],
)
async def get_external_provider_statuses() -> list[ExternalProviderStatusModel]:
statuses = ApiDependencies.invoker.services.external_generation.get_provider_statuses()
return [status_to_model(status) for status in statuses.values()]
@app_router.get(
"/external_providers/config",
operation_id="get_external_provider_configs",
status_code=200,
response_model=list[ExternalProviderConfigModel],
)
async def get_external_provider_configs() -> list[ExternalProviderConfigModel]:
config = get_config()
return [_build_external_provider_config(provider_id, config) for provider_id in EXTERNAL_PROVIDER_FIELDS]
@app_router.post(
"/external_providers/config/{provider_id}",
operation_id="set_external_provider_config",
status_code=200,
response_model=ExternalProviderConfigModel,
)
async def set_external_provider_config(
provider_id: str = Path(description="The external provider identifier"),
update: ExternalProviderConfigUpdate = Body(description="External provider configuration settings"),
) -> ExternalProviderConfigModel:
api_key_field, base_url_field = _get_external_provider_fields(provider_id)
updates: dict[str, str | None] = {}
if update.api_key is not None:
api_key = update.api_key.strip()
updates[api_key_field] = api_key or None
if update.base_url is not None:
base_url = update.base_url.strip()
updates[base_url_field] = base_url or None
if not updates:
raise HTTPException(status_code=400, detail="No external provider config fields provided")
api_key_removed = update.api_key is not None and updates.get(api_key_field) is None
_apply_external_provider_update(updates)
if api_key_removed:
_remove_external_models_for_provider(provider_id)
return _build_external_provider_config(provider_id, get_config())
@app_router.delete(
"/external_providers/config/{provider_id}",
operation_id="reset_external_provider_config",
status_code=200,
response_model=ExternalProviderConfigModel,
)
async def reset_external_provider_config(
provider_id: str = Path(description="The external provider identifier"),
) -> ExternalProviderConfigModel:
api_key_field, base_url_field = _get_external_provider_fields(provider_id)
_apply_external_provider_update({api_key_field: None, base_url_field: None})
_remove_external_models_for_provider(provider_id)
return _build_external_provider_config(provider_id, get_config())
def status_to_model(status: ExternalProviderStatus) -> ExternalProviderStatusModel:
return ExternalProviderStatusModel(
provider_id=status.provider_id,
configured=status.configured,
message=status.message,
)
def _get_external_provider_fields(provider_id: str) -> tuple[str, str]:
if provider_id not in EXTERNAL_PROVIDER_FIELDS:
raise HTTPException(status_code=404, detail=f"Unknown external provider '{provider_id}'")
return EXTERNAL_PROVIDER_FIELDS[provider_id]
def _write_external_api_keys_file(api_keys_file_path: FilePath, api_keys: dict[str, str]) -> None:
if not api_keys:
if api_keys_file_path.exists():
api_keys_file_path.unlink()
return
api_keys_file_path.parent.mkdir(parents=True, exist_ok=True)
with open(api_keys_file_path, "w", encoding=locale.getpreferredencoding()) as api_keys_file:
yaml.safe_dump(api_keys, api_keys_file, sort_keys=False)
def _apply_external_provider_update(updates: dict[str, str | None]) -> None:
with _EXTERNAL_PROVIDER_CONFIG_LOCK:
runtime_config = get_config()
config_path = runtime_config.config_file_path
api_keys_file_path = runtime_config.api_keys_file_path
if config_path.exists():
file_config = load_and_migrate_config(config_path)
else:
file_config = DefaultInvokeAIAppConfig()
runtime_config.update_config(updates)
provider_config_fields = set(EXTERNAL_PROVIDER_CONFIG_FIELDS)
provider_updates = {field: value for field, value in updates.items() if field in provider_config_fields}
non_provider_updates = {field: value for field, value in updates.items() if field not in provider_config_fields}
if non_provider_updates:
file_config.update_config(non_provider_updates)
persisted_api_keys = load_external_api_keys(api_keys_file_path)
for field_name in EXTERNAL_PROVIDER_CONFIG_FIELDS:
file_value = getattr(file_config, field_name, None)
if field_name not in persisted_api_keys and isinstance(file_value, str) and file_value.strip():
persisted_api_keys[field_name] = file_value
for field_name, value in provider_updates.items():
if value is None:
persisted_api_keys.pop(field_name, None)
else:
persisted_api_keys[field_name] = value
_write_external_api_keys_file(api_keys_file_path, persisted_api_keys)
for field_name in EXTERNAL_PROVIDER_CONFIG_FIELDS:
setattr(file_config, field_name, None)
file_config_to_write = type(file_config).model_validate(
file_config.model_dump(exclude_unset=True, exclude_none=True)
)
file_config_to_write.write_file(config_path, as_example=False)
def _build_external_provider_config(provider_id: str, config: InvokeAIAppConfig) -> ExternalProviderConfigModel:
api_key_field, base_url_field = _get_external_provider_fields(provider_id)
return ExternalProviderConfigModel(
provider_id=provider_id,
api_key_configured=bool(getattr(config, api_key_field)),
base_url=getattr(config, base_url_field),
)
def _remove_external_models_for_provider(provider_id: str) -> None:
model_manager = ApiDependencies.invoker.services.model_manager
external_models = model_manager.store.search_by_attr(
base_model=BaseModelType.External,
model_type=ModelType.ExternalImageGenerator,
)
for model in external_models:
if getattr(model, "provider_id", None) != provider_id:
continue
try:
model_manager.install.delete(model.key)
except UnknownModelException:
logging.warning(f"External model key '{model.key}' was already removed while resetting '{provider_id}'")
except Exception as error:
logging.warning(f"Failed removing external model key '{model.key}' for '{provider_id}': {error}")
@app_router.get(
"/logging",
operation_id="get_log_level",

View File

@@ -30,6 +30,7 @@ from invokeai.app.services.model_records import (
)
from invokeai.app.services.orphaned_models import OrphanedModelInfo
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
from invokeai.backend.model_manager.configs.main import (
Main_Checkpoint_SD1_Config,
@@ -75,8 +76,36 @@ class CacheType(str, Enum):
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
"""Add a cover image URL to a model configuration."""
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
config.cover_image = cover_image
return config
return config.model_copy(update={"cover_image": cover_image})
def apply_external_starter_model_overrides(config: AnyModelConfig) -> AnyModelConfig:
"""Overlay starter-model metadata onto installed external model configs."""
if not isinstance(config, ExternalApiModelConfig):
return config
starter_match = next((starter for starter in STARTER_MODELS if starter.source == config.source), None)
if starter_match is None:
return config
model_updates: dict[str, object] = {}
if starter_match.capabilities is not None:
model_updates["capabilities"] = starter_match.capabilities
if starter_match.default_settings is not None:
model_updates["default_settings"] = starter_match.default_settings
if starter_match.panel_schema is not None:
model_updates["panel_schema"] = starter_match.panel_schema
if not model_updates:
return config
return config.model_copy(update=model_updates)
def prepare_model_config_for_response(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
"""Apply API-only model config overlays before returning a response."""
config = apply_external_starter_model_overrides(config)
return add_cover_image_to_model_config(config, dependencies)
##############################################################################
@@ -145,8 +174,8 @@ async def list_model_records(
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
)
for model in found_models:
model = add_cover_image_to_model_config(model, ApiDependencies)
for index, model in enumerate(found_models):
found_models[index] = prepare_model_config_for_response(model, ApiDependencies)
return ModelsList(models=found_models)
@@ -166,6 +195,8 @@ async def list_missing_models() -> ModelsList:
missing_models: list[AnyModelConfig] = []
for model_config in record_store.all_models():
if model_config.base == BaseModelType.External or model_config.format == ModelFormat.ExternalApi:
continue
if not (models_path / model_config.path).resolve().exists():
missing_models.append(model_config)
@@ -190,7 +221,7 @@ async def get_model_records_by_attrs(
if not configs:
raise HTTPException(status_code=404, detail="No model found with these attributes")
return configs[0]
return prepare_model_config_for_response(configs[0], ApiDependencies)
@model_manager_router.get(
@@ -207,7 +238,7 @@ async def get_model_records_by_hash(
if not configs:
raise HTTPException(status_code=404, detail="No model found with this hash")
return configs[0]
return prepare_model_config_for_response(configs[0], ApiDependencies)
@model_manager_router.get(
@@ -228,7 +259,7 @@ async def get_model_record(
"""Get a model record"""
try:
config = ApiDependencies.invoker.services.model_manager.store.get_model(key)
return add_cover_image_to_model_config(config, ApiDependencies)
return prepare_model_config_for_response(config, ApiDependencies)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -268,7 +299,7 @@ async def reidentify_model(
result.config.name = config.name
result.config.description = config.description
result.config.cover_image = config.cover_image
if hasattr(config, "trigger_phrases") and hasattr(result.config, "trigger_phrases"):
if hasattr(result.config, "trigger_phrases") and hasattr(config, "trigger_phrases"):
result.config.trigger_phrases = config.trigger_phrases
result.config.source = config.source
result.config.source_type = config.source_type
@@ -392,7 +423,7 @@ async def update_model_record(
record_store = ApiDependencies.invoker.services.model_manager.store
try:
config = record_store.update_model(key, changes=changes, allow_class_change=True)
config = add_cover_image_to_model_config(config, ApiDependencies)
config = prepare_model_config_for_response(config, ApiDependencies)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -1124,7 +1155,7 @@ async def convert_model(
# return the config record for the new diffusers directory
new_config = store.get_model(new_key)
new_config = add_cover_image_to_model_config(new_config, ApiDependencies)
new_config = prepare_model_config_for_response(new_config, ApiDependencies)
return new_config

View File

@@ -0,0 +1,203 @@
from typing import Any, ClassVar, Literal
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
MetadataField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageCollectionOutput
from invokeai.app.services.external_generation.external_generation_common import (
ExternalGenerationRequest,
ExternalGenerationResult,
ExternalReferenceImage,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig, ExternalGenerationMode
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
class BaseExternalImageGenerationInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generate images using an external provider."""
provider_id: ClassVar[str | None] = None
model: ModelIdentifierField = InputField(
description=FieldDescriptions.main_model,
ui_model_base=[BaseModelType.External],
ui_model_type=[ModelType.ExternalImageGenerator],
ui_model_format=[ModelFormat.ExternalApi],
)
mode: ExternalGenerationMode = InputField(default="txt2img", description="Generation mode")
prompt: str = InputField(description="Prompt")
seed: int | None = InputField(default=None, description=FieldDescriptions.seed)
num_images: int = InputField(default=1, gt=0, description="Number of images to generate")
width: int = InputField(default=1024, gt=0, description=FieldDescriptions.width)
height: int = InputField(default=1024, gt=0, description=FieldDescriptions.height)
image_size: str | None = InputField(default=None, description="Image size preset (e.g. 1K, 2K, 4K)")
init_image: ImageField | None = InputField(default=None, description="Init image for img2img/inpaint")
mask_image: ImageField | None = InputField(default=None, description="Mask image for inpaint")
reference_images: list[ImageField] = InputField(default=[], description="Reference images")
def _build_provider_options(self) -> dict[str, Any] | None:
"""Override in provider-specific subclasses to pass extra options."""
return None
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
model_config = context.models.get_config(self.model)
if not isinstance(model_config, ExternalApiModelConfig):
raise ValueError("Selected model is not an external API model")
if self.provider_id is not None and model_config.provider_id != self.provider_id:
raise ValueError(
f"Selected model provider '{model_config.provider_id}' does not match node provider '{self.provider_id}'"
)
init_image = None
if self.init_image is not None:
init_image = context.images.get_pil(self.init_image.image_name, mode="RGB")
mask_image = None
if self.mask_image is not None:
mask_image = context.images.get_pil(self.mask_image.image_name, mode="L")
reference_images: list[ExternalReferenceImage] = []
for image_field in self.reference_images:
reference_image = context.images.get_pil(image_field.image_name, mode="RGB")
reference_images.append(ExternalReferenceImage(image=reference_image))
request = ExternalGenerationRequest(
model=model_config,
mode=self.mode,
prompt=self.prompt,
seed=self.seed,
num_images=self.num_images,
width=self.width,
height=self.height,
image_size=self.image_size,
init_image=init_image,
mask_image=mask_image,
reference_images=reference_images,
metadata=self._build_request_metadata(),
provider_options=self._build_provider_options(),
)
result = context._services.external_generation.generate(request)
outputs: list[ImageField] = []
for generated in result.images:
metadata = self._build_output_metadata(model_config, result, generated.seed)
image_dto = context.images.save(image=generated.image, metadata=metadata)
outputs.append(ImageField(image_name=image_dto.image_name))
return ImageCollectionOutput(collection=outputs)
def _build_request_metadata(self) -> dict[str, Any] | None:
if self.metadata is None:
return None
return self.metadata.root
def _build_output_metadata(
self,
model_config: ExternalApiModelConfig,
result: ExternalGenerationResult,
image_seed: int | None,
) -> MetadataField | None:
metadata: dict[str, Any] = {}
if self.metadata is not None:
metadata.update(self.metadata.root)
metadata.update(
{
"external_provider": model_config.provider_id,
"external_model_id": model_config.provider_model_id,
}
)
provider_request_id = getattr(result, "provider_request_id", None)
if provider_request_id:
metadata["external_request_id"] = provider_request_id
provider_metadata = getattr(result, "provider_metadata", None)
if provider_metadata:
metadata["external_provider_metadata"] = provider_metadata
if image_seed is not None:
metadata["external_seed"] = image_seed
if not metadata:
return None
return MetadataField(root=metadata)
@invocation(
"external_image_generation",
title="External Image Generation (Legacy)",
tags=["external", "generation"],
category="image",
version="1.1.0",
classification=Classification.Internal,
)
class ExternalImageGenerationInvocation(BaseExternalImageGenerationInvocation):
"""Legacy external image generation node kept for backward compatibility."""
@invocation(
"openai_image_generation",
title="OpenAI Image Generation",
tags=["external", "generation", "openai"],
category="image",
version="1.0.0",
)
class OpenAIImageGenerationInvocation(BaseExternalImageGenerationInvocation):
"""Generate images using an OpenAI-hosted external model."""
provider_id = "openai"
quality: Literal["auto", "high", "medium", "low"] = InputField(default="auto", description="Output image quality")
background: Literal["auto", "transparent", "opaque"] = InputField(
default="auto", description="Background transparency handling"
)
input_fidelity: Literal["low", "high"] | None = InputField(
default=None, description="Fidelity to source images (edits only)"
)
def _build_provider_options(self) -> dict[str, Any]:
options: dict[str, Any] = {
"quality": self.quality,
"background": self.background,
}
if self.input_fidelity is not None:
options["input_fidelity"] = self.input_fidelity
return options
@invocation(
"gemini_image_generation",
title="Gemini Image Generation",
tags=["external", "generation", "gemini"],
category="image",
version="1.0.0",
)
class GeminiImageGenerationInvocation(BaseExternalImageGenerationInvocation):
"""Generate images using a Gemini-hosted external model."""
provider_id = "gemini"
temperature: float | None = InputField(default=None, ge=0.0, le=2.0, description="Sampling temperature")
thinking_level: Literal["minimal", "high"] | None = InputField(
default=None, description="Thinking level for image generation"
)
def _build_provider_options(self) -> dict[str, Any] | None:
options: dict[str, Any] = {}
if self.temperature is not None:
options["temperature"] = self.temperature
if self.thinking_level is not None:
options["thinking_level"] = self.thinking_level
return options or None

View File

@@ -22,6 +22,7 @@ from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
INIT_FILE = Path("invokeai.yaml")
API_KEYS_FILE = Path("api_keys.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
@@ -30,6 +31,14 @@ ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.2"
EXTERNAL_PROVIDER_CONFIG_FIELDS = (
"external_alibabacloud_api_key",
"external_alibabacloud_base_url",
"external_gemini_api_key",
"external_gemini_base_url",
"external_openai_api_key",
"external_openai_base_url",
)
class URLRegexTokenPair(BaseModel):
@@ -113,6 +122,10 @@ class InvokeAIAppConfig(BaseSettings):
allow_unknown_models: Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.
multiuser: Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.
strict_password_checking: Enforce strict password requirements. When True, passwords must contain uppercase, lowercase, and numbers. When False (default), any password is accepted but its strength (weak/moderate/strong) is reported to the user.
external_gemini_api_key: API key for Gemini image generation.
external_openai_api_key: API key for OpenAI image generation.
external_gemini_base_url: Base URL override for Gemini image generation.
external_openai_base_url: Base URL override for OpenAI image generation.
"""
_root: Optional[Path] = PrivateAttr(default=None)
@@ -211,6 +224,20 @@ class InvokeAIAppConfig(BaseSettings):
multiuser: bool = Field(default=False, description="Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.")
strict_password_checking: bool = Field(default=False, description="Enforce strict password requirements. When True, passwords must contain uppercase, lowercase, and numbers. When False (default), any password is accepted but its strength (weak/moderate/strong) is reported to the user.")
# EXTERNAL PROVIDERS
external_alibabacloud_api_key: Optional[str] = Field(default=None, description="API key for Alibaba Cloud DashScope image generation.")
external_alibabacloud_base_url: Optional[str] = Field(
default=None, description="Base URL override for Alibaba Cloud DashScope image generation."
)
external_gemini_api_key: Optional[str] = Field(default=None, description="API key for Gemini image generation.")
external_openai_api_key: Optional[str] = Field(default=None, description="API key for OpenAI image generation.")
external_gemini_base_url: Optional[str] = Field(
default=None, description="Base URL override for Gemini image generation."
)
external_openai_base_url: Optional[str] = Field(
default=None, description="Base URL override for OpenAI image generation."
)
# fmt: on
model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True)
@@ -292,6 +319,13 @@ class InvokeAIAppConfig(BaseSettings):
assert resolved_path is not None
return resolved_path
@property
def api_keys_file_path(self) -> Path:
"""Path to api_keys.yaml, resolved to an absolute path.."""
resolved_path = self._resolve(API_KEYS_FILE)
assert resolved_path is not None
return resolved_path
@property
def outputs_path(self) -> Optional[Path]:
"""Path to the outputs directory, resolved to an absolute path.."""
@@ -504,6 +538,36 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
def load_external_api_keys(api_keys_file_path: Path) -> dict[str, str]:
"""Load external provider config (API keys and base URLs) from a dedicated YAML file."""
if not api_keys_file_path.exists():
return {}
with open(api_keys_file_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_api_keys: Any = yaml.safe_load(file)
if loaded_api_keys is None:
return {}
if not isinstance(loaded_api_keys, dict):
raise RuntimeError(f"Failed to load api keys file {api_keys_file_path}: expected a mapping")
parsed_api_keys: dict[str, str] = {}
for field_name in EXTERNAL_PROVIDER_CONFIG_FIELDS:
value = loaded_api_keys.get(field_name)
if value is None:
continue
if not isinstance(value, str):
raise RuntimeError(
f"Failed to load api keys file {api_keys_file_path}: value for '{field_name}' must be a string"
)
stripped_value = value.strip()
if stripped_value:
parsed_api_keys[field_name] = stripped_value
return parsed_api_keys
@lru_cache(maxsize=1)
def get_config() -> InvokeAIAppConfig:
"""Get the global singleton app config.
@@ -520,6 +584,7 @@ def get_config() -> InvokeAIAppConfig:
"""
# This object includes environment variables, as parsed by pydantic-settings
config = InvokeAIAppConfig()
env_fields_set = set(config.model_fields_set)
args = InvokeAIArgs.args
@@ -581,4 +646,11 @@ def get_config() -> InvokeAIAppConfig:
default_config = DefaultInvokeAIAppConfig()
default_config.write_file(config.config_file_path, as_example=False)
api_keys_from_file = load_external_api_keys(config.api_keys_file_path)
if api_keys_from_file:
# API keys file should take precedence over invokeai.yaml, but not over environment variables.
api_keys_to_apply = {key: value for key, value in api_keys_from_file.items() if key not in env_fields_set}
if api_keys_to_apply:
config.update_config(api_keys_to_apply, clobber=True)
return config

View File

@@ -0,0 +1,23 @@
from invokeai.app.services.external_generation.external_generation_base import (
ExternalGenerationServiceBase,
ExternalProvider,
)
from invokeai.app.services.external_generation.external_generation_common import (
ExternalGeneratedImage,
ExternalGenerationRequest,
ExternalGenerationResult,
ExternalProviderStatus,
ExternalReferenceImage,
)
from invokeai.app.services.external_generation.external_generation_default import ExternalGenerationService
__all__ = [
"ExternalGenerationRequest",
"ExternalGenerationResult",
"ExternalGeneratedImage",
"ExternalGenerationService",
"ExternalGenerationServiceBase",
"ExternalProvider",
"ExternalProviderStatus",
"ExternalReferenceImage",
]

View File

@@ -0,0 +1,28 @@
class ExternalGenerationError(Exception):
"""Base error for external generation."""
class ExternalProviderNotFoundError(ExternalGenerationError):
"""Raised when no provider is registered for a model."""
class ExternalProviderNotConfiguredError(ExternalGenerationError):
"""Raised when a provider is missing required credentials."""
class ExternalProviderCapabilityError(ExternalGenerationError):
"""Raised when a request is not supported by provider capabilities."""
class ExternalProviderRequestError(ExternalGenerationError):
"""Raised when a provider rejects the request or returns an error."""
class ExternalProviderRateLimitError(ExternalProviderRequestError):
"""Raised when a provider returns HTTP 429 (rate limit exceeded)."""
retry_after: float | None
def __init__(self, message: str, retry_after: float | None = None) -> None:
super().__init__(message)
self.retry_after = retry_after

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.external_generation.external_generation_common import (
ExternalGenerationRequest,
ExternalGenerationResult,
ExternalProviderStatus,
)
class ExternalProvider(ABC):
provider_id: str
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
@abstractmethod
def is_configured(self) -> bool:
raise NotImplementedError
@abstractmethod
def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult:
raise NotImplementedError
def get_status(self) -> ExternalProviderStatus:
return ExternalProviderStatus(provider_id=self.provider_id, configured=self.is_configured())
class ExternalGenerationServiceBase(ABC):
@abstractmethod
def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult:
raise NotImplementedError
@abstractmethod
def get_provider_statuses(self) -> dict[str, ExternalProviderStatus]:
raise NotImplementedError

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from PIL.Image import Image as PILImageType
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig, ExternalGenerationMode
@dataclass(frozen=True)
class ExternalReferenceImage:
image: PILImageType
@dataclass(frozen=True)
class ExternalGenerationRequest:
model: ExternalApiModelConfig
mode: ExternalGenerationMode
prompt: str
seed: int | None
num_images: int
width: int
height: int
image_size: str | None
init_image: PILImageType | None
mask_image: PILImageType | None
reference_images: list[ExternalReferenceImage]
metadata: dict[str, Any] | None
provider_options: dict[str, Any] | None = None
@dataclass(frozen=True)
class ExternalGeneratedImage:
image: PILImageType
seed: int | None = None
@dataclass(frozen=True)
class ExternalGenerationResult:
images: list[ExternalGeneratedImage]
seed_used: int | None = None
provider_request_id: str | None = None
provider_metadata: dict[str, Any] | None = None
content_filters: dict[str, str] | None = None
@dataclass(frozen=True)
class ExternalProviderStatus:
provider_id: str
configured: bool
message: str | None = None

View File

@@ -0,0 +1,353 @@
from __future__ import annotations
import time
from logging import Logger
from typing import TYPE_CHECKING
from PIL import Image
from PIL.Image import Image as PILImageType
from invokeai.app.services.external_generation.errors import (
ExternalProviderCapabilityError,
ExternalProviderNotConfiguredError,
ExternalProviderNotFoundError,
ExternalProviderRateLimitError,
)
from invokeai.app.services.external_generation.external_generation_base import (
ExternalGenerationServiceBase,
ExternalProvider,
)
from invokeai.app.services.external_generation.external_generation_common import (
ExternalGeneratedImage,
ExternalGenerationRequest,
ExternalGenerationResult,
ExternalProviderStatus,
)
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig, ExternalImageSize
from invokeai.backend.model_manager.starter_models import STARTER_MODELS
if TYPE_CHECKING:
from invokeai.app.services.model_records import ModelRecordServiceBase
class ExternalGenerationService(ExternalGenerationServiceBase):
def __init__(
self,
providers: dict[str, ExternalProvider],
logger: Logger,
record_store: ModelRecordServiceBase | None = None,
) -> None:
self._providers = providers
self._logger = logger
self._record_store = record_store
def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult:
provider = self._providers.get(request.model.provider_id)
if provider is None:
raise ExternalProviderNotFoundError(f"No external provider registered for '{request.model.provider_id}'")
if not provider.is_configured():
raise ExternalProviderNotConfiguredError(f"Provider '{request.model.provider_id}' is missing credentials")
request = self._refresh_model_capabilities(request)
resize_to_original_inpaint_size = _get_resize_target_for_inpaint(request)
request = self._bucket_request(request)
self._validate_request(request)
result = self._generate_with_retry(provider, request)
if resize_to_original_inpaint_size is None:
return result
width, height = resize_to_original_inpaint_size
return _resize_result_images(result, width, height)
_MAX_RETRIES = 3
_DEFAULT_RETRY_DELAY = 10.0
_MAX_RETRY_DELAY = 60.0
def _generate_with_retry(
self, provider: ExternalProvider, request: ExternalGenerationRequest
) -> ExternalGenerationResult:
for attempt in range(self._MAX_RETRIES):
try:
return provider.generate(request)
except ExternalProviderRateLimitError as exc:
if attempt == self._MAX_RETRIES - 1:
raise
delay = min(exc.retry_after or self._DEFAULT_RETRY_DELAY, self._MAX_RETRY_DELAY)
self._logger.warning(
"Rate limited by %s (attempt %d/%d), retrying in %.0fs",
request.model.provider_id,
attempt + 1,
self._MAX_RETRIES,
delay,
)
time.sleep(delay)
raise ExternalProviderRateLimitError("Rate limit exceeded after all retries")
def get_provider_statuses(self) -> dict[str, ExternalProviderStatus]:
return {provider_id: provider.get_status() for provider_id, provider in self._providers.items()}
def _validate_request(self, request: ExternalGenerationRequest) -> None:
capabilities = request.model.capabilities
self._logger.debug(
"Validating external request provider=%s model=%s mode=%s supported=%s",
request.model.provider_id,
request.model.provider_model_id,
request.mode,
capabilities.modes,
)
if request.mode not in capabilities.modes:
raise ExternalProviderCapabilityError(f"Mode '{request.mode}' is not supported by {request.model.name}")
if request.seed is not None and not capabilities.supports_seed:
raise ExternalProviderCapabilityError(f"Seed control is not supported by {request.model.name}")
if request.reference_images and not capabilities.supports_reference_images:
raise ExternalProviderCapabilityError(f"Reference images are not supported by {request.model.name}")
if capabilities.max_reference_images is not None:
if len(request.reference_images) > capabilities.max_reference_images:
raise ExternalProviderCapabilityError(
f"{request.model.name} supports at most {capabilities.max_reference_images} reference images"
)
if capabilities.max_images_per_request is not None and request.num_images > capabilities.max_images_per_request:
raise ExternalProviderCapabilityError(
f"{request.model.name} supports at most {capabilities.max_images_per_request} images per request"
)
if capabilities.max_image_size is not None:
if request.width > capabilities.max_image_size.width or request.height > capabilities.max_image_size.height:
raise ExternalProviderCapabilityError(
f"{request.model.name} supports a maximum size of {capabilities.max_image_size.width}x{capabilities.max_image_size.height}"
)
if capabilities.allowed_aspect_ratios:
aspect_ratio = _format_aspect_ratio(request.width, request.height)
if aspect_ratio not in capabilities.allowed_aspect_ratios:
size_ratio = None
if capabilities.aspect_ratio_sizes:
size_ratio = _ratio_for_size(request.width, request.height, capabilities.aspect_ratio_sizes)
if size_ratio is None or size_ratio not in capabilities.allowed_aspect_ratios:
ratio_label = size_ratio or aspect_ratio
raise ExternalProviderCapabilityError(
f"{request.model.name} does not support aspect ratio {ratio_label}"
)
required_modes = capabilities.input_image_required_for or ["img2img", "inpaint"]
if request.mode in required_modes and request.init_image is None:
raise ExternalProviderCapabilityError(
f"Mode '{request.mode}' requires an init image for {request.model.name}"
)
if request.mode == "inpaint" and request.mask_image is None:
raise ExternalProviderCapabilityError(
f"Mode '{request.mode}' requires a mask image for {request.model.name}"
)
def _refresh_model_capabilities(self, request: ExternalGenerationRequest) -> ExternalGenerationRequest:
if self._record_store is None:
return request
try:
record = self._record_store.get_model(request.model.key)
except Exception:
record = None
if not isinstance(record, ExternalApiModelConfig):
return request
if record.key != request.model.key:
return request
if record.provider_id != request.model.provider_id:
return request
if record.provider_model_id != request.model.provider_model_id:
return request
record = _apply_starter_overrides(record)
if record == request.model:
return request
return ExternalGenerationRequest(
model=record,
mode=request.mode,
prompt=request.prompt,
seed=request.seed,
num_images=request.num_images,
width=request.width,
height=request.height,
image_size=request.image_size,
init_image=request.init_image,
mask_image=request.mask_image,
reference_images=request.reference_images,
metadata=request.metadata,
provider_options=request.provider_options,
)
def _bucket_request(self, request: ExternalGenerationRequest) -> ExternalGenerationRequest:
capabilities = request.model.capabilities
if not capabilities.allowed_aspect_ratios:
return request
aspect_ratio = _format_aspect_ratio(request.width, request.height)
size = None
if capabilities.aspect_ratio_sizes:
size = capabilities.aspect_ratio_sizes.get(aspect_ratio)
if size is not None:
if request.width == size.width and request.height == size.height:
return request
return self._bucket_to_size(request, size.width, size.height, aspect_ratio)
if aspect_ratio in capabilities.allowed_aspect_ratios:
return request
if not capabilities.aspect_ratio_sizes:
return request
closest = _select_closest_ratio(
request.width,
request.height,
capabilities.allowed_aspect_ratios,
)
if closest is None:
return request
size = capabilities.aspect_ratio_sizes.get(closest)
if size is None:
return request
return self._bucket_to_size(request, size.width, size.height, closest)
def _bucket_to_size(
self,
request: ExternalGenerationRequest,
width: int,
height: int,
ratio: str,
) -> ExternalGenerationRequest:
self._logger.info(
"Bucketing external request provider=%s model=%s %sx%s -> %sx%s (ratio %s)",
request.model.provider_id,
request.model.provider_model_id,
request.width,
request.height,
width,
height,
ratio,
)
return ExternalGenerationRequest(
model=request.model,
mode=request.mode,
prompt=request.prompt,
seed=request.seed,
num_images=request.num_images,
width=width,
height=height,
image_size=request.image_size,
init_image=_resize_image(request.init_image, width, height, "RGB"),
mask_image=_resize_image(request.mask_image, width, height, "L"),
reference_images=request.reference_images,
metadata=request.metadata,
provider_options=request.provider_options,
)
def _format_aspect_ratio(width: int, height: int) -> str:
divisor = _gcd(width, height)
return f"{width // divisor}:{height // divisor}"
def _select_closest_ratio(width: int, height: int, ratios: list[str]) -> str | None:
ratio = width / height
parsed: list[tuple[str, float]] = []
for value in ratios:
parsed_ratio = _parse_ratio(value)
if parsed_ratio is not None:
parsed.append((value, parsed_ratio))
if not parsed:
return None
return min(parsed, key=lambda item: abs(item[1] - ratio))[0]
def _ratio_for_size(width: int, height: int, sizes: dict[str, ExternalImageSize]) -> str | None:
for ratio, size in sizes.items():
if size.width == width and size.height == height:
return ratio
return None
def _parse_ratio(value: str) -> float | None:
if ":" not in value:
return None
left, right = value.split(":", 1)
try:
numerator = float(left)
denominator = float(right)
except ValueError:
return None
if denominator == 0:
return None
return numerator / denominator
def _gcd(a: int, b: int) -> int:
while b:
a, b = b, a % b
return a
def _resize_image(image: PILImageType | None, width: int, height: int, mode: str) -> PILImageType | None:
if image is None:
return None
if image.width == width and image.height == height:
return image
return image.convert(mode).resize((width, height), Image.Resampling.LANCZOS)
def _get_resize_target_for_inpaint(request: ExternalGenerationRequest) -> tuple[int, int] | None:
if request.mode != "inpaint" or request.init_image is None:
return None
return request.init_image.width, request.init_image.height
def _resize_result_images(result: ExternalGenerationResult, width: int, height: int) -> ExternalGenerationResult:
resized_images = [
ExternalGeneratedImage(
image=generated.image
if generated.image.width == width and generated.image.height == height
else generated.image.resize((width, height), Image.Resampling.LANCZOS),
seed=generated.seed,
)
for generated in result.images
]
return ExternalGenerationResult(
images=resized_images,
seed_used=result.seed_used,
provider_request_id=result.provider_request_id,
provider_metadata=result.provider_metadata,
content_filters=result.content_filters,
)
def _apply_starter_overrides(model: ExternalApiModelConfig) -> ExternalApiModelConfig:
source = model.source or f"external://{model.provider_id}/{model.provider_model_id}"
starter_match = next((starter for starter in STARTER_MODELS if starter.source == source), None)
if starter_match is None:
return model
updates: dict[str, object] = {}
if starter_match.capabilities is not None:
updates["capabilities"] = starter_match.capabilities
if starter_match.default_settings is not None:
updates["default_settings"] = starter_match.default_settings
if not updates:
return model
return model.model_copy(update=updates)

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
import base64
import io
from PIL import Image
from PIL.Image import Image as PILImageType
def encode_image_base64(image: PILImageType, format: str = "PNG") -> str:
buffer = io.BytesIO()
image.save(buffer, format=format)
return base64.b64encode(buffer.getvalue()).decode("ascii")
def decode_image_base64(encoded: str) -> PILImageType:
data = base64.b64decode(encoded)
image = Image.open(io.BytesIO(data))
return image.convert("RGB")

View File

@@ -0,0 +1,5 @@
from invokeai.app.services.external_generation.providers.alibabacloud import AlibabaCloudProvider
from invokeai.app.services.external_generation.providers.gemini import GeminiProvider
from invokeai.app.services.external_generation.providers.openai import OpenAIProvider
__all__ = ["AlibabaCloudProvider", "GeminiProvider", "OpenAIProvider"]

View File

@@ -0,0 +1,309 @@
from __future__ import annotations
import io
import time
import requests
from PIL import Image
from PIL.Image import Image as PILImageType
from invokeai.app.services.external_generation.errors import ExternalProviderRequestError
from invokeai.app.services.external_generation.external_generation_base import ExternalProvider
from invokeai.app.services.external_generation.external_generation_common import (
ExternalGeneratedImage,
ExternalGenerationRequest,
ExternalGenerationResult,
)
from invokeai.app.services.external_generation.image_utils import decode_image_base64, encode_image_base64
# Models that support the synchronous multimodal-generation endpoint with messages format
_SYNC_MODELS = {
"qwen-image-2.0-pro",
"qwen-image-2.0",
"qwen-image-max",
"wan2.6-t2i",
"wan2.6-image",
"qwen-image-edit-max",
}
# Models that use the async image-generation endpoint with flat prompt format
_ASYNC_MODELS = {
"qwen-image-plus",
"qwen-image",
"qwen-image-edit-plus",
"qwen-image-edit",
"wan2.5-t2i-preview",
"wan2.2-t2i-flash",
"wanx2.0-t2i-turbo",
}
# Models that support image editing (accept input images)
_EDIT_MODELS = {
"wan2.6-image",
"qwen-image-edit-max",
"qwen-image-edit-plus",
"qwen-image-edit",
}
_TASK_POLL_INTERVAL = 5 # seconds
_TASK_POLL_TIMEOUT = 300 # seconds
class AlibabaCloudProvider(ExternalProvider):
provider_id = "alibabacloud"
def is_configured(self) -> bool:
return bool(self._app_config.external_alibabacloud_api_key)
def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult:
api_key = self._app_config.external_alibabacloud_api_key
if not api_key:
raise ExternalProviderRequestError("Alibaba Cloud DashScope API key is not configured")
base_url = (
self._app_config.external_alibabacloud_base_url or "https://dashscope-intl.aliyuncs.com"
).rstrip("/")
model_id = request.model.provider_model_id
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
size = f"{request.width}*{request.height}"
if model_id in _SYNC_MODELS or model_id not in _ASYNC_MODELS:
return self._generate_sync(request, base_url, headers, model_id, size)
else:
return self._generate_async(request, base_url, headers, model_id, size)
def _generate_sync(
self,
request: ExternalGenerationRequest,
base_url: str,
headers: dict[str, str],
model_id: str,
size: str,
) -> ExternalGenerationResult:
"""Use the synchronous multimodal-generation endpoint (messages format)."""
endpoint = f"{base_url}/api/v1/services/aigc/multimodal-generation/generation"
content: list[dict[str, str]] = []
# Add init image for editing
if request.init_image is not None and model_id in _EDIT_MODELS:
content.append({"image": f"data:image/png;base64,{encode_image_base64(request.init_image)}"})
# Add reference images
for ref in request.reference_images:
content.append({"image": f"data:image/png;base64,{encode_image_base64(ref.image)}"})
content.append({"text": request.prompt})
parameters: dict[str, object] = {
"size": size,
"n": request.num_images,
"prompt_extend": False,
"watermark": False,
}
if request.negative_prompt:
parameters["negative_prompt"] = request.negative_prompt
if request.seed is not None:
parameters["seed"] = request.seed
payload: dict[str, object] = {
"model": model_id,
"input": {
"messages": [
{
"role": "user",
"content": content,
}
]
},
"parameters": parameters,
}
response = requests.post(endpoint, headers=headers, json=payload, timeout=120)
if not response.ok:
raise ExternalProviderRequestError(
f"DashScope request failed with status {response.status_code} for model '{model_id}': {response.text}"
)
data = response.json()
request_id = data.get("request_id")
return self._parse_sync_response(data, request, request_id)
def _generate_async(
self,
request: ExternalGenerationRequest,
base_url: str,
headers: dict[str, str],
model_id: str,
size: str,
) -> ExternalGenerationResult:
"""Use the async image-generation endpoint (flat prompt format) with task polling."""
endpoint = f"{base_url}/api/v1/services/aigc/image-generation/generation"
async_headers = {**headers, "X-DashScope-Async": "enable"}
parameters: dict[str, object] = {
"size": size,
"n": request.num_images,
"prompt_extend": False,
"watermark": False,
}
if request.negative_prompt:
parameters["negative_prompt"] = request.negative_prompt
if request.seed is not None:
parameters["seed"] = request.seed
input_data: dict[str, object] = {"prompt": request.prompt}
if request.negative_prompt:
input_data["negative_prompt"] = request.negative_prompt
payload: dict[str, object] = {
"model": model_id,
"input": input_data,
"parameters": parameters,
}
response = requests.post(endpoint, headers=async_headers, json=payload, timeout=60)
if not response.ok:
raise ExternalProviderRequestError(
f"DashScope async request failed with status {response.status_code} for model '{model_id}': {response.text}"
)
data = response.json()
request_id = data.get("request_id")
output = data.get("output", {})
task_id = output.get("task_id")
if not task_id:
raise ExternalProviderRequestError(f"DashScope async response missing task_id: {data}")
return self._poll_task(base_url, headers, task_id, request, request_id)
def _poll_task(
self,
base_url: str,
headers: dict[str, str],
task_id: str,
request: ExternalGenerationRequest,
request_id: str | None,
) -> ExternalGenerationResult:
"""Poll an async task until completion."""
task_url = f"{base_url}/api/v1/tasks/{task_id}"
start_time = time.monotonic()
while True:
elapsed = time.monotonic() - start_time
if elapsed > _TASK_POLL_TIMEOUT:
raise ExternalProviderRequestError(
f"DashScope task {task_id} timed out after {_TASK_POLL_TIMEOUT}s"
)
time.sleep(_TASK_POLL_INTERVAL)
response = requests.get(task_url, headers={"Authorization": headers["Authorization"]}, timeout=30)
if not response.ok:
raise ExternalProviderRequestError(
f"DashScope task poll failed with status {response.status_code}: {response.text}"
)
data = response.json()
output = data.get("output", {})
status = output.get("task_status")
if status == "SUCCEEDED":
return self._parse_async_response(output, request, request_id)
elif status in ("FAILED", "UNKNOWN"):
message = output.get("message", "Unknown error")
raise ExternalProviderRequestError(f"DashScope task {task_id} failed: {message}")
self._logger.debug("DashScope task %s status: %s (%.0fs elapsed)", task_id, status, elapsed)
def _parse_sync_response(
self,
data: dict[str, object],
request: ExternalGenerationRequest,
request_id: str | None,
) -> ExternalGenerationResult:
"""Parse the synchronous multimodal-generation response."""
output = data.get("output")
if not isinstance(output, dict):
raise ExternalProviderRequestError(f"DashScope response missing output: {data}")
choices = output.get("choices")
if not isinstance(choices, list):
raise ExternalProviderRequestError(f"DashScope response missing choices: {data}")
images: list[ExternalGeneratedImage] = []
for choice in choices:
if not isinstance(choice, dict):
continue
message = choice.get("message")
if not isinstance(message, dict):
continue
content = message.get("content")
if not isinstance(content, list):
continue
for part in content:
if not isinstance(part, dict):
continue
image_url = part.get("image")
if isinstance(image_url, str) and image_url:
pil_image = self._download_image(image_url)
images.append(ExternalGeneratedImage(image=pil_image, seed=request.seed))
if not images:
raise ExternalProviderRequestError(f"DashScope response contained no images: {data}")
return ExternalGenerationResult(
images=images,
seed_used=request.seed,
provider_request_id=request_id,
provider_metadata={"model": request.model.provider_model_id},
)
def _parse_async_response(
self,
output: dict[str, object],
request: ExternalGenerationRequest,
request_id: str | None,
) -> ExternalGenerationResult:
"""Parse the async task completion response."""
results = output.get("results")
if not isinstance(results, list):
raise ExternalProviderRequestError(f"DashScope async response missing results: {output}")
images: list[ExternalGeneratedImage] = []
for result in results:
if not isinstance(result, dict):
continue
url = result.get("url")
if isinstance(url, str) and url:
pil_image = self._download_image(url)
images.append(ExternalGeneratedImage(image=pil_image, seed=request.seed))
b64_image = result.get("b64_image")
if isinstance(b64_image, str) and b64_image:
pil_image = decode_image_base64(b64_image)
images.append(ExternalGeneratedImage(image=pil_image, seed=request.seed))
if not images:
raise ExternalProviderRequestError(f"DashScope async response contained no images: {output}")
return ExternalGenerationResult(
images=images,
seed_used=request.seed,
provider_request_id=request_id,
provider_metadata={"model": request.model.provider_model_id},
)
def _download_image(self, url: str) -> PILImageType:
"""Download an image from a URL and return it as a PIL Image."""
response = requests.get(url, timeout=60)
if not response.ok:
raise ExternalProviderRequestError(
f"Failed to download image from DashScope (status {response.status_code})"
)
return Image.open(io.BytesIO(response.content)).convert("RGB")

View File

@@ -0,0 +1,282 @@
from __future__ import annotations
import json
import uuid
import requests
from PIL.Image import Image as PILImageType
from invokeai.app.services.external_generation.errors import (
ExternalProviderRateLimitError,
ExternalProviderRequestError,
)
from invokeai.app.services.external_generation.external_generation_base import ExternalProvider
from invokeai.app.services.external_generation.external_generation_common import (
ExternalGeneratedImage,
ExternalGenerationRequest,
ExternalGenerationResult,
)
from invokeai.app.services.external_generation.image_utils import decode_image_base64, encode_image_base64
class GeminiProvider(ExternalProvider):
provider_id = "gemini"
_SYSTEM_INSTRUCTION = (
"You are an image generation model. Always respond with an image based on the user's prompt. "
"Do not return text-only responses. If the user input is not an edit instruction, "
"interpret it as a request to create a new image."
)
def is_configured(self) -> bool:
return bool(self._app_config.external_gemini_api_key)
def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult:
api_key = self._app_config.external_gemini_api_key
if not api_key:
raise ExternalProviderRequestError("Gemini API key is not configured")
base_url = (self._app_config.external_gemini_base_url or "https://generativelanguage.googleapis.com").rstrip(
"/"
)
if not base_url.endswith("/v1") and not base_url.endswith("/v1beta"):
base_url = f"{base_url}/v1beta"
model_id = request.model.provider_model_id.removeprefix("models/")
endpoint = f"{base_url}/models/{model_id}:generateContent"
request_parts: list[dict[str, object]] = []
if request.init_image is not None:
request_parts.append(
{
"inlineData": {
"mimeType": "image/png",
"data": encode_image_base64(request.init_image),
}
}
)
request_parts.append({"text": request.prompt})
for reference in request.reference_images:
request_parts.append(
{
"inlineData": {
"mimeType": "image/png",
"data": encode_image_base64(reference.image),
}
}
)
opts = request.provider_options or {}
generation_config: dict[str, object] = {
"candidateCount": request.num_images,
"responseModalities": ["IMAGE"],
}
if "temperature" in opts:
generation_config["temperature"] = opts["temperature"]
aspect_ratio = _select_aspect_ratio(
request.width,
request.height,
request.model.capabilities.allowed_aspect_ratios,
)
uses_image_config = request.model.capabilities.resolution_presets is not None
if uses_image_config:
image_config: dict[str, str] = {}
if aspect_ratio is not None:
image_config["aspectRatio"] = aspect_ratio
if request.image_size is not None:
image_config["imageSize"] = request.image_size
if image_config:
generation_config["imageConfig"] = image_config
system_instruction = self._SYSTEM_INSTRUCTION
if request.init_image is not None:
system_instruction = (
f"{system_instruction} An input image is provided. "
"Treat the prompt as an edit instruction and modify the image accordingly. "
"Do not return the original image unchanged."
)
if not uses_image_config and aspect_ratio is not None:
system_instruction = f"{system_instruction} Use an aspect ratio of {aspect_ratio}."
payload: dict[str, object] = {
"systemInstruction": {"parts": [{"text": system_instruction}]},
"contents": [{"role": "user", "parts": request_parts}],
"generationConfig": generation_config,
}
if "thinking_level" in opts:
payload["thinkingConfig"] = {"thinkingLevel": opts["thinking_level"].upper()}
self._dump_debug_payload("request", payload)
response = requests.post(
endpoint,
params={"key": api_key},
json=payload,
timeout=120,
)
if not response.ok:
if response.status_code == 429:
retry_after = _parse_retry_after(response.headers.get("retry-after"))
raise ExternalProviderRateLimitError(
f"Gemini rate limit exceeded. {f'Retry after {retry_after:.0f}s.' if retry_after else 'Please try again later.'}",
retry_after=retry_after,
)
raise ExternalProviderRequestError(
f"Gemini request failed with status {response.status_code} for model '{model_id}': {response.text}"
)
data = response.json()
self._dump_debug_payload("response", data)
if not isinstance(data, dict):
raise ExternalProviderRequestError("Gemini response payload was not a JSON object")
images: list[ExternalGeneratedImage] = []
text_parts: list[str] = []
finish_messages: list[str] = []
candidates = data.get("candidates")
if not isinstance(candidates, list):
raise ExternalProviderRequestError("Gemini response payload missing candidates")
for candidate in candidates:
if not isinstance(candidate, dict):
continue
finish_message = candidate.get("finishMessage")
finish_reason = candidate.get("finishReason")
if isinstance(finish_message, str):
finish_messages.append(finish_message)
elif isinstance(finish_reason, str):
finish_messages.append(f"Finish reason: {finish_reason}")
for part in _iter_response_parts(candidate):
inline_data = part.get("inline_data") or part.get("inlineData")
if isinstance(inline_data, dict):
encoded = inline_data.get("data")
if encoded:
image = decode_image_base64(encoded)
images.append(ExternalGeneratedImage(image=image, seed=request.seed))
self._dump_debug_image(image)
continue
file_data = part.get("fileData") or part.get("file_data")
if isinstance(file_data, dict):
file_uri = file_data.get("fileUri") or file_data.get("file_uri")
if isinstance(file_uri, str) and file_uri:
raise ExternalProviderRequestError(
f"Gemini returned fileUri instead of inline image data: {file_uri}"
)
text = part.get("text")
if isinstance(text, str):
text_parts.append(text)
if not images:
self._logger.error("Gemini response contained no images: %s", data)
detail = ""
if finish_messages:
combined = " ".join(message.strip() for message in finish_messages if message.strip())
if combined:
detail = f" Response status: {combined[:500]}"
elif text_parts:
combined = " ".join(text_parts).strip()
if combined:
detail = f" Response text: {combined[:500]}"
raise ExternalProviderRequestError(f"Gemini response contained no images.{detail}")
return ExternalGenerationResult(
images=images,
seed_used=request.seed,
provider_metadata={"model": request.model.provider_model_id},
)
def _dump_debug_payload(self, label: str, payload: object) -> None:
"""TODO: remove debug payload dump once Gemini is stable."""
try:
outputs_path = self._app_config.outputs_path
if outputs_path is None:
return
debug_dir = outputs_path / "external_debug" / "gemini"
debug_dir.mkdir(parents=True, exist_ok=True)
path = debug_dir / f"{label}_{uuid.uuid4().hex}.json"
path.write_text(json.dumps(payload, indent=2, default=str), encoding="utf-8")
except Exception as exc:
self._logger.debug("Failed to write Gemini debug payload: %s", exc)
def _dump_debug_image(self, image: "PILImageType") -> None:
"""TODO: remove debug image dump once Gemini is stable."""
try:
outputs_path = self._app_config.outputs_path
if outputs_path is None:
return
debug_dir = outputs_path / "external_debug" / "gemini"
debug_dir.mkdir(parents=True, exist_ok=True)
path = debug_dir / f"decoded_{uuid.uuid4().hex}.png"
image.save(path, format="PNG")
except Exception as exc:
self._logger.debug("Failed to write Gemini debug image: %s", exc)
def _iter_response_parts(candidate: dict[str, object]) -> list[dict[str, object]]:
content = candidate.get("content")
if isinstance(content, dict):
content_parts = content.get("parts")
if isinstance(content_parts, list):
return [part for part in content_parts if isinstance(part, dict)]
contents = candidate.get("contents")
if isinstance(contents, list):
parts: list[dict[str, object]] = []
for item in contents:
if not isinstance(item, dict):
continue
item_parts = item.get("parts")
if isinstance(item_parts, list):
parts.extend([part for part in item_parts if isinstance(part, dict)])
if parts:
return parts
return []
def _select_aspect_ratio(width: int, height: int, allowed: list[str] | None) -> str | None:
if width <= 0 or height <= 0:
return None
ratio = width / height
default_ratio = _format_aspect_ratio(width, height)
if not allowed:
return default_ratio
parsed = [(value, _parse_ratio(value)) for value in allowed]
filtered = [(value, parsed_ratio) for value, parsed_ratio in parsed if parsed_ratio is not None]
if not filtered:
return default_ratio
return min(filtered, key=lambda item: abs(item[1] - ratio))[0]
def _format_aspect_ratio(width: int, height: int) -> str | None:
if width <= 0 or height <= 0:
return None
divisor = _gcd(width, height)
return f"{width // divisor}:{height // divisor}"
def _parse_ratio(value: str) -> float | None:
if ":" not in value:
return None
left, right = value.split(":", 1)
try:
numerator = float(left)
denominator = float(right)
except ValueError:
return None
if denominator == 0:
return None
return numerator / denominator
def _parse_retry_after(value: str | None) -> float | None:
if not value:
return None
try:
return float(value)
except ValueError:
return None
def _gcd(a: int, b: int) -> int:
while b:
a, b = b, a % b
return a

View File

@@ -0,0 +1,159 @@
from __future__ import annotations
import io
import requests
from PIL.Image import Image as PILImageType
from invokeai.app.services.external_generation.errors import (
ExternalProviderRateLimitError,
ExternalProviderRequestError,
)
from invokeai.app.services.external_generation.external_generation_base import ExternalProvider
from invokeai.app.services.external_generation.external_generation_common import (
ExternalGeneratedImage,
ExternalGenerationRequest,
ExternalGenerationResult,
)
from invokeai.app.services.external_generation.image_utils import decode_image_base64
class OpenAIProvider(ExternalProvider):
provider_id = "openai"
_GPT_IMAGE_MODELS = {"gpt-image-1", "gpt-image-1.5", "gpt-image-1-mini"}
def is_configured(self) -> bool:
return bool(self._app_config.external_openai_api_key)
def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult:
api_key = self._app_config.external_openai_api_key
if not api_key:
raise ExternalProviderRequestError("OpenAI API key is not configured")
model_id = request.model.provider_model_id
is_gpt_image = model_id in self._GPT_IMAGE_MODELS
size = f"{request.width}x{request.height}"
base_url = (self._app_config.external_openai_base_url or "https://api.openai.com").rstrip("/")
headers = {"Authorization": f"Bearer {api_key}"}
use_edits_endpoint = request.mode != "txt2img" or bool(request.reference_images)
opts = request.provider_options or {}
if not use_edits_endpoint:
payload: dict[str, object] = {
"model": model_id,
"prompt": request.prompt,
"n": request.num_images,
"size": size,
}
# GPT Image models use output_format; DALL-E uses response_format
if is_gpt_image:
payload["output_format"] = "png"
else:
payload["response_format"] = "b64_json"
if is_gpt_image:
if opts.get("quality") and opts["quality"] != "auto":
payload["quality"] = opts["quality"]
if opts.get("background") and opts["background"] != "auto":
payload["background"] = opts["background"]
response = requests.post(
f"{base_url}/v1/images/generations",
headers=headers,
json=payload,
timeout=120,
)
else:
images: list[PILImageType] = []
if request.init_image is not None:
images.append(request.init_image)
images.extend(reference.image for reference in request.reference_images)
if not images:
raise ExternalProviderRequestError(
"OpenAI image edits require at least one image (init image or reference image)"
)
files: list[tuple[str, tuple[str, io.BytesIO, str]]] = []
image_field_name = "image" if len(images) == 1 else "image[]"
for index, image in enumerate(images):
image_buffer = io.BytesIO()
image.save(image_buffer, format="PNG")
image_buffer.seek(0)
files.append((image_field_name, (f"image_{index}.png", image_buffer, "image/png")))
if request.mask_image is not None:
mask_buffer = io.BytesIO()
request.mask_image.save(mask_buffer, format="PNG")
mask_buffer.seek(0)
files.append(("mask", ("mask.png", mask_buffer, "image/png")))
data: dict[str, object] = {
"model": model_id,
"prompt": request.prompt,
"n": request.num_images,
"size": size,
}
if is_gpt_image:
data["output_format"] = "png"
else:
data["response_format"] = "b64_json"
if is_gpt_image:
if opts.get("quality") and opts["quality"] != "auto":
data["quality"] = opts["quality"]
if opts.get("background") and opts["background"] != "auto":
data["background"] = opts["background"]
if opts.get("input_fidelity"):
data["input_fidelity"] = opts["input_fidelity"]
response = requests.post(
f"{base_url}/v1/images/edits",
headers=headers,
data=data,
files=files,
timeout=120,
)
if not response.ok:
if response.status_code == 429:
retry_after = _parse_retry_after(response.headers.get("retry-after"))
raise ExternalProviderRateLimitError(
f"OpenAI rate limit exceeded. {f'Retry after {retry_after:.0f}s.' if retry_after else 'Please try again later.'}",
retry_after=retry_after,
)
raise ExternalProviderRequestError(
f"OpenAI request failed with status {response.status_code}: {response.text}"
)
response_payload = response.json()
if not isinstance(response_payload, dict):
raise ExternalProviderRequestError("OpenAI response payload was not a JSON object")
images: list[ExternalGeneratedImage] = []
data_items = response_payload.get("data")
if not isinstance(data_items, list):
raise ExternalProviderRequestError("OpenAI response payload missing image data")
for item in data_items:
if not isinstance(item, dict):
continue
encoded = item.get("b64_json")
if not encoded:
continue
images.append(ExternalGeneratedImage(image=decode_image_base64(encoded), seed=request.seed))
if not images:
raise ExternalProviderRequestError("OpenAI response contained no images")
return ExternalGenerationResult(
images=images,
seed_used=request.seed,
provider_request_id=response.headers.get("x-request-id"),
provider_metadata={"model": model_id},
)
def _parse_retry_after(value: str | None) -> float | None:
if not value:
return None
try:
return float(value)
except ValueError:
return None

View File

@@ -0,0 +1,59 @@
from logging import Logger
from typing import TYPE_CHECKING
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig
from invokeai.backend.model_manager.starter_models import STARTER_MODELS
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
if TYPE_CHECKING:
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
def sync_configured_external_starter_models(
configured_provider_ids: set[str],
model_manager: "ModelManagerServiceBase",
logger: Logger,
) -> list[str]:
"""Queue missing external starter models for configured providers."""
if not configured_provider_ids:
return []
installed_sources = {
model.source
for model in model_manager.store.search_by_attr(
base_model=BaseModelType.External,
model_type=ModelType.ExternalImageGenerator,
)
if isinstance(model, ExternalApiModelConfig) and model.source
}
queued_sources: list[str] = []
for starter_model in STARTER_MODELS:
if not starter_model.source.startswith("external://"):
continue
provider_id = starter_model.source.removeprefix("external://").split("/", 1)[0]
if provider_id not in configured_provider_ids:
continue
if starter_model.source in installed_sources:
continue
model_manager.install.heuristic_import(
starter_model.source,
config=ModelRecordChanges(
name=starter_model.name,
base=starter_model.base,
type=starter_model.type,
description=starter_model.description,
format=starter_model.format,
capabilities=starter_model.capabilities,
default_settings=starter_model.default_settings,
),
)
queued_sources.append(starter_model.source)
logger.info("Queued external starter model sync for %s", starter_model.source)
return queued_sources

View File

@@ -21,6 +21,7 @@ if TYPE_CHECKING:
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.external_generation.external_generation_base import ExternalGenerationServiceBase
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase
from invokeai.app.services.images.images_base import ImageServiceABC
@@ -63,6 +64,7 @@ class InvocationServices:
model_relationships: "ModelRelationshipsServiceABC",
model_relationship_records: "ModelRelationshipRecordStorageBase",
download_queue: "DownloadQueueServiceBase",
external_generation: "ExternalGenerationServiceBase",
performance_statistics: "InvocationStatsServiceBase",
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
@@ -94,6 +96,7 @@ class InvocationServices:
self.model_relationships = model_relationships
self.model_relationship_records = model_relationship_records
self.download_queue = download_queue
self.external_generation = external_generation
self.performance_statistics = performance_statistics
self.session_queue = session_queue
self.session_processor = session_processor

View File

@@ -139,12 +139,27 @@ class URLModelSource(StringLikeSource):
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
class ExternalModelSource(StringLikeSource):
"""An external provider model identifier."""
provider_id: str
provider_model_id: str
type: Literal["external"] = "external"
def __str__(self) -> str:
return f"external://{self.provider_id}/{self.provider_model_id}"
ModelSource = Annotated[
Union[LocalModelSource, HFModelSource, URLModelSource, ExternalModelSource],
Field(discriminator="type"),
]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
ExternalModelSource: ModelSourceType.External,
}

View File

@@ -28,6 +28,7 @@ from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_install.model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP,
ExternalModelSource,
HFModelSource,
InstallStatus,
InvalidModelConfigException,
@@ -37,10 +38,15 @@ from invokeai.app.services.model_install.model_install_common import (
StringLikeSource,
URLModelSource,
)
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base
from invokeai.backend.model_manager.configs.external_api import (
ExternalApiModelConfig,
ExternalApiModelDefaultSettings,
ExternalModelCapabilities,
)
from invokeai.backend.model_manager.configs.factory import (
AnyModelConfig,
ModelConfigFactory,
@@ -55,7 +61,13 @@ from invokeai.backend.model_manager.metadata import (
)
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
)
from invokeai.backend.model_manager.util.lora_metadata_extractor import apply_lora_metadata
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
@@ -459,6 +471,9 @@ class ModelInstallService(ModelInstallServiceBase):
install_job = self._import_from_hf(source, config)
elif isinstance(source, URLModelSource):
install_job = self._import_from_url(source, config)
elif isinstance(source, ExternalModelSource):
install_job = self._import_external_model(source, config)
self._put_in_queue(install_job)
else:
raise ValueError(f"Unsupported model source: '{type(source)}'")
@@ -758,7 +773,13 @@ class ModelInstallService(ModelInstallServiceBase):
source_obj: Optional[StringLikeSource] = None
source_stripped = source.strip('"')
if Path(source_stripped).exists(): # A local file or directory
if source_stripped.startswith("external://"):
external_id = source_stripped.removeprefix("external://")
provider_id, _, provider_model_id = external_id.partition("/")
if not provider_id or not provider_model_id:
raise ValueError(f"Invalid external model source: '{source_stripped}'")
source_obj = ExternalModelSource(provider_id=provider_id, provider_model_id=provider_model_id)
elif Path(source_stripped).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source_stripped))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
@@ -850,6 +871,9 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
def _register_or_install(self, job: ModelInstallJob) -> None:
if isinstance(job.source, ExternalModelSource):
self._register_external_model(job)
return
# local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
@@ -870,6 +894,71 @@ class ModelInstallService(ModelInstallServiceBase):
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
def _register_external_model(self, job: ModelInstallJob) -> None:
job.total_bytes = 0
job.bytes = 0
self._signal_job_running(job)
job.config_in.source = str(job.source)
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
provider_id = job.source.provider_id
provider_model_id = job.source.provider_model_id
capabilities = job.config_in.capabilities or ExternalModelCapabilities()
default_settings = (
job.config_in.default_settings
if isinstance(job.config_in.default_settings, ExternalApiModelDefaultSettings)
else None
)
name = job.config_in.name or f"{provider_id} {provider_model_id}"
key = job.config_in.key or slugify(f"{provider_id}-{provider_model_id}")
existing_external = next(
(
model
for model in self.record_store.search_by_attr(
base_model=BaseModelType.External, model_type=ModelType.ExternalImageGenerator
)
if isinstance(model, ExternalApiModelConfig)
and model.provider_id == provider_id
and model.provider_model_id == provider_model_id
),
None,
)
if existing_external is not None:
key = existing_external.key
else:
try:
self.record_store.get_model(key)
raise DuplicateModelException(
f"Model key '{key}' already exists. Provide a different key to install this external model."
)
except UnknownModelException:
pass
config = ExternalApiModelConfig(
key=key,
name=name,
description=job.config_in.description,
provider_id=provider_id,
provider_model_id=provider_model_id,
capabilities=capabilities,
default_settings=default_settings,
source=str(job.source),
source_type=MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__],
path="",
hash="",
file_size=0,
)
if existing_external is not None:
self.record_store.replace_model(existing_external.key, config)
else:
self.record_store.add_model(config)
job.config_out = self.record_store.get_model(config.key)
self._signal_job_completed(job)
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
multifile_download_job = install_job._multifile_job
if multifile_download_job and any(
@@ -905,6 +994,8 @@ class ModelInstallService(ModelInstallServiceBase):
"""Scan the models directory for missing models and return a list of them."""
missing_models: list[AnyModelConfig] = []
for model_config in self.record_store.all_models():
if model_config.base == BaseModelType.External or model_config.format == ModelFormat.ExternalApi:
continue
if not (self.app_config.models_path / model_config.path).resolve().exists():
missing_models.append(model_config)
return missing_models
@@ -1046,6 +1137,19 @@ class ModelInstallService(ModelInstallServiceBase):
remote_files=remote_files,
)
def _import_external_model(
self,
source: ExternalModelSource,
config: Optional[ModelRecordChanges] = None,
) -> ModelInstallJob:
return ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or ModelRecordChanges(),
local_path=self._app_config.models_path,
inplace=True,
)
def _import_remote_model(
self,
source: HFModelSource | URLModelSource,

View File

@@ -13,6 +13,10 @@ from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.external_api import (
ExternalApiModelDefaultSettings,
ExternalModelCapabilities,
)
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.configs.lora import LoraModelDefaultSettings
from invokeai.backend.model_manager.configs.main import MainModelDefaultSettings
@@ -87,8 +91,19 @@ class ModelRecordChanges(BaseModelExcludeNull):
file_size: Optional[int] = Field(description="Size of model file", default=None)
format: Optional[str] = Field(description="format of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | LoraModelDefaultSettings | ControlAdapterDefaultSettings] = (
Field(description="Default settings for this model", default=None)
default_settings: Optional[
MainModelDefaultSettings
| LoraModelDefaultSettings
| ControlAdapterDefaultSettings
| ExternalApiModelDefaultSettings
] = Field(description="Default settings for this model", default=None)
# External API model changes
provider_id: Optional[str] = Field(description="External provider identifier", default=None)
provider_model_id: Optional[str] = Field(description="External provider model identifier", default=None)
capabilities: Optional[ExternalModelCapabilities] = Field(
description="External model capabilities",
default=None,
)
cpu_only: Optional[bool] = Field(description="Whether this model should run on CPU only", default=None)

View File

@@ -388,6 +388,8 @@ class ModelsInterface(InvocationContextInterface):
submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
self._raise_if_external(model)
message = f"Loading model {model.name}"
if submodel_type:
message += f" ({submodel_type.value})"
@@ -417,12 +419,18 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
self._raise_if_external(configs[0])
message = f"Loading model {name}"
if submodel_type:
message += f" ({submodel_type.value})"
self._util.signal_progress(message)
return self._services.model_manager.load.load_model(configs[0], submodel_type)
@staticmethod
def _raise_if_external(model: AnyModelConfig) -> None:
if model.base == BaseModelType.External or model.format == ModelFormat.ExternalApi:
raise ValueError("External API models cannot be loaded from disk")
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Get a model's config.

View File

@@ -14,43 +14,61 @@ def _get_processor_invocation_class(processor_type: str):
"""Get the invocation class for a processor type."""
# Import processor invocation classes on demand
processor_class_map = {
"canny_image_processor": lambda: __import__(
"invokeai.app.invocations.canny", fromlist=["CannyEdgeDetectionInvocation"]
).CannyEdgeDetectionInvocation,
"hed_image_processor": lambda: __import__(
"invokeai.app.invocations.hed", fromlist=["HEDEdgeDetectionInvocation"]
).HEDEdgeDetectionInvocation,
"mlsd_image_processor": lambda: __import__(
"invokeai.app.invocations.mlsd", fromlist=["MLSDDetectionInvocation"]
).MLSDDetectionInvocation,
"depth_anything_image_processor": lambda: __import__(
"invokeai.app.invocations.depth_anything", fromlist=["DepthAnythingDepthEstimationInvocation"]
).DepthAnythingDepthEstimationInvocation,
"normalbae_image_processor": lambda: __import__(
"invokeai.app.invocations.normal_bae", fromlist=["NormalMapInvocation"]
).NormalMapInvocation,
"pidi_image_processor": lambda: __import__(
"invokeai.app.invocations.pidi", fromlist=["PiDiNetEdgeDetectionInvocation"]
).PiDiNetEdgeDetectionInvocation,
"lineart_image_processor": lambda: __import__(
"invokeai.app.invocations.lineart", fromlist=["LineartEdgeDetectionInvocation"]
).LineartEdgeDetectionInvocation,
"lineart_anime_image_processor": lambda: __import__(
"invokeai.app.invocations.lineart_anime", fromlist=["LineartAnimeEdgeDetectionInvocation"]
).LineartAnimeEdgeDetectionInvocation,
"content_shuffle_image_processor": lambda: __import__(
"invokeai.app.invocations.content_shuffle", fromlist=["ContentShuffleInvocation"]
).ContentShuffleInvocation,
"dw_openpose_image_processor": lambda: __import__(
"invokeai.app.invocations.dw_openpose", fromlist=["DWOpenposeDetectionInvocation"]
).DWOpenposeDetectionInvocation,
"mediapipe_face_processor": lambda: __import__(
"invokeai.app.invocations.mediapipe_face", fromlist=["MediaPipeFaceDetectionInvocation"]
).MediaPipeFaceDetectionInvocation,
"canny_image_processor": lambda: (
__import__(
"invokeai.app.invocations.canny", fromlist=["CannyEdgeDetectionInvocation"]
).CannyEdgeDetectionInvocation
),
"hed_image_processor": lambda: (
__import__(
"invokeai.app.invocations.hed", fromlist=["HEDEdgeDetectionInvocation"]
).HEDEdgeDetectionInvocation
),
"mlsd_image_processor": lambda: (
__import__("invokeai.app.invocations.mlsd", fromlist=["MLSDDetectionInvocation"]).MLSDDetectionInvocation
),
"depth_anything_image_processor": lambda: (
__import__(
"invokeai.app.invocations.depth_anything", fromlist=["DepthAnythingDepthEstimationInvocation"]
).DepthAnythingDepthEstimationInvocation
),
"normalbae_image_processor": lambda: (
__import__("invokeai.app.invocations.normal_bae", fromlist=["NormalMapInvocation"]).NormalMapInvocation
),
"pidi_image_processor": lambda: (
__import__(
"invokeai.app.invocations.pidi", fromlist=["PiDiNetEdgeDetectionInvocation"]
).PiDiNetEdgeDetectionInvocation
),
"lineart_image_processor": lambda: (
__import__(
"invokeai.app.invocations.lineart", fromlist=["LineartEdgeDetectionInvocation"]
).LineartEdgeDetectionInvocation
),
"lineart_anime_image_processor": lambda: (
__import__(
"invokeai.app.invocations.lineart_anime", fromlist=["LineartAnimeEdgeDetectionInvocation"]
).LineartAnimeEdgeDetectionInvocation
),
"content_shuffle_image_processor": lambda: (
__import__(
"invokeai.app.invocations.content_shuffle", fromlist=["ContentShuffleInvocation"]
).ContentShuffleInvocation
),
"dw_openpose_image_processor": lambda: (
__import__(
"invokeai.app.invocations.dw_openpose", fromlist=["DWOpenposeDetectionInvocation"]
).DWOpenposeDetectionInvocation
),
"mediapipe_face_processor": lambda: (
__import__(
"invokeai.app.invocations.mediapipe_face", fromlist=["MediaPipeFaceDetectionInvocation"]
).MediaPipeFaceDetectionInvocation
),
# Note: zoe_depth_image_processor doesn't have a processor invocation implementation
"color_map_image_processor": lambda: __import__(
"invokeai.app.invocations.color_map", fromlist=["ColorMapInvocation"]
).ColorMapInvocation,
"color_map_image_processor": lambda: (
__import__("invokeai.app.invocations.color_map", fromlist=["ColorMapInvocation"]).ColorMapInvocation
),
}
if processor_type in processor_class_map:

View File

@@ -0,0 +1,113 @@
from typing import Literal, Self
from pydantic import BaseModel, ConfigDict, Field, model_validator
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelSourceType, ModelType
ExternalGenerationMode = Literal["txt2img", "img2img", "inpaint"]
ExternalMaskFormat = Literal["alpha", "binary", "none"]
ExternalPanelControlName = Literal["reference_images", "dimensions", "seed"]
class ExternalImageSize(BaseModel):
width: int = Field(gt=0)
height: int = Field(gt=0)
model_config = ConfigDict(extra="forbid")
class ExternalResolutionPreset(BaseModel):
label: str = Field(min_length=1, description="Display label, e.g. '1:1 (1K)'")
aspect_ratio: str = Field(min_length=1, description="Aspect ratio string, e.g. '1:1'")
image_size: str = Field(min_length=1, description="Image size preset, e.g. '1K'")
width: int = Field(gt=0)
height: int = Field(gt=0)
model_config = ConfigDict(extra="forbid")
class ExternalModelCapabilities(BaseModel):
modes: list[ExternalGenerationMode] = Field(default_factory=lambda: ["txt2img"])
supports_reference_images: bool = Field(default=False)
supports_negative_prompt: bool = Field(default=True)
supports_seed: bool = Field(default=False)
supports_guidance: bool = Field(default=False)
supports_steps: bool = Field(default=False)
max_images_per_request: int | None = Field(default=None, gt=0)
max_image_size: ExternalImageSize | None = Field(default=None)
allowed_aspect_ratios: list[str] | None = Field(default=None)
aspect_ratio_sizes: dict[str, ExternalImageSize] | None = Field(default=None)
resolution_presets: list[ExternalResolutionPreset] | None = Field(default=None)
max_reference_images: int | None = Field(default=None, gt=0)
mask_format: ExternalMaskFormat = Field(default="none")
input_image_required_for: list[ExternalGenerationMode] | None = Field(default=None)
model_config = ConfigDict(extra="forbid")
class ExternalApiModelDefaultSettings(BaseModel):
width: int | None = Field(default=None, gt=0)
height: int | None = Field(default=None, gt=0)
num_images: int | None = Field(default=None, gt=0)
model_config = ConfigDict(extra="forbid")
class ExternalModelPanelControl(BaseModel):
name: ExternalPanelControlName
slider_min: float | None = Field(default=None)
slider_max: float | None = Field(default=None)
number_input_min: float | None = Field(default=None)
number_input_max: float | None = Field(default=None)
fine_step: float | None = Field(default=None)
coarse_step: float | None = Field(default=None)
marks: list[float] | None = Field(default=None)
model_config = ConfigDict(extra="forbid")
class ExternalModelPanelSchema(BaseModel):
prompts: list[ExternalModelPanelControl] = Field(default_factory=list)
image: list[ExternalModelPanelControl] = Field(default_factory=list)
generation: list[ExternalModelPanelControl] = Field(default_factory=list)
model_config = ConfigDict(extra="forbid")
class ExternalApiModelConfig(Config_Base):
base: Literal[BaseModelType.External] = Field(default=BaseModelType.External)
type: Literal[ModelType.ExternalImageGenerator] = Field(default=ModelType.ExternalImageGenerator)
format: Literal[ModelFormat.ExternalApi] = Field(default=ModelFormat.ExternalApi)
provider_id: str = Field(min_length=1, description="External provider ID")
provider_model_id: str = Field(min_length=1, description="Provider-specific model ID")
capabilities: ExternalModelCapabilities = Field(description="Provider capability matrix")
default_settings: ExternalApiModelDefaultSettings | None = Field(default=None)
panel_schema: ExternalModelPanelSchema | None = Field(default=None)
tags: list[str] | None = Field(default=None)
is_default: bool = Field(default=False)
source_type: ModelSourceType = Field(default=ModelSourceType.External)
path: str = Field(default="")
source: str = Field(default="")
hash: str = Field(default="")
file_size: int = Field(default=0, ge=0)
model_config = ConfigDict(extra="forbid")
@model_validator(mode="after")
def _populate_external_fields(self) -> "ExternalApiModelConfig":
if not self.path:
self.path = f"external://{self.provider_id}/{self.provider_model_id}"
if not self.source:
self.source = self.path
if not self.hash:
self.hash = f"external:{self.provider_id}:{self.provider_model_id}"
return self
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, object]) -> Self:
raise NotAMatchError("external API models are not probed from disk")

View File

@@ -26,6 +26,7 @@ from invokeai.backend.model_manager.configs.controlnet import (
ControlNet_Diffusers_SD2_Config,
ControlNet_Diffusers_SDXL_Config,
)
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig
from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoint_Config
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
from invokeai.backend.model_manager.configs.ip_adapter import (
@@ -268,6 +269,7 @@ AnyModelConfig = Annotated[
Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()],
Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()],
Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()],
Annotated[ExternalApiModelConfig, ExternalApiModelConfig.get_tag()],
# Unknown model (fallback)
Annotated[Unknown_Config, Unknown_Config.get_tag()],
],

View File

@@ -2,6 +2,13 @@ from typing import Optional
from pydantic import BaseModel
from invokeai.backend.model_manager.configs.external_api import (
ExternalApiModelDefaultSettings,
ExternalImageSize,
ExternalModelCapabilities,
ExternalModelPanelSchema,
ExternalResolutionPreset,
)
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
@@ -20,6 +27,9 @@ class StarterModelWithoutDependencies(BaseModel):
format: Optional[ModelFormat] = None
variant: Optional[AnyVariant] = None
is_installed: bool = False
capabilities: ExternalModelCapabilities | None = None
default_settings: ExternalApiModelDefaultSettings | None = None
panel_schema: ExternalModelPanelSchema | None = None
# allows us to track what models a user has installed across name changes within starter models
# if you update a starter model name, please add the old one to this list for that starter model
previous_names: list[str] = []
@@ -1001,6 +1011,367 @@ z_image_controlnet_tile = StarterModel(
)
# endregion
# region External API
GEMINI_3_IMAGE_ALLOWED_ASPECT_RATIOS = [
"1:1",
"1:4",
"1:8",
"2:3",
"3:2",
"3:4",
"4:1",
"4:3",
"4:5",
"5:4",
"8:1",
"9:16",
"16:9",
"21:9",
]
GEMINI_3_IMAGE_MAX_SIZE = ExternalImageSize(width=4096, height=4096)
def _gemini_3_resolution_presets(
image_sizes: list[str],
aspect_ratios: list[str] | None = None,
) -> list[ExternalResolutionPreset]:
"""Build resolution presets for Gemini 3 models.
Each preset combines an aspect ratio with an image size preset (512/1K/2K/4K).
Pixel dimensions are approximations based on the preset name (longest side).
"""
if aspect_ratios is None:
aspect_ratios = GEMINI_3_IMAGE_ALLOWED_ASPECT_RATIOS
base_pixels = {"512": 512, "1K": 1024, "2K": 2048, "4K": 4096}
presets: list[ExternalResolutionPreset] = []
for image_size in image_sizes:
base = base_pixels[image_size]
for ratio_str in aspect_ratios:
w_part, h_part = (int(x) for x in ratio_str.split(":"))
if w_part >= h_part:
w = base
h = max(1, round(base * h_part / w_part))
else:
h = base
w = max(1, round(base * w_part / h_part))
presets.append(
ExternalResolutionPreset(
label=f"{ratio_str} ({image_size}) — {w}\u00d7{h}",
aspect_ratio=ratio_str,
image_size=image_size,
width=w,
height=h,
)
)
return presets
GEMINI_3_PRO_RESOLUTION_PRESETS = _gemini_3_resolution_presets(["1K", "2K", "4K"])
GEMINI_3_1_FLASH_RESOLUTION_PRESETS = _gemini_3_resolution_presets(["512", "1K", "2K", "4K"])
gemini_flash_image = StarterModel(
name="Gemini 2.5 Flash Image",
base=BaseModelType.External,
source="external://gemini/gemini-2.5-flash-image",
description="Google Gemini 2.5 Flash image generation model (external API). Requires a configured Gemini API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img", "img2img", "inpaint"],
supports_seed=True,
supports_reference_images=True,
max_images_per_request=1,
allowed_aspect_ratios=[
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"4:5",
"5:4",
"9:16",
"16:9",
"21:9",
],
aspect_ratio_sizes={
"1:1": ExternalImageSize(width=1024, height=1024),
"2:3": ExternalImageSize(width=832, height=1248),
"3:2": ExternalImageSize(width=1248, height=832),
"3:4": ExternalImageSize(width=864, height=1184),
"4:3": ExternalImageSize(width=1184, height=864),
"4:5": ExternalImageSize(width=896, height=1152),
"5:4": ExternalImageSize(width=1152, height=896),
"9:16": ExternalImageSize(width=768, height=1344),
"16:9": ExternalImageSize(width=1344, height=768),
"21:9": ExternalImageSize(width=1536, height=672),
},
),
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
panel_schema=ExternalModelPanelSchema(prompts=[{"name": "reference_images"}], image=[{"name": "dimensions"}]),
)
gemini_pro_image_preview = StarterModel(
name="Gemini 3 Pro Image Preview",
base=BaseModelType.External,
source="external://gemini/gemini-3-pro-image-preview",
description="Google Gemini 3 Pro image generation preview model (external API). Supports up to 14 reference images, including up to 6 object references and up to 5 character references. Supports 1K/2K/4K resolution presets. Requires a configured Gemini API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img", "img2img", "inpaint"],
supports_seed=True,
supports_reference_images=True,
max_reference_images=14,
max_images_per_request=1,
max_image_size=GEMINI_3_IMAGE_MAX_SIZE,
allowed_aspect_ratios=GEMINI_3_IMAGE_ALLOWED_ASPECT_RATIOS,
resolution_presets=GEMINI_3_PRO_RESOLUTION_PRESETS,
),
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
panel_schema=ExternalModelPanelSchema(prompts=[{"name": "reference_images"}], image=[{"name": "dimensions"}]),
)
gemini_3_1_flash_image_preview = StarterModel(
name="Gemini 3.1 Flash Image Preview",
base=BaseModelType.External,
source="external://gemini/gemini-3.1-flash-image-preview",
description="Google Gemini 3.1 Flash image generation preview model (external API). Supports up to 14 reference images, including up to 10 object references and up to 4 character references. Supports 512/1K/2K/4K resolution presets. Requires a configured Gemini API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img", "img2img", "inpaint"],
supports_seed=True,
supports_reference_images=True,
max_reference_images=14,
max_images_per_request=1,
max_image_size=GEMINI_3_IMAGE_MAX_SIZE,
allowed_aspect_ratios=GEMINI_3_IMAGE_ALLOWED_ASPECT_RATIOS,
resolution_presets=GEMINI_3_1_FLASH_RESOLUTION_PRESETS,
),
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
panel_schema=ExternalModelPanelSchema(prompts=[{"name": "reference_images"}], image=[{"name": "dimensions"}]),
)
QWEN_IMAGE_2_ALLOWED_ASPECT_RATIOS = ["1:1", "4:3", "3:4", "16:9", "9:16"]
QWEN_IMAGE_MAX_ALLOWED_ASPECT_RATIOS = ["1:1", "4:3", "3:4", "16:9", "9:16"]
WAN_V2_ALLOWED_ASPECT_RATIOS = ["1:1", "4:3", "3:4", "16:9", "9:16"]
alibabacloud_qwen_image_2_pro = StarterModel(
name="Qwen Image 2.0 Pro",
base=BaseModelType.External,
source="external://alibabacloud/qwen-image-2.0-pro",
description="Alibaba Cloud Qwen Image 2.0 Pro model (external API). Best quality text-to-image with excellent bilingual text rendering. Requires a configured Alibaba Cloud DashScope API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img"],
supports_negative_prompt=True,
supports_seed=True,
max_images_per_request=4,
allowed_aspect_ratios=QWEN_IMAGE_2_ALLOWED_ASPECT_RATIOS,
aspect_ratio_sizes={
"1:1": ExternalImageSize(width=2048, height=2048),
"4:3": ExternalImageSize(width=2368, height=1728),
"3:4": ExternalImageSize(width=1728, height=2368),
"16:9": ExternalImageSize(width=2688, height=1536),
"9:16": ExternalImageSize(width=1536, height=2688),
},
),
default_settings=ExternalApiModelDefaultSettings(width=2048, height=2048, num_images=1),
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
)
alibabacloud_qwen_image_2 = StarterModel(
name="Qwen Image 2.0",
base=BaseModelType.External,
source="external://alibabacloud/qwen-image-2.0",
description="Alibaba Cloud Qwen Image 2.0 model (external API). Fast text-to-image with good bilingual text rendering. Requires a configured Alibaba Cloud DashScope API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img"],
supports_negative_prompt=True,
supports_seed=True,
max_images_per_request=4,
allowed_aspect_ratios=QWEN_IMAGE_2_ALLOWED_ASPECT_RATIOS,
aspect_ratio_sizes={
"1:1": ExternalImageSize(width=2048, height=2048),
"4:3": ExternalImageSize(width=2368, height=1728),
"3:4": ExternalImageSize(width=1728, height=2368),
"16:9": ExternalImageSize(width=2688, height=1536),
"9:16": ExternalImageSize(width=1536, height=2688),
},
),
default_settings=ExternalApiModelDefaultSettings(width=2048, height=2048, num_images=1),
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
)
alibabacloud_qwen_image_max = StarterModel(
name="Qwen Image Max",
base=BaseModelType.External,
source="external://alibabacloud/qwen-image-max",
description="Alibaba Cloud Qwen Image Max model (external API). High quality text-to-image generation. Requires a configured Alibaba Cloud DashScope API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img"],
supports_negative_prompt=True,
supports_seed=True,
max_images_per_request=4,
allowed_aspect_ratios=QWEN_IMAGE_MAX_ALLOWED_ASPECT_RATIOS,
aspect_ratio_sizes={
"1:1": ExternalImageSize(width=1328, height=1328),
"4:3": ExternalImageSize(width=1472, height=1104),
"3:4": ExternalImageSize(width=1104, height=1472),
"16:9": ExternalImageSize(width=1664, height=928),
"9:16": ExternalImageSize(width=928, height=1664),
},
),
default_settings=ExternalApiModelDefaultSettings(width=1328, height=1328, num_images=1),
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
)
alibabacloud_wan26_t2i = StarterModel(
name="Wan 2.6 Text-to-Image",
base=BaseModelType.External,
source="external://alibabacloud/wan2.6-t2i",
description="Alibaba Cloud Wan 2.6 text-to-image model (external API). Photorealistic image generation. Requires a configured Alibaba Cloud DashScope API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img"],
supports_negative_prompt=True,
supports_seed=True,
max_images_per_request=4,
allowed_aspect_ratios=WAN_V2_ALLOWED_ASPECT_RATIOS,
aspect_ratio_sizes={
"1:1": ExternalImageSize(width=1024, height=1024),
"4:3": ExternalImageSize(width=1440, height=1080),
"3:4": ExternalImageSize(width=1080, height=1440),
"16:9": ExternalImageSize(width=1440, height=810),
"9:16": ExternalImageSize(width=810, height=1440),
},
),
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
)
alibabacloud_qwen_image_edit_max = StarterModel(
name="Qwen Image Edit Max",
base=BaseModelType.External,
source="external://alibabacloud/qwen-image-edit-max",
description="Alibaba Cloud Qwen Image Edit Max model (external API). Image editing with industrial design and geometric reasoning. Requires a configured Alibaba Cloud DashScope API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["img2img"],
supports_negative_prompt=True,
supports_seed=True,
max_images_per_request=4,
allowed_aspect_ratios=QWEN_IMAGE_2_ALLOWED_ASPECT_RATIOS,
aspect_ratio_sizes={
"1:1": ExternalImageSize(width=2048, height=2048),
"4:3": ExternalImageSize(width=2368, height=1728),
"3:4": ExternalImageSize(width=1728, height=2368),
"16:9": ExternalImageSize(width=2688, height=1536),
"9:16": ExternalImageSize(width=1536, height=2688),
},
),
default_settings=ExternalApiModelDefaultSettings(width=2048, height=2048, num_images=1),
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
)
OPENAI_GPT_IMAGE_ASPECT_RATIOS = ["1:1", "3:2", "2:3"]
OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES = {
"1:1": ExternalImageSize(width=1024, height=1024),
"3:2": ExternalImageSize(width=1536, height=1024),
"2:3": ExternalImageSize(width=1024, height=1536),
}
OPENAI_GPT_IMAGE_PANEL_SCHEMA = ExternalModelPanelSchema(
prompts=[{"name": "reference_images"}], image=[{"name": "dimensions"}]
)
openai_gpt_image_1_5 = StarterModel(
name="GPT Image 1.5",
base=BaseModelType.External,
source="external://openai/gpt-image-1.5",
description="OpenAI GPT-Image-1.5 image generation model. Fastest and most affordable GPT image model. Requires a configured OpenAI API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img", "img2img", "inpaint"],
supports_reference_images=True,
max_images_per_request=10,
allowed_aspect_ratios=OPENAI_GPT_IMAGE_ASPECT_RATIOS,
aspect_ratio_sizes=OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES,
),
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
panel_schema=OPENAI_GPT_IMAGE_PANEL_SCHEMA,
)
openai_gpt_image_1 = StarterModel(
name="GPT Image 1",
base=BaseModelType.External,
source="external://openai/gpt-image-1",
description="OpenAI GPT-Image-1 image generation model. High quality image generation. Requires a configured OpenAI API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img", "img2img", "inpaint"],
supports_reference_images=True,
max_images_per_request=10,
allowed_aspect_ratios=OPENAI_GPT_IMAGE_ASPECT_RATIOS,
aspect_ratio_sizes=OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES,
),
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
panel_schema=OPENAI_GPT_IMAGE_PANEL_SCHEMA,
)
openai_gpt_image_1_mini = StarterModel(
name="GPT Image 1 Mini",
base=BaseModelType.External,
source="external://openai/gpt-image-1-mini",
description="OpenAI GPT-Image-1-Mini image generation model. Cost-efficient option, 80%% cheaper than GPT-Image-1. Requires a configured OpenAI API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img", "img2img", "inpaint"],
supports_reference_images=True,
max_images_per_request=10,
allowed_aspect_ratios=OPENAI_GPT_IMAGE_ASPECT_RATIOS,
aspect_ratio_sizes=OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES,
),
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
panel_schema=OPENAI_GPT_IMAGE_PANEL_SCHEMA,
)
openai_dall_e_3 = StarterModel(
name="DALL-E 3",
base=BaseModelType.External,
source="external://openai/dall-e-3",
description="OpenAI DALL-E 3 image generation model. Supports vivid and natural styles. Only text-to-image, no editing. Requires a configured OpenAI API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img"],
max_images_per_request=1,
allowed_aspect_ratios=["1:1", "7:4", "4:7"],
aspect_ratio_sizes={
"1:1": ExternalImageSize(width=1024, height=1024),
"7:4": ExternalImageSize(width=1792, height=1024),
"4:7": ExternalImageSize(width=1024, height=1792),
},
),
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
)
openai_dall_e_2 = StarterModel(
name="DALL-E 2",
base=BaseModelType.External,
source="external://openai/dall-e-2",
description="OpenAI DALL-E 2 image generation model. Supports square images only. Requires a configured OpenAI API key and may incur provider usage costs.",
type=ModelType.ExternalImageGenerator,
format=ModelFormat.ExternalApi,
capabilities=ExternalModelCapabilities(
modes=["txt2img", "img2img", "inpaint"],
max_images_per_request=10,
allowed_aspect_ratios=["1:1"],
aspect_ratio_sizes={
"1:1": ExternalImageSize(width=1024, height=1024),
},
),
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
)
# region Anima
anima_qwen3_encoder = StarterModel(
name="Anima Qwen3 0.6B Text Encoder",
@@ -1140,6 +1511,19 @@ STARTER_MODELS: list[StarterModel] = [
z_image_qwen3_encoder_quantized,
z_image_controlnet_union,
z_image_controlnet_tile,
gemini_flash_image,
gemini_pro_image_preview,
gemini_3_1_flash_image_preview,
openai_gpt_image_1_5,
openai_gpt_image_1,
openai_gpt_image_1_mini,
openai_dall_e_3,
openai_dall_e_2,
alibabacloud_qwen_image_2_pro,
alibabacloud_qwen_image_2,
alibabacloud_qwen_image_max,
alibabacloud_wan26_t2i,
alibabacloud_qwen_image_edit_max,
anima_preview3,
anima_qwen3_encoder,
anima_vae,

View File

@@ -52,6 +52,8 @@ class BaseModelType(str, Enum):
"""Indicates the model is associated with CogView 4 model architecture."""
ZImage = "z-image"
"""Indicates the model is associated with Z-Image model architecture, including Z-Image-Turbo."""
External = "external"
"""Indicates the model is hosted by an external provider."""
QwenImage = "qwen-image"
"""Indicates the model is associated with Qwen Image Edit 2511 model architecture."""
Anima = "anima"
@@ -80,6 +82,7 @@ class ModelType(str, Enum):
SigLIP = "siglip"
FluxRedux = "flux_redux"
LlavaOnevision = "llava_onevision"
ExternalImageGenerator = "external_image_generator"
Unknown = "unknown"
@@ -187,6 +190,7 @@ class ModelFormat(str, Enum):
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
GGUFQuantized = "gguf_quantized"
ExternalApi = "external_api"
Unknown = "unknown"
@@ -215,6 +219,7 @@ class ModelSourceType(str, Enum):
Path = "path"
Url = "url"
HFRepoID = "hf_repo_id"
External = "external"
class FluxLoRAFormat(str, Enum):

View File

@@ -1115,6 +1115,22 @@
"fileSize": "File Size",
"filterModels": "Filter models",
"fluxRedux": "FLUX Redux",
"externalImageGenerator": "External Image Generator",
"externalProviders": "External Providers",
"externalSetupTitle": "External Providers Setup",
"externalSetupDescription": "Connect an API key to enable external image generation. External starter models auto-install when a provider is configured.",
"externalInstallDefaults": "Auto-install starter models",
"externalProvidersUnavailable": "External providers are not available in this build.",
"externalSetupFooter": "An API key is required. External providers use remote APIs; usage may incur provider-side costs.",
"externalProviderCardDescription": "Configure {{providerId}} credentials for external image generation.",
"externalApiKey": "API Key",
"externalApiKeyPlaceholder": "Paste your API key",
"externalApiKeyPlaceholderSet": "API key configured",
"externalApiKeyHelper": "Stored in your InvokeAI config file.",
"externalBaseUrl": "Base URL (optional)",
"externalBaseUrlPlaceholder": "https://...",
"externalBaseUrlHelper": "Override the default API base URL if needed.",
"externalResetHelper": "Clear API key and base URL.",
"height": "Height",
"huggingFace": "HuggingFace",
"huggingFacePlaceholder": "owner/model-name",
@@ -1182,6 +1198,21 @@
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"name": "Name",
"externalProvider": "External Provider",
"externalCapabilities": "External Capabilities",
"externalDefaults": "External Defaults",
"providerId": "Provider ID",
"providerModelId": "Provider Model ID",
"supportedModes": "Supported Modes",
"supportsNegativePrompt": "Supports Negative Prompt",
"supportsReferenceImages": "Supports Reference Images",
"supportsSeed": "Supports Seed",
"supportsGuidance": "Supports Guidance",
"maxImagesPerRequest": "Max Images Per Request",
"maxReferenceImages": "Max Reference Images",
"maxImageWidth": "Max Image Width",
"maxImageHeight": "Max Image Height",
"numImages": "Num Images",
"modelPickerFallbackNoModelsInstalled": "No models installed.",
"modelPickerFallbackNoModelsInstalled2": "Visit the <LinkComponent>Model Manager</LinkComponent> to install models.",
"modelPickerFallbackNoModelsInstalledNonAdmin": "No models installed. Ask your InvokeAI administrator (<AdminEmailLink />) to install some models.",
@@ -1226,6 +1257,7 @@
"urlDescription": "Install models from a URL or local file path. Perfect for specific models you want to add.",
"huggingFaceDescription": "Browse and install models directly from HuggingFace repositories.",
"scanFolderDescription": "Scan a local folder to automatically detect and install models.",
"externalDescription": "Connect a Gemini or OpenAI API key to enable external generation. Usage may incur provider-side costs.",
"recommendedModels": "Recommended Models",
"exploreStarter": "Or browse all available starter models",
"quickStart": "Quick Start Bundles",
@@ -1544,6 +1576,7 @@
"copyImage": "Copy Image",
"denoisingStrength": "Denoising Strength",
"disabledNoRasterContent": "Disabled (No Raster Content)",
"disabledNotSupported": "Not supported by model",
"downloadImage": "Download Image",
"general": "General",
"guidance": "Guidance",
@@ -1661,6 +1694,7 @@
"boxBlur": "Box Blur",
"staged": "Staged",
"resolution": "Resolution",
"imageSize": "Image Size",
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your <LinkComponent>account settings</LinkComponent> to upgrade."
},
"dynamicPrompts": {
@@ -1737,7 +1771,11 @@
"intermediatesCleared_one": "Cleared {{count}} Intermediate",
"intermediatesCleared_other": "Cleared {{count}} Intermediates",
"intermediatesClearedFailed": "Problem Clearing Intermediates",
"reloadingIn": "Reloading in"
"reloadingIn": "Reloading in",
"externalProviders": "External Providers",
"externalProviderConfigured": "Configured",
"externalProviderNotConfigured": "API Key Required",
"externalProviderNotConfiguredHint": "Add your API key in Model Manager or the server config to enable this provider."
},
"toast": {
"addedToBoard": "Added to board {{name}}'s assets",

View File

@@ -7,10 +7,12 @@ import {
animaQwen3EncoderModelSelected,
animaT5EncoderModelSelected,
animaVaeModelSelected,
aspectRatioIdChanged,
kleinQwen3EncoderModelSelected,
kleinVaeModelSelected,
modelChanged,
qwenImageComponentSourceSelected,
resolutionPresetSelected,
setZImageScheduler,
syncedToOptimalDimension,
vaeSelected,
@@ -30,6 +32,7 @@ import {
} from 'features/controlLayers/store/selectors';
import {
getEntityIdentifier,
isAspectRatioID,
isFlux2ReferenceImageConfig,
isQwenImageReferenceImageConfig,
} from 'features/controlLayers/store/types';
@@ -59,7 +62,7 @@ import {
selectZImageDiffusersModels,
} from 'services/api/hooks/modelsByType';
import type { FLUXKontextModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
import { isFluxKontextModelConfig, isFluxReduxModelConfig } from 'services/api/types';
import { isExternalApiModelConfig, isFluxKontextModelConfig, isFluxReduxModelConfig } from 'services/api/types';
const log = logger('models');
@@ -281,7 +284,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
}
if (SUPPORTS_REF_IMAGES_BASE_MODELS.includes(newModel.base)) {
if (newModel.base !== 'external' && SUPPORTS_REF_IMAGES_BASE_MODELS.includes(newModel.base)) {
// Handle incompatible reference image models - switch to first compatible model, with some smart logic
// to choose the best available model based on the new main model.
const allRefImageModels = selectGlobalRefImageModels(state).filter(({ base }) => base === newBase);
@@ -529,6 +532,34 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
dispatch(bboxSyncedToOptimalDimension());
}
}
// When switching to an external model, sync bbox to the model's first preset dimensions
if (newBase === 'external') {
const modelConfigsResult = selectModelConfigsQuery(getState());
if (modelConfigsResult.data) {
const newModelConfig = modelConfigsAdapterSelectors.selectById(modelConfigsResult.data, newModel.key);
if (newModelConfig && isExternalApiModelConfig(newModelConfig)) {
const { aspect_ratio_sizes, resolution_presets } = newModelConfig.capabilities;
if (resolution_presets && resolution_presets.length > 0) {
const firstPreset = resolution_presets[0]!;
dispatch(
resolutionPresetSelected({
imageSize: firstPreset.image_size,
aspectRatio: firstPreset.aspect_ratio,
width: firstPreset.width,
height: firstPreset.height,
})
);
} else if (aspect_ratio_sizes) {
const firstRatio = Object.keys(aspect_ratio_sizes)[0];
const firstSize = firstRatio ? aspect_ratio_sizes[firstRatio] : undefined;
if (firstRatio && firstSize && isAspectRatioID(firstRatio)) {
dispatch(aspectRatioIdChanged({ id: firstRatio, fixedSize: firstSize }));
}
}
}
}
}
},
});
};

View File

@@ -11,7 +11,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import WavyLine from 'common/components/WavyLine';
import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
import { selectImg2imgStrength, selectIsExternal, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
import { selectActiveRasterLayerEntities } from 'features/controlLayers/store/selectors';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -37,6 +37,7 @@ export const ParamDenoisingStrength = memo(() => {
const img2imgStrength = useAppSelector(selectImg2imgStrength);
const dispatch = useAppDispatch();
const hasRasterLayersWithContent = useAppSelector(selectHasRasterLayersWithContent);
const isExternal = useAppSelector(selectIsExternal);
const selectedModelConfig = useSelectedModelConfig();
const onChange = useCallback(
@@ -55,12 +56,16 @@ export const ParamDenoisingStrength = memo(() => {
// Denoising strength does nothing if there are no raster layers w/ content
return true;
}
if (isExternal) {
// External models don't support denoise strength - they handle img2img via prompt
return true;
}
if (selectedModelConfig && isFluxFillMainModelModelConfig(selectedModelConfig)) {
// Denoising strength is ignored by FLUX Fill, which is indicated by the variant being 'inpaint'
return true;
}
return false;
}, [hasRasterLayersWithContent, selectedModelConfig]);
}, [hasRasterLayersWithContent, isExternal, selectedModelConfig]);
return (
<FormControl isDisabled={isDisabled} p={1} justifyContent="space-between" h={8}>
@@ -96,7 +101,9 @@ export const ParamDenoisingStrength = memo(() => {
</>
) : (
<Flex alignItems="center">
<Badge opacity="0.6">{t('parameters.disabledNoRasterContent')}</Badge>
<Badge opacity="0.6">
{isExternal ? t('parameters.disabledNotSupported') : t('parameters.disabledNoRasterContent')}
</Badge>
</Flex>
)}
</FormControl>

View File

@@ -16,6 +16,7 @@ import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi';
import { useImageDTOFromCroppableImage } from 'services/api/endpoints/images';
import { isExternalApiModelConfig } from 'services/api/types';
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
@@ -73,18 +74,19 @@ export const RefImagePreview = memo(() => {
const selectedEntityId = useAppSelector(selectSelectedRefEntityId);
const isPanelOpen = useAppSelector(selectIsRefImagePanelOpen);
const [showWeightDisplay, setShowWeightDisplay] = useState(false);
const isExternalModel = !!mainModelConfig && isExternalApiModelConfig(mainModelConfig);
const imageDTO = useImageDTOFromCroppableImage(entity.config.image);
const sx = useMemo(() => {
if (!isIPAdapterConfig(entity.config)) {
if (!isIPAdapterConfig(entity.config) || isExternalModel) {
return baseSx;
}
return getImageSxWithWeight(entity.config.weight);
}, [entity.config]);
}, [entity.config, isExternalModel]);
useEffect(() => {
if (!isIPAdapterConfig(entity.config)) {
if (!isIPAdapterConfig(entity.config) || isExternalModel) {
return;
}
setShowWeightDisplay(true);
@@ -94,7 +96,7 @@ export const RefImagePreview = memo(() => {
return () => {
window.clearTimeout(timeout);
};
}, [entity.config]);
}, [entity.config, isExternalModel]);
const warnings = useMemo(() => {
return getGlobalReferenceImageWarnings(entity, mainModelConfig);
@@ -156,7 +158,7 @@ export const RefImagePreview = memo(() => {
) : (
<Skeleton h="full" aspectRatio="1/1" />
)}
{isIPAdapterConfig(entity.config) && (
{isIPAdapterConfig(entity.config) && !isExternalModel && (
<Flex
position="absolute"
inset={0}

View File

@@ -15,7 +15,7 @@ import {
useCanvasManagerSafe,
} from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectIsFLUX, selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import {
refImageFLUXReduxImageInfluenceChanged,
refImageImageChanged,
@@ -50,6 +50,7 @@ import type {
FLUXReduxModelConfig,
IPAdapterModelConfig,
} from 'services/api/types';
import { isExternalApiModelConfig } from 'services/api/types';
import { RefImageImage } from './RefImageImage';
@@ -65,6 +66,7 @@ const RefImageSettingsContent = memo(() => {
const selectConfig = useMemo(() => buildSelectConfig(id), [id]);
const config = useAppSelector(selectConfig);
const tab = useAppSelector(selectActiveTab);
const mainModelConfig = useAppSelector(selectMainModelConfig);
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
@@ -125,9 +127,11 @@ const RefImageSettingsContent = memo(() => {
);
const isFLUX = useAppSelector(selectIsFLUX);
const isExternalModel = !!mainModelConfig && isExternalApiModelConfig(mainModelConfig);
// FLUX.2 Klein and Qwen Image Edit have built-in reference image support - no model selector needed
const showModelSelector = !isFlux2ReferenceImageConfig(config) && !isQwenImageReferenceImageConfig(config);
// FLUX.2 Klein, Qwen Image Edit and external API models do not require a ref image model selection.
const showModelSelector =
!isFlux2ReferenceImageConfig(config) && !isQwenImageReferenceImageConfig(config) && !isExternalModel;
return (
<Flex flexDir="column" gap={2} position="relative" w="full">
@@ -155,14 +159,14 @@ const RefImageSettingsContent = memo(() => {
</Flex>
)}
<Flex gap={2} w="full">
{isIPAdapterConfig(config) && (
{isIPAdapterConfig(config) && !isExternalModel && (
<Flex flexDir="column" gap={2} w="full">
{!isFLUX && <IPAdapterMethod method={config.method} onChange={onChangeIPMethod} />}
<Weight weight={config.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={config.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
)}
{isFLUXReduxConfig(config) && (
{isFLUXReduxConfig(config) && !isExternalModel && (
<Flex flexDir="column" gap={2} w="full" alignItems="flex-start">
<FLUXReduxImageInfluence
imageInfluence={config.imageInfluence ?? 'lowest'}

View File

@@ -75,11 +75,18 @@ export const StagingAreaContextProvider = memo(({ children, sessionId }: PropsWi
onAccept: (item, imageDTO) => {
const bboxRect = selectBboxRect(store.getState());
const { x, y } = bboxRect;
const imageObject = imageDTOToImageObject(imageDTO, { usePixelBbox: false });
const scale = Math.min(bboxRect.width / imageDTO.width, bboxRect.height / imageDTO.height);
const scaledWidth = Math.round(imageDTO.width * scale);
const scaledHeight = Math.round(imageDTO.height * scale);
const position = {
x: x + Math.round((bboxRect.width - scaledWidth) / 2),
y: y + Math.round((bboxRect.height - scaledHeight) / 2),
};
const selectedEntityIdentifier = selectSelectedEntityIdentifier(store.getState());
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRasterLayerState> = {
position: { x, y },
position,
objects: [imageObject],
};
store.dispatch(rasterLayerAdded({ overrides, isSelected: selectedEntityIdentifier?.type === 'raster_layer' }));

View File

@@ -183,6 +183,33 @@ describe('StagingAreaApi Utility Functions', () => {
expect(result).toEqual(['first-image.png', 'second-image.png']);
});
it('should return first image from image collections', () => {
const queueItem: S['SessionQueueItem'] = {
item_id: 1,
status: 'completed',
priority: 0,
destination: 'test-session',
created_at: '2024-01-01T00:00:00Z',
updated_at: '2024-01-01T00:00:00Z',
started_at: '2024-01-01T00:00:01Z',
completed_at: '2024-01-01T00:01:00Z',
error: null,
session: {
id: 'test-session',
source_prepared_mapping: {
canvas_output: ['output-node-id'],
},
results: {
'output-node-id': {
images: [{ image_name: 'first.png' }, { image_name: 'second.png' }],
},
},
},
} as unknown as S['SessionQueueItem'];
expect(getOutputImageNames(queueItem)).toEqual(['first.png', 'second.png']);
});
it('should handle empty session mapping', () => {
const queueItem: S['SessionQueueItem'] = {
item_id: 1,

View File

@@ -1,4 +1,4 @@
import { isImageField } from 'features/nodes/types/common';
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
import { isCanvasOutputNodeId } from 'features/nodes/util/graph/graphBuilderUtils';
import type { S } from 'services/api/types';
import { formatProgressMessage } from 'services/events/stores';
@@ -32,6 +32,11 @@ export const getOutputImageNames = (item: S['SessionQueueItem']): string[] => {
if (isImageField(value)) {
imageNames.push(value.image_name);
}
if (isImageFieldCollection(value)) {
for (const img of value) {
imageNames.push(img.image_name);
}
}
}
}

View File

@@ -45,7 +45,7 @@ import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { serializeError } from 'serialize-error';
import type { ImageDTO } from 'services/api/types';
import { type ImageDTO, isExternalApiModelConfig } from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('canvas');
@@ -90,7 +90,7 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
metadata.negative_prompt = selectNegativePrompt(state);
metadata.seed = selectSeed(state);
const model = selectMainModelConfig(state);
if (model) {
if (model && !isExternalApiModelConfig(model)) {
metadata.model = Graph.getModelMetadataField(model);
}
}

View File

@@ -10,7 +10,7 @@ import {
getPrefixedId,
} from 'features/controlLayers/konva/util';
import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { selectHasFixedDimensionSizes, selectModel } from 'features/controlLayers/store/paramsSlice';
import { selectBbox } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect, Tool } from 'features/controlLayers/store/types';
import Konva from 'konva';
@@ -191,6 +191,9 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
// Listen for the model changing - some model types constraint the bbox to a certain size or aspect ratio.
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectModel, this.render));
// Listen for fixed dimension sizes changes - external models may lock bbox resizing
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectHasFixedDimensionSizes, this.render));
// Update on busy state changes
this.subscriptions.add(this.manager.$isBusy.listen(this.render));
@@ -246,6 +249,10 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
if (tool !== 'bbox') {
return NO_ANCHORS;
}
// External models with fixed dimension presets don't allow free bbox resizing
if (this.manager.stateApi.runSelector(selectHasFixedDimensionSizes)) {
return NO_ANCHORS;
}
return ALL_ANCHORS;
};

View File

@@ -7,7 +7,7 @@ import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMul
import { merge } from 'es-toolkit/compat';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasReset } from 'features/controlLayers/store/actions';
import { modelChanged } from 'features/controlLayers/store/paramsSlice';
import { aspectRatioIdChanged, modelChanged, resolutionPresetSelected } from 'features/controlLayers/store/paramsSlice';
import {
selectAllEntities,
selectAllEntitiesOfType,
@@ -31,6 +31,7 @@ import type {
RgbColor,
SimpleAdjustmentsConfig,
} from 'features/controlLayers/store/types';
import { isAspectRatioID } from 'features/controlLayers/store/types';
import {
calculateNewSize,
getScaledBoundingBoxDimensions,
@@ -1288,21 +1289,31 @@ const slice = createSlice({
state.bbox.aspectRatio.isLocked = !state.bbox.aspectRatio.isLocked;
syncScaledSize(state);
},
bboxAspectRatioIdChanged: (state, action: PayloadAction<{ id: AspectRatioID }>) => {
const { id } = action.payload;
bboxAspectRatioIdChanged: (
state,
action: PayloadAction<{ id: AspectRatioID; fixedSize?: { width: number; height: number } }>
) => {
const { id, fixedSize } = action.payload;
state.bbox.aspectRatio.id = id;
if (id === 'Free') {
state.bbox.aspectRatio.isLocked = false;
} else {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
const { width, height } = calculateNewSize(
state.bbox.aspectRatio.value,
state.bbox.rect.width * state.bbox.rect.height,
state.bbox.modelBase
);
state.bbox.rect.width = width;
state.bbox.rect.height = height;
if (fixedSize) {
// External models provide fixed dimensions for each aspect ratio
state.bbox.aspectRatio.value = fixedSize.width / fixedSize.height;
state.bbox.rect.width = fixedSize.width;
state.bbox.rect.height = fixedSize.height;
} else {
state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
const { width, height } = calculateNewSize(
state.bbox.aspectRatio.value,
state.bbox.rect.width * state.bbox.rect.height,
state.bbox.modelBase
);
state.bbox.rect.width = width;
state.bbox.rect.height = height;
}
}
syncScaledSize(state);
@@ -1800,6 +1811,29 @@ const slice = createSlice({
syncScaledSize(state);
}
});
// Sync bbox when external model resolution preset is selected (aspect_ratio_sizes)
builder.addCase(aspectRatioIdChanged, (state, action) => {
const { id, fixedSize } = action.payload;
// Only sync when fixedSize is provided (external models with aspect_ratio_sizes)
if (fixedSize) {
state.bbox.rect.width = fixedSize.width;
state.bbox.rect.height = fixedSize.height;
state.bbox.aspectRatio.value = fixedSize.width / fixedSize.height;
state.bbox.aspectRatio.id = id;
state.bbox.aspectRatio.isLocked = true;
syncScaledSize(state);
}
});
// Sync bbox when external model resolution preset is selected (resolution_presets)
builder.addCase(resolutionPresetSelected, (state, action) => {
const { width, height, aspectRatio } = action.payload;
state.bbox.rect.width = width;
state.bbox.rect.height = height;
state.bbox.aspectRatio.value = width / height;
state.bbox.aspectRatio.id = isAspectRatioID(aspectRatio) ? aspectRatio : 'Free';
state.bbox.aspectRatio.isLocked = true;
syncScaledSize(state);
});
},
});

View File

@@ -0,0 +1,133 @@
import type {
ExternalApiModelConfig,
ExternalApiModelDefaultSettings,
ExternalImageSize,
ExternalModelCapabilities,
ExternalModelPanelSchema,
} from 'services/api/types';
import { describe, expect, it } from 'vitest';
import {
selectModelSupportsDimensions,
selectModelSupportsGuidance,
selectModelSupportsNegativePrompt,
selectModelSupportsRefImages,
selectModelSupportsSeed,
selectModelSupportsSteps,
} from './paramsSlice';
const buildExternalModelIdentifier = (config: ExternalApiModelConfig) =>
({
key: config.key,
hash: config.hash,
name: config.name,
base: config.base,
type: config.type,
}) as const;
const createExternalConfig = (
capabilities: ExternalModelCapabilities,
panelSchema?: ExternalModelPanelSchema
): ExternalApiModelConfig => {
const maxImageSize: ExternalImageSize = { width: 1024, height: 1024 };
const defaultSettings: ExternalApiModelDefaultSettings = { width: 1024, height: 1024 };
return {
key: 'external-test',
hash: 'external:openai:gpt-image-1',
path: 'external://openai/gpt-image-1',
file_size: 0,
name: 'External Test',
description: null,
source: 'external://openai/gpt-image-1',
source_type: 'url',
source_api_response: null,
cover_image: null,
base: 'external',
type: 'external_image_generator',
format: 'external_api',
provider_id: 'openai',
provider_model_id: 'gpt-image-1',
capabilities: { ...capabilities, max_image_size: maxImageSize },
default_settings: defaultSettings,
panel_schema: panelSchema,
tags: ['external'],
is_default: false,
};
};
describe('paramsSlice selectors for external models', () => {
it('returns false for negative prompt support on external models', () => {
const config = createExternalConfig({
modes: ['txt2img'],
supports_reference_images: false,
});
const model = buildExternalModelIdentifier(config);
expect(selectModelSupportsNegativePrompt.resultFunc(model)).toBe(false);
});
it('uses external capabilities for ref image support', () => {
const config = createExternalConfig({
modes: ['txt2img'],
supports_reference_images: false,
});
const model = buildExternalModelIdentifier(config);
expect(selectModelSupportsRefImages.resultFunc(model, config)).toBe(false);
});
it('returns false for guidance support on external models', () => {
const config = createExternalConfig({
modes: ['txt2img'],
supports_reference_images: false,
});
const model = buildExternalModelIdentifier(config);
expect(selectModelSupportsGuidance.resultFunc(model)).toBe(false);
});
it('uses external capabilities for seed support', () => {
const config = createExternalConfig({
modes: ['txt2img'],
supports_reference_images: false,
supports_seed: false,
});
const model = buildExternalModelIdentifier(config);
expect(selectModelSupportsSeed.resultFunc(model, config)).toBe(false);
});
it('returns false for steps support on external models', () => {
const config = createExternalConfig({
modes: ['txt2img'],
supports_reference_images: false,
});
const model = buildExternalModelIdentifier(config);
expect(selectModelSupportsSteps.resultFunc(model)).toBe(false);
});
it('prefers panel schema over capabilities for control visibility', () => {
const config = createExternalConfig(
{
modes: ['txt2img'],
supports_reference_images: true,
supports_seed: true,
},
{
prompts: [{ name: 'reference_images' }],
image: [{ name: 'dimensions' }],
generation: [],
}
);
const model = buildExternalModelIdentifier(config);
expect(selectModelSupportsNegativePrompt.resultFunc(model)).toBe(false);
expect(selectModelSupportsRefImages.resultFunc(model, config)).toBe(true);
expect(selectModelSupportsGuidance.resultFunc(model)).toBe(false);
expect(selectModelSupportsSeed.resultFunc(model, config)).toBe(false);
expect(selectModelSupportsSteps.resultFunc(model)).toBe(false);
expect(selectModelSupportsDimensions.resultFunc(model, config)).toBe(true);
});
});

View File

@@ -21,6 +21,7 @@ import {
SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS,
SUPPORTS_REF_IMAGES_BASE_MODELS,
} from 'features/modelManagerV2/models';
import type { BaseModelType } from 'features/nodes/types/common';
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
import type {
ParameterCanvasCoherenceMode,
@@ -41,9 +42,11 @@ import type {
ParameterT5EncoderModel,
ParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { getExternalPanelControl, hasExternalPanelControl } from 'features/parameters/util/externalPanelSchema';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import type { AnyModelConfigWithExternal } from 'services/api/types';
import { isExternalApiModelConfig, isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
const slice = createSlice({
@@ -360,7 +363,7 @@ const slice = createSlice({
//#region Dimensions
sizeRecalled: (state, action: PayloadAction<{ width: number; height: number }>) => {
const { width, height } = action.payload;
const gridSize = getGridSize(state.model?.base);
const gridSize = getGridSize(state.model?.base as BaseModelType | undefined);
state.dimensions.width = Math.max(roundDownToMultiple(width, gridSize), 64);
state.dimensions.height = Math.max(roundDownToMultiple(height, gridSize), 64);
state.dimensions.aspectRatio.value = state.dimensions.width / state.dimensions.height;
@@ -369,7 +372,7 @@ const slice = createSlice({
},
widthChanged: (state, action: PayloadAction<{ width: number; updateAspectRatio?: boolean; clamp?: boolean }>) => {
const { width, updateAspectRatio, clamp } = action.payload;
const gridSize = getGridSize(state.model?.base);
const gridSize = getGridSize(state.model?.base as BaseModelType | undefined);
state.dimensions.width = clamp ? Math.max(roundDownToMultiple(width, gridSize), 64) : width;
if (state.dimensions.aspectRatio.isLocked) {
@@ -387,7 +390,7 @@ const slice = createSlice({
},
heightChanged: (state, action: PayloadAction<{ height: number; updateAspectRatio?: boolean; clamp?: boolean }>) => {
const { height, updateAspectRatio, clamp } = action.payload;
const gridSize = getGridSize(state.model?.base);
const gridSize = getGridSize(state.model?.base as BaseModelType | undefined);
state.dimensions.height = clamp ? Math.max(roundDownToMultiple(height, gridSize), 64) : height;
if (state.dimensions.aspectRatio.isLocked) {
@@ -406,21 +409,30 @@ const slice = createSlice({
aspectRatioLockToggled: (state) => {
state.dimensions.aspectRatio.isLocked = !state.dimensions.aspectRatio.isLocked;
},
aspectRatioIdChanged: (state, action: PayloadAction<{ id: AspectRatioID }>) => {
const { id } = action.payload;
aspectRatioIdChanged: (
state,
action: PayloadAction<{ id: AspectRatioID; fixedSize?: { width: number; height: number } }>
) => {
const { id, fixedSize } = action.payload;
state.dimensions.aspectRatio.id = id;
if (id === 'Free') {
state.dimensions.aspectRatio.isLocked = false;
} else {
state.dimensions.aspectRatio.isLocked = true;
state.dimensions.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
const { width, height } = calculateNewSize(
state.dimensions.aspectRatio.value,
state.dimensions.width * state.dimensions.height,
state.model?.base
);
state.dimensions.width = width;
state.dimensions.height = height;
if (fixedSize) {
state.dimensions.aspectRatio.value = fixedSize.width / fixedSize.height;
state.dimensions.width = fixedSize.width;
state.dimensions.height = fixedSize.height;
} else {
state.dimensions.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
const { width, height } = calculateNewSize(
state.dimensions.aspectRatio.value,
state.dimensions.width * state.dimensions.height,
state.model?.base as BaseModelType | undefined
);
state.dimensions.width = width;
state.dimensions.height = height;
}
}
},
dimensionsSwapped: (state) => {
@@ -434,7 +446,7 @@ const slice = createSlice({
const { width, height } = calculateNewSize(
state.dimensions.aspectRatio.value,
state.dimensions.width * state.dimensions.height,
state.model?.base
state.model?.base as BaseModelType | undefined
);
state.dimensions.width = width;
state.dimensions.height = height;
@@ -442,12 +454,12 @@ const slice = createSlice({
}
},
sizeOptimized: (state) => {
const optimalDimension = getOptimalDimension(state.model?.base);
const optimalDimension = getOptimalDimension(state.model?.base as BaseModelType | undefined);
if (state.dimensions.aspectRatio.isLocked) {
const { width, height } = calculateNewSize(
state.dimensions.aspectRatio.value,
optimalDimension * optimalDimension,
state.model?.base
state.model?.base as BaseModelType | undefined
);
state.dimensions.width = width;
state.dimensions.height = height;
@@ -458,18 +470,54 @@ const slice = createSlice({
}
},
syncedToOptimalDimension: (state) => {
const optimalDimension = getOptimalDimension(state.model?.base);
const optimalDimension = getOptimalDimension(state.model?.base as BaseModelType | undefined);
if (!getIsSizeOptimal(state.dimensions.width, state.dimensions.height, state.model?.base)) {
if (
!getIsSizeOptimal(
state.dimensions.width,
state.dimensions.height,
state.model?.base as BaseModelType | undefined
)
) {
const bboxDims = calculateNewSize(
state.dimensions.aspectRatio.value,
optimalDimension * optimalDimension,
state.model?.base
state.model?.base as BaseModelType | undefined
);
state.dimensions.width = bboxDims.width;
state.dimensions.height = bboxDims.height;
}
},
imageSizeChanged: (state, action: PayloadAction<string | null>) => {
state.imageSize = action.payload;
},
openaiQualityChanged: (state, action: PayloadAction<'auto' | 'high' | 'medium' | 'low'>) => {
state.openaiQuality = action.payload;
},
openaiBackgroundChanged: (state, action: PayloadAction<'auto' | 'transparent' | 'opaque'>) => {
state.openaiBackground = action.payload;
},
openaiInputFidelityChanged: (state, action: PayloadAction<'low' | 'high' | null>) => {
state.openaiInputFidelity = action.payload;
},
geminiTemperatureChanged: (state, action: PayloadAction<number | null>) => {
state.geminiTemperature = action.payload;
},
geminiThinkingLevelChanged: (state, action: PayloadAction<'minimal' | 'high' | null>) => {
state.geminiThinkingLevel = action.payload;
},
resolutionPresetSelected: (
state,
action: PayloadAction<{ imageSize: string; aspectRatio: string; width: number; height: number }>
) => {
const { imageSize, aspectRatio, width, height } = action.payload;
state.imageSize = imageSize;
state.dimensions.width = width;
state.dimensions.height = height;
state.dimensions.aspectRatio.id = aspectRatio as AspectRatioID;
state.dimensions.aspectRatio.value = width / height;
state.dimensions.aspectRatio.isLocked = true;
},
paramsReset: (state) => resetState(state),
paramsRecalled: (_state, action: PayloadAction<ParamsState>) => {
return action.payload;
@@ -502,6 +550,9 @@ const hasModelClipSkip = (model: ParameterModel | null) => {
};
const getModelMaxClipSkip = (model: ParameterModel) => {
if (model.base === 'external') {
return undefined;
}
if (model.base === 'sdxl') {
// We don't support user-defined CLIP skip for SDXL because it doesn't do anything useful
return 0;
@@ -611,7 +662,13 @@ export const {
sizeOptimized,
syncedToOptimalDimension,
resolutionPresetSelected,
paramsReset,
openaiQualityChanged,
openaiBackgroundChanged,
openaiInputFidelityChanged,
geminiTemperatureChanged,
geminiThinkingLevelChanged,
paramsRecalled,
animaVaeModelSelected,
animaQwen3EncoderModelSelected,
@@ -656,6 +713,7 @@ export const selectIsCogView4 = createParamsSelector((params) => params.model?.b
export const selectIsZImage = createParamsSelector((params) => params.model?.base === 'z-image');
export const selectIsAnima = createParamsSelector((params) => params.model?.base === 'anima');
export const selectIsFlux2 = createParamsSelector((params) => params.model?.base === 'flux2');
export const selectIsExternal = createParamsSelector((params) => params.model?.base === 'external');
export const selectIsQwenImage = createParamsSelector((params) => params.model?.base === 'qwen-image');
export const selectIsFluxKontext = createParamsSelector((params) => {
if (params.model?.base === 'flux' && params.model?.name.toLowerCase().includes('kontext')) {
@@ -708,19 +766,91 @@ export const selectOptimizedDenoisingEnabled = createParamsSelector((params) =>
export const selectPositivePrompt = createParamsSelector((params) => params.positivePrompt);
export const selectNegativePrompt = createParamsSelector((params) => params.negativePrompt);
export const selectNegativePromptWithFallback = createParamsSelector((params) => params.negativePrompt ?? '');
export const selectModelConfig = createSelector(
selectModelConfigsQuery,
selectParamsSlice,
(modelConfigs, { model }) => {
if (!modelConfigs.data) {
return null;
}
if (!model) {
return null;
}
return (
(modelConfigsAdapterSelectors.selectById(modelConfigs.data, model.key) as
| AnyModelConfigWithExternal
| undefined) ?? null
);
}
);
export const selectHasNegativePrompt = createParamsSelector((params) => params.negativePrompt !== null);
export const selectModelSupportsNegativePrompt = createSelector(
selectModel,
(model) => !!model && SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS.includes(model.base)
);
export const selectModelSupportsRefImages = createSelector(
selectModel,
(model) => !!model && SUPPORTS_REF_IMAGES_BASE_MODELS.includes(model.base)
);
export const selectModelSupportsNegativePrompt = createSelector(selectModel, (model) => {
if (!model) {
return false;
}
if (model.base === 'external') {
return false;
}
return SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS.includes(model.base);
});
export const selectModelSupportsRefImages = createSelector(selectModel, selectModelConfig, (model, modelConfig) => {
if (!model) {
return false;
}
if (modelConfig && isExternalApiModelConfig(modelConfig)) {
return hasExternalPanelControl(modelConfig, 'prompts', 'reference_images');
}
if (model.base === 'external') {
return false;
}
return SUPPORTS_REF_IMAGES_BASE_MODELS.includes(model.base);
});
export const selectModelSupportsOptimizedDenoising = createSelector(
selectModel,
(model) => !!model && SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS.includes(model.base)
(model) => !!model && model.base !== 'external' && SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS.includes(model.base)
);
export const selectModelSupportsGuidance = createSelector(selectModel, (model) => {
if (!model) {
return false;
}
if (model.base === 'external') {
return false;
}
return true;
});
export const selectModelSupportsSeed = createSelector(selectModel, selectModelConfig, (model, modelConfig) => {
if (!model) {
return false;
}
if (modelConfig && isExternalApiModelConfig(modelConfig)) {
return hasExternalPanelControl(modelConfig, 'image', 'seed');
}
return true;
});
export const selectModelSupportsSteps = createSelector(selectModel, (model) => {
if (!model) {
return false;
}
if (model.base === 'external') {
return false;
}
return true;
});
export const selectModelSupportsDimensions = createSelector(selectModel, selectModelConfig, (model, modelConfig) => {
if (!model) {
return false;
}
if (modelConfig && isExternalApiModelConfig(modelConfig)) {
return hasExternalPanelControl(modelConfig, 'image', 'dimensions');
}
return true;
});
export const selectSeedControl = createSelector(selectModelConfig, (modelConfig) => {
if (modelConfig && isExternalApiModelConfig(modelConfig)) {
return getExternalPanelControl(modelConfig, 'image', 'seed');
}
return null;
});
export const selectScheduler = createParamsSelector((params) => params.scheduler);
export const selectFluxScheduler = createParamsSelector((params) => params.fluxScheduler);
export const selectFluxDypePreset = createParamsSelector((params) => params.fluxDypePreset);
@@ -764,24 +894,52 @@ export const selectHeight = createParamsSelector((params) => params.dimensions.h
export const selectAspectRatioID = createParamsSelector((params) => params.dimensions.aspectRatio.id);
export const selectAspectRatioValue = createParamsSelector((params) => params.dimensions.aspectRatio.value);
export const selectAspectRatioIsLocked = createParamsSelector((params) => params.dimensions.aspectRatio.isLocked);
export const selectAllowedAspectRatioIDs = createSelector(selectModelConfig, (modelConfig) => {
if (!modelConfig || !isExternalApiModelConfig(modelConfig)) {
return null;
}
const allowed = modelConfig.capabilities.allowed_aspect_ratios;
return allowed?.length ? allowed : null;
});
export const selectAspectRatioSizes = createSelector(selectModelConfig, (modelConfig) => {
if (!modelConfig || !isExternalApiModelConfig(modelConfig)) {
return null;
}
return modelConfig.capabilities.aspect_ratio_sizes ?? null;
});
export const selectResolutionPresets = createSelector(selectModelConfig, (modelConfig) => {
if (!modelConfig || !isExternalApiModelConfig(modelConfig)) {
return null;
}
return modelConfig.capabilities.resolution_presets ?? null;
});
export const selectHasFixedDimensionSizes = createSelector(
selectAspectRatioSizes,
selectResolutionPresets,
(sizes, presets) => sizes !== null || (presets !== null && presets.length > 0)
);
export const selectImageSize = createParamsSelector((params) => params.imageSize);
export const selectOpenaiQuality = createParamsSelector((params) => params.openaiQuality);
export const selectOpenaiBackground = createParamsSelector((params) => params.openaiBackground);
export const selectOpenaiInputFidelity = createParamsSelector((params) => params.openaiInputFidelity);
export const selectGeminiTemperature = createParamsSelector((params) => params.geminiTemperature);
export const selectGeminiThinkingLevel = createParamsSelector((params) => params.geminiThinkingLevel);
export const selectExternalProviderId = createSelector(selectModelConfig, (modelConfig) => {
if (modelConfig && isExternalApiModelConfig(modelConfig)) {
return modelConfig.provider_id;
}
return null;
});
export const selectMainModelConfig = createSelector(
selectModelConfigsQuery,
selectParamsSlice,
(modelConfigs, { model }) => {
if (!modelConfigs.data) {
return null;
}
if (!model) {
return null;
}
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, model.key);
if (!modelConfig) {
return null;
}
if (!isNonRefinerMainModelConfig(modelConfig)) {
return null;
}
export const selectMainModelConfig = createSelector(selectModelConfig, (modelConfig) => {
if (!modelConfig) {
return null;
}
if (isExternalApiModelConfig(modelConfig)) {
return modelConfig;
}
);
if (!isNonRefinerMainModelConfig(modelConfig)) {
return null;
}
return modelConfig;
});

View File

@@ -13,6 +13,7 @@ import type {
CanvasRegionalGuidanceState,
CanvasState,
} from 'features/controlLayers/store/types';
import type { BaseModelType } from 'features/nodes/types/common';
import { getGridSize, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
@@ -74,7 +75,7 @@ export const selectHasEntities = createSelector(selectEntityCountAll, (count) =>
* Selects the optimal dimension for the canvas based on the currently-selected model
*/
export const selectOptimalDimension = createSelector(selectParamsSlice, (params): number => {
const modelBase = params.model?.base;
const modelBase = params.model?.base as BaseModelType | undefined;
return getOptimalDimension(modelBase ?? null);
});
@@ -82,7 +83,7 @@ export const selectOptimalDimension = createSelector(selectParamsSlice, (params)
* Selects the grid size for the canvas based on the currently-selected model
*/
export const selectGridSize = createSelector(selectParamsSlice, (params): number => {
const modelBase = params.model?.base;
const modelBase = params.model?.base as BaseModelType | undefined;
return getGridSize(modelBase ?? null);
});

View File

@@ -663,19 +663,42 @@ export const zLoRA = z.object({
});
export type LoRA = z.infer<typeof zLoRA>;
export const zAspectRatioID = z.enum(['Free', '21:9', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16', '9:21']);
export const zAspectRatioID = z.enum([
'Free',
'8:1',
'4:1',
'21:9',
'16:9',
'3:2',
'5:4',
'4:3',
'1:1',
'3:4',
'4:5',
'2:3',
'9:16',
'1:4',
'9:21',
'1:8',
]);
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
export const ASPECT_RATIO_MAP: Record<Exclude<AspectRatioID, 'Free'>, { ratio: number; inverseID: AspectRatioID }> = {
'8:1': { ratio: 8 / 1, inverseID: '1:8' },
'4:1': { ratio: 4 / 1, inverseID: '1:4' },
'21:9': { ratio: 21 / 9, inverseID: '9:21' },
'16:9': { ratio: 16 / 9, inverseID: '9:16' },
'3:2': { ratio: 3 / 2, inverseID: '2:3' },
'5:4': { ratio: 5 / 4, inverseID: '4:5' },
'4:3': { ratio: 4 / 3, inverseID: '4:3' },
'1:1': { ratio: 1, inverseID: '1:1' },
'3:4': { ratio: 3 / 4, inverseID: '4:3' },
'4:5': { ratio: 4 / 5, inverseID: '5:4' },
'2:3': { ratio: 2 / 3, inverseID: '3:2' },
'9:16': { ratio: 9 / 16, inverseID: '16:9' },
'1:4': { ratio: 1 / 4, inverseID: '4:1' },
'9:21': { ratio: 9 / 21, inverseID: '21:9' },
'1:8': { ratio: 1 / 8, inverseID: '8:1' },
};
const zAspectRatioConfig = z.object({
@@ -794,6 +817,14 @@ export const zParamsState = z.object({
zImageSeedVarianceEnabled: z.boolean(),
zImageSeedVarianceStrength: z.number().min(0).max(2),
zImageSeedVarianceRandomizePercent: z.number().min(1).max(100),
imageSize: z.string().nullable().default(null),
// OpenAI-specific external options
openaiQuality: z.enum(['auto', 'high', 'medium', 'low']).default('auto'),
openaiBackground: z.enum(['auto', 'transparent', 'opaque']).default('auto'),
openaiInputFidelity: z.enum(['low', 'high']).nullable().default(null),
// Gemini-specific external options
geminiTemperature: z.number().min(0).max(2).nullable().default(null),
geminiThinkingLevel: z.enum(['minimal', 'high']).nullable().default(null),
dimensions: zDimensionsState,
});
export type ParamsState = z.infer<typeof zParamsState>;
@@ -865,6 +896,12 @@ export const getInitialParamsState = (): ParamsState => ({
zImageSeedVarianceEnabled: false,
zImageSeedVarianceStrength: 0.1,
zImageSeedVarianceRandomizePercent: 50,
imageSize: null,
openaiQuality: 'auto',
openaiBackground: 'auto',
openaiInputFidelity: null,
geminiTemperature: null,
geminiThinkingLevel: null,
dimensions: {
width: 512,
height: 512,

View File

@@ -6,7 +6,11 @@ import type {
RefImageState,
} from 'features/controlLayers/store/types';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import type { AnyModelConfig, MainModelConfig } from 'services/api/types';
import {
type AnyModelConfigWithExternal,
isExternalApiModelConfig,
type MainOrExternalModelConfig,
} from 'services/api/types';
const WARNINGS = {
UNSUPPORTED_MODEL: 'controlLayers.warnings.unsupportedModel',
@@ -28,7 +32,7 @@ type WarningTKey = (typeof WARNINGS)[keyof typeof WARNINGS];
export const getRegionalGuidanceWarnings = (
entity: CanvasRegionalGuidanceState,
model: MainModelConfig | null | undefined
model: MainOrExternalModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
@@ -100,8 +104,8 @@ export const getRegionalGuidanceWarnings = (
};
export const areBasesCompatibleForRefImage = (
first?: ModelIdentifierField | AnyModelConfig | null,
second?: ModelIdentifierField | AnyModelConfig | null
first?: ModelIdentifierField | AnyModelConfigWithExternal | null,
second?: ModelIdentifierField | AnyModelConfigWithExternal | null
): boolean => {
if (!first || !second) {
return false;
@@ -122,11 +126,19 @@ export const areBasesCompatibleForRefImage = (
export const getGlobalReferenceImageWarnings = (
entity: RefImageState,
model: MainModelConfig | null | undefined
model: MainOrExternalModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
if (model) {
if (isExternalApiModelConfig(model)) {
if (!entity.config.image) {
// No image selected
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
}
return warnings;
}
if (model.base === 'sd-3' || model.base === 'sd-2' || model.base === 'anima') {
// Unsupported model architecture
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
@@ -159,7 +171,7 @@ export const getGlobalReferenceImageWarnings = (
export const getControlLayerWarnings = (
entity: CanvasControlLayerState,
model: MainModelConfig | null | undefined
model: MainOrExternalModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
@@ -193,7 +205,7 @@ export const getControlLayerWarnings = (
export const getRasterLayerWarnings = (
_entity: CanvasRasterLayerState,
_model: MainModelConfig | null | undefined
_model: MainOrExternalModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
@@ -204,7 +216,7 @@ export const getRasterLayerWarnings = (
export const getInpaintMaskWarnings = (
_entity: CanvasInpaintMaskState,
_model: MainModelConfig | null | undefined
_model: MainOrExternalModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];

View File

@@ -1186,7 +1186,8 @@ const LoRAs: CollectionMetadataHandler<LoRA[]> = {
const key = getProperty(rawItem, 'lora.key');
assert(isString(key));
// No need to catch here - if this throws, we move on to the next item
identifier = await getModelIdentiferFromKey(key, store);
const modelConfig = await getModelIdentiferFromKey(key, store);
identifier = zModelIdentifierField.parse(modelConfig);
}
assert(identifier.type === 'lora');

View File

@@ -1,10 +1,11 @@
import type { AnyModelVariant, BaseModelType, ModelFormat, ModelType } from 'features/nodes/types/common';
import type { AnyModelConfig } from 'services/api/types';
import {
type AnyModelConfig,
isCLIPEmbedModelConfig,
isCLIPVisionModelConfig,
isControlLoRAModelConfig,
isControlNetModelConfig,
isExternalApiModelConfig,
isFluxReduxModelConfig,
isIPAdapterModelConfig,
isLLaVAModelConfig,
@@ -121,6 +122,11 @@ export const MODEL_CATEGORIES: Record<ModelCategoryType, ModelCategoryData> = {
i18nKey: 'modelManager.llavaOnevision',
filter: isLLaVAModelConfig,
},
external_image_generator: {
category: 'external_image_generator',
i18nKey: 'modelManager.externalImageGenerator',
filter: isExternalApiModelConfig,
},
};
export const MODEL_CATEGORIES_AS_LIST = objectEntries(MODEL_CATEGORIES).map(([category, { i18nKey, filter }]) => ({
@@ -144,6 +150,7 @@ export const MODEL_BASE_TO_COLOR: Record<BaseModelType, string> = {
cogview4: 'red',
'qwen-image': 'orange',
'z-image': 'cyan',
external: 'orange',
anima: 'invokePurple',
unknown: 'red',
};
@@ -169,6 +176,7 @@ export const MODEL_TYPE_TO_LONG_NAME: Record<ModelType, string> = {
clip_embed: 'CLIP Embed',
siglip: 'SigLIP',
flux_redux: 'FLUX Redux',
external_image_generator: 'External Image Generator',
unknown: 'Unknown',
};
@@ -187,6 +195,7 @@ export const MODEL_BASE_TO_LONG_NAME: Record<BaseModelType, string> = {
cogview4: 'CogView4',
'qwen-image': 'Qwen Image',
'z-image': 'Z-Image',
external: 'External',
anima: 'Anima',
unknown: 'Unknown',
};
@@ -206,6 +215,7 @@ export const MODEL_BASE_TO_SHORT_NAME: Record<BaseModelType, string> = {
cogview4: 'CogView4',
'qwen-image': 'QwenImg',
'z-image': 'Z-Image',
external: 'External',
anima: 'Anima',
unknown: 'Unknown',
};
@@ -237,6 +247,7 @@ export const MODEL_FORMAT_TO_LONG_NAME: Record<ModelFormat, string> = {
checkpoint: 'Checkpoint',
lycoris: 'LyCORIS',
onnx: 'ONNX',
external_api: 'External API',
olive: 'Olive',
embedding_file: 'Embedding (file)',
embedding_folder: 'Embedding (folder)',

View File

@@ -1,13 +1,14 @@
import { atom } from 'nanostores';
type InstallModelsTabName = 'launchpad' | 'urlOrLocal' | 'huggingface' | 'scanFolder' | 'starterModels';
type InstallModelsTabName = 'launchpad' | 'urlOrLocal' | 'huggingface' | 'external' | 'scanFolder' | 'starterModels';
const TAB_TO_INDEX_MAP: Record<InstallModelsTabName, number> = {
launchpad: 0,
urlOrLocal: 1,
huggingface: 2,
scanFolder: 3,
starterModels: 4,
external: 3,
scanFolder: 4,
starterModels: 5,
};
export const setInstallModelsTabByName = (tab: InstallModelsTabName) => {

View File

@@ -7,7 +7,10 @@ import { zModelType } from 'features/nodes/types/common';
import { assert } from 'tsafe';
import z from 'zod';
const zModelCategoryType = zModelType.exclude(['onnx']).or(z.literal('refiner'));
const zModelCategoryType = zModelType
.exclude(['onnx'])
.or(z.literal('refiner'))
.or(z.literal('external_image_generator'));
export type ModelCategoryType = z.infer<typeof zModelCategoryType>;
const zFilterableModelType = zModelCategoryType.or(z.literal('missing'));

View File

@@ -0,0 +1,267 @@
import {
Badge,
Button,
Card,
Flex,
FormControl,
FormHelperText,
FormLabel,
Heading,
Input,
Text,
Tooltip,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { useBuildModelInstallArg } from 'features/modelManagerV2/hooks/useBuildModelsToInstall';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import { $installModelsTabIndex } from 'features/modelManagerV2/store/installModelsStore';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCheckBold, PiWarningBold } from 'react-icons/pi';
import {
useGetExternalProviderConfigsQuery,
useResetExternalProviderConfigMutation,
useSetExternalProviderConfigMutation,
} from 'services/api/endpoints/appInfo';
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
import type { ExternalProviderConfig, StarterModel } from 'services/api/types';
const PROVIDER_SORT_ORDER = ['gemini', 'openai', 'alibabacloud'];
type ProviderCardProps = {
provider: ExternalProviderConfig;
onInstallModels: (providerId: string) => void;
};
type UpdatePayload = {
provider_id: string;
api_key?: string;
base_url?: string;
};
export const ExternalProvidersForm = memo(() => {
const { t } = useTranslation();
const { data, isLoading } = useGetExternalProviderConfigsQuery();
const { data: starterModels } = useGetStarterModelsQuery();
const [installModel] = useInstallModel();
const { getIsInstalled, buildModelInstallArg } = useBuildModelInstallArg();
const tabIndex = useStore($installModelsTabIndex);
const externalModelsByProvider = useMemo(() => {
const groups = new Map<string, StarterModel[]>();
for (const model of starterModels?.starter_models ?? []) {
if (!model.source.startsWith('external://')) {
continue;
}
const providerId = model.source.replace('external://', '').split('/')[0];
if (!providerId) {
continue;
}
const models = groups.get(providerId) ?? [];
models.push(model);
groups.set(providerId, models);
}
for (const [providerId, models] of groups.entries()) {
models.sort((a, b) => a.name.localeCompare(b.name));
groups.set(providerId, models);
}
return groups;
}, [starterModels]);
const handleInstallProviderModels = useCallback(
(providerId: string) => {
const models = externalModelsByProvider.get(providerId);
if (!models?.length) {
return;
}
const modelsToInstall = models.filter((model) => !getIsInstalled(model)).map(buildModelInstallArg);
modelsToInstall.forEach((model) => installModel(model));
},
[buildModelInstallArg, externalModelsByProvider, getIsInstalled, installModel]
);
const sortedProviders = useMemo(() => {
if (!data) {
return [];
}
return [...data].sort((a, b) => {
const aIndex = PROVIDER_SORT_ORDER.indexOf(a.provider_id);
const bIndex = PROVIDER_SORT_ORDER.indexOf(b.provider_id);
if (aIndex === -1 && bIndex === -1) {
return a.provider_id.localeCompare(b.provider_id);
}
if (aIndex === -1) {
return 1;
}
if (bIndex === -1) {
return -1;
}
return aIndex - bIndex;
});
}, [data]);
return (
<Flex flexDir="column" height="100%" gap={4}>
<Flex flexDir="column" gap={1}>
<Heading size="sm">{t('modelManager.externalSetupTitle')}</Heading>
<Text color="base.300">{t('modelManager.externalSetupDescription')}</Text>
</Flex>
<ScrollableContent>
<Flex flexDir="column" gap={4}>
{isLoading && <Text color="base.300">{t('common.loading')}</Text>}
{!isLoading && sortedProviders.length === 0 && (
<Text color="base.300">{t('modelManager.externalProvidersUnavailable')}</Text>
)}
{sortedProviders.map((provider) => (
<ProviderCard
key={provider.provider_id}
provider={provider}
onInstallModels={handleInstallProviderModels}
/>
))}
</Flex>
</ScrollableContent>
{tabIndex === 3 && (
<Text variant="subtext" color="base.400">
{t('modelManager.externalSetupFooter')}
</Text>
)}
</Flex>
);
});
ExternalProvidersForm.displayName = 'ExternalProvidersForm';
const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) => {
const { t } = useTranslation();
const [apiKey, setApiKey] = useState('');
const [baseUrl, setBaseUrl] = useState(provider.base_url ?? '');
const [saveConfig, { isLoading }] = useSetExternalProviderConfigMutation();
const [resetConfig, { isLoading: isResetting }] = useResetExternalProviderConfigMutation();
useEffect(() => {
setBaseUrl(provider.base_url ?? '');
}, [provider.base_url]);
const handleSave = useCallback(() => {
const trimmedApiKey = apiKey.trim();
const trimmedBaseUrl = baseUrl.trim();
const updatePayload: UpdatePayload = {
provider_id: provider.provider_id,
};
if (trimmedApiKey) {
updatePayload.api_key = trimmedApiKey;
}
if (trimmedBaseUrl !== (provider.base_url ?? '')) {
updatePayload.base_url = trimmedBaseUrl;
}
if (!updatePayload.api_key && updatePayload.base_url === undefined) {
return;
}
saveConfig(updatePayload)
.unwrap()
.then((result) => {
if (result.api_key_configured) {
setApiKey('');
onInstallModels(provider.provider_id);
}
if (result.base_url !== undefined) {
setBaseUrl(result.base_url ?? '');
}
});
}, [apiKey, baseUrl, onInstallModels, provider.base_url, provider.provider_id, saveConfig]);
const handleReset = useCallback(() => {
resetConfig(provider.provider_id)
.unwrap()
.then((result) => {
setApiKey('');
setBaseUrl(result.base_url ?? '');
});
}, [provider.provider_id, resetConfig]);
const handleApiKeyChange = useCallback((event: ChangeEvent<HTMLInputElement>) => {
setApiKey(event.target.value);
}, []);
const handleBaseUrlChange = useCallback((event: ChangeEvent<HTMLInputElement>) => {
setBaseUrl(event.target.value);
}, []);
const statusBadge = provider.api_key_configured ? (
<Badge colorScheme="green" display="flex" alignItems="center" gap={2}>
<PiCheckBold />
{t('settings.externalProviderConfigured')}
</Badge>
) : (
<Badge colorScheme="warning" display="flex" alignItems="center" gap={2}>
<PiWarningBold />
{t('settings.externalProviderNotConfigured')}
</Badge>
);
return (
<Card p={4} gap={4} variant="outline">
<Flex justifyContent="space-between" alignItems="center" flexWrap="wrap" gap={3}>
<Flex flexDir="column" gap={1}>
<Heading size="xs" textTransform="capitalize">
{provider.provider_id}
</Heading>
<Text variant="subtext">
{t('modelManager.externalProviderCardDescription', { providerId: provider.provider_id })}
</Text>
</Flex>
{statusBadge}
</Flex>
<Flex flexDir="column" gap={4}>
<FormControl>
<FormLabel>{t('modelManager.externalApiKey')}</FormLabel>
<Input
type="password"
autoComplete="off"
placeholder={
provider.api_key_configured
? t('modelManager.externalApiKeyPlaceholderSet')
: t('modelManager.externalApiKeyPlaceholder')
}
value={apiKey}
onChange={handleApiKeyChange}
/>
<FormHelperText>{t('modelManager.externalApiKeyHelper')}</FormHelperText>
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.externalBaseUrl')}</FormLabel>
<Input
placeholder={t('modelManager.externalBaseUrlPlaceholder')}
value={baseUrl}
onChange={handleBaseUrlChange}
/>
<FormHelperText>{t('modelManager.externalBaseUrlHelper')}</FormHelperText>
</FormControl>
<Flex gap={2} justifyContent="flex-end" flexWrap="wrap">
<Tooltip label={t('modelManager.externalResetHelper')}>
<Button variant="ghost" onClick={handleReset} isLoading={isResetting}>
{t('common.reset')}
</Button>
</Tooltip>
<Button
colorScheme="invokeYellow"
onClick={handleSave}
isLoading={isLoading}
isDisabled={!apiKey.trim() && baseUrl.trim() === (provider.base_url ?? '')}
>
{t('common.save')}
</Button>
</Flex>
</Flex>
</Card>
);
});
ProviderCard.displayName = 'ProviderCard';

View File

@@ -6,7 +6,7 @@ import { StarterBundleButton } from 'features/modelManagerV2/subpanels/AddModelP
import { StarterBundleTooltipContentCompact } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterBundleTooltipContentCompact';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFolderOpenBold, PiLinkBold, PiStarBold } from 'react-icons/pi';
import { PiFolderOpenBold, PiLinkBold, PiPlugBold, PiStarBold } from 'react-icons/pi';
import { SiHuggingface } from 'react-icons/si';
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
@@ -28,6 +28,10 @@ export const LaunchpadForm = memo(() => {
setInstallModelsTabByName('scanFolder');
}, []);
const navigateToExternalTab = useCallback(() => {
setInstallModelsTabByName('external');
}, []);
const navigateToStarterModelsTab = useCallback(() => {
setInstallModelsTabByName('starterModels');
}, []);
@@ -63,6 +67,12 @@ export const LaunchpadForm = memo(() => {
title={t('modelManager.scanFolder')}
description={t('modelManager.launchpad.scanFolderDescription')}
/>
<LaunchpadButton
onClick={navigateToExternalTab}
icon={PiPlugBold}
title={t('modelManager.externalProviders')}
description={t('modelManager.launchpad.externalDescription')}
/>
</Grid>
</Flex>
{/* Recommended Section */}

View File

@@ -285,6 +285,8 @@ export const ModelInstallQueueItem = memo((props: ModelListItemProps) => {
return installJob.source.url;
case 'local':
return installJob.source.path;
case 'external':
return `external://${installJob.source.provider_id}/${installJob.source.provider_model_id}`;
default:
return t('common.unknown');
}
@@ -292,6 +294,8 @@ export const ModelInstallQueueItem = memo((props: ModelListItemProps) => {
const displayStatus = optimisticStatus?.status ?? installJob.status;
const configuredName = installJob.config_in?.name;
const modelName = useMemo(() => {
switch (installJob.source.type) {
case 'hf': {
@@ -302,13 +306,15 @@ export const ModelInstallQueueItem = memo((props: ModelListItemProps) => {
return repo_id;
}
case 'url':
return installJob.source.url.split('/').slice(-1)[0] ?? t('common.unknown');
return configuredName ?? installJob.source.url.split('/').slice(-1)[0] ?? t('common.unknown');
case 'local':
return installJob.source.path.split(/[/\\]/).slice(-1)[0] ?? t('common.unknown');
return configuredName ?? installJob.source.path.split(/[/\\]/).slice(-1)[0] ?? t('common.unknown');
case 'external':
return configuredName ?? `${installJob.source.provider_id}/${installJob.source.provider_model_id}`;
default:
return t('common.unknown');
return configuredName ?? t('common.unknown');
}
}, [installJob.source, t]);
}, [configuredName, installJob.source, t]);
const progressValue = useMemo(() => {
if (displayStatus === 'completed' || displayStatus === 'error' || displayStatus === 'cancelled') {

View File

@@ -21,6 +21,9 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
const filteredResults = useMemo(() => {
return results.starter_models.filter((result) => {
if (result.source.startsWith('external://')) {
return false;
}
const trimmedSearchTerm = searchTerm.trim().toLowerCase();
const matchStrings = [
result.name.toLowerCase(),

View File

@@ -2,18 +2,18 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Divider, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $installModelsTabIndex } from 'features/modelManagerV2/store/installModelsStore';
import { ExternalProvidersForm } from 'features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm';
import { HuggingFaceForm } from 'features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
import { InstallModelForm } from 'features/modelManagerV2/subpanels/AddModelPanel/InstallModelForm';
import { LaunchpadForm } from 'features/modelManagerV2/subpanels/AddModelPanel/LaunchpadForm/LaunchpadForm';
import { ModelInstallQueue } from 'features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueue';
import { ScanModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/ScanFolder/ScanFolderForm';
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCubeBold, PiFolderOpenBold, PiLinkSimpleBold, PiShootingStarBold } from 'react-icons/pi';
import { PiCubeBold, PiFolderOpenBold, PiLinkSimpleBold, PiPlugBold, PiShootingStarBold } from 'react-icons/pi';
import { SiHuggingface } from 'react-icons/si';
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
import { LaunchpadForm } from './AddModelPanel/LaunchpadForm/LaunchpadForm';
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
const installModelsTabSx: SystemStyleObject = {
display: 'flex',
gap: 2,
@@ -61,6 +61,10 @@ export const InstallModels = memo(() => {
<SiHuggingface />
{t('modelManager.huggingFace')}
</Tab>
<Tab sx={installModelsTabSx}>
<PiPlugBold />
{t('modelManager.externalProviders')}
</Tab>
<Tab sx={installModelsTabSx}>
<PiFolderOpenBold />
{t('modelManager.scanFolder')}
@@ -80,6 +84,9 @@ export const InstallModels = memo(() => {
<TabPanel height="100%">
<HuggingFaceForm />
</TabPanel>
<TabPanel height="100%">
<ExternalProvidersForm />
</TabPanel>
<TabPanel height="100%">
<ScanModelsForm />
</TabPanel>

View File

@@ -19,6 +19,7 @@ const FORMAT_NAME_MAP: Record<ModelFormat, string> = {
bnb_quantized_nf4b: 'quantized',
gguf_quantized: 'gguf',
omi: 'omi',
external_api: 'external_api',
unknown: 'unknown',
olive: 'olive',
onnx: 'onnx',
@@ -40,6 +41,7 @@ const FORMAT_COLOR_MAP: Record<ModelFormat, string> = {
unknown: 'red',
olive: 'base',
onnx: 'base',
external_api: 'base',
};
const ModelFormatBadge = ({ format }: Props) => {

View File

@@ -5,7 +5,7 @@ import { MODEL_BASE_TO_LONG_NAME } from 'features/modelManagerV2/models';
import { useCallback, useMemo } from 'react';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import type { UpdateModelBody } from 'services/api/types';
import { objectEntries } from 'tsafe';
const options: ComboboxOption[] = objectEntries(MODEL_BASE_TO_LONG_NAME).map(([value, label]) => ({
@@ -14,7 +14,7 @@ const options: ComboboxOption[] = objectEntries(MODEL_BASE_TO_LONG_NAME).map(([v
}));
type Props = {
control: Control<UpdateModelArg['body']>;
control: Control<UpdateModelBody>;
};
const BaseModelSelect = ({ control }: Props) => {

View File

@@ -5,7 +5,7 @@ import { MODEL_FORMAT_TO_LONG_NAME } from 'features/modelManagerV2/models';
import { useCallback, useMemo } from 'react';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import type { UpdateModelBody } from 'services/api/types';
import { objectEntries } from 'tsafe';
const options: ComboboxOption[] = objectEntries(MODEL_FORMAT_TO_LONG_NAME).map(([value, label]) => ({
@@ -14,7 +14,7 @@ const options: ComboboxOption[] = objectEntries(MODEL_FORMAT_TO_LONG_NAME).map((
}));
type Props = {
control: Control<UpdateModelArg['body']>;
control: Control<UpdateModelBody>;
};
const ModelFormatSelect = ({ control }: Props) => {

View File

@@ -5,7 +5,7 @@ import { MODEL_TYPE_TO_LONG_NAME } from 'features/modelManagerV2/models';
import { useCallback, useMemo } from 'react';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import type { UpdateModelBody } from 'services/api/types';
import { objectEntries } from 'tsafe';
const options: ComboboxOption[] = objectEntries(MODEL_TYPE_TO_LONG_NAME).map(([value, label]) => ({
@@ -14,7 +14,7 @@ const options: ComboboxOption[] = objectEntries(MODEL_TYPE_TO_LONG_NAME).map(([v
}));
type Props = {
control: Control<UpdateModelArg['body']>;
control: Control<UpdateModelBody>;
};
const ModelTypeSelect = ({ control }: Props) => {

View File

@@ -5,13 +5,13 @@ import { MODEL_VARIANT_TO_LONG_NAME } from 'features/modelManagerV2/models';
import { useCallback, useMemo } from 'react';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import type { UpdateModelBody } from 'services/api/types';
import { objectEntries } from 'tsafe';
const options: ComboboxOption[] = objectEntries(MODEL_VARIANT_TO_LONG_NAME).map(([value, label]) => ({ label, value }));
type Props = {
control: Control<UpdateModelArg['body']>;
control: Control<UpdateModelBody>;
};
const ModelVariantSelect = ({ control }: Props) => {

View File

@@ -4,7 +4,7 @@ import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import type { UpdateModelBody } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
@@ -14,7 +14,7 @@ const options: ComboboxOption[] = [
];
type Props = {
control: Control<UpdateModelArg['body']>;
control: Control<UpdateModelBody>;
};
const PredictionTypeSelect = ({ control }: Props) => {

View File

@@ -5,6 +5,7 @@ import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiExclamationMarkBold } from 'react-icons/pi';
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
import type { AnyModelConfigWithExternal } from 'services/api/types';
import { ModelEdit } from './ModelEdit';
import { ModelView } from './ModelView';
@@ -21,7 +22,9 @@ export const Model = memo(() => {
if (selectedModelKey === null) {
return null;
}
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs, selectedModelKey);
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs, selectedModelKey) as
| AnyModelConfigWithExternal
| undefined;
if (!modelConfig) {
return null;

View File

@@ -7,11 +7,11 @@ import { memo, type MouseEvent, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import type { AnyModelConfigWithExternal } from 'services/api/types';
type Props = {
showLabel?: boolean;
modelConfig: AnyModelConfig;
modelConfig: AnyModelConfigWithExternal;
};
export const ModelDeleteButton = memo(({ showLabel = true, modelConfig }: Props) => {

View File

@@ -15,12 +15,18 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
import { type SubmitHandler, useForm } from 'react-hook-form';
import { memo, useCallback, useMemo } from 'react';
import { type Control, type SubmitHandler, useForm, useWatch } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { PiCheckBold, PiXBold } from 'react-icons/pi';
import { type UpdateModelArg, useUpdateModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
type AnyModelConfigWithExternal,
type ExternalApiModelDefaultSettings,
type ExternalModelCapabilities,
isExternalApiModelConfig,
type UpdateModelBody,
} from 'services/api/types';
import BaseModelSelect from './Fields/BaseModelSelect';
import ModelFormatSelect from './Fields/ModelFormatSelect';
@@ -30,7 +36,14 @@ import PredictionTypeSelect from './Fields/PredictionTypeSelect';
import { ModelFooter } from './ModelFooter';
type Props = {
modelConfig: AnyModelConfig;
modelConfig: AnyModelConfigWithExternal;
};
type ModelEditFormValues = UpdateModelBody & {
capabilities?: ExternalModelCapabilities;
provider_id?: string;
provider_model_id?: string;
default_settings?: ExternalApiModelDefaultSettings | null;
};
const stringFieldOptions = {
@@ -41,19 +54,54 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
const { t } = useTranslation();
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
const dispatch = useAppDispatch();
const isExternal = useMemo(() => isExternalApiModelConfig(modelConfig), [modelConfig]);
const form = useForm<UpdateModelArg['body']>({
defaultValues: modelConfig,
const form = useForm<ModelEditFormValues>({
defaultValues: modelConfig as unknown as ModelEditFormValues,
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
const externalModes = useWatch({
control: form.control,
name: 'capabilities.modes',
}) as ExternalModelCapabilities['modes'] | undefined;
const modeSet = useMemo(() => new Set(externalModes ?? []), [externalModes]);
const toggleMode = useCallback(
(mode: ExternalModelCapabilities['modes'][number]) => {
const nextModes = modeSet.has(mode)
? externalModes?.filter((value) => value !== mode)
: [...(externalModes ?? []), mode];
form.setValue('capabilities.modes', nextModes ?? [], { shouldDirty: true, shouldValidate: true });
},
[externalModes, form, modeSet]
);
const handleToggleTxt2Img = useCallback(() => toggleMode('txt2img'), [toggleMode]);
const handleToggleImg2Img = useCallback(() => toggleMode('img2img'), [toggleMode]);
const handleToggleInpaint = useCallback(() => toggleMode('inpaint'), [toggleMode]);
const parseOptionalNumber = useCallback((value: string | null | undefined) => {
if (value === null || value === undefined || value === '') {
return null;
}
if (typeof value !== 'string') {
return Number.isNaN(Number(value)) ? null : Number(value);
}
if (value.trim() === '') {
return null;
}
const parsed = Number(value);
return Number.isNaN(parsed) ? null : parsed;
}, []);
const onSubmit = useCallback<SubmitHandler<ModelEditFormValues>>(
(values) => {
const responseBody: UpdateModelArg = {
key: modelConfig.key,
body: values,
body: values as UpdateModelBody,
};
updateModel(responseBody)
.unwrap()
.then((payload) => {
@@ -133,19 +181,19 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
<SimpleGrid columns={2} gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect control={form.control} />
<ModelTypeSelect control={form.control as unknown as Control<UpdateModelBody>} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelFormat')}</FormLabel>
<ModelFormatSelect control={form.control} />
<ModelFormatSelect control={form.control as unknown as Control<UpdateModelBody>} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={form.control} />
<BaseModelSelect control={form.control as unknown as Control<UpdateModelBody>} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={form.control} />
<ModelVariantSelect control={form.control as unknown as Control<UpdateModelBody>} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
@@ -153,13 +201,143 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect control={form.control} />
<PredictionTypeSelect control={form.control as unknown as Control<UpdateModelBody>} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<Checkbox {...form.register('upcast_attention')} />
</FormControl>
</SimpleGrid>
{isExternal && (
<>
<Heading as="h3" fontSize="md" mt="4">
{t('modelManager.externalProvider')}
</Heading>
<SimpleGrid columns={2} gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.providerId')}</FormLabel>
<Input {...form.register('provider_id', stringFieldOptions)} size="md" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.providerModelId')}</FormLabel>
<Input {...form.register('provider_model_id', stringFieldOptions)} size="md" />
</FormControl>
</SimpleGrid>
<Heading as="h3" fontSize="md" mt="4">
{t('modelManager.externalCapabilities')}
</Heading>
<SimpleGrid columns={2} gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.supportedModes')}</FormLabel>
<Flex gap={3} wrap="wrap">
<Checkbox isChecked={modeSet.has('txt2img')} onChange={handleToggleTxt2Img}>
txt2img
</Checkbox>
<Checkbox isChecked={modeSet.has('img2img')} onChange={handleToggleImg2Img}>
img2img
</Checkbox>
<Checkbox isChecked={modeSet.has('inpaint')} onChange={handleToggleInpaint}>
inpaint
</Checkbox>
</Flex>
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.supportsReferenceImages')}</FormLabel>
<Checkbox {...form.register('capabilities.supports_reference_images')} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.supportsSeed')}</FormLabel>
<Checkbox {...form.register('capabilities.supports_seed')} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.maxImagesPerRequest')}</FormLabel>
<Input
type="number"
{...form.register('capabilities.max_images_per_request', {
setValueAs: parseOptionalNumber,
})}
/>
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.maxReferenceImages')}</FormLabel>
<Input
type="number"
{...form.register('capabilities.max_reference_images', {
setValueAs: parseOptionalNumber,
})}
/>
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.maxImageWidth')}</FormLabel>
<Input
type="number"
{...form.register('capabilities.max_image_size.width', {
setValueAs: parseOptionalNumber,
})}
/>
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.maxImageHeight')}</FormLabel>
<Input
type="number"
{...form.register('capabilities.max_image_size.height', {
setValueAs: parseOptionalNumber,
})}
/>
</FormControl>
</SimpleGrid>
<Heading as="h3" fontSize="md" mt="4">
{t('modelManager.externalDefaults')}
</Heading>
<SimpleGrid columns={2} gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.width')}</FormLabel>
<Input
type="number"
{...form.register('default_settings.width', {
setValueAs: parseOptionalNumber,
})}
/>
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.height')}</FormLabel>
<Input
type="number"
{...form.register('default_settings.height', {
setValueAs: parseOptionalNumber,
})}
/>
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('parameters.steps')}</FormLabel>
<Input
type="number"
{...form.register('default_settings.steps', {
setValueAs: parseOptionalNumber,
})}
/>
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('parameters.guidance')}</FormLabel>
<Input
type="number"
{...form.register('default_settings.guidance', {
setValueAs: parseOptionalNumber,
})}
/>
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.numImages')}</FormLabel>
<Input
type="number"
{...form.register('default_settings.num_images', {
setValueAs: parseOptionalNumber,
})}
/>
</FormControl>
</SimpleGrid>
</>
)}
</Flex>
</form>
</Flex>

View File

@@ -1,7 +1,7 @@
import { Flex, Heading, type SystemStyleObject } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
import type { AnyModelConfigWithExternal } from 'services/api/types';
import { ModelConvertButton } from './ModelConvertButton';
import { ModelDeleteButton } from './ModelDeleteButton';
@@ -20,7 +20,7 @@ const footerRowSx: SystemStyleObject = {
};
type Props = {
modelConfig: AnyModelConfig;
modelConfig: AnyModelConfigWithExternal;
isEditing: boolean;
};

View File

@@ -4,10 +4,10 @@ import ModelImageUpload from 'features/modelManagerV2/subpanels/ModelPanel/Field
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
import type { AnyModelConfigWithExternal } from 'services/api/types';
type Props = PropsWithChildren<{
modelConfig: AnyModelConfig;
modelConfig: AnyModelConfigWithExternal;
}>;
export const ModelHeader = memo(({ modelConfig, children }: Props) => {

View File

@@ -4,15 +4,18 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiSparkleFill } from 'react-icons/pi';
import { useReidentifyModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { type AnyModelConfigWithExternal, isExternalApiModelConfig } from 'services/api/types';
import { isExternalModel } from './isExternalModel';
interface Props {
modelConfig: AnyModelConfig;
modelConfig: AnyModelConfigWithExternal;
}
export const ModelReidentifyButton = memo(({ modelConfig }: Props) => {
const { t } = useTranslation();
const [reidentifyModel, { isLoading }] = useReidentifyModelMutation();
const isExternal = isExternalApiModelConfig(modelConfig) || isExternalModel(modelConfig.path);
const onClick = useCallback(() => {
reidentifyModel({ key: modelConfig.key })
@@ -40,6 +43,10 @@ export const ModelReidentifyButton = memo(({ modelConfig }: Props) => {
});
}, [modelConfig.key, reidentifyModel, t]);
if (isExternal) {
return null;
}
return (
<Button
onClick={onClick}

View File

@@ -3,13 +3,13 @@ import { toast } from 'features/toast/toast';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiDownloadSimpleBold } from 'react-icons/pi';
import type { AnyModelConfig } from 'services/api/types';
import type { AnyModelConfigWithExternal } from 'services/api/types';
type Props = {
modelConfig: AnyModelConfig;
modelConfig: AnyModelConfigWithExternal;
};
const buildExportData = (modelConfig: AnyModelConfig): Record<string, unknown> => {
const buildExportData = (modelConfig: AnyModelConfigWithExternal): Record<string, unknown> => {
const data: Record<string, unknown> = {};
if (

View File

@@ -5,7 +5,7 @@ import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiUploadSimpleBold } from 'react-icons/pi';
import { useUpdateModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import type { AnyModelConfigWithExternal } from 'services/api/types';
const validateImportData = (data: unknown): data is Record<string, unknown> => {
if (typeof data !== 'object' || data === null || Array.isArray(data)) {
@@ -40,7 +40,7 @@ const validateImportData = (data: unknown): data is Record<string, unknown> => {
};
type Props = {
modelConfig: AnyModelConfig;
modelConfig: AnyModelConfigWithExternal;
};
export const ModelSettingsImportButton = memo(({ modelConfig }: Props) => {

View File

@@ -22,10 +22,10 @@ import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFolderOpenFill } from 'react-icons/pi';
import { useUpdateModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import type { AnyModelConfigWithExternal } from 'services/api/types';
interface Props {
modelConfig: AnyModelConfig;
modelConfig: AnyModelConfigWithExternal;
}
export const ModelUpdatePathButton = memo(({ modelConfig }: Props) => {

View File

@@ -12,14 +12,15 @@ import { TriggerPhrases } from 'features/modelManagerV2/subpanels/ModelPanel/Tri
import { filesize } from 'filesize';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type {
AnyModelConfig,
CLIPEmbedModelConfig,
CLIPVisionModelConfig,
LlavaOnevisionModelConfig,
Qwen3EncoderModelConfig,
SigLIPModelConfig,
T5EncoderModelConfig,
import {
type AnyModelConfigWithExternal,
type CLIPEmbedModelConfig,
type CLIPVisionModelConfig,
isExternalApiModelConfig,
type LlavaOnevisionModelConfig,
type Qwen3EncoderModelConfig,
type SigLIPModelConfig,
type T5EncoderModelConfig,
} from 'services/api/types';
import { isExternalModel } from './isExternalModel';
@@ -38,7 +39,7 @@ type EncoderModelConfig =
| SigLIPModelConfig
| LlavaOnevisionModelConfig;
const isEncoderModel = (modelConfig: AnyModelConfig): modelConfig is EncoderModelConfig => {
const isEncoderModel = (modelConfig: AnyModelConfigWithExternal): modelConfig is EncoderModelConfig => {
return (
modelConfig.type === 'clip_embed' ||
modelConfig.type === 't5_encoder' ||
@@ -50,7 +51,7 @@ const isEncoderModel = (modelConfig: AnyModelConfig): modelConfig is EncoderMode
};
type Props = {
modelConfig: AnyModelConfig;
modelConfig: AnyModelConfigWithExternal;
};
export const ModelView = memo(({ modelConfig }: Props) => {
@@ -104,6 +105,12 @@ export const ModelView = memo(({ modelConfig }: Props) => {
<ModelAttrView label={t('modelManager.modelFormat')} value={modelConfig.format} />
<ModelAttrView label={t('modelManager.path')} value={modelConfig.path} />
<ModelAttrView label={t('modelManager.fileSize')} value={filesize(modelConfig.file_size)} />
{isExternalApiModelConfig(modelConfig) && (
<>
<ModelAttrView label={t('modelManager.providerId')} value={modelConfig.provider_id} />
<ModelAttrView label={t('modelManager.providerModelId')} value={modelConfig.provider_model_id} />
</>
)}
{'variant' in modelConfig && modelConfig.variant && (
<ModelAttrView label={t('modelManager.variant')} value={modelConfig.variant} />
)}

View File

@@ -31,10 +31,10 @@ import {
useRemoveModelRelationshipMutation,
} from 'services/api/endpoints/modelRelationships';
import { useGetModelConfigsQuery } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import type { AnyModelConfig, AnyModelConfigWithExternal } from 'services/api/types';
type Props = {
modelConfig: AnyModelConfig;
modelConfig: AnyModelConfigWithExternal;
};
type ModelGroup = {
@@ -52,7 +52,10 @@ type ModelGroup = {
//
// TODO: In the future, refine this logic to more strictly validate
// relationships based on model types or actual usage patterns.
const isBaseCompatible = (a: AnyModelConfig, b: AnyModelConfig): boolean => {
const isBaseCompatible = (a: AnyModelConfigWithExternal, b: AnyModelConfig): boolean => {
if (a.base === 'external') {
return false;
}
if (a.base === 'any' || b.base === 'any') {
return true;
}

View File

@@ -97,6 +97,7 @@ export const zBaseModelType = z.enum([
'cogview4',
'qwen-image',
'z-image',
'external',
'anima',
'unknown',
]);
@@ -133,6 +134,7 @@ export const zModelType = z.enum([
'clip_embed',
'siglip',
'flux_redux',
'external_image_generator',
'unknown',
]);
export type ModelType = z.infer<typeof zModelType>;
@@ -184,6 +186,7 @@ export const zModelFormat = z.enum([
'bnb_quantized_int8b',
'bnb_quantized_nf4b',
'gguf_quantized',
'external_api',
'unknown',
]);
export type ModelFormat = z.infer<typeof zModelFormat>;
@@ -197,6 +200,17 @@ export const zModelIdentifierField = z.object({
submodel_type: zSubModelType.nullish(),
});
export type ModelIdentifierField = z.infer<typeof zModelIdentifierField>;
// Frontend-only identifier for external API models (not part of the backend schema)
export const zExternalModelIdentifierField = z.object({
key: z.string().min(1),
hash: z.string().min(1),
name: z.string().min(1),
base: z.literal('external'),
type: z.literal('external_image_generator'),
submodel_type: zSubModelType.nullish(),
});
// #endregion
// #region Control Adapters

View File

@@ -0,0 +1,201 @@
import type { RootState } from 'app/store/store';
import type { ParamsState, RefImagesState } from 'features/controlLayers/store/types';
import { imageDTOToCroppableImage, initialIPAdapter } from 'features/controlLayers/store/util';
import type {
ExternalApiModelConfig,
ExternalApiModelDefaultSettings,
ExternalImageSize,
ExternalModelCapabilities,
ExternalModelPanelSchema,
ImageDTO,
} from 'services/api/types';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { buildExternalGraph } from './buildExternalGraph';
const createExternalModel = (overrides: Partial<ExternalApiModelConfig> = {}): ExternalApiModelConfig => {
const maxImageSize: ExternalImageSize = { width: 1024, height: 1024 };
const defaultSettings: ExternalApiModelDefaultSettings = { width: 1024, height: 1024 };
const capabilities: ExternalModelCapabilities = {
modes: ['txt2img'],
supports_reference_images: true,
supports_seed: true,
max_image_size: maxImageSize,
};
return {
key: 'external-test',
hash: 'external:openai:gpt-image-1',
path: 'external://openai/gpt-image-1',
file_size: 0,
name: 'External Test',
description: null,
source: 'external://openai/gpt-image-1',
source_type: 'url',
source_api_response: null,
cover_image: null,
base: 'external',
type: 'external_image_generator',
format: 'external_api',
provider_id: 'openai',
provider_model_id: 'gpt-image-1',
capabilities,
default_settings: defaultSettings,
panel_schema: undefined,
tags: ['external'],
is_default: false,
...overrides,
};
};
let mockModelConfig: ExternalApiModelConfig | null = null;
let mockParams: ParamsState;
let mockRefImages: RefImagesState;
let mockSizes: { scaledSize: { width: number; height: number } };
const mockOutputFields = {
id: 'external_output',
use_cache: false,
is_intermediate: false,
board: undefined,
};
vi.mock('features/controlLayers/store/paramsSlice', () => ({
selectModelConfig: () => mockModelConfig,
selectParamsSlice: () => mockParams,
}));
vi.mock('features/controlLayers/store/refImagesSlice', () => ({
selectRefImagesSlice: () => mockRefImages,
}));
vi.mock('features/nodes/util/graph/graphBuilderUtils', () => ({
getOriginalAndScaledSizesForTextToImage: () => mockSizes,
getOriginalAndScaledSizesForOtherModes: () => ({
scaledSize: { width: 512, height: 512 },
rect: { x: 0, y: 0, width: 512, height: 512 },
}),
selectCanvasOutputFields: () => mockOutputFields,
}));
beforeEach(() => {
mockParams = {
steps: 20,
guidance: 4.5,
dimensions: { width: 768, height: 512, aspectRatio: { id: '3:2', value: 1.5, isLocked: true } },
} as ParamsState;
mockSizes = { scaledSize: { width: 768, height: 512 } };
const imageDTO = { image_name: 'ref.png', width: 64, height: 64 } as ImageDTO;
mockRefImages = {
selectedEntityId: null,
isPanelOpen: false,
entities: [
{
id: 'ref-image-1',
isEnabled: true,
config: {
...initialIPAdapter,
weight: 0.5,
image: imageDTOToCroppableImage(imageDTO),
},
},
],
};
});
describe('buildExternalGraph', () => {
it('builds txt2img graph with reference images and seed', async () => {
const modelConfig = createExternalModel();
mockModelConfig = modelConfig;
const { g } = await buildExternalGraph({
generationMode: 'txt2img',
state: {} as RootState,
manager: null,
});
const graph = g.getGraph();
const externalNode = Object.values(graph.nodes).find((node) => node.type === 'openai_image_generation') as
| Record<string, unknown>
| undefined;
expect(externalNode).toBeDefined();
expect(externalNode?.type).toBe('openai_image_generation');
expect(externalNode?.mode).toBe('txt2img');
expect(externalNode?.width).toBe(768);
expect(externalNode?.height).toBe(512);
expect((externalNode?.reference_images as Array<{ image_name: string }> | undefined)?.[0]).toEqual({
image_name: 'ref.png',
});
const seedEdge = graph.edges.find((edge) => edge.destination.field === 'seed');
expect(seedEdge).toBeDefined();
});
it('prefers panel schema over capabilities when building node inputs', async () => {
const panelSchema: ExternalModelPanelSchema = {
prompts: [{ name: 'reference_images' }],
image: [{ name: 'dimensions' }],
generation: [],
};
mockModelConfig = createExternalModel({
panel_schema: panelSchema,
});
const { g } = await buildExternalGraph({
generationMode: 'txt2img',
state: {} as RootState,
manager: null,
});
const graph = g.getGraph();
const externalNode = Object.values(graph.nodes).find((node) => node.type === 'openai_image_generation') as
| Record<string, unknown>
| undefined;
expect((externalNode?.reference_images as Array<{ image_name: string }> | undefined)?.[0]).toEqual({
image_name: 'ref.png',
});
const seedEdge = graph.edges.find((edge) => edge.destination.field === 'seed');
expect(seedEdge).toBeUndefined();
});
it('uses provider-specific node types', async () => {
mockModelConfig = createExternalModel({
provider_id: 'gemini',
provider_model_id: 'gemini-2.5-flash-image',
path: 'external://gemini/gemini-2.5-flash-image',
source: 'external://gemini/gemini-2.5-flash-image',
hash: 'external:gemini:gemini-2.5-flash-image',
});
const { g } = await buildExternalGraph({
generationMode: 'txt2img',
state: {} as RootState,
manager: null,
});
const graph = g.getGraph();
const externalNode = Object.values(graph.nodes).find((node) => node.type === 'gemini_image_generation');
expect(externalNode).toBeDefined();
expect(externalNode?.type).toBe('gemini_image_generation');
});
it('throws when mode is unsupported', async () => {
const modelConfig = createExternalModel({
capabilities: {
modes: ['img2img'],
},
});
mockModelConfig = modelConfig;
await expect(
buildExternalGraph({
generationMode: 'txt2img',
state: {} as RootState,
manager: null,
})
).rejects.toThrow('does not support txt2img');
});
});

View File

@@ -0,0 +1,152 @@
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { type ModelIdentifierField, zImageField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
getOriginalAndScaledSizesForOtherModes,
selectCanvasOutputFields,
} from 'features/nodes/util/graph/graphBuilderUtils';
import {
type GraphBuilderArg,
type GraphBuilderReturn,
UnsupportedGenerationModeError,
} from 'features/nodes/util/graph/types';
import { hasExternalPanelControl } from 'features/parameters/util/externalPanelSchema';
import {
type AnyInvocation,
type AnyModelConfigWithExternal,
type Invocation,
isExternalApiModelConfig,
} from 'services/api/types';
import { assert } from 'tsafe';
const EXTERNAL_PROVIDER_NODE_TYPES = {
gemini: 'gemini_image_generation',
openai: 'openai_image_generation',
} as const;
export const buildExternalGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state, manager } = arg;
const modelConfig = selectModelConfig(state) as AnyModelConfigWithExternal | null;
assert(modelConfig, 'No model selected');
assert(isExternalApiModelConfig(modelConfig), 'Selected model is not an external API model');
const model = modelConfig;
const requestedMode = generationMode === 'outpaint' ? 'inpaint' : generationMode;
if (!model.capabilities.modes.includes(requestedMode)) {
throw new UnsupportedGenerationModeError(`${model.name} does not support ${requestedMode} mode`);
}
const params = selectParamsSlice(state);
const refImages = selectRefImagesSlice(state);
const g = new Graph(getPrefixedId('external_graph'));
const supportsSeed = hasExternalPanelControl(model, 'image', 'seed');
const supportsReferenceImages = hasExternalPanelControl(model, 'prompts', 'reference_images');
const seed = supportsSeed
? g.addNode({
id: getPrefixedId('seed'),
type: 'integer',
})
: null;
const positivePrompt = g.addNode({
id: getPrefixedId('positive_prompt'),
type: 'string',
});
const externalNodeType = EXTERNAL_PROVIDER_NODE_TYPES[model.provider_id as keyof typeof EXTERNAL_PROVIDER_NODE_TYPES];
const externalNode: Record<string, unknown> = {
id: getPrefixedId('external_image_generation'),
type: externalNodeType ?? 'external_image_generation',
model: model as unknown as ModelIdentifierField,
mode: requestedMode,
image_size: params.imageSize ?? null,
num_images: 1,
};
// Provider-specific options
if (model.provider_id === 'openai') {
externalNode.quality = params.openaiQuality;
externalNode.background = params.openaiBackground;
if (params.openaiInputFidelity) {
externalNode.input_fidelity = params.openaiInputFidelity;
}
} else if (model.provider_id === 'gemini') {
if (params.geminiTemperature !== null) {
externalNode.temperature = params.geminiTemperature;
}
if (params.geminiThinkingLevel) {
externalNode.thinking_level = params.geminiThinkingLevel;
}
}
g.addNode(externalNode as AnyInvocation);
if (seed) {
g.addEdgeFromObj({
source: { node_id: seed.id, field: 'value' },
destination: { node_id: externalNode.id as string, field: 'seed' },
});
}
g.addEdgeFromObj({
source: { node_id: positivePrompt.id, field: 'value' },
destination: { node_id: externalNode.id as string, field: 'prompt' },
});
if (supportsReferenceImages) {
const referenceImages = refImages.entities
.filter((entity) => entity.isEnabled)
.map((entity) => entity.config)
.filter((config) => config.image)
.map((config) => zImageField.parse(config.image?.crop?.image ?? config.image?.original.image));
if (referenceImages.length > 0) {
externalNode.reference_images = referenceImages;
}
}
// External models require specific dimensions matching their supported presets.
// Always use params dimensions (from selected preset) for the API width/height.
externalNode.width = params.dimensions.width;
externalNode.height = params.dimensions.height;
if (generationMode !== 'txt2img') {
assert(manager, 'Canvas manager is required for img2img/inpaint');
const canvasSettings = selectCanvasSettingsSlice(state);
const { rect } = getOriginalAndScaledSizesForOtherModes(state);
const rasterAdapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const initImage = await manager.compositor.getCompositeImageDTO(rasterAdapters, rect, {
is_intermediate: true,
silent: true,
});
externalNode.init_image = { image_name: initImage.image_name };
if (generationMode === 'inpaint' || generationMode === 'outpaint') {
const inpaintMaskAdapters = manager.compositor.getVisibleAdaptersOfType('inpaint_mask');
const maskImage = await manager.compositor.getGrayscaleMaskCompositeImageDTO(
inpaintMaskAdapters,
rect,
'denoiseLimit',
canvasSettings.preserveMask,
{
is_intermediate: true,
silent: true,
}
);
externalNode.mask_image = { image_name: maskImage.image_name };
}
}
g.updateNode(externalNode as AnyInvocation, selectCanvasOutputFields(state));
return {
g,
seed: seed ?? undefined,
positivePrompt: positivePrompt as Invocation<'string'>,
};
};

View File

@@ -13,6 +13,7 @@ import {
} from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { addZImageControl } from 'features/nodes/util/graph/generation/addControlAdapters';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
@@ -77,7 +78,7 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
model,
vae_model: zImageVaeModel ?? undefined,
qwen3_encoder_model: zImageQwen3EncoderModel ?? undefined,
qwen3_source_model: zImageQwen3SourceModel ?? undefined,
qwen3_source_model: (zImageQwen3SourceModel as ModelIdentifierField | null) ?? undefined,
});
const positivePrompt = g.addNode({

View File

@@ -2,6 +2,7 @@ import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { selectCLIPSkip, selectModel, setClipSkip } from 'features/controlLayers/store/paramsSlice';
import type { BaseModelType } from 'features/nodes/types/common';
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -32,14 +33,14 @@ const ParamClipSkip = () => {
if (!model) {
return CLIP_SKIP_MAP['sd-1']?.maxClip;
}
return CLIP_SKIP_MAP[model.base]?.maxClip;
return CLIP_SKIP_MAP[model.base as BaseModelType]?.maxClip;
}, [model]);
const sliderMarks = useMemo(() => {
if (!model) {
return CLIP_SKIP_MAP['sd-1']?.markers;
}
return CLIP_SKIP_MAP[model.base]?.markers;
return CLIP_SKIP_MAP[model.base as BaseModelType]?.markers;
}, [model]);
if (model?.base === 'sdxl') {

View File

@@ -9,7 +9,7 @@ import {
zImageQwen3SourceModelSelected,
zImageVaeModelSelected,
} from 'features/controlLayers/store/paramsSlice';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { type ModelIdentifierField, zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useFlux1VAEModels, useQwen3EncoderModels, useZImageDiffusersModels } from 'services/api/hooks/modelsByType';
@@ -136,7 +136,7 @@ const ParamZImageQwen3SourceModelSelect = memo(() => {
const { options, value, onChange, noOptionsMessage } = useModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: zImageQwen3SourceModel,
selectedModel: zImageQwen3SourceModel as ModelIdentifierField | null,
isLoading,
});

View File

@@ -0,0 +1,42 @@
import type { ExternalApiModelConfig } from 'services/api/types';
import { describe, expect, test } from 'vitest';
const createExternalModel = (overrides: Partial<ExternalApiModelConfig> = {}): ExternalApiModelConfig => ({
key: 'external-test',
name: 'External Test',
base: 'external',
type: 'external_image_generator',
format: 'external_api',
provider_id: 'gemini',
provider_model_id: 'gemini-2.5-flash-image',
description: 'Test model',
source: 'external://gemini/gemini-2.5-flash-image',
source_type: 'url',
source_api_response: null,
path: '',
file_size: 0,
hash: 'external:gemini:gemini-2.5-flash-image',
cover_image: null,
is_default: false,
tags: ['external'],
capabilities: {
modes: ['txt2img'],
supports_reference_images: false,
supports_seed: true,
max_images_per_request: 1,
max_image_size: null,
allowed_aspect_ratios: ['1:1', '16:9'],
max_reference_images: null,
mask_format: 'none',
input_image_required_for: null,
},
default_settings: null,
...overrides,
});
describe('external model aspect ratios (bbox)', () => {
test('uses allowed aspect ratios for external models', () => {
const model = createExternalModel();
expect(model.capabilities.allowed_aspect_ratios).toEqual(['1:1', '16:9']);
});
});

View File

@@ -3,10 +3,16 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { bboxAspectRatioIdChanged } from 'features/controlLayers/store/canvasSlice';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
aspectRatioIdChanged,
selectAllowedAspectRatioIDs,
selectAspectRatioSizes,
selectHasFixedDimensionSizes,
} from 'features/controlLayers/store/paramsSlice';
import { selectAspectRatioID } from 'features/controlLayers/store/selectors';
import { isAspectRatioID, zAspectRatioID } from 'features/controlLayers/store/types';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
@@ -15,24 +21,33 @@ export const BboxAspectRatioSelect = memo(() => {
const dispatch = useAppDispatch();
const id = useAppSelector(selectAspectRatioID);
const isStaging = useCanvasIsStaging();
const allowedAspectRatios = useAppSelector(selectAllowedAspectRatioIDs);
const aspectRatioSizes = useAppSelector(selectAspectRatioSizes);
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
const options = useMemo(() => allowedAspectRatios ?? zAspectRatioID.options, [allowedAspectRatios]);
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => {
if (!isAspectRatioID(e.target.value)) {
return;
}
dispatch(bboxAspectRatioIdChanged({ id: e.target.value }));
const fixedSize = aspectRatioSizes?.[e.target.value] ?? undefined;
dispatch(bboxAspectRatioIdChanged({ id: e.target.value, fixedSize }));
// For external models with fixed sizes, also sync to params so buildExternalGraph uses correct dimensions
if (fixedSize) {
dispatch(aspectRatioIdChanged({ id: e.target.value, fixedSize }));
}
},
[dispatch]
[dispatch, aspectRatioSizes]
);
return (
<FormControl isDisabled={isStaging}>
<FormControl isDisabled={isStaging || hasFixedSizes}>
<InformationalPopover feature="paramAspect">
<FormLabel>{t('parameters.aspect')}</FormLabel>
</InformationalPopover>
<Select size="sm" value={id} onChange={onChange} cursor="pointer" iconSize="0.75rem" icon={<PiCaretDownBold />}>
{zAspectRatioID.options.map((ratio) => (
{options.map((ratio) => (
<option key={ratio} value={ratio}>
{ratio}
</option>

View File

@@ -1,7 +1,8 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { bboxDimensionsSwapped } from 'features/controlLayers/store/canvasSlice';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectHasFixedDimensionSizes } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsDownUpBold } from 'react-icons/pi';
@@ -10,6 +11,7 @@ export const BboxSwapDimensionsButton = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isStaging = useCanvasIsStaging();
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
const onClick = useCallback(() => {
dispatch(bboxDimensionsSwapped());
}, [dispatch]);
@@ -21,7 +23,7 @@ export const BboxSwapDimensionsButton = memo(() => {
variant="ghost"
size="sm"
icon={<PiArrowsDownUpBold />}
isDisabled={isStaging}
isDisabled={isStaging || hasFixedSizes}
/>
);
});

View File

@@ -1,6 +1,9 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectHasFixedDimensionSizes } from 'features/controlLayers/store/paramsSlice';
export const useIsBboxSizeLocked = () => {
const isStaging = useCanvasIsStaging();
return isStaging;
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
return isStaging || hasFixedSizes;
};

View File

@@ -0,0 +1,42 @@
import type { ExternalApiModelConfig } from 'services/api/types';
import { describe, expect, test } from 'vitest';
const createExternalModel = (overrides: Partial<ExternalApiModelConfig> = {}): ExternalApiModelConfig => ({
key: 'external-test',
name: 'External Test',
base: 'external',
type: 'external_image_generator',
format: 'external_api',
provider_id: 'gemini',
provider_model_id: 'gemini-2.5-flash-image',
description: 'Test model',
source: 'external://gemini/gemini-2.5-flash-image',
source_type: 'url',
source_api_response: null,
path: '',
file_size: 0,
hash: 'external:gemini:gemini-2.5-flash-image',
cover_image: null,
is_default: false,
tags: ['external'],
capabilities: {
modes: ['txt2img'],
supports_reference_images: false,
supports_seed: true,
max_images_per_request: 1,
max_image_size: null,
allowed_aspect_ratios: ['1:1', '16:9'],
max_reference_images: null,
mask_format: 'none',
input_image_required_for: null,
},
default_settings: null,
...overrides,
});
describe('external model aspect ratios', () => {
test('uses allowed aspect ratios for external models', () => {
const model = createExternalModel();
expect(model.capabilities.allowed_aspect_ratios).toEqual(['1:1', '16:9']);
});
});

View File

@@ -1,7 +1,12 @@
import { FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { aspectRatioIdChanged, selectAspectRatioID } from 'features/controlLayers/store/paramsSlice';
import {
aspectRatioIdChanged,
selectAllowedAspectRatioIDs,
selectAspectRatioID,
selectAspectRatioSizes,
} from 'features/controlLayers/store/paramsSlice';
import { isAspectRatioID, zAspectRatioID } from 'features/controlLayers/store/types';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback } from 'react';
@@ -12,15 +17,19 @@ export const DimensionsAspectRatioSelect = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const id = useAppSelector(selectAspectRatioID);
const allowedAspectRatios = useAppSelector(selectAllowedAspectRatioIDs);
const aspectRatioSizes = useAppSelector(selectAspectRatioSizes);
const options = allowedAspectRatios ?? zAspectRatioID.options;
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => {
if (!isAspectRatioID(e.target.value)) {
return;
}
dispatch(aspectRatioIdChanged({ id: e.target.value }));
const fixedSize = aspectRatioSizes?.[e.target.value] ?? undefined;
dispatch(aspectRatioIdChanged({ id: e.target.value, fixedSize }));
},
[dispatch]
[dispatch, aspectRatioSizes]
);
return (
@@ -29,7 +38,7 @@ export const DimensionsAspectRatioSelect = memo(() => {
<FormLabel>{t('parameters.aspect')}</FormLabel>
</InformationalPopover>
<Select size="sm" value={id} onChange={onChange} cursor="pointer" iconSize="0.75rem" icon={<PiCaretDownBold />}>
{zAspectRatioID.options.map((ratio) => (
{options.map((ratio) => (
<option key={ratio} value={ratio}>
{ratio}
</option>

View File

@@ -1,7 +1,7 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { heightChanged, selectHeight } from 'features/controlLayers/store/paramsSlice';
import { heightChanged, selectHasFixedDimensionSizes, selectHeight } from 'features/controlLayers/store/paramsSlice';
import { selectGridSize, selectOptimalDimension } from 'features/controlLayers/store/selectors';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -22,6 +22,7 @@ export const DimensionsHeight = memo(() => {
const optimalDimension = useAppSelector(selectOptimalDimension);
const height = useAppSelector(selectHeight);
const gridSize = useAppSelector(selectGridSize);
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
const onChange = useCallback(
(v: number) => {
@@ -33,7 +34,7 @@ export const DimensionsHeight = memo(() => {
const marks = useMemo(() => [CONSTRAINTS.sliderMin, optimalDimension, CONSTRAINTS.sliderMax], [optimalDimension]);
return (
<FormControl>
<FormControl isDisabled={hasFixedSizes}>
<InformationalPopover feature="paramHeight">
<FormLabel>{t('parameters.height')}</FormLabel>
</InformationalPopover>

View File

@@ -1,6 +1,10 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { aspectRatioLockToggled, selectAspectRatioIsLocked } from 'features/controlLayers/store/paramsSlice';
import {
aspectRatioLockToggled,
selectAspectRatioIsLocked,
selectHasFixedDimensionSizes,
} from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiLockSimpleFill, PiLockSimpleOpenBold } from 'react-icons/pi';
@@ -9,6 +13,7 @@ export const DimensionsLockAspectRatioButton = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isLocked = useAppSelector(selectAspectRatioIsLocked);
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
const onClick = useCallback(() => {
dispatch(aspectRatioLockToggled());
@@ -22,6 +27,7 @@ export const DimensionsLockAspectRatioButton = memo(() => {
variant={isLocked ? 'outline' : 'ghost'}
size="sm"
icon={isLocked ? <PiLockSimpleFill /> : <PiLockSimpleOpenBold />}
isDisabled={hasFixedSizes}
/>
);
});

View File

@@ -1,6 +1,11 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectHeight, selectWidth, sizeOptimized } from 'features/controlLayers/store/paramsSlice';
import {
selectHasFixedDimensionSizes,
selectHeight,
selectWidth,
sizeOptimized,
} from 'features/controlLayers/store/paramsSlice';
import { selectOptimalDimension } from 'features/controlLayers/store/selectors';
import { getIsSizeTooLarge, getIsSizeTooSmall } from 'features/parameters/util/optimalDimension';
import { memo, useCallback, useMemo } from 'react';
@@ -13,6 +18,7 @@ export const DimensionsSetOptimalSizeButton = memo(() => {
const width = useAppSelector(selectWidth);
const height = useAppSelector(selectHeight);
const optimalDimension = useAppSelector(selectOptimalDimension);
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
const isSizeTooSmall = useMemo(
() => getIsSizeTooSmall(width, height, optimalDimension),
[height, width, optimalDimension]
@@ -43,6 +49,7 @@ export const DimensionsSetOptimalSizeButton = memo(() => {
size="sm"
icon={<PiSparkleFill />}
colorScheme={isSizeTooSmall || isSizeTooLarge ? 'warning' : 'base'}
isDisabled={hasFixedSizes}
/>
);
});

View File

@@ -1,6 +1,6 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { dimensionsSwapped } from 'features/controlLayers/store/paramsSlice';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { dimensionsSwapped, selectHasFixedDimensionSizes } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsDownUpBold } from 'react-icons/pi';
@@ -8,6 +8,7 @@ import { PiArrowsDownUpBold } from 'react-icons/pi';
export const DimensionsSwapButton = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
const onClick = useCallback(() => {
dispatch(dimensionsSwapped());
}, [dispatch]);
@@ -19,6 +20,7 @@ export const DimensionsSwapButton = memo(() => {
variant="ghost"
size="sm"
icon={<PiArrowsDownUpBold />}
isDisabled={hasFixedSizes}
/>
);
});

View File

@@ -1,7 +1,7 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { selectWidth, widthChanged } from 'features/controlLayers/store/paramsSlice';
import { selectHasFixedDimensionSizes, selectWidth, widthChanged } from 'features/controlLayers/store/paramsSlice';
import { selectGridSize, selectOptimalDimension } from 'features/controlLayers/store/selectors';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -22,6 +22,7 @@ export const DimensionsWidth = memo(() => {
const width = useAppSelector(selectWidth);
const optimalDimension = useAppSelector(selectOptimalDimension);
const gridSize = useAppSelector(selectGridSize);
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
const onChange = useCallback(
(v: number) => {
@@ -33,7 +34,7 @@ export const DimensionsWidth = memo(() => {
const marks = useMemo(() => [CONSTRAINTS.sliderMin, optimalDimension, CONSTRAINTS.sliderMax], [optimalDimension]);
return (
<FormControl>
<FormControl isDisabled={hasFixedSizes}>
<InformationalPopover feature="paramWidth">
<FormLabel>{t('parameters.width')}</FormLabel>
</InformationalPopover>

View File

@@ -0,0 +1,94 @@
import { FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
resolutionPresetSelected,
selectAspectRatioID,
selectImageSize,
selectResolutionPresets,
} from 'features/controlLayers/store/paramsSlice';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
const makeKey = (aspectRatio: string, imageSize: string) => `${aspectRatio}|${imageSize}`;
export const ExternalModelImageSizeSelect = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const presets = useAppSelector(selectResolutionPresets);
const currentAspectRatio = useAppSelector(selectAspectRatioID);
const currentImageSize = useAppSelector(selectImageSize);
const presetMap = useMemo(() => {
if (!presets) {
return null;
}
const map = new Map<string, (typeof presets)[number]>();
for (const preset of presets) {
map.set(makeKey(preset.aspect_ratio, preset.image_size), preset);
}
return map;
}, [presets]);
const selectedKey = useMemo(() => {
if (!presets || presets.length === 0) {
return '';
}
if (currentImageSize && currentAspectRatio) {
const key = makeKey(currentAspectRatio, currentImageSize);
if (presetMap?.has(key)) {
return key;
}
}
// Fallback to first preset
return makeKey(presets[0]!.aspect_ratio, presets[0]!.image_size);
}, [presets, presetMap, currentAspectRatio, currentImageSize]);
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => {
const preset = presetMap?.get(e.target.value);
if (!preset) {
return;
}
dispatch(
resolutionPresetSelected({
imageSize: preset.image_size,
aspectRatio: preset.aspect_ratio,
width: preset.width,
height: preset.height,
})
);
},
[dispatch, presetMap]
);
if (!presets || presets.length === 0) {
return null;
}
return (
<FormControl>
<FormLabel>{t('parameters.resolution')}</FormLabel>
<Select
size="sm"
value={selectedKey}
onChange={onChange}
cursor="pointer"
iconSize="0.75rem"
icon={<PiCaretDownBold />}
>
{presets.map((preset) => {
const key = makeKey(preset.aspect_ratio, preset.image_size);
return (
<option key={key} value={key}>
{preset.label}
</option>
);
})}
</Select>
</FormControl>
);
});
ExternalModelImageSizeSelect.displayName = 'ExternalModelImageSizeSelect';

View File

@@ -0,0 +1,68 @@
import { FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
aspectRatioIdChanged,
selectAspectRatioID,
selectAspectRatioSizes,
} from 'features/controlLayers/store/paramsSlice';
import { isAspectRatioID } from 'features/controlLayers/store/types';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
export const ExternalModelResolutionSelect = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const aspectRatioID = useAppSelector(selectAspectRatioID);
const aspectRatioSizes = useAppSelector(selectAspectRatioSizes);
const options = useMemo(() => {
if (!aspectRatioSizes) {
return [];
}
return Object.entries(aspectRatioSizes).map(([ratio, size]) => ({
ratio,
label: `${ratio} (${size.width}×${size.height})`,
size,
}));
}, [aspectRatioSizes]);
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => {
const ratio = e.target.value;
if (!isAspectRatioID(ratio)) {
return;
}
const fixedSize = aspectRatioSizes?.[ratio] ?? undefined;
dispatch(aspectRatioIdChanged({ id: ratio, fixedSize }));
},
[dispatch, aspectRatioSizes]
);
if (!aspectRatioSizes) {
return null;
}
return (
<FormControl>
<FormLabel>{t('parameters.resolution')}</FormLabel>
<Select
size="sm"
value={aspectRatioID}
onChange={onChange}
cursor="pointer"
iconSize="0.75rem"
icon={<PiCaretDownBold />}
>
{options.map(({ ratio, label }) => (
<option key={ratio} value={ratio}>
{label}
</option>
))}
</Select>
</FormControl>
);
});
ExternalModelResolutionSelect.displayName = 'ExternalModelResolutionSelect';

View File

@@ -0,0 +1,74 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
geminiTemperatureChanged,
geminiThinkingLevelChanged,
selectGeminiTemperature,
selectGeminiThinkingLevel,
} from 'features/controlLayers/store/paramsSlice';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
const TEMPERATURE_MARKS = [0, 1, 2];
export const GeminiProviderOptions = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const temperature = useAppSelector(selectGeminiTemperature);
const thinkingLevel = useAppSelector(selectGeminiThinkingLevel);
const onTemperatureChange = useCallback((v: number) => dispatch(geminiTemperatureChanged(v)), [dispatch]);
const onThinkingLevelChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => {
const value = e.target.value;
dispatch(geminiThinkingLevelChanged(value === '' ? null : (value as 'minimal' | 'high')));
},
[dispatch]
);
return (
<>
<FormControl>
<FormLabel>{t('parameters.temperature', 'Temperature')}</FormLabel>
<CompositeSlider
value={temperature ?? 1}
defaultValue={1}
min={0}
max={2}
step={0.1}
fineStep={0.05}
onChange={onTemperatureChange}
marks={TEMPERATURE_MARKS}
/>
<CompositeNumberInput
value={temperature ?? 1}
defaultValue={1}
min={0}
max={2}
step={0.1}
fineStep={0.05}
onChange={onTemperatureChange}
/>
</FormControl>
<FormControl>
<FormLabel>{t('parameters.thinkingLevel', 'Thinking Level')}</FormLabel>
<Select
size="sm"
value={thinkingLevel ?? ''}
onChange={onThinkingLevelChange}
icon={<PiCaretDownBold />}
iconSize="0.75rem"
>
<option value="">Default</option>
<option value="minimal">Minimal</option>
<option value="high">High</option>
</Select>
</FormControl>
</>
);
});
GeminiProviderOptions.displayName = 'GeminiProviderOptions';

View File

@@ -0,0 +1,84 @@
import { FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
openaiBackgroundChanged,
openaiInputFidelityChanged,
openaiQualityChanged,
selectOpenaiBackground,
selectOpenaiInputFidelity,
selectOpenaiQuality,
} from 'features/controlLayers/store/paramsSlice';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
export const OpenAIProviderOptions = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const quality = useAppSelector(selectOpenaiQuality);
const background = useAppSelector(selectOpenaiBackground);
const inputFidelity = useAppSelector(selectOpenaiInputFidelity);
const onQualityChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => dispatch(openaiQualityChanged(e.target.value as 'auto' | 'high' | 'medium' | 'low')),
[dispatch]
);
const onBackgroundChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => dispatch(openaiBackgroundChanged(e.target.value as 'auto' | 'transparent' | 'opaque')),
[dispatch]
);
const onInputFidelityChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => {
const value = e.target.value;
dispatch(openaiInputFidelityChanged(value === '' ? null : (value as 'low' | 'high')));
},
[dispatch]
);
return (
<>
<FormControl>
<FormLabel>{t('parameters.quality', 'Quality')}</FormLabel>
<Select size="sm" value={quality} onChange={onQualityChange} icon={<PiCaretDownBold />} iconSize="0.75rem">
<option value="auto">Auto</option>
<option value="high">High</option>
<option value="medium">Medium</option>
<option value="low">Low</option>
</Select>
</FormControl>
<FormControl>
<FormLabel>{t('parameters.background', 'Background')}</FormLabel>
<Select
size="sm"
value={background}
onChange={onBackgroundChange}
icon={<PiCaretDownBold />}
iconSize="0.75rem"
>
<option value="auto">Auto</option>
<option value="transparent">Transparent</option>
<option value="opaque">Opaque</option>
</Select>
</FormControl>
<FormControl>
<FormLabel>{t('parameters.inputFidelity', 'Input Fidelity')}</FormLabel>
<Select
size="sm"
value={inputFidelity ?? ''}
onChange={onInputFidelityChange}
icon={<PiCaretDownBold />}
iconSize="0.75rem"
>
<option value="">Default</option>
<option value="low">Low</option>
<option value="high">High</option>
</Select>
</FormControl>
</>
);
});
OpenAIProviderOptions.displayName = 'OpenAIProviderOptions';

View File

@@ -0,0 +1,60 @@
import type {
ExternalApiModelConfig,
ExternalApiModelDefaultSettings,
ExternalImageSize,
ExternalModelCapabilities,
} from 'services/api/types';
import { describe, expect, it } from 'vitest';
import { isExternalModelUnsupportedForTab } from './mainModelPickerUtils';
const createExternalConfig = (modes: ExternalModelCapabilities['modes']): ExternalApiModelConfig => {
const maxImageSize: ExternalImageSize = { width: 1024, height: 1024 };
const defaultSettings: ExternalApiModelDefaultSettings = { width: 1024, height: 1024, steps: 30 };
return {
key: 'external-test',
hash: 'external:openai:gpt-image-1',
path: 'external://openai/gpt-image-1',
file_size: 0,
name: 'External Test',
description: null,
source: 'external://openai/gpt-image-1',
source_type: 'url',
source_api_response: null,
cover_image: null,
base: 'external',
type: 'external_image_generator',
format: 'external_api',
provider_id: 'openai',
provider_model_id: 'gpt-image-1',
capabilities: {
modes,
supports_reference_images: false,
max_image_size: maxImageSize,
},
default_settings: defaultSettings,
tags: ['external'],
is_default: false,
};
};
describe('isExternalModelUnsupportedForTab', () => {
it('disables external models without txt2img for generate', () => {
const model = createExternalConfig(['img2img', 'inpaint']);
expect(isExternalModelUnsupportedForTab(model, 'generate')).toBe(true);
});
it('allows external models with txt2img for generate', () => {
const model = createExternalConfig(['txt2img']);
expect(isExternalModelUnsupportedForTab(model, 'generate')).toBe(false);
});
it('allows external models on canvas', () => {
const model = createExternalConfig(['inpaint']);
expect(isExternalModelUnsupportedForTab(model, 'canvas')).toBe(false);
});
});

View File

@@ -0,0 +1,14 @@
import type { TabName } from 'features/ui/store/uiTypes';
import { type AnyModelConfigWithExternal, isExternalApiModelConfig } from 'services/api/types';
export const isExternalModelUnsupportedForTab = (model: AnyModelConfigWithExternal, tab: TabName): boolean => {
if (!isExternalApiModelConfig(model)) {
return false;
}
if (tab === 'generate') {
return !model.capabilities.modes.includes('txt2img');
}
return false;
};

View File

@@ -1,5 +1,6 @@
import type { BoxProps, ButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
import {
Badge,
Button,
Flex,
Icon,
@@ -36,7 +37,11 @@ import { Trans, useTranslation } from 'react-i18next';
import { PiCaretDownBold, PiLinkSimple } from 'react-icons/pi';
import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
import type { AnyModelConfig } from 'services/api/types';
import {
type AnyModelConfigWithExternal,
type ExternalApiModelConfig,
isExternalApiModelConfig,
} from 'services/api/types';
const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
const keys: string[] = [];
@@ -67,7 +72,7 @@ const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, select
type WithStarred<T> = T & { starred?: boolean };
// Type for models with starred field
const getOptionId = <T extends AnyModelConfig>(modelConfig: WithStarred<T>) => modelConfig.key;
const getOptionId = <T extends AnyModelConfigWithExternal>(modelConfig: WithStarred<T>) => modelConfig.key;
const ModelManagerLink = memo((props: ButtonProps) => {
const onClick = useCallback(() => {
@@ -123,19 +128,17 @@ const NoOptionsFallback = memo(({ noOptionsText }: { noOptionsText?: string }) =
});
NoOptionsFallback.displayName = 'NoOptionsFallback';
const getGroupIDFromModelConfig = (modelConfig: AnyModelConfig): string => {
return modelConfig.base;
};
const getGroupIDFromModelConfig = (modelConfig: AnyModelConfigWithExternal): string => modelConfig.base;
const getGroupNameFromModelConfig = (modelConfig: AnyModelConfig): string => {
const getGroupNameFromModelConfig = (modelConfig: AnyModelConfigWithExternal): string => {
return MODEL_BASE_TO_LONG_NAME[modelConfig.base];
};
const getGroupShortNameFromModelConfig = (modelConfig: AnyModelConfig): string => {
const getGroupShortNameFromModelConfig = (modelConfig: AnyModelConfigWithExternal): string => {
return MODEL_BASE_TO_SHORT_NAME[modelConfig.base];
};
const getGroupColorSchemeFromModelConfig = (modelConfig: AnyModelConfig): string => {
const getGroupColorSchemeFromModelConfig = (modelConfig: AnyModelConfigWithExternal): string => {
return MODEL_BASE_TO_COLOR[modelConfig.base];
};
@@ -162,7 +165,7 @@ const removeStarred = <T,>(obj: WithStarred<T>): T => {
};
export const ModelPicker = typedMemo(
<T extends AnyModelConfig = AnyModelConfig>({
<T extends AnyModelConfigWithExternal = AnyModelConfigWithExternal>({
pickerId,
modelConfigs,
selectedModelConfig,
@@ -414,8 +417,10 @@ const optionNameSx: SystemStyleObject = {
};
const PickerOptionComponent = typedMemo(
<T extends AnyModelConfig>({ option, ...rest }: { option: WithStarred<T> } & BoxProps) => {
<T extends AnyModelConfigWithExternal>({ option, ...rest }: { option: WithStarred<T> } & BoxProps) => {
const { isCompactView } = usePickerContext<WithStarred<T>>();
const externalOption = isExternalApiModelConfig(option) ? (option as ExternalApiModelConfig) : null;
const providerLabel = externalOption ? externalOption.provider_id.toUpperCase() : null;
return (
<Flex {...rest} sx={optionSx} data-is-compact={isCompactView}>
@@ -426,6 +431,15 @@ const PickerOptionComponent = typedMemo(
<Text className="picker-option" sx={optionNameSx} data-is-compact={isCompactView}>
{option.name}
</Text>
{!isCompactView && externalOption && (
<Badge
colorScheme={MODEL_BASE_TO_COLOR[externalOption.base as BaseModelType]}
variant="subtle"
flexShrink={0}
>
{providerLabel}
</Badge>
)}
<Spacer />
{option.file_size > 0 && (
<Text
@@ -458,11 +472,13 @@ const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
'sd-3': ['sd3', 'sd3.0', 'sd3.5', 'sd-3'],
};
const isMatch = <T extends AnyModelConfig>(model: WithStarred<T>, searchTerm: string) => {
const isMatch = <T extends AnyModelConfigWithExternal>(model: WithStarred<T>, searchTerm: string) => {
const regex = getRegex(searchTerm);
const bases = BASE_KEYWORDS[model.base] ?? [model.base];
const externalModel = isExternalApiModelConfig(model) ? (model as ExternalApiModelConfig) : null;
const externalSearch = externalModel ? ` ${externalModel.provider_id} ${externalModel.provider_model_id}` : '';
const testString =
`${model.name} ${bases.join(' ')} ${model.type} ${model.description ?? ''} ${model.format}`.toLowerCase();
`${model.name} ${bases.join(' ')} ${model.type} ${model.description ?? ''} ${model.format}${externalSearch}`.toLowerCase();
if (testString.includes(searchTerm) || regex.test(testString)) {
return true;

View File

@@ -2,13 +2,19 @@ import { CompositeNumberInput, FormControl, FormLabel } from '@invoke-ai/ui-libr
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { selectSeed, selectShouldRandomizeSeed, setSeed } from 'features/controlLayers/store/paramsSlice';
import {
selectSeed,
selectSeedControl,
selectShouldRandomizeSeed,
setSeed,
} from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const ParamSeedNumberInput = memo(() => {
const seed = useAppSelector(selectSeed);
const shouldRandomizeSeed = useAppSelector(selectShouldRandomizeSeed);
const externalControl = useAppSelector(selectSeedControl);
const { t } = useTranslation();
@@ -22,9 +28,10 @@ export const ParamSeedNumberInput = memo(() => {
<FormLabel>{t('parameters.seed')}</FormLabel>
</InformationalPopover>
<CompositeNumberInput
step={1}
min={NUMPY_RAND_MIN}
max={NUMPY_RAND_MAX}
step={externalControl?.coarse_step ?? 1}
fineStep={externalControl?.fine_step ?? undefined}
min={externalControl?.number_input_min ?? NUMPY_RAND_MIN}
max={externalControl?.number_input_max ?? NUMPY_RAND_MAX}
onChange={handleChangeSeed}
value={seed}
flexGrow={1}

View File

@@ -3,6 +3,7 @@ import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { buildZodTypeGuard } from 'common/util/zodUtils';
import {
zAnimaSchedulerField,
zExternalModelIdentifierField,
zFluxDypeExponentField,
zFluxDypePresetField,
zFluxDypeScaleField,
@@ -121,7 +122,7 @@ export const isParameterHeight = isParameterImageDimension;
// #endregion
// #region Model
export const zParameterModel = zModelIdentifierField;
export const zParameterModel = z.union([zModelIdentifierField, zExternalModelIdentifierField]);
export type ParameterModel = z.infer<typeof zParameterModel>;
// #endregion

View File

@@ -0,0 +1,33 @@
import type {
ExternalApiModelConfig,
ExternalModelCapabilities,
ExternalModelPanelControl,
ExternalModelPanelSchema,
ExternalPanelControlName,
} from 'services/api/types';
type ExternalPanelName = keyof ExternalModelPanelSchema;
const buildExternalPanelSchemaFromCapabilities = (
capabilities: ExternalModelCapabilities
): ExternalModelPanelSchema => ({
prompts: [...(capabilities.supports_reference_images ? [{ name: 'reference_images' as const }] : [])],
image: [{ name: 'dimensions' }, ...(capabilities.supports_seed ? [{ name: 'seed' as const }] : [])],
generation: [],
});
const getExternalPanelSchema = (modelConfig: ExternalApiModelConfig): ExternalModelPanelSchema =>
modelConfig.panel_schema ?? buildExternalPanelSchemaFromCapabilities(modelConfig.capabilities);
export const getExternalPanelControl = (
modelConfig: ExternalApiModelConfig,
panel: ExternalPanelName,
controlName: ExternalPanelControlName
): ExternalModelPanelControl | null =>
getExternalPanelSchema(modelConfig)[panel].find((control) => control.name === controlName) ?? null;
export const hasExternalPanelControl = (
modelConfig: ExternalApiModelConfig,
panel: ExternalPanelName,
controlName: ExternalPanelControlName
): boolean => getExternalPanelControl(modelConfig, panel, controlName) !== null;

View File

@@ -7,9 +7,11 @@ import { withResult, withResultAsync } from 'common/util/result';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
import type { BaseModelType } from 'features/nodes/types/common';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildAnimaGraph } from 'features/nodes/util/graph/generation/buildAnimaGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildExternalGraph } from 'features/nodes/util/graph/generation/buildExternalGraph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildQwenImageGraph } from 'features/nodes/util/graph/generation/buildQwenImageGraph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
@@ -63,6 +65,8 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
return await buildQwenImageGraph(graphBuilderArg);
case 'z-image':
return await buildZImageGraph(graphBuilderArg);
case 'external':
return await buildExternalGraph(graphBuilderArg);
case 'anima':
return await buildAnimaGraph(graphBuilderArg);
default:
@@ -97,7 +101,7 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
prepareLinearUIBatch({
state,
g,
base,
base: base as BaseModelType,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,

View File

@@ -5,9 +5,11 @@ import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
import type { BaseModelType } from 'features/nodes/types/common';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildAnimaGraph } from 'features/nodes/util/graph/generation/buildAnimaGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildExternalGraph } from 'features/nodes/util/graph/generation/buildExternalGraph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildQwenImageGraph } from 'features/nodes/util/graph/generation/buildQwenImageGraph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
@@ -56,6 +58,8 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
return await buildQwenImageGraph(graphBuilderArg);
case 'z-image':
return await buildZImageGraph(graphBuilderArg);
case 'external':
return await buildExternalGraph(graphBuilderArg);
case 'anima':
return await buildAnimaGraph(graphBuilderArg);
default:
@@ -90,7 +94,7 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
prepareLinearUIBatch({
state,
g,
base,
base: base as BaseModelType,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,

View File

@@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
import type { BaseModelType } from 'features/nodes/types/common';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
import { useCallback } from 'react';
@@ -26,7 +27,7 @@ const enqueueUpscaling = async (store: AppStore, prepend: boolean) => {
const batchConfig = prepareLinearUIBatch({
state,
g,
base,
base: base as BaseModelType,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,

Some files were not shown because too many files have changed in this diff Show More