mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 23:35:12 -05:00
216 lines
7.7 KiB
Python
216 lines
7.7 KiB
Python
"""FastAPI route for model relationship records."""
|
|
|
|
from typing import List
|
|
|
|
from fastapi import APIRouter, Body, HTTPException, Path, status
|
|
from pydantic import BaseModel, Field
|
|
|
|
from invokeai.app.api.dependencies import ApiDependencies
|
|
|
|
model_relationships_router = APIRouter(prefix="/v1/model_relationships", tags=["model_relationships"])
|
|
|
|
# === Schemas ===
|
|
|
|
|
|
class ModelRelationshipCreateRequest(BaseModel):
|
|
model_key_1: str = Field(
|
|
...,
|
|
description="The key of the first model in the relationship",
|
|
examples=[
|
|
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
|
|
"ac32b914-10ab-496e-a24a-3068724b9c35",
|
|
"d944abfd-c7c3-42e2-a4ff-da640b29b8b4",
|
|
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
|
|
"12345678-90ab-cdef-1234-567890abcdef",
|
|
"fedcba98-7654-3210-fedc-ba9876543210",
|
|
],
|
|
)
|
|
model_key_2: str = Field(
|
|
...,
|
|
description="The key of the second model in the relationship",
|
|
examples=[
|
|
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
|
|
"f0c3da4e-d9ff-42b5-a45c-23be75c887c9",
|
|
"38170dd8-f1e5-431e-866c-2c81f1277fcc",
|
|
"c57fea2d-7646-424c-b9ad-c0ba60fc68be",
|
|
"10f7807b-ab54-46a9-ab03-600e88c630a1",
|
|
"f6c1d267-cf87-4ee0-bee0-37e791eacab7",
|
|
],
|
|
)
|
|
|
|
|
|
class ModelRelationshipBatchRequest(BaseModel):
|
|
model_keys: List[str] = Field(
|
|
...,
|
|
description="List of model keys to fetch related models for",
|
|
examples=[
|
|
[
|
|
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
|
|
"ac32b914-10ab-496e-a24a-3068724b9c35",
|
|
],
|
|
[
|
|
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
|
|
"12345678-90ab-cdef-1234-567890abcdef",
|
|
"fedcba98-7654-3210-fedc-ba9876543210",
|
|
],
|
|
[
|
|
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
|
|
],
|
|
],
|
|
)
|
|
|
|
|
|
# === Routes ===
|
|
|
|
|
|
@model_relationships_router.get(
|
|
"/i/{model_key}",
|
|
operation_id="get_related_models",
|
|
response_model=list[str],
|
|
responses={
|
|
200: {
|
|
"description": "A list of related model keys was retrieved successfully",
|
|
"content": {
|
|
"application/json": {
|
|
"example": [
|
|
"15e9eb28-8cfe-47c9-b610-37907a79fc3c",
|
|
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
|
|
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2",
|
|
]
|
|
}
|
|
},
|
|
},
|
|
404: {"description": "The specified model could not be found"},
|
|
422: {"description": "Validation error"},
|
|
},
|
|
)
|
|
async def get_related_models(
|
|
model_key: str = Path(..., description="The key of the model to get relationships for"),
|
|
) -> list[str]:
|
|
"""
|
|
Get a list of model keys related to a given model.
|
|
"""
|
|
try:
|
|
return ApiDependencies.invoker.services.model_relationships.get_related_model_keys(model_key)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@model_relationships_router.post(
|
|
"/",
|
|
status_code=status.HTTP_204_NO_CONTENT,
|
|
responses={
|
|
204: {"description": "The relationship was successfully created"},
|
|
400: {"description": "Invalid model keys or self-referential relationship"},
|
|
409: {"description": "The relationship already exists"},
|
|
422: {"description": "Validation error"},
|
|
500: {"description": "Internal server error"},
|
|
},
|
|
summary="Add Model Relationship",
|
|
description="Creates a **bidirectional** relationship between two models, allowing each to reference the other as related.",
|
|
)
|
|
async def add_model_relationship(
|
|
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to relate"),
|
|
) -> None:
|
|
"""
|
|
Add a relationship between two models.
|
|
|
|
Relationships are bidirectional and will be accessible from both models.
|
|
|
|
- Raises 400 if keys are invalid or identical.
|
|
- Raises 409 if the relationship already exists.
|
|
"""
|
|
try:
|
|
if req.model_key_1 == req.model_key_2:
|
|
raise HTTPException(status_code=400, detail="Cannot relate a model to itself.")
|
|
|
|
ApiDependencies.invoker.services.model_relationships.add_model_relationship(
|
|
req.model_key_1,
|
|
req.model_key_2,
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=409, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@model_relationships_router.delete(
|
|
"/",
|
|
status_code=status.HTTP_204_NO_CONTENT,
|
|
responses={
|
|
204: {"description": "The relationship was successfully removed"},
|
|
400: {"description": "Invalid model keys or self-referential relationship"},
|
|
404: {"description": "The relationship does not exist"},
|
|
422: {"description": "Validation error"},
|
|
500: {"description": "Internal server error"},
|
|
},
|
|
summary="Remove Model Relationship",
|
|
description="Removes a **bidirectional** relationship between two models. The relationship must already exist.",
|
|
)
|
|
async def remove_model_relationship(
|
|
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to disconnect"),
|
|
) -> None:
|
|
"""
|
|
Removes a bidirectional relationship between two model keys.
|
|
|
|
- Raises 400 if attempting to unlink a model from itself.
|
|
- Raises 404 if the relationship was not found.
|
|
"""
|
|
try:
|
|
if req.model_key_1 == req.model_key_2:
|
|
raise HTTPException(status_code=400, detail="Cannot unlink a model from itself.")
|
|
|
|
ApiDependencies.invoker.services.model_relationships.remove_model_relationship(
|
|
req.model_key_1,
|
|
req.model_key_2,
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@model_relationships_router.post(
|
|
"/batch",
|
|
operation_id="get_related_models_batch",
|
|
response_model=List[str],
|
|
responses={
|
|
200: {
|
|
"description": "Related model keys retrieved successfully",
|
|
"content": {
|
|
"application/json": {
|
|
"example": [
|
|
"ca562b14-995e-4a42-90c1-9528f1a5921d",
|
|
"cc0c2b8a-c62e-41d6-878e-cc74dde5ca8f",
|
|
"18ca7649-6a9e-47d5-bc17-41ab1e8cec81",
|
|
"7c12d1b2-0ef9-4bec-ba55-797b2d8f2ee1",
|
|
"c382eaa3-0e28-4ab0-9446-408667699aeb",
|
|
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
|
|
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2",
|
|
]
|
|
}
|
|
},
|
|
},
|
|
422: {"description": "Validation error"},
|
|
500: {"description": "Internal server error"},
|
|
},
|
|
summary="Get Related Model Keys (Batch)",
|
|
description="Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering.",
|
|
)
|
|
async def get_related_models_batch(
|
|
req: ModelRelationshipBatchRequest = Body(..., description="Model keys to check for related connections"),
|
|
) -> list[str]:
|
|
"""
|
|
Accepts multiple model keys and returns a flat list of all unique related keys.
|
|
|
|
Useful when working with multiple selections in the UI or cross-model comparisons.
|
|
"""
|
|
try:
|
|
all_related: set[str] = set()
|
|
for key in req.model_keys:
|
|
related = ApiDependencies.invoker.services.model_relationships.get_related_model_keys(key)
|
|
all_related.update(related)
|
|
return list(all_related)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|