mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Compare commits
44 Commits
lstein/rec
...
external-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7938d840b2 | ||
|
|
450ba7b7e1 | ||
|
|
c743106f66 | ||
|
|
cd888654d5 | ||
|
|
ec4b87b949 | ||
|
|
8f00759af0 | ||
|
|
5c09c823a9 | ||
|
|
ec90b2fbe9 | ||
|
|
17157d7c60 | ||
|
|
853c3ef915 | ||
|
|
3e9e052d5d | ||
|
|
089e2db402 | ||
|
|
4cbd60b4a5 | ||
|
|
c2016bcfb7 | ||
|
|
813a5e2c2e | ||
|
|
18315db7f0 | ||
|
|
edde0b4737 | ||
|
|
27fc650f4f | ||
|
|
a1eef791a1 | ||
|
|
d8d0ebc356 | ||
|
|
8375f95ea9 | ||
|
|
9e4d0bb191 | ||
|
|
20a400cee8 | ||
|
|
40f02aa6c4 | ||
|
|
c3a482e80a | ||
|
|
257994f552 | ||
|
|
bafce41856 | ||
|
|
757bd3d002 | ||
|
|
519575e871 | ||
|
|
f39456e6f0 | ||
|
|
689725c6e4 | ||
|
|
10729f40f2 | ||
|
|
362054120e | ||
|
|
b91a156a3d | ||
|
|
c6b0d45c5f | ||
|
|
dc665e08ac | ||
|
|
0dd72837d3 | ||
|
|
d5a6283f23 | ||
|
|
6fe1a6f1ac | ||
|
|
5d34eab6f0 | ||
|
|
1b43769b95 | ||
|
|
a9d3b4e17c | ||
|
|
74ecc461b9 | ||
|
|
19650f6ada |
129
docs/contributing/EXTERNAL_PROVIDERS.md
Normal file
129
docs/contributing/EXTERNAL_PROVIDERS.md
Normal file
@@ -0,0 +1,129 @@
|
||||
# External Provider Integration
|
||||
|
||||
This guide covers:
|
||||
|
||||
1. Adding a new **external model** (most common; existing provider).
|
||||
2. Adding a brand-new **external provider** (adapter + config + UI wiring).
|
||||
|
||||
## 1) Add a New External Model (Existing Provider)
|
||||
|
||||
For provider-backed models (for example, OpenAI or Gemini), the source of truth is
|
||||
`invokeai/backend/model_manager/starter_models.py`.
|
||||
|
||||
### Required model fields
|
||||
|
||||
Define a `StarterModel` with:
|
||||
|
||||
- `base=BaseModelType.External`
|
||||
- `type=ModelType.ExternalImageGenerator`
|
||||
- `format=ModelFormat.ExternalApi`
|
||||
- `source="external://<provider_id>/<provider_model_id>"`
|
||||
- `name`, `description`
|
||||
- `capabilities=ExternalModelCapabilities(...)`
|
||||
- optional `default_settings=ExternalApiModelDefaultSettings(...)`
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
new_external_model = StarterModel(
|
||||
name="Provider Model Name",
|
||||
base=BaseModelType.External,
|
||||
source="external://openai/my-model-id",
|
||||
description=(
|
||||
"Provider model (external API). "
|
||||
"Requires a configured OpenAI API key and may incur provider usage costs."
|
||||
),
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img", "img2img", "inpaint"],
|
||||
supports_negative_prompt=False,
|
||||
supports_seed=False,
|
||||
supports_guidance=False,
|
||||
supports_steps=False,
|
||||
supports_reference_images=True,
|
||||
max_images_per_request=4,
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(
|
||||
width=1024,
|
||||
height=1024,
|
||||
num_images=1,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
Then append it to `STARTER_MODELS`.
|
||||
|
||||
### Required description text
|
||||
|
||||
External starter model descriptions must clearly state:
|
||||
|
||||
- an API key is required
|
||||
- usage may incur provider-side costs
|
||||
|
||||
### Capabilities must be accurate
|
||||
|
||||
These flags directly control UI visibility and request payload fields:
|
||||
|
||||
- `supports_negative_prompt`
|
||||
- `supports_seed`
|
||||
- `supports_guidance`
|
||||
- `supports_steps`
|
||||
- `supports_reference_images`
|
||||
|
||||
`supports_steps` is especially important: if `False`, steps are hidden for that model and `steps` is sent as `null`.
|
||||
|
||||
### Source string stability
|
||||
|
||||
Starter overrides are matched by `source` (`external://provider/model-id`). Keep this stable:
|
||||
|
||||
- runtime capability/default overrides depend on it
|
||||
- installation detection in starter-model APIs depends on it
|
||||
|
||||
`STARTER_MODELS` enforces unique `source` values with an assertion.
|
||||
|
||||
### Install behavior notes
|
||||
|
||||
- External starter models are managed in **External Providers** setup (not the regular Starter Models tab).
|
||||
- External starter models auto-install when a provider is configured.
|
||||
- Removing a provider API key removes installed external models for that provider.
|
||||
|
||||
## 2) Credentials and Config
|
||||
|
||||
External provider API keys are stored separately from `invokeai.yaml`:
|
||||
|
||||
- default file: `~/invokeai/api_keys.yaml`
|
||||
- resolved path: `<INVOKEAI_ROOT>/api_keys.yaml`
|
||||
|
||||
Non-secret provider settings (for example base URL overrides) stay in `invokeai.yaml`.
|
||||
|
||||
Environment variables are still supported, e.g.:
|
||||
|
||||
- `INVOKEAI_EXTERNAL_GEMINI_API_KEY`
|
||||
- `INVOKEAI_EXTERNAL_OPENAI_API_KEY`
|
||||
|
||||
## 3) Add a New Provider (Only If Needed)
|
||||
|
||||
If your model uses a provider that is not already integrated:
|
||||
|
||||
1. Add config fields in `invokeai/app/services/config/config_default.py`
|
||||
`external_<provider>_api_key` and optional `external_<provider>_base_url`.
|
||||
2. Add provider field mapping in `invokeai/app/api/routers/app_info.py`
|
||||
(`EXTERNAL_PROVIDER_FIELDS`).
|
||||
3. Implement provider adapter in `invokeai/app/services/external_generation/providers/`
|
||||
by subclassing `ExternalProvider`.
|
||||
4. Register the provider in `invokeai/app/api/dependencies.py` when building
|
||||
`ExternalGenerationService`.
|
||||
5. Add starter model entries using `source="external://<provider>/<model-id>"`.
|
||||
6. Optional UI ordering tweak:
|
||||
`invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm.tsx`
|
||||
(`PROVIDER_SORT_ORDER`).
|
||||
|
||||
## 4) Optional Manual Installation
|
||||
|
||||
You can also install external models directly via:
|
||||
|
||||
`POST /api/v2/models/install?source=external://<provider_id>/<provider_model_id>`
|
||||
|
||||
If omitted, `path`, `source`, and `hash` are auto-populated for external model configs.
|
||||
Set capabilities conservatively; the external generation service enforces capability checks at runtime.
|
||||
@@ -8,6 +8,10 @@ We welcome contributions, whether features, bug fixes, code cleanup, testing, co
|
||||
|
||||
If you’d like to help with development, please see our [development guide](contribution_guides/development.md).
|
||||
|
||||
## External Providers
|
||||
|
||||
If you are adding external image generation providers or configs, see our [external provider integration guide](EXTERNAL_PROVIDERS.md).
|
||||
|
||||
**New Contributors:** If you’re unfamiliar with contributing to open source projects, take a look at our [new contributor guide](contribution_guides/newContributorChecklist.md).
|
||||
|
||||
## Nodes
|
||||
|
||||
@@ -16,6 +16,9 @@ from invokeai.app.services.client_state_persistence.client_state_persistence_sql
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.download.download_default import DownloadQueueService
|
||||
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
|
||||
from invokeai.app.services.external_generation.external_generation_default import ExternalGenerationService
|
||||
from invokeai.app.services.external_generation.providers import AlibabaCloudProvider, GeminiProvider, OpenAIProvider
|
||||
from invokeai.app.services.external_generation.startup import sync_configured_external_starter_models
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images.images_default import ImageService
|
||||
@@ -149,13 +152,23 @@ class ApiDependencies:
|
||||
),
|
||||
)
|
||||
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_record_service = ModelRecordServiceSQL(db=db, logger=logger)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db, logger=logger),
|
||||
model_record_service=model_record_service,
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
external_generation = ExternalGenerationService(
|
||||
providers={
|
||||
AlibabaCloudProvider.provider_id: AlibabaCloudProvider(app_config=configuration, logger=logger),
|
||||
GeminiProvider.provider_id: GeminiProvider(app_config=configuration, logger=logger),
|
||||
OpenAIProvider.provider_id: OpenAIProvider(app_config=configuration, logger=logger),
|
||||
},
|
||||
logger=logger,
|
||||
record_store=model_record_service,
|
||||
)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_relationships = ModelRelationshipsService()
|
||||
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
|
||||
names = SimpleNameService()
|
||||
@@ -188,6 +201,7 @@ class ApiDependencies:
|
||||
model_relationships=model_relationships,
|
||||
model_relationship_records=model_relationship_records,
|
||||
download_queue=download_queue_service,
|
||||
external_generation=external_generation,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
session_processor=session_processor,
|
||||
@@ -204,6 +218,16 @@ class ApiDependencies:
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
configured_external_providers = {
|
||||
provider_id
|
||||
for provider_id, status in external_generation.get_provider_statuses().items()
|
||||
if status.configured
|
||||
}
|
||||
sync_configured_external_starter_models(
|
||||
configured_provider_ids=configured_external_providers,
|
||||
model_manager=model_manager,
|
||||
logger=logger,
|
||||
)
|
||||
db.clean()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,21 +1,30 @@
|
||||
import locale
|
||||
from enum import Enum
|
||||
from importlib.metadata import distributions
|
||||
from pathlib import Path as FilePath
|
||||
from threading import Lock
|
||||
|
||||
import torch
|
||||
from fastapi import Body
|
||||
import yaml
|
||||
from fastapi import Body, HTTPException, Path
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.auth_dependencies import AdminUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.config.config_default import (
|
||||
EXTERNAL_PROVIDER_CONFIG_FIELDS,
|
||||
DefaultInvokeAIAppConfig,
|
||||
InvokeAIAppConfig,
|
||||
get_config,
|
||||
load_and_migrate_config,
|
||||
load_external_api_keys,
|
||||
)
|
||||
from invokeai.app.services.external_generation.external_generation_common import ExternalProviderStatus
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
from invokeai.backend.util.logging import logging
|
||||
from invokeai.version import __version__
|
||||
|
||||
@@ -47,7 +56,7 @@ async def get_version() -> AppVersion:
|
||||
async def get_app_deps() -> dict[str, str]:
|
||||
deps: dict[str, str] = {dist.metadata["Name"]: dist.version for dist in distributions()}
|
||||
try:
|
||||
cuda = torch.version.cuda or "N/A"
|
||||
cuda = getattr(getattr(torch, "version", None), "cuda", None) or "N/A" # pyright: ignore[reportAttributeAccessIssue]
|
||||
except Exception:
|
||||
cuda = "N/A"
|
||||
|
||||
@@ -70,6 +79,31 @@ class InvokeAIAppConfigWithSetFields(BaseModel):
|
||||
config: InvokeAIAppConfig = Field(description="The InvokeAI App Config")
|
||||
|
||||
|
||||
class ExternalProviderStatusModel(BaseModel):
|
||||
provider_id: str = Field(description="The external provider identifier")
|
||||
configured: bool = Field(description="Whether credentials are configured for the provider")
|
||||
message: str | None = Field(default=None, description="Optional provider status detail")
|
||||
|
||||
|
||||
class ExternalProviderConfigUpdate(BaseModel):
|
||||
api_key: str | None = Field(default=None, description="API key for the external provider")
|
||||
base_url: str | None = Field(default=None, description="Optional base URL override for the provider")
|
||||
|
||||
|
||||
class ExternalProviderConfigModel(BaseModel):
|
||||
provider_id: str = Field(description="The external provider identifier")
|
||||
api_key_configured: bool = Field(description="Whether an API key is configured")
|
||||
base_url: str | None = Field(default=None, description="Optional base URL override")
|
||||
|
||||
|
||||
EXTERNAL_PROVIDER_FIELDS: dict[str, tuple[str, str]] = {
|
||||
"alibabacloud": ("external_alibabacloud_api_key", "external_alibabacloud_base_url"),
|
||||
"gemini": ("external_gemini_api_key", "external_gemini_base_url"),
|
||||
"openai": ("external_openai_api_key", "external_openai_base_url"),
|
||||
}
|
||||
_EXTERNAL_PROVIDER_CONFIG_LOCK = Lock()
|
||||
|
||||
|
||||
class UpdateAppGenerationSettingsRequest(BaseModel):
|
||||
"""Writable generation-related app settings."""
|
||||
|
||||
@@ -112,6 +146,166 @@ async def update_runtime_config(
|
||||
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/external_providers/status",
|
||||
operation_id="get_external_provider_statuses",
|
||||
status_code=200,
|
||||
response_model=list[ExternalProviderStatusModel],
|
||||
)
|
||||
async def get_external_provider_statuses() -> list[ExternalProviderStatusModel]:
|
||||
statuses = ApiDependencies.invoker.services.external_generation.get_provider_statuses()
|
||||
return [status_to_model(status) for status in statuses.values()]
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/external_providers/config",
|
||||
operation_id="get_external_provider_configs",
|
||||
status_code=200,
|
||||
response_model=list[ExternalProviderConfigModel],
|
||||
)
|
||||
async def get_external_provider_configs() -> list[ExternalProviderConfigModel]:
|
||||
config = get_config()
|
||||
return [_build_external_provider_config(provider_id, config) for provider_id in EXTERNAL_PROVIDER_FIELDS]
|
||||
|
||||
|
||||
@app_router.post(
|
||||
"/external_providers/config/{provider_id}",
|
||||
operation_id="set_external_provider_config",
|
||||
status_code=200,
|
||||
response_model=ExternalProviderConfigModel,
|
||||
)
|
||||
async def set_external_provider_config(
|
||||
provider_id: str = Path(description="The external provider identifier"),
|
||||
update: ExternalProviderConfigUpdate = Body(description="External provider configuration settings"),
|
||||
) -> ExternalProviderConfigModel:
|
||||
api_key_field, base_url_field = _get_external_provider_fields(provider_id)
|
||||
updates: dict[str, str | None] = {}
|
||||
|
||||
if update.api_key is not None:
|
||||
api_key = update.api_key.strip()
|
||||
updates[api_key_field] = api_key or None
|
||||
if update.base_url is not None:
|
||||
base_url = update.base_url.strip()
|
||||
updates[base_url_field] = base_url or None
|
||||
|
||||
if not updates:
|
||||
raise HTTPException(status_code=400, detail="No external provider config fields provided")
|
||||
|
||||
api_key_removed = update.api_key is not None and updates.get(api_key_field) is None
|
||||
_apply_external_provider_update(updates)
|
||||
if api_key_removed:
|
||||
_remove_external_models_for_provider(provider_id)
|
||||
return _build_external_provider_config(provider_id, get_config())
|
||||
|
||||
|
||||
@app_router.delete(
|
||||
"/external_providers/config/{provider_id}",
|
||||
operation_id="reset_external_provider_config",
|
||||
status_code=200,
|
||||
response_model=ExternalProviderConfigModel,
|
||||
)
|
||||
async def reset_external_provider_config(
|
||||
provider_id: str = Path(description="The external provider identifier"),
|
||||
) -> ExternalProviderConfigModel:
|
||||
api_key_field, base_url_field = _get_external_provider_fields(provider_id)
|
||||
_apply_external_provider_update({api_key_field: None, base_url_field: None})
|
||||
_remove_external_models_for_provider(provider_id)
|
||||
return _build_external_provider_config(provider_id, get_config())
|
||||
|
||||
|
||||
def status_to_model(status: ExternalProviderStatus) -> ExternalProviderStatusModel:
|
||||
return ExternalProviderStatusModel(
|
||||
provider_id=status.provider_id,
|
||||
configured=status.configured,
|
||||
message=status.message,
|
||||
)
|
||||
|
||||
|
||||
def _get_external_provider_fields(provider_id: str) -> tuple[str, str]:
|
||||
if provider_id not in EXTERNAL_PROVIDER_FIELDS:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown external provider '{provider_id}'")
|
||||
return EXTERNAL_PROVIDER_FIELDS[provider_id]
|
||||
|
||||
|
||||
def _write_external_api_keys_file(api_keys_file_path: FilePath, api_keys: dict[str, str]) -> None:
|
||||
if not api_keys:
|
||||
if api_keys_file_path.exists():
|
||||
api_keys_file_path.unlink()
|
||||
return
|
||||
|
||||
api_keys_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(api_keys_file_path, "w", encoding=locale.getpreferredencoding()) as api_keys_file:
|
||||
yaml.safe_dump(api_keys, api_keys_file, sort_keys=False)
|
||||
|
||||
|
||||
def _apply_external_provider_update(updates: dict[str, str | None]) -> None:
|
||||
with _EXTERNAL_PROVIDER_CONFIG_LOCK:
|
||||
runtime_config = get_config()
|
||||
config_path = runtime_config.config_file_path
|
||||
api_keys_file_path = runtime_config.api_keys_file_path
|
||||
if config_path.exists():
|
||||
file_config = load_and_migrate_config(config_path)
|
||||
else:
|
||||
file_config = DefaultInvokeAIAppConfig()
|
||||
|
||||
runtime_config.update_config(updates)
|
||||
provider_config_fields = set(EXTERNAL_PROVIDER_CONFIG_FIELDS)
|
||||
provider_updates = {field: value for field, value in updates.items() if field in provider_config_fields}
|
||||
non_provider_updates = {field: value for field, value in updates.items() if field not in provider_config_fields}
|
||||
|
||||
if non_provider_updates:
|
||||
file_config.update_config(non_provider_updates)
|
||||
|
||||
persisted_api_keys = load_external_api_keys(api_keys_file_path)
|
||||
for field_name in EXTERNAL_PROVIDER_CONFIG_FIELDS:
|
||||
file_value = getattr(file_config, field_name, None)
|
||||
if field_name not in persisted_api_keys and isinstance(file_value, str) and file_value.strip():
|
||||
persisted_api_keys[field_name] = file_value
|
||||
|
||||
for field_name, value in provider_updates.items():
|
||||
if value is None:
|
||||
persisted_api_keys.pop(field_name, None)
|
||||
else:
|
||||
persisted_api_keys[field_name] = value
|
||||
|
||||
_write_external_api_keys_file(api_keys_file_path, persisted_api_keys)
|
||||
|
||||
for field_name in EXTERNAL_PROVIDER_CONFIG_FIELDS:
|
||||
setattr(file_config, field_name, None)
|
||||
|
||||
file_config_to_write = type(file_config).model_validate(
|
||||
file_config.model_dump(exclude_unset=True, exclude_none=True)
|
||||
)
|
||||
file_config_to_write.write_file(config_path, as_example=False)
|
||||
|
||||
|
||||
def _build_external_provider_config(provider_id: str, config: InvokeAIAppConfig) -> ExternalProviderConfigModel:
|
||||
api_key_field, base_url_field = _get_external_provider_fields(provider_id)
|
||||
return ExternalProviderConfigModel(
|
||||
provider_id=provider_id,
|
||||
api_key_configured=bool(getattr(config, api_key_field)),
|
||||
base_url=getattr(config, base_url_field),
|
||||
)
|
||||
|
||||
|
||||
def _remove_external_models_for_provider(provider_id: str) -> None:
|
||||
model_manager = ApiDependencies.invoker.services.model_manager
|
||||
external_models = model_manager.store.search_by_attr(
|
||||
base_model=BaseModelType.External,
|
||||
model_type=ModelType.ExternalImageGenerator,
|
||||
)
|
||||
|
||||
for model in external_models:
|
||||
if getattr(model, "provider_id", None) != provider_id:
|
||||
continue
|
||||
try:
|
||||
model_manager.install.delete(model.key)
|
||||
except UnknownModelException:
|
||||
logging.warning(f"External model key '{model.key}' was already removed while resetting '{provider_id}'")
|
||||
except Exception as error:
|
||||
logging.warning(f"Failed removing external model key '{model.key}' for '{provider_id}': {error}")
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/logging",
|
||||
operation_id="get_log_level",
|
||||
|
||||
@@ -30,6 +30,7 @@ from invokeai.app.services.model_records import (
|
||||
)
|
||||
from invokeai.app.services.orphaned_models import OrphanedModelInfo
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig
|
||||
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
|
||||
from invokeai.backend.model_manager.configs.main import (
|
||||
Main_Checkpoint_SD1_Config,
|
||||
@@ -75,8 +76,36 @@ class CacheType(str, Enum):
|
||||
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||
"""Add a cover image URL to a model configuration."""
|
||||
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
|
||||
config.cover_image = cover_image
|
||||
return config
|
||||
return config.model_copy(update={"cover_image": cover_image})
|
||||
|
||||
|
||||
def apply_external_starter_model_overrides(config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""Overlay starter-model metadata onto installed external model configs."""
|
||||
if not isinstance(config, ExternalApiModelConfig):
|
||||
return config
|
||||
|
||||
starter_match = next((starter for starter in STARTER_MODELS if starter.source == config.source), None)
|
||||
if starter_match is None:
|
||||
return config
|
||||
|
||||
model_updates: dict[str, object] = {}
|
||||
if starter_match.capabilities is not None:
|
||||
model_updates["capabilities"] = starter_match.capabilities
|
||||
if starter_match.default_settings is not None:
|
||||
model_updates["default_settings"] = starter_match.default_settings
|
||||
if starter_match.panel_schema is not None:
|
||||
model_updates["panel_schema"] = starter_match.panel_schema
|
||||
|
||||
if not model_updates:
|
||||
return config
|
||||
|
||||
return config.model_copy(update=model_updates)
|
||||
|
||||
|
||||
def prepare_model_config_for_response(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||
"""Apply API-only model config overlays before returning a response."""
|
||||
config = apply_external_starter_model_overrides(config)
|
||||
return add_cover_image_to_model_config(config, dependencies)
|
||||
|
||||
|
||||
##############################################################################
|
||||
@@ -145,8 +174,8 @@ async def list_model_records(
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
for model in found_models:
|
||||
model = add_cover_image_to_model_config(model, ApiDependencies)
|
||||
for index, model in enumerate(found_models):
|
||||
found_models[index] = prepare_model_config_for_response(model, ApiDependencies)
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@@ -166,6 +195,8 @@ async def list_missing_models() -> ModelsList:
|
||||
|
||||
missing_models: list[AnyModelConfig] = []
|
||||
for model_config in record_store.all_models():
|
||||
if model_config.base == BaseModelType.External or model_config.format == ModelFormat.ExternalApi:
|
||||
continue
|
||||
if not (models_path / model_config.path).resolve().exists():
|
||||
missing_models.append(model_config)
|
||||
|
||||
@@ -190,7 +221,7 @@ async def get_model_records_by_attrs(
|
||||
if not configs:
|
||||
raise HTTPException(status_code=404, detail="No model found with these attributes")
|
||||
|
||||
return configs[0]
|
||||
return prepare_model_config_for_response(configs[0], ApiDependencies)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
@@ -207,7 +238,7 @@ async def get_model_records_by_hash(
|
||||
if not configs:
|
||||
raise HTTPException(status_code=404, detail="No model found with this hash")
|
||||
|
||||
return configs[0]
|
||||
return prepare_model_config_for_response(configs[0], ApiDependencies)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
@@ -228,7 +259,7 @@ async def get_model_record(
|
||||
"""Get a model record"""
|
||||
try:
|
||||
config = ApiDependencies.invoker.services.model_manager.store.get_model(key)
|
||||
return add_cover_image_to_model_config(config, ApiDependencies)
|
||||
return prepare_model_config_for_response(config, ApiDependencies)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -268,7 +299,7 @@ async def reidentify_model(
|
||||
result.config.name = config.name
|
||||
result.config.description = config.description
|
||||
result.config.cover_image = config.cover_image
|
||||
if hasattr(config, "trigger_phrases") and hasattr(result.config, "trigger_phrases"):
|
||||
if hasattr(result.config, "trigger_phrases") and hasattr(config, "trigger_phrases"):
|
||||
result.config.trigger_phrases = config.trigger_phrases
|
||||
result.config.source = config.source
|
||||
result.config.source_type = config.source_type
|
||||
@@ -392,7 +423,7 @@ async def update_model_record(
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
config = record_store.update_model(key, changes=changes, allow_class_change=True)
|
||||
config = add_cover_image_to_model_config(config, ApiDependencies)
|
||||
config = prepare_model_config_for_response(config, ApiDependencies)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -1124,7 +1155,7 @@ async def convert_model(
|
||||
|
||||
# return the config record for the new diffusers directory
|
||||
new_config = store.get_model(new_key)
|
||||
new_config = add_cover_image_to_model_config(new_config, ApiDependencies)
|
||||
new_config = prepare_model_config_for_response(new_config, ApiDependencies)
|
||||
return new_config
|
||||
|
||||
|
||||
|
||||
203
invokeai/app/invocations/external_image_generation.py
Normal file
203
invokeai/app/invocations/external_image_generation.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
MetadataField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageCollectionOutput
|
||||
from invokeai.app.services.external_generation.external_generation_common import (
|
||||
ExternalGenerationRequest,
|
||||
ExternalGenerationResult,
|
||||
ExternalReferenceImage,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig, ExternalGenerationMode
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
||||
|
||||
|
||||
class BaseExternalImageGenerationInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generate images using an external provider."""
|
||||
|
||||
provider_id: ClassVar[str | None] = None
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.main_model,
|
||||
ui_model_base=[BaseModelType.External],
|
||||
ui_model_type=[ModelType.ExternalImageGenerator],
|
||||
ui_model_format=[ModelFormat.ExternalApi],
|
||||
)
|
||||
mode: ExternalGenerationMode = InputField(default="txt2img", description="Generation mode")
|
||||
prompt: str = InputField(description="Prompt")
|
||||
seed: int | None = InputField(default=None, description=FieldDescriptions.seed)
|
||||
num_images: int = InputField(default=1, gt=0, description="Number of images to generate")
|
||||
width: int = InputField(default=1024, gt=0, description=FieldDescriptions.width)
|
||||
height: int = InputField(default=1024, gt=0, description=FieldDescriptions.height)
|
||||
image_size: str | None = InputField(default=None, description="Image size preset (e.g. 1K, 2K, 4K)")
|
||||
init_image: ImageField | None = InputField(default=None, description="Init image for img2img/inpaint")
|
||||
mask_image: ImageField | None = InputField(default=None, description="Mask image for inpaint")
|
||||
reference_images: list[ImageField] = InputField(default=[], description="Reference images")
|
||||
|
||||
def _build_provider_options(self) -> dict[str, Any] | None:
|
||||
"""Override in provider-specific subclasses to pass extra options."""
|
||||
return None
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
model_config = context.models.get_config(self.model)
|
||||
if not isinstance(model_config, ExternalApiModelConfig):
|
||||
raise ValueError("Selected model is not an external API model")
|
||||
|
||||
if self.provider_id is not None and model_config.provider_id != self.provider_id:
|
||||
raise ValueError(
|
||||
f"Selected model provider '{model_config.provider_id}' does not match node provider '{self.provider_id}'"
|
||||
)
|
||||
|
||||
init_image = None
|
||||
if self.init_image is not None:
|
||||
init_image = context.images.get_pil(self.init_image.image_name, mode="RGB")
|
||||
|
||||
mask_image = None
|
||||
if self.mask_image is not None:
|
||||
mask_image = context.images.get_pil(self.mask_image.image_name, mode="L")
|
||||
|
||||
reference_images: list[ExternalReferenceImage] = []
|
||||
for image_field in self.reference_images:
|
||||
reference_image = context.images.get_pil(image_field.image_name, mode="RGB")
|
||||
reference_images.append(ExternalReferenceImage(image=reference_image))
|
||||
|
||||
request = ExternalGenerationRequest(
|
||||
model=model_config,
|
||||
mode=self.mode,
|
||||
prompt=self.prompt,
|
||||
seed=self.seed,
|
||||
num_images=self.num_images,
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
image_size=self.image_size,
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
reference_images=reference_images,
|
||||
metadata=self._build_request_metadata(),
|
||||
provider_options=self._build_provider_options(),
|
||||
)
|
||||
|
||||
result = context._services.external_generation.generate(request)
|
||||
|
||||
outputs: list[ImageField] = []
|
||||
for generated in result.images:
|
||||
metadata = self._build_output_metadata(model_config, result, generated.seed)
|
||||
image_dto = context.images.save(image=generated.image, metadata=metadata)
|
||||
outputs.append(ImageField(image_name=image_dto.image_name))
|
||||
|
||||
return ImageCollectionOutput(collection=outputs)
|
||||
|
||||
def _build_request_metadata(self) -> dict[str, Any] | None:
|
||||
if self.metadata is None:
|
||||
return None
|
||||
return self.metadata.root
|
||||
|
||||
def _build_output_metadata(
|
||||
self,
|
||||
model_config: ExternalApiModelConfig,
|
||||
result: ExternalGenerationResult,
|
||||
image_seed: int | None,
|
||||
) -> MetadataField | None:
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
if self.metadata is not None:
|
||||
metadata.update(self.metadata.root)
|
||||
|
||||
metadata.update(
|
||||
{
|
||||
"external_provider": model_config.provider_id,
|
||||
"external_model_id": model_config.provider_model_id,
|
||||
}
|
||||
)
|
||||
|
||||
provider_request_id = getattr(result, "provider_request_id", None)
|
||||
if provider_request_id:
|
||||
metadata["external_request_id"] = provider_request_id
|
||||
|
||||
provider_metadata = getattr(result, "provider_metadata", None)
|
||||
if provider_metadata:
|
||||
metadata["external_provider_metadata"] = provider_metadata
|
||||
|
||||
if image_seed is not None:
|
||||
metadata["external_seed"] = image_seed
|
||||
|
||||
if not metadata:
|
||||
return None
|
||||
return MetadataField(root=metadata)
|
||||
|
||||
|
||||
@invocation(
|
||||
"external_image_generation",
|
||||
title="External Image Generation (Legacy)",
|
||||
tags=["external", "generation"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
classification=Classification.Internal,
|
||||
)
|
||||
class ExternalImageGenerationInvocation(BaseExternalImageGenerationInvocation):
|
||||
"""Legacy external image generation node kept for backward compatibility."""
|
||||
|
||||
|
||||
@invocation(
|
||||
"openai_image_generation",
|
||||
title="OpenAI Image Generation",
|
||||
tags=["external", "generation", "openai"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class OpenAIImageGenerationInvocation(BaseExternalImageGenerationInvocation):
|
||||
"""Generate images using an OpenAI-hosted external model."""
|
||||
|
||||
provider_id = "openai"
|
||||
|
||||
quality: Literal["auto", "high", "medium", "low"] = InputField(default="auto", description="Output image quality")
|
||||
background: Literal["auto", "transparent", "opaque"] = InputField(
|
||||
default="auto", description="Background transparency handling"
|
||||
)
|
||||
input_fidelity: Literal["low", "high"] | None = InputField(
|
||||
default=None, description="Fidelity to source images (edits only)"
|
||||
)
|
||||
|
||||
def _build_provider_options(self) -> dict[str, Any]:
|
||||
options: dict[str, Any] = {
|
||||
"quality": self.quality,
|
||||
"background": self.background,
|
||||
}
|
||||
if self.input_fidelity is not None:
|
||||
options["input_fidelity"] = self.input_fidelity
|
||||
return options
|
||||
|
||||
|
||||
@invocation(
|
||||
"gemini_image_generation",
|
||||
title="Gemini Image Generation",
|
||||
tags=["external", "generation", "gemini"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class GeminiImageGenerationInvocation(BaseExternalImageGenerationInvocation):
|
||||
"""Generate images using a Gemini-hosted external model."""
|
||||
|
||||
provider_id = "gemini"
|
||||
|
||||
temperature: float | None = InputField(default=None, ge=0.0, le=2.0, description="Sampling temperature")
|
||||
thinking_level: Literal["minimal", "high"] | None = InputField(
|
||||
default=None, description="Thinking level for image generation"
|
||||
)
|
||||
|
||||
def _build_provider_options(self) -> dict[str, Any] | None:
|
||||
options: dict[str, Any] = {}
|
||||
if self.temperature is not None:
|
||||
options["temperature"] = self.temperature
|
||||
if self.thinking_level is not None:
|
||||
options["thinking_level"] = self.thinking_level
|
||||
return options or None
|
||||
@@ -22,6 +22,7 @@ from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
API_KEYS_FILE = Path("api_keys.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||
@@ -30,6 +31,14 @@ ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8
|
||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
CONFIG_SCHEMA_VERSION = "4.0.2"
|
||||
EXTERNAL_PROVIDER_CONFIG_FIELDS = (
|
||||
"external_alibabacloud_api_key",
|
||||
"external_alibabacloud_base_url",
|
||||
"external_gemini_api_key",
|
||||
"external_gemini_base_url",
|
||||
"external_openai_api_key",
|
||||
"external_openai_base_url",
|
||||
)
|
||||
|
||||
|
||||
class URLRegexTokenPair(BaseModel):
|
||||
@@ -113,6 +122,10 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
allow_unknown_models: Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.
|
||||
multiuser: Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.
|
||||
strict_password_checking: Enforce strict password requirements. When True, passwords must contain uppercase, lowercase, and numbers. When False (default), any password is accepted but its strength (weak/moderate/strong) is reported to the user.
|
||||
external_gemini_api_key: API key for Gemini image generation.
|
||||
external_openai_api_key: API key for OpenAI image generation.
|
||||
external_gemini_base_url: Base URL override for Gemini image generation.
|
||||
external_openai_base_url: Base URL override for OpenAI image generation.
|
||||
"""
|
||||
|
||||
_root: Optional[Path] = PrivateAttr(default=None)
|
||||
@@ -211,6 +224,20 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
multiuser: bool = Field(default=False, description="Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.")
|
||||
strict_password_checking: bool = Field(default=False, description="Enforce strict password requirements. When True, passwords must contain uppercase, lowercase, and numbers. When False (default), any password is accepted but its strength (weak/moderate/strong) is reported to the user.")
|
||||
|
||||
# EXTERNAL PROVIDERS
|
||||
external_alibabacloud_api_key: Optional[str] = Field(default=None, description="API key for Alibaba Cloud DashScope image generation.")
|
||||
external_alibabacloud_base_url: Optional[str] = Field(
|
||||
default=None, description="Base URL override for Alibaba Cloud DashScope image generation."
|
||||
)
|
||||
external_gemini_api_key: Optional[str] = Field(default=None, description="API key for Gemini image generation.")
|
||||
external_openai_api_key: Optional[str] = Field(default=None, description="API key for OpenAI image generation.")
|
||||
external_gemini_base_url: Optional[str] = Field(
|
||||
default=None, description="Base URL override for Gemini image generation."
|
||||
)
|
||||
external_openai_base_url: Optional[str] = Field(
|
||||
default=None, description="Base URL override for OpenAI image generation."
|
||||
)
|
||||
|
||||
# fmt: on
|
||||
|
||||
model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True)
|
||||
@@ -292,6 +319,13 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
assert resolved_path is not None
|
||||
return resolved_path
|
||||
|
||||
@property
|
||||
def api_keys_file_path(self) -> Path:
|
||||
"""Path to api_keys.yaml, resolved to an absolute path.."""
|
||||
resolved_path = self._resolve(API_KEYS_FILE)
|
||||
assert resolved_path is not None
|
||||
return resolved_path
|
||||
|
||||
@property
|
||||
def outputs_path(self) -> Optional[Path]:
|
||||
"""Path to the outputs directory, resolved to an absolute path.."""
|
||||
@@ -504,6 +538,36 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
||||
|
||||
def load_external_api_keys(api_keys_file_path: Path) -> dict[str, str]:
|
||||
"""Load external provider config (API keys and base URLs) from a dedicated YAML file."""
|
||||
if not api_keys_file_path.exists():
|
||||
return {}
|
||||
|
||||
with open(api_keys_file_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
loaded_api_keys: Any = yaml.safe_load(file)
|
||||
|
||||
if loaded_api_keys is None:
|
||||
return {}
|
||||
|
||||
if not isinstance(loaded_api_keys, dict):
|
||||
raise RuntimeError(f"Failed to load api keys file {api_keys_file_path}: expected a mapping")
|
||||
|
||||
parsed_api_keys: dict[str, str] = {}
|
||||
for field_name in EXTERNAL_PROVIDER_CONFIG_FIELDS:
|
||||
value = loaded_api_keys.get(field_name)
|
||||
if value is None:
|
||||
continue
|
||||
if not isinstance(value, str):
|
||||
raise RuntimeError(
|
||||
f"Failed to load api keys file {api_keys_file_path}: value for '{field_name}' must be a string"
|
||||
)
|
||||
stripped_value = value.strip()
|
||||
if stripped_value:
|
||||
parsed_api_keys[field_name] = stripped_value
|
||||
|
||||
return parsed_api_keys
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_config() -> InvokeAIAppConfig:
|
||||
"""Get the global singleton app config.
|
||||
@@ -520,6 +584,7 @@ def get_config() -> InvokeAIAppConfig:
|
||||
"""
|
||||
# This object includes environment variables, as parsed by pydantic-settings
|
||||
config = InvokeAIAppConfig()
|
||||
env_fields_set = set(config.model_fields_set)
|
||||
|
||||
args = InvokeAIArgs.args
|
||||
|
||||
@@ -581,4 +646,11 @@ def get_config() -> InvokeAIAppConfig:
|
||||
default_config = DefaultInvokeAIAppConfig()
|
||||
default_config.write_file(config.config_file_path, as_example=False)
|
||||
|
||||
api_keys_from_file = load_external_api_keys(config.api_keys_file_path)
|
||||
if api_keys_from_file:
|
||||
# API keys file should take precedence over invokeai.yaml, but not over environment variables.
|
||||
api_keys_to_apply = {key: value for key, value in api_keys_from_file.items() if key not in env_fields_set}
|
||||
if api_keys_to_apply:
|
||||
config.update_config(api_keys_to_apply, clobber=True)
|
||||
|
||||
return config
|
||||
|
||||
23
invokeai/app/services/external_generation/__init__.py
Normal file
23
invokeai/app/services/external_generation/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from invokeai.app.services.external_generation.external_generation_base import (
|
||||
ExternalGenerationServiceBase,
|
||||
ExternalProvider,
|
||||
)
|
||||
from invokeai.app.services.external_generation.external_generation_common import (
|
||||
ExternalGeneratedImage,
|
||||
ExternalGenerationRequest,
|
||||
ExternalGenerationResult,
|
||||
ExternalProviderStatus,
|
||||
ExternalReferenceImage,
|
||||
)
|
||||
from invokeai.app.services.external_generation.external_generation_default import ExternalGenerationService
|
||||
|
||||
__all__ = [
|
||||
"ExternalGenerationRequest",
|
||||
"ExternalGenerationResult",
|
||||
"ExternalGeneratedImage",
|
||||
"ExternalGenerationService",
|
||||
"ExternalGenerationServiceBase",
|
||||
"ExternalProvider",
|
||||
"ExternalProviderStatus",
|
||||
"ExternalReferenceImage",
|
||||
]
|
||||
28
invokeai/app/services/external_generation/errors.py
Normal file
28
invokeai/app/services/external_generation/errors.py
Normal file
@@ -0,0 +1,28 @@
|
||||
class ExternalGenerationError(Exception):
|
||||
"""Base error for external generation."""
|
||||
|
||||
|
||||
class ExternalProviderNotFoundError(ExternalGenerationError):
|
||||
"""Raised when no provider is registered for a model."""
|
||||
|
||||
|
||||
class ExternalProviderNotConfiguredError(ExternalGenerationError):
|
||||
"""Raised when a provider is missing required credentials."""
|
||||
|
||||
|
||||
class ExternalProviderCapabilityError(ExternalGenerationError):
|
||||
"""Raised when a request is not supported by provider capabilities."""
|
||||
|
||||
|
||||
class ExternalProviderRequestError(ExternalGenerationError):
|
||||
"""Raised when a provider rejects the request or returns an error."""
|
||||
|
||||
|
||||
class ExternalProviderRateLimitError(ExternalProviderRequestError):
|
||||
"""Raised when a provider returns HTTP 429 (rate limit exceeded)."""
|
||||
|
||||
retry_after: float | None
|
||||
|
||||
def __init__(self, message: str, retry_after: float | None = None) -> None:
|
||||
super().__init__(message)
|
||||
self.retry_after = retry_after
|
||||
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.external_generation.external_generation_common import (
|
||||
ExternalGenerationRequest,
|
||||
ExternalGenerationResult,
|
||||
ExternalProviderStatus,
|
||||
)
|
||||
|
||||
|
||||
class ExternalProvider(ABC):
|
||||
provider_id: str
|
||||
|
||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||
self._app_config = app_config
|
||||
self._logger = logger
|
||||
|
||||
@abstractmethod
|
||||
def is_configured(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_status(self) -> ExternalProviderStatus:
|
||||
return ExternalProviderStatus(provider_id=self.provider_id, configured=self.is_configured())
|
||||
|
||||
|
||||
class ExternalGenerationServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_statuses(self) -> dict[str, ExternalProviderStatus]:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig, ExternalGenerationMode
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExternalReferenceImage:
|
||||
image: PILImageType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExternalGenerationRequest:
|
||||
model: ExternalApiModelConfig
|
||||
mode: ExternalGenerationMode
|
||||
prompt: str
|
||||
seed: int | None
|
||||
num_images: int
|
||||
width: int
|
||||
height: int
|
||||
image_size: str | None
|
||||
init_image: PILImageType | None
|
||||
mask_image: PILImageType | None
|
||||
reference_images: list[ExternalReferenceImage]
|
||||
metadata: dict[str, Any] | None
|
||||
provider_options: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExternalGeneratedImage:
|
||||
image: PILImageType
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExternalGenerationResult:
|
||||
images: list[ExternalGeneratedImage]
|
||||
seed_used: int | None = None
|
||||
provider_request_id: str | None = None
|
||||
provider_metadata: dict[str, Any] | None = None
|
||||
content_filters: dict[str, str] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExternalProviderStatus:
|
||||
provider_id: str
|
||||
configured: bool
|
||||
message: str | None = None
|
||||
@@ -0,0 +1,353 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.services.external_generation.errors import (
|
||||
ExternalProviderCapabilityError,
|
||||
ExternalProviderNotConfiguredError,
|
||||
ExternalProviderNotFoundError,
|
||||
ExternalProviderRateLimitError,
|
||||
)
|
||||
from invokeai.app.services.external_generation.external_generation_base import (
|
||||
ExternalGenerationServiceBase,
|
||||
ExternalProvider,
|
||||
)
|
||||
from invokeai.app.services.external_generation.external_generation_common import (
|
||||
ExternalGeneratedImage,
|
||||
ExternalGenerationRequest,
|
||||
ExternalGenerationResult,
|
||||
ExternalProviderStatus,
|
||||
)
|
||||
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig, ExternalImageSize
|
||||
from invokeai.backend.model_manager.starter_models import STARTER_MODELS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
|
||||
|
||||
class ExternalGenerationService(ExternalGenerationServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
providers: dict[str, ExternalProvider],
|
||||
logger: Logger,
|
||||
record_store: ModelRecordServiceBase | None = None,
|
||||
) -> None:
|
||||
self._providers = providers
|
||||
self._logger = logger
|
||||
self._record_store = record_store
|
||||
|
||||
def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult:
|
||||
provider = self._providers.get(request.model.provider_id)
|
||||
if provider is None:
|
||||
raise ExternalProviderNotFoundError(f"No external provider registered for '{request.model.provider_id}'")
|
||||
|
||||
if not provider.is_configured():
|
||||
raise ExternalProviderNotConfiguredError(f"Provider '{request.model.provider_id}' is missing credentials")
|
||||
|
||||
request = self._refresh_model_capabilities(request)
|
||||
resize_to_original_inpaint_size = _get_resize_target_for_inpaint(request)
|
||||
request = self._bucket_request(request)
|
||||
|
||||
self._validate_request(request)
|
||||
result = self._generate_with_retry(provider, request)
|
||||
|
||||
if resize_to_original_inpaint_size is None:
|
||||
return result
|
||||
|
||||
width, height = resize_to_original_inpaint_size
|
||||
return _resize_result_images(result, width, height)
|
||||
|
||||
_MAX_RETRIES = 3
|
||||
_DEFAULT_RETRY_DELAY = 10.0
|
||||
_MAX_RETRY_DELAY = 60.0
|
||||
|
||||
def _generate_with_retry(
|
||||
self, provider: ExternalProvider, request: ExternalGenerationRequest
|
||||
) -> ExternalGenerationResult:
|
||||
for attempt in range(self._MAX_RETRIES):
|
||||
try:
|
||||
return provider.generate(request)
|
||||
except ExternalProviderRateLimitError as exc:
|
||||
if attempt == self._MAX_RETRIES - 1:
|
||||
raise
|
||||
delay = min(exc.retry_after or self._DEFAULT_RETRY_DELAY, self._MAX_RETRY_DELAY)
|
||||
self._logger.warning(
|
||||
"Rate limited by %s (attempt %d/%d), retrying in %.0fs",
|
||||
request.model.provider_id,
|
||||
attempt + 1,
|
||||
self._MAX_RETRIES,
|
||||
delay,
|
||||
)
|
||||
time.sleep(delay)
|
||||
raise ExternalProviderRateLimitError("Rate limit exceeded after all retries")
|
||||
|
||||
def get_provider_statuses(self) -> dict[str, ExternalProviderStatus]:
|
||||
return {provider_id: provider.get_status() for provider_id, provider in self._providers.items()}
|
||||
|
||||
def _validate_request(self, request: ExternalGenerationRequest) -> None:
|
||||
capabilities = request.model.capabilities
|
||||
|
||||
self._logger.debug(
|
||||
"Validating external request provider=%s model=%s mode=%s supported=%s",
|
||||
request.model.provider_id,
|
||||
request.model.provider_model_id,
|
||||
request.mode,
|
||||
capabilities.modes,
|
||||
)
|
||||
|
||||
if request.mode not in capabilities.modes:
|
||||
raise ExternalProviderCapabilityError(f"Mode '{request.mode}' is not supported by {request.model.name}")
|
||||
|
||||
if request.seed is not None and not capabilities.supports_seed:
|
||||
raise ExternalProviderCapabilityError(f"Seed control is not supported by {request.model.name}")
|
||||
|
||||
if request.reference_images and not capabilities.supports_reference_images:
|
||||
raise ExternalProviderCapabilityError(f"Reference images are not supported by {request.model.name}")
|
||||
|
||||
if capabilities.max_reference_images is not None:
|
||||
if len(request.reference_images) > capabilities.max_reference_images:
|
||||
raise ExternalProviderCapabilityError(
|
||||
f"{request.model.name} supports at most {capabilities.max_reference_images} reference images"
|
||||
)
|
||||
|
||||
if capabilities.max_images_per_request is not None and request.num_images > capabilities.max_images_per_request:
|
||||
raise ExternalProviderCapabilityError(
|
||||
f"{request.model.name} supports at most {capabilities.max_images_per_request} images per request"
|
||||
)
|
||||
|
||||
if capabilities.max_image_size is not None:
|
||||
if request.width > capabilities.max_image_size.width or request.height > capabilities.max_image_size.height:
|
||||
raise ExternalProviderCapabilityError(
|
||||
f"{request.model.name} supports a maximum size of {capabilities.max_image_size.width}x{capabilities.max_image_size.height}"
|
||||
)
|
||||
|
||||
if capabilities.allowed_aspect_ratios:
|
||||
aspect_ratio = _format_aspect_ratio(request.width, request.height)
|
||||
if aspect_ratio not in capabilities.allowed_aspect_ratios:
|
||||
size_ratio = None
|
||||
if capabilities.aspect_ratio_sizes:
|
||||
size_ratio = _ratio_for_size(request.width, request.height, capabilities.aspect_ratio_sizes)
|
||||
if size_ratio is None or size_ratio not in capabilities.allowed_aspect_ratios:
|
||||
ratio_label = size_ratio or aspect_ratio
|
||||
raise ExternalProviderCapabilityError(
|
||||
f"{request.model.name} does not support aspect ratio {ratio_label}"
|
||||
)
|
||||
|
||||
required_modes = capabilities.input_image_required_for or ["img2img", "inpaint"]
|
||||
if request.mode in required_modes and request.init_image is None:
|
||||
raise ExternalProviderCapabilityError(
|
||||
f"Mode '{request.mode}' requires an init image for {request.model.name}"
|
||||
)
|
||||
|
||||
if request.mode == "inpaint" and request.mask_image is None:
|
||||
raise ExternalProviderCapabilityError(
|
||||
f"Mode '{request.mode}' requires a mask image for {request.model.name}"
|
||||
)
|
||||
|
||||
def _refresh_model_capabilities(self, request: ExternalGenerationRequest) -> ExternalGenerationRequest:
|
||||
if self._record_store is None:
|
||||
return request
|
||||
|
||||
try:
|
||||
record = self._record_store.get_model(request.model.key)
|
||||
except Exception:
|
||||
record = None
|
||||
|
||||
if not isinstance(record, ExternalApiModelConfig):
|
||||
return request
|
||||
|
||||
if record.key != request.model.key:
|
||||
return request
|
||||
|
||||
if record.provider_id != request.model.provider_id:
|
||||
return request
|
||||
|
||||
if record.provider_model_id != request.model.provider_model_id:
|
||||
return request
|
||||
|
||||
record = _apply_starter_overrides(record)
|
||||
|
||||
if record == request.model:
|
||||
return request
|
||||
|
||||
return ExternalGenerationRequest(
|
||||
model=record,
|
||||
mode=request.mode,
|
||||
prompt=request.prompt,
|
||||
seed=request.seed,
|
||||
num_images=request.num_images,
|
||||
width=request.width,
|
||||
height=request.height,
|
||||
image_size=request.image_size,
|
||||
init_image=request.init_image,
|
||||
mask_image=request.mask_image,
|
||||
reference_images=request.reference_images,
|
||||
metadata=request.metadata,
|
||||
provider_options=request.provider_options,
|
||||
)
|
||||
|
||||
def _bucket_request(self, request: ExternalGenerationRequest) -> ExternalGenerationRequest:
|
||||
capabilities = request.model.capabilities
|
||||
if not capabilities.allowed_aspect_ratios:
|
||||
return request
|
||||
|
||||
aspect_ratio = _format_aspect_ratio(request.width, request.height)
|
||||
size = None
|
||||
if capabilities.aspect_ratio_sizes:
|
||||
size = capabilities.aspect_ratio_sizes.get(aspect_ratio)
|
||||
|
||||
if size is not None:
|
||||
if request.width == size.width and request.height == size.height:
|
||||
return request
|
||||
return self._bucket_to_size(request, size.width, size.height, aspect_ratio)
|
||||
|
||||
if aspect_ratio in capabilities.allowed_aspect_ratios:
|
||||
return request
|
||||
|
||||
if not capabilities.aspect_ratio_sizes:
|
||||
return request
|
||||
|
||||
closest = _select_closest_ratio(
|
||||
request.width,
|
||||
request.height,
|
||||
capabilities.allowed_aspect_ratios,
|
||||
)
|
||||
if closest is None:
|
||||
return request
|
||||
|
||||
size = capabilities.aspect_ratio_sizes.get(closest)
|
||||
if size is None:
|
||||
return request
|
||||
|
||||
return self._bucket_to_size(request, size.width, size.height, closest)
|
||||
|
||||
def _bucket_to_size(
|
||||
self,
|
||||
request: ExternalGenerationRequest,
|
||||
width: int,
|
||||
height: int,
|
||||
ratio: str,
|
||||
) -> ExternalGenerationRequest:
|
||||
self._logger.info(
|
||||
"Bucketing external request provider=%s model=%s %sx%s -> %sx%s (ratio %s)",
|
||||
request.model.provider_id,
|
||||
request.model.provider_model_id,
|
||||
request.width,
|
||||
request.height,
|
||||
width,
|
||||
height,
|
||||
ratio,
|
||||
)
|
||||
|
||||
return ExternalGenerationRequest(
|
||||
model=request.model,
|
||||
mode=request.mode,
|
||||
prompt=request.prompt,
|
||||
seed=request.seed,
|
||||
num_images=request.num_images,
|
||||
width=width,
|
||||
height=height,
|
||||
image_size=request.image_size,
|
||||
init_image=_resize_image(request.init_image, width, height, "RGB"),
|
||||
mask_image=_resize_image(request.mask_image, width, height, "L"),
|
||||
reference_images=request.reference_images,
|
||||
metadata=request.metadata,
|
||||
provider_options=request.provider_options,
|
||||
)
|
||||
|
||||
|
||||
def _format_aspect_ratio(width: int, height: int) -> str:
|
||||
divisor = _gcd(width, height)
|
||||
return f"{width // divisor}:{height // divisor}"
|
||||
|
||||
|
||||
def _select_closest_ratio(width: int, height: int, ratios: list[str]) -> str | None:
|
||||
ratio = width / height
|
||||
parsed: list[tuple[str, float]] = []
|
||||
for value in ratios:
|
||||
parsed_ratio = _parse_ratio(value)
|
||||
if parsed_ratio is not None:
|
||||
parsed.append((value, parsed_ratio))
|
||||
if not parsed:
|
||||
return None
|
||||
return min(parsed, key=lambda item: abs(item[1] - ratio))[0]
|
||||
|
||||
|
||||
def _ratio_for_size(width: int, height: int, sizes: dict[str, ExternalImageSize]) -> str | None:
|
||||
for ratio, size in sizes.items():
|
||||
if size.width == width and size.height == height:
|
||||
return ratio
|
||||
return None
|
||||
|
||||
|
||||
def _parse_ratio(value: str) -> float | None:
|
||||
if ":" not in value:
|
||||
return None
|
||||
left, right = value.split(":", 1)
|
||||
try:
|
||||
numerator = float(left)
|
||||
denominator = float(right)
|
||||
except ValueError:
|
||||
return None
|
||||
if denominator == 0:
|
||||
return None
|
||||
return numerator / denominator
|
||||
|
||||
|
||||
def _gcd(a: int, b: int) -> int:
|
||||
while b:
|
||||
a, b = b, a % b
|
||||
return a
|
||||
|
||||
|
||||
def _resize_image(image: PILImageType | None, width: int, height: int, mode: str) -> PILImageType | None:
|
||||
if image is None:
|
||||
return None
|
||||
if image.width == width and image.height == height:
|
||||
return image
|
||||
return image.convert(mode).resize((width, height), Image.Resampling.LANCZOS)
|
||||
|
||||
|
||||
def _get_resize_target_for_inpaint(request: ExternalGenerationRequest) -> tuple[int, int] | None:
|
||||
if request.mode != "inpaint" or request.init_image is None:
|
||||
return None
|
||||
return request.init_image.width, request.init_image.height
|
||||
|
||||
|
||||
def _resize_result_images(result: ExternalGenerationResult, width: int, height: int) -> ExternalGenerationResult:
|
||||
resized_images = [
|
||||
ExternalGeneratedImage(
|
||||
image=generated.image
|
||||
if generated.image.width == width and generated.image.height == height
|
||||
else generated.image.resize((width, height), Image.Resampling.LANCZOS),
|
||||
seed=generated.seed,
|
||||
)
|
||||
for generated in result.images
|
||||
]
|
||||
return ExternalGenerationResult(
|
||||
images=resized_images,
|
||||
seed_used=result.seed_used,
|
||||
provider_request_id=result.provider_request_id,
|
||||
provider_metadata=result.provider_metadata,
|
||||
content_filters=result.content_filters,
|
||||
)
|
||||
|
||||
|
||||
def _apply_starter_overrides(model: ExternalApiModelConfig) -> ExternalApiModelConfig:
|
||||
source = model.source or f"external://{model.provider_id}/{model.provider_model_id}"
|
||||
starter_match = next((starter for starter in STARTER_MODELS if starter.source == source), None)
|
||||
if starter_match is None:
|
||||
return model
|
||||
updates: dict[str, object] = {}
|
||||
if starter_match.capabilities is not None:
|
||||
updates["capabilities"] = starter_match.capabilities
|
||||
if starter_match.default_settings is not None:
|
||||
updates["default_settings"] = starter_match.default_settings
|
||||
if not updates:
|
||||
return model
|
||||
return model.model_copy(update=updates)
|
||||
19
invokeai/app/services/external_generation/image_utils.py
Normal file
19
invokeai/app/services/external_generation/image_utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
|
||||
def encode_image_base64(image: PILImageType, format: str = "PNG") -> str:
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format=format)
|
||||
return base64.b64encode(buffer.getvalue()).decode("ascii")
|
||||
|
||||
|
||||
def decode_image_base64(encoded: str) -> PILImageType:
|
||||
data = base64.b64decode(encoded)
|
||||
image = Image.open(io.BytesIO(data))
|
||||
return image.convert("RGB")
|
||||
@@ -0,0 +1,5 @@
|
||||
from invokeai.app.services.external_generation.providers.alibabacloud import AlibabaCloudProvider
|
||||
from invokeai.app.services.external_generation.providers.gemini import GeminiProvider
|
||||
from invokeai.app.services.external_generation.providers.openai import OpenAIProvider
|
||||
|
||||
__all__ = ["AlibabaCloudProvider", "GeminiProvider", "OpenAIProvider"]
|
||||
@@ -0,0 +1,309 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import time
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.services.external_generation.errors import ExternalProviderRequestError
|
||||
from invokeai.app.services.external_generation.external_generation_base import ExternalProvider
|
||||
from invokeai.app.services.external_generation.external_generation_common import (
|
||||
ExternalGeneratedImage,
|
||||
ExternalGenerationRequest,
|
||||
ExternalGenerationResult,
|
||||
)
|
||||
from invokeai.app.services.external_generation.image_utils import decode_image_base64, encode_image_base64
|
||||
|
||||
# Models that support the synchronous multimodal-generation endpoint with messages format
|
||||
_SYNC_MODELS = {
|
||||
"qwen-image-2.0-pro",
|
||||
"qwen-image-2.0",
|
||||
"qwen-image-max",
|
||||
"wan2.6-t2i",
|
||||
"wan2.6-image",
|
||||
"qwen-image-edit-max",
|
||||
}
|
||||
|
||||
# Models that use the async image-generation endpoint with flat prompt format
|
||||
_ASYNC_MODELS = {
|
||||
"qwen-image-plus",
|
||||
"qwen-image",
|
||||
"qwen-image-edit-plus",
|
||||
"qwen-image-edit",
|
||||
"wan2.5-t2i-preview",
|
||||
"wan2.2-t2i-flash",
|
||||
"wanx2.0-t2i-turbo",
|
||||
}
|
||||
|
||||
# Models that support image editing (accept input images)
|
||||
_EDIT_MODELS = {
|
||||
"wan2.6-image",
|
||||
"qwen-image-edit-max",
|
||||
"qwen-image-edit-plus",
|
||||
"qwen-image-edit",
|
||||
}
|
||||
|
||||
_TASK_POLL_INTERVAL = 5 # seconds
|
||||
_TASK_POLL_TIMEOUT = 300 # seconds
|
||||
|
||||
|
||||
class AlibabaCloudProvider(ExternalProvider):
|
||||
provider_id = "alibabacloud"
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self._app_config.external_alibabacloud_api_key)
|
||||
|
||||
def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult:
|
||||
api_key = self._app_config.external_alibabacloud_api_key
|
||||
if not api_key:
|
||||
raise ExternalProviderRequestError("Alibaba Cloud DashScope API key is not configured")
|
||||
|
||||
base_url = (
|
||||
self._app_config.external_alibabacloud_base_url or "https://dashscope-intl.aliyuncs.com"
|
||||
).rstrip("/")
|
||||
model_id = request.model.provider_model_id
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
size = f"{request.width}*{request.height}"
|
||||
|
||||
if model_id in _SYNC_MODELS or model_id not in _ASYNC_MODELS:
|
||||
return self._generate_sync(request, base_url, headers, model_id, size)
|
||||
else:
|
||||
return self._generate_async(request, base_url, headers, model_id, size)
|
||||
|
||||
def _generate_sync(
|
||||
self,
|
||||
request: ExternalGenerationRequest,
|
||||
base_url: str,
|
||||
headers: dict[str, str],
|
||||
model_id: str,
|
||||
size: str,
|
||||
) -> ExternalGenerationResult:
|
||||
"""Use the synchronous multimodal-generation endpoint (messages format)."""
|
||||
endpoint = f"{base_url}/api/v1/services/aigc/multimodal-generation/generation"
|
||||
|
||||
content: list[dict[str, str]] = []
|
||||
|
||||
# Add init image for editing
|
||||
if request.init_image is not None and model_id in _EDIT_MODELS:
|
||||
content.append({"image": f"data:image/png;base64,{encode_image_base64(request.init_image)}"})
|
||||
|
||||
# Add reference images
|
||||
for ref in request.reference_images:
|
||||
content.append({"image": f"data:image/png;base64,{encode_image_base64(ref.image)}"})
|
||||
|
||||
content.append({"text": request.prompt})
|
||||
|
||||
parameters: dict[str, object] = {
|
||||
"size": size,
|
||||
"n": request.num_images,
|
||||
"prompt_extend": False,
|
||||
"watermark": False,
|
||||
}
|
||||
if request.negative_prompt:
|
||||
parameters["negative_prompt"] = request.negative_prompt
|
||||
if request.seed is not None:
|
||||
parameters["seed"] = request.seed
|
||||
|
||||
payload: dict[str, object] = {
|
||||
"model": model_id,
|
||||
"input": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
]
|
||||
},
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
response = requests.post(endpoint, headers=headers, json=payload, timeout=120)
|
||||
|
||||
if not response.ok:
|
||||
raise ExternalProviderRequestError(
|
||||
f"DashScope request failed with status {response.status_code} for model '{model_id}': {response.text}"
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
request_id = data.get("request_id")
|
||||
return self._parse_sync_response(data, request, request_id)
|
||||
|
||||
def _generate_async(
|
||||
self,
|
||||
request: ExternalGenerationRequest,
|
||||
base_url: str,
|
||||
headers: dict[str, str],
|
||||
model_id: str,
|
||||
size: str,
|
||||
) -> ExternalGenerationResult:
|
||||
"""Use the async image-generation endpoint (flat prompt format) with task polling."""
|
||||
endpoint = f"{base_url}/api/v1/services/aigc/image-generation/generation"
|
||||
async_headers = {**headers, "X-DashScope-Async": "enable"}
|
||||
|
||||
parameters: dict[str, object] = {
|
||||
"size": size,
|
||||
"n": request.num_images,
|
||||
"prompt_extend": False,
|
||||
"watermark": False,
|
||||
}
|
||||
if request.negative_prompt:
|
||||
parameters["negative_prompt"] = request.negative_prompt
|
||||
if request.seed is not None:
|
||||
parameters["seed"] = request.seed
|
||||
|
||||
input_data: dict[str, object] = {"prompt": request.prompt}
|
||||
if request.negative_prompt:
|
||||
input_data["negative_prompt"] = request.negative_prompt
|
||||
|
||||
payload: dict[str, object] = {
|
||||
"model": model_id,
|
||||
"input": input_data,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
response = requests.post(endpoint, headers=async_headers, json=payload, timeout=60)
|
||||
|
||||
if not response.ok:
|
||||
raise ExternalProviderRequestError(
|
||||
f"DashScope async request failed with status {response.status_code} for model '{model_id}': {response.text}"
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
request_id = data.get("request_id")
|
||||
output = data.get("output", {})
|
||||
task_id = output.get("task_id")
|
||||
|
||||
if not task_id:
|
||||
raise ExternalProviderRequestError(f"DashScope async response missing task_id: {data}")
|
||||
|
||||
return self._poll_task(base_url, headers, task_id, request, request_id)
|
||||
|
||||
def _poll_task(
|
||||
self,
|
||||
base_url: str,
|
||||
headers: dict[str, str],
|
||||
task_id: str,
|
||||
request: ExternalGenerationRequest,
|
||||
request_id: str | None,
|
||||
) -> ExternalGenerationResult:
|
||||
"""Poll an async task until completion."""
|
||||
task_url = f"{base_url}/api/v1/tasks/{task_id}"
|
||||
start_time = time.monotonic()
|
||||
|
||||
while True:
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed > _TASK_POLL_TIMEOUT:
|
||||
raise ExternalProviderRequestError(
|
||||
f"DashScope task {task_id} timed out after {_TASK_POLL_TIMEOUT}s"
|
||||
)
|
||||
|
||||
time.sleep(_TASK_POLL_INTERVAL)
|
||||
|
||||
response = requests.get(task_url, headers={"Authorization": headers["Authorization"]}, timeout=30)
|
||||
if not response.ok:
|
||||
raise ExternalProviderRequestError(
|
||||
f"DashScope task poll failed with status {response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
output = data.get("output", {})
|
||||
status = output.get("task_status")
|
||||
|
||||
if status == "SUCCEEDED":
|
||||
return self._parse_async_response(output, request, request_id)
|
||||
elif status in ("FAILED", "UNKNOWN"):
|
||||
message = output.get("message", "Unknown error")
|
||||
raise ExternalProviderRequestError(f"DashScope task {task_id} failed: {message}")
|
||||
|
||||
self._logger.debug("DashScope task %s status: %s (%.0fs elapsed)", task_id, status, elapsed)
|
||||
|
||||
def _parse_sync_response(
|
||||
self,
|
||||
data: dict[str, object],
|
||||
request: ExternalGenerationRequest,
|
||||
request_id: str | None,
|
||||
) -> ExternalGenerationResult:
|
||||
"""Parse the synchronous multimodal-generation response."""
|
||||
output = data.get("output")
|
||||
if not isinstance(output, dict):
|
||||
raise ExternalProviderRequestError(f"DashScope response missing output: {data}")
|
||||
|
||||
choices = output.get("choices")
|
||||
if not isinstance(choices, list):
|
||||
raise ExternalProviderRequestError(f"DashScope response missing choices: {data}")
|
||||
|
||||
images: list[ExternalGeneratedImage] = []
|
||||
for choice in choices:
|
||||
if not isinstance(choice, dict):
|
||||
continue
|
||||
message = choice.get("message")
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
image_url = part.get("image")
|
||||
if isinstance(image_url, str) and image_url:
|
||||
pil_image = self._download_image(image_url)
|
||||
images.append(ExternalGeneratedImage(image=pil_image, seed=request.seed))
|
||||
|
||||
if not images:
|
||||
raise ExternalProviderRequestError(f"DashScope response contained no images: {data}")
|
||||
|
||||
return ExternalGenerationResult(
|
||||
images=images,
|
||||
seed_used=request.seed,
|
||||
provider_request_id=request_id,
|
||||
provider_metadata={"model": request.model.provider_model_id},
|
||||
)
|
||||
|
||||
def _parse_async_response(
|
||||
self,
|
||||
output: dict[str, object],
|
||||
request: ExternalGenerationRequest,
|
||||
request_id: str | None,
|
||||
) -> ExternalGenerationResult:
|
||||
"""Parse the async task completion response."""
|
||||
results = output.get("results")
|
||||
if not isinstance(results, list):
|
||||
raise ExternalProviderRequestError(f"DashScope async response missing results: {output}")
|
||||
|
||||
images: list[ExternalGeneratedImage] = []
|
||||
for result in results:
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
url = result.get("url")
|
||||
if isinstance(url, str) and url:
|
||||
pil_image = self._download_image(url)
|
||||
images.append(ExternalGeneratedImage(image=pil_image, seed=request.seed))
|
||||
b64_image = result.get("b64_image")
|
||||
if isinstance(b64_image, str) and b64_image:
|
||||
pil_image = decode_image_base64(b64_image)
|
||||
images.append(ExternalGeneratedImage(image=pil_image, seed=request.seed))
|
||||
|
||||
if not images:
|
||||
raise ExternalProviderRequestError(f"DashScope async response contained no images: {output}")
|
||||
|
||||
return ExternalGenerationResult(
|
||||
images=images,
|
||||
seed_used=request.seed,
|
||||
provider_request_id=request_id,
|
||||
provider_metadata={"model": request.model.provider_model_id},
|
||||
)
|
||||
|
||||
def _download_image(self, url: str) -> PILImageType:
|
||||
"""Download an image from a URL and return it as a PIL Image."""
|
||||
response = requests.get(url, timeout=60)
|
||||
if not response.ok:
|
||||
raise ExternalProviderRequestError(
|
||||
f"Failed to download image from DashScope (status {response.status_code})"
|
||||
)
|
||||
return Image.open(io.BytesIO(response.content)).convert("RGB")
|
||||
282
invokeai/app/services/external_generation/providers/gemini.py
Normal file
282
invokeai/app/services/external_generation/providers/gemini.py
Normal file
@@ -0,0 +1,282 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.services.external_generation.errors import (
|
||||
ExternalProviderRateLimitError,
|
||||
ExternalProviderRequestError,
|
||||
)
|
||||
from invokeai.app.services.external_generation.external_generation_base import ExternalProvider
|
||||
from invokeai.app.services.external_generation.external_generation_common import (
|
||||
ExternalGeneratedImage,
|
||||
ExternalGenerationRequest,
|
||||
ExternalGenerationResult,
|
||||
)
|
||||
from invokeai.app.services.external_generation.image_utils import decode_image_base64, encode_image_base64
|
||||
|
||||
|
||||
class GeminiProvider(ExternalProvider):
|
||||
provider_id = "gemini"
|
||||
_SYSTEM_INSTRUCTION = (
|
||||
"You are an image generation model. Always respond with an image based on the user's prompt. "
|
||||
"Do not return text-only responses. If the user input is not an edit instruction, "
|
||||
"interpret it as a request to create a new image."
|
||||
)
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self._app_config.external_gemini_api_key)
|
||||
|
||||
def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult:
|
||||
api_key = self._app_config.external_gemini_api_key
|
||||
if not api_key:
|
||||
raise ExternalProviderRequestError("Gemini API key is not configured")
|
||||
|
||||
base_url = (self._app_config.external_gemini_base_url or "https://generativelanguage.googleapis.com").rstrip(
|
||||
"/"
|
||||
)
|
||||
if not base_url.endswith("/v1") and not base_url.endswith("/v1beta"):
|
||||
base_url = f"{base_url}/v1beta"
|
||||
model_id = request.model.provider_model_id.removeprefix("models/")
|
||||
endpoint = f"{base_url}/models/{model_id}:generateContent"
|
||||
|
||||
request_parts: list[dict[str, object]] = []
|
||||
|
||||
if request.init_image is not None:
|
||||
request_parts.append(
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": encode_image_base64(request.init_image),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
request_parts.append({"text": request.prompt})
|
||||
|
||||
for reference in request.reference_images:
|
||||
request_parts.append(
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": encode_image_base64(reference.image),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
opts = request.provider_options or {}
|
||||
|
||||
generation_config: dict[str, object] = {
|
||||
"candidateCount": request.num_images,
|
||||
"responseModalities": ["IMAGE"],
|
||||
}
|
||||
if "temperature" in opts:
|
||||
generation_config["temperature"] = opts["temperature"]
|
||||
aspect_ratio = _select_aspect_ratio(
|
||||
request.width,
|
||||
request.height,
|
||||
request.model.capabilities.allowed_aspect_ratios,
|
||||
)
|
||||
uses_image_config = request.model.capabilities.resolution_presets is not None
|
||||
if uses_image_config:
|
||||
image_config: dict[str, str] = {}
|
||||
if aspect_ratio is not None:
|
||||
image_config["aspectRatio"] = aspect_ratio
|
||||
if request.image_size is not None:
|
||||
image_config["imageSize"] = request.image_size
|
||||
if image_config:
|
||||
generation_config["imageConfig"] = image_config
|
||||
system_instruction = self._SYSTEM_INSTRUCTION
|
||||
if request.init_image is not None:
|
||||
system_instruction = (
|
||||
f"{system_instruction} An input image is provided. "
|
||||
"Treat the prompt as an edit instruction and modify the image accordingly. "
|
||||
"Do not return the original image unchanged."
|
||||
)
|
||||
if not uses_image_config and aspect_ratio is not None:
|
||||
system_instruction = f"{system_instruction} Use an aspect ratio of {aspect_ratio}."
|
||||
|
||||
payload: dict[str, object] = {
|
||||
"systemInstruction": {"parts": [{"text": system_instruction}]},
|
||||
"contents": [{"role": "user", "parts": request_parts}],
|
||||
"generationConfig": generation_config,
|
||||
}
|
||||
if "thinking_level" in opts:
|
||||
payload["thinkingConfig"] = {"thinkingLevel": opts["thinking_level"].upper()}
|
||||
|
||||
self._dump_debug_payload("request", payload)
|
||||
|
||||
response = requests.post(
|
||||
endpoint,
|
||||
params={"key": api_key},
|
||||
json=payload,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
if response.status_code == 429:
|
||||
retry_after = _parse_retry_after(response.headers.get("retry-after"))
|
||||
raise ExternalProviderRateLimitError(
|
||||
f"Gemini rate limit exceeded. {f'Retry after {retry_after:.0f}s.' if retry_after else 'Please try again later.'}",
|
||||
retry_after=retry_after,
|
||||
)
|
||||
raise ExternalProviderRequestError(
|
||||
f"Gemini request failed with status {response.status_code} for model '{model_id}': {response.text}"
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
self._dump_debug_payload("response", data)
|
||||
if not isinstance(data, dict):
|
||||
raise ExternalProviderRequestError("Gemini response payload was not a JSON object")
|
||||
images: list[ExternalGeneratedImage] = []
|
||||
text_parts: list[str] = []
|
||||
finish_messages: list[str] = []
|
||||
candidates = data.get("candidates")
|
||||
if not isinstance(candidates, list):
|
||||
raise ExternalProviderRequestError("Gemini response payload missing candidates")
|
||||
for candidate in candidates:
|
||||
if not isinstance(candidate, dict):
|
||||
continue
|
||||
finish_message = candidate.get("finishMessage")
|
||||
finish_reason = candidate.get("finishReason")
|
||||
if isinstance(finish_message, str):
|
||||
finish_messages.append(finish_message)
|
||||
elif isinstance(finish_reason, str):
|
||||
finish_messages.append(f"Finish reason: {finish_reason}")
|
||||
for part in _iter_response_parts(candidate):
|
||||
inline_data = part.get("inline_data") or part.get("inlineData")
|
||||
if isinstance(inline_data, dict):
|
||||
encoded = inline_data.get("data")
|
||||
if encoded:
|
||||
image = decode_image_base64(encoded)
|
||||
images.append(ExternalGeneratedImage(image=image, seed=request.seed))
|
||||
self._dump_debug_image(image)
|
||||
continue
|
||||
file_data = part.get("fileData") or part.get("file_data")
|
||||
if isinstance(file_data, dict):
|
||||
file_uri = file_data.get("fileUri") or file_data.get("file_uri")
|
||||
if isinstance(file_uri, str) and file_uri:
|
||||
raise ExternalProviderRequestError(
|
||||
f"Gemini returned fileUri instead of inline image data: {file_uri}"
|
||||
)
|
||||
text = part.get("text")
|
||||
if isinstance(text, str):
|
||||
text_parts.append(text)
|
||||
|
||||
if not images:
|
||||
self._logger.error("Gemini response contained no images: %s", data)
|
||||
detail = ""
|
||||
if finish_messages:
|
||||
combined = " ".join(message.strip() for message in finish_messages if message.strip())
|
||||
if combined:
|
||||
detail = f" Response status: {combined[:500]}"
|
||||
elif text_parts:
|
||||
combined = " ".join(text_parts).strip()
|
||||
if combined:
|
||||
detail = f" Response text: {combined[:500]}"
|
||||
raise ExternalProviderRequestError(f"Gemini response contained no images.{detail}")
|
||||
|
||||
return ExternalGenerationResult(
|
||||
images=images,
|
||||
seed_used=request.seed,
|
||||
provider_metadata={"model": request.model.provider_model_id},
|
||||
)
|
||||
|
||||
def _dump_debug_payload(self, label: str, payload: object) -> None:
|
||||
"""TODO: remove debug payload dump once Gemini is stable."""
|
||||
try:
|
||||
outputs_path = self._app_config.outputs_path
|
||||
if outputs_path is None:
|
||||
return
|
||||
debug_dir = outputs_path / "external_debug" / "gemini"
|
||||
debug_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = debug_dir / f"{label}_{uuid.uuid4().hex}.json"
|
||||
path.write_text(json.dumps(payload, indent=2, default=str), encoding="utf-8")
|
||||
except Exception as exc:
|
||||
self._logger.debug("Failed to write Gemini debug payload: %s", exc)
|
||||
|
||||
def _dump_debug_image(self, image: "PILImageType") -> None:
|
||||
"""TODO: remove debug image dump once Gemini is stable."""
|
||||
try:
|
||||
outputs_path = self._app_config.outputs_path
|
||||
if outputs_path is None:
|
||||
return
|
||||
debug_dir = outputs_path / "external_debug" / "gemini"
|
||||
debug_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = debug_dir / f"decoded_{uuid.uuid4().hex}.png"
|
||||
image.save(path, format="PNG")
|
||||
except Exception as exc:
|
||||
self._logger.debug("Failed to write Gemini debug image: %s", exc)
|
||||
|
||||
|
||||
def _iter_response_parts(candidate: dict[str, object]) -> list[dict[str, object]]:
|
||||
content = candidate.get("content")
|
||||
if isinstance(content, dict):
|
||||
content_parts = content.get("parts")
|
||||
if isinstance(content_parts, list):
|
||||
return [part for part in content_parts if isinstance(part, dict)]
|
||||
contents = candidate.get("contents")
|
||||
if isinstance(contents, list):
|
||||
parts: list[dict[str, object]] = []
|
||||
for item in contents:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_parts = item.get("parts")
|
||||
if isinstance(item_parts, list):
|
||||
parts.extend([part for part in item_parts if isinstance(part, dict)])
|
||||
if parts:
|
||||
return parts
|
||||
return []
|
||||
|
||||
|
||||
def _select_aspect_ratio(width: int, height: int, allowed: list[str] | None) -> str | None:
|
||||
if width <= 0 or height <= 0:
|
||||
return None
|
||||
ratio = width / height
|
||||
default_ratio = _format_aspect_ratio(width, height)
|
||||
if not allowed:
|
||||
return default_ratio
|
||||
parsed = [(value, _parse_ratio(value)) for value in allowed]
|
||||
filtered = [(value, parsed_ratio) for value, parsed_ratio in parsed if parsed_ratio is not None]
|
||||
if not filtered:
|
||||
return default_ratio
|
||||
return min(filtered, key=lambda item: abs(item[1] - ratio))[0]
|
||||
|
||||
|
||||
def _format_aspect_ratio(width: int, height: int) -> str | None:
|
||||
if width <= 0 or height <= 0:
|
||||
return None
|
||||
divisor = _gcd(width, height)
|
||||
return f"{width // divisor}:{height // divisor}"
|
||||
|
||||
|
||||
def _parse_ratio(value: str) -> float | None:
|
||||
if ":" not in value:
|
||||
return None
|
||||
left, right = value.split(":", 1)
|
||||
try:
|
||||
numerator = float(left)
|
||||
denominator = float(right)
|
||||
except ValueError:
|
||||
return None
|
||||
if denominator == 0:
|
||||
return None
|
||||
return numerator / denominator
|
||||
|
||||
|
||||
def _parse_retry_after(value: str | None) -> float | None:
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _gcd(a: int, b: int) -> int:
|
||||
while b:
|
||||
a, b = b, a % b
|
||||
return a
|
||||
159
invokeai/app/services/external_generation/providers/openai.py
Normal file
159
invokeai/app/services/external_generation/providers/openai.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
|
||||
import requests
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.services.external_generation.errors import (
|
||||
ExternalProviderRateLimitError,
|
||||
ExternalProviderRequestError,
|
||||
)
|
||||
from invokeai.app.services.external_generation.external_generation_base import ExternalProvider
|
||||
from invokeai.app.services.external_generation.external_generation_common import (
|
||||
ExternalGeneratedImage,
|
||||
ExternalGenerationRequest,
|
||||
ExternalGenerationResult,
|
||||
)
|
||||
from invokeai.app.services.external_generation.image_utils import decode_image_base64
|
||||
|
||||
|
||||
class OpenAIProvider(ExternalProvider):
|
||||
provider_id = "openai"
|
||||
|
||||
_GPT_IMAGE_MODELS = {"gpt-image-1", "gpt-image-1.5", "gpt-image-1-mini"}
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self._app_config.external_openai_api_key)
|
||||
|
||||
def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult:
|
||||
api_key = self._app_config.external_openai_api_key
|
||||
if not api_key:
|
||||
raise ExternalProviderRequestError("OpenAI API key is not configured")
|
||||
|
||||
model_id = request.model.provider_model_id
|
||||
is_gpt_image = model_id in self._GPT_IMAGE_MODELS
|
||||
size = f"{request.width}x{request.height}"
|
||||
base_url = (self._app_config.external_openai_base_url or "https://api.openai.com").rstrip("/")
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
use_edits_endpoint = request.mode != "txt2img" or bool(request.reference_images)
|
||||
|
||||
opts = request.provider_options or {}
|
||||
|
||||
if not use_edits_endpoint:
|
||||
payload: dict[str, object] = {
|
||||
"model": model_id,
|
||||
"prompt": request.prompt,
|
||||
"n": request.num_images,
|
||||
"size": size,
|
||||
}
|
||||
# GPT Image models use output_format; DALL-E uses response_format
|
||||
if is_gpt_image:
|
||||
payload["output_format"] = "png"
|
||||
else:
|
||||
payload["response_format"] = "b64_json"
|
||||
if is_gpt_image:
|
||||
if opts.get("quality") and opts["quality"] != "auto":
|
||||
payload["quality"] = opts["quality"]
|
||||
if opts.get("background") and opts["background"] != "auto":
|
||||
payload["background"] = opts["background"]
|
||||
response = requests.post(
|
||||
f"{base_url}/v1/images/generations",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=120,
|
||||
)
|
||||
else:
|
||||
images: list[PILImageType] = []
|
||||
if request.init_image is not None:
|
||||
images.append(request.init_image)
|
||||
images.extend(reference.image for reference in request.reference_images)
|
||||
if not images:
|
||||
raise ExternalProviderRequestError(
|
||||
"OpenAI image edits require at least one image (init image or reference image)"
|
||||
)
|
||||
|
||||
files: list[tuple[str, tuple[str, io.BytesIO, str]]] = []
|
||||
image_field_name = "image" if len(images) == 1 else "image[]"
|
||||
for index, image in enumerate(images):
|
||||
image_buffer = io.BytesIO()
|
||||
image.save(image_buffer, format="PNG")
|
||||
image_buffer.seek(0)
|
||||
files.append((image_field_name, (f"image_{index}.png", image_buffer, "image/png")))
|
||||
|
||||
if request.mask_image is not None:
|
||||
mask_buffer = io.BytesIO()
|
||||
request.mask_image.save(mask_buffer, format="PNG")
|
||||
mask_buffer.seek(0)
|
||||
files.append(("mask", ("mask.png", mask_buffer, "image/png")))
|
||||
|
||||
data: dict[str, object] = {
|
||||
"model": model_id,
|
||||
"prompt": request.prompt,
|
||||
"n": request.num_images,
|
||||
"size": size,
|
||||
}
|
||||
if is_gpt_image:
|
||||
data["output_format"] = "png"
|
||||
else:
|
||||
data["response_format"] = "b64_json"
|
||||
if is_gpt_image:
|
||||
if opts.get("quality") and opts["quality"] != "auto":
|
||||
data["quality"] = opts["quality"]
|
||||
if opts.get("background") and opts["background"] != "auto":
|
||||
data["background"] = opts["background"]
|
||||
if opts.get("input_fidelity"):
|
||||
data["input_fidelity"] = opts["input_fidelity"]
|
||||
response = requests.post(
|
||||
f"{base_url}/v1/images/edits",
|
||||
headers=headers,
|
||||
data=data,
|
||||
files=files,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
if response.status_code == 429:
|
||||
retry_after = _parse_retry_after(response.headers.get("retry-after"))
|
||||
raise ExternalProviderRateLimitError(
|
||||
f"OpenAI rate limit exceeded. {f'Retry after {retry_after:.0f}s.' if retry_after else 'Please try again later.'}",
|
||||
retry_after=retry_after,
|
||||
)
|
||||
raise ExternalProviderRequestError(
|
||||
f"OpenAI request failed with status {response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
response_payload = response.json()
|
||||
if not isinstance(response_payload, dict):
|
||||
raise ExternalProviderRequestError("OpenAI response payload was not a JSON object")
|
||||
images: list[ExternalGeneratedImage] = []
|
||||
data_items = response_payload.get("data")
|
||||
if not isinstance(data_items, list):
|
||||
raise ExternalProviderRequestError("OpenAI response payload missing image data")
|
||||
for item in data_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
encoded = item.get("b64_json")
|
||||
if not encoded:
|
||||
continue
|
||||
images.append(ExternalGeneratedImage(image=decode_image_base64(encoded), seed=request.seed))
|
||||
|
||||
if not images:
|
||||
raise ExternalProviderRequestError("OpenAI response contained no images")
|
||||
|
||||
return ExternalGenerationResult(
|
||||
images=images,
|
||||
seed_used=request.seed,
|
||||
provider_request_id=response.headers.get("x-request-id"),
|
||||
provider_metadata={"model": model_id},
|
||||
)
|
||||
|
||||
|
||||
def _parse_retry_after(value: str | None) -> float | None:
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
59
invokeai/app/services/external_generation/startup.py
Normal file
59
invokeai/app/services/external_generation/startup.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig
|
||||
from invokeai.backend.model_manager.starter_models import STARTER_MODELS
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
|
||||
|
||||
|
||||
def sync_configured_external_starter_models(
|
||||
configured_provider_ids: set[str],
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
logger: Logger,
|
||||
) -> list[str]:
|
||||
"""Queue missing external starter models for configured providers."""
|
||||
|
||||
if not configured_provider_ids:
|
||||
return []
|
||||
|
||||
installed_sources = {
|
||||
model.source
|
||||
for model in model_manager.store.search_by_attr(
|
||||
base_model=BaseModelType.External,
|
||||
model_type=ModelType.ExternalImageGenerator,
|
||||
)
|
||||
if isinstance(model, ExternalApiModelConfig) and model.source
|
||||
}
|
||||
|
||||
queued_sources: list[str] = []
|
||||
for starter_model in STARTER_MODELS:
|
||||
if not starter_model.source.startswith("external://"):
|
||||
continue
|
||||
|
||||
provider_id = starter_model.source.removeprefix("external://").split("/", 1)[0]
|
||||
if provider_id not in configured_provider_ids:
|
||||
continue
|
||||
|
||||
if starter_model.source in installed_sources:
|
||||
continue
|
||||
|
||||
model_manager.install.heuristic_import(
|
||||
starter_model.source,
|
||||
config=ModelRecordChanges(
|
||||
name=starter_model.name,
|
||||
base=starter_model.base,
|
||||
type=starter_model.type,
|
||||
description=starter_model.description,
|
||||
format=starter_model.format,
|
||||
capabilities=starter_model.capabilities,
|
||||
default_settings=starter_model.default_settings,
|
||||
),
|
||||
)
|
||||
queued_sources.append(starter_model.source)
|
||||
logger.info("Queued external starter model sync for %s", starter_model.source)
|
||||
|
||||
return queued_sources
|
||||
@@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.external_generation.external_generation_base import ExternalGenerationServiceBase
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase
|
||||
from invokeai.app.services.images.images_base import ImageServiceABC
|
||||
@@ -63,6 +64,7 @@ class InvocationServices:
|
||||
model_relationships: "ModelRelationshipsServiceABC",
|
||||
model_relationship_records: "ModelRelationshipRecordStorageBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
external_generation: "ExternalGenerationServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
session_queue: "SessionQueueBase",
|
||||
session_processor: "SessionProcessorBase",
|
||||
@@ -94,6 +96,7 @@ class InvocationServices:
|
||||
self.model_relationships = model_relationships
|
||||
self.model_relationship_records = model_relationship_records
|
||||
self.download_queue = download_queue
|
||||
self.external_generation = external_generation
|
||||
self.performance_statistics = performance_statistics
|
||||
self.session_queue = session_queue
|
||||
self.session_processor = session_processor
|
||||
|
||||
@@ -139,12 +139,27 @@ class URLModelSource(StringLikeSource):
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
class ExternalModelSource(StringLikeSource):
|
||||
"""An external provider model identifier."""
|
||||
|
||||
provider_id: str
|
||||
provider_model_id: str
|
||||
type: Literal["external"] = "external"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"external://{self.provider_id}/{self.provider_model_id}"
|
||||
|
||||
|
||||
ModelSource = Annotated[
|
||||
Union[LocalModelSource, HFModelSource, URLModelSource, ExternalModelSource],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
ExternalModelSource: ModelSourceType.External,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_install.model_install_common import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
ExternalModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
InvalidModelConfigException,
|
||||
@@ -37,10 +38,15 @@ from invokeai.app.services.model_install.model_install_common import (
|
||||
StringLikeSource,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base
|
||||
from invokeai.backend.model_manager.configs.external_api import (
|
||||
ExternalApiModelConfig,
|
||||
ExternalApiModelDefaultSettings,
|
||||
ExternalModelCapabilities,
|
||||
)
|
||||
from invokeai.backend.model_manager.configs.factory import (
|
||||
AnyModelConfig,
|
||||
ModelConfigFactory,
|
||||
@@ -55,7 +61,13 @@ from invokeai.backend.model_manager.metadata import (
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.lora_metadata_extractor import apply_lora_metadata
|
||||
from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
@@ -459,6 +471,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
install_job = self._import_from_hf(source, config)
|
||||
elif isinstance(source, URLModelSource):
|
||||
install_job = self._import_from_url(source, config)
|
||||
elif isinstance(source, ExternalModelSource):
|
||||
install_job = self._import_external_model(source, config)
|
||||
self._put_in_queue(install_job)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{type(source)}'")
|
||||
|
||||
@@ -758,7 +773,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
source_stripped = source.strip('"')
|
||||
|
||||
if Path(source_stripped).exists(): # A local file or directory
|
||||
if source_stripped.startswith("external://"):
|
||||
external_id = source_stripped.removeprefix("external://")
|
||||
provider_id, _, provider_model_id = external_id.partition("/")
|
||||
if not provider_id or not provider_model_id:
|
||||
raise ValueError(f"Invalid external model source: '{source_stripped}'")
|
||||
source_obj = ExternalModelSource(provider_id=provider_id, provider_model_id=provider_model_id)
|
||||
elif Path(source_stripped).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source_stripped))
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
@@ -850,6 +871,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
|
||||
|
||||
def _register_or_install(self, job: ModelInstallJob) -> None:
|
||||
if isinstance(job.source, ExternalModelSource):
|
||||
self._register_external_model(job)
|
||||
return
|
||||
# local jobs will be in waiting state, remote jobs will be downloading state
|
||||
job.total_bytes = self._stat_size(job.local_path)
|
||||
job.bytes = job.total_bytes
|
||||
@@ -870,6 +894,71 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
def _register_external_model(self, job: ModelInstallJob) -> None:
|
||||
job.total_bytes = 0
|
||||
job.bytes = 0
|
||||
self._signal_job_running(job)
|
||||
job.config_in.source = str(job.source)
|
||||
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
|
||||
provider_id = job.source.provider_id
|
||||
provider_model_id = job.source.provider_model_id
|
||||
capabilities = job.config_in.capabilities or ExternalModelCapabilities()
|
||||
default_settings = (
|
||||
job.config_in.default_settings
|
||||
if isinstance(job.config_in.default_settings, ExternalApiModelDefaultSettings)
|
||||
else None
|
||||
)
|
||||
name = job.config_in.name or f"{provider_id} {provider_model_id}"
|
||||
key = job.config_in.key or slugify(f"{provider_id}-{provider_model_id}")
|
||||
|
||||
existing_external = next(
|
||||
(
|
||||
model
|
||||
for model in self.record_store.search_by_attr(
|
||||
base_model=BaseModelType.External, model_type=ModelType.ExternalImageGenerator
|
||||
)
|
||||
if isinstance(model, ExternalApiModelConfig)
|
||||
and model.provider_id == provider_id
|
||||
and model.provider_model_id == provider_model_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if existing_external is not None:
|
||||
key = existing_external.key
|
||||
else:
|
||||
try:
|
||||
self.record_store.get_model(key)
|
||||
raise DuplicateModelException(
|
||||
f"Model key '{key}' already exists. Provide a different key to install this external model."
|
||||
)
|
||||
except UnknownModelException:
|
||||
pass
|
||||
|
||||
config = ExternalApiModelConfig(
|
||||
key=key,
|
||||
name=name,
|
||||
description=job.config_in.description,
|
||||
provider_id=provider_id,
|
||||
provider_model_id=provider_model_id,
|
||||
capabilities=capabilities,
|
||||
default_settings=default_settings,
|
||||
source=str(job.source),
|
||||
source_type=MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__],
|
||||
path="",
|
||||
hash="",
|
||||
file_size=0,
|
||||
)
|
||||
|
||||
if existing_external is not None:
|
||||
self.record_store.replace_model(existing_external.key, config)
|
||||
else:
|
||||
self.record_store.add_model(config)
|
||||
|
||||
job.config_out = self.record_store.get_model(config.key)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
|
||||
multifile_download_job = install_job._multifile_job
|
||||
if multifile_download_job and any(
|
||||
@@ -905,6 +994,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
"""Scan the models directory for missing models and return a list of them."""
|
||||
missing_models: list[AnyModelConfig] = []
|
||||
for model_config in self.record_store.all_models():
|
||||
if model_config.base == BaseModelType.External or model_config.format == ModelFormat.ExternalApi:
|
||||
continue
|
||||
if not (self.app_config.models_path / model_config.path).resolve().exists():
|
||||
missing_models.append(model_config)
|
||||
return missing_models
|
||||
@@ -1046,6 +1137,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
remote_files=remote_files,
|
||||
)
|
||||
|
||||
def _import_external_model(
|
||||
self,
|
||||
source: ExternalModelSource,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
return ModelInstallJob(
|
||||
id=self._next_id(),
|
||||
source=source,
|
||||
config_in=config or ModelRecordChanges(),
|
||||
local_path=self._app_config.models_path,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
def _import_remote_model(
|
||||
self,
|
||||
source: HFModelSource | URLModelSource,
|
||||
|
||||
@@ -13,6 +13,10 @@ from pydantic import BaseModel, Field
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
|
||||
from invokeai.backend.model_manager.configs.external_api import (
|
||||
ExternalApiModelDefaultSettings,
|
||||
ExternalModelCapabilities,
|
||||
)
|
||||
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
|
||||
from invokeai.backend.model_manager.configs.lora import LoraModelDefaultSettings
|
||||
from invokeai.backend.model_manager.configs.main import MainModelDefaultSettings
|
||||
@@ -87,8 +91,19 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
file_size: Optional[int] = Field(description="Size of model file", default=None)
|
||||
format: Optional[str] = Field(description="format of model file", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[MainModelDefaultSettings | LoraModelDefaultSettings | ControlAdapterDefaultSettings] = (
|
||||
Field(description="Default settings for this model", default=None)
|
||||
default_settings: Optional[
|
||||
MainModelDefaultSettings
|
||||
| LoraModelDefaultSettings
|
||||
| ControlAdapterDefaultSettings
|
||||
| ExternalApiModelDefaultSettings
|
||||
] = Field(description="Default settings for this model", default=None)
|
||||
|
||||
# External API model changes
|
||||
provider_id: Optional[str] = Field(description="External provider identifier", default=None)
|
||||
provider_model_id: Optional[str] = Field(description="External provider model identifier", default=None)
|
||||
capabilities: Optional[ExternalModelCapabilities] = Field(
|
||||
description="External model capabilities",
|
||||
default=None,
|
||||
)
|
||||
cpu_only: Optional[bool] = Field(description="Whether this model should run on CPU only", default=None)
|
||||
|
||||
|
||||
@@ -388,6 +388,8 @@ class ModelsInterface(InvocationContextInterface):
|
||||
submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
|
||||
self._raise_if_external(model)
|
||||
|
||||
message = f"Loading model {model.name}"
|
||||
if submodel_type:
|
||||
message += f" ({submodel_type.value})"
|
||||
@@ -417,12 +419,18 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
self._raise_if_external(configs[0])
|
||||
message = f"Loading model {name}"
|
||||
if submodel_type:
|
||||
message += f" ({submodel_type.value})"
|
||||
self._util.signal_progress(message)
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
|
||||
@staticmethod
|
||||
def _raise_if_external(model: AnyModelConfig) -> None:
|
||||
if model.base == BaseModelType.External or model.format == ModelFormat.ExternalApi:
|
||||
raise ValueError("External API models cannot be loaded from disk")
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
"""Get a model's config.
|
||||
|
||||
|
||||
@@ -14,43 +14,61 @@ def _get_processor_invocation_class(processor_type: str):
|
||||
"""Get the invocation class for a processor type."""
|
||||
# Import processor invocation classes on demand
|
||||
processor_class_map = {
|
||||
"canny_image_processor": lambda: __import__(
|
||||
"invokeai.app.invocations.canny", fromlist=["CannyEdgeDetectionInvocation"]
|
||||
).CannyEdgeDetectionInvocation,
|
||||
"hed_image_processor": lambda: __import__(
|
||||
"invokeai.app.invocations.hed", fromlist=["HEDEdgeDetectionInvocation"]
|
||||
).HEDEdgeDetectionInvocation,
|
||||
"mlsd_image_processor": lambda: __import__(
|
||||
"invokeai.app.invocations.mlsd", fromlist=["MLSDDetectionInvocation"]
|
||||
).MLSDDetectionInvocation,
|
||||
"depth_anything_image_processor": lambda: __import__(
|
||||
"invokeai.app.invocations.depth_anything", fromlist=["DepthAnythingDepthEstimationInvocation"]
|
||||
).DepthAnythingDepthEstimationInvocation,
|
||||
"normalbae_image_processor": lambda: __import__(
|
||||
"invokeai.app.invocations.normal_bae", fromlist=["NormalMapInvocation"]
|
||||
).NormalMapInvocation,
|
||||
"pidi_image_processor": lambda: __import__(
|
||||
"invokeai.app.invocations.pidi", fromlist=["PiDiNetEdgeDetectionInvocation"]
|
||||
).PiDiNetEdgeDetectionInvocation,
|
||||
"lineart_image_processor": lambda: __import__(
|
||||
"invokeai.app.invocations.lineart", fromlist=["LineartEdgeDetectionInvocation"]
|
||||
).LineartEdgeDetectionInvocation,
|
||||
"lineart_anime_image_processor": lambda: __import__(
|
||||
"invokeai.app.invocations.lineart_anime", fromlist=["LineartAnimeEdgeDetectionInvocation"]
|
||||
).LineartAnimeEdgeDetectionInvocation,
|
||||
"content_shuffle_image_processor": lambda: __import__(
|
||||
"invokeai.app.invocations.content_shuffle", fromlist=["ContentShuffleInvocation"]
|
||||
).ContentShuffleInvocation,
|
||||
"dw_openpose_image_processor": lambda: __import__(
|
||||
"invokeai.app.invocations.dw_openpose", fromlist=["DWOpenposeDetectionInvocation"]
|
||||
).DWOpenposeDetectionInvocation,
|
||||
"mediapipe_face_processor": lambda: __import__(
|
||||
"invokeai.app.invocations.mediapipe_face", fromlist=["MediaPipeFaceDetectionInvocation"]
|
||||
).MediaPipeFaceDetectionInvocation,
|
||||
"canny_image_processor": lambda: (
|
||||
__import__(
|
||||
"invokeai.app.invocations.canny", fromlist=["CannyEdgeDetectionInvocation"]
|
||||
).CannyEdgeDetectionInvocation
|
||||
),
|
||||
"hed_image_processor": lambda: (
|
||||
__import__(
|
||||
"invokeai.app.invocations.hed", fromlist=["HEDEdgeDetectionInvocation"]
|
||||
).HEDEdgeDetectionInvocation
|
||||
),
|
||||
"mlsd_image_processor": lambda: (
|
||||
__import__("invokeai.app.invocations.mlsd", fromlist=["MLSDDetectionInvocation"]).MLSDDetectionInvocation
|
||||
),
|
||||
"depth_anything_image_processor": lambda: (
|
||||
__import__(
|
||||
"invokeai.app.invocations.depth_anything", fromlist=["DepthAnythingDepthEstimationInvocation"]
|
||||
).DepthAnythingDepthEstimationInvocation
|
||||
),
|
||||
"normalbae_image_processor": lambda: (
|
||||
__import__("invokeai.app.invocations.normal_bae", fromlist=["NormalMapInvocation"]).NormalMapInvocation
|
||||
),
|
||||
"pidi_image_processor": lambda: (
|
||||
__import__(
|
||||
"invokeai.app.invocations.pidi", fromlist=["PiDiNetEdgeDetectionInvocation"]
|
||||
).PiDiNetEdgeDetectionInvocation
|
||||
),
|
||||
"lineart_image_processor": lambda: (
|
||||
__import__(
|
||||
"invokeai.app.invocations.lineart", fromlist=["LineartEdgeDetectionInvocation"]
|
||||
).LineartEdgeDetectionInvocation
|
||||
),
|
||||
"lineart_anime_image_processor": lambda: (
|
||||
__import__(
|
||||
"invokeai.app.invocations.lineart_anime", fromlist=["LineartAnimeEdgeDetectionInvocation"]
|
||||
).LineartAnimeEdgeDetectionInvocation
|
||||
),
|
||||
"content_shuffle_image_processor": lambda: (
|
||||
__import__(
|
||||
"invokeai.app.invocations.content_shuffle", fromlist=["ContentShuffleInvocation"]
|
||||
).ContentShuffleInvocation
|
||||
),
|
||||
"dw_openpose_image_processor": lambda: (
|
||||
__import__(
|
||||
"invokeai.app.invocations.dw_openpose", fromlist=["DWOpenposeDetectionInvocation"]
|
||||
).DWOpenposeDetectionInvocation
|
||||
),
|
||||
"mediapipe_face_processor": lambda: (
|
||||
__import__(
|
||||
"invokeai.app.invocations.mediapipe_face", fromlist=["MediaPipeFaceDetectionInvocation"]
|
||||
).MediaPipeFaceDetectionInvocation
|
||||
),
|
||||
# Note: zoe_depth_image_processor doesn't have a processor invocation implementation
|
||||
"color_map_image_processor": lambda: __import__(
|
||||
"invokeai.app.invocations.color_map", fromlist=["ColorMapInvocation"]
|
||||
).ColorMapInvocation,
|
||||
"color_map_image_processor": lambda: (
|
||||
__import__("invokeai.app.invocations.color_map", fromlist=["ColorMapInvocation"]).ColorMapInvocation
|
||||
),
|
||||
}
|
||||
|
||||
if processor_type in processor_class_map:
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
from typing import Literal, Self
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from invokeai.backend.model_manager.configs.base import Config_Base
|
||||
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelSourceType, ModelType
|
||||
|
||||
ExternalGenerationMode = Literal["txt2img", "img2img", "inpaint"]
|
||||
ExternalMaskFormat = Literal["alpha", "binary", "none"]
|
||||
ExternalPanelControlName = Literal["reference_images", "dimensions", "seed"]
|
||||
|
||||
|
||||
class ExternalImageSize(BaseModel):
|
||||
width: int = Field(gt=0)
|
||||
height: int = Field(gt=0)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ExternalResolutionPreset(BaseModel):
|
||||
label: str = Field(min_length=1, description="Display label, e.g. '1:1 (1K)'")
|
||||
aspect_ratio: str = Field(min_length=1, description="Aspect ratio string, e.g. '1:1'")
|
||||
image_size: str = Field(min_length=1, description="Image size preset, e.g. '1K'")
|
||||
width: int = Field(gt=0)
|
||||
height: int = Field(gt=0)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ExternalModelCapabilities(BaseModel):
|
||||
modes: list[ExternalGenerationMode] = Field(default_factory=lambda: ["txt2img"])
|
||||
supports_reference_images: bool = Field(default=False)
|
||||
supports_negative_prompt: bool = Field(default=True)
|
||||
supports_seed: bool = Field(default=False)
|
||||
supports_guidance: bool = Field(default=False)
|
||||
supports_steps: bool = Field(default=False)
|
||||
max_images_per_request: int | None = Field(default=None, gt=0)
|
||||
max_image_size: ExternalImageSize | None = Field(default=None)
|
||||
allowed_aspect_ratios: list[str] | None = Field(default=None)
|
||||
aspect_ratio_sizes: dict[str, ExternalImageSize] | None = Field(default=None)
|
||||
resolution_presets: list[ExternalResolutionPreset] | None = Field(default=None)
|
||||
max_reference_images: int | None = Field(default=None, gt=0)
|
||||
mask_format: ExternalMaskFormat = Field(default="none")
|
||||
input_image_required_for: list[ExternalGenerationMode] | None = Field(default=None)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ExternalApiModelDefaultSettings(BaseModel):
|
||||
width: int | None = Field(default=None, gt=0)
|
||||
height: int | None = Field(default=None, gt=0)
|
||||
num_images: int | None = Field(default=None, gt=0)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ExternalModelPanelControl(BaseModel):
|
||||
name: ExternalPanelControlName
|
||||
slider_min: float | None = Field(default=None)
|
||||
slider_max: float | None = Field(default=None)
|
||||
number_input_min: float | None = Field(default=None)
|
||||
number_input_max: float | None = Field(default=None)
|
||||
fine_step: float | None = Field(default=None)
|
||||
coarse_step: float | None = Field(default=None)
|
||||
marks: list[float] | None = Field(default=None)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ExternalModelPanelSchema(BaseModel):
|
||||
prompts: list[ExternalModelPanelControl] = Field(default_factory=list)
|
||||
image: list[ExternalModelPanelControl] = Field(default_factory=list)
|
||||
generation: list[ExternalModelPanelControl] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ExternalApiModelConfig(Config_Base):
|
||||
base: Literal[BaseModelType.External] = Field(default=BaseModelType.External)
|
||||
type: Literal[ModelType.ExternalImageGenerator] = Field(default=ModelType.ExternalImageGenerator)
|
||||
format: Literal[ModelFormat.ExternalApi] = Field(default=ModelFormat.ExternalApi)
|
||||
|
||||
provider_id: str = Field(min_length=1, description="External provider ID")
|
||||
provider_model_id: str = Field(min_length=1, description="Provider-specific model ID")
|
||||
capabilities: ExternalModelCapabilities = Field(description="Provider capability matrix")
|
||||
default_settings: ExternalApiModelDefaultSettings | None = Field(default=None)
|
||||
panel_schema: ExternalModelPanelSchema | None = Field(default=None)
|
||||
tags: list[str] | None = Field(default=None)
|
||||
is_default: bool = Field(default=False)
|
||||
|
||||
source_type: ModelSourceType = Field(default=ModelSourceType.External)
|
||||
path: str = Field(default="")
|
||||
source: str = Field(default="")
|
||||
hash: str = Field(default="")
|
||||
file_size: int = Field(default=0, ge=0)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _populate_external_fields(self) -> "ExternalApiModelConfig":
|
||||
if not self.path:
|
||||
self.path = f"external://{self.provider_id}/{self.provider_model_id}"
|
||||
if not self.source:
|
||||
self.source = self.path
|
||||
if not self.hash:
|
||||
self.hash = f"external:{self.provider_id}:{self.provider_model_id}"
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, object]) -> Self:
|
||||
raise NotAMatchError("external API models are not probed from disk")
|
||||
|
||||
@@ -26,6 +26,7 @@ from invokeai.backend.model_manager.configs.controlnet import (
|
||||
ControlNet_Diffusers_SD2_Config,
|
||||
ControlNet_Diffusers_SDXL_Config,
|
||||
)
|
||||
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig
|
||||
from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoint_Config
|
||||
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
|
||||
from invokeai.backend.model_manager.configs.ip_adapter import (
|
||||
@@ -268,6 +269,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()],
|
||||
Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()],
|
||||
Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()],
|
||||
Annotated[ExternalApiModelConfig, ExternalApiModelConfig.get_tag()],
|
||||
# Unknown model (fallback)
|
||||
Annotated[Unknown_Config, Unknown_Config.get_tag()],
|
||||
],
|
||||
|
||||
@@ -2,6 +2,13 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.backend.model_manager.configs.external_api import (
|
||||
ExternalApiModelDefaultSettings,
|
||||
ExternalImageSize,
|
||||
ExternalModelCapabilities,
|
||||
ExternalModelPanelSchema,
|
||||
ExternalResolutionPreset,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
@@ -20,6 +27,9 @@ class StarterModelWithoutDependencies(BaseModel):
|
||||
format: Optional[ModelFormat] = None
|
||||
variant: Optional[AnyVariant] = None
|
||||
is_installed: bool = False
|
||||
capabilities: ExternalModelCapabilities | None = None
|
||||
default_settings: ExternalApiModelDefaultSettings | None = None
|
||||
panel_schema: ExternalModelPanelSchema | None = None
|
||||
# allows us to track what models a user has installed across name changes within starter models
|
||||
# if you update a starter model name, please add the old one to this list for that starter model
|
||||
previous_names: list[str] = []
|
||||
@@ -1001,6 +1011,367 @@ z_image_controlnet_tile = StarterModel(
|
||||
)
|
||||
# endregion
|
||||
|
||||
# region External API
|
||||
GEMINI_3_IMAGE_ALLOWED_ASPECT_RATIOS = [
|
||||
"1:1",
|
||||
"1:4",
|
||||
"1:8",
|
||||
"2:3",
|
||||
"3:2",
|
||||
"3:4",
|
||||
"4:1",
|
||||
"4:3",
|
||||
"4:5",
|
||||
"5:4",
|
||||
"8:1",
|
||||
"9:16",
|
||||
"16:9",
|
||||
"21:9",
|
||||
]
|
||||
GEMINI_3_IMAGE_MAX_SIZE = ExternalImageSize(width=4096, height=4096)
|
||||
|
||||
|
||||
def _gemini_3_resolution_presets(
|
||||
image_sizes: list[str],
|
||||
aspect_ratios: list[str] | None = None,
|
||||
) -> list[ExternalResolutionPreset]:
|
||||
"""Build resolution presets for Gemini 3 models.
|
||||
|
||||
Each preset combines an aspect ratio with an image size preset (512/1K/2K/4K).
|
||||
Pixel dimensions are approximations based on the preset name (longest side).
|
||||
"""
|
||||
if aspect_ratios is None:
|
||||
aspect_ratios = GEMINI_3_IMAGE_ALLOWED_ASPECT_RATIOS
|
||||
base_pixels = {"512": 512, "1K": 1024, "2K": 2048, "4K": 4096}
|
||||
presets: list[ExternalResolutionPreset] = []
|
||||
for image_size in image_sizes:
|
||||
base = base_pixels[image_size]
|
||||
for ratio_str in aspect_ratios:
|
||||
w_part, h_part = (int(x) for x in ratio_str.split(":"))
|
||||
if w_part >= h_part:
|
||||
w = base
|
||||
h = max(1, round(base * h_part / w_part))
|
||||
else:
|
||||
h = base
|
||||
w = max(1, round(base * w_part / h_part))
|
||||
presets.append(
|
||||
ExternalResolutionPreset(
|
||||
label=f"{ratio_str} ({image_size}) — {w}\u00d7{h}",
|
||||
aspect_ratio=ratio_str,
|
||||
image_size=image_size,
|
||||
width=w,
|
||||
height=h,
|
||||
)
|
||||
)
|
||||
return presets
|
||||
|
||||
|
||||
GEMINI_3_PRO_RESOLUTION_PRESETS = _gemini_3_resolution_presets(["1K", "2K", "4K"])
|
||||
GEMINI_3_1_FLASH_RESOLUTION_PRESETS = _gemini_3_resolution_presets(["512", "1K", "2K", "4K"])
|
||||
|
||||
gemini_flash_image = StarterModel(
|
||||
name="Gemini 2.5 Flash Image",
|
||||
base=BaseModelType.External,
|
||||
source="external://gemini/gemini-2.5-flash-image",
|
||||
description="Google Gemini 2.5 Flash image generation model (external API). Requires a configured Gemini API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img", "img2img", "inpaint"],
|
||||
supports_seed=True,
|
||||
supports_reference_images=True,
|
||||
max_images_per_request=1,
|
||||
allowed_aspect_ratios=[
|
||||
"1:1",
|
||||
"2:3",
|
||||
"3:2",
|
||||
"3:4",
|
||||
"4:3",
|
||||
"4:5",
|
||||
"5:4",
|
||||
"9:16",
|
||||
"16:9",
|
||||
"21:9",
|
||||
],
|
||||
aspect_ratio_sizes={
|
||||
"1:1": ExternalImageSize(width=1024, height=1024),
|
||||
"2:3": ExternalImageSize(width=832, height=1248),
|
||||
"3:2": ExternalImageSize(width=1248, height=832),
|
||||
"3:4": ExternalImageSize(width=864, height=1184),
|
||||
"4:3": ExternalImageSize(width=1184, height=864),
|
||||
"4:5": ExternalImageSize(width=896, height=1152),
|
||||
"5:4": ExternalImageSize(width=1152, height=896),
|
||||
"9:16": ExternalImageSize(width=768, height=1344),
|
||||
"16:9": ExternalImageSize(width=1344, height=768),
|
||||
"21:9": ExternalImageSize(width=1536, height=672),
|
||||
},
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
|
||||
panel_schema=ExternalModelPanelSchema(prompts=[{"name": "reference_images"}], image=[{"name": "dimensions"}]),
|
||||
)
|
||||
gemini_pro_image_preview = StarterModel(
|
||||
name="Gemini 3 Pro Image Preview",
|
||||
base=BaseModelType.External,
|
||||
source="external://gemini/gemini-3-pro-image-preview",
|
||||
description="Google Gemini 3 Pro image generation preview model (external API). Supports up to 14 reference images, including up to 6 object references and up to 5 character references. Supports 1K/2K/4K resolution presets. Requires a configured Gemini API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img", "img2img", "inpaint"],
|
||||
supports_seed=True,
|
||||
supports_reference_images=True,
|
||||
max_reference_images=14,
|
||||
max_images_per_request=1,
|
||||
max_image_size=GEMINI_3_IMAGE_MAX_SIZE,
|
||||
allowed_aspect_ratios=GEMINI_3_IMAGE_ALLOWED_ASPECT_RATIOS,
|
||||
resolution_presets=GEMINI_3_PRO_RESOLUTION_PRESETS,
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
|
||||
panel_schema=ExternalModelPanelSchema(prompts=[{"name": "reference_images"}], image=[{"name": "dimensions"}]),
|
||||
)
|
||||
gemini_3_1_flash_image_preview = StarterModel(
|
||||
name="Gemini 3.1 Flash Image Preview",
|
||||
base=BaseModelType.External,
|
||||
source="external://gemini/gemini-3.1-flash-image-preview",
|
||||
description="Google Gemini 3.1 Flash image generation preview model (external API). Supports up to 14 reference images, including up to 10 object references and up to 4 character references. Supports 512/1K/2K/4K resolution presets. Requires a configured Gemini API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img", "img2img", "inpaint"],
|
||||
supports_seed=True,
|
||||
supports_reference_images=True,
|
||||
max_reference_images=14,
|
||||
max_images_per_request=1,
|
||||
max_image_size=GEMINI_3_IMAGE_MAX_SIZE,
|
||||
allowed_aspect_ratios=GEMINI_3_IMAGE_ALLOWED_ASPECT_RATIOS,
|
||||
resolution_presets=GEMINI_3_1_FLASH_RESOLUTION_PRESETS,
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
|
||||
panel_schema=ExternalModelPanelSchema(prompts=[{"name": "reference_images"}], image=[{"name": "dimensions"}]),
|
||||
)
|
||||
QWEN_IMAGE_2_ALLOWED_ASPECT_RATIOS = ["1:1", "4:3", "3:4", "16:9", "9:16"]
|
||||
QWEN_IMAGE_MAX_ALLOWED_ASPECT_RATIOS = ["1:1", "4:3", "3:4", "16:9", "9:16"]
|
||||
WAN_V2_ALLOWED_ASPECT_RATIOS = ["1:1", "4:3", "3:4", "16:9", "9:16"]
|
||||
|
||||
alibabacloud_qwen_image_2_pro = StarterModel(
|
||||
name="Qwen Image 2.0 Pro",
|
||||
base=BaseModelType.External,
|
||||
source="external://alibabacloud/qwen-image-2.0-pro",
|
||||
description="Alibaba Cloud Qwen Image 2.0 Pro model (external API). Best quality text-to-image with excellent bilingual text rendering. Requires a configured Alibaba Cloud DashScope API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img"],
|
||||
supports_negative_prompt=True,
|
||||
supports_seed=True,
|
||||
max_images_per_request=4,
|
||||
allowed_aspect_ratios=QWEN_IMAGE_2_ALLOWED_ASPECT_RATIOS,
|
||||
aspect_ratio_sizes={
|
||||
"1:1": ExternalImageSize(width=2048, height=2048),
|
||||
"4:3": ExternalImageSize(width=2368, height=1728),
|
||||
"3:4": ExternalImageSize(width=1728, height=2368),
|
||||
"16:9": ExternalImageSize(width=2688, height=1536),
|
||||
"9:16": ExternalImageSize(width=1536, height=2688),
|
||||
},
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=2048, height=2048, num_images=1),
|
||||
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
|
||||
)
|
||||
alibabacloud_qwen_image_2 = StarterModel(
|
||||
name="Qwen Image 2.0",
|
||||
base=BaseModelType.External,
|
||||
source="external://alibabacloud/qwen-image-2.0",
|
||||
description="Alibaba Cloud Qwen Image 2.0 model (external API). Fast text-to-image with good bilingual text rendering. Requires a configured Alibaba Cloud DashScope API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img"],
|
||||
supports_negative_prompt=True,
|
||||
supports_seed=True,
|
||||
max_images_per_request=4,
|
||||
allowed_aspect_ratios=QWEN_IMAGE_2_ALLOWED_ASPECT_RATIOS,
|
||||
aspect_ratio_sizes={
|
||||
"1:1": ExternalImageSize(width=2048, height=2048),
|
||||
"4:3": ExternalImageSize(width=2368, height=1728),
|
||||
"3:4": ExternalImageSize(width=1728, height=2368),
|
||||
"16:9": ExternalImageSize(width=2688, height=1536),
|
||||
"9:16": ExternalImageSize(width=1536, height=2688),
|
||||
},
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=2048, height=2048, num_images=1),
|
||||
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
|
||||
)
|
||||
alibabacloud_qwen_image_max = StarterModel(
|
||||
name="Qwen Image Max",
|
||||
base=BaseModelType.External,
|
||||
source="external://alibabacloud/qwen-image-max",
|
||||
description="Alibaba Cloud Qwen Image Max model (external API). High quality text-to-image generation. Requires a configured Alibaba Cloud DashScope API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img"],
|
||||
supports_negative_prompt=True,
|
||||
supports_seed=True,
|
||||
max_images_per_request=4,
|
||||
allowed_aspect_ratios=QWEN_IMAGE_MAX_ALLOWED_ASPECT_RATIOS,
|
||||
aspect_ratio_sizes={
|
||||
"1:1": ExternalImageSize(width=1328, height=1328),
|
||||
"4:3": ExternalImageSize(width=1472, height=1104),
|
||||
"3:4": ExternalImageSize(width=1104, height=1472),
|
||||
"16:9": ExternalImageSize(width=1664, height=928),
|
||||
"9:16": ExternalImageSize(width=928, height=1664),
|
||||
},
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=1328, height=1328, num_images=1),
|
||||
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
|
||||
)
|
||||
alibabacloud_wan26_t2i = StarterModel(
|
||||
name="Wan 2.6 Text-to-Image",
|
||||
base=BaseModelType.External,
|
||||
source="external://alibabacloud/wan2.6-t2i",
|
||||
description="Alibaba Cloud Wan 2.6 text-to-image model (external API). Photorealistic image generation. Requires a configured Alibaba Cloud DashScope API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img"],
|
||||
supports_negative_prompt=True,
|
||||
supports_seed=True,
|
||||
max_images_per_request=4,
|
||||
allowed_aspect_ratios=WAN_V2_ALLOWED_ASPECT_RATIOS,
|
||||
aspect_ratio_sizes={
|
||||
"1:1": ExternalImageSize(width=1024, height=1024),
|
||||
"4:3": ExternalImageSize(width=1440, height=1080),
|
||||
"3:4": ExternalImageSize(width=1080, height=1440),
|
||||
"16:9": ExternalImageSize(width=1440, height=810),
|
||||
"9:16": ExternalImageSize(width=810, height=1440),
|
||||
},
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
|
||||
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
|
||||
)
|
||||
alibabacloud_qwen_image_edit_max = StarterModel(
|
||||
name="Qwen Image Edit Max",
|
||||
base=BaseModelType.External,
|
||||
source="external://alibabacloud/qwen-image-edit-max",
|
||||
description="Alibaba Cloud Qwen Image Edit Max model (external API). Image editing with industrial design and geometric reasoning. Requires a configured Alibaba Cloud DashScope API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["img2img"],
|
||||
supports_negative_prompt=True,
|
||||
supports_seed=True,
|
||||
max_images_per_request=4,
|
||||
allowed_aspect_ratios=QWEN_IMAGE_2_ALLOWED_ASPECT_RATIOS,
|
||||
aspect_ratio_sizes={
|
||||
"1:1": ExternalImageSize(width=2048, height=2048),
|
||||
"4:3": ExternalImageSize(width=2368, height=1728),
|
||||
"3:4": ExternalImageSize(width=1728, height=2368),
|
||||
"16:9": ExternalImageSize(width=2688, height=1536),
|
||||
"9:16": ExternalImageSize(width=1536, height=2688),
|
||||
},
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=2048, height=2048, num_images=1),
|
||||
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
|
||||
)
|
||||
OPENAI_GPT_IMAGE_ASPECT_RATIOS = ["1:1", "3:2", "2:3"]
|
||||
OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES = {
|
||||
"1:1": ExternalImageSize(width=1024, height=1024),
|
||||
"3:2": ExternalImageSize(width=1536, height=1024),
|
||||
"2:3": ExternalImageSize(width=1024, height=1536),
|
||||
}
|
||||
OPENAI_GPT_IMAGE_PANEL_SCHEMA = ExternalModelPanelSchema(
|
||||
prompts=[{"name": "reference_images"}], image=[{"name": "dimensions"}]
|
||||
)
|
||||
|
||||
openai_gpt_image_1_5 = StarterModel(
|
||||
name="GPT Image 1.5",
|
||||
base=BaseModelType.External,
|
||||
source="external://openai/gpt-image-1.5",
|
||||
description="OpenAI GPT-Image-1.5 image generation model. Fastest and most affordable GPT image model. Requires a configured OpenAI API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img", "img2img", "inpaint"],
|
||||
supports_reference_images=True,
|
||||
max_images_per_request=10,
|
||||
allowed_aspect_ratios=OPENAI_GPT_IMAGE_ASPECT_RATIOS,
|
||||
aspect_ratio_sizes=OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES,
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
|
||||
panel_schema=OPENAI_GPT_IMAGE_PANEL_SCHEMA,
|
||||
)
|
||||
openai_gpt_image_1 = StarterModel(
|
||||
name="GPT Image 1",
|
||||
base=BaseModelType.External,
|
||||
source="external://openai/gpt-image-1",
|
||||
description="OpenAI GPT-Image-1 image generation model. High quality image generation. Requires a configured OpenAI API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img", "img2img", "inpaint"],
|
||||
supports_reference_images=True,
|
||||
max_images_per_request=10,
|
||||
allowed_aspect_ratios=OPENAI_GPT_IMAGE_ASPECT_RATIOS,
|
||||
aspect_ratio_sizes=OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES,
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
|
||||
panel_schema=OPENAI_GPT_IMAGE_PANEL_SCHEMA,
|
||||
)
|
||||
openai_gpt_image_1_mini = StarterModel(
|
||||
name="GPT Image 1 Mini",
|
||||
base=BaseModelType.External,
|
||||
source="external://openai/gpt-image-1-mini",
|
||||
description="OpenAI GPT-Image-1-Mini image generation model. Cost-efficient option, 80%% cheaper than GPT-Image-1. Requires a configured OpenAI API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img", "img2img", "inpaint"],
|
||||
supports_reference_images=True,
|
||||
max_images_per_request=10,
|
||||
allowed_aspect_ratios=OPENAI_GPT_IMAGE_ASPECT_RATIOS,
|
||||
aspect_ratio_sizes=OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES,
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
|
||||
panel_schema=OPENAI_GPT_IMAGE_PANEL_SCHEMA,
|
||||
)
|
||||
openai_dall_e_3 = StarterModel(
|
||||
name="DALL-E 3",
|
||||
base=BaseModelType.External,
|
||||
source="external://openai/dall-e-3",
|
||||
description="OpenAI DALL-E 3 image generation model. Supports vivid and natural styles. Only text-to-image, no editing. Requires a configured OpenAI API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img"],
|
||||
max_images_per_request=1,
|
||||
allowed_aspect_ratios=["1:1", "7:4", "4:7"],
|
||||
aspect_ratio_sizes={
|
||||
"1:1": ExternalImageSize(width=1024, height=1024),
|
||||
"7:4": ExternalImageSize(width=1792, height=1024),
|
||||
"4:7": ExternalImageSize(width=1024, height=1792),
|
||||
},
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
|
||||
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
|
||||
)
|
||||
openai_dall_e_2 = StarterModel(
|
||||
name="DALL-E 2",
|
||||
base=BaseModelType.External,
|
||||
source="external://openai/dall-e-2",
|
||||
description="OpenAI DALL-E 2 image generation model. Supports square images only. Requires a configured OpenAI API key and may incur provider usage costs.",
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img", "img2img", "inpaint"],
|
||||
max_images_per_request=10,
|
||||
allowed_aspect_ratios=["1:1"],
|
||||
aspect_ratio_sizes={
|
||||
"1:1": ExternalImageSize(width=1024, height=1024),
|
||||
},
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1),
|
||||
panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]),
|
||||
)
|
||||
# region Anima
|
||||
anima_qwen3_encoder = StarterModel(
|
||||
name="Anima Qwen3 0.6B Text Encoder",
|
||||
@@ -1140,6 +1511,19 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
z_image_qwen3_encoder_quantized,
|
||||
z_image_controlnet_union,
|
||||
z_image_controlnet_tile,
|
||||
gemini_flash_image,
|
||||
gemini_pro_image_preview,
|
||||
gemini_3_1_flash_image_preview,
|
||||
openai_gpt_image_1_5,
|
||||
openai_gpt_image_1,
|
||||
openai_gpt_image_1_mini,
|
||||
openai_dall_e_3,
|
||||
openai_dall_e_2,
|
||||
alibabacloud_qwen_image_2_pro,
|
||||
alibabacloud_qwen_image_2,
|
||||
alibabacloud_qwen_image_max,
|
||||
alibabacloud_wan26_t2i,
|
||||
alibabacloud_qwen_image_edit_max,
|
||||
anima_preview3,
|
||||
anima_qwen3_encoder,
|
||||
anima_vae,
|
||||
|
||||
@@ -52,6 +52,8 @@ class BaseModelType(str, Enum):
|
||||
"""Indicates the model is associated with CogView 4 model architecture."""
|
||||
ZImage = "z-image"
|
||||
"""Indicates the model is associated with Z-Image model architecture, including Z-Image-Turbo."""
|
||||
External = "external"
|
||||
"""Indicates the model is hosted by an external provider."""
|
||||
QwenImage = "qwen-image"
|
||||
"""Indicates the model is associated with Qwen Image Edit 2511 model architecture."""
|
||||
Anima = "anima"
|
||||
@@ -80,6 +82,7 @@ class ModelType(str, Enum):
|
||||
SigLIP = "siglip"
|
||||
FluxRedux = "flux_redux"
|
||||
LlavaOnevision = "llava_onevision"
|
||||
ExternalImageGenerator = "external_image_generator"
|
||||
Unknown = "unknown"
|
||||
|
||||
|
||||
@@ -187,6 +190,7 @@ class ModelFormat(str, Enum):
|
||||
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
GGUFQuantized = "gguf_quantized"
|
||||
ExternalApi = "external_api"
|
||||
Unknown = "unknown"
|
||||
|
||||
|
||||
@@ -215,6 +219,7 @@ class ModelSourceType(str, Enum):
|
||||
Path = "path"
|
||||
Url = "url"
|
||||
HFRepoID = "hf_repo_id"
|
||||
External = "external"
|
||||
|
||||
|
||||
class FluxLoRAFormat(str, Enum):
|
||||
|
||||
@@ -1115,6 +1115,22 @@
|
||||
"fileSize": "File Size",
|
||||
"filterModels": "Filter models",
|
||||
"fluxRedux": "FLUX Redux",
|
||||
"externalImageGenerator": "External Image Generator",
|
||||
"externalProviders": "External Providers",
|
||||
"externalSetupTitle": "External Providers Setup",
|
||||
"externalSetupDescription": "Connect an API key to enable external image generation. External starter models auto-install when a provider is configured.",
|
||||
"externalInstallDefaults": "Auto-install starter models",
|
||||
"externalProvidersUnavailable": "External providers are not available in this build.",
|
||||
"externalSetupFooter": "An API key is required. External providers use remote APIs; usage may incur provider-side costs.",
|
||||
"externalProviderCardDescription": "Configure {{providerId}} credentials for external image generation.",
|
||||
"externalApiKey": "API Key",
|
||||
"externalApiKeyPlaceholder": "Paste your API key",
|
||||
"externalApiKeyPlaceholderSet": "API key configured",
|
||||
"externalApiKeyHelper": "Stored in your InvokeAI config file.",
|
||||
"externalBaseUrl": "Base URL (optional)",
|
||||
"externalBaseUrlPlaceholder": "https://...",
|
||||
"externalBaseUrlHelper": "Override the default API base URL if needed.",
|
||||
"externalResetHelper": "Clear API key and base URL.",
|
||||
"height": "Height",
|
||||
"huggingFace": "HuggingFace",
|
||||
"huggingFacePlaceholder": "owner/model-name",
|
||||
@@ -1182,6 +1198,21 @@
|
||||
"modelUpdated": "Model Updated",
|
||||
"modelUpdateFailed": "Model Update Failed",
|
||||
"name": "Name",
|
||||
"externalProvider": "External Provider",
|
||||
"externalCapabilities": "External Capabilities",
|
||||
"externalDefaults": "External Defaults",
|
||||
"providerId": "Provider ID",
|
||||
"providerModelId": "Provider Model ID",
|
||||
"supportedModes": "Supported Modes",
|
||||
"supportsNegativePrompt": "Supports Negative Prompt",
|
||||
"supportsReferenceImages": "Supports Reference Images",
|
||||
"supportsSeed": "Supports Seed",
|
||||
"supportsGuidance": "Supports Guidance",
|
||||
"maxImagesPerRequest": "Max Images Per Request",
|
||||
"maxReferenceImages": "Max Reference Images",
|
||||
"maxImageWidth": "Max Image Width",
|
||||
"maxImageHeight": "Max Image Height",
|
||||
"numImages": "Num Images",
|
||||
"modelPickerFallbackNoModelsInstalled": "No models installed.",
|
||||
"modelPickerFallbackNoModelsInstalled2": "Visit the <LinkComponent>Model Manager</LinkComponent> to install models.",
|
||||
"modelPickerFallbackNoModelsInstalledNonAdmin": "No models installed. Ask your InvokeAI administrator (<AdminEmailLink />) to install some models.",
|
||||
@@ -1226,6 +1257,7 @@
|
||||
"urlDescription": "Install models from a URL or local file path. Perfect for specific models you want to add.",
|
||||
"huggingFaceDescription": "Browse and install models directly from HuggingFace repositories.",
|
||||
"scanFolderDescription": "Scan a local folder to automatically detect and install models.",
|
||||
"externalDescription": "Connect a Gemini or OpenAI API key to enable external generation. Usage may incur provider-side costs.",
|
||||
"recommendedModels": "Recommended Models",
|
||||
"exploreStarter": "Or browse all available starter models",
|
||||
"quickStart": "Quick Start Bundles",
|
||||
@@ -1544,6 +1576,7 @@
|
||||
"copyImage": "Copy Image",
|
||||
"denoisingStrength": "Denoising Strength",
|
||||
"disabledNoRasterContent": "Disabled (No Raster Content)",
|
||||
"disabledNotSupported": "Not supported by model",
|
||||
"downloadImage": "Download Image",
|
||||
"general": "General",
|
||||
"guidance": "Guidance",
|
||||
@@ -1661,6 +1694,7 @@
|
||||
"boxBlur": "Box Blur",
|
||||
"staged": "Staged",
|
||||
"resolution": "Resolution",
|
||||
"imageSize": "Image Size",
|
||||
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your <LinkComponent>account settings</LinkComponent> to upgrade."
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
@@ -1737,7 +1771,11 @@
|
||||
"intermediatesCleared_one": "Cleared {{count}} Intermediate",
|
||||
"intermediatesCleared_other": "Cleared {{count}} Intermediates",
|
||||
"intermediatesClearedFailed": "Problem Clearing Intermediates",
|
||||
"reloadingIn": "Reloading in"
|
||||
"reloadingIn": "Reloading in",
|
||||
"externalProviders": "External Providers",
|
||||
"externalProviderConfigured": "Configured",
|
||||
"externalProviderNotConfigured": "API Key Required",
|
||||
"externalProviderNotConfiguredHint": "Add your API key in Model Manager or the server config to enable this provider."
|
||||
},
|
||||
"toast": {
|
||||
"addedToBoard": "Added to board {{name}}'s assets",
|
||||
|
||||
@@ -7,10 +7,12 @@ import {
|
||||
animaQwen3EncoderModelSelected,
|
||||
animaT5EncoderModelSelected,
|
||||
animaVaeModelSelected,
|
||||
aspectRatioIdChanged,
|
||||
kleinQwen3EncoderModelSelected,
|
||||
kleinVaeModelSelected,
|
||||
modelChanged,
|
||||
qwenImageComponentSourceSelected,
|
||||
resolutionPresetSelected,
|
||||
setZImageScheduler,
|
||||
syncedToOptimalDimension,
|
||||
vaeSelected,
|
||||
@@ -30,6 +32,7 @@ import {
|
||||
} from 'features/controlLayers/store/selectors';
|
||||
import {
|
||||
getEntityIdentifier,
|
||||
isAspectRatioID,
|
||||
isFlux2ReferenceImageConfig,
|
||||
isQwenImageReferenceImageConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
@@ -59,7 +62,7 @@ import {
|
||||
selectZImageDiffusersModels,
|
||||
} from 'services/api/hooks/modelsByType';
|
||||
import type { FLUXKontextModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
|
||||
import { isFluxKontextModelConfig, isFluxReduxModelConfig } from 'services/api/types';
|
||||
import { isExternalApiModelConfig, isFluxKontextModelConfig, isFluxReduxModelConfig } from 'services/api/types';
|
||||
|
||||
const log = logger('models');
|
||||
|
||||
@@ -281,7 +284,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
}
|
||||
}
|
||||
|
||||
if (SUPPORTS_REF_IMAGES_BASE_MODELS.includes(newModel.base)) {
|
||||
if (newModel.base !== 'external' && SUPPORTS_REF_IMAGES_BASE_MODELS.includes(newModel.base)) {
|
||||
// Handle incompatible reference image models - switch to first compatible model, with some smart logic
|
||||
// to choose the best available model based on the new main model.
|
||||
const allRefImageModels = selectGlobalRefImageModels(state).filter(({ base }) => base === newBase);
|
||||
@@ -529,6 +532,34 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
dispatch(bboxSyncedToOptimalDimension());
|
||||
}
|
||||
}
|
||||
|
||||
// When switching to an external model, sync bbox to the model's first preset dimensions
|
||||
if (newBase === 'external') {
|
||||
const modelConfigsResult = selectModelConfigsQuery(getState());
|
||||
if (modelConfigsResult.data) {
|
||||
const newModelConfig = modelConfigsAdapterSelectors.selectById(modelConfigsResult.data, newModel.key);
|
||||
if (newModelConfig && isExternalApiModelConfig(newModelConfig)) {
|
||||
const { aspect_ratio_sizes, resolution_presets } = newModelConfig.capabilities;
|
||||
if (resolution_presets && resolution_presets.length > 0) {
|
||||
const firstPreset = resolution_presets[0]!;
|
||||
dispatch(
|
||||
resolutionPresetSelected({
|
||||
imageSize: firstPreset.image_size,
|
||||
aspectRatio: firstPreset.aspect_ratio,
|
||||
width: firstPreset.width,
|
||||
height: firstPreset.height,
|
||||
})
|
||||
);
|
||||
} else if (aspect_ratio_sizes) {
|
||||
const firstRatio = Object.keys(aspect_ratio_sizes)[0];
|
||||
const firstSize = firstRatio ? aspect_ratio_sizes[firstRatio] : undefined;
|
||||
if (firstRatio && firstSize && isAspectRatioID(firstRatio)) {
|
||||
dispatch(aspectRatioIdChanged({ id: firstRatio, fixedSize: firstSize }));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -11,7 +11,7 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import WavyLine from 'common/components/WavyLine';
|
||||
import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectImg2imgStrength, selectIsExternal, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectActiveRasterLayerEntities } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -37,6 +37,7 @@ export const ParamDenoisingStrength = memo(() => {
|
||||
const img2imgStrength = useAppSelector(selectImg2imgStrength);
|
||||
const dispatch = useAppDispatch();
|
||||
const hasRasterLayersWithContent = useAppSelector(selectHasRasterLayersWithContent);
|
||||
const isExternal = useAppSelector(selectIsExternal);
|
||||
const selectedModelConfig = useSelectedModelConfig();
|
||||
|
||||
const onChange = useCallback(
|
||||
@@ -55,12 +56,16 @@ export const ParamDenoisingStrength = memo(() => {
|
||||
// Denoising strength does nothing if there are no raster layers w/ content
|
||||
return true;
|
||||
}
|
||||
if (isExternal) {
|
||||
// External models don't support denoise strength - they handle img2img via prompt
|
||||
return true;
|
||||
}
|
||||
if (selectedModelConfig && isFluxFillMainModelModelConfig(selectedModelConfig)) {
|
||||
// Denoising strength is ignored by FLUX Fill, which is indicated by the variant being 'inpaint'
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}, [hasRasterLayersWithContent, selectedModelConfig]);
|
||||
}, [hasRasterLayersWithContent, isExternal, selectedModelConfig]);
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={isDisabled} p={1} justifyContent="space-between" h={8}>
|
||||
@@ -96,7 +101,9 @@ export const ParamDenoisingStrength = memo(() => {
|
||||
</>
|
||||
) : (
|
||||
<Flex alignItems="center">
|
||||
<Badge opacity="0.6">{t('parameters.disabledNoRasterContent')}</Badge>
|
||||
<Badge opacity="0.6">
|
||||
{isExternal ? t('parameters.disabledNotSupported') : t('parameters.disabledNoRasterContent')}
|
||||
</Badge>
|
||||
</Flex>
|
||||
)}
|
||||
</FormControl>
|
||||
|
||||
@@ -16,6 +16,7 @@ import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi';
|
||||
import { useImageDTOFromCroppableImage } from 'services/api/endpoints/images';
|
||||
import { isExternalApiModelConfig } from 'services/api/types';
|
||||
|
||||
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
|
||||
|
||||
@@ -73,18 +74,19 @@ export const RefImagePreview = memo(() => {
|
||||
const selectedEntityId = useAppSelector(selectSelectedRefEntityId);
|
||||
const isPanelOpen = useAppSelector(selectIsRefImagePanelOpen);
|
||||
const [showWeightDisplay, setShowWeightDisplay] = useState(false);
|
||||
const isExternalModel = !!mainModelConfig && isExternalApiModelConfig(mainModelConfig);
|
||||
|
||||
const imageDTO = useImageDTOFromCroppableImage(entity.config.image);
|
||||
|
||||
const sx = useMemo(() => {
|
||||
if (!isIPAdapterConfig(entity.config)) {
|
||||
if (!isIPAdapterConfig(entity.config) || isExternalModel) {
|
||||
return baseSx;
|
||||
}
|
||||
return getImageSxWithWeight(entity.config.weight);
|
||||
}, [entity.config]);
|
||||
}, [entity.config, isExternalModel]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isIPAdapterConfig(entity.config)) {
|
||||
if (!isIPAdapterConfig(entity.config) || isExternalModel) {
|
||||
return;
|
||||
}
|
||||
setShowWeightDisplay(true);
|
||||
@@ -94,7 +96,7 @@ export const RefImagePreview = memo(() => {
|
||||
return () => {
|
||||
window.clearTimeout(timeout);
|
||||
};
|
||||
}, [entity.config]);
|
||||
}, [entity.config, isExternalModel]);
|
||||
|
||||
const warnings = useMemo(() => {
|
||||
return getGlobalReferenceImageWarnings(entity, mainModelConfig);
|
||||
@@ -156,7 +158,7 @@ export const RefImagePreview = memo(() => {
|
||||
) : (
|
||||
<Skeleton h="full" aspectRatio="1/1" />
|
||||
)}
|
||||
{isIPAdapterConfig(entity.config) && (
|
||||
{isIPAdapterConfig(entity.config) && !isExternalModel && (
|
||||
<Flex
|
||||
position="absolute"
|
||||
inset={0}
|
||||
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
useCanvasManagerSafe,
|
||||
} from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectIsFLUX, selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
|
||||
import {
|
||||
refImageFLUXReduxImageInfluenceChanged,
|
||||
refImageImageChanged,
|
||||
@@ -50,6 +50,7 @@ import type {
|
||||
FLUXReduxModelConfig,
|
||||
IPAdapterModelConfig,
|
||||
} from 'services/api/types';
|
||||
import { isExternalApiModelConfig } from 'services/api/types';
|
||||
|
||||
import { RefImageImage } from './RefImageImage';
|
||||
|
||||
@@ -65,6 +66,7 @@ const RefImageSettingsContent = memo(() => {
|
||||
const selectConfig = useMemo(() => buildSelectConfig(id), [id]);
|
||||
const config = useAppSelector(selectConfig);
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const mainModelConfig = useAppSelector(selectMainModelConfig);
|
||||
|
||||
const onChangeBeginEndStepPct = useCallback(
|
||||
(beginEndStepPct: [number, number]) => {
|
||||
@@ -125,9 +127,11 @@ const RefImageSettingsContent = memo(() => {
|
||||
);
|
||||
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const isExternalModel = !!mainModelConfig && isExternalApiModelConfig(mainModelConfig);
|
||||
|
||||
// FLUX.2 Klein and Qwen Image Edit have built-in reference image support - no model selector needed
|
||||
const showModelSelector = !isFlux2ReferenceImageConfig(config) && !isQwenImageReferenceImageConfig(config);
|
||||
// FLUX.2 Klein, Qwen Image Edit and external API models do not require a ref image model selection.
|
||||
const showModelSelector =
|
||||
!isFlux2ReferenceImageConfig(config) && !isQwenImageReferenceImageConfig(config) && !isExternalModel;
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
@@ -155,14 +159,14 @@ const RefImageSettingsContent = memo(() => {
|
||||
</Flex>
|
||||
)}
|
||||
<Flex gap={2} w="full">
|
||||
{isIPAdapterConfig(config) && (
|
||||
{isIPAdapterConfig(config) && !isExternalModel && (
|
||||
<Flex flexDir="column" gap={2} w="full">
|
||||
{!isFLUX && <IPAdapterMethod method={config.method} onChange={onChangeIPMethod} />}
|
||||
<Weight weight={config.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={config.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
</Flex>
|
||||
)}
|
||||
{isFLUXReduxConfig(config) && (
|
||||
{isFLUXReduxConfig(config) && !isExternalModel && (
|
||||
<Flex flexDir="column" gap={2} w="full" alignItems="flex-start">
|
||||
<FLUXReduxImageInfluence
|
||||
imageInfluence={config.imageInfluence ?? 'lowest'}
|
||||
|
||||
@@ -75,11 +75,18 @@ export const StagingAreaContextProvider = memo(({ children, sessionId }: PropsWi
|
||||
onAccept: (item, imageDTO) => {
|
||||
const bboxRect = selectBboxRect(store.getState());
|
||||
const { x, y } = bboxRect;
|
||||
const imageObject = imageDTOToImageObject(imageDTO, { usePixelBbox: false });
|
||||
const scale = Math.min(bboxRect.width / imageDTO.width, bboxRect.height / imageDTO.height);
|
||||
const scaledWidth = Math.round(imageDTO.width * scale);
|
||||
const scaledHeight = Math.round(imageDTO.height * scale);
|
||||
const position = {
|
||||
x: x + Math.round((bboxRect.width - scaledWidth) / 2),
|
||||
y: y + Math.round((bboxRect.height - scaledHeight) / 2),
|
||||
};
|
||||
const selectedEntityIdentifier = selectSelectedEntityIdentifier(store.getState());
|
||||
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
position: { x, y },
|
||||
position,
|
||||
objects: [imageObject],
|
||||
};
|
||||
store.dispatch(rasterLayerAdded({ overrides, isSelected: selectedEntityIdentifier?.type === 'raster_layer' }));
|
||||
|
||||
@@ -183,6 +183,33 @@ describe('StagingAreaApi Utility Functions', () => {
|
||||
expect(result).toEqual(['first-image.png', 'second-image.png']);
|
||||
});
|
||||
|
||||
it('should return first image from image collections', () => {
|
||||
const queueItem: S['SessionQueueItem'] = {
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
priority: 0,
|
||||
destination: 'test-session',
|
||||
created_at: '2024-01-01T00:00:00Z',
|
||||
updated_at: '2024-01-01T00:00:00Z',
|
||||
started_at: '2024-01-01T00:00:01Z',
|
||||
completed_at: '2024-01-01T00:01:00Z',
|
||||
error: null,
|
||||
session: {
|
||||
id: 'test-session',
|
||||
source_prepared_mapping: {
|
||||
canvas_output: ['output-node-id'],
|
||||
},
|
||||
results: {
|
||||
'output-node-id': {
|
||||
images: [{ image_name: 'first.png' }, { image_name: 'second.png' }],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
expect(getOutputImageNames(queueItem)).toEqual(['first.png', 'second.png']);
|
||||
});
|
||||
|
||||
it('should handle empty session mapping', () => {
|
||||
const queueItem: S['SessionQueueItem'] = {
|
||||
item_id: 1,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { isImageField } from 'features/nodes/types/common';
|
||||
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
|
||||
import { isCanvasOutputNodeId } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { S } from 'services/api/types';
|
||||
import { formatProgressMessage } from 'services/events/stores';
|
||||
@@ -32,6 +32,11 @@ export const getOutputImageNames = (item: S['SessionQueueItem']): string[] => {
|
||||
if (isImageField(value)) {
|
||||
imageNames.push(value.image_name);
|
||||
}
|
||||
if (isImageFieldCollection(value)) {
|
||||
for (const img of value) {
|
||||
imageNames.push(img.image_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { type ImageDTO, isExternalApiModelConfig } from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('canvas');
|
||||
@@ -90,7 +90,7 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
|
||||
metadata.negative_prompt = selectNegativePrompt(state);
|
||||
metadata.seed = selectSeed(state);
|
||||
const model = selectMainModelConfig(state);
|
||||
if (model) {
|
||||
if (model && !isExternalApiModelConfig(model)) {
|
||||
metadata.model = Graph.getModelMetadataField(model);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
getPrefixedId,
|
||||
} from 'features/controlLayers/konva/util';
|
||||
import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectModel } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectHasFixedDimensionSizes, selectModel } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectBbox } from 'features/controlLayers/store/selectors';
|
||||
import type { Coordinate, Rect, Tool } from 'features/controlLayers/store/types';
|
||||
import Konva from 'konva';
|
||||
@@ -191,6 +191,9 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
|
||||
// Listen for the model changing - some model types constraint the bbox to a certain size or aspect ratio.
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectModel, this.render));
|
||||
|
||||
// Listen for fixed dimension sizes changes - external models may lock bbox resizing
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectHasFixedDimensionSizes, this.render));
|
||||
|
||||
// Update on busy state changes
|
||||
this.subscriptions.add(this.manager.$isBusy.listen(this.render));
|
||||
|
||||
@@ -246,6 +249,10 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
|
||||
if (tool !== 'bbox') {
|
||||
return NO_ANCHORS;
|
||||
}
|
||||
// External models with fixed dimension presets don't allow free bbox resizing
|
||||
if (this.manager.stateApi.runSelector(selectHasFixedDimensionSizes)) {
|
||||
return NO_ANCHORS;
|
||||
}
|
||||
return ALL_ANCHORS;
|
||||
};
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMul
|
||||
import { merge } from 'es-toolkit/compat';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { canvasReset } from 'features/controlLayers/store/actions';
|
||||
import { modelChanged } from 'features/controlLayers/store/paramsSlice';
|
||||
import { aspectRatioIdChanged, modelChanged, resolutionPresetSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import {
|
||||
selectAllEntities,
|
||||
selectAllEntitiesOfType,
|
||||
@@ -31,6 +31,7 @@ import type {
|
||||
RgbColor,
|
||||
SimpleAdjustmentsConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { isAspectRatioID } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
calculateNewSize,
|
||||
getScaledBoundingBoxDimensions,
|
||||
@@ -1288,21 +1289,31 @@ const slice = createSlice({
|
||||
state.bbox.aspectRatio.isLocked = !state.bbox.aspectRatio.isLocked;
|
||||
syncScaledSize(state);
|
||||
},
|
||||
bboxAspectRatioIdChanged: (state, action: PayloadAction<{ id: AspectRatioID }>) => {
|
||||
const { id } = action.payload;
|
||||
bboxAspectRatioIdChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ id: AspectRatioID; fixedSize?: { width: number; height: number } }>
|
||||
) => {
|
||||
const { id, fixedSize } = action.payload;
|
||||
state.bbox.aspectRatio.id = id;
|
||||
if (id === 'Free') {
|
||||
state.bbox.aspectRatio.isLocked = false;
|
||||
} else {
|
||||
state.bbox.aspectRatio.isLocked = true;
|
||||
state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
|
||||
const { width, height } = calculateNewSize(
|
||||
state.bbox.aspectRatio.value,
|
||||
state.bbox.rect.width * state.bbox.rect.height,
|
||||
state.bbox.modelBase
|
||||
);
|
||||
state.bbox.rect.width = width;
|
||||
state.bbox.rect.height = height;
|
||||
if (fixedSize) {
|
||||
// External models provide fixed dimensions for each aspect ratio
|
||||
state.bbox.aspectRatio.value = fixedSize.width / fixedSize.height;
|
||||
state.bbox.rect.width = fixedSize.width;
|
||||
state.bbox.rect.height = fixedSize.height;
|
||||
} else {
|
||||
state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
|
||||
const { width, height } = calculateNewSize(
|
||||
state.bbox.aspectRatio.value,
|
||||
state.bbox.rect.width * state.bbox.rect.height,
|
||||
state.bbox.modelBase
|
||||
);
|
||||
state.bbox.rect.width = width;
|
||||
state.bbox.rect.height = height;
|
||||
}
|
||||
}
|
||||
|
||||
syncScaledSize(state);
|
||||
@@ -1800,6 +1811,29 @@ const slice = createSlice({
|
||||
syncScaledSize(state);
|
||||
}
|
||||
});
|
||||
// Sync bbox when external model resolution preset is selected (aspect_ratio_sizes)
|
||||
builder.addCase(aspectRatioIdChanged, (state, action) => {
|
||||
const { id, fixedSize } = action.payload;
|
||||
// Only sync when fixedSize is provided (external models with aspect_ratio_sizes)
|
||||
if (fixedSize) {
|
||||
state.bbox.rect.width = fixedSize.width;
|
||||
state.bbox.rect.height = fixedSize.height;
|
||||
state.bbox.aspectRatio.value = fixedSize.width / fixedSize.height;
|
||||
state.bbox.aspectRatio.id = id;
|
||||
state.bbox.aspectRatio.isLocked = true;
|
||||
syncScaledSize(state);
|
||||
}
|
||||
});
|
||||
// Sync bbox when external model resolution preset is selected (resolution_presets)
|
||||
builder.addCase(resolutionPresetSelected, (state, action) => {
|
||||
const { width, height, aspectRatio } = action.payload;
|
||||
state.bbox.rect.width = width;
|
||||
state.bbox.rect.height = height;
|
||||
state.bbox.aspectRatio.value = width / height;
|
||||
state.bbox.aspectRatio.id = isAspectRatioID(aspectRatio) ? aspectRatio : 'Free';
|
||||
state.bbox.aspectRatio.isLocked = true;
|
||||
syncScaledSize(state);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
import type {
|
||||
ExternalApiModelConfig,
|
||||
ExternalApiModelDefaultSettings,
|
||||
ExternalImageSize,
|
||||
ExternalModelCapabilities,
|
||||
ExternalModelPanelSchema,
|
||||
} from 'services/api/types';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
selectModelSupportsDimensions,
|
||||
selectModelSupportsGuidance,
|
||||
selectModelSupportsNegativePrompt,
|
||||
selectModelSupportsRefImages,
|
||||
selectModelSupportsSeed,
|
||||
selectModelSupportsSteps,
|
||||
} from './paramsSlice';
|
||||
|
||||
const buildExternalModelIdentifier = (config: ExternalApiModelConfig) =>
|
||||
({
|
||||
key: config.key,
|
||||
hash: config.hash,
|
||||
name: config.name,
|
||||
base: config.base,
|
||||
type: config.type,
|
||||
}) as const;
|
||||
|
||||
const createExternalConfig = (
|
||||
capabilities: ExternalModelCapabilities,
|
||||
panelSchema?: ExternalModelPanelSchema
|
||||
): ExternalApiModelConfig => {
|
||||
const maxImageSize: ExternalImageSize = { width: 1024, height: 1024 };
|
||||
const defaultSettings: ExternalApiModelDefaultSettings = { width: 1024, height: 1024 };
|
||||
|
||||
return {
|
||||
key: 'external-test',
|
||||
hash: 'external:openai:gpt-image-1',
|
||||
path: 'external://openai/gpt-image-1',
|
||||
file_size: 0,
|
||||
name: 'External Test',
|
||||
description: null,
|
||||
source: 'external://openai/gpt-image-1',
|
||||
source_type: 'url',
|
||||
source_api_response: null,
|
||||
cover_image: null,
|
||||
base: 'external',
|
||||
type: 'external_image_generator',
|
||||
format: 'external_api',
|
||||
provider_id: 'openai',
|
||||
provider_model_id: 'gpt-image-1',
|
||||
capabilities: { ...capabilities, max_image_size: maxImageSize },
|
||||
default_settings: defaultSettings,
|
||||
panel_schema: panelSchema,
|
||||
tags: ['external'],
|
||||
is_default: false,
|
||||
};
|
||||
};
|
||||
|
||||
describe('paramsSlice selectors for external models', () => {
|
||||
it('returns false for negative prompt support on external models', () => {
|
||||
const config = createExternalConfig({
|
||||
modes: ['txt2img'],
|
||||
supports_reference_images: false,
|
||||
});
|
||||
const model = buildExternalModelIdentifier(config);
|
||||
|
||||
expect(selectModelSupportsNegativePrompt.resultFunc(model)).toBe(false);
|
||||
});
|
||||
|
||||
it('uses external capabilities for ref image support', () => {
|
||||
const config = createExternalConfig({
|
||||
modes: ['txt2img'],
|
||||
supports_reference_images: false,
|
||||
});
|
||||
const model = buildExternalModelIdentifier(config);
|
||||
|
||||
expect(selectModelSupportsRefImages.resultFunc(model, config)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false for guidance support on external models', () => {
|
||||
const config = createExternalConfig({
|
||||
modes: ['txt2img'],
|
||||
supports_reference_images: false,
|
||||
});
|
||||
const model = buildExternalModelIdentifier(config);
|
||||
|
||||
expect(selectModelSupportsGuidance.resultFunc(model)).toBe(false);
|
||||
});
|
||||
|
||||
it('uses external capabilities for seed support', () => {
|
||||
const config = createExternalConfig({
|
||||
modes: ['txt2img'],
|
||||
supports_reference_images: false,
|
||||
supports_seed: false,
|
||||
});
|
||||
const model = buildExternalModelIdentifier(config);
|
||||
|
||||
expect(selectModelSupportsSeed.resultFunc(model, config)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false for steps support on external models', () => {
|
||||
const config = createExternalConfig({
|
||||
modes: ['txt2img'],
|
||||
supports_reference_images: false,
|
||||
});
|
||||
const model = buildExternalModelIdentifier(config);
|
||||
|
||||
expect(selectModelSupportsSteps.resultFunc(model)).toBe(false);
|
||||
});
|
||||
|
||||
it('prefers panel schema over capabilities for control visibility', () => {
|
||||
const config = createExternalConfig(
|
||||
{
|
||||
modes: ['txt2img'],
|
||||
supports_reference_images: true,
|
||||
supports_seed: true,
|
||||
},
|
||||
{
|
||||
prompts: [{ name: 'reference_images' }],
|
||||
image: [{ name: 'dimensions' }],
|
||||
generation: [],
|
||||
}
|
||||
);
|
||||
const model = buildExternalModelIdentifier(config);
|
||||
|
||||
expect(selectModelSupportsNegativePrompt.resultFunc(model)).toBe(false);
|
||||
expect(selectModelSupportsRefImages.resultFunc(model, config)).toBe(true);
|
||||
expect(selectModelSupportsGuidance.resultFunc(model)).toBe(false);
|
||||
expect(selectModelSupportsSeed.resultFunc(model, config)).toBe(false);
|
||||
expect(selectModelSupportsSteps.resultFunc(model)).toBe(false);
|
||||
expect(selectModelSupportsDimensions.resultFunc(model, config)).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -21,6 +21,7 @@ import {
|
||||
SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS,
|
||||
SUPPORTS_REF_IMAGES_BASE_MODELS,
|
||||
} from 'features/modelManagerV2/models';
|
||||
import type { BaseModelType } from 'features/nodes/types/common';
|
||||
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
|
||||
import type {
|
||||
ParameterCanvasCoherenceMode,
|
||||
@@ -41,9 +42,11 @@ import type {
|
||||
ParameterT5EncoderModel,
|
||||
ParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { getExternalPanelControl, hasExternalPanelControl } from 'features/parameters/util/externalPanelSchema';
|
||||
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import type { AnyModelConfigWithExternal } from 'services/api/types';
|
||||
import { isExternalApiModelConfig, isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const slice = createSlice({
|
||||
@@ -360,7 +363,7 @@ const slice = createSlice({
|
||||
//#region Dimensions
|
||||
sizeRecalled: (state, action: PayloadAction<{ width: number; height: number }>) => {
|
||||
const { width, height } = action.payload;
|
||||
const gridSize = getGridSize(state.model?.base);
|
||||
const gridSize = getGridSize(state.model?.base as BaseModelType | undefined);
|
||||
state.dimensions.width = Math.max(roundDownToMultiple(width, gridSize), 64);
|
||||
state.dimensions.height = Math.max(roundDownToMultiple(height, gridSize), 64);
|
||||
state.dimensions.aspectRatio.value = state.dimensions.width / state.dimensions.height;
|
||||
@@ -369,7 +372,7 @@ const slice = createSlice({
|
||||
},
|
||||
widthChanged: (state, action: PayloadAction<{ width: number; updateAspectRatio?: boolean; clamp?: boolean }>) => {
|
||||
const { width, updateAspectRatio, clamp } = action.payload;
|
||||
const gridSize = getGridSize(state.model?.base);
|
||||
const gridSize = getGridSize(state.model?.base as BaseModelType | undefined);
|
||||
state.dimensions.width = clamp ? Math.max(roundDownToMultiple(width, gridSize), 64) : width;
|
||||
|
||||
if (state.dimensions.aspectRatio.isLocked) {
|
||||
@@ -387,7 +390,7 @@ const slice = createSlice({
|
||||
},
|
||||
heightChanged: (state, action: PayloadAction<{ height: number; updateAspectRatio?: boolean; clamp?: boolean }>) => {
|
||||
const { height, updateAspectRatio, clamp } = action.payload;
|
||||
const gridSize = getGridSize(state.model?.base);
|
||||
const gridSize = getGridSize(state.model?.base as BaseModelType | undefined);
|
||||
state.dimensions.height = clamp ? Math.max(roundDownToMultiple(height, gridSize), 64) : height;
|
||||
|
||||
if (state.dimensions.aspectRatio.isLocked) {
|
||||
@@ -406,21 +409,30 @@ const slice = createSlice({
|
||||
aspectRatioLockToggled: (state) => {
|
||||
state.dimensions.aspectRatio.isLocked = !state.dimensions.aspectRatio.isLocked;
|
||||
},
|
||||
aspectRatioIdChanged: (state, action: PayloadAction<{ id: AspectRatioID }>) => {
|
||||
const { id } = action.payload;
|
||||
aspectRatioIdChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ id: AspectRatioID; fixedSize?: { width: number; height: number } }>
|
||||
) => {
|
||||
const { id, fixedSize } = action.payload;
|
||||
state.dimensions.aspectRatio.id = id;
|
||||
if (id === 'Free') {
|
||||
state.dimensions.aspectRatio.isLocked = false;
|
||||
} else {
|
||||
state.dimensions.aspectRatio.isLocked = true;
|
||||
state.dimensions.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
|
||||
const { width, height } = calculateNewSize(
|
||||
state.dimensions.aspectRatio.value,
|
||||
state.dimensions.width * state.dimensions.height,
|
||||
state.model?.base
|
||||
);
|
||||
state.dimensions.width = width;
|
||||
state.dimensions.height = height;
|
||||
if (fixedSize) {
|
||||
state.dimensions.aspectRatio.value = fixedSize.width / fixedSize.height;
|
||||
state.dimensions.width = fixedSize.width;
|
||||
state.dimensions.height = fixedSize.height;
|
||||
} else {
|
||||
state.dimensions.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
|
||||
const { width, height } = calculateNewSize(
|
||||
state.dimensions.aspectRatio.value,
|
||||
state.dimensions.width * state.dimensions.height,
|
||||
state.model?.base as BaseModelType | undefined
|
||||
);
|
||||
state.dimensions.width = width;
|
||||
state.dimensions.height = height;
|
||||
}
|
||||
}
|
||||
},
|
||||
dimensionsSwapped: (state) => {
|
||||
@@ -434,7 +446,7 @@ const slice = createSlice({
|
||||
const { width, height } = calculateNewSize(
|
||||
state.dimensions.aspectRatio.value,
|
||||
state.dimensions.width * state.dimensions.height,
|
||||
state.model?.base
|
||||
state.model?.base as BaseModelType | undefined
|
||||
);
|
||||
state.dimensions.width = width;
|
||||
state.dimensions.height = height;
|
||||
@@ -442,12 +454,12 @@ const slice = createSlice({
|
||||
}
|
||||
},
|
||||
sizeOptimized: (state) => {
|
||||
const optimalDimension = getOptimalDimension(state.model?.base);
|
||||
const optimalDimension = getOptimalDimension(state.model?.base as BaseModelType | undefined);
|
||||
if (state.dimensions.aspectRatio.isLocked) {
|
||||
const { width, height } = calculateNewSize(
|
||||
state.dimensions.aspectRatio.value,
|
||||
optimalDimension * optimalDimension,
|
||||
state.model?.base
|
||||
state.model?.base as BaseModelType | undefined
|
||||
);
|
||||
state.dimensions.width = width;
|
||||
state.dimensions.height = height;
|
||||
@@ -458,18 +470,54 @@ const slice = createSlice({
|
||||
}
|
||||
},
|
||||
syncedToOptimalDimension: (state) => {
|
||||
const optimalDimension = getOptimalDimension(state.model?.base);
|
||||
const optimalDimension = getOptimalDimension(state.model?.base as BaseModelType | undefined);
|
||||
|
||||
if (!getIsSizeOptimal(state.dimensions.width, state.dimensions.height, state.model?.base)) {
|
||||
if (
|
||||
!getIsSizeOptimal(
|
||||
state.dimensions.width,
|
||||
state.dimensions.height,
|
||||
state.model?.base as BaseModelType | undefined
|
||||
)
|
||||
) {
|
||||
const bboxDims = calculateNewSize(
|
||||
state.dimensions.aspectRatio.value,
|
||||
optimalDimension * optimalDimension,
|
||||
state.model?.base
|
||||
state.model?.base as BaseModelType | undefined
|
||||
);
|
||||
state.dimensions.width = bboxDims.width;
|
||||
state.dimensions.height = bboxDims.height;
|
||||
}
|
||||
},
|
||||
imageSizeChanged: (state, action: PayloadAction<string | null>) => {
|
||||
state.imageSize = action.payload;
|
||||
},
|
||||
openaiQualityChanged: (state, action: PayloadAction<'auto' | 'high' | 'medium' | 'low'>) => {
|
||||
state.openaiQuality = action.payload;
|
||||
},
|
||||
openaiBackgroundChanged: (state, action: PayloadAction<'auto' | 'transparent' | 'opaque'>) => {
|
||||
state.openaiBackground = action.payload;
|
||||
},
|
||||
openaiInputFidelityChanged: (state, action: PayloadAction<'low' | 'high' | null>) => {
|
||||
state.openaiInputFidelity = action.payload;
|
||||
},
|
||||
geminiTemperatureChanged: (state, action: PayloadAction<number | null>) => {
|
||||
state.geminiTemperature = action.payload;
|
||||
},
|
||||
geminiThinkingLevelChanged: (state, action: PayloadAction<'minimal' | 'high' | null>) => {
|
||||
state.geminiThinkingLevel = action.payload;
|
||||
},
|
||||
resolutionPresetSelected: (
|
||||
state,
|
||||
action: PayloadAction<{ imageSize: string; aspectRatio: string; width: number; height: number }>
|
||||
) => {
|
||||
const { imageSize, aspectRatio, width, height } = action.payload;
|
||||
state.imageSize = imageSize;
|
||||
state.dimensions.width = width;
|
||||
state.dimensions.height = height;
|
||||
state.dimensions.aspectRatio.id = aspectRatio as AspectRatioID;
|
||||
state.dimensions.aspectRatio.value = width / height;
|
||||
state.dimensions.aspectRatio.isLocked = true;
|
||||
},
|
||||
paramsReset: (state) => resetState(state),
|
||||
paramsRecalled: (_state, action: PayloadAction<ParamsState>) => {
|
||||
return action.payload;
|
||||
@@ -502,6 +550,9 @@ const hasModelClipSkip = (model: ParameterModel | null) => {
|
||||
};
|
||||
|
||||
const getModelMaxClipSkip = (model: ParameterModel) => {
|
||||
if (model.base === 'external') {
|
||||
return undefined;
|
||||
}
|
||||
if (model.base === 'sdxl') {
|
||||
// We don't support user-defined CLIP skip for SDXL because it doesn't do anything useful
|
||||
return 0;
|
||||
@@ -611,7 +662,13 @@ export const {
|
||||
sizeOptimized,
|
||||
syncedToOptimalDimension,
|
||||
|
||||
resolutionPresetSelected,
|
||||
paramsReset,
|
||||
openaiQualityChanged,
|
||||
openaiBackgroundChanged,
|
||||
openaiInputFidelityChanged,
|
||||
geminiTemperatureChanged,
|
||||
geminiThinkingLevelChanged,
|
||||
paramsRecalled,
|
||||
animaVaeModelSelected,
|
||||
animaQwen3EncoderModelSelected,
|
||||
@@ -656,6 +713,7 @@ export const selectIsCogView4 = createParamsSelector((params) => params.model?.b
|
||||
export const selectIsZImage = createParamsSelector((params) => params.model?.base === 'z-image');
|
||||
export const selectIsAnima = createParamsSelector((params) => params.model?.base === 'anima');
|
||||
export const selectIsFlux2 = createParamsSelector((params) => params.model?.base === 'flux2');
|
||||
export const selectIsExternal = createParamsSelector((params) => params.model?.base === 'external');
|
||||
export const selectIsQwenImage = createParamsSelector((params) => params.model?.base === 'qwen-image');
|
||||
export const selectIsFluxKontext = createParamsSelector((params) => {
|
||||
if (params.model?.base === 'flux' && params.model?.name.toLowerCase().includes('kontext')) {
|
||||
@@ -708,19 +766,91 @@ export const selectOptimizedDenoisingEnabled = createParamsSelector((params) =>
|
||||
export const selectPositivePrompt = createParamsSelector((params) => params.positivePrompt);
|
||||
export const selectNegativePrompt = createParamsSelector((params) => params.negativePrompt);
|
||||
export const selectNegativePromptWithFallback = createParamsSelector((params) => params.negativePrompt ?? '');
|
||||
export const selectModelConfig = createSelector(
|
||||
selectModelConfigsQuery,
|
||||
selectParamsSlice,
|
||||
(modelConfigs, { model }) => {
|
||||
if (!modelConfigs.data) {
|
||||
return null;
|
||||
}
|
||||
if (!model) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
(modelConfigsAdapterSelectors.selectById(modelConfigs.data, model.key) as
|
||||
| AnyModelConfigWithExternal
|
||||
| undefined) ?? null
|
||||
);
|
||||
}
|
||||
);
|
||||
export const selectHasNegativePrompt = createParamsSelector((params) => params.negativePrompt !== null);
|
||||
export const selectModelSupportsNegativePrompt = createSelector(
|
||||
selectModel,
|
||||
(model) => !!model && SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS.includes(model.base)
|
||||
);
|
||||
export const selectModelSupportsRefImages = createSelector(
|
||||
selectModel,
|
||||
(model) => !!model && SUPPORTS_REF_IMAGES_BASE_MODELS.includes(model.base)
|
||||
);
|
||||
export const selectModelSupportsNegativePrompt = createSelector(selectModel, (model) => {
|
||||
if (!model) {
|
||||
return false;
|
||||
}
|
||||
if (model.base === 'external') {
|
||||
return false;
|
||||
}
|
||||
return SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS.includes(model.base);
|
||||
});
|
||||
export const selectModelSupportsRefImages = createSelector(selectModel, selectModelConfig, (model, modelConfig) => {
|
||||
if (!model) {
|
||||
return false;
|
||||
}
|
||||
if (modelConfig && isExternalApiModelConfig(modelConfig)) {
|
||||
return hasExternalPanelControl(modelConfig, 'prompts', 'reference_images');
|
||||
}
|
||||
if (model.base === 'external') {
|
||||
return false;
|
||||
}
|
||||
return SUPPORTS_REF_IMAGES_BASE_MODELS.includes(model.base);
|
||||
});
|
||||
export const selectModelSupportsOptimizedDenoising = createSelector(
|
||||
selectModel,
|
||||
(model) => !!model && SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS.includes(model.base)
|
||||
(model) => !!model && model.base !== 'external' && SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS.includes(model.base)
|
||||
);
|
||||
export const selectModelSupportsGuidance = createSelector(selectModel, (model) => {
|
||||
if (!model) {
|
||||
return false;
|
||||
}
|
||||
if (model.base === 'external') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
export const selectModelSupportsSeed = createSelector(selectModel, selectModelConfig, (model, modelConfig) => {
|
||||
if (!model) {
|
||||
return false;
|
||||
}
|
||||
if (modelConfig && isExternalApiModelConfig(modelConfig)) {
|
||||
return hasExternalPanelControl(modelConfig, 'image', 'seed');
|
||||
}
|
||||
return true;
|
||||
});
|
||||
export const selectModelSupportsSteps = createSelector(selectModel, (model) => {
|
||||
if (!model) {
|
||||
return false;
|
||||
}
|
||||
if (model.base === 'external') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
export const selectModelSupportsDimensions = createSelector(selectModel, selectModelConfig, (model, modelConfig) => {
|
||||
if (!model) {
|
||||
return false;
|
||||
}
|
||||
if (modelConfig && isExternalApiModelConfig(modelConfig)) {
|
||||
return hasExternalPanelControl(modelConfig, 'image', 'dimensions');
|
||||
}
|
||||
return true;
|
||||
});
|
||||
export const selectSeedControl = createSelector(selectModelConfig, (modelConfig) => {
|
||||
if (modelConfig && isExternalApiModelConfig(modelConfig)) {
|
||||
return getExternalPanelControl(modelConfig, 'image', 'seed');
|
||||
}
|
||||
return null;
|
||||
});
|
||||
export const selectScheduler = createParamsSelector((params) => params.scheduler);
|
||||
export const selectFluxScheduler = createParamsSelector((params) => params.fluxScheduler);
|
||||
export const selectFluxDypePreset = createParamsSelector((params) => params.fluxDypePreset);
|
||||
@@ -764,24 +894,52 @@ export const selectHeight = createParamsSelector((params) => params.dimensions.h
|
||||
export const selectAspectRatioID = createParamsSelector((params) => params.dimensions.aspectRatio.id);
|
||||
export const selectAspectRatioValue = createParamsSelector((params) => params.dimensions.aspectRatio.value);
|
||||
export const selectAspectRatioIsLocked = createParamsSelector((params) => params.dimensions.aspectRatio.isLocked);
|
||||
export const selectAllowedAspectRatioIDs = createSelector(selectModelConfig, (modelConfig) => {
|
||||
if (!modelConfig || !isExternalApiModelConfig(modelConfig)) {
|
||||
return null;
|
||||
}
|
||||
const allowed = modelConfig.capabilities.allowed_aspect_ratios;
|
||||
return allowed?.length ? allowed : null;
|
||||
});
|
||||
export const selectAspectRatioSizes = createSelector(selectModelConfig, (modelConfig) => {
|
||||
if (!modelConfig || !isExternalApiModelConfig(modelConfig)) {
|
||||
return null;
|
||||
}
|
||||
return modelConfig.capabilities.aspect_ratio_sizes ?? null;
|
||||
});
|
||||
export const selectResolutionPresets = createSelector(selectModelConfig, (modelConfig) => {
|
||||
if (!modelConfig || !isExternalApiModelConfig(modelConfig)) {
|
||||
return null;
|
||||
}
|
||||
return modelConfig.capabilities.resolution_presets ?? null;
|
||||
});
|
||||
export const selectHasFixedDimensionSizes = createSelector(
|
||||
selectAspectRatioSizes,
|
||||
selectResolutionPresets,
|
||||
(sizes, presets) => sizes !== null || (presets !== null && presets.length > 0)
|
||||
);
|
||||
export const selectImageSize = createParamsSelector((params) => params.imageSize);
|
||||
export const selectOpenaiQuality = createParamsSelector((params) => params.openaiQuality);
|
||||
export const selectOpenaiBackground = createParamsSelector((params) => params.openaiBackground);
|
||||
export const selectOpenaiInputFidelity = createParamsSelector((params) => params.openaiInputFidelity);
|
||||
export const selectGeminiTemperature = createParamsSelector((params) => params.geminiTemperature);
|
||||
export const selectGeminiThinkingLevel = createParamsSelector((params) => params.geminiThinkingLevel);
|
||||
export const selectExternalProviderId = createSelector(selectModelConfig, (modelConfig) => {
|
||||
if (modelConfig && isExternalApiModelConfig(modelConfig)) {
|
||||
return modelConfig.provider_id;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
export const selectMainModelConfig = createSelector(
|
||||
selectModelConfigsQuery,
|
||||
selectParamsSlice,
|
||||
(modelConfigs, { model }) => {
|
||||
if (!modelConfigs.data) {
|
||||
return null;
|
||||
}
|
||||
if (!model) {
|
||||
return null;
|
||||
}
|
||||
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, model.key);
|
||||
if (!modelConfig) {
|
||||
return null;
|
||||
}
|
||||
if (!isNonRefinerMainModelConfig(modelConfig)) {
|
||||
return null;
|
||||
}
|
||||
export const selectMainModelConfig = createSelector(selectModelConfig, (modelConfig) => {
|
||||
if (!modelConfig) {
|
||||
return null;
|
||||
}
|
||||
if (isExternalApiModelConfig(modelConfig)) {
|
||||
return modelConfig;
|
||||
}
|
||||
);
|
||||
if (!isNonRefinerMainModelConfig(modelConfig)) {
|
||||
return null;
|
||||
}
|
||||
return modelConfig;
|
||||
});
|
||||
|
||||
@@ -13,6 +13,7 @@ import type {
|
||||
CanvasRegionalGuidanceState,
|
||||
CanvasState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type { BaseModelType } from 'features/nodes/types/common';
|
||||
import { getGridSize, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -74,7 +75,7 @@ export const selectHasEntities = createSelector(selectEntityCountAll, (count) =>
|
||||
* Selects the optimal dimension for the canvas based on the currently-selected model
|
||||
*/
|
||||
export const selectOptimalDimension = createSelector(selectParamsSlice, (params): number => {
|
||||
const modelBase = params.model?.base;
|
||||
const modelBase = params.model?.base as BaseModelType | undefined;
|
||||
return getOptimalDimension(modelBase ?? null);
|
||||
});
|
||||
|
||||
@@ -82,7 +83,7 @@ export const selectOptimalDimension = createSelector(selectParamsSlice, (params)
|
||||
* Selects the grid size for the canvas based on the currently-selected model
|
||||
*/
|
||||
export const selectGridSize = createSelector(selectParamsSlice, (params): number => {
|
||||
const modelBase = params.model?.base;
|
||||
const modelBase = params.model?.base as BaseModelType | undefined;
|
||||
return getGridSize(modelBase ?? null);
|
||||
});
|
||||
|
||||
|
||||
@@ -663,19 +663,42 @@ export const zLoRA = z.object({
|
||||
});
|
||||
export type LoRA = z.infer<typeof zLoRA>;
|
||||
|
||||
export const zAspectRatioID = z.enum(['Free', '21:9', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16', '9:21']);
|
||||
export const zAspectRatioID = z.enum([
|
||||
'Free',
|
||||
'8:1',
|
||||
'4:1',
|
||||
'21:9',
|
||||
'16:9',
|
||||
'3:2',
|
||||
'5:4',
|
||||
'4:3',
|
||||
'1:1',
|
||||
'3:4',
|
||||
'4:5',
|
||||
'2:3',
|
||||
'9:16',
|
||||
'1:4',
|
||||
'9:21',
|
||||
'1:8',
|
||||
]);
|
||||
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
|
||||
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
|
||||
export const ASPECT_RATIO_MAP: Record<Exclude<AspectRatioID, 'Free'>, { ratio: number; inverseID: AspectRatioID }> = {
|
||||
'8:1': { ratio: 8 / 1, inverseID: '1:8' },
|
||||
'4:1': { ratio: 4 / 1, inverseID: '1:4' },
|
||||
'21:9': { ratio: 21 / 9, inverseID: '9:21' },
|
||||
'16:9': { ratio: 16 / 9, inverseID: '9:16' },
|
||||
'3:2': { ratio: 3 / 2, inverseID: '2:3' },
|
||||
'5:4': { ratio: 5 / 4, inverseID: '4:5' },
|
||||
'4:3': { ratio: 4 / 3, inverseID: '4:3' },
|
||||
'1:1': { ratio: 1, inverseID: '1:1' },
|
||||
'3:4': { ratio: 3 / 4, inverseID: '4:3' },
|
||||
'4:5': { ratio: 4 / 5, inverseID: '5:4' },
|
||||
'2:3': { ratio: 2 / 3, inverseID: '3:2' },
|
||||
'9:16': { ratio: 9 / 16, inverseID: '16:9' },
|
||||
'1:4': { ratio: 1 / 4, inverseID: '4:1' },
|
||||
'9:21': { ratio: 9 / 21, inverseID: '21:9' },
|
||||
'1:8': { ratio: 1 / 8, inverseID: '8:1' },
|
||||
};
|
||||
|
||||
const zAspectRatioConfig = z.object({
|
||||
@@ -794,6 +817,14 @@ export const zParamsState = z.object({
|
||||
zImageSeedVarianceEnabled: z.boolean(),
|
||||
zImageSeedVarianceStrength: z.number().min(0).max(2),
|
||||
zImageSeedVarianceRandomizePercent: z.number().min(1).max(100),
|
||||
imageSize: z.string().nullable().default(null),
|
||||
// OpenAI-specific external options
|
||||
openaiQuality: z.enum(['auto', 'high', 'medium', 'low']).default('auto'),
|
||||
openaiBackground: z.enum(['auto', 'transparent', 'opaque']).default('auto'),
|
||||
openaiInputFidelity: z.enum(['low', 'high']).nullable().default(null),
|
||||
// Gemini-specific external options
|
||||
geminiTemperature: z.number().min(0).max(2).nullable().default(null),
|
||||
geminiThinkingLevel: z.enum(['minimal', 'high']).nullable().default(null),
|
||||
dimensions: zDimensionsState,
|
||||
});
|
||||
export type ParamsState = z.infer<typeof zParamsState>;
|
||||
@@ -865,6 +896,12 @@ export const getInitialParamsState = (): ParamsState => ({
|
||||
zImageSeedVarianceEnabled: false,
|
||||
zImageSeedVarianceStrength: 0.1,
|
||||
zImageSeedVarianceRandomizePercent: 50,
|
||||
imageSize: null,
|
||||
openaiQuality: 'auto',
|
||||
openaiBackground: 'auto',
|
||||
openaiInputFidelity: null,
|
||||
geminiTemperature: null,
|
||||
geminiThinkingLevel: null,
|
||||
dimensions: {
|
||||
width: 512,
|
||||
height: 512,
|
||||
|
||||
@@ -6,7 +6,11 @@ import type {
|
||||
RefImageState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { AnyModelConfig, MainModelConfig } from 'services/api/types';
|
||||
import {
|
||||
type AnyModelConfigWithExternal,
|
||||
isExternalApiModelConfig,
|
||||
type MainOrExternalModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
const WARNINGS = {
|
||||
UNSUPPORTED_MODEL: 'controlLayers.warnings.unsupportedModel',
|
||||
@@ -28,7 +32,7 @@ type WarningTKey = (typeof WARNINGS)[keyof typeof WARNINGS];
|
||||
|
||||
export const getRegionalGuidanceWarnings = (
|
||||
entity: CanvasRegionalGuidanceState,
|
||||
model: MainModelConfig | null | undefined
|
||||
model: MainOrExternalModelConfig | null | undefined
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
@@ -100,8 +104,8 @@ export const getRegionalGuidanceWarnings = (
|
||||
};
|
||||
|
||||
export const areBasesCompatibleForRefImage = (
|
||||
first?: ModelIdentifierField | AnyModelConfig | null,
|
||||
second?: ModelIdentifierField | AnyModelConfig | null
|
||||
first?: ModelIdentifierField | AnyModelConfigWithExternal | null,
|
||||
second?: ModelIdentifierField | AnyModelConfigWithExternal | null
|
||||
): boolean => {
|
||||
if (!first || !second) {
|
||||
return false;
|
||||
@@ -122,11 +126,19 @@ export const areBasesCompatibleForRefImage = (
|
||||
|
||||
export const getGlobalReferenceImageWarnings = (
|
||||
entity: RefImageState,
|
||||
model: MainModelConfig | null | undefined
|
||||
model: MainOrExternalModelConfig | null | undefined
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
if (model) {
|
||||
if (isExternalApiModelConfig(model)) {
|
||||
if (!entity.config.image) {
|
||||
// No image selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
|
||||
}
|
||||
return warnings;
|
||||
}
|
||||
|
||||
if (model.base === 'sd-3' || model.base === 'sd-2' || model.base === 'anima') {
|
||||
// Unsupported model architecture
|
||||
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
|
||||
@@ -159,7 +171,7 @@ export const getGlobalReferenceImageWarnings = (
|
||||
|
||||
export const getControlLayerWarnings = (
|
||||
entity: CanvasControlLayerState,
|
||||
model: MainModelConfig | null | undefined
|
||||
model: MainOrExternalModelConfig | null | undefined
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
@@ -193,7 +205,7 @@ export const getControlLayerWarnings = (
|
||||
|
||||
export const getRasterLayerWarnings = (
|
||||
_entity: CanvasRasterLayerState,
|
||||
_model: MainModelConfig | null | undefined
|
||||
_model: MainOrExternalModelConfig | null | undefined
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
@@ -204,7 +216,7 @@ export const getRasterLayerWarnings = (
|
||||
|
||||
export const getInpaintMaskWarnings = (
|
||||
_entity: CanvasInpaintMaskState,
|
||||
_model: MainModelConfig | null | undefined
|
||||
_model: MainOrExternalModelConfig | null | undefined
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
|
||||
@@ -1186,7 +1186,8 @@ const LoRAs: CollectionMetadataHandler<LoRA[]> = {
|
||||
const key = getProperty(rawItem, 'lora.key');
|
||||
assert(isString(key));
|
||||
// No need to catch here - if this throws, we move on to the next item
|
||||
identifier = await getModelIdentiferFromKey(key, store);
|
||||
const modelConfig = await getModelIdentiferFromKey(key, store);
|
||||
identifier = zModelIdentifierField.parse(modelConfig);
|
||||
}
|
||||
|
||||
assert(identifier.type === 'lora');
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import type { AnyModelVariant, BaseModelType, ModelFormat, ModelType } from 'features/nodes/types/common';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
type AnyModelConfig,
|
||||
isCLIPEmbedModelConfig,
|
||||
isCLIPVisionModelConfig,
|
||||
isControlLoRAModelConfig,
|
||||
isControlNetModelConfig,
|
||||
isExternalApiModelConfig,
|
||||
isFluxReduxModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLLaVAModelConfig,
|
||||
@@ -121,6 +122,11 @@ export const MODEL_CATEGORIES: Record<ModelCategoryType, ModelCategoryData> = {
|
||||
i18nKey: 'modelManager.llavaOnevision',
|
||||
filter: isLLaVAModelConfig,
|
||||
},
|
||||
external_image_generator: {
|
||||
category: 'external_image_generator',
|
||||
i18nKey: 'modelManager.externalImageGenerator',
|
||||
filter: isExternalApiModelConfig,
|
||||
},
|
||||
};
|
||||
|
||||
export const MODEL_CATEGORIES_AS_LIST = objectEntries(MODEL_CATEGORIES).map(([category, { i18nKey, filter }]) => ({
|
||||
@@ -144,6 +150,7 @@ export const MODEL_BASE_TO_COLOR: Record<BaseModelType, string> = {
|
||||
cogview4: 'red',
|
||||
'qwen-image': 'orange',
|
||||
'z-image': 'cyan',
|
||||
external: 'orange',
|
||||
anima: 'invokePurple',
|
||||
unknown: 'red',
|
||||
};
|
||||
@@ -169,6 +176,7 @@ export const MODEL_TYPE_TO_LONG_NAME: Record<ModelType, string> = {
|
||||
clip_embed: 'CLIP Embed',
|
||||
siglip: 'SigLIP',
|
||||
flux_redux: 'FLUX Redux',
|
||||
external_image_generator: 'External Image Generator',
|
||||
unknown: 'Unknown',
|
||||
};
|
||||
|
||||
@@ -187,6 +195,7 @@ export const MODEL_BASE_TO_LONG_NAME: Record<BaseModelType, string> = {
|
||||
cogview4: 'CogView4',
|
||||
'qwen-image': 'Qwen Image',
|
||||
'z-image': 'Z-Image',
|
||||
external: 'External',
|
||||
anima: 'Anima',
|
||||
unknown: 'Unknown',
|
||||
};
|
||||
@@ -206,6 +215,7 @@ export const MODEL_BASE_TO_SHORT_NAME: Record<BaseModelType, string> = {
|
||||
cogview4: 'CogView4',
|
||||
'qwen-image': 'QwenImg',
|
||||
'z-image': 'Z-Image',
|
||||
external: 'External',
|
||||
anima: 'Anima',
|
||||
unknown: 'Unknown',
|
||||
};
|
||||
@@ -237,6 +247,7 @@ export const MODEL_FORMAT_TO_LONG_NAME: Record<ModelFormat, string> = {
|
||||
checkpoint: 'Checkpoint',
|
||||
lycoris: 'LyCORIS',
|
||||
onnx: 'ONNX',
|
||||
external_api: 'External API',
|
||||
olive: 'Olive',
|
||||
embedding_file: 'Embedding (file)',
|
||||
embedding_folder: 'Embedding (folder)',
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import { atom } from 'nanostores';
|
||||
|
||||
type InstallModelsTabName = 'launchpad' | 'urlOrLocal' | 'huggingface' | 'scanFolder' | 'starterModels';
|
||||
type InstallModelsTabName = 'launchpad' | 'urlOrLocal' | 'huggingface' | 'external' | 'scanFolder' | 'starterModels';
|
||||
|
||||
const TAB_TO_INDEX_MAP: Record<InstallModelsTabName, number> = {
|
||||
launchpad: 0,
|
||||
urlOrLocal: 1,
|
||||
huggingface: 2,
|
||||
scanFolder: 3,
|
||||
starterModels: 4,
|
||||
external: 3,
|
||||
scanFolder: 4,
|
||||
starterModels: 5,
|
||||
};
|
||||
|
||||
export const setInstallModelsTabByName = (tab: InstallModelsTabName) => {
|
||||
|
||||
@@ -7,7 +7,10 @@ import { zModelType } from 'features/nodes/types/common';
|
||||
import { assert } from 'tsafe';
|
||||
import z from 'zod';
|
||||
|
||||
const zModelCategoryType = zModelType.exclude(['onnx']).or(z.literal('refiner'));
|
||||
const zModelCategoryType = zModelType
|
||||
.exclude(['onnx'])
|
||||
.or(z.literal('refiner'))
|
||||
.or(z.literal('external_image_generator'));
|
||||
export type ModelCategoryType = z.infer<typeof zModelCategoryType>;
|
||||
|
||||
const zFilterableModelType = zModelCategoryType.or(z.literal('missing'));
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
import {
|
||||
Badge,
|
||||
Button,
|
||||
Card,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormHelperText,
|
||||
FormLabel,
|
||||
Heading,
|
||||
Input,
|
||||
Text,
|
||||
Tooltip,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useBuildModelInstallArg } from 'features/modelManagerV2/hooks/useBuildModelsToInstall';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import { $installModelsTabIndex } from 'features/modelManagerV2/store/installModelsStore';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold, PiWarningBold } from 'react-icons/pi';
|
||||
import {
|
||||
useGetExternalProviderConfigsQuery,
|
||||
useResetExternalProviderConfigMutation,
|
||||
useSetExternalProviderConfigMutation,
|
||||
} from 'services/api/endpoints/appInfo';
|
||||
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { ExternalProviderConfig, StarterModel } from 'services/api/types';
|
||||
|
||||
const PROVIDER_SORT_ORDER = ['gemini', 'openai', 'alibabacloud'];
|
||||
|
||||
type ProviderCardProps = {
|
||||
provider: ExternalProviderConfig;
|
||||
onInstallModels: (providerId: string) => void;
|
||||
};
|
||||
|
||||
type UpdatePayload = {
|
||||
provider_id: string;
|
||||
api_key?: string;
|
||||
base_url?: string;
|
||||
};
|
||||
|
||||
export const ExternalProvidersForm = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const { data, isLoading } = useGetExternalProviderConfigsQuery();
|
||||
const { data: starterModels } = useGetStarterModelsQuery();
|
||||
const [installModel] = useInstallModel();
|
||||
const { getIsInstalled, buildModelInstallArg } = useBuildModelInstallArg();
|
||||
const tabIndex = useStore($installModelsTabIndex);
|
||||
|
||||
const externalModelsByProvider = useMemo(() => {
|
||||
const groups = new Map<string, StarterModel[]>();
|
||||
for (const model of starterModels?.starter_models ?? []) {
|
||||
if (!model.source.startsWith('external://')) {
|
||||
continue;
|
||||
}
|
||||
const providerId = model.source.replace('external://', '').split('/')[0];
|
||||
if (!providerId) {
|
||||
continue;
|
||||
}
|
||||
const models = groups.get(providerId) ?? [];
|
||||
models.push(model);
|
||||
groups.set(providerId, models);
|
||||
}
|
||||
|
||||
for (const [providerId, models] of groups.entries()) {
|
||||
models.sort((a, b) => a.name.localeCompare(b.name));
|
||||
groups.set(providerId, models);
|
||||
}
|
||||
|
||||
return groups;
|
||||
}, [starterModels]);
|
||||
|
||||
const handleInstallProviderModels = useCallback(
|
||||
(providerId: string) => {
|
||||
const models = externalModelsByProvider.get(providerId);
|
||||
if (!models?.length) {
|
||||
return;
|
||||
}
|
||||
const modelsToInstall = models.filter((model) => !getIsInstalled(model)).map(buildModelInstallArg);
|
||||
modelsToInstall.forEach((model) => installModel(model));
|
||||
},
|
||||
[buildModelInstallArg, externalModelsByProvider, getIsInstalled, installModel]
|
||||
);
|
||||
|
||||
const sortedProviders = useMemo(() => {
|
||||
if (!data) {
|
||||
return [];
|
||||
}
|
||||
return [...data].sort((a, b) => {
|
||||
const aIndex = PROVIDER_SORT_ORDER.indexOf(a.provider_id);
|
||||
const bIndex = PROVIDER_SORT_ORDER.indexOf(b.provider_id);
|
||||
if (aIndex === -1 && bIndex === -1) {
|
||||
return a.provider_id.localeCompare(b.provider_id);
|
||||
}
|
||||
if (aIndex === -1) {
|
||||
return 1;
|
||||
}
|
||||
if (bIndex === -1) {
|
||||
return -1;
|
||||
}
|
||||
return aIndex - bIndex;
|
||||
});
|
||||
}, [data]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" height="100%" gap={4}>
|
||||
<Flex flexDir="column" gap={1}>
|
||||
<Heading size="sm">{t('modelManager.externalSetupTitle')}</Heading>
|
||||
<Text color="base.300">{t('modelManager.externalSetupDescription')}</Text>
|
||||
</Flex>
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={4}>
|
||||
{isLoading && <Text color="base.300">{t('common.loading')}</Text>}
|
||||
{!isLoading && sortedProviders.length === 0 && (
|
||||
<Text color="base.300">{t('modelManager.externalProvidersUnavailable')}</Text>
|
||||
)}
|
||||
{sortedProviders.map((provider) => (
|
||||
<ProviderCard
|
||||
key={provider.provider_id}
|
||||
provider={provider}
|
||||
onInstallModels={handleInstallProviderModels}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
{tabIndex === 3 && (
|
||||
<Text variant="subtext" color="base.400">
|
||||
{t('modelManager.externalSetupFooter')}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ExternalProvidersForm.displayName = 'ExternalProvidersForm';
|
||||
|
||||
const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [apiKey, setApiKey] = useState('');
|
||||
const [baseUrl, setBaseUrl] = useState(provider.base_url ?? '');
|
||||
const [saveConfig, { isLoading }] = useSetExternalProviderConfigMutation();
|
||||
const [resetConfig, { isLoading: isResetting }] = useResetExternalProviderConfigMutation();
|
||||
|
||||
useEffect(() => {
|
||||
setBaseUrl(provider.base_url ?? '');
|
||||
}, [provider.base_url]);
|
||||
|
||||
const handleSave = useCallback(() => {
|
||||
const trimmedApiKey = apiKey.trim();
|
||||
const trimmedBaseUrl = baseUrl.trim();
|
||||
const updatePayload: UpdatePayload = {
|
||||
provider_id: provider.provider_id,
|
||||
};
|
||||
if (trimmedApiKey) {
|
||||
updatePayload.api_key = trimmedApiKey;
|
||||
}
|
||||
if (trimmedBaseUrl !== (provider.base_url ?? '')) {
|
||||
updatePayload.base_url = trimmedBaseUrl;
|
||||
}
|
||||
|
||||
if (!updatePayload.api_key && updatePayload.base_url === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
saveConfig(updatePayload)
|
||||
.unwrap()
|
||||
.then((result) => {
|
||||
if (result.api_key_configured) {
|
||||
setApiKey('');
|
||||
onInstallModels(provider.provider_id);
|
||||
}
|
||||
if (result.base_url !== undefined) {
|
||||
setBaseUrl(result.base_url ?? '');
|
||||
}
|
||||
});
|
||||
}, [apiKey, baseUrl, onInstallModels, provider.base_url, provider.provider_id, saveConfig]);
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
resetConfig(provider.provider_id)
|
||||
.unwrap()
|
||||
.then((result) => {
|
||||
setApiKey('');
|
||||
setBaseUrl(result.base_url ?? '');
|
||||
});
|
||||
}, [provider.provider_id, resetConfig]);
|
||||
|
||||
const handleApiKeyChange = useCallback((event: ChangeEvent<HTMLInputElement>) => {
|
||||
setApiKey(event.target.value);
|
||||
}, []);
|
||||
|
||||
const handleBaseUrlChange = useCallback((event: ChangeEvent<HTMLInputElement>) => {
|
||||
setBaseUrl(event.target.value);
|
||||
}, []);
|
||||
|
||||
const statusBadge = provider.api_key_configured ? (
|
||||
<Badge colorScheme="green" display="flex" alignItems="center" gap={2}>
|
||||
<PiCheckBold />
|
||||
{t('settings.externalProviderConfigured')}
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge colorScheme="warning" display="flex" alignItems="center" gap={2}>
|
||||
<PiWarningBold />
|
||||
{t('settings.externalProviderNotConfigured')}
|
||||
</Badge>
|
||||
);
|
||||
|
||||
return (
|
||||
<Card p={4} gap={4} variant="outline">
|
||||
<Flex justifyContent="space-between" alignItems="center" flexWrap="wrap" gap={3}>
|
||||
<Flex flexDir="column" gap={1}>
|
||||
<Heading size="xs" textTransform="capitalize">
|
||||
{provider.provider_id}
|
||||
</Heading>
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.externalProviderCardDescription', { providerId: provider.provider_id })}
|
||||
</Text>
|
||||
</Flex>
|
||||
{statusBadge}
|
||||
</Flex>
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.externalApiKey')}</FormLabel>
|
||||
<Input
|
||||
type="password"
|
||||
autoComplete="off"
|
||||
placeholder={
|
||||
provider.api_key_configured
|
||||
? t('modelManager.externalApiKeyPlaceholderSet')
|
||||
: t('modelManager.externalApiKeyPlaceholder')
|
||||
}
|
||||
value={apiKey}
|
||||
onChange={handleApiKeyChange}
|
||||
/>
|
||||
<FormHelperText>{t('modelManager.externalApiKeyHelper')}</FormHelperText>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.externalBaseUrl')}</FormLabel>
|
||||
<Input
|
||||
placeholder={t('modelManager.externalBaseUrlPlaceholder')}
|
||||
value={baseUrl}
|
||||
onChange={handleBaseUrlChange}
|
||||
/>
|
||||
<FormHelperText>{t('modelManager.externalBaseUrlHelper')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Flex gap={2} justifyContent="flex-end" flexWrap="wrap">
|
||||
<Tooltip label={t('modelManager.externalResetHelper')}>
|
||||
<Button variant="ghost" onClick={handleReset} isLoading={isResetting}>
|
||||
{t('common.reset')}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
<Button
|
||||
colorScheme="invokeYellow"
|
||||
onClick={handleSave}
|
||||
isLoading={isLoading}
|
||||
isDisabled={!apiKey.trim() && baseUrl.trim() === (provider.base_url ?? '')}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Card>
|
||||
);
|
||||
});
|
||||
|
||||
ProviderCard.displayName = 'ProviderCard';
|
||||
@@ -6,7 +6,7 @@ import { StarterBundleButton } from 'features/modelManagerV2/subpanels/AddModelP
|
||||
import { StarterBundleTooltipContentCompact } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterBundleTooltipContentCompact';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFolderOpenBold, PiLinkBold, PiStarBold } from 'react-icons/pi';
|
||||
import { PiFolderOpenBold, PiLinkBold, PiPlugBold, PiStarBold } from 'react-icons/pi';
|
||||
import { SiHuggingface } from 'react-icons/si';
|
||||
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
@@ -28,6 +28,10 @@ export const LaunchpadForm = memo(() => {
|
||||
setInstallModelsTabByName('scanFolder');
|
||||
}, []);
|
||||
|
||||
const navigateToExternalTab = useCallback(() => {
|
||||
setInstallModelsTabByName('external');
|
||||
}, []);
|
||||
|
||||
const navigateToStarterModelsTab = useCallback(() => {
|
||||
setInstallModelsTabByName('starterModels');
|
||||
}, []);
|
||||
@@ -63,6 +67,12 @@ export const LaunchpadForm = memo(() => {
|
||||
title={t('modelManager.scanFolder')}
|
||||
description={t('modelManager.launchpad.scanFolderDescription')}
|
||||
/>
|
||||
<LaunchpadButton
|
||||
onClick={navigateToExternalTab}
|
||||
icon={PiPlugBold}
|
||||
title={t('modelManager.externalProviders')}
|
||||
description={t('modelManager.launchpad.externalDescription')}
|
||||
/>
|
||||
</Grid>
|
||||
</Flex>
|
||||
{/* Recommended Section */}
|
||||
|
||||
@@ -285,6 +285,8 @@ export const ModelInstallQueueItem = memo((props: ModelListItemProps) => {
|
||||
return installJob.source.url;
|
||||
case 'local':
|
||||
return installJob.source.path;
|
||||
case 'external':
|
||||
return `external://${installJob.source.provider_id}/${installJob.source.provider_model_id}`;
|
||||
default:
|
||||
return t('common.unknown');
|
||||
}
|
||||
@@ -292,6 +294,8 @@ export const ModelInstallQueueItem = memo((props: ModelListItemProps) => {
|
||||
|
||||
const displayStatus = optimisticStatus?.status ?? installJob.status;
|
||||
|
||||
const configuredName = installJob.config_in?.name;
|
||||
|
||||
const modelName = useMemo(() => {
|
||||
switch (installJob.source.type) {
|
||||
case 'hf': {
|
||||
@@ -302,13 +306,15 @@ export const ModelInstallQueueItem = memo((props: ModelListItemProps) => {
|
||||
return repo_id;
|
||||
}
|
||||
case 'url':
|
||||
return installJob.source.url.split('/').slice(-1)[0] ?? t('common.unknown');
|
||||
return configuredName ?? installJob.source.url.split('/').slice(-1)[0] ?? t('common.unknown');
|
||||
case 'local':
|
||||
return installJob.source.path.split(/[/\\]/).slice(-1)[0] ?? t('common.unknown');
|
||||
return configuredName ?? installJob.source.path.split(/[/\\]/).slice(-1)[0] ?? t('common.unknown');
|
||||
case 'external':
|
||||
return configuredName ?? `${installJob.source.provider_id}/${installJob.source.provider_model_id}`;
|
||||
default:
|
||||
return t('common.unknown');
|
||||
return configuredName ?? t('common.unknown');
|
||||
}
|
||||
}, [installJob.source, t]);
|
||||
}, [configuredName, installJob.source, t]);
|
||||
|
||||
const progressValue = useMemo(() => {
|
||||
if (displayStatus === 'completed' || displayStatus === 'error' || displayStatus === 'cancelled') {
|
||||
|
||||
@@ -21,6 +21,9 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
|
||||
|
||||
const filteredResults = useMemo(() => {
|
||||
return results.starter_models.filter((result) => {
|
||||
if (result.source.startsWith('external://')) {
|
||||
return false;
|
||||
}
|
||||
const trimmedSearchTerm = searchTerm.trim().toLowerCase();
|
||||
const matchStrings = [
|
||||
result.name.toLowerCase(),
|
||||
|
||||
@@ -2,18 +2,18 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Divider, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $installModelsTabIndex } from 'features/modelManagerV2/store/installModelsStore';
|
||||
import { ExternalProvidersForm } from 'features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm';
|
||||
import { HuggingFaceForm } from 'features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
||||
import { InstallModelForm } from 'features/modelManagerV2/subpanels/AddModelPanel/InstallModelForm';
|
||||
import { LaunchpadForm } from 'features/modelManagerV2/subpanels/AddModelPanel/LaunchpadForm/LaunchpadForm';
|
||||
import { ModelInstallQueue } from 'features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueue';
|
||||
import { ScanModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/ScanFolder/ScanFolderForm';
|
||||
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCubeBold, PiFolderOpenBold, PiLinkSimpleBold, PiShootingStarBold } from 'react-icons/pi';
|
||||
import { PiCubeBold, PiFolderOpenBold, PiLinkSimpleBold, PiPlugBold, PiShootingStarBold } from 'react-icons/pi';
|
||||
import { SiHuggingface } from 'react-icons/si';
|
||||
|
||||
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
||||
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
||||
import { LaunchpadForm } from './AddModelPanel/LaunchpadForm/LaunchpadForm';
|
||||
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
|
||||
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
||||
|
||||
const installModelsTabSx: SystemStyleObject = {
|
||||
display: 'flex',
|
||||
gap: 2,
|
||||
@@ -61,6 +61,10 @@ export const InstallModels = memo(() => {
|
||||
<SiHuggingface />
|
||||
{t('modelManager.huggingFace')}
|
||||
</Tab>
|
||||
<Tab sx={installModelsTabSx}>
|
||||
<PiPlugBold />
|
||||
{t('modelManager.externalProviders')}
|
||||
</Tab>
|
||||
<Tab sx={installModelsTabSx}>
|
||||
<PiFolderOpenBold />
|
||||
{t('modelManager.scanFolder')}
|
||||
@@ -80,6 +84,9 @@ export const InstallModels = memo(() => {
|
||||
<TabPanel height="100%">
|
||||
<HuggingFaceForm />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<ExternalProvidersForm />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<ScanModelsForm />
|
||||
</TabPanel>
|
||||
|
||||
@@ -19,6 +19,7 @@ const FORMAT_NAME_MAP: Record<ModelFormat, string> = {
|
||||
bnb_quantized_nf4b: 'quantized',
|
||||
gguf_quantized: 'gguf',
|
||||
omi: 'omi',
|
||||
external_api: 'external_api',
|
||||
unknown: 'unknown',
|
||||
olive: 'olive',
|
||||
onnx: 'onnx',
|
||||
@@ -40,6 +41,7 @@ const FORMAT_COLOR_MAP: Record<ModelFormat, string> = {
|
||||
unknown: 'red',
|
||||
olive: 'base',
|
||||
onnx: 'base',
|
||||
external_api: 'base',
|
||||
};
|
||||
|
||||
const ModelFormatBadge = ({ format }: Props) => {
|
||||
|
||||
@@ -5,7 +5,7 @@ import { MODEL_BASE_TO_LONG_NAME } from 'features/modelManagerV2/models';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import type { UpdateModelBody } from 'services/api/types';
|
||||
import { objectEntries } from 'tsafe';
|
||||
|
||||
const options: ComboboxOption[] = objectEntries(MODEL_BASE_TO_LONG_NAME).map(([value, label]) => ({
|
||||
@@ -14,7 +14,7 @@ const options: ComboboxOption[] = objectEntries(MODEL_BASE_TO_LONG_NAME).map(([v
|
||||
}));
|
||||
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
control: Control<UpdateModelBody>;
|
||||
};
|
||||
|
||||
const BaseModelSelect = ({ control }: Props) => {
|
||||
|
||||
@@ -5,7 +5,7 @@ import { MODEL_FORMAT_TO_LONG_NAME } from 'features/modelManagerV2/models';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import type { UpdateModelBody } from 'services/api/types';
|
||||
import { objectEntries } from 'tsafe';
|
||||
|
||||
const options: ComboboxOption[] = objectEntries(MODEL_FORMAT_TO_LONG_NAME).map(([value, label]) => ({
|
||||
@@ -14,7 +14,7 @@ const options: ComboboxOption[] = objectEntries(MODEL_FORMAT_TO_LONG_NAME).map((
|
||||
}));
|
||||
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
control: Control<UpdateModelBody>;
|
||||
};
|
||||
|
||||
const ModelFormatSelect = ({ control }: Props) => {
|
||||
|
||||
@@ -5,7 +5,7 @@ import { MODEL_TYPE_TO_LONG_NAME } from 'features/modelManagerV2/models';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import type { UpdateModelBody } from 'services/api/types';
|
||||
import { objectEntries } from 'tsafe';
|
||||
|
||||
const options: ComboboxOption[] = objectEntries(MODEL_TYPE_TO_LONG_NAME).map(([value, label]) => ({
|
||||
@@ -14,7 +14,7 @@ const options: ComboboxOption[] = objectEntries(MODEL_TYPE_TO_LONG_NAME).map(([v
|
||||
}));
|
||||
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
control: Control<UpdateModelBody>;
|
||||
};
|
||||
|
||||
const ModelTypeSelect = ({ control }: Props) => {
|
||||
|
||||
@@ -5,13 +5,13 @@ import { MODEL_VARIANT_TO_LONG_NAME } from 'features/modelManagerV2/models';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import type { UpdateModelBody } from 'services/api/types';
|
||||
import { objectEntries } from 'tsafe';
|
||||
|
||||
const options: ComboboxOption[] = objectEntries(MODEL_VARIANT_TO_LONG_NAME).map(([value, label]) => ({ label, value }));
|
||||
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
control: Control<UpdateModelBody>;
|
||||
};
|
||||
|
||||
const ModelVariantSelect = ({ control }: Props) => {
|
||||
|
||||
@@ -4,7 +4,7 @@ import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import type { UpdateModelBody } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
@@ -14,7 +14,7 @@ const options: ComboboxOption[] = [
|
||||
];
|
||||
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
control: Control<UpdateModelBody>;
|
||||
};
|
||||
|
||||
const PredictionTypeSelect = ({ control }: Props) => {
|
||||
|
||||
@@ -5,6 +5,7 @@ import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiExclamationMarkBold } from 'react-icons/pi';
|
||||
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfigWithExternal } from 'services/api/types';
|
||||
|
||||
import { ModelEdit } from './ModelEdit';
|
||||
import { ModelView } from './ModelView';
|
||||
@@ -21,7 +22,9 @@ export const Model = memo(() => {
|
||||
if (selectedModelKey === null) {
|
||||
return null;
|
||||
}
|
||||
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs, selectedModelKey);
|
||||
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs, selectedModelKey) as
|
||||
| AnyModelConfigWithExternal
|
||||
| undefined;
|
||||
|
||||
if (!modelConfig) {
|
||||
return null;
|
||||
|
||||
@@ -7,11 +7,11 @@ import { memo, type MouseEvent, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { AnyModelConfigWithExternal } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
showLabel?: boolean;
|
||||
modelConfig: AnyModelConfig;
|
||||
modelConfig: AnyModelConfigWithExternal;
|
||||
};
|
||||
|
||||
export const ModelDeleteButton = memo(({ showLabel = true, modelConfig }: Props) => {
|
||||
|
||||
@@ -15,12 +15,18 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { type SubmitHandler, useForm } from 'react-hook-form';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { type Control, type SubmitHandler, useForm, useWatch } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold, PiXBold } from 'react-icons/pi';
|
||||
import { type UpdateModelArg, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
type AnyModelConfigWithExternal,
|
||||
type ExternalApiModelDefaultSettings,
|
||||
type ExternalModelCapabilities,
|
||||
isExternalApiModelConfig,
|
||||
type UpdateModelBody,
|
||||
} from 'services/api/types';
|
||||
|
||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||
import ModelFormatSelect from './Fields/ModelFormatSelect';
|
||||
@@ -30,7 +36,14 @@ import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||
import { ModelFooter } from './ModelFooter';
|
||||
|
||||
type Props = {
|
||||
modelConfig: AnyModelConfig;
|
||||
modelConfig: AnyModelConfigWithExternal;
|
||||
};
|
||||
|
||||
type ModelEditFormValues = UpdateModelBody & {
|
||||
capabilities?: ExternalModelCapabilities;
|
||||
provider_id?: string;
|
||||
provider_model_id?: string;
|
||||
default_settings?: ExternalApiModelDefaultSettings | null;
|
||||
};
|
||||
|
||||
const stringFieldOptions = {
|
||||
@@ -41,19 +54,54 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isExternal = useMemo(() => isExternalApiModelConfig(modelConfig), [modelConfig]);
|
||||
|
||||
const form = useForm<UpdateModelArg['body']>({
|
||||
defaultValues: modelConfig,
|
||||
const form = useForm<ModelEditFormValues>({
|
||||
defaultValues: modelConfig as unknown as ModelEditFormValues,
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
||||
const externalModes = useWatch({
|
||||
control: form.control,
|
||||
name: 'capabilities.modes',
|
||||
}) as ExternalModelCapabilities['modes'] | undefined;
|
||||
|
||||
const modeSet = useMemo(() => new Set(externalModes ?? []), [externalModes]);
|
||||
|
||||
const toggleMode = useCallback(
|
||||
(mode: ExternalModelCapabilities['modes'][number]) => {
|
||||
const nextModes = modeSet.has(mode)
|
||||
? externalModes?.filter((value) => value !== mode)
|
||||
: [...(externalModes ?? []), mode];
|
||||
form.setValue('capabilities.modes', nextModes ?? [], { shouldDirty: true, shouldValidate: true });
|
||||
},
|
||||
[externalModes, form, modeSet]
|
||||
);
|
||||
|
||||
const handleToggleTxt2Img = useCallback(() => toggleMode('txt2img'), [toggleMode]);
|
||||
const handleToggleImg2Img = useCallback(() => toggleMode('img2img'), [toggleMode]);
|
||||
const handleToggleInpaint = useCallback(() => toggleMode('inpaint'), [toggleMode]);
|
||||
|
||||
const parseOptionalNumber = useCallback((value: string | null | undefined) => {
|
||||
if (value === null || value === undefined || value === '') {
|
||||
return null;
|
||||
}
|
||||
if (typeof value !== 'string') {
|
||||
return Number.isNaN(Number(value)) ? null : Number(value);
|
||||
}
|
||||
if (value.trim() === '') {
|
||||
return null;
|
||||
}
|
||||
const parsed = Number(value);
|
||||
return Number.isNaN(parsed) ? null : parsed;
|
||||
}, []);
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<ModelEditFormValues>>(
|
||||
(values) => {
|
||||
const responseBody: UpdateModelArg = {
|
||||
key: modelConfig.key,
|
||||
body: values,
|
||||
body: values as UpdateModelBody,
|
||||
};
|
||||
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
@@ -133,19 +181,19 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
<ModelTypeSelect control={form.control} />
|
||||
<ModelTypeSelect control={form.control as unknown as Control<UpdateModelBody>} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.modelFormat')}</FormLabel>
|
||||
<ModelFormatSelect control={form.control} />
|
||||
<ModelFormatSelect control={form.control as unknown as Control<UpdateModelBody>} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={form.control} />
|
||||
<BaseModelSelect control={form.control as unknown as Control<UpdateModelBody>} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
<ModelVariantSelect control={form.control as unknown as Control<UpdateModelBody>} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
@@ -153,13 +201,143 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={form.control} />
|
||||
<PredictionTypeSelect control={form.control as unknown as Control<UpdateModelBody>} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...form.register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</SimpleGrid>
|
||||
{isExternal && (
|
||||
<>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.externalProvider')}
|
||||
</Heading>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.providerId')}</FormLabel>
|
||||
<Input {...form.register('provider_id', stringFieldOptions)} size="md" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.providerModelId')}</FormLabel>
|
||||
<Input {...form.register('provider_model_id', stringFieldOptions)} size="md" />
|
||||
</FormControl>
|
||||
</SimpleGrid>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.externalCapabilities')}
|
||||
</Heading>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.supportedModes')}</FormLabel>
|
||||
<Flex gap={3} wrap="wrap">
|
||||
<Checkbox isChecked={modeSet.has('txt2img')} onChange={handleToggleTxt2Img}>
|
||||
txt2img
|
||||
</Checkbox>
|
||||
<Checkbox isChecked={modeSet.has('img2img')} onChange={handleToggleImg2Img}>
|
||||
img2img
|
||||
</Checkbox>
|
||||
<Checkbox isChecked={modeSet.has('inpaint')} onChange={handleToggleInpaint}>
|
||||
inpaint
|
||||
</Checkbox>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.supportsReferenceImages')}</FormLabel>
|
||||
<Checkbox {...form.register('capabilities.supports_reference_images')} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.supportsSeed')}</FormLabel>
|
||||
<Checkbox {...form.register('capabilities.supports_seed')} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.maxImagesPerRequest')}</FormLabel>
|
||||
<Input
|
||||
type="number"
|
||||
{...form.register('capabilities.max_images_per_request', {
|
||||
setValueAs: parseOptionalNumber,
|
||||
})}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.maxReferenceImages')}</FormLabel>
|
||||
<Input
|
||||
type="number"
|
||||
{...form.register('capabilities.max_reference_images', {
|
||||
setValueAs: parseOptionalNumber,
|
||||
})}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.maxImageWidth')}</FormLabel>
|
||||
<Input
|
||||
type="number"
|
||||
{...form.register('capabilities.max_image_size.width', {
|
||||
setValueAs: parseOptionalNumber,
|
||||
})}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.maxImageHeight')}</FormLabel>
|
||||
<Input
|
||||
type="number"
|
||||
{...form.register('capabilities.max_image_size.height', {
|
||||
setValueAs: parseOptionalNumber,
|
||||
})}
|
||||
/>
|
||||
</FormControl>
|
||||
</SimpleGrid>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.externalDefaults')}
|
||||
</Heading>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.width')}</FormLabel>
|
||||
<Input
|
||||
type="number"
|
||||
{...form.register('default_settings.width', {
|
||||
setValueAs: parseOptionalNumber,
|
||||
})}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.height')}</FormLabel>
|
||||
<Input
|
||||
type="number"
|
||||
{...form.register('default_settings.height', {
|
||||
setValueAs: parseOptionalNumber,
|
||||
})}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('parameters.steps')}</FormLabel>
|
||||
<Input
|
||||
type="number"
|
||||
{...form.register('default_settings.steps', {
|
||||
setValueAs: parseOptionalNumber,
|
||||
})}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('parameters.guidance')}</FormLabel>
|
||||
<Input
|
||||
type="number"
|
||||
{...form.register('default_settings.guidance', {
|
||||
setValueAs: parseOptionalNumber,
|
||||
})}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.numImages')}</FormLabel>
|
||||
<Input
|
||||
type="number"
|
||||
{...form.register('default_settings.num_images', {
|
||||
setValueAs: parseOptionalNumber,
|
||||
})}
|
||||
/>
|
||||
</FormControl>
|
||||
</SimpleGrid>
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Flex, Heading, type SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { AnyModelConfigWithExternal } from 'services/api/types';
|
||||
|
||||
import { ModelConvertButton } from './ModelConvertButton';
|
||||
import { ModelDeleteButton } from './ModelDeleteButton';
|
||||
@@ -20,7 +20,7 @@ const footerRowSx: SystemStyleObject = {
|
||||
};
|
||||
|
||||
type Props = {
|
||||
modelConfig: AnyModelConfig;
|
||||
modelConfig: AnyModelConfigWithExternal;
|
||||
isEditing: boolean;
|
||||
};
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ import ModelImageUpload from 'features/modelManagerV2/subpanels/ModelPanel/Field
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { AnyModelConfigWithExternal } from 'services/api/types';
|
||||
|
||||
type Props = PropsWithChildren<{
|
||||
modelConfig: AnyModelConfig;
|
||||
modelConfig: AnyModelConfigWithExternal;
|
||||
}>;
|
||||
|
||||
export const ModelHeader = memo(({ modelConfig, children }: Props) => {
|
||||
|
||||
@@ -4,15 +4,18 @@ import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiSparkleFill } from 'react-icons/pi';
|
||||
import { useReidentifyModelMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import { type AnyModelConfigWithExternal, isExternalApiModelConfig } from 'services/api/types';
|
||||
|
||||
import { isExternalModel } from './isExternalModel';
|
||||
|
||||
interface Props {
|
||||
modelConfig: AnyModelConfig;
|
||||
modelConfig: AnyModelConfigWithExternal;
|
||||
}
|
||||
|
||||
export const ModelReidentifyButton = memo(({ modelConfig }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const [reidentifyModel, { isLoading }] = useReidentifyModelMutation();
|
||||
const isExternal = isExternalApiModelConfig(modelConfig) || isExternalModel(modelConfig.path);
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
reidentifyModel({ key: modelConfig.key })
|
||||
@@ -40,6 +43,10 @@ export const ModelReidentifyButton = memo(({ modelConfig }: Props) => {
|
||||
});
|
||||
}, [modelConfig.key, reidentifyModel, t]);
|
||||
|
||||
if (isExternal) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Button
|
||||
onClick={onClick}
|
||||
|
||||
@@ -3,13 +3,13 @@ import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiDownloadSimpleBold } from 'react-icons/pi';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { AnyModelConfigWithExternal } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
modelConfig: AnyModelConfig;
|
||||
modelConfig: AnyModelConfigWithExternal;
|
||||
};
|
||||
|
||||
const buildExportData = (modelConfig: AnyModelConfig): Record<string, unknown> => {
|
||||
const buildExportData = (modelConfig: AnyModelConfigWithExternal): Record<string, unknown> => {
|
||||
const data: Record<string, unknown> = {};
|
||||
|
||||
if (
|
||||
|
||||
@@ -5,7 +5,7 @@ import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiUploadSimpleBold } from 'react-icons/pi';
|
||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { AnyModelConfigWithExternal } from 'services/api/types';
|
||||
|
||||
const validateImportData = (data: unknown): data is Record<string, unknown> => {
|
||||
if (typeof data !== 'object' || data === null || Array.isArray(data)) {
|
||||
@@ -40,7 +40,7 @@ const validateImportData = (data: unknown): data is Record<string, unknown> => {
|
||||
};
|
||||
|
||||
type Props = {
|
||||
modelConfig: AnyModelConfig;
|
||||
modelConfig: AnyModelConfigWithExternal;
|
||||
};
|
||||
|
||||
export const ModelSettingsImportButton = memo(({ modelConfig }: Props) => {
|
||||
|
||||
@@ -22,10 +22,10 @@ import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFolderOpenFill } from 'react-icons/pi';
|
||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { AnyModelConfigWithExternal } from 'services/api/types';
|
||||
|
||||
interface Props {
|
||||
modelConfig: AnyModelConfig;
|
||||
modelConfig: AnyModelConfigWithExternal;
|
||||
}
|
||||
|
||||
export const ModelUpdatePathButton = memo(({ modelConfig }: Props) => {
|
||||
|
||||
@@ -12,14 +12,15 @@ import { TriggerPhrases } from 'features/modelManagerV2/subpanels/ModelPanel/Tri
|
||||
import { filesize } from 'filesize';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type {
|
||||
AnyModelConfig,
|
||||
CLIPEmbedModelConfig,
|
||||
CLIPVisionModelConfig,
|
||||
LlavaOnevisionModelConfig,
|
||||
Qwen3EncoderModelConfig,
|
||||
SigLIPModelConfig,
|
||||
T5EncoderModelConfig,
|
||||
import {
|
||||
type AnyModelConfigWithExternal,
|
||||
type CLIPEmbedModelConfig,
|
||||
type CLIPVisionModelConfig,
|
||||
isExternalApiModelConfig,
|
||||
type LlavaOnevisionModelConfig,
|
||||
type Qwen3EncoderModelConfig,
|
||||
type SigLIPModelConfig,
|
||||
type T5EncoderModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { isExternalModel } from './isExternalModel';
|
||||
@@ -38,7 +39,7 @@ type EncoderModelConfig =
|
||||
| SigLIPModelConfig
|
||||
| LlavaOnevisionModelConfig;
|
||||
|
||||
const isEncoderModel = (modelConfig: AnyModelConfig): modelConfig is EncoderModelConfig => {
|
||||
const isEncoderModel = (modelConfig: AnyModelConfigWithExternal): modelConfig is EncoderModelConfig => {
|
||||
return (
|
||||
modelConfig.type === 'clip_embed' ||
|
||||
modelConfig.type === 't5_encoder' ||
|
||||
@@ -50,7 +51,7 @@ const isEncoderModel = (modelConfig: AnyModelConfig): modelConfig is EncoderMode
|
||||
};
|
||||
|
||||
type Props = {
|
||||
modelConfig: AnyModelConfig;
|
||||
modelConfig: AnyModelConfigWithExternal;
|
||||
};
|
||||
|
||||
export const ModelView = memo(({ modelConfig }: Props) => {
|
||||
@@ -104,6 +105,12 @@ export const ModelView = memo(({ modelConfig }: Props) => {
|
||||
<ModelAttrView label={t('modelManager.modelFormat')} value={modelConfig.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelConfig.path} />
|
||||
<ModelAttrView label={t('modelManager.fileSize')} value={filesize(modelConfig.file_size)} />
|
||||
{isExternalApiModelConfig(modelConfig) && (
|
||||
<>
|
||||
<ModelAttrView label={t('modelManager.providerId')} value={modelConfig.provider_id} />
|
||||
<ModelAttrView label={t('modelManager.providerModelId')} value={modelConfig.provider_model_id} />
|
||||
</>
|
||||
)}
|
||||
{'variant' in modelConfig && modelConfig.variant && (
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelConfig.variant} />
|
||||
)}
|
||||
|
||||
@@ -31,10 +31,10 @@ import {
|
||||
useRemoveModelRelationshipMutation,
|
||||
} from 'services/api/endpoints/modelRelationships';
|
||||
import { useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { AnyModelConfig, AnyModelConfigWithExternal } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
modelConfig: AnyModelConfig;
|
||||
modelConfig: AnyModelConfigWithExternal;
|
||||
};
|
||||
|
||||
type ModelGroup = {
|
||||
@@ -52,7 +52,10 @@ type ModelGroup = {
|
||||
//
|
||||
// TODO: In the future, refine this logic to more strictly validate
|
||||
// relationships based on model types or actual usage patterns.
|
||||
const isBaseCompatible = (a: AnyModelConfig, b: AnyModelConfig): boolean => {
|
||||
const isBaseCompatible = (a: AnyModelConfigWithExternal, b: AnyModelConfig): boolean => {
|
||||
if (a.base === 'external') {
|
||||
return false;
|
||||
}
|
||||
if (a.base === 'any' || b.base === 'any') {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -97,6 +97,7 @@ export const zBaseModelType = z.enum([
|
||||
'cogview4',
|
||||
'qwen-image',
|
||||
'z-image',
|
||||
'external',
|
||||
'anima',
|
||||
'unknown',
|
||||
]);
|
||||
@@ -133,6 +134,7 @@ export const zModelType = z.enum([
|
||||
'clip_embed',
|
||||
'siglip',
|
||||
'flux_redux',
|
||||
'external_image_generator',
|
||||
'unknown',
|
||||
]);
|
||||
export type ModelType = z.infer<typeof zModelType>;
|
||||
@@ -184,6 +186,7 @@ export const zModelFormat = z.enum([
|
||||
'bnb_quantized_int8b',
|
||||
'bnb_quantized_nf4b',
|
||||
'gguf_quantized',
|
||||
'external_api',
|
||||
'unknown',
|
||||
]);
|
||||
export type ModelFormat = z.infer<typeof zModelFormat>;
|
||||
@@ -197,6 +200,17 @@ export const zModelIdentifierField = z.object({
|
||||
submodel_type: zSubModelType.nullish(),
|
||||
});
|
||||
export type ModelIdentifierField = z.infer<typeof zModelIdentifierField>;
|
||||
|
||||
// Frontend-only identifier for external API models (not part of the backend schema)
|
||||
export const zExternalModelIdentifierField = z.object({
|
||||
key: z.string().min(1),
|
||||
hash: z.string().min(1),
|
||||
name: z.string().min(1),
|
||||
base: z.literal('external'),
|
||||
type: z.literal('external_image_generator'),
|
||||
submodel_type: zSubModelType.nullish(),
|
||||
});
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region Control Adapters
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { ParamsState, RefImagesState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToCroppableImage, initialIPAdapter } from 'features/controlLayers/store/util';
|
||||
import type {
|
||||
ExternalApiModelConfig,
|
||||
ExternalApiModelDefaultSettings,
|
||||
ExternalImageSize,
|
||||
ExternalModelCapabilities,
|
||||
ExternalModelPanelSchema,
|
||||
ImageDTO,
|
||||
} from 'services/api/types';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { buildExternalGraph } from './buildExternalGraph';
|
||||
|
||||
const createExternalModel = (overrides: Partial<ExternalApiModelConfig> = {}): ExternalApiModelConfig => {
|
||||
const maxImageSize: ExternalImageSize = { width: 1024, height: 1024 };
|
||||
const defaultSettings: ExternalApiModelDefaultSettings = { width: 1024, height: 1024 };
|
||||
const capabilities: ExternalModelCapabilities = {
|
||||
modes: ['txt2img'],
|
||||
supports_reference_images: true,
|
||||
supports_seed: true,
|
||||
max_image_size: maxImageSize,
|
||||
};
|
||||
|
||||
return {
|
||||
key: 'external-test',
|
||||
hash: 'external:openai:gpt-image-1',
|
||||
path: 'external://openai/gpt-image-1',
|
||||
file_size: 0,
|
||||
name: 'External Test',
|
||||
description: null,
|
||||
source: 'external://openai/gpt-image-1',
|
||||
source_type: 'url',
|
||||
source_api_response: null,
|
||||
cover_image: null,
|
||||
base: 'external',
|
||||
type: 'external_image_generator',
|
||||
format: 'external_api',
|
||||
provider_id: 'openai',
|
||||
provider_model_id: 'gpt-image-1',
|
||||
capabilities,
|
||||
default_settings: defaultSettings,
|
||||
panel_schema: undefined,
|
||||
tags: ['external'],
|
||||
is_default: false,
|
||||
...overrides,
|
||||
};
|
||||
};
|
||||
|
||||
let mockModelConfig: ExternalApiModelConfig | null = null;
|
||||
let mockParams: ParamsState;
|
||||
let mockRefImages: RefImagesState;
|
||||
let mockSizes: { scaledSize: { width: number; height: number } };
|
||||
|
||||
const mockOutputFields = {
|
||||
id: 'external_output',
|
||||
use_cache: false,
|
||||
is_intermediate: false,
|
||||
board: undefined,
|
||||
};
|
||||
|
||||
vi.mock('features/controlLayers/store/paramsSlice', () => ({
|
||||
selectModelConfig: () => mockModelConfig,
|
||||
selectParamsSlice: () => mockParams,
|
||||
}));
|
||||
|
||||
vi.mock('features/controlLayers/store/refImagesSlice', () => ({
|
||||
selectRefImagesSlice: () => mockRefImages,
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/graphBuilderUtils', () => ({
|
||||
getOriginalAndScaledSizesForTextToImage: () => mockSizes,
|
||||
getOriginalAndScaledSizesForOtherModes: () => ({
|
||||
scaledSize: { width: 512, height: 512 },
|
||||
rect: { x: 0, y: 0, width: 512, height: 512 },
|
||||
}),
|
||||
selectCanvasOutputFields: () => mockOutputFields,
|
||||
}));
|
||||
|
||||
beforeEach(() => {
|
||||
mockParams = {
|
||||
steps: 20,
|
||||
guidance: 4.5,
|
||||
dimensions: { width: 768, height: 512, aspectRatio: { id: '3:2', value: 1.5, isLocked: true } },
|
||||
} as ParamsState;
|
||||
mockSizes = { scaledSize: { width: 768, height: 512 } };
|
||||
|
||||
const imageDTO = { image_name: 'ref.png', width: 64, height: 64 } as ImageDTO;
|
||||
mockRefImages = {
|
||||
selectedEntityId: null,
|
||||
isPanelOpen: false,
|
||||
entities: [
|
||||
{
|
||||
id: 'ref-image-1',
|
||||
isEnabled: true,
|
||||
config: {
|
||||
...initialIPAdapter,
|
||||
weight: 0.5,
|
||||
image: imageDTOToCroppableImage(imageDTO),
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
});
|
||||
|
||||
describe('buildExternalGraph', () => {
|
||||
it('builds txt2img graph with reference images and seed', async () => {
|
||||
const modelConfig = createExternalModel();
|
||||
mockModelConfig = modelConfig;
|
||||
|
||||
const { g } = await buildExternalGraph({
|
||||
generationMode: 'txt2img',
|
||||
state: {} as RootState,
|
||||
manager: null,
|
||||
});
|
||||
const graph = g.getGraph();
|
||||
const externalNode = Object.values(graph.nodes).find((node) => node.type === 'openai_image_generation') as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
|
||||
expect(externalNode).toBeDefined();
|
||||
expect(externalNode?.type).toBe('openai_image_generation');
|
||||
expect(externalNode?.mode).toBe('txt2img');
|
||||
expect(externalNode?.width).toBe(768);
|
||||
expect(externalNode?.height).toBe(512);
|
||||
expect((externalNode?.reference_images as Array<{ image_name: string }> | undefined)?.[0]).toEqual({
|
||||
image_name: 'ref.png',
|
||||
});
|
||||
|
||||
const seedEdge = graph.edges.find((edge) => edge.destination.field === 'seed');
|
||||
expect(seedEdge).toBeDefined();
|
||||
});
|
||||
|
||||
it('prefers panel schema over capabilities when building node inputs', async () => {
|
||||
const panelSchema: ExternalModelPanelSchema = {
|
||||
prompts: [{ name: 'reference_images' }],
|
||||
image: [{ name: 'dimensions' }],
|
||||
generation: [],
|
||||
};
|
||||
mockModelConfig = createExternalModel({
|
||||
panel_schema: panelSchema,
|
||||
});
|
||||
|
||||
const { g } = await buildExternalGraph({
|
||||
generationMode: 'txt2img',
|
||||
state: {} as RootState,
|
||||
manager: null,
|
||||
});
|
||||
const graph = g.getGraph();
|
||||
const externalNode = Object.values(graph.nodes).find((node) => node.type === 'openai_image_generation') as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
|
||||
expect((externalNode?.reference_images as Array<{ image_name: string }> | undefined)?.[0]).toEqual({
|
||||
image_name: 'ref.png',
|
||||
});
|
||||
|
||||
const seedEdge = graph.edges.find((edge) => edge.destination.field === 'seed');
|
||||
expect(seedEdge).toBeUndefined();
|
||||
});
|
||||
|
||||
it('uses provider-specific node types', async () => {
|
||||
mockModelConfig = createExternalModel({
|
||||
provider_id: 'gemini',
|
||||
provider_model_id: 'gemini-2.5-flash-image',
|
||||
path: 'external://gemini/gemini-2.5-flash-image',
|
||||
source: 'external://gemini/gemini-2.5-flash-image',
|
||||
hash: 'external:gemini:gemini-2.5-flash-image',
|
||||
});
|
||||
|
||||
const { g } = await buildExternalGraph({
|
||||
generationMode: 'txt2img',
|
||||
state: {} as RootState,
|
||||
manager: null,
|
||||
});
|
||||
|
||||
const graph = g.getGraph();
|
||||
const externalNode = Object.values(graph.nodes).find((node) => node.type === 'gemini_image_generation');
|
||||
|
||||
expect(externalNode).toBeDefined();
|
||||
expect(externalNode?.type).toBe('gemini_image_generation');
|
||||
});
|
||||
|
||||
it('throws when mode is unsupported', async () => {
|
||||
const modelConfig = createExternalModel({
|
||||
capabilities: {
|
||||
modes: ['img2img'],
|
||||
},
|
||||
});
|
||||
mockModelConfig = modelConfig;
|
||||
|
||||
await expect(
|
||||
buildExternalGraph({
|
||||
generationMode: 'txt2img',
|
||||
state: {} as RootState,
|
||||
manager: null,
|
||||
})
|
||||
).rejects.toThrow('does not support txt2img');
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,152 @@
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
|
||||
import { type ModelIdentifierField, zImageField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
getOriginalAndScaledSizesForOtherModes,
|
||||
selectCanvasOutputFields,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import {
|
||||
type GraphBuilderArg,
|
||||
type GraphBuilderReturn,
|
||||
UnsupportedGenerationModeError,
|
||||
} from 'features/nodes/util/graph/types';
|
||||
import { hasExternalPanelControl } from 'features/parameters/util/externalPanelSchema';
|
||||
import {
|
||||
type AnyInvocation,
|
||||
type AnyModelConfigWithExternal,
|
||||
type Invocation,
|
||||
isExternalApiModelConfig,
|
||||
} from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const EXTERNAL_PROVIDER_NODE_TYPES = {
|
||||
gemini: 'gemini_image_generation',
|
||||
openai: 'openai_image_generation',
|
||||
} as const;
|
||||
|
||||
export const buildExternalGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
|
||||
const { generationMode, state, manager } = arg;
|
||||
|
||||
const modelConfig = selectModelConfig(state) as AnyModelConfigWithExternal | null;
|
||||
assert(modelConfig, 'No model selected');
|
||||
assert(isExternalApiModelConfig(modelConfig), 'Selected model is not an external API model');
|
||||
const model = modelConfig;
|
||||
|
||||
const requestedMode = generationMode === 'outpaint' ? 'inpaint' : generationMode;
|
||||
if (!model.capabilities.modes.includes(requestedMode)) {
|
||||
throw new UnsupportedGenerationModeError(`${model.name} does not support ${requestedMode} mode`);
|
||||
}
|
||||
|
||||
const params = selectParamsSlice(state);
|
||||
const refImages = selectRefImagesSlice(state);
|
||||
|
||||
const g = new Graph(getPrefixedId('external_graph'));
|
||||
const supportsSeed = hasExternalPanelControl(model, 'image', 'seed');
|
||||
const supportsReferenceImages = hasExternalPanelControl(model, 'prompts', 'reference_images');
|
||||
|
||||
const seed = supportsSeed
|
||||
? g.addNode({
|
||||
id: getPrefixedId('seed'),
|
||||
type: 'integer',
|
||||
})
|
||||
: null;
|
||||
|
||||
const positivePrompt = g.addNode({
|
||||
id: getPrefixedId('positive_prompt'),
|
||||
type: 'string',
|
||||
});
|
||||
|
||||
const externalNodeType = EXTERNAL_PROVIDER_NODE_TYPES[model.provider_id as keyof typeof EXTERNAL_PROVIDER_NODE_TYPES];
|
||||
const externalNode: Record<string, unknown> = {
|
||||
id: getPrefixedId('external_image_generation'),
|
||||
type: externalNodeType ?? 'external_image_generation',
|
||||
model: model as unknown as ModelIdentifierField,
|
||||
mode: requestedMode,
|
||||
image_size: params.imageSize ?? null,
|
||||
num_images: 1,
|
||||
};
|
||||
|
||||
// Provider-specific options
|
||||
if (model.provider_id === 'openai') {
|
||||
externalNode.quality = params.openaiQuality;
|
||||
externalNode.background = params.openaiBackground;
|
||||
if (params.openaiInputFidelity) {
|
||||
externalNode.input_fidelity = params.openaiInputFidelity;
|
||||
}
|
||||
} else if (model.provider_id === 'gemini') {
|
||||
if (params.geminiTemperature !== null) {
|
||||
externalNode.temperature = params.geminiTemperature;
|
||||
}
|
||||
if (params.geminiThinkingLevel) {
|
||||
externalNode.thinking_level = params.geminiThinkingLevel;
|
||||
}
|
||||
}
|
||||
g.addNode(externalNode as AnyInvocation);
|
||||
|
||||
if (seed) {
|
||||
g.addEdgeFromObj({
|
||||
source: { node_id: seed.id, field: 'value' },
|
||||
destination: { node_id: externalNode.id as string, field: 'seed' },
|
||||
});
|
||||
}
|
||||
g.addEdgeFromObj({
|
||||
source: { node_id: positivePrompt.id, field: 'value' },
|
||||
destination: { node_id: externalNode.id as string, field: 'prompt' },
|
||||
});
|
||||
|
||||
if (supportsReferenceImages) {
|
||||
const referenceImages = refImages.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.map((entity) => entity.config)
|
||||
.filter((config) => config.image)
|
||||
.map((config) => zImageField.parse(config.image?.crop?.image ?? config.image?.original.image));
|
||||
|
||||
if (referenceImages.length > 0) {
|
||||
externalNode.reference_images = referenceImages;
|
||||
}
|
||||
}
|
||||
|
||||
// External models require specific dimensions matching their supported presets.
|
||||
// Always use params dimensions (from selected preset) for the API width/height.
|
||||
externalNode.width = params.dimensions.width;
|
||||
externalNode.height = params.dimensions.height;
|
||||
|
||||
if (generationMode !== 'txt2img') {
|
||||
assert(manager, 'Canvas manager is required for img2img/inpaint');
|
||||
const canvasSettings = selectCanvasSettingsSlice(state);
|
||||
const { rect } = getOriginalAndScaledSizesForOtherModes(state);
|
||||
|
||||
const rasterAdapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
const initImage = await manager.compositor.getCompositeImageDTO(rasterAdapters, rect, {
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
});
|
||||
externalNode.init_image = { image_name: initImage.image_name };
|
||||
|
||||
if (generationMode === 'inpaint' || generationMode === 'outpaint') {
|
||||
const inpaintMaskAdapters = manager.compositor.getVisibleAdaptersOfType('inpaint_mask');
|
||||
const maskImage = await manager.compositor.getGrayscaleMaskCompositeImageDTO(
|
||||
inpaintMaskAdapters,
|
||||
rect,
|
||||
'denoiseLimit',
|
||||
canvasSettings.preserveMask,
|
||||
{
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
}
|
||||
);
|
||||
externalNode.mask_image = { image_name: maskImage.image_name };
|
||||
}
|
||||
}
|
||||
|
||||
g.updateNode(externalNode as AnyInvocation, selectCanvasOutputFields(state));
|
||||
|
||||
return {
|
||||
g,
|
||||
seed: seed ?? undefined,
|
||||
positivePrompt: positivePrompt as Invocation<'string'>,
|
||||
};
|
||||
};
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { addZImageControl } from 'features/nodes/util/graph/generation/addControlAdapters';
|
||||
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
|
||||
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
|
||||
@@ -77,7 +78,7 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
|
||||
model,
|
||||
vae_model: zImageVaeModel ?? undefined,
|
||||
qwen3_encoder_model: zImageQwen3EncoderModel ?? undefined,
|
||||
qwen3_source_model: zImageQwen3SourceModel ?? undefined,
|
||||
qwen3_source_model: (zImageQwen3SourceModel as ModelIdentifierField | null) ?? undefined,
|
||||
});
|
||||
|
||||
const positivePrompt = g.addNode({
|
||||
|
||||
@@ -2,6 +2,7 @@ import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { selectCLIPSkip, selectModel, setClipSkip } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { BaseModelType } from 'features/nodes/types/common';
|
||||
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -32,14 +33,14 @@ const ParamClipSkip = () => {
|
||||
if (!model) {
|
||||
return CLIP_SKIP_MAP['sd-1']?.maxClip;
|
||||
}
|
||||
return CLIP_SKIP_MAP[model.base]?.maxClip;
|
||||
return CLIP_SKIP_MAP[model.base as BaseModelType]?.maxClip;
|
||||
}, [model]);
|
||||
|
||||
const sliderMarks = useMemo(() => {
|
||||
if (!model) {
|
||||
return CLIP_SKIP_MAP['sd-1']?.markers;
|
||||
}
|
||||
return CLIP_SKIP_MAP[model.base]?.markers;
|
||||
return CLIP_SKIP_MAP[model.base as BaseModelType]?.markers;
|
||||
}, [model]);
|
||||
|
||||
if (model?.base === 'sdxl') {
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
zImageQwen3SourceModelSelected,
|
||||
zImageVaeModelSelected,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { type ModelIdentifierField, zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useFlux1VAEModels, useQwen3EncoderModels, useZImageDiffusersModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -136,7 +136,7 @@ const ParamZImageQwen3SourceModelSelect = memo(() => {
|
||||
const { options, value, onChange, noOptionsMessage } = useModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: zImageQwen3SourceModel,
|
||||
selectedModel: zImageQwen3SourceModel as ModelIdentifierField | null,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import type { ExternalApiModelConfig } from 'services/api/types';
|
||||
import { describe, expect, test } from 'vitest';
|
||||
|
||||
const createExternalModel = (overrides: Partial<ExternalApiModelConfig> = {}): ExternalApiModelConfig => ({
|
||||
key: 'external-test',
|
||||
name: 'External Test',
|
||||
base: 'external',
|
||||
type: 'external_image_generator',
|
||||
format: 'external_api',
|
||||
provider_id: 'gemini',
|
||||
provider_model_id: 'gemini-2.5-flash-image',
|
||||
description: 'Test model',
|
||||
source: 'external://gemini/gemini-2.5-flash-image',
|
||||
source_type: 'url',
|
||||
source_api_response: null,
|
||||
path: '',
|
||||
file_size: 0,
|
||||
hash: 'external:gemini:gemini-2.5-flash-image',
|
||||
cover_image: null,
|
||||
is_default: false,
|
||||
tags: ['external'],
|
||||
capabilities: {
|
||||
modes: ['txt2img'],
|
||||
supports_reference_images: false,
|
||||
supports_seed: true,
|
||||
max_images_per_request: 1,
|
||||
max_image_size: null,
|
||||
allowed_aspect_ratios: ['1:1', '16:9'],
|
||||
max_reference_images: null,
|
||||
mask_format: 'none',
|
||||
input_image_required_for: null,
|
||||
},
|
||||
default_settings: null,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
describe('external model aspect ratios (bbox)', () => {
|
||||
test('uses allowed aspect ratios for external models', () => {
|
||||
const model = createExternalModel();
|
||||
expect(model.capabilities.allowed_aspect_ratios).toEqual(['1:1', '16:9']);
|
||||
});
|
||||
});
|
||||
@@ -3,10 +3,16 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { bboxAspectRatioIdChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
aspectRatioIdChanged,
|
||||
selectAllowedAspectRatioIDs,
|
||||
selectAspectRatioSizes,
|
||||
selectHasFixedDimensionSizes,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectAspectRatioID } from 'features/controlLayers/store/selectors';
|
||||
import { isAspectRatioID, zAspectRatioID } from 'features/controlLayers/store/types';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
|
||||
@@ -15,24 +21,33 @@ export const BboxAspectRatioSelect = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const id = useAppSelector(selectAspectRatioID);
|
||||
const isStaging = useCanvasIsStaging();
|
||||
const allowedAspectRatios = useAppSelector(selectAllowedAspectRatioIDs);
|
||||
const aspectRatioSizes = useAppSelector(selectAspectRatioSizes);
|
||||
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
|
||||
const options = useMemo(() => allowedAspectRatios ?? zAspectRatioID.options, [allowedAspectRatios]);
|
||||
|
||||
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
|
||||
(e) => {
|
||||
if (!isAspectRatioID(e.target.value)) {
|
||||
return;
|
||||
}
|
||||
dispatch(bboxAspectRatioIdChanged({ id: e.target.value }));
|
||||
const fixedSize = aspectRatioSizes?.[e.target.value] ?? undefined;
|
||||
dispatch(bboxAspectRatioIdChanged({ id: e.target.value, fixedSize }));
|
||||
// For external models with fixed sizes, also sync to params so buildExternalGraph uses correct dimensions
|
||||
if (fixedSize) {
|
||||
dispatch(aspectRatioIdChanged({ id: e.target.value, fixedSize }));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
[dispatch, aspectRatioSizes]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={isStaging}>
|
||||
<FormControl isDisabled={isStaging || hasFixedSizes}>
|
||||
<InformationalPopover feature="paramAspect">
|
||||
<FormLabel>{t('parameters.aspect')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Select size="sm" value={id} onChange={onChange} cursor="pointer" iconSize="0.75rem" icon={<PiCaretDownBold />}>
|
||||
{zAspectRatioID.options.map((ratio) => (
|
||||
{options.map((ratio) => (
|
||||
<option key={ratio} value={ratio}>
|
||||
{ratio}
|
||||
</option>
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { bboxDimensionsSwapped } from 'features/controlLayers/store/canvasSlice';
|
||||
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectHasFixedDimensionSizes } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsDownUpBold } from 'react-icons/pi';
|
||||
@@ -10,6 +11,7 @@ export const BboxSwapDimensionsButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isStaging = useCanvasIsStaging();
|
||||
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(bboxDimensionsSwapped());
|
||||
}, [dispatch]);
|
||||
@@ -21,7 +23,7 @@ export const BboxSwapDimensionsButton = memo(() => {
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
icon={<PiArrowsDownUpBold />}
|
||||
isDisabled={isStaging}
|
||||
isDisabled={isStaging || hasFixedSizes}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectHasFixedDimensionSizes } from 'features/controlLayers/store/paramsSlice';
|
||||
|
||||
export const useIsBboxSizeLocked = () => {
|
||||
const isStaging = useCanvasIsStaging();
|
||||
return isStaging;
|
||||
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
|
||||
return isStaging || hasFixedSizes;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import type { ExternalApiModelConfig } from 'services/api/types';
|
||||
import { describe, expect, test } from 'vitest';
|
||||
|
||||
const createExternalModel = (overrides: Partial<ExternalApiModelConfig> = {}): ExternalApiModelConfig => ({
|
||||
key: 'external-test',
|
||||
name: 'External Test',
|
||||
base: 'external',
|
||||
type: 'external_image_generator',
|
||||
format: 'external_api',
|
||||
provider_id: 'gemini',
|
||||
provider_model_id: 'gemini-2.5-flash-image',
|
||||
description: 'Test model',
|
||||
source: 'external://gemini/gemini-2.5-flash-image',
|
||||
source_type: 'url',
|
||||
source_api_response: null,
|
||||
path: '',
|
||||
file_size: 0,
|
||||
hash: 'external:gemini:gemini-2.5-flash-image',
|
||||
cover_image: null,
|
||||
is_default: false,
|
||||
tags: ['external'],
|
||||
capabilities: {
|
||||
modes: ['txt2img'],
|
||||
supports_reference_images: false,
|
||||
supports_seed: true,
|
||||
max_images_per_request: 1,
|
||||
max_image_size: null,
|
||||
allowed_aspect_ratios: ['1:1', '16:9'],
|
||||
max_reference_images: null,
|
||||
mask_format: 'none',
|
||||
input_image_required_for: null,
|
||||
},
|
||||
default_settings: null,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
describe('external model aspect ratios', () => {
|
||||
test('uses allowed aspect ratios for external models', () => {
|
||||
const model = createExternalModel();
|
||||
expect(model.capabilities.allowed_aspect_ratios).toEqual(['1:1', '16:9']);
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,12 @@
|
||||
import { FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { aspectRatioIdChanged, selectAspectRatioID } from 'features/controlLayers/store/paramsSlice';
|
||||
import {
|
||||
aspectRatioIdChanged,
|
||||
selectAllowedAspectRatioIDs,
|
||||
selectAspectRatioID,
|
||||
selectAspectRatioSizes,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { isAspectRatioID, zAspectRatioID } from 'features/controlLayers/store/types';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
@@ -12,15 +17,19 @@ export const DimensionsAspectRatioSelect = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const id = useAppSelector(selectAspectRatioID);
|
||||
const allowedAspectRatios = useAppSelector(selectAllowedAspectRatioIDs);
|
||||
const aspectRatioSizes = useAppSelector(selectAspectRatioSizes);
|
||||
const options = allowedAspectRatios ?? zAspectRatioID.options;
|
||||
|
||||
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
|
||||
(e) => {
|
||||
if (!isAspectRatioID(e.target.value)) {
|
||||
return;
|
||||
}
|
||||
dispatch(aspectRatioIdChanged({ id: e.target.value }));
|
||||
const fixedSize = aspectRatioSizes?.[e.target.value] ?? undefined;
|
||||
dispatch(aspectRatioIdChanged({ id: e.target.value, fixedSize }));
|
||||
},
|
||||
[dispatch]
|
||||
[dispatch, aspectRatioSizes]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -29,7 +38,7 @@ export const DimensionsAspectRatioSelect = memo(() => {
|
||||
<FormLabel>{t('parameters.aspect')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Select size="sm" value={id} onChange={onChange} cursor="pointer" iconSize="0.75rem" icon={<PiCaretDownBold />}>
|
||||
{zAspectRatioID.options.map((ratio) => (
|
||||
{options.map((ratio) => (
|
||||
<option key={ratio} value={ratio}>
|
||||
{ratio}
|
||||
</option>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { heightChanged, selectHeight } from 'features/controlLayers/store/paramsSlice';
|
||||
import { heightChanged, selectHasFixedDimensionSizes, selectHeight } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectGridSize, selectOptimalDimension } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -22,6 +22,7 @@ export const DimensionsHeight = memo(() => {
|
||||
const optimalDimension = useAppSelector(selectOptimalDimension);
|
||||
const height = useAppSelector(selectHeight);
|
||||
const gridSize = useAppSelector(selectGridSize);
|
||||
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
@@ -33,7 +34,7 @@ export const DimensionsHeight = memo(() => {
|
||||
const marks = useMemo(() => [CONSTRAINTS.sliderMin, optimalDimension, CONSTRAINTS.sliderMax], [optimalDimension]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormControl isDisabled={hasFixedSizes}>
|
||||
<InformationalPopover feature="paramHeight">
|
||||
<FormLabel>{t('parameters.height')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { aspectRatioLockToggled, selectAspectRatioIsLocked } from 'features/controlLayers/store/paramsSlice';
|
||||
import {
|
||||
aspectRatioLockToggled,
|
||||
selectAspectRatioIsLocked,
|
||||
selectHasFixedDimensionSizes,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiLockSimpleFill, PiLockSimpleOpenBold } from 'react-icons/pi';
|
||||
@@ -9,6 +13,7 @@ export const DimensionsLockAspectRatioButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isLocked = useAppSelector(selectAspectRatioIsLocked);
|
||||
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(aspectRatioLockToggled());
|
||||
@@ -22,6 +27,7 @@ export const DimensionsLockAspectRatioButton = memo(() => {
|
||||
variant={isLocked ? 'outline' : 'ghost'}
|
||||
size="sm"
|
||||
icon={isLocked ? <PiLockSimpleFill /> : <PiLockSimpleOpenBold />}
|
||||
isDisabled={hasFixedSizes}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectHeight, selectWidth, sizeOptimized } from 'features/controlLayers/store/paramsSlice';
|
||||
import {
|
||||
selectHasFixedDimensionSizes,
|
||||
selectHeight,
|
||||
selectWidth,
|
||||
sizeOptimized,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectOptimalDimension } from 'features/controlLayers/store/selectors';
|
||||
import { getIsSizeTooLarge, getIsSizeTooSmall } from 'features/parameters/util/optimalDimension';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -13,6 +18,7 @@ export const DimensionsSetOptimalSizeButton = memo(() => {
|
||||
const width = useAppSelector(selectWidth);
|
||||
const height = useAppSelector(selectHeight);
|
||||
const optimalDimension = useAppSelector(selectOptimalDimension);
|
||||
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
|
||||
const isSizeTooSmall = useMemo(
|
||||
() => getIsSizeTooSmall(width, height, optimalDimension),
|
||||
[height, width, optimalDimension]
|
||||
@@ -43,6 +49,7 @@ export const DimensionsSetOptimalSizeButton = memo(() => {
|
||||
size="sm"
|
||||
icon={<PiSparkleFill />}
|
||||
colorScheme={isSizeTooSmall || isSizeTooLarge ? 'warning' : 'base'}
|
||||
isDisabled={hasFixedSizes}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { dimensionsSwapped } from 'features/controlLayers/store/paramsSlice';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { dimensionsSwapped, selectHasFixedDimensionSizes } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsDownUpBold } from 'react-icons/pi';
|
||||
@@ -8,6 +8,7 @@ import { PiArrowsDownUpBold } from 'react-icons/pi';
|
||||
export const DimensionsSwapButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(dimensionsSwapped());
|
||||
}, [dispatch]);
|
||||
@@ -19,6 +20,7 @@ export const DimensionsSwapButton = memo(() => {
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
icon={<PiArrowsDownUpBold />}
|
||||
isDisabled={hasFixedSizes}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { selectWidth, widthChanged } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectHasFixedDimensionSizes, selectWidth, widthChanged } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectGridSize, selectOptimalDimension } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -22,6 +22,7 @@ export const DimensionsWidth = memo(() => {
|
||||
const width = useAppSelector(selectWidth);
|
||||
const optimalDimension = useAppSelector(selectOptimalDimension);
|
||||
const gridSize = useAppSelector(selectGridSize);
|
||||
const hasFixedSizes = useAppSelector(selectHasFixedDimensionSizes);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
@@ -33,7 +34,7 @@ export const DimensionsWidth = memo(() => {
|
||||
const marks = useMemo(() => [CONSTRAINTS.sliderMin, optimalDimension, CONSTRAINTS.sliderMax], [optimalDimension]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormControl isDisabled={hasFixedSizes}>
|
||||
<InformationalPopover feature="paramWidth">
|
||||
<FormLabel>{t('parameters.width')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
import { FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
resolutionPresetSelected,
|
||||
selectAspectRatioID,
|
||||
selectImageSize,
|
||||
selectResolutionPresets,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
|
||||
const makeKey = (aspectRatio: string, imageSize: string) => `${aspectRatio}|${imageSize}`;
|
||||
|
||||
export const ExternalModelImageSizeSelect = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const presets = useAppSelector(selectResolutionPresets);
|
||||
const currentAspectRatio = useAppSelector(selectAspectRatioID);
|
||||
const currentImageSize = useAppSelector(selectImageSize);
|
||||
|
||||
const presetMap = useMemo(() => {
|
||||
if (!presets) {
|
||||
return null;
|
||||
}
|
||||
const map = new Map<string, (typeof presets)[number]>();
|
||||
for (const preset of presets) {
|
||||
map.set(makeKey(preset.aspect_ratio, preset.image_size), preset);
|
||||
}
|
||||
return map;
|
||||
}, [presets]);
|
||||
|
||||
const selectedKey = useMemo(() => {
|
||||
if (!presets || presets.length === 0) {
|
||||
return '';
|
||||
}
|
||||
if (currentImageSize && currentAspectRatio) {
|
||||
const key = makeKey(currentAspectRatio, currentImageSize);
|
||||
if (presetMap?.has(key)) {
|
||||
return key;
|
||||
}
|
||||
}
|
||||
// Fallback to first preset
|
||||
return makeKey(presets[0]!.aspect_ratio, presets[0]!.image_size);
|
||||
}, [presets, presetMap, currentAspectRatio, currentImageSize]);
|
||||
|
||||
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
|
||||
(e) => {
|
||||
const preset = presetMap?.get(e.target.value);
|
||||
if (!preset) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
resolutionPresetSelected({
|
||||
imageSize: preset.image_size,
|
||||
aspectRatio: preset.aspect_ratio,
|
||||
width: preset.width,
|
||||
height: preset.height,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, presetMap]
|
||||
);
|
||||
|
||||
if (!presets || presets.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{t('parameters.resolution')}</FormLabel>
|
||||
<Select
|
||||
size="sm"
|
||||
value={selectedKey}
|
||||
onChange={onChange}
|
||||
cursor="pointer"
|
||||
iconSize="0.75rem"
|
||||
icon={<PiCaretDownBold />}
|
||||
>
|
||||
{presets.map((preset) => {
|
||||
const key = makeKey(preset.aspect_ratio, preset.image_size);
|
||||
return (
|
||||
<option key={key} value={key}>
|
||||
{preset.label}
|
||||
</option>
|
||||
);
|
||||
})}
|
||||
</Select>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
ExternalModelImageSizeSelect.displayName = 'ExternalModelImageSizeSelect';
|
||||
@@ -0,0 +1,68 @@
|
||||
import { FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
aspectRatioIdChanged,
|
||||
selectAspectRatioID,
|
||||
selectAspectRatioSizes,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { isAspectRatioID } from 'features/controlLayers/store/types';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
|
||||
export const ExternalModelResolutionSelect = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const aspectRatioID = useAppSelector(selectAspectRatioID);
|
||||
const aspectRatioSizes = useAppSelector(selectAspectRatioSizes);
|
||||
|
||||
const options = useMemo(() => {
|
||||
if (!aspectRatioSizes) {
|
||||
return [];
|
||||
}
|
||||
return Object.entries(aspectRatioSizes).map(([ratio, size]) => ({
|
||||
ratio,
|
||||
label: `${ratio} (${size.width}×${size.height})`,
|
||||
size,
|
||||
}));
|
||||
}, [aspectRatioSizes]);
|
||||
|
||||
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
|
||||
(e) => {
|
||||
const ratio = e.target.value;
|
||||
if (!isAspectRatioID(ratio)) {
|
||||
return;
|
||||
}
|
||||
const fixedSize = aspectRatioSizes?.[ratio] ?? undefined;
|
||||
dispatch(aspectRatioIdChanged({ id: ratio, fixedSize }));
|
||||
},
|
||||
[dispatch, aspectRatioSizes]
|
||||
);
|
||||
|
||||
if (!aspectRatioSizes) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{t('parameters.resolution')}</FormLabel>
|
||||
<Select
|
||||
size="sm"
|
||||
value={aspectRatioID}
|
||||
onChange={onChange}
|
||||
cursor="pointer"
|
||||
iconSize="0.75rem"
|
||||
icon={<PiCaretDownBold />}
|
||||
>
|
||||
{options.map(({ ratio, label }) => (
|
||||
<option key={ratio} value={ratio}>
|
||||
{label}
|
||||
</option>
|
||||
))}
|
||||
</Select>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
ExternalModelResolutionSelect.displayName = 'ExternalModelResolutionSelect';
|
||||
74
invokeai/frontend/web/src/features/parameters/components/External/GeminiProviderOptions.tsx
vendored
Normal file
74
invokeai/frontend/web/src/features/parameters/components/External/GeminiProviderOptions.tsx
vendored
Normal file
@@ -0,0 +1,74 @@
|
||||
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
geminiTemperatureChanged,
|
||||
geminiThinkingLevelChanged,
|
||||
selectGeminiTemperature,
|
||||
selectGeminiThinkingLevel,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
|
||||
const TEMPERATURE_MARKS = [0, 1, 2];
|
||||
|
||||
export const GeminiProviderOptions = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const temperature = useAppSelector(selectGeminiTemperature);
|
||||
const thinkingLevel = useAppSelector(selectGeminiThinkingLevel);
|
||||
|
||||
const onTemperatureChange = useCallback((v: number) => dispatch(geminiTemperatureChanged(v)), [dispatch]);
|
||||
|
||||
const onThinkingLevelChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
|
||||
(e) => {
|
||||
const value = e.target.value;
|
||||
dispatch(geminiThinkingLevelChanged(value === '' ? null : (value as 'minimal' | 'high')));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<FormControl>
|
||||
<FormLabel>{t('parameters.temperature', 'Temperature')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={temperature ?? 1}
|
||||
defaultValue={1}
|
||||
min={0}
|
||||
max={2}
|
||||
step={0.1}
|
||||
fineStep={0.05}
|
||||
onChange={onTemperatureChange}
|
||||
marks={TEMPERATURE_MARKS}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={temperature ?? 1}
|
||||
defaultValue={1}
|
||||
min={0}
|
||||
max={2}
|
||||
step={0.1}
|
||||
fineStep={0.05}
|
||||
onChange={onTemperatureChange}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('parameters.thinkingLevel', 'Thinking Level')}</FormLabel>
|
||||
<Select
|
||||
size="sm"
|
||||
value={thinkingLevel ?? ''}
|
||||
onChange={onThinkingLevelChange}
|
||||
icon={<PiCaretDownBold />}
|
||||
iconSize="0.75rem"
|
||||
>
|
||||
<option value="">Default</option>
|
||||
<option value="minimal">Minimal</option>
|
||||
<option value="high">High</option>
|
||||
</Select>
|
||||
</FormControl>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
GeminiProviderOptions.displayName = 'GeminiProviderOptions';
|
||||
84
invokeai/frontend/web/src/features/parameters/components/External/OpenAIProviderOptions.tsx
vendored
Normal file
84
invokeai/frontend/web/src/features/parameters/components/External/OpenAIProviderOptions.tsx
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
import { FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
openaiBackgroundChanged,
|
||||
openaiInputFidelityChanged,
|
||||
openaiQualityChanged,
|
||||
selectOpenaiBackground,
|
||||
selectOpenaiInputFidelity,
|
||||
selectOpenaiQuality,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
|
||||
export const OpenAIProviderOptions = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const quality = useAppSelector(selectOpenaiQuality);
|
||||
const background = useAppSelector(selectOpenaiBackground);
|
||||
const inputFidelity = useAppSelector(selectOpenaiInputFidelity);
|
||||
|
||||
const onQualityChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
|
||||
(e) => dispatch(openaiQualityChanged(e.target.value as 'auto' | 'high' | 'medium' | 'low')),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onBackgroundChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
|
||||
(e) => dispatch(openaiBackgroundChanged(e.target.value as 'auto' | 'transparent' | 'opaque')),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onInputFidelityChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
|
||||
(e) => {
|
||||
const value = e.target.value;
|
||||
dispatch(openaiInputFidelityChanged(value === '' ? null : (value as 'low' | 'high')));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<FormControl>
|
||||
<FormLabel>{t('parameters.quality', 'Quality')}</FormLabel>
|
||||
<Select size="sm" value={quality} onChange={onQualityChange} icon={<PiCaretDownBold />} iconSize="0.75rem">
|
||||
<option value="auto">Auto</option>
|
||||
<option value="high">High</option>
|
||||
<option value="medium">Medium</option>
|
||||
<option value="low">Low</option>
|
||||
</Select>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('parameters.background', 'Background')}</FormLabel>
|
||||
<Select
|
||||
size="sm"
|
||||
value={background}
|
||||
onChange={onBackgroundChange}
|
||||
icon={<PiCaretDownBold />}
|
||||
iconSize="0.75rem"
|
||||
>
|
||||
<option value="auto">Auto</option>
|
||||
<option value="transparent">Transparent</option>
|
||||
<option value="opaque">Opaque</option>
|
||||
</Select>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('parameters.inputFidelity', 'Input Fidelity')}</FormLabel>
|
||||
<Select
|
||||
size="sm"
|
||||
value={inputFidelity ?? ''}
|
||||
onChange={onInputFidelityChange}
|
||||
icon={<PiCaretDownBold />}
|
||||
iconSize="0.75rem"
|
||||
>
|
||||
<option value="">Default</option>
|
||||
<option value="low">Low</option>
|
||||
<option value="high">High</option>
|
||||
</Select>
|
||||
</FormControl>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
OpenAIProviderOptions.displayName = 'OpenAIProviderOptions';
|
||||
@@ -0,0 +1,60 @@
|
||||
import type {
|
||||
ExternalApiModelConfig,
|
||||
ExternalApiModelDefaultSettings,
|
||||
ExternalImageSize,
|
||||
ExternalModelCapabilities,
|
||||
} from 'services/api/types';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { isExternalModelUnsupportedForTab } from './mainModelPickerUtils';
|
||||
|
||||
const createExternalConfig = (modes: ExternalModelCapabilities['modes']): ExternalApiModelConfig => {
|
||||
const maxImageSize: ExternalImageSize = { width: 1024, height: 1024 };
|
||||
const defaultSettings: ExternalApiModelDefaultSettings = { width: 1024, height: 1024, steps: 30 };
|
||||
|
||||
return {
|
||||
key: 'external-test',
|
||||
hash: 'external:openai:gpt-image-1',
|
||||
path: 'external://openai/gpt-image-1',
|
||||
file_size: 0,
|
||||
name: 'External Test',
|
||||
description: null,
|
||||
source: 'external://openai/gpt-image-1',
|
||||
source_type: 'url',
|
||||
source_api_response: null,
|
||||
cover_image: null,
|
||||
base: 'external',
|
||||
type: 'external_image_generator',
|
||||
format: 'external_api',
|
||||
provider_id: 'openai',
|
||||
provider_model_id: 'gpt-image-1',
|
||||
capabilities: {
|
||||
modes,
|
||||
supports_reference_images: false,
|
||||
max_image_size: maxImageSize,
|
||||
},
|
||||
default_settings: defaultSettings,
|
||||
tags: ['external'],
|
||||
is_default: false,
|
||||
};
|
||||
};
|
||||
|
||||
describe('isExternalModelUnsupportedForTab', () => {
|
||||
it('disables external models without txt2img for generate', () => {
|
||||
const model = createExternalConfig(['img2img', 'inpaint']);
|
||||
|
||||
expect(isExternalModelUnsupportedForTab(model, 'generate')).toBe(true);
|
||||
});
|
||||
|
||||
it('allows external models with txt2img for generate', () => {
|
||||
const model = createExternalConfig(['txt2img']);
|
||||
|
||||
expect(isExternalModelUnsupportedForTab(model, 'generate')).toBe(false);
|
||||
});
|
||||
|
||||
it('allows external models on canvas', () => {
|
||||
const model = createExternalConfig(['inpaint']);
|
||||
|
||||
expect(isExternalModelUnsupportedForTab(model, 'canvas')).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,14 @@
|
||||
import type { TabName } from 'features/ui/store/uiTypes';
|
||||
import { type AnyModelConfigWithExternal, isExternalApiModelConfig } from 'services/api/types';
|
||||
|
||||
export const isExternalModelUnsupportedForTab = (model: AnyModelConfigWithExternal, tab: TabName): boolean => {
|
||||
if (!isExternalApiModelConfig(model)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tab === 'generate') {
|
||||
return !model.capabilities.modes.includes('txt2img');
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { BoxProps, ButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Badge,
|
||||
Button,
|
||||
Flex,
|
||||
Icon,
|
||||
@@ -36,7 +37,11 @@ import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold, PiLinkSimple } from 'react-icons/pi';
|
||||
import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
|
||||
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
type AnyModelConfigWithExternal,
|
||||
type ExternalApiModelConfig,
|
||||
isExternalApiModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
|
||||
const keys: string[] = [];
|
||||
@@ -67,7 +72,7 @@ const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, select
|
||||
type WithStarred<T> = T & { starred?: boolean };
|
||||
|
||||
// Type for models with starred field
|
||||
const getOptionId = <T extends AnyModelConfig>(modelConfig: WithStarred<T>) => modelConfig.key;
|
||||
const getOptionId = <T extends AnyModelConfigWithExternal>(modelConfig: WithStarred<T>) => modelConfig.key;
|
||||
|
||||
const ModelManagerLink = memo((props: ButtonProps) => {
|
||||
const onClick = useCallback(() => {
|
||||
@@ -123,19 +128,17 @@ const NoOptionsFallback = memo(({ noOptionsText }: { noOptionsText?: string }) =
|
||||
});
|
||||
NoOptionsFallback.displayName = 'NoOptionsFallback';
|
||||
|
||||
const getGroupIDFromModelConfig = (modelConfig: AnyModelConfig): string => {
|
||||
return modelConfig.base;
|
||||
};
|
||||
const getGroupIDFromModelConfig = (modelConfig: AnyModelConfigWithExternal): string => modelConfig.base;
|
||||
|
||||
const getGroupNameFromModelConfig = (modelConfig: AnyModelConfig): string => {
|
||||
const getGroupNameFromModelConfig = (modelConfig: AnyModelConfigWithExternal): string => {
|
||||
return MODEL_BASE_TO_LONG_NAME[modelConfig.base];
|
||||
};
|
||||
|
||||
const getGroupShortNameFromModelConfig = (modelConfig: AnyModelConfig): string => {
|
||||
const getGroupShortNameFromModelConfig = (modelConfig: AnyModelConfigWithExternal): string => {
|
||||
return MODEL_BASE_TO_SHORT_NAME[modelConfig.base];
|
||||
};
|
||||
|
||||
const getGroupColorSchemeFromModelConfig = (modelConfig: AnyModelConfig): string => {
|
||||
const getGroupColorSchemeFromModelConfig = (modelConfig: AnyModelConfigWithExternal): string => {
|
||||
return MODEL_BASE_TO_COLOR[modelConfig.base];
|
||||
};
|
||||
|
||||
@@ -162,7 +165,7 @@ const removeStarred = <T,>(obj: WithStarred<T>): T => {
|
||||
};
|
||||
|
||||
export const ModelPicker = typedMemo(
|
||||
<T extends AnyModelConfig = AnyModelConfig>({
|
||||
<T extends AnyModelConfigWithExternal = AnyModelConfigWithExternal>({
|
||||
pickerId,
|
||||
modelConfigs,
|
||||
selectedModelConfig,
|
||||
@@ -414,8 +417,10 @@ const optionNameSx: SystemStyleObject = {
|
||||
};
|
||||
|
||||
const PickerOptionComponent = typedMemo(
|
||||
<T extends AnyModelConfig>({ option, ...rest }: { option: WithStarred<T> } & BoxProps) => {
|
||||
<T extends AnyModelConfigWithExternal>({ option, ...rest }: { option: WithStarred<T> } & BoxProps) => {
|
||||
const { isCompactView } = usePickerContext<WithStarred<T>>();
|
||||
const externalOption = isExternalApiModelConfig(option) ? (option as ExternalApiModelConfig) : null;
|
||||
const providerLabel = externalOption ? externalOption.provider_id.toUpperCase() : null;
|
||||
|
||||
return (
|
||||
<Flex {...rest} sx={optionSx} data-is-compact={isCompactView}>
|
||||
@@ -426,6 +431,15 @@ const PickerOptionComponent = typedMemo(
|
||||
<Text className="picker-option" sx={optionNameSx} data-is-compact={isCompactView}>
|
||||
{option.name}
|
||||
</Text>
|
||||
{!isCompactView && externalOption && (
|
||||
<Badge
|
||||
colorScheme={MODEL_BASE_TO_COLOR[externalOption.base as BaseModelType]}
|
||||
variant="subtle"
|
||||
flexShrink={0}
|
||||
>
|
||||
{providerLabel}
|
||||
</Badge>
|
||||
)}
|
||||
<Spacer />
|
||||
{option.file_size > 0 && (
|
||||
<Text
|
||||
@@ -458,11 +472,13 @@ const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
|
||||
'sd-3': ['sd3', 'sd3.0', 'sd3.5', 'sd-3'],
|
||||
};
|
||||
|
||||
const isMatch = <T extends AnyModelConfig>(model: WithStarred<T>, searchTerm: string) => {
|
||||
const isMatch = <T extends AnyModelConfigWithExternal>(model: WithStarred<T>, searchTerm: string) => {
|
||||
const regex = getRegex(searchTerm);
|
||||
const bases = BASE_KEYWORDS[model.base] ?? [model.base];
|
||||
const externalModel = isExternalApiModelConfig(model) ? (model as ExternalApiModelConfig) : null;
|
||||
const externalSearch = externalModel ? ` ${externalModel.provider_id} ${externalModel.provider_model_id}` : '';
|
||||
const testString =
|
||||
`${model.name} ${bases.join(' ')} ${model.type} ${model.description ?? ''} ${model.format}`.toLowerCase();
|
||||
`${model.name} ${bases.join(' ')} ${model.type} ${model.description ?? ''} ${model.format}${externalSearch}`.toLowerCase();
|
||||
|
||||
if (testString.includes(searchTerm) || regex.test(testString)) {
|
||||
return true;
|
||||
|
||||
@@ -2,13 +2,19 @@ import { CompositeNumberInput, FormControl, FormLabel } from '@invoke-ai/ui-libr
|
||||
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from 'app/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { selectSeed, selectShouldRandomizeSeed, setSeed } from 'features/controlLayers/store/paramsSlice';
|
||||
import {
|
||||
selectSeed,
|
||||
selectSeedControl,
|
||||
selectShouldRandomizeSeed,
|
||||
setSeed,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const ParamSeedNumberInput = memo(() => {
|
||||
const seed = useAppSelector(selectSeed);
|
||||
const shouldRandomizeSeed = useAppSelector(selectShouldRandomizeSeed);
|
||||
const externalControl = useAppSelector(selectSeedControl);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -22,9 +28,10 @@ export const ParamSeedNumberInput = memo(() => {
|
||||
<FormLabel>{t('parameters.seed')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<CompositeNumberInput
|
||||
step={1}
|
||||
min={NUMPY_RAND_MIN}
|
||||
max={NUMPY_RAND_MAX}
|
||||
step={externalControl?.coarse_step ?? 1}
|
||||
fineStep={externalControl?.fine_step ?? undefined}
|
||||
min={externalControl?.number_input_min ?? NUMPY_RAND_MIN}
|
||||
max={externalControl?.number_input_max ?? NUMPY_RAND_MAX}
|
||||
onChange={handleChangeSeed}
|
||||
value={seed}
|
||||
flexGrow={1}
|
||||
|
||||
@@ -3,6 +3,7 @@ import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { buildZodTypeGuard } from 'common/util/zodUtils';
|
||||
import {
|
||||
zAnimaSchedulerField,
|
||||
zExternalModelIdentifierField,
|
||||
zFluxDypeExponentField,
|
||||
zFluxDypePresetField,
|
||||
zFluxDypeScaleField,
|
||||
@@ -121,7 +122,7 @@ export const isParameterHeight = isParameterImageDimension;
|
||||
// #endregion
|
||||
|
||||
// #region Model
|
||||
export const zParameterModel = zModelIdentifierField;
|
||||
export const zParameterModel = z.union([zModelIdentifierField, zExternalModelIdentifierField]);
|
||||
export type ParameterModel = z.infer<typeof zParameterModel>;
|
||||
// #endregion
|
||||
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import type {
|
||||
ExternalApiModelConfig,
|
||||
ExternalModelCapabilities,
|
||||
ExternalModelPanelControl,
|
||||
ExternalModelPanelSchema,
|
||||
ExternalPanelControlName,
|
||||
} from 'services/api/types';
|
||||
|
||||
type ExternalPanelName = keyof ExternalModelPanelSchema;
|
||||
|
||||
const buildExternalPanelSchemaFromCapabilities = (
|
||||
capabilities: ExternalModelCapabilities
|
||||
): ExternalModelPanelSchema => ({
|
||||
prompts: [...(capabilities.supports_reference_images ? [{ name: 'reference_images' as const }] : [])],
|
||||
image: [{ name: 'dimensions' }, ...(capabilities.supports_seed ? [{ name: 'seed' as const }] : [])],
|
||||
generation: [],
|
||||
});
|
||||
|
||||
const getExternalPanelSchema = (modelConfig: ExternalApiModelConfig): ExternalModelPanelSchema =>
|
||||
modelConfig.panel_schema ?? buildExternalPanelSchemaFromCapabilities(modelConfig.capabilities);
|
||||
|
||||
export const getExternalPanelControl = (
|
||||
modelConfig: ExternalApiModelConfig,
|
||||
panel: ExternalPanelName,
|
||||
controlName: ExternalPanelControlName
|
||||
): ExternalModelPanelControl | null =>
|
||||
getExternalPanelSchema(modelConfig)[panel].find((control) => control.name === controlName) ?? null;
|
||||
|
||||
export const hasExternalPanelControl = (
|
||||
modelConfig: ExternalApiModelConfig,
|
||||
panel: ExternalPanelName,
|
||||
controlName: ExternalPanelControlName
|
||||
): boolean => getExternalPanelControl(modelConfig, panel, controlName) !== null;
|
||||
@@ -7,9 +7,11 @@ import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { BaseModelType } from 'features/nodes/types/common';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildAnimaGraph } from 'features/nodes/util/graph/generation/buildAnimaGraph';
|
||||
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
||||
import { buildExternalGraph } from 'features/nodes/util/graph/generation/buildExternalGraph';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildQwenImageGraph } from 'features/nodes/util/graph/generation/buildQwenImageGraph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
@@ -63,6 +65,8 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
|
||||
return await buildQwenImageGraph(graphBuilderArg);
|
||||
case 'z-image':
|
||||
return await buildZImageGraph(graphBuilderArg);
|
||||
case 'external':
|
||||
return await buildExternalGraph(graphBuilderArg);
|
||||
case 'anima':
|
||||
return await buildAnimaGraph(graphBuilderArg);
|
||||
default:
|
||||
@@ -97,7 +101,7 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base,
|
||||
base: base as BaseModelType,
|
||||
prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
|
||||
@@ -5,9 +5,11 @@ import { useAppStore } from 'app/store/storeHooks';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { BaseModelType } from 'features/nodes/types/common';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildAnimaGraph } from 'features/nodes/util/graph/generation/buildAnimaGraph';
|
||||
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
||||
import { buildExternalGraph } from 'features/nodes/util/graph/generation/buildExternalGraph';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildQwenImageGraph } from 'features/nodes/util/graph/generation/buildQwenImageGraph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
@@ -56,6 +58,8 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
|
||||
return await buildQwenImageGraph(graphBuilderArg);
|
||||
case 'z-image':
|
||||
return await buildZImageGraph(graphBuilderArg);
|
||||
case 'external':
|
||||
return await buildExternalGraph(graphBuilderArg);
|
||||
case 'anima':
|
||||
return await buildAnimaGraph(graphBuilderArg);
|
||||
default:
|
||||
@@ -90,7 +94,7 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base,
|
||||
base: base as BaseModelType,
|
||||
prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
|
||||
@@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { BaseModelType } from 'features/nodes/types/common';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
|
||||
import { useCallback } from 'react';
|
||||
@@ -26,7 +27,7 @@ const enqueueUpscaling = async (store: AppStore, prepend: boolean) => {
|
||||
const batchConfig = prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base,
|
||||
base: base as BaseModelType,
|
||||
prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user