mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-14 00:04:59 -05:00
Adds full support for managing model-to-model relationships in the UI and backend.
Introduces RelatedModels subpanel for linking and unlinking models in model management.
- Adds REST API routes for adding, removing, and retrieving model relationships.
- New database migration: creates model_relationships table for bidirectional links.
- New service layer (model_relationships) for relationship management.
- Updated frontend: Related models float to top of LoRA/Main grouped model comboboxes for quick access.
- Added 'Show Only Related' toggle badge to MainModelPicker filter bar
**Amended commit to remove changes to ParamMainModelSelect.tsx and MainModelPicker.tsx to avoid conflict with upstream deletion/ rewrite**
196 lines
7.5 KiB
Python
196 lines
7.5 KiB
Python
"""FastAPI route for model relationship records."""
|
|
|
|
from fastapi import HTTPException, APIRouter, Path, Body, status
|
|
from pydantic import BaseModel, Field
|
|
from typing import List
|
|
from invokeai.app.api.dependencies import ApiDependencies
|
|
|
|
model_relationships_router = APIRouter(
|
|
prefix="/v1/model_relationships",
|
|
tags=["model_relationships"]
|
|
)
|
|
|
|
# === Schemas ===
|
|
|
|
class ModelRelationshipCreateRequest(BaseModel):
|
|
model_key_1: str = Field(..., description="The key of the first model in the relationship", examples=[
|
|
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
|
|
"ac32b914-10ab-496e-a24a-3068724b9c35",
|
|
"d944abfd-c7c3-42e2-a4ff-da640b29b8b4",
|
|
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
|
|
"12345678-90ab-cdef-1234-567890abcdef",
|
|
"fedcba98-7654-3210-fedc-ba9876543210"
|
|
])
|
|
model_key_2: str = Field(..., description="The key of the second model in the relationship", examples=[
|
|
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
|
|
"f0c3da4e-d9ff-42b5-a45c-23be75c887c9",
|
|
"38170dd8-f1e5-431e-866c-2c81f1277fcc",
|
|
"c57fea2d-7646-424c-b9ad-c0ba60fc68be",
|
|
"10f7807b-ab54-46a9-ab03-600e88c630a1",
|
|
"f6c1d267-cf87-4ee0-bee0-37e791eacab7"
|
|
])
|
|
|
|
class ModelRelationshipBatchRequest(BaseModel):
|
|
model_keys: List[str] = Field(..., description="List of model keys to fetch related models for", examples=
|
|
[[
|
|
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
|
|
"ac32b914-10ab-496e-a24a-3068724b9c35",
|
|
],[
|
|
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
|
|
"12345678-90ab-cdef-1234-567890abcdef",
|
|
"fedcba98-7654-3210-fedc-ba9876543210"
|
|
],[
|
|
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
|
|
]])
|
|
|
|
# === Routes ===
|
|
|
|
@model_relationships_router.get(
|
|
"/i/{model_key}",
|
|
operation_id="get_related_models",
|
|
response_model=list[str],
|
|
responses={
|
|
200: {
|
|
"description": "A list of related model keys was retrieved successfully",
|
|
"content": {
|
|
"application/json": {
|
|
"example": [
|
|
"15e9eb28-8cfe-47c9-b610-37907a79fc3c",
|
|
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
|
|
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2"
|
|
]
|
|
}
|
|
},
|
|
},
|
|
404: {"description": "The specified model could not be found"},
|
|
422: {"description": "Validation error"},
|
|
},
|
|
)
|
|
async def get_related_models(
|
|
model_key: str = Path(..., description="The key of the model to get relationships for")
|
|
) -> list[str]:
|
|
"""
|
|
Get a list of model keys related to a given model.
|
|
"""
|
|
try:
|
|
return ApiDependencies.invoker.services.model_relationships.get_related_model_keys(model_key)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@model_relationships_router.post(
|
|
"/",
|
|
status_code=status.HTTP_204_NO_CONTENT,
|
|
responses={
|
|
204: {"description": "The relationship was successfully created"},
|
|
400: {"description": "Invalid model keys or self-referential relationship"},
|
|
409: {"description": "The relationship already exists"},
|
|
422: {"description": "Validation error"},
|
|
500: {"description": "Internal server error"},
|
|
},
|
|
summary="Add Model Relationship",
|
|
description="Creates a **bidirectional** relationship between two models, allowing each to reference the other as related.",
|
|
)
|
|
async def add_model_relationship(
|
|
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to relate")
|
|
) -> None:
|
|
"""
|
|
Add a relationship between two models.
|
|
|
|
Relationships are bidirectional and will be accessible from both models.
|
|
|
|
- Raises 400 if keys are invalid or identical.
|
|
- Raises 409 if the relationship already exists.
|
|
"""
|
|
try:
|
|
if req.model_key_1 == req.model_key_2:
|
|
raise HTTPException(status_code=400, detail="Cannot relate a model to itself.")
|
|
|
|
ApiDependencies.invoker.services.model_relationships.add_model_relationship(
|
|
req.model_key_1,
|
|
req.model_key_2,
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=409, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@model_relationships_router.delete(
|
|
"/",
|
|
status_code=status.HTTP_204_NO_CONTENT,
|
|
responses={
|
|
204: {"description": "The relationship was successfully removed"},
|
|
400: {"description": "Invalid model keys or self-referential relationship"},
|
|
404: {"description": "The relationship does not exist"},
|
|
422: {"description": "Validation error"},
|
|
500: {"description": "Internal server error"},
|
|
},
|
|
summary="Remove Model Relationship",
|
|
description="Removes a **bidirectional** relationship between two models. The relationship must already exist."
|
|
)
|
|
async def remove_model_relationship(
|
|
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to disconnect")
|
|
) -> None:
|
|
"""
|
|
Removes a bidirectional relationship between two model keys.
|
|
|
|
- Raises 400 if attempting to unlink a model from itself.
|
|
- Raises 404 if the relationship was not found.
|
|
"""
|
|
try:
|
|
if req.model_key_1 == req.model_key_2:
|
|
raise HTTPException(status_code=400, detail="Cannot unlink a model from itself.")
|
|
|
|
ApiDependencies.invoker.services.model_relationships.remove_model_relationship(
|
|
req.model_key_1,
|
|
req.model_key_2,
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@model_relationships_router.post(
|
|
"/batch",
|
|
operation_id="get_related_models_batch",
|
|
response_model=List[str],
|
|
responses={
|
|
200: {
|
|
"description": "Related model keys retrieved successfully",
|
|
"content": {
|
|
"application/json": {
|
|
"example": [
|
|
"ca562b14-995e-4a42-90c1-9528f1a5921d",
|
|
"cc0c2b8a-c62e-41d6-878e-cc74dde5ca8f",
|
|
"18ca7649-6a9e-47d5-bc17-41ab1e8cec81",
|
|
"7c12d1b2-0ef9-4bec-ba55-797b2d8f2ee1",
|
|
"c382eaa3-0e28-4ab0-9446-408667699aeb",
|
|
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
|
|
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2"
|
|
]
|
|
}
|
|
}
|
|
},
|
|
422: {"description": "Validation error"},
|
|
500: {"description": "Internal server error"},
|
|
},
|
|
summary="Get Related Model Keys (Batch)",
|
|
description="Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering."
|
|
)
|
|
async def get_related_models_batch(
|
|
req: ModelRelationshipBatchRequest = Body(..., description="Model keys to check for related connections")
|
|
) -> list[str]:
|
|
"""
|
|
Accepts multiple model keys and returns a flat list of all unique related keys.
|
|
|
|
Useful when working with multiple selections in the UI or cross-model comparisons.
|
|
"""
|
|
try:
|
|
all_related: set[str] = set()
|
|
for key in req.model_keys:
|
|
related = ApiDependencies.invoker.services.model_relationships.get_related_model_keys(key)
|
|
all_related.update(related)
|
|
return list(all_related)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e)) |