mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-12 07:45:14 -05:00
Compare commits
9 Commits
add-llm-ma
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5348d97437 | ||
|
|
113e87a23c | ||
|
|
d09f1532a4 | ||
|
|
a78145505b | ||
|
|
36aeb0b2b3 | ||
|
|
2a189c44c4 | ||
|
|
508759610f | ||
|
|
6573d987ea | ||
|
|
ae8ce8b4ca |
@@ -122,24 +122,6 @@ class ConnectionManager:
|
|||||||
|
|
||||||
return len(connections)
|
return len(connections)
|
||||||
|
|
||||||
async def broadcast_to_all(self, *, method: WSMethod, data: dict) -> int:
|
|
||||||
"""Broadcast a message to all active websocket connections."""
|
|
||||||
message = WSMessage(
|
|
||||||
method=method,
|
|
||||||
data=data,
|
|
||||||
).model_dump_json()
|
|
||||||
|
|
||||||
connections = tuple(self.active_connections)
|
|
||||||
if not connections:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
await asyncio.gather(
|
|
||||||
*(connection.send_text(message) for connection in connections),
|
|
||||||
return_exceptions=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return len(connections)
|
|
||||||
|
|
||||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||||
if channel_key not in self.subscriptions:
|
if channel_key not in self.subscriptions:
|
||||||
self.subscriptions[channel_key] = set()
|
self.subscriptions[channel_key] = set()
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
import backend.api.features.store.cache as store_cache
|
import backend.api.features.store.cache as store_cache
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.data.block
|
import backend.blocks
|
||||||
from backend.api.external.middleware import require_permission
|
from backend.api.external.middleware import require_permission
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
@@ -67,7 +67,7 @@ async def get_user_info(
|
|||||||
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||||
)
|
)
|
||||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
blocks = [block() for block in backend.blocks.get_blocks().values()]
|
||||||
return [b.to_dict() for b in blocks if not b.disabled]
|
return [b.to_dict() for b in blocks if not b.disabled]
|
||||||
|
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ async def execute_graph_block(
|
|||||||
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
||||||
),
|
),
|
||||||
) -> CompletedBlockOutput:
|
) -> CompletedBlockOutput:
|
||||||
obj = backend.data.block.get_block(block_id)
|
obj = backend.blocks.get_block(block_id)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||||
if obj.disabled:
|
if obj.disabled:
|
||||||
|
|||||||
@@ -176,64 +176,30 @@ async def get_execution_analytics_config(
|
|||||||
# Return with provider prefix for clarity
|
# Return with provider prefix for clarity
|
||||||
return f"{provider_name}: {model_name}"
|
return f"{provider_name}: {model_name}"
|
||||||
|
|
||||||
# Get all models from the registry (dynamic, not hardcoded enum)
|
# Include all LlmModel values (no more filtering by hardcoded list)
|
||||||
from backend.data import llm_registry
|
recommended_model = LlmModel.GPT4O_MINI.value
|
||||||
from backend.server.v2.llm import db as llm_db
|
for model in LlmModel:
|
||||||
|
|
||||||
# Get the recommended model from the database (configurable via admin UI)
|
|
||||||
recommended_model_slug = await llm_db.get_recommended_model_slug()
|
|
||||||
|
|
||||||
# Build the available models list
|
|
||||||
first_enabled_slug = None
|
|
||||||
for registry_model in llm_registry.iter_dynamic_models():
|
|
||||||
# Only include enabled models in the list
|
|
||||||
if not registry_model.is_enabled:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Track first enabled model as fallback
|
|
||||||
if first_enabled_slug is None:
|
|
||||||
first_enabled_slug = registry_model.slug
|
|
||||||
|
|
||||||
model = LlmModel(registry_model.slug)
|
|
||||||
label = generate_model_label(model)
|
label = generate_model_label(model)
|
||||||
# Add "(Recommended)" suffix to the recommended model
|
# Add "(Recommended)" suffix to the recommended model
|
||||||
if registry_model.slug == recommended_model_slug:
|
if model.value == recommended_model:
|
||||||
label += " (Recommended)"
|
label += " (Recommended)"
|
||||||
|
|
||||||
available_models.append(
|
available_models.append(
|
||||||
ModelInfo(
|
ModelInfo(
|
||||||
value=registry_model.slug,
|
value=model.value,
|
||||||
label=label,
|
label=label,
|
||||||
provider=registry_model.metadata.provider,
|
provider=model.provider,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sort models by provider and name for better UX
|
# Sort models by provider and name for better UX
|
||||||
available_models.sort(key=lambda x: (x.provider, x.label))
|
available_models.sort(key=lambda x: (x.provider, x.label))
|
||||||
|
|
||||||
# Handle case where no models are available
|
|
||||||
if not available_models:
|
|
||||||
logger.warning(
|
|
||||||
"No enabled LLM models found in registry. "
|
|
||||||
"Ensure models are configured and enabled in the LLM Registry."
|
|
||||||
)
|
|
||||||
# Provide a placeholder entry so admins see meaningful feedback
|
|
||||||
available_models.append(
|
|
||||||
ModelInfo(
|
|
||||||
value="",
|
|
||||||
label="No models available - configure in LLM Registry",
|
|
||||||
provider="none",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the DB recommended model, or fallback to first enabled model
|
|
||||||
final_recommended = recommended_model_slug or first_enabled_slug or ""
|
|
||||||
|
|
||||||
return ExecutionAnalyticsConfig(
|
return ExecutionAnalyticsConfig(
|
||||||
available_models=available_models,
|
available_models=available_models,
|
||||||
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
|
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
|
||||||
default_user_prompt=DEFAULT_USER_PROMPT,
|
default_user_prompt=DEFAULT_USER_PROMPT,
|
||||||
recommended_model=final_recommended,
|
recommended_model=recommended_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,593 +0,0 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
import autogpt_libs.auth
|
|
||||||
import fastapi
|
|
||||||
|
|
||||||
from backend.data import llm_registry
|
|
||||||
from backend.data.block_cost_config import refresh_llm_costs
|
|
||||||
from backend.server.v2.llm import db as llm_db
|
|
||||||
from backend.server.v2.llm import model as llm_model
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = fastapi.APIRouter(
|
|
||||||
tags=["llm", "admin"],
|
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _refresh_runtime_state() -> None:
|
|
||||||
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
|
|
||||||
logger.info("Refreshing LLM registry runtime state...")
|
|
||||||
try:
|
|
||||||
# Refresh registry from database
|
|
||||||
await llm_registry.refresh_llm_registry()
|
|
||||||
refresh_llm_costs()
|
|
||||||
|
|
||||||
# Clear block schema caches so they're regenerated with updated model options
|
|
||||||
from backend.data.block import BlockSchema
|
|
||||||
|
|
||||||
BlockSchema.clear_all_schema_caches()
|
|
||||||
logger.info("Cleared all block schema caches")
|
|
||||||
|
|
||||||
# Clear the /blocks endpoint cache so frontend gets updated schemas
|
|
||||||
try:
|
|
||||||
from backend.api.features.v1 import _get_cached_blocks
|
|
||||||
|
|
||||||
_get_cached_blocks.cache_clear()
|
|
||||||
logger.info("Cleared /blocks endpoint cache")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to clear /blocks cache: %s", e)
|
|
||||||
|
|
||||||
# Clear the v2 builder caches
|
|
||||||
try:
|
|
||||||
from backend.api.features.builder import db as builder_db
|
|
||||||
|
|
||||||
builder_db._get_all_providers.cache_clear()
|
|
||||||
logger.info("Cleared v2 builder providers cache")
|
|
||||||
builder_db._build_cached_search_results.cache_clear()
|
|
||||||
logger.info("Cleared v2 builder search results cache")
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("Could not clear v2 builder cache: %s", e)
|
|
||||||
|
|
||||||
# Notify all executor services to refresh their registry cache
|
|
||||||
from backend.data.llm_registry import publish_registry_refresh_notification
|
|
||||||
|
|
||||||
await publish_registry_refresh_notification()
|
|
||||||
logger.info("Published registry refresh notification")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception(
|
|
||||||
"LLM runtime state refresh failed; caches may be stale: %s", exc
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/providers",
|
|
||||||
summary="List LLM providers",
|
|
||||||
response_model=llm_model.LlmProvidersResponse,
|
|
||||||
)
|
|
||||||
async def list_llm_providers(include_models: bool = True):
|
|
||||||
providers = await llm_db.list_providers(include_models=include_models)
|
|
||||||
return llm_model.LlmProvidersResponse(providers=providers)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/providers",
|
|
||||||
summary="Create LLM provider",
|
|
||||||
response_model=llm_model.LlmProvider,
|
|
||||||
)
|
|
||||||
async def create_llm_provider(request: llm_model.UpsertLlmProviderRequest):
|
|
||||||
provider = await llm_db.upsert_provider(request=request)
|
|
||||||
await _refresh_runtime_state()
|
|
||||||
return provider
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
|
||||||
"/providers/{provider_id}",
|
|
||||||
summary="Update LLM provider",
|
|
||||||
response_model=llm_model.LlmProvider,
|
|
||||||
)
|
|
||||||
async def update_llm_provider(
|
|
||||||
provider_id: str,
|
|
||||||
request: llm_model.UpsertLlmProviderRequest,
|
|
||||||
):
|
|
||||||
provider = await llm_db.upsert_provider(request=request, provider_id=provider_id)
|
|
||||||
await _refresh_runtime_state()
|
|
||||||
return provider
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
|
||||||
"/providers/{provider_id}",
|
|
||||||
summary="Delete LLM provider",
|
|
||||||
response_model=dict,
|
|
||||||
)
|
|
||||||
async def delete_llm_provider(provider_id: str):
|
|
||||||
"""
|
|
||||||
Delete an LLM provider.
|
|
||||||
|
|
||||||
A provider can only be deleted if it has no associated models.
|
|
||||||
Delete all models from the provider first before deleting the provider.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await llm_db.delete_provider(provider_id)
|
|
||||||
await _refresh_runtime_state()
|
|
||||||
logger.info("Deleted LLM provider '%s'", provider_id)
|
|
||||||
return {"success": True, "message": "Provider deleted successfully"}
|
|
||||||
except ValueError as e:
|
|
||||||
logger.warning("Failed to delete provider '%s': %s", provider_id, e)
|
|
||||||
raise fastapi.HTTPException(status_code=400, detail=str(e))
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Failed to delete provider '%s': %s", provider_id, e)
|
|
||||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/models",
|
|
||||||
summary="List LLM models",
|
|
||||||
response_model=llm_model.LlmModelsResponse,
|
|
||||||
)
|
|
||||||
async def list_llm_models(
|
|
||||||
provider_id: str | None = fastapi.Query(default=None),
|
|
||||||
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
|
|
||||||
page_size: int = fastapi.Query(
|
|
||||||
default=50, ge=1, le=100, description="Number of models per page"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
return await llm_db.list_models(
|
|
||||||
provider_id=provider_id, page=page, page_size=page_size
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/models",
|
|
||||||
summary="Create LLM model",
|
|
||||||
response_model=llm_model.LlmModel,
|
|
||||||
)
|
|
||||||
async def create_llm_model(request: llm_model.CreateLlmModelRequest):
|
|
||||||
model = await llm_db.create_model(request=request)
|
|
||||||
await _refresh_runtime_state()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
|
||||||
"/models/{model_id}",
|
|
||||||
summary="Update LLM model",
|
|
||||||
response_model=llm_model.LlmModel,
|
|
||||||
)
|
|
||||||
async def update_llm_model(
|
|
||||||
model_id: str,
|
|
||||||
request: llm_model.UpdateLlmModelRequest,
|
|
||||||
):
|
|
||||||
model = await llm_db.update_model(model_id=model_id, request=request)
|
|
||||||
await _refresh_runtime_state()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
|
||||||
"/models/{model_id}/toggle",
|
|
||||||
summary="Toggle LLM model availability",
|
|
||||||
response_model=llm_model.ToggleLlmModelResponse,
|
|
||||||
)
|
|
||||||
async def toggle_llm_model(
|
|
||||||
model_id: str,
|
|
||||||
request: llm_model.ToggleLlmModelRequest,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Toggle a model's enabled status, optionally migrating workflows when disabling.
|
|
||||||
|
|
||||||
If disabling a model and `migrate_to_slug` is provided, all workflows using
|
|
||||||
this model will be migrated to the specified replacement model before disabling.
|
|
||||||
A migration record is created which can be reverted later using the revert endpoint.
|
|
||||||
|
|
||||||
Optional fields:
|
|
||||||
- `migration_reason`: Reason for the migration (e.g., "Provider outage")
|
|
||||||
- `custom_credit_cost`: Custom pricing override for billing during migration
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await llm_db.toggle_model(
|
|
||||||
model_id=model_id,
|
|
||||||
is_enabled=request.is_enabled,
|
|
||||||
migrate_to_slug=request.migrate_to_slug,
|
|
||||||
migration_reason=request.migration_reason,
|
|
||||||
custom_credit_cost=request.custom_credit_cost,
|
|
||||||
)
|
|
||||||
await _refresh_runtime_state()
|
|
||||||
if result.nodes_migrated > 0:
|
|
||||||
logger.info(
|
|
||||||
"Toggled model '%s' to %s and migrated %d nodes to '%s' (migration_id=%s)",
|
|
||||||
result.model.slug,
|
|
||||||
"enabled" if request.is_enabled else "disabled",
|
|
||||||
result.nodes_migrated,
|
|
||||||
result.migrated_to_slug,
|
|
||||||
result.migration_id,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
except ValueError as exc:
|
|
||||||
logger.warning("Model toggle validation failed: %s", exc)
|
|
||||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to toggle LLM model %s: %s", model_id, exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to toggle model availability",
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/models/{model_id}/usage",
|
|
||||||
summary="Get model usage count",
|
|
||||||
response_model=llm_model.LlmModelUsageResponse,
|
|
||||||
)
|
|
||||||
async def get_llm_model_usage(model_id: str):
|
|
||||||
"""Get the number of workflow nodes using this model."""
|
|
||||||
try:
|
|
||||||
return await llm_db.get_model_usage(model_id=model_id)
|
|
||||||
except ValueError as exc:
|
|
||||||
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to get model usage %s: %s", model_id, exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to get model usage",
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
|
||||||
"/models/{model_id}",
|
|
||||||
summary="Delete LLM model and migrate workflows",
|
|
||||||
response_model=llm_model.DeleteLlmModelResponse,
|
|
||||||
)
|
|
||||||
async def delete_llm_model(
|
|
||||||
model_id: str,
|
|
||||||
replacement_model_slug: str | None = fastapi.Query(
|
|
||||||
default=None,
|
|
||||||
description="Slug of the model to migrate existing workflows to (required only if workflows use this model)",
|
|
||||||
),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Delete a model and optionally migrate workflows using it to a replacement model.
|
|
||||||
|
|
||||||
If no workflows are using this model, it can be deleted without providing a
|
|
||||||
replacement. If workflows exist, replacement_model_slug is required.
|
|
||||||
|
|
||||||
This endpoint:
|
|
||||||
1. Counts how many workflow nodes use the model being deleted
|
|
||||||
2. If nodes exist, validates the replacement model and migrates them
|
|
||||||
3. Deletes the model record
|
|
||||||
4. Refreshes all caches and notifies executors
|
|
||||||
|
|
||||||
Example: DELETE /api/llm/admin/models/{id}?replacement_model_slug=gpt-4o
|
|
||||||
Example (no usage): DELETE /api/llm/admin/models/{id}
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await llm_db.delete_model(
|
|
||||||
model_id=model_id, replacement_model_slug=replacement_model_slug
|
|
||||||
)
|
|
||||||
await _refresh_runtime_state()
|
|
||||||
logger.info(
|
|
||||||
"Deleted model '%s' and migrated %d nodes to '%s'",
|
|
||||||
result.deleted_model_slug,
|
|
||||||
result.nodes_migrated,
|
|
||||||
result.replacement_model_slug,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
except ValueError as exc:
|
|
||||||
# Validation errors (model not found, replacement invalid, etc.)
|
|
||||||
logger.warning("Model deletion validation failed: %s", exc)
|
|
||||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to delete LLM model %s: %s", model_id, exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to delete model and migrate workflows",
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Migration Management Endpoints
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/migrations",
|
|
||||||
summary="List model migrations",
|
|
||||||
response_model=llm_model.LlmMigrationsResponse,
|
|
||||||
)
|
|
||||||
async def list_llm_migrations(
|
|
||||||
include_reverted: bool = fastapi.Query(
|
|
||||||
default=False, description="Include reverted migrations in the list"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
List all model migrations.
|
|
||||||
|
|
||||||
Migrations are created when disabling a model with the migrate_to_slug option.
|
|
||||||
They can be reverted to restore the original model configuration.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
migrations = await llm_db.list_migrations(include_reverted=include_reverted)
|
|
||||||
return llm_model.LlmMigrationsResponse(migrations=migrations)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to list migrations: %s", exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to list migrations",
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/migrations/{migration_id}",
|
|
||||||
summary="Get migration details",
|
|
||||||
response_model=llm_model.LlmModelMigration,
|
|
||||||
)
|
|
||||||
async def get_llm_migration(migration_id: str):
|
|
||||||
"""Get details of a specific migration."""
|
|
||||||
try:
|
|
||||||
migration = await llm_db.get_migration(migration_id)
|
|
||||||
if not migration:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=404, detail=f"Migration '{migration_id}' not found"
|
|
||||||
)
|
|
||||||
return migration
|
|
||||||
except fastapi.HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to get migration %s: %s", migration_id, exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to get migration",
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/migrations/{migration_id}/revert",
|
|
||||||
summary="Revert a model migration",
|
|
||||||
response_model=llm_model.RevertMigrationResponse,
|
|
||||||
)
|
|
||||||
async def revert_llm_migration(
|
|
||||||
migration_id: str,
|
|
||||||
request: llm_model.RevertMigrationRequest | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Revert a model migration, restoring affected workflows to their original model.
|
|
||||||
|
|
||||||
This only reverts the specific nodes that were part of the migration.
|
|
||||||
The source model must exist for the revert to succeed.
|
|
||||||
|
|
||||||
Options:
|
|
||||||
- `re_enable_source_model`: Whether to re-enable the source model if disabled (default: True)
|
|
||||||
|
|
||||||
Response includes:
|
|
||||||
- `nodes_reverted`: Number of nodes successfully reverted
|
|
||||||
- `nodes_already_changed`: Number of nodes that were modified since migration (not reverted)
|
|
||||||
- `source_model_re_enabled`: Whether the source model was re-enabled
|
|
||||||
|
|
||||||
Requirements:
|
|
||||||
- Migration must not already be reverted
|
|
||||||
- Source model must exist
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
re_enable = request.re_enable_source_model if request else True
|
|
||||||
result = await llm_db.revert_migration(
|
|
||||||
migration_id,
|
|
||||||
re_enable_source_model=re_enable,
|
|
||||||
)
|
|
||||||
await _refresh_runtime_state()
|
|
||||||
logger.info(
|
|
||||||
"Reverted migration '%s': %d nodes restored from '%s' to '%s' "
|
|
||||||
"(%d already changed, source re-enabled=%s)",
|
|
||||||
migration_id,
|
|
||||||
result.nodes_reverted,
|
|
||||||
result.target_model_slug,
|
|
||||||
result.source_model_slug,
|
|
||||||
result.nodes_already_changed,
|
|
||||||
result.source_model_re_enabled,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
except ValueError as exc:
|
|
||||||
logger.warning("Migration revert validation failed: %s", exc)
|
|
||||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to revert migration %s: %s", migration_id, exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to revert migration",
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Creator Management Endpoints
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/creators",
|
|
||||||
summary="List model creators",
|
|
||||||
response_model=llm_model.LlmCreatorsResponse,
|
|
||||||
)
|
|
||||||
async def list_llm_creators():
|
|
||||||
"""
|
|
||||||
List all model creators.
|
|
||||||
|
|
||||||
Creators are organizations that create/train models (e.g., OpenAI, Meta, Anthropic).
|
|
||||||
This is distinct from providers who host/serve the models (e.g., OpenRouter).
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
creators = await llm_db.list_creators()
|
|
||||||
return llm_model.LlmCreatorsResponse(creators=creators)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to list creators: %s", exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to list creators",
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/creators/{creator_id}",
|
|
||||||
summary="Get creator details",
|
|
||||||
response_model=llm_model.LlmModelCreator,
|
|
||||||
)
|
|
||||||
async def get_llm_creator(creator_id: str):
|
|
||||||
"""Get details of a specific model creator."""
|
|
||||||
try:
|
|
||||||
creator = await llm_db.get_creator(creator_id)
|
|
||||||
if not creator:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=404, detail=f"Creator '{creator_id}' not found"
|
|
||||||
)
|
|
||||||
return creator
|
|
||||||
except fastapi.HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to get creator %s: %s", creator_id, exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to get creator",
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/creators",
|
|
||||||
summary="Create model creator",
|
|
||||||
response_model=llm_model.LlmModelCreator,
|
|
||||||
)
|
|
||||||
async def create_llm_creator(request: llm_model.UpsertLlmCreatorRequest):
|
|
||||||
"""
|
|
||||||
Create a new model creator.
|
|
||||||
|
|
||||||
A creator represents an organization that creates/trains AI models,
|
|
||||||
such as OpenAI, Anthropic, Meta, or Google.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
creator = await llm_db.upsert_creator(request=request)
|
|
||||||
await _refresh_runtime_state()
|
|
||||||
logger.info("Created model creator '%s' (%s)", creator.display_name, creator.id)
|
|
||||||
return creator
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to create creator: %s", exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to create creator",
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
|
||||||
"/creators/{creator_id}",
|
|
||||||
summary="Update model creator",
|
|
||||||
response_model=llm_model.LlmModelCreator,
|
|
||||||
)
|
|
||||||
async def update_llm_creator(
|
|
||||||
creator_id: str,
|
|
||||||
request: llm_model.UpsertLlmCreatorRequest,
|
|
||||||
):
|
|
||||||
"""Update an existing model creator."""
|
|
||||||
try:
|
|
||||||
creator = await llm_db.upsert_creator(request=request, creator_id=creator_id)
|
|
||||||
await _refresh_runtime_state()
|
|
||||||
logger.info("Updated model creator '%s' (%s)", creator.display_name, creator_id)
|
|
||||||
return creator
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to update creator %s: %s", creator_id, exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to update creator",
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
|
||||||
"/creators/{creator_id}",
|
|
||||||
summary="Delete model creator",
|
|
||||||
response_model=dict,
|
|
||||||
)
|
|
||||||
async def delete_llm_creator(creator_id: str):
|
|
||||||
"""
|
|
||||||
Delete a model creator.
|
|
||||||
|
|
||||||
This will remove the creator association from all models that reference it
|
|
||||||
(sets creatorId to NULL), but will not delete the models themselves.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await llm_db.delete_creator(creator_id)
|
|
||||||
await _refresh_runtime_state()
|
|
||||||
logger.info("Deleted model creator '%s'", creator_id)
|
|
||||||
return {"success": True, "message": f"Creator '{creator_id}' deleted"}
|
|
||||||
except ValueError as exc:
|
|
||||||
logger.warning("Creator deletion validation failed: %s", exc)
|
|
||||||
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to delete creator %s: %s", creator_id, exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to delete creator",
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Recommended Model Endpoints
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/recommended-model",
|
|
||||||
summary="Get recommended model",
|
|
||||||
response_model=llm_model.RecommendedModelResponse,
|
|
||||||
)
|
|
||||||
async def get_recommended_model():
|
|
||||||
"""
|
|
||||||
Get the currently recommended LLM model.
|
|
||||||
|
|
||||||
The recommended model is shown to users as the default/suggested option
|
|
||||||
in model selection dropdowns.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
model = await llm_db.get_recommended_model()
|
|
||||||
return llm_model.RecommendedModelResponse(
|
|
||||||
model=model,
|
|
||||||
slug=model.slug if model else None,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to get recommended model: %s", exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to get recommended model",
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/recommended-model",
|
|
||||||
summary="Set recommended model",
|
|
||||||
response_model=llm_model.SetRecommendedModelResponse,
|
|
||||||
)
|
|
||||||
async def set_recommended_model(request: llm_model.SetRecommendedModelRequest):
|
|
||||||
"""
|
|
||||||
Set a model as the recommended model.
|
|
||||||
|
|
||||||
This clears the recommended flag from any other model and sets it on
|
|
||||||
the specified model. The model must be enabled to be set as recommended.
|
|
||||||
|
|
||||||
The recommended model is displayed to users as the default/suggested
|
|
||||||
option in model selection dropdowns throughout the platform.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
model, previous_slug = await llm_db.set_recommended_model(request.model_id)
|
|
||||||
await _refresh_runtime_state()
|
|
||||||
logger.info(
|
|
||||||
"Set recommended model to '%s' (previous: %s)",
|
|
||||||
model.slug,
|
|
||||||
previous_slug or "none",
|
|
||||||
)
|
|
||||||
return llm_model.SetRecommendedModelResponse(
|
|
||||||
model=model,
|
|
||||||
previous_recommended_slug=previous_slug,
|
|
||||||
message=f"Model '{model.display_name}' is now the recommended model",
|
|
||||||
)
|
|
||||||
except ValueError as exc:
|
|
||||||
logger.warning("Set recommended model validation failed: %s", exc)
|
|
||||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to set recommended model: %s", exc)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Failed to set recommended model",
|
|
||||||
) from exc
|
|
||||||
@@ -1,491 +0,0 @@
|
|||||||
import json
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
import fastapi.testclient
|
|
||||||
import pytest
|
|
||||||
import pytest_mock
|
|
||||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
|
||||||
from pytest_snapshot.plugin import Snapshot
|
|
||||||
|
|
||||||
import backend.api.features.admin.llm_routes as llm_routes
|
|
||||||
from backend.server.v2.llm import model as llm_model
|
|
||||||
from backend.util.models import Pagination
|
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
|
||||||
app.include_router(llm_routes.router, prefix="/admin/llm")
|
|
||||||
|
|
||||||
client = fastapi.testclient.TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_app_admin_auth(mock_jwt_admin):
|
|
||||||
"""Setup admin auth overrides for all tests in this module"""
|
|
||||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
|
||||||
yield
|
|
||||||
app.dependency_overrides.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_llm_providers_success(
|
|
||||||
mocker: pytest_mock.MockFixture,
|
|
||||||
configured_snapshot: Snapshot,
|
|
||||||
) -> None:
|
|
||||||
"""Test successful listing of LLM providers"""
|
|
||||||
# Mock the database function
|
|
||||||
mock_providers = [
|
|
||||||
{
|
|
||||||
"id": "provider-1",
|
|
||||||
"name": "openai",
|
|
||||||
"display_name": "OpenAI",
|
|
||||||
"description": "OpenAI LLM provider",
|
|
||||||
"supports_tools": True,
|
|
||||||
"supports_json_output": True,
|
|
||||||
"supports_reasoning": False,
|
|
||||||
"supports_parallel_tool": True,
|
|
||||||
"metadata": {},
|
|
||||||
"models": [],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "provider-2",
|
|
||||||
"name": "anthropic",
|
|
||||||
"display_name": "Anthropic",
|
|
||||||
"description": "Anthropic LLM provider",
|
|
||||||
"supports_tools": True,
|
|
||||||
"supports_json_output": True,
|
|
||||||
"supports_reasoning": False,
|
|
||||||
"supports_parallel_tool": True,
|
|
||||||
"metadata": {},
|
|
||||||
"models": [],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes.llm_db.list_providers",
|
|
||||||
new=AsyncMock(return_value=mock_providers),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.get("/admin/llm/providers")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
response_data = response.json()
|
|
||||||
assert len(response_data["providers"]) == 2
|
|
||||||
assert response_data["providers"][0]["name"] == "openai"
|
|
||||||
|
|
||||||
# Snapshot test the response (must be string)
|
|
||||||
configured_snapshot.assert_match(
|
|
||||||
json.dumps(response_data, indent=2, sort_keys=True),
|
|
||||||
"list_llm_providers_success.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_llm_models_success(
|
|
||||||
mocker: pytest_mock.MockFixture,
|
|
||||||
configured_snapshot: Snapshot,
|
|
||||||
) -> None:
|
|
||||||
"""Test successful listing of LLM models with pagination"""
|
|
||||||
# Mock the database function - now returns LlmModelsResponse
|
|
||||||
mock_model = llm_model.LlmModel(
|
|
||||||
id="model-1",
|
|
||||||
slug="gpt-4o",
|
|
||||||
display_name="GPT-4o",
|
|
||||||
description="GPT-4 Optimized",
|
|
||||||
provider_id="provider-1",
|
|
||||||
context_window=128000,
|
|
||||||
max_output_tokens=16384,
|
|
||||||
is_enabled=True,
|
|
||||||
capabilities={},
|
|
||||||
metadata={},
|
|
||||||
costs=[
|
|
||||||
llm_model.LlmModelCost(
|
|
||||||
id="cost-1",
|
|
||||||
credit_cost=10,
|
|
||||||
credential_provider="openai",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_response = llm_model.LlmModelsResponse(
|
|
||||||
models=[mock_model],
|
|
||||||
pagination=Pagination(
|
|
||||||
total_items=1,
|
|
||||||
total_pages=1,
|
|
||||||
current_page=1,
|
|
||||||
page_size=50,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes.llm_db.list_models",
|
|
||||||
new=AsyncMock(return_value=mock_response),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.get("/admin/llm/models")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
response_data = response.json()
|
|
||||||
assert len(response_data["models"]) == 1
|
|
||||||
assert response_data["models"][0]["slug"] == "gpt-4o"
|
|
||||||
assert response_data["pagination"]["total_items"] == 1
|
|
||||||
assert response_data["pagination"]["page_size"] == 50
|
|
||||||
|
|
||||||
# Snapshot test the response (must be string)
|
|
||||||
configured_snapshot.assert_match(
|
|
||||||
json.dumps(response_data, indent=2, sort_keys=True),
|
|
||||||
"list_llm_models_success.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_llm_provider_success(
|
|
||||||
mocker: pytest_mock.MockFixture,
|
|
||||||
configured_snapshot: Snapshot,
|
|
||||||
) -> None:
|
|
||||||
"""Test successful creation of LLM provider"""
|
|
||||||
mock_provider = {
|
|
||||||
"id": "new-provider-id",
|
|
||||||
"name": "groq",
|
|
||||||
"display_name": "Groq",
|
|
||||||
"description": "Groq LLM provider",
|
|
||||||
"supports_tools": True,
|
|
||||||
"supports_json_output": True,
|
|
||||||
"supports_reasoning": False,
|
|
||||||
"supports_parallel_tool": False,
|
|
||||||
"metadata": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes.llm_db.upsert_provider",
|
|
||||||
new=AsyncMock(return_value=mock_provider),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_refresh = mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
|
||||||
new=AsyncMock(),
|
|
||||||
)
|
|
||||||
|
|
||||||
request_data = {
|
|
||||||
"name": "groq",
|
|
||||||
"display_name": "Groq",
|
|
||||||
"description": "Groq LLM provider",
|
|
||||||
"supports_tools": True,
|
|
||||||
"supports_json_output": True,
|
|
||||||
"supports_reasoning": False,
|
|
||||||
"supports_parallel_tool": False,
|
|
||||||
"metadata": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post("/admin/llm/providers", json=request_data)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
response_data = response.json()
|
|
||||||
assert response_data["name"] == "groq"
|
|
||||||
assert response_data["display_name"] == "Groq"
|
|
||||||
|
|
||||||
# Verify refresh was called
|
|
||||||
mock_refresh.assert_called_once()
|
|
||||||
|
|
||||||
# Snapshot test the response (must be string)
|
|
||||||
configured_snapshot.assert_match(
|
|
||||||
json.dumps(response_data, indent=2, sort_keys=True),
|
|
||||||
"create_llm_provider_success.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_llm_model_success(
|
|
||||||
mocker: pytest_mock.MockFixture,
|
|
||||||
configured_snapshot: Snapshot,
|
|
||||||
) -> None:
|
|
||||||
"""Test successful creation of LLM model"""
|
|
||||||
mock_model = {
|
|
||||||
"id": "new-model-id",
|
|
||||||
"slug": "gpt-4.1-mini",
|
|
||||||
"display_name": "GPT-4.1 Mini",
|
|
||||||
"description": "Latest GPT-4.1 Mini model",
|
|
||||||
"provider_id": "provider-1",
|
|
||||||
"context_window": 128000,
|
|
||||||
"max_output_tokens": 16384,
|
|
||||||
"is_enabled": True,
|
|
||||||
"capabilities": {},
|
|
||||||
"metadata": {},
|
|
||||||
"costs": [
|
|
||||||
{
|
|
||||||
"id": "cost-id",
|
|
||||||
"credit_cost": 5,
|
|
||||||
"credential_provider": "openai",
|
|
||||||
"metadata": {},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes.llm_db.create_model",
|
|
||||||
new=AsyncMock(return_value=mock_model),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_refresh = mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
|
||||||
new=AsyncMock(),
|
|
||||||
)
|
|
||||||
|
|
||||||
request_data = {
|
|
||||||
"slug": "gpt-4.1-mini",
|
|
||||||
"display_name": "GPT-4.1 Mini",
|
|
||||||
"description": "Latest GPT-4.1 Mini model",
|
|
||||||
"provider_id": "provider-1",
|
|
||||||
"context_window": 128000,
|
|
||||||
"max_output_tokens": 16384,
|
|
||||||
"is_enabled": True,
|
|
||||||
"capabilities": {},
|
|
||||||
"metadata": {},
|
|
||||||
"costs": [
|
|
||||||
{
|
|
||||||
"credit_cost": 5,
|
|
||||||
"credential_provider": "openai",
|
|
||||||
"metadata": {},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post("/admin/llm/models", json=request_data)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
response_data = response.json()
|
|
||||||
assert response_data["slug"] == "gpt-4.1-mini"
|
|
||||||
assert response_data["is_enabled"] is True
|
|
||||||
|
|
||||||
# Verify refresh was called
|
|
||||||
mock_refresh.assert_called_once()
|
|
||||||
|
|
||||||
# Snapshot test the response (must be string)
|
|
||||||
configured_snapshot.assert_match(
|
|
||||||
json.dumps(response_data, indent=2, sort_keys=True),
|
|
||||||
"create_llm_model_success.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_llm_model_success(
|
|
||||||
mocker: pytest_mock.MockFixture,
|
|
||||||
configured_snapshot: Snapshot,
|
|
||||||
) -> None:
|
|
||||||
"""Test successful update of LLM model"""
|
|
||||||
mock_model = {
|
|
||||||
"id": "model-1",
|
|
||||||
"slug": "gpt-4o",
|
|
||||||
"display_name": "GPT-4o Updated",
|
|
||||||
"description": "Updated description",
|
|
||||||
"provider_id": "provider-1",
|
|
||||||
"context_window": 256000,
|
|
||||||
"max_output_tokens": 32768,
|
|
||||||
"is_enabled": True,
|
|
||||||
"capabilities": {},
|
|
||||||
"metadata": {},
|
|
||||||
"costs": [
|
|
||||||
{
|
|
||||||
"id": "cost-1",
|
|
||||||
"credit_cost": 15,
|
|
||||||
"credential_provider": "openai",
|
|
||||||
"metadata": {},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes.llm_db.update_model",
|
|
||||||
new=AsyncMock(return_value=mock_model),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_refresh = mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
|
||||||
new=AsyncMock(),
|
|
||||||
)
|
|
||||||
|
|
||||||
request_data = {
|
|
||||||
"display_name": "GPT-4o Updated",
|
|
||||||
"description": "Updated description",
|
|
||||||
"context_window": 256000,
|
|
||||||
"max_output_tokens": 32768,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.patch("/admin/llm/models/model-1", json=request_data)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
response_data = response.json()
|
|
||||||
assert response_data["display_name"] == "GPT-4o Updated"
|
|
||||||
assert response_data["context_window"] == 256000
|
|
||||||
|
|
||||||
# Verify refresh was called
|
|
||||||
mock_refresh.assert_called_once()
|
|
||||||
|
|
||||||
# Snapshot test the response (must be string)
|
|
||||||
configured_snapshot.assert_match(
|
|
||||||
json.dumps(response_data, indent=2, sort_keys=True),
|
|
||||||
"update_llm_model_success.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_toggle_llm_model_success(
|
|
||||||
mocker: pytest_mock.MockFixture,
|
|
||||||
configured_snapshot: Snapshot,
|
|
||||||
) -> None:
|
|
||||||
"""Test successful toggling of LLM model enabled status"""
|
|
||||||
# Create a proper mock model object
|
|
||||||
mock_model = llm_model.LlmModel(
|
|
||||||
id="model-1",
|
|
||||||
slug="gpt-4o",
|
|
||||||
display_name="GPT-4o",
|
|
||||||
description="GPT-4 Optimized",
|
|
||||||
provider_id="provider-1",
|
|
||||||
context_window=128000,
|
|
||||||
max_output_tokens=16384,
|
|
||||||
is_enabled=False,
|
|
||||||
capabilities={},
|
|
||||||
metadata={},
|
|
||||||
costs=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a proper ToggleLlmModelResponse
|
|
||||||
mock_response = llm_model.ToggleLlmModelResponse(
|
|
||||||
model=mock_model,
|
|
||||||
nodes_migrated=0,
|
|
||||||
migrated_to_slug=None,
|
|
||||||
migration_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes.llm_db.toggle_model",
|
|
||||||
new=AsyncMock(return_value=mock_response),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_refresh = mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
|
||||||
new=AsyncMock(),
|
|
||||||
)
|
|
||||||
|
|
||||||
request_data = {"is_enabled": False}
|
|
||||||
|
|
||||||
response = client.patch("/admin/llm/models/model-1/toggle", json=request_data)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
response_data = response.json()
|
|
||||||
assert response_data["model"]["is_enabled"] is False
|
|
||||||
|
|
||||||
# Verify refresh was called
|
|
||||||
mock_refresh.assert_called_once()
|
|
||||||
|
|
||||||
# Snapshot test the response (must be string)
|
|
||||||
configured_snapshot.assert_match(
|
|
||||||
json.dumps(response_data, indent=2, sort_keys=True),
|
|
||||||
"toggle_llm_model_success.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_llm_model_success(
|
|
||||||
mocker: pytest_mock.MockFixture,
|
|
||||||
configured_snapshot: Snapshot,
|
|
||||||
) -> None:
|
|
||||||
"""Test successful deletion of LLM model with migration"""
|
|
||||||
# Create a proper DeleteLlmModelResponse
|
|
||||||
mock_response = llm_model.DeleteLlmModelResponse(
|
|
||||||
deleted_model_slug="gpt-3.5-turbo",
|
|
||||||
deleted_model_display_name="GPT-3.5 Turbo",
|
|
||||||
replacement_model_slug="gpt-4o-mini",
|
|
||||||
nodes_migrated=42,
|
|
||||||
message="Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
|
|
||||||
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
|
|
||||||
)
|
|
||||||
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
|
||||||
new=AsyncMock(return_value=mock_response),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_refresh = mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
|
||||||
new=AsyncMock(),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.delete(
|
|
||||||
"/admin/llm/models/model-1?replacement_model_slug=gpt-4o-mini"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
response_data = response.json()
|
|
||||||
assert response_data["deleted_model_slug"] == "gpt-3.5-turbo"
|
|
||||||
assert response_data["nodes_migrated"] == 42
|
|
||||||
assert response_data["replacement_model_slug"] == "gpt-4o-mini"
|
|
||||||
|
|
||||||
# Verify refresh was called
|
|
||||||
mock_refresh.assert_called_once()
|
|
||||||
|
|
||||||
# Snapshot test the response (must be string)
|
|
||||||
configured_snapshot.assert_match(
|
|
||||||
json.dumps(response_data, indent=2, sort_keys=True),
|
|
||||||
"delete_llm_model_success.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_llm_model_validation_error(
|
|
||||||
mocker: pytest_mock.MockFixture,
|
|
||||||
) -> None:
|
|
||||||
"""Test deletion fails with proper error when validation fails"""
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
|
||||||
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "Replacement model 'invalid' not found" in response.json()["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_llm_model_no_replacement_with_usage(
|
|
||||||
mocker: pytest_mock.MockFixture,
|
|
||||||
) -> None:
|
|
||||||
"""Test deletion fails when nodes exist but no replacement is provided"""
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
|
||||||
new=AsyncMock(
|
|
||||||
side_effect=ValueError(
|
|
||||||
"Cannot delete model 'test-model': 5 workflow node(s) are using it. "
|
|
||||||
"Please provide a replacement_model_slug to migrate them."
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.delete("/admin/llm/models/model-1")
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "workflow node(s) are using it" in response.json()["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_llm_model_no_replacement_no_usage(
|
|
||||||
mocker: pytest_mock.MockFixture,
|
|
||||||
) -> None:
|
|
||||||
"""Test deletion succeeds when no nodes use the model and no replacement is provided"""
|
|
||||||
mock_response = llm_model.DeleteLlmModelResponse(
|
|
||||||
deleted_model_slug="unused-model",
|
|
||||||
deleted_model_display_name="Unused Model",
|
|
||||||
replacement_model_slug=None,
|
|
||||||
nodes_migrated=0,
|
|
||||||
message="Successfully deleted model 'Unused Model' (unused-model). No workflows were using this model.",
|
|
||||||
)
|
|
||||||
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
|
||||||
new=AsyncMock(return_value=mock_response),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_refresh = mocker.patch(
|
|
||||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
|
||||||
new=AsyncMock(),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.delete("/admin/llm/models/model-1")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
response_data = response.json()
|
|
||||||
assert response_data["deleted_model_slug"] == "unused-model"
|
|
||||||
assert response_data["nodes_migrated"] == 0
|
|
||||||
assert response_data["replacement_model_slug"] is None
|
|
||||||
mock_refresh.assert_called_once()
|
|
||||||
@@ -10,12 +10,16 @@ import backend.api.features.library.db as library_db
|
|||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.db as store_db
|
import backend.api.features.store.db as store_db
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.data.block
|
|
||||||
from backend.blocks import load_all_blocks
|
from backend.blocks import load_all_blocks
|
||||||
|
from backend.blocks._base import (
|
||||||
|
AnyBlockSchema,
|
||||||
|
BlockCategory,
|
||||||
|
BlockInfo,
|
||||||
|
BlockSchema,
|
||||||
|
BlockType,
|
||||||
|
)
|
||||||
from backend.blocks.llm import LlmModel
|
from backend.blocks.llm import LlmModel
|
||||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
|
||||||
from backend.data.db import query_raw_with_schema
|
from backend.data.db import query_raw_with_schema
|
||||||
from backend.data.llm_registry import get_all_model_slugs_for_validation
|
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
@@ -23,7 +27,7 @@ from backend.util.models import Pagination
|
|||||||
from .model import (
|
from .model import (
|
||||||
BlockCategoryResponse,
|
BlockCategoryResponse,
|
||||||
BlockResponse,
|
BlockResponse,
|
||||||
BlockType,
|
BlockTypeFilter,
|
||||||
CountResponse,
|
CountResponse,
|
||||||
FilterType,
|
FilterType,
|
||||||
Provider,
|
Provider,
|
||||||
@@ -32,14 +36,7 @@ from .model import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||||
|
|
||||||
def _get_llm_models() -> list[str]:
|
|
||||||
"""Get LLM model names for search matching from the registry."""
|
|
||||||
return [
|
|
||||||
slug.lower().replace("-", " ") for slug in get_all_model_slugs_for_validation()
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
MAX_LIBRARY_AGENT_RESULTS = 100
|
MAX_LIBRARY_AGENT_RESULTS = 100
|
||||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||||
@@ -96,7 +93,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
|||||||
def get_blocks(
|
def get_blocks(
|
||||||
*,
|
*,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
type: BlockType | None = None,
|
type: BlockTypeFilter | None = None,
|
||||||
provider: ProviderName | None = None,
|
provider: ProviderName | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 50,
|
page_size: int = 50,
|
||||||
@@ -504,8 +501,8 @@ async def _get_static_counts():
|
|||||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||||
for field in schema_cls.model_fields.values():
|
for field in schema_cls.model_fields.values():
|
||||||
if field.annotation == LlmModel:
|
if field.annotation == LlmModel:
|
||||||
# Check if query matches any value in llm_models from registry
|
# Check if query matches any value in llm_models
|
||||||
if any(query in name for name in _get_llm_models()):
|
if any(query in name for name in llm_models):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -677,9 +674,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
|||||||
for block_type in load_all_blocks().values():
|
for block_type in load_all_blocks().values():
|
||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
if block.disabled or block.block_type in (
|
if block.disabled or block.block_type in (
|
||||||
backend.data.block.BlockType.INPUT,
|
BlockType.INPUT,
|
||||||
backend.data.block.BlockType.OUTPUT,
|
BlockType.OUTPUT,
|
||||||
backend.data.block.BlockType.AGENT,
|
BlockType.AGENT,
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# Find the execution count for this block
|
# Find the execution count for this block
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
from backend.data.block import BlockInfo
|
from backend.blocks._base import BlockInfo
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ FilterType = Literal[
|
|||||||
"my_agents",
|
"my_agents",
|
||||||
]
|
]
|
||||||
|
|
||||||
BlockType = Literal["all", "input", "action", "output"]
|
BlockTypeFilter = Literal["all", "input", "action", "output"]
|
||||||
|
|
||||||
|
|
||||||
class SearchEntry(BaseModel):
|
class SearchEntry(BaseModel):
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ async def get_block_categories(
|
|||||||
)
|
)
|
||||||
async def get_blocks(
|
async def get_blocks(
|
||||||
category: Annotated[str | None, fastapi.Query()] = None,
|
category: Annotated[str | None, fastapi.Query()] = None,
|
||||||
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
|
type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None,
|
||||||
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
||||||
page: Annotated[int, fastapi.Query()] = 1,
|
page: Annotated[int, fastapi.Query()] = 1,
|
||||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
@@ -104,6 +104,26 @@ class ChatSession(BaseModel):
|
|||||||
successful_agent_runs: dict[str, int] = {}
|
successful_agent_runs: dict[str, int] = {}
|
||||||
successful_agent_schedules: dict[str, int] = {}
|
successful_agent_schedules: dict[str, int] = {}
|
||||||
|
|
||||||
|
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||||
|
"""Attach a tool_call to the current turn's assistant message.
|
||||||
|
|
||||||
|
Searches backwards for the most recent assistant message (stopping at
|
||||||
|
any user message boundary). If found, appends the tool_call to it.
|
||||||
|
Otherwise creates a new assistant message with the tool_call.
|
||||||
|
"""
|
||||||
|
for msg in reversed(self.messages):
|
||||||
|
if msg.role == "user":
|
||||||
|
break
|
||||||
|
if msg.role == "assistant":
|
||||||
|
if not msg.tool_calls:
|
||||||
|
msg.tool_calls = []
|
||||||
|
msg.tool_calls.append(tool_call)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.messages.append(
|
||||||
|
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def new(user_id: str) -> "ChatSession":
|
def new(user_id: str) -> "ChatSession":
|
||||||
return ChatSession(
|
return ChatSession(
|
||||||
@@ -172,6 +192,47 @@ class ChatSession(BaseModel):
|
|||||||
successful_agent_schedules=successful_agent_schedules,
|
successful_agent_schedules=successful_agent_schedules,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_consecutive_assistant_messages(
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
) -> list[ChatCompletionMessageParam]:
|
||||||
|
"""Merge consecutive assistant messages into single messages.
|
||||||
|
|
||||||
|
Long-running tool flows can create split assistant messages: one with
|
||||||
|
text content and another with tool_calls. Anthropic's API requires
|
||||||
|
tool_result blocks to reference a tool_use in the immediately preceding
|
||||||
|
assistant message, so these splits cause 400 errors via OpenRouter.
|
||||||
|
"""
|
||||||
|
if len(messages) < 2:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
result: list[ChatCompletionMessageParam] = [messages[0]]
|
||||||
|
for msg in messages[1:]:
|
||||||
|
prev = result[-1]
|
||||||
|
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
||||||
|
result.append(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
||||||
|
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
||||||
|
|
||||||
|
curr_content = curr.get("content") or ""
|
||||||
|
if curr_content:
|
||||||
|
prev_content = prev.get("content") or ""
|
||||||
|
prev["content"] = (
|
||||||
|
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_tool_calls = curr.get("tool_calls")
|
||||||
|
if curr_tool_calls:
|
||||||
|
prev_tool_calls = prev.get("tool_calls")
|
||||||
|
prev["tool_calls"] = (
|
||||||
|
list(prev_tool_calls) + list(curr_tool_calls)
|
||||||
|
if prev_tool_calls
|
||||||
|
else list(curr_tool_calls)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||||
messages = []
|
messages = []
|
||||||
for message in self.messages:
|
for message in self.messages:
|
||||||
@@ -258,7 +319,7 @@ class ChatSession(BaseModel):
|
|||||||
name=message.name or "",
|
name=message.name or "",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return messages
|
return self._merge_consecutive_assistant_messages(messages)
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||||
|
|||||||
@@ -1,4 +1,16 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionAssistantMessageParam,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionToolMessageParam,
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||||
|
ChatCompletionMessageToolCallParam,
|
||||||
|
Function,
|
||||||
|
)
|
||||||
|
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -117,3 +129,205 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
|||||||
loaded.tool_calls is not None
|
loaded.tool_calls is not None
|
||||||
), f"Tool calls missing for {orig.role} message"
|
), f"Tool calls missing for {orig.role} message"
|
||||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# _merge_consecutive_assistant_messages #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
_tc = ChatCompletionMessageToolCallParam(
|
||||||
|
id="tc1", type="function", function=Function(name="do_stuff", arguments="{}")
|
||||||
|
)
|
||||||
|
_tc2 = ChatCompletionMessageToolCallParam(
|
||||||
|
id="tc2", type="function", function=Function(name="other", arguments="{}")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_noop_when_no_consecutive_assistants():
|
||||||
|
"""Messages without consecutive assistants are returned unchanged."""
|
||||||
|
msgs = [
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="hi"),
|
||||||
|
ChatCompletionAssistantMessageParam(role="assistant", content="hello"),
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="bye"),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
|
||||||
|
assert len(merged) == 3
|
||||||
|
assert [m["role"] for m in merged] == ["user", "assistant", "user"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_splits_text_and_tool_calls():
|
||||||
|
"""The exact bug scenario: text-only assistant followed by tool_calls-only assistant."""
|
||||||
|
msgs = [
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="build agent"),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="Let me build that"
|
||||||
|
),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="", tool_calls=[_tc]
|
||||||
|
),
|
||||||
|
ChatCompletionToolMessageParam(role="tool", content="ok", tool_call_id="tc1"),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
|
||||||
|
|
||||||
|
assert len(merged) == 3
|
||||||
|
assert merged[0]["role"] == "user"
|
||||||
|
assert merged[2]["role"] == "tool"
|
||||||
|
a = cast(ChatCompletionAssistantMessageParam, merged[1])
|
||||||
|
assert a["role"] == "assistant"
|
||||||
|
assert a.get("content") == "Let me build that"
|
||||||
|
assert a.get("tool_calls") == [_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_combines_tool_calls_from_both():
|
||||||
|
"""Both consecutive assistants have tool_calls — they get merged."""
|
||||||
|
msgs: list[ChatCompletionAssistantMessageParam] = [
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="text", tool_calls=[_tc]
|
||||||
|
),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="", tool_calls=[_tc2]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert len(merged) == 1
|
||||||
|
a = cast(ChatCompletionAssistantMessageParam, merged[0])
|
||||||
|
assert a.get("tool_calls") == [_tc, _tc2]
|
||||||
|
assert a.get("content") == "text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_three_consecutive_assistants():
|
||||||
|
"""Three consecutive assistants collapse into one."""
|
||||||
|
msgs: list[ChatCompletionAssistantMessageParam] = [
|
||||||
|
ChatCompletionAssistantMessageParam(role="assistant", content="a"),
|
||||||
|
ChatCompletionAssistantMessageParam(role="assistant", content="b"),
|
||||||
|
ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant", content="", tool_calls=[_tc]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert len(merged) == 1
|
||||||
|
a = cast(ChatCompletionAssistantMessageParam, merged[0])
|
||||||
|
assert a.get("content") == "a\nb"
|
||||||
|
assert a.get("tool_calls") == [_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_empty_and_single_message():
|
||||||
|
"""Edge cases: empty list and single message."""
|
||||||
|
assert ChatSession._merge_consecutive_assistant_messages([]) == []
|
||||||
|
|
||||||
|
single: list[ChatCompletionMessageParam] = [
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="hi")
|
||||||
|
]
|
||||||
|
assert ChatSession._merge_consecutive_assistant_messages(single) == single
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# add_tool_call_to_current_turn #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
_raw_tc = {
|
||||||
|
"id": "tc1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "f", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
_raw_tc2 = {
|
||||||
|
"id": "tc2",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "g", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_appends_to_existing_assistant():
|
||||||
|
"""When the last assistant is from the current turn, tool_call is added to it."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="hi"),
|
||||||
|
ChatMessage(role="assistant", content="working on it"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
|
||||||
|
assert len(session.messages) == 2 # no new message created
|
||||||
|
assert session.messages[1].tool_calls == [_raw_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_creates_assistant_when_none_exists():
|
||||||
|
"""When there's no current-turn assistant, a new one is created."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="hi"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
|
||||||
|
assert len(session.messages) == 2
|
||||||
|
assert session.messages[1].role == "assistant"
|
||||||
|
assert session.messages[1].tool_calls == [_raw_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_does_not_cross_user_boundary():
|
||||||
|
"""A user message acts as a boundary — previous assistant is not modified."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="assistant", content="old turn"),
|
||||||
|
ChatMessage(role="user", content="new message"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
|
||||||
|
assert len(session.messages) == 3 # new assistant was created
|
||||||
|
assert session.messages[0].tool_calls is None # old assistant untouched
|
||||||
|
assert session.messages[2].role == "assistant"
|
||||||
|
assert session.messages[2].tool_calls == [_raw_tc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_call_multiple_times():
|
||||||
|
"""Multiple long-running tool calls accumulate on the same assistant."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="hi"),
|
||||||
|
ChatMessage(role="assistant", content="doing stuff"),
|
||||||
|
]
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc)
|
||||||
|
# Simulate a pending tool result in between (like _yield_tool_call does)
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(role="tool", content="pending", tool_call_id="tc1")
|
||||||
|
)
|
||||||
|
session.add_tool_call_to_current_turn(_raw_tc2)
|
||||||
|
|
||||||
|
assert len(session.messages) == 3 # user, assistant, tool — no extra assistant
|
||||||
|
assert session.messages[1].tool_calls == [_raw_tc, _raw_tc2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_openai_messages_merges_split_assistants():
|
||||||
|
"""End-to-end: session with split assistants produces valid OpenAI messages."""
|
||||||
|
session = ChatSession.new(user_id="u")
|
||||||
|
session.messages = [
|
||||||
|
ChatMessage(role="user", content="build agent"),
|
||||||
|
ChatMessage(role="assistant", content="Let me build that"),
|
||||||
|
ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "tc1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "create_agent", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ChatMessage(role="tool", content="done", tool_call_id="tc1"),
|
||||||
|
ChatMessage(role="assistant", content="Saved!"),
|
||||||
|
ChatMessage(role="user", content="show me an example run"),
|
||||||
|
]
|
||||||
|
openai_msgs = session.to_openai_messages()
|
||||||
|
|
||||||
|
# The two consecutive assistants at index 1,2 should be merged
|
||||||
|
roles = [m["role"] for m in openai_msgs]
|
||||||
|
assert roles == ["user", "assistant", "tool", "assistant", "user"]
|
||||||
|
|
||||||
|
# The merged assistant should have both content and tool_calls
|
||||||
|
merged = cast(ChatCompletionAssistantMessageParam, openai_msgs[1])
|
||||||
|
assert merged.get("content") == "Let me build that"
|
||||||
|
tc_list = merged.get("tool_calls")
|
||||||
|
assert tc_list is not None and len(list(tc_list)) == 1
|
||||||
|
assert list(tc_list)[0]["id"] == "tc1"
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.util.json import dumps as json_dumps
|
||||||
|
|
||||||
|
|
||||||
class ResponseType(str, Enum):
|
class ResponseType(str, Enum):
|
||||||
"""Types of streaming responses following AI SDK protocol."""
|
"""Types of streaming responses following AI SDK protocol."""
|
||||||
@@ -193,6 +195,18 @@ class StreamError(StreamBaseResponse):
|
|||||||
default=None, description="Additional error details"
|
default=None, description="Additional error details"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
|
||||||
|
|
||||||
|
The AI SDK uses z.strictObject({type, errorText}) which rejects
|
||||||
|
any extra fields like `code` or `details`.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"errorText": self.errorText,
|
||||||
|
}
|
||||||
|
return f"data: {json_dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
class StreamHeartbeat(StreamBaseResponse):
|
class StreamHeartbeat(StreamBaseResponse):
|
||||||
"""Heartbeat to keep SSE connection alive during long-running operations.
|
"""Heartbeat to keep SSE connection alive during long-running operations.
|
||||||
|
|||||||
@@ -800,9 +800,13 @@ async def stream_chat_completion(
|
|||||||
# Build the messages list in the correct order
|
# Build the messages list in the correct order
|
||||||
messages_to_save: list[ChatMessage] = []
|
messages_to_save: list[ChatMessage] = []
|
||||||
|
|
||||||
# Add assistant message with tool_calls if any
|
# Add assistant message with tool_calls if any.
|
||||||
|
# Use extend (not assign) to preserve tool_calls already added by
|
||||||
|
# _yield_tool_call for long-running tools.
|
||||||
if accumulated_tool_calls:
|
if accumulated_tool_calls:
|
||||||
assistant_response.tool_calls = accumulated_tool_calls
|
if not assistant_response.tool_calls:
|
||||||
|
assistant_response.tool_calls = []
|
||||||
|
assistant_response.tool_calls.extend(accumulated_tool_calls)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||||
)
|
)
|
||||||
@@ -1404,13 +1408,9 @@ async def _yield_tool_call(
|
|||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save assistant message with tool_call FIRST (required by LLM)
|
# Attach the tool_call to the current turn's assistant message
|
||||||
assistant_message = ChatMessage(
|
# (or create one if this is a tool-only response with no text).
|
||||||
role="assistant",
|
session.add_tool_call_to_current_turn(tool_calls[yield_idx])
|
||||||
content="",
|
|
||||||
tool_calls=[tool_calls[yield_idx]],
|
|
||||||
)
|
|
||||||
session.messages.append(assistant_message)
|
|
||||||
|
|
||||||
# Then save pending tool result
|
# Then save pending tool result
|
||||||
pending_message = ChatMessage(
|
pending_message = ChatMessage(
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ from backend.api.features.chat.tools.models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
from backend.data.block import BlockType, get_block
|
from backend.blocks import get_block
|
||||||
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from backend.api.features.chat.tools.find_block import (
|
|||||||
FindBlockTool,
|
FindBlockTool,
|
||||||
)
|
)
|
||||||
from backend.api.features.chat.tools.models import BlockListResponse
|
from backend.api.features.chat.tools.models import BlockListResponse
|
||||||
from backend.data.block import BlockType
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
from ._test_data import make_session
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ from backend.api.features.chat.tools.find_block import (
|
|||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
)
|
)
|
||||||
from backend.data.block import AnyBlockSchema, get_block
|
from backend.blocks import get_block
|
||||||
|
from backend.blocks._base import AnyBlockSchema
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import pytest
|
|||||||
|
|
||||||
from backend.api.features.chat.tools.models import ErrorResponse
|
from backend.api.features.chat.tools.models import ErrorResponse
|
||||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||||
from backend.data.block import BlockType
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
from ._test_data import make_session
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
|||||||
@@ -12,12 +12,11 @@ import backend.api.features.store.image_gen as store_image_gen
|
|||||||
import backend.api.features.store.media as store_media
|
import backend.api.features.store.media as store_media
|
||||||
import backend.data.graph as graph_db
|
import backend.data.graph as graph_db
|
||||||
import backend.data.integrations as integrations_db
|
import backend.data.integrations as integrations_db
|
||||||
from backend.data.block import BlockInput
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.execution import get_graph_execution
|
from backend.data.execution import get_graph_execution
|
||||||
from backend.data.graph import GraphSettings
|
from backend.data.graph import GraphSettings
|
||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput, GraphInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||||
on_graph_activate,
|
on_graph_activate,
|
||||||
@@ -1130,7 +1129,7 @@ async def create_preset_from_graph_execution(
|
|||||||
async def update_preset(
|
async def update_preset(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
preset_id: str,
|
preset_id: str,
|
||||||
inputs: Optional[BlockInput] = None,
|
inputs: Optional[GraphInput] = None,
|
||||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ import prisma.enums
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from backend.data.block import BlockInput
|
|
||||||
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
||||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
from backend.data.model import (
|
||||||
|
CredentialsMetaInput,
|
||||||
|
GraphInput,
|
||||||
|
is_credentials_field_name,
|
||||||
|
)
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -323,7 +326,7 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
|
|||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
|
|
||||||
inputs: BlockInput
|
inputs: GraphInput
|
||||||
credentials: dict[str, CredentialsMetaInput]
|
credentials: dict[str, CredentialsMetaInput]
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@@ -352,7 +355,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
|
|||||||
Request model used when updating a preset for a library agent.
|
Request model used when updating a preset for a library agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inputs: Optional[BlockInput] = None
|
inputs: Optional[GraphInput] = None
|
||||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
||||||
|
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
@@ -395,7 +398,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
|||||||
"Webhook must be included in AgentPreset query when webhookId is set"
|
"Webhook must be included in AgentPreset query when webhookId is set"
|
||||||
)
|
)
|
||||||
|
|
||||||
input_data: BlockInput = {}
|
input_data: GraphInput = {}
|
||||||
input_credentials: dict[str, CredentialsMetaInput] = {}
|
input_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
|
|
||||||
for preset_input in preset.InputPresets:
|
for preset_input in preset.InputPresets:
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from typing import Optional
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from backend.blocks import get_block
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.block import get_block
|
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from .models import ApiResponse, ChatRequest, GraphData
|
from .models import ApiResponse, ChatRequest, GraphData
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class BlockHandler(ContentHandler):
|
|||||||
|
|
||||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||||
"""Fetch blocks without embeddings."""
|
"""Fetch blocks without embeddings."""
|
||||||
from backend.data.block import get_blocks
|
from backend.blocks import get_blocks
|
||||||
|
|
||||||
# Get all available blocks
|
# Get all available blocks
|
||||||
all_blocks = get_blocks()
|
all_blocks = get_blocks()
|
||||||
@@ -249,7 +249,7 @@ class BlockHandler(ContentHandler):
|
|||||||
|
|
||||||
async def get_stats(self) -> dict[str, int]:
|
async def get_stats(self) -> dict[str, int]:
|
||||||
"""Get statistics about block embedding coverage."""
|
"""Get statistics about block embedding coverage."""
|
||||||
from backend.data.block import get_blocks
|
from backend.blocks import get_blocks
|
||||||
|
|
||||||
all_blocks = get_blocks()
|
all_blocks = get_blocks()
|
||||||
|
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ async def test_block_handler_get_missing_items(mocker):
|
|||||||
mock_existing = []
|
mock_existing = []
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -135,7 +135,7 @@ async def test_block_handler_get_stats(mocker):
|
|||||||
mock_embedded = [{"count": 2}]
|
mock_embedded = [{"count": 2}]
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -327,7 +327,7 @@ async def test_block_handler_handles_missing_attributes():
|
|||||||
mock_blocks = {"block-minimal": mock_block_class}
|
mock_blocks = {"block-minimal": mock_block_class}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -360,7 +360,7 @@ async def test_block_handler_skips_failed_blocks():
|
|||||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.data.block.get_blocks",
|
"backend.blocks.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
|
|||||||
@@ -662,7 +662,7 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
|||||||
)
|
)
|
||||||
current_ids = {row["id"] for row in valid_agents}
|
current_ids = {row["id"] for row in valid_agents}
|
||||||
elif content_type == ContentType.BLOCK:
|
elif content_type == ContentType.BLOCK:
|
||||||
from backend.data.block import get_blocks
|
from backend.blocks import get_blocks
|
||||||
|
|
||||||
current_ids = set(get_blocks().keys())
|
current_ids = set(get_blocks().keys())
|
||||||
elif content_type == ContentType.DOCUMENTATION:
|
elif content_type == ContentType.DOCUMENTATION:
|
||||||
|
|||||||
@@ -7,15 +7,6 @@ from replicate.client import Client as ReplicateClient
|
|||||||
from replicate.exceptions import ReplicateError
|
from replicate.exceptions import ReplicateError
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.blocks.ideogram import (
|
|
||||||
AspectRatio,
|
|
||||||
ColorPalettePreset,
|
|
||||||
IdeogramModelBlock,
|
|
||||||
IdeogramModelName,
|
|
||||||
MagicPromptOption,
|
|
||||||
StyleType,
|
|
||||||
UpscaleOption,
|
|
||||||
)
|
|
||||||
from backend.data.graph import GraphBaseMeta
|
from backend.data.graph import GraphBaseMeta
|
||||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||||
from backend.integrations.credentials_store import ideogram_credentials
|
from backend.integrations.credentials_store import ideogram_credentials
|
||||||
@@ -50,6 +41,16 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
if not ideogram_credentials.api_key:
|
if not ideogram_credentials.api_key:
|
||||||
raise ValueError("Missing Ideogram API key")
|
raise ValueError("Missing Ideogram API key")
|
||||||
|
|
||||||
|
from backend.blocks.ideogram import (
|
||||||
|
AspectRatio,
|
||||||
|
ColorPalettePreset,
|
||||||
|
IdeogramModelBlock,
|
||||||
|
IdeogramModelName,
|
||||||
|
MagicPromptOption,
|
||||||
|
StyleType,
|
||||||
|
UpscaleOption,
|
||||||
|
)
|
||||||
|
|
||||||
name = graph.name
|
name = graph.name
|
||||||
description = f"{name} ({graph.description})" if graph.description else name
|
description = f"{name} ({graph.description})" if graph.description else name
|
||||||
|
|
||||||
|
|||||||
@@ -393,7 +393,6 @@ async def get_creators(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/creator/{username}",
|
"/creator/{username}",
|
||||||
summary="Get creator details",
|
summary="Get creator details",
|
||||||
operation_id="getV2GetCreatorDetails",
|
|
||||||
tags=["store", "public"],
|
tags=["store", "public"],
|
||||||
response_model=store_model.CreatorDetails,
|
response_model=store_model.CreatorDetails,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -40,10 +40,11 @@ from backend.api.model import (
|
|||||||
UpdateTimezoneRequest,
|
UpdateTimezoneRequest,
|
||||||
UploadFileResponse,
|
UploadFileResponse,
|
||||||
)
|
)
|
||||||
|
from backend.blocks import get_block, get_blocks
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.auth import api_key as api_key_db
|
from backend.data.auth import api_key as api_key_db
|
||||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||||
from backend.data.credit import (
|
from backend.data.credit import (
|
||||||
AutoTopUpConfig,
|
AutoTopUpConfig,
|
||||||
RefundRequest,
|
RefundRequest,
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from prisma.errors import PrismaError
|
|||||||
|
|
||||||
import backend.api.features.admin.credit_admin_routes
|
import backend.api.features.admin.credit_admin_routes
|
||||||
import backend.api.features.admin.execution_analytics_routes
|
import backend.api.features.admin.execution_analytics_routes
|
||||||
import backend.api.features.admin.llm_routes
|
|
||||||
import backend.api.features.admin.store_admin_routes
|
import backend.api.features.admin.store_admin_routes
|
||||||
import backend.api.features.builder
|
import backend.api.features.builder
|
||||||
import backend.api.features.builder.routes
|
import backend.api.features.builder.routes
|
||||||
@@ -39,15 +38,13 @@ import backend.data.db
|
|||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
import backend.data.user
|
import backend.data.user
|
||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
import backend.server.v2.llm.routes as public_llm_routes
|
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
from backend.api.features.chat.completion_consumer import (
|
from backend.api.features.chat.completion_consumer import (
|
||||||
start_completion_consumer,
|
start_completion_consumer,
|
||||||
stop_completion_consumer,
|
stop_completion_consumer,
|
||||||
)
|
)
|
||||||
from backend.data import llm_registry
|
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||||
from backend.data.block_cost_config import refresh_llm_costs
|
|
||||||
from backend.data.model import Credentials
|
from backend.data.model import Credentials
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.monitoring.instrumentation import instrument_fastapi
|
from backend.monitoring.instrumentation import instrument_fastapi
|
||||||
@@ -118,27 +115,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
|
|
||||||
AutoRegistry.patch_integrations()
|
AutoRegistry.patch_integrations()
|
||||||
|
|
||||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
|
||||||
await llm_registry.refresh_llm_registry()
|
|
||||||
refresh_llm_costs()
|
|
||||||
|
|
||||||
# Clear block schema caches so they're regenerated with updated discriminator_mapping
|
|
||||||
from backend.data.block import BlockSchema
|
|
||||||
|
|
||||||
BlockSchema.clear_all_schema_caches()
|
|
||||||
|
|
||||||
await backend.data.block.initialize_blocks()
|
await backend.data.block.initialize_blocks()
|
||||||
|
|
||||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||||
await backend.data.graph.fix_llm_provider_credentials()
|
await backend.data.graph.fix_llm_provider_credentials()
|
||||||
# migrate_llm_models uses registry default model
|
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||||
from backend.blocks.llm import LlmModel
|
|
||||||
|
|
||||||
default_model_slug = llm_registry.get_default_model_slug()
|
|
||||||
if default_model_slug:
|
|
||||||
await backend.data.graph.migrate_llm_models(LlmModel(default_model_slug))
|
|
||||||
else:
|
|
||||||
logger.warning("Skipping LLM model migration: no default model available")
|
|
||||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||||
|
|
||||||
# Start chat completion consumer for Redis Streams notifications
|
# Start chat completion consumer for Redis Streams notifications
|
||||||
@@ -340,16 +321,6 @@ app.include_router(
|
|||||||
tags=["v2", "executions", "review"],
|
tags=["v2", "executions", "review"],
|
||||||
prefix="/api/review",
|
prefix="/api/review",
|
||||||
)
|
)
|
||||||
app.include_router(
|
|
||||||
backend.api.features.admin.llm_routes.router,
|
|
||||||
tags=["v2", "admin", "llm"],
|
|
||||||
prefix="/api/llm/admin",
|
|
||||||
)
|
|
||||||
app.include_router(
|
|
||||||
public_llm_routes.router,
|
|
||||||
tags=["v2", "llm"],
|
|
||||||
prefix="/api",
|
|
||||||
)
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
|
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -79,39 +79,7 @@ async def event_broadcaster(manager: ConnectionManager):
|
|||||||
payload=notification.payload,
|
payload=notification.payload,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def registry_refresh_worker():
|
await asyncio.gather(execution_worker(), notification_worker())
|
||||||
"""Listen for LLM registry refresh notifications and broadcast to all clients."""
|
|
||||||
from backend.data.llm_registry import REGISTRY_REFRESH_CHANNEL
|
|
||||||
from backend.data.redis_client import connect_async
|
|
||||||
|
|
||||||
redis = await connect_async()
|
|
||||||
pubsub = redis.pubsub()
|
|
||||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
|
||||||
logger.info(
|
|
||||||
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
|
|
||||||
)
|
|
||||||
|
|
||||||
async for message in pubsub.listen():
|
|
||||||
if (
|
|
||||||
message["type"] == "message"
|
|
||||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"Broadcasting LLM registry refresh to all WebSocket clients"
|
|
||||||
)
|
|
||||||
await manager.broadcast_to_all(
|
|
||||||
method=WSMethod.NOTIFICATION,
|
|
||||||
data={
|
|
||||||
"type": "LLM_REGISTRY_REFRESH",
|
|
||||||
"event": "registry_updated",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await asyncio.gather(
|
|
||||||
execution_worker(),
|
|
||||||
notification_worker(),
|
|
||||||
registry_refresh_worker(),
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
# Ensure PubSub connections are closed on any exit to prevent leaks
|
# Ensure PubSub connections are closed on any exit to prevent leaks
|
||||||
await execution_bus.close()
|
await execution_bus.close()
|
||||||
|
|||||||
@@ -3,22 +3,19 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, TypeVar
|
from typing import Sequence, Type, TypeVar
|
||||||
|
|
||||||
|
from backend.blocks._base import AnyBlockSchema, BlockType
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from backend.data.block import Block
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600)
|
@cached(ttl_seconds=3600)
|
||||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
|
||||||
from backend.data.block import Block
|
from backend.blocks._base import Block
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
|
|
||||||
# Check if example blocks should be loaded from settings
|
# Check if example blocks should be loaded from settings
|
||||||
@@ -50,8 +47,8 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
importlib.import_module(f".{module}", package=__name__)
|
importlib.import_module(f".{module}", package=__name__)
|
||||||
|
|
||||||
# Load all Block instances from the available modules
|
# Load all Block instances from the available modules
|
||||||
available_blocks: dict[str, type["Block"]] = {}
|
available_blocks: dict[str, type["AnyBlockSchema"]] = {}
|
||||||
for block_cls in all_subclasses(Block):
|
for block_cls in _all_subclasses(Block):
|
||||||
class_name = block_cls.__name__
|
class_name = block_cls.__name__
|
||||||
|
|
||||||
if class_name.endswith("Base"):
|
if class_name.endswith("Base"):
|
||||||
@@ -64,7 +61,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
"please name the class with 'Base' at the end"
|
"please name the class with 'Base' at the end"
|
||||||
)
|
)
|
||||||
|
|
||||||
block = block_cls.create()
|
block = block_cls() # pyright: ignore[reportAbstractUsage]
|
||||||
|
|
||||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -105,7 +102,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
available_blocks[block.id] = block_cls
|
available_blocks[block.id] = block_cls
|
||||||
|
|
||||||
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
||||||
from backend.data.block import is_block_auth_configured
|
from ._utils import is_block_auth_configured
|
||||||
|
|
||||||
filtered_blocks = {}
|
filtered_blocks = {}
|
||||||
for block_id, block_cls in available_blocks.items():
|
for block_id, block_cls in available_blocks.items():
|
||||||
@@ -115,11 +112,48 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
|||||||
return filtered_blocks
|
return filtered_blocks
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["load_all_blocks"]
|
def _all_subclasses(cls: type[T]) -> list[type[T]]:
|
||||||
|
|
||||||
|
|
||||||
def all_subclasses(cls: type[T]) -> list[type[T]]:
|
|
||||||
subclasses = cls.__subclasses__()
|
subclasses = cls.__subclasses__()
|
||||||
for subclass in subclasses:
|
for subclass in subclasses:
|
||||||
subclasses += all_subclasses(subclass)
|
subclasses += _all_subclasses(subclass)
|
||||||
return subclasses
|
return subclasses
|
||||||
|
|
||||||
|
|
||||||
|
# ============== Block access helper functions ============== #
|
||||||
|
|
||||||
|
|
||||||
|
def get_blocks() -> dict[str, Type["AnyBlockSchema"]]:
|
||||||
|
return load_all_blocks()
|
||||||
|
|
||||||
|
|
||||||
|
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||||
|
def get_block(block_id: str) -> "AnyBlockSchema | None":
|
||||||
|
cls = get_blocks().get(block_id)
|
||||||
|
return cls() if cls else None
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_webhook_block_ids() -> Sequence[str]:
|
||||||
|
return [
|
||||||
|
id
|
||||||
|
for id, B in get_blocks().items()
|
||||||
|
if B().block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_io_block_ids() -> Sequence[str]:
|
||||||
|
return [
|
||||||
|
id
|
||||||
|
for id, B in get_blocks().items()
|
||||||
|
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
|
def get_human_in_the_loop_block_ids() -> Sequence[str]:
|
||||||
|
return [
|
||||||
|
id
|
||||||
|
for id, B in get_blocks().items()
|
||||||
|
if B().block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||||
|
]
|
||||||
|
|||||||
739
autogpt_platform/backend/backend/blocks/_base.py
Normal file
739
autogpt_platform/backend/backend/blocks/_base.py
Normal file
@@ -0,0 +1,739 @@
|
|||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
ClassVar,
|
||||||
|
Generic,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeAlias,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
get_origin,
|
||||||
|
)
|
||||||
|
|
||||||
|
import jsonref
|
||||||
|
import jsonschema
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
|
||||||
|
from backend.data.model import (
|
||||||
|
Credentials,
|
||||||
|
CredentialsFieldInfo,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
SchemaField,
|
||||||
|
is_credentials_field_name,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util import json
|
||||||
|
from backend.util.exceptions import (
|
||||||
|
BlockError,
|
||||||
|
BlockExecutionError,
|
||||||
|
BlockInputError,
|
||||||
|
BlockOutputError,
|
||||||
|
BlockUnknownError,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import ContributorDetails, NodeExecutionStats
|
||||||
|
|
||||||
|
from ..data.graph import Link
|
||||||
|
|
||||||
|
app_config = Config()
|
||||||
|
|
||||||
|
|
||||||
|
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
|
||||||
|
|
||||||
|
|
||||||
|
class BlockType(Enum):
|
||||||
|
STANDARD = "Standard"
|
||||||
|
INPUT = "Input"
|
||||||
|
OUTPUT = "Output"
|
||||||
|
NOTE = "Note"
|
||||||
|
WEBHOOK = "Webhook"
|
||||||
|
WEBHOOK_MANUAL = "Webhook (manual)"
|
||||||
|
AGENT = "Agent"
|
||||||
|
AI = "AI"
|
||||||
|
AYRSHARE = "Ayrshare"
|
||||||
|
HUMAN_IN_THE_LOOP = "Human In The Loop"
|
||||||
|
|
||||||
|
|
||||||
|
class BlockCategory(Enum):
|
||||||
|
AI = "Block that leverages AI to perform a task."
|
||||||
|
SOCIAL = "Block that interacts with social media platforms."
|
||||||
|
TEXT = "Block that processes text data."
|
||||||
|
SEARCH = "Block that searches or extracts information from the internet."
|
||||||
|
BASIC = "Block that performs basic operations."
|
||||||
|
INPUT = "Block that interacts with input of the graph."
|
||||||
|
OUTPUT = "Block that interacts with output of the graph."
|
||||||
|
LOGIC = "Programming logic to control the flow of your agent"
|
||||||
|
COMMUNICATION = "Block that interacts with communication platforms."
|
||||||
|
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
|
||||||
|
DATA = "Block that interacts with structured data."
|
||||||
|
HARDWARE = "Block that interacts with hardware."
|
||||||
|
AGENT = "Block that interacts with other agents."
|
||||||
|
CRM = "Block that interacts with CRM services."
|
||||||
|
SAFETY = (
|
||||||
|
"Block that provides AI safety mechanisms such as detecting harmful content"
|
||||||
|
)
|
||||||
|
PRODUCTIVITY = "Block that helps with productivity"
|
||||||
|
ISSUE_TRACKING = "Block that helps with issue tracking"
|
||||||
|
MULTIMEDIA = "Block that interacts with multimedia content"
|
||||||
|
MARKETING = "Block that helps with marketing"
|
||||||
|
|
||||||
|
def dict(self) -> dict[str, str]:
|
||||||
|
return {"category": self.name, "description": self.value}
|
||||||
|
|
||||||
|
|
||||||
|
class BlockCostType(str, Enum):
|
||||||
|
RUN = "run" # cost X credits per run
|
||||||
|
BYTE = "byte" # cost X credits per byte
|
||||||
|
SECOND = "second" # cost X credits per second
|
||||||
|
|
||||||
|
|
||||||
|
class BlockCost(BaseModel):
|
||||||
|
cost_amount: int
|
||||||
|
cost_filter: BlockInput
|
||||||
|
cost_type: BlockCostType
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cost_amount: int,
|
||||||
|
cost_type: BlockCostType = BlockCostType.RUN,
|
||||||
|
cost_filter: Optional[BlockInput] = None,
|
||||||
|
**data: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
cost_amount=cost_amount,
|
||||||
|
cost_filter=cost_filter or {},
|
||||||
|
cost_type=cost_type,
|
||||||
|
**data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockInfo(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
inputSchema: dict[str, Any]
|
||||||
|
outputSchema: dict[str, Any]
|
||||||
|
costs: list[BlockCost]
|
||||||
|
description: str
|
||||||
|
categories: list[dict[str, str]]
|
||||||
|
contributors: list[dict[str, Any]]
|
||||||
|
staticOutput: bool
|
||||||
|
uiType: str
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSchema(BaseModel):
|
||||||
|
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def jsonschema(cls) -> dict[str, Any]:
|
||||||
|
if cls.cached_jsonschema:
|
||||||
|
return cls.cached_jsonschema
|
||||||
|
|
||||||
|
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||||
|
|
||||||
|
def ref_to_dict(obj):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||||
|
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||||
|
keys = {"allOf", "anyOf", "oneOf"}
|
||||||
|
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
|
||||||
|
if one_key:
|
||||||
|
obj.update(obj[one_key][0])
|
||||||
|
|
||||||
|
return {
|
||||||
|
key: ref_to_dict(value)
|
||||||
|
for key, value in obj.items()
|
||||||
|
if not key.startswith("$") and key != one_key
|
||||||
|
}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [ref_to_dict(item) for item in obj]
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||||
|
|
||||||
|
return cls.cached_jsonschema
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_data(cls, data: BlockInput) -> str | None:
|
||||||
|
return json.validate_with_jsonschema(
|
||||||
|
schema=cls.jsonschema(),
|
||||||
|
data={k: v for k, v in data.items() if v is not None},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||||
|
return cls.validate_data(data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
|
||||||
|
model_schema = cls.jsonschema().get("properties", {})
|
||||||
|
if not model_schema:
|
||||||
|
raise ValueError(f"Invalid model schema {cls}")
|
||||||
|
|
||||||
|
property_schema = model_schema.get(field_name)
|
||||||
|
if not property_schema:
|
||||||
|
raise ValueError(f"Invalid property name {field_name}")
|
||||||
|
|
||||||
|
return property_schema
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
|
||||||
|
"""
|
||||||
|
Validate the data against a specific property (one of the input/output name).
|
||||||
|
Returns the validation error message if the data does not match the schema.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
property_schema = cls.get_field_schema(field_name)
|
||||||
|
jsonschema.validate(json.to_dict(data), property_schema)
|
||||||
|
return None
|
||||||
|
except jsonschema.ValidationError as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_fields(cls) -> set[str]:
|
||||||
|
return set(cls.model_fields.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_required_fields(cls) -> set[str]:
|
||||||
|
return {
|
||||||
|
field
|
||||||
|
for field, field_info in cls.model_fields.items()
|
||||||
|
if field_info.is_required()
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __pydantic_init_subclass__(cls, **kwargs):
|
||||||
|
"""Validates the schema definition. Rules:
|
||||||
|
- Fields with annotation `CredentialsMetaInput` MUST be
|
||||||
|
named `credentials` or `*_credentials`
|
||||||
|
- Fields named `credentials` or `*_credentials` MUST be
|
||||||
|
of type `CredentialsMetaInput`
|
||||||
|
"""
|
||||||
|
super().__pydantic_init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
# Reset cached JSON schema to prevent inheriting it from parent class
|
||||||
|
cls.cached_jsonschema = {}
|
||||||
|
|
||||||
|
credentials_fields = cls.get_credentials_fields()
|
||||||
|
|
||||||
|
for field_name in cls.get_fields():
|
||||||
|
if is_credentials_field_name(field_name):
|
||||||
|
if field_name not in credentials_fields:
|
||||||
|
raise TypeError(
|
||||||
|
f"Credentials field '{field_name}' on {cls.__qualname__} "
|
||||||
|
f"is not of type {CredentialsMetaInput.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
CredentialsMetaInput.validate_credentials_field_schema(
|
||||||
|
cls.get_field_schema(field_name), field_name
|
||||||
|
)
|
||||||
|
|
||||||
|
elif field_name in credentials_fields:
|
||||||
|
raise KeyError(
|
||||||
|
f"Credentials field '{field_name}' on {cls.__qualname__} "
|
||||||
|
"has invalid name: must be 'credentials' or *_credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
|
||||||
|
return {
|
||||||
|
field_name: info.annotation
|
||||||
|
for field_name, info in cls.model_fields.items()
|
||||||
|
if (
|
||||||
|
inspect.isclass(info.annotation)
|
||||||
|
and issubclass(
|
||||||
|
get_origin(info.annotation) or info.annotation,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_auto_credentials_fields(cls) -> dict[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get fields that have auto_credentials metadata (e.g., GoogleDriveFileInput).
|
||||||
|
|
||||||
|
Returns a dict mapping kwarg_name -> {field_name, auto_credentials_config}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If multiple fields have the same kwarg_name, as this would
|
||||||
|
cause silent overwriting and only the last field would be processed.
|
||||||
|
"""
|
||||||
|
result: dict[str, dict[str, Any]] = {}
|
||||||
|
schema = cls.jsonschema()
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
|
||||||
|
for field_name, field_schema in properties.items():
|
||||||
|
auto_creds = field_schema.get("auto_credentials")
|
||||||
|
if auto_creds:
|
||||||
|
kwarg_name = auto_creds.get("kwarg_name", "credentials")
|
||||||
|
if kwarg_name in result:
|
||||||
|
raise ValueError(
|
||||||
|
f"Duplicate auto_credentials kwarg_name '{kwarg_name}' "
|
||||||
|
f"in fields '{result[kwarg_name]['field_name']}' and "
|
||||||
|
f"'{field_name}' on {cls.__qualname__}"
|
||||||
|
)
|
||||||
|
result[kwarg_name] = {
|
||||||
|
"field_name": field_name,
|
||||||
|
"config": auto_creds,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# Regular credentials fields
|
||||||
|
for field_name in cls.get_credentials_fields().keys():
|
||||||
|
result[field_name] = CredentialsFieldInfo.model_validate(
|
||||||
|
cls.get_field_schema(field_name), by_alias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-generated credentials fields (from GoogleDriveFileInput etc.)
|
||||||
|
for kwarg_name, info in cls.get_auto_credentials_fields().items():
|
||||||
|
config = info["config"]
|
||||||
|
# Build a schema-like dict that CredentialsFieldInfo can parse
|
||||||
|
auto_schema = {
|
||||||
|
"credentials_provider": [config.get("provider", "google")],
|
||||||
|
"credentials_types": [config.get("type", "oauth2")],
|
||||||
|
"credentials_scopes": config.get("scopes"),
|
||||||
|
}
|
||||||
|
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||||
|
auto_schema, by_alias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||||
|
return data # Return as is, by default.
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||||
|
input_fields_from_nodes = {link.sink_name for link in links}
|
||||||
|
return input_fields_from_nodes - set(data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||||
|
return cls.get_required_fields() - set(data)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSchemaInput(BlockSchema):
|
||||||
|
"""
|
||||||
|
Base schema class for block inputs.
|
||||||
|
All block input schemas should extend this class for consistency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSchemaOutput(BlockSchema):
|
||||||
|
"""
|
||||||
|
Base schema class for block outputs that includes a standard error field.
|
||||||
|
All block output schemas should extend this class to ensure consistent error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if the operation failed", default=""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchemaInput)
|
||||||
|
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchemaOutput)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyInputSchema(BlockSchemaInput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyOutputSchema(BlockSchemaOutput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility - will be deprecated
|
||||||
|
EmptySchema = EmptyOutputSchema
|
||||||
|
|
||||||
|
|
||||||
|
# --8<-- [start:BlockWebhookConfig]
|
||||||
|
class BlockManualWebhookConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration model for webhook-triggered blocks on which
|
||||||
|
the user has to manually set up the webhook at the provider.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: ProviderName
|
||||||
|
"""The service provider that the webhook connects to"""
|
||||||
|
|
||||||
|
webhook_type: str
|
||||||
|
"""
|
||||||
|
Identifier for the webhook type. E.g. GitHub has repo and organization level hooks.
|
||||||
|
|
||||||
|
Only for use in the corresponding `WebhooksManager`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_filter_input: str = ""
|
||||||
|
"""
|
||||||
|
Name of the block's event filter input.
|
||||||
|
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_format: str = "{event}"
|
||||||
|
"""
|
||||||
|
Template string for the event(s) that a block instance subscribes to.
|
||||||
|
Applied individually to each event selected in the event filter input.
|
||||||
|
|
||||||
|
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||||
|
"""
|
||||||
|
Configuration model for webhook-triggered blocks for which
|
||||||
|
the webhook can be automatically set up through the provider's API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
resource_format: str
|
||||||
|
"""
|
||||||
|
Template string for the resource that a block instance subscribes to.
|
||||||
|
Fields will be filled from the block's inputs (except `payload`).
|
||||||
|
|
||||||
|
Example: `f"{repo}/pull_requests"` (note: not how it's actually implemented)
|
||||||
|
|
||||||
|
Only for use in the corresponding `WebhooksManager`.
|
||||||
|
"""
|
||||||
|
# --8<-- [end:BlockWebhookConfig]
|
||||||
|
|
||||||
|
|
||||||
|
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str = "",
|
||||||
|
description: str = "",
|
||||||
|
contributors: list["ContributorDetails"] = [],
|
||||||
|
categories: set[BlockCategory] | None = None,
|
||||||
|
input_schema: Type[BlockSchemaInputType] = EmptyInputSchema,
|
||||||
|
output_schema: Type[BlockSchemaOutputType] = EmptyOutputSchema,
|
||||||
|
test_input: BlockInput | list[BlockInput] | None = None,
|
||||||
|
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
|
||||||
|
test_mock: dict[str, Any] | None = None,
|
||||||
|
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
|
||||||
|
disabled: bool = False,
|
||||||
|
static_output: bool = False,
|
||||||
|
block_type: BlockType = BlockType.STANDARD,
|
||||||
|
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
||||||
|
is_sensitive_action: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the block with the given schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: The unique identifier for the block, this value will be persisted in the
|
||||||
|
DB. So it should be a unique and constant across the application run.
|
||||||
|
Use the UUID format for the ID.
|
||||||
|
description: The description of the block, explaining what the block does.
|
||||||
|
contributors: The list of contributors who contributed to the block.
|
||||||
|
input_schema: The schema, defined as a Pydantic model, for the input data.
|
||||||
|
output_schema: The schema, defined as a Pydantic model, for the output data.
|
||||||
|
test_input: The list or single sample input data for the block, for testing.
|
||||||
|
test_output: The list or single expected output if the test_input is run.
|
||||||
|
test_mock: function names on the block implementation to mock on test run.
|
||||||
|
disabled: If the block is disabled, it will not be available for execution.
|
||||||
|
static_output: Whether the output links of the block are static by default.
|
||||||
|
"""
|
||||||
|
from backend.data.model import NodeExecutionStats
|
||||||
|
|
||||||
|
self.id = id
|
||||||
|
self.input_schema = input_schema
|
||||||
|
self.output_schema = output_schema
|
||||||
|
self.test_input = test_input
|
||||||
|
self.test_output = test_output
|
||||||
|
self.test_mock = test_mock
|
||||||
|
self.test_credentials = test_credentials
|
||||||
|
self.description = description
|
||||||
|
self.categories = categories or set()
|
||||||
|
self.contributors = contributors or set()
|
||||||
|
self.disabled = disabled
|
||||||
|
self.static_output = static_output
|
||||||
|
self.block_type = block_type
|
||||||
|
self.webhook_config = webhook_config
|
||||||
|
self.is_sensitive_action = is_sensitive_action
|
||||||
|
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||||
|
|
||||||
|
if self.webhook_config:
|
||||||
|
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||||
|
# Enforce presence of credentials field on auto-setup webhook blocks
|
||||||
|
if not (cred_fields := self.input_schema.get_credentials_fields()):
|
||||||
|
raise TypeError(
|
||||||
|
"credentials field is required on auto-setup webhook blocks"
|
||||||
|
)
|
||||||
|
# Disallow multiple credentials inputs on webhook blocks
|
||||||
|
elif len(cred_fields) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Multiple credentials inputs not supported on webhook blocks"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.block_type = BlockType.WEBHOOK
|
||||||
|
else:
|
||||||
|
self.block_type = BlockType.WEBHOOK_MANUAL
|
||||||
|
|
||||||
|
# Enforce shape of webhook event filter, if present
|
||||||
|
if self.webhook_config.event_filter_input:
|
||||||
|
event_filter_field = self.input_schema.model_fields[
|
||||||
|
self.webhook_config.event_filter_input
|
||||||
|
]
|
||||||
|
if not (
|
||||||
|
isinstance(event_filter_field.annotation, type)
|
||||||
|
and issubclass(event_filter_field.annotation, BaseModel)
|
||||||
|
and all(
|
||||||
|
field.annotation is bool
|
||||||
|
for field in event_filter_field.annotation.model_fields.values()
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.name} has an invalid webhook event selector: "
|
||||||
|
"field must be a BaseModel and all its fields must be boolean"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enforce presence of 'payload' input
|
||||||
|
if "payload" not in self.input_schema.model_fields:
|
||||||
|
raise TypeError(
|
||||||
|
f"{self.name} is webhook-triggered but has no 'payload' input"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable webhook-triggered block if webhook functionality not available
|
||||||
|
if not app_config.platform_base_url:
|
||||||
|
self.disabled = True
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
|
||||||
|
"""
|
||||||
|
Run the block with the given input data.
|
||||||
|
Args:
|
||||||
|
input_data: The input data with the structure of input_schema.
|
||||||
|
|
||||||
|
Kwargs: Currently 14/02/2025 these include
|
||||||
|
graph_id: The ID of the graph.
|
||||||
|
node_id: The ID of the node.
|
||||||
|
graph_exec_id: The ID of the graph execution.
|
||||||
|
node_exec_id: The ID of the node execution.
|
||||||
|
user_id: The ID of the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Generator that yields (output_name, output_data).
|
||||||
|
output_name: One of the output name defined in Block's output_schema.
|
||||||
|
output_data: The data for the output_name, matching the defined schema.
|
||||||
|
"""
|
||||||
|
# --- satisfy the type checker, never executed -------------
|
||||||
|
if False: # noqa: SIM115
|
||||||
|
yield "name", "value" # pyright: ignore[reportMissingYield]
|
||||||
|
raise NotImplementedError(f"{self.name} does not implement the run method.")
|
||||||
|
|
||||||
|
async def run_once(
|
||||||
|
self, input_data: BlockSchemaInputType, output: str, **kwargs
|
||||||
|
) -> Any:
|
||||||
|
async for item in self.run(input_data, **kwargs):
|
||||||
|
name, data = item
|
||||||
|
if name == output:
|
||||||
|
return data
|
||||||
|
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||||
|
|
||||||
|
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||||
|
self.execution_stats += stats
|
||||||
|
return self.execution_stats
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"inputSchema": self.input_schema.jsonschema(),
|
||||||
|
"outputSchema": self.output_schema.jsonschema(),
|
||||||
|
"description": self.description,
|
||||||
|
"categories": [category.dict() for category in self.categories],
|
||||||
|
"contributors": [
|
||||||
|
contributor.model_dump() for contributor in self.contributors
|
||||||
|
],
|
||||||
|
"staticOutput": self.static_output,
|
||||||
|
"uiType": self.block_type.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_info(self) -> BlockInfo:
|
||||||
|
from backend.data.credit import get_block_cost
|
||||||
|
|
||||||
|
return BlockInfo(
|
||||||
|
id=self.id,
|
||||||
|
name=self.name,
|
||||||
|
inputSchema=self.input_schema.jsonschema(),
|
||||||
|
outputSchema=self.output_schema.jsonschema(),
|
||||||
|
costs=get_block_cost(self),
|
||||||
|
description=self.description,
|
||||||
|
categories=[category.dict() for category in self.categories],
|
||||||
|
contributors=[
|
||||||
|
contributor.model_dump() for contributor in self.contributors
|
||||||
|
],
|
||||||
|
staticOutput=self.static_output,
|
||||||
|
uiType=self.block_type.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
async for output_name, output_data in self._execute(input_data, **kwargs):
|
||||||
|
yield output_name, output_data
|
||||||
|
except Exception as ex:
|
||||||
|
if isinstance(ex, BlockError):
|
||||||
|
raise ex
|
||||||
|
else:
|
||||||
|
raise (
|
||||||
|
BlockExecutionError
|
||||||
|
if isinstance(ex, ValueError)
|
||||||
|
else BlockUnknownError
|
||||||
|
)(
|
||||||
|
message=str(ex),
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
) from ex
|
||||||
|
|
||||||
|
async def is_block_exec_need_review(
|
||||||
|
self,
|
||||||
|
input_data: BlockInput,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
node_id: str,
|
||||||
|
node_exec_id: str,
|
||||||
|
graph_exec_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
execution_context: "ExecutionContext",
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[bool, BlockInput]:
|
||||||
|
"""
|
||||||
|
Check if this block execution needs human review and handle the review process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (should_pause, input_data_to_use)
|
||||||
|
- should_pause: True if execution should be paused for review
|
||||||
|
- input_data_to_use: The input data to use (may be modified by reviewer)
|
||||||
|
"""
|
||||||
|
if not (
|
||||||
|
self.is_sensitive_action and execution_context.sensitive_action_safe_mode
|
||||||
|
):
|
||||||
|
return False, input_data
|
||||||
|
|
||||||
|
from backend.blocks.helpers.review import HITLReviewHelper
|
||||||
|
|
||||||
|
# Handle the review request and get decision
|
||||||
|
decision = await HITLReviewHelper.handle_review_decision(
|
||||||
|
input_data=input_data,
|
||||||
|
user_id=user_id,
|
||||||
|
node_id=node_id,
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
block_name=self.name,
|
||||||
|
editable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decision is None:
|
||||||
|
# We're awaiting review - pause execution
|
||||||
|
return True, input_data
|
||||||
|
|
||||||
|
if not decision.should_proceed:
|
||||||
|
# Review was rejected, raise an error to stop execution
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Block execution rejected by reviewer: {decision.message}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Review was approved - use the potentially modified data
|
||||||
|
# ReviewResult.data must be a dict for block inputs
|
||||||
|
reviewed_data = decision.review_result.data
|
||||||
|
if not isinstance(reviewed_data, dict):
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
return False, reviewed_data
|
||||||
|
|
||||||
|
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||||
|
# Check for review requirement only if running within a graph execution context
|
||||||
|
# Direct block execution (e.g., from chat) skips the review process
|
||||||
|
has_graph_context = all(
|
||||||
|
key in kwargs
|
||||||
|
for key in (
|
||||||
|
"node_exec_id",
|
||||||
|
"graph_exec_id",
|
||||||
|
"graph_id",
|
||||||
|
"execution_context",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if has_graph_context:
|
||||||
|
should_pause, input_data = await self.is_block_exec_need_review(
|
||||||
|
input_data, **kwargs
|
||||||
|
)
|
||||||
|
if should_pause:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Validate the input data (original or reviewer-modified) once
|
||||||
|
if error := self.input_schema.validate_data(input_data):
|
||||||
|
raise BlockInputError(
|
||||||
|
message=f"Unable to execute block with invalid input data: {error}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the validated input data
|
||||||
|
async for output_name, output_data in self.run(
|
||||||
|
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if output_name == "error":
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=output_data, block_name=self.name, block_id=self.id
|
||||||
|
)
|
||||||
|
if self.block_type == BlockType.STANDARD and (
|
||||||
|
error := self.output_schema.validate_field(output_name, output_data)
|
||||||
|
):
|
||||||
|
raise BlockOutputError(
|
||||||
|
message=f"Block produced an invalid output data: {error}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
yield output_name, output_data
|
||||||
|
|
||||||
|
def is_triggered_by_event_type(
|
||||||
|
self, trigger_config: dict[str, Any], event_type: str
|
||||||
|
) -> bool:
|
||||||
|
if not self.webhook_config:
|
||||||
|
raise TypeError("This method can't be used on non-trigger blocks")
|
||||||
|
if not self.webhook_config.event_filter_input:
|
||||||
|
return True
|
||||||
|
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
|
||||||
|
if not event_filter:
|
||||||
|
raise ValueError("Event filter is not configured on trigger")
|
||||||
|
return event_type in [
|
||||||
|
self.webhook_config.event_format.format(event=k)
|
||||||
|
for k in event_filter
|
||||||
|
if event_filter[k] is True
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for any block with standard input/output schemas
|
||||||
|
AnyBlockSchema: TypeAlias = Block[BlockSchemaInput, BlockSchemaOutput]
|
||||||
122
autogpt_platform/backend/backend/blocks/_utils.py
Normal file
122
autogpt_platform/backend/backend/blocks/_utils.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
from ._base import AnyBlockSchema
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_block_auth_configured(
|
||||||
|
block_cls: type[AnyBlockSchema],
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a block has a valid authentication method configured at runtime.
|
||||||
|
|
||||||
|
For example if a block is an OAuth-only block and there env vars are not set,
|
||||||
|
do not show it in the UI.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from backend.sdk.registry import AutoRegistry
|
||||||
|
|
||||||
|
# Create an instance to access input_schema
|
||||||
|
try:
|
||||||
|
block = block_cls()
|
||||||
|
except Exception as e:
|
||||||
|
# If we can't create a block instance, assume it's not OAuth-only
|
||||||
|
logger.error(f"Error creating block instance for {block_cls.__name__}: {e}")
|
||||||
|
return True
|
||||||
|
logger.debug(
|
||||||
|
f"Checking if block {block_cls.__name__} has a valid provider configured"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all credential inputs from input schema
|
||||||
|
credential_inputs = block.input_schema.get_credentials_fields_info()
|
||||||
|
required_inputs = block.input_schema.get_required_fields()
|
||||||
|
if not credential_inputs:
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} has no credential inputs - Treating as valid"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check credential inputs
|
||||||
|
if len(required_inputs.intersection(credential_inputs.keys())) == 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} has only optional credential inputs"
|
||||||
|
" - will work without credentials configured"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the credential inputs for this block are correctly configured
|
||||||
|
for field_name, field_info in credential_inputs.items():
|
||||||
|
provider_names = field_info.provider
|
||||||
|
if not provider_names:
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} "
|
||||||
|
f"has credential input '{field_name}' with no provider options"
|
||||||
|
" - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If a field has multiple possible providers, each one needs to be usable to
|
||||||
|
# prevent breaking the UX
|
||||||
|
for _provider_name in provider_names:
|
||||||
|
provider_name = _provider_name.value
|
||||||
|
if provider_name in ProviderName.__members__.values():
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' is part of the legacy provider system"
|
||||||
|
" - Treating as valid"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
provider = AutoRegistry.get_provider(provider_name)
|
||||||
|
if not provider:
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"refers to unknown provider '{provider_name}' - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check the provider's supported auth types
|
||||||
|
if field_info.supported_types != provider.supported_auth_types:
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"has mismatched supported auth types (field <> Provider): "
|
||||||
|
f"{field_info.supported_types} != {provider.supported_auth_types}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (supported_auth_types := provider.supported_auth_types):
|
||||||
|
# No auth methods are been configured for this provider
|
||||||
|
logger.warning(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' "
|
||||||
|
"has no authentication methods configured - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if provider supports OAuth
|
||||||
|
if "oauth2" in supported_auth_types:
|
||||||
|
# Check if OAuth environment variables are set
|
||||||
|
if (oauth_config := provider.oauth_config) and bool(
|
||||||
|
os.getenv(oauth_config.client_id_env_var)
|
||||||
|
and os.getenv(oauth_config.client_secret_env_var)
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' is configured for OAuth"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' "
|
||||||
|
f"provider '{provider_name}' "
|
||||||
|
"is missing OAuth client ID or secret - Disabling"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Block {block_cls.__name__} credential input '{field_name}' is valid; "
|
||||||
|
f"supported credential types: {', '.join(field_info.supported_types)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockInput,
|
BlockInput,
|
||||||
@@ -9,13 +9,15 @@ from backend.data.block import (
|
|||||||
BlockSchema,
|
BlockSchema,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
get_block,
|
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
|
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
|
||||||
from backend.data.model import NodeExecutionStats, SchemaField
|
from backend.data.model import NodeExecutionStats, SchemaField
|
||||||
from backend.util.json import validate_with_jsonschema
|
from backend.util.json import validate_with_jsonschema
|
||||||
from backend.util.retry import func_retry
|
from backend.util.retry import func_retry
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.executor.utils import LogMetadata
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -124,9 +126,10 @@ class AgentExecutorBlock(Block):
|
|||||||
graph_version: int,
|
graph_version: int,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
logger,
|
logger: "LogMetadata",
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
|
from backend.blocks import get_block
|
||||||
from backend.data.execution import ExecutionEventType
|
from backend.data.execution import ExecutionEventType
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
|
|
||||||
@@ -198,7 +201,7 @@ class AgentExecutorBlock(Block):
|
|||||||
self,
|
self,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
logger,
|
logger: "LogMetadata",
|
||||||
) -> None:
|
) -> None:
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.llm import (
|
from backend.blocks.llm import (
|
||||||
|
DEFAULT_LLM_MODEL,
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
TEST_CREDENTIALS_INPUT,
|
TEST_CREDENTIALS_INPUT,
|
||||||
AIBlockBase,
|
AIBlockBase,
|
||||||
@@ -9,13 +16,6 @@ from backend.blocks.llm import (
|
|||||||
LlmModel,
|
LlmModel,
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
llm_call,
|
llm_call,
|
||||||
llm_model_schema_extra,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
)
|
||||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||||
|
|
||||||
@@ -50,10 +50,9 @@ class AIConditionBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default_factory=LlmModel.default,
|
default=DEFAULT_LLM_MODEL,
|
||||||
description="The language model to use for evaluating the condition.",
|
description="The language model to use for evaluating the condition.",
|
||||||
advanced=False,
|
advanced=False,
|
||||||
json_schema_extra=llm_model_schema_extra(),
|
|
||||||
)
|
)
|
||||||
credentials: AICredentials = AICredentialsField()
|
credentials: AICredentials = AICredentialsField()
|
||||||
|
|
||||||
@@ -83,7 +82,7 @@ class AIConditionBlock(AIBlockBase):
|
|||||||
"condition": "the input is an email address",
|
"condition": "the input is an email address",
|
||||||
"yes_value": "Valid email",
|
"yes_value": "Valid email",
|
||||||
"no_value": "Not an email",
|
"no_value": "Not an email",
|
||||||
"model": LlmModel.default(),
|
"model": DEFAULT_LLM_MODEL,
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -5,7 +5,12 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal
|
|||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -10,13 +17,6 @@ from backend.blocks.apollo.models import (
|
|||||||
PrimaryPhone,
|
PrimaryPhone,
|
||||||
SearchOrganizationsRequest,
|
SearchOrganizationsRequest,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -14,13 +21,6 @@ from backend.blocks.apollo.models import (
|
|||||||
SearchPeopleRequest,
|
SearchPeopleRequest,
|
||||||
SenorityLevels,
|
SenorityLevels,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -6,13 +13,6 @@ from backend.blocks.apollo._auth import (
|
|||||||
ApolloCredentialsInput,
|
ApolloCredentialsInput,
|
||||||
)
|
)
|
||||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.data.block import BlockSchemaInput
|
from backend.blocks._base import BlockSchemaInput
|
||||||
from backend.data.model import SchemaField, UserIntegrations
|
from backend.data.model import SchemaField, UserIntegrations
|
||||||
from backend.integrations.ayrshare import AyrshareClient
|
from backend.integrations.ayrshare import AyrshareClient
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal, Optional
|
|||||||
from e2b import AsyncSandbox as BaseAsyncSandbox
|
from e2b import AsyncSandbox as BaseAsyncSandbox
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from e2b_code_interpreter import Result as E2BExecutionResult
|
|||||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||||
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from openai import AsyncOpenAI
|
|||||||
from openai.types.responses import Response as OpenAIResponse
|
from openai.types.responses import Response as OpenAIResponse
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockManualWebhookConfig,
|
BlockManualWebhookConfig,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import codecs
|
import codecs
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Any, Literal, cast
|
|||||||
import discord
|
import discord
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
Discord OAuth-based blocks.
|
Discord OAuth-based blocks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, SecretStr
|
from pydantic import BaseModel, ConfigDict, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ which provides access to LinkedIn profile data and related information.
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,6 +3,13 @@ import logging
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.fal._auth import (
|
from backend.blocks.fal._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
TEST_CREDENTIALS_INPUT,
|
TEST_CREDENTIALS_INPUT,
|
||||||
@@ -10,13 +17,6 @@ from backend.blocks.fal._auth import (
|
|||||||
FalCredentialsField,
|
FalCredentialsField,
|
||||||
FalCredentialsInput,
|
FalCredentialsInput,
|
||||||
)
|
)
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import re
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import base64
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Any, List, Optional
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from google.oauth2.credentials import Credentials
|
|||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -7,14 +7,14 @@ from google.oauth2.credentials import Credentials
|
|||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
from gravitas_md2gdocs import to_requests
|
from gravitas_md2gdocs import to_requests
|
||||||
|
|
||||||
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
|
from backend.blocks._base import (
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from google.oauth2.credentials import Credentials
|
|||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -7,14 +7,14 @@ from enum import Enum
|
|||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
|
from backend.blocks._base import (
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Literal
|
|||||||
import googlemaps
|
import googlemaps
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -9,9 +9,7 @@ from typing import Any, Optional
|
|||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.execution import ExecutionStatus
|
|
||||||
from backend.data.human_review import ReviewResult
|
from backend.data.human_review import ReviewResult
|
||||||
from backend.executor.manager import async_update_node_execution_status
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -43,6 +41,8 @@ class HITLReviewHelper:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_node_execution_status(**kwargs) -> None:
|
async def update_node_execution_status(**kwargs) -> None:
|
||||||
"""Update the execution status of a node."""
|
"""Update the execution status of a node."""
|
||||||
|
from backend.executor.manager import async_update_node_execution_status
|
||||||
|
|
||||||
await async_update_node_execution_status(
|
await async_update_node_execution_status(
|
||||||
db_client=get_database_manager_async_client(), **kwargs
|
db_client=get_database_manager_async_client(), **kwargs
|
||||||
)
|
)
|
||||||
@@ -88,12 +88,13 @@ class HITLReviewHelper:
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: If review creation or status update fails
|
Exception: If review creation or status update fails
|
||||||
"""
|
"""
|
||||||
|
from backend.data.execution import ExecutionStatus
|
||||||
|
|
||||||
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
|
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
|
||||||
# are handled by the caller:
|
# are handled by the caller:
|
||||||
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
|
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
|
||||||
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
|
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
|
||||||
# This function only handles checking for existing approvals.
|
# This function only handles checking for existing approvals.
|
||||||
|
|
||||||
# Check if this node has already been approved (normal or auto-approval)
|
# Check if this node has already been approved (normal or auto-approval)
|
||||||
if approval_result := await HITLReviewHelper.check_approval(
|
if approval_result := await HITLReviewHelper.check_approval(
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Literal
|
|||||||
import aiofiles
|
import aiofiles
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks.hubspot._auth import (
|
from backend.blocks._base import (
|
||||||
HubSpotCredentials,
|
|
||||||
HubSpotCredentialsField,
|
|
||||||
HubSpotCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.hubspot._auth import (
|
||||||
|
HubSpotCredentials,
|
||||||
|
HubSpotCredentialsField,
|
||||||
|
HubSpotCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks.hubspot._auth import (
|
from backend.blocks._base import (
|
||||||
HubSpotCredentials,
|
|
||||||
HubSpotCredentialsField,
|
|
||||||
HubSpotCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.hubspot._auth import (
|
||||||
|
HubSpotCredentials,
|
||||||
|
HubSpotCredentialsField,
|
||||||
|
HubSpotCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from backend.blocks.hubspot._auth import (
|
from backend.blocks._base import (
|
||||||
HubSpotCredentials,
|
|
||||||
HubSpotCredentialsField,
|
|
||||||
HubSpotCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.hubspot._auth import (
|
||||||
|
HubSpotCredentials,
|
||||||
|
HubSpotCredentialsField,
|
||||||
|
HubSpotCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ from typing import Any
|
|||||||
|
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
|
|
||||||
from backend.blocks.helpers.review import HITLReviewHelper
|
from backend.blocks._base import (
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
@@ -12,6 +11,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.helpers.review import HITLReviewHelper
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.human_review import ReviewResult
|
from backend.data.human_review import ReviewResult
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
@@ -21,43 +21,71 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class HumanInTheLoopBlock(Block):
|
class HumanInTheLoopBlock(Block):
|
||||||
"""
|
"""
|
||||||
This block pauses execution and waits for human approval or modification of the data.
|
Pauses execution and waits for human approval or rejection of the data.
|
||||||
|
|
||||||
When executed, it creates a pending review entry and sets the node execution status
|
When executed, this block creates a pending review entry and sets the node execution
|
||||||
to REVIEW. The execution will remain paused until a human user either:
|
status to REVIEW. The execution remains paused until a human user either approves
|
||||||
- Approves the data (with or without modifications)
|
or rejects the data.
|
||||||
- Rejects the data
|
|
||||||
|
|
||||||
This is useful for workflows that require human validation or intervention before
|
**How it works:**
|
||||||
proceeding to the next steps.
|
- The input data is presented to a human reviewer
|
||||||
|
- The reviewer can approve or reject (and optionally modify the data if editable)
|
||||||
|
- On approval: the data flows out through the `approved_data` output pin
|
||||||
|
- On rejection: the data flows out through the `rejected_data` output pin
|
||||||
|
|
||||||
|
**Important:** The output pins yield the actual data itself, NOT status strings.
|
||||||
|
The approval/rejection decision determines WHICH output pin fires, not the value.
|
||||||
|
You do NOT need to compare the output to "APPROVED" or "REJECTED" - simply connect
|
||||||
|
downstream blocks to the appropriate output pin for each case.
|
||||||
|
|
||||||
|
**Example usage:**
|
||||||
|
- Connect `approved_data` → next step in your workflow (data was approved)
|
||||||
|
- Connect `rejected_data` → error handling or notification (data was rejected)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
data: Any = SchemaField(description="The data to be reviewed by a human user")
|
data: Any = SchemaField(
|
||||||
|
description="The data to be reviewed by a human user. "
|
||||||
|
"This exact data will be passed through to either approved_data or "
|
||||||
|
"rejected_data output based on the reviewer's decision."
|
||||||
|
)
|
||||||
name: str = SchemaField(
|
name: str = SchemaField(
|
||||||
description="A descriptive name for what this data represents",
|
description="A descriptive name for what this data represents. "
|
||||||
|
"This helps the reviewer understand what they are reviewing.",
|
||||||
)
|
)
|
||||||
editable: bool = SchemaField(
|
editable: bool = SchemaField(
|
||||||
description="Whether the human reviewer can edit the data",
|
description="Whether the human reviewer can edit the data before "
|
||||||
|
"approving or rejecting it",
|
||||||
default=True,
|
default=True,
|
||||||
advanced=True,
|
advanced=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
approved_data: Any = SchemaField(
|
approved_data: Any = SchemaField(
|
||||||
description="The data when approved (may be modified by reviewer)"
|
description="Outputs the input data when the reviewer APPROVES it. "
|
||||||
|
"The value is the actual data itself (not a status string like 'APPROVED'). "
|
||||||
|
"If the reviewer edited the data, this contains the modified version. "
|
||||||
|
"Connect downstream blocks here for the 'approved' workflow path."
|
||||||
)
|
)
|
||||||
rejected_data: Any = SchemaField(
|
rejected_data: Any = SchemaField(
|
||||||
description="The data when rejected (may be modified by reviewer)"
|
description="Outputs the input data when the reviewer REJECTS it. "
|
||||||
|
"The value is the actual data itself (not a status string like 'REJECTED'). "
|
||||||
|
"If the reviewer edited the data, this contains the modified version. "
|
||||||
|
"Connect downstream blocks here for the 'rejected' workflow path."
|
||||||
)
|
)
|
||||||
review_message: str = SchemaField(
|
review_message: str = SchemaField(
|
||||||
description="Any message provided by the reviewer", default=""
|
description="Optional message provided by the reviewer explaining their "
|
||||||
|
"decision. Only outputs when the reviewer provides a message; "
|
||||||
|
"this pin does not fire if no message was given.",
|
||||||
|
default="",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id="8b2a7b3c-6e9d-4a5f-8c1b-2e3f4a5b6c7d",
|
id="8b2a7b3c-6e9d-4a5f-8c1b-2e3f4a5b6c7d",
|
||||||
description="Pause execution and wait for human approval or modification of data",
|
description="Pause execution for human review. Data flows through "
|
||||||
|
"approved_data or rejected_data output based on the reviewer's decision. "
|
||||||
|
"Outputs contain the actual data, not status strings.",
|
||||||
categories={BlockCategory.BASIC},
|
categories={BlockCategory.BASIC},
|
||||||
input_schema=HumanInTheLoopBlock.Input,
|
input_schema=HumanInTheLoopBlock.Input,
|
||||||
output_schema=HumanInTheLoopBlock.Output,
|
output_schema=HumanInTheLoopBlock.Output,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, Literal, Optional
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ import copy
|
|||||||
from datetime import date, time
|
from datetime import date, time
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
# Import for Google Drive file input block
|
from backend.blocks._base import (
|
||||||
from backend.blocks.google._drive import AttachmentView, GoogleDriveFile
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
@@ -12,6 +10,9 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import for Google Drive file input block
|
||||||
|
from backend.blocks.google._drive import AttachmentView, GoogleDriveFile
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks.jina._auth import (
|
from backend.blocks._base import (
|
||||||
JinaCredentials,
|
|
||||||
JinaCredentialsField,
|
|
||||||
JinaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.jina._auth import (
|
||||||
|
JinaCredentials,
|
||||||
|
JinaCredentialsField,
|
||||||
|
JinaCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks.jina._auth import (
|
from backend.blocks._base import (
|
||||||
JinaCredentials,
|
|
||||||
JinaCredentialsField,
|
|
||||||
JinaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.jina._auth import (
|
||||||
|
JinaCredentials,
|
||||||
|
JinaCredentialsField,
|
||||||
|
JinaCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -3,18 +3,18 @@ from urllib.parse import quote
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.blocks.jina._auth import (
|
from backend.blocks._base import (
|
||||||
JinaCredentials,
|
|
||||||
JinaCredentialsField,
|
|
||||||
JinaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.jina._auth import (
|
||||||
|
JinaCredentials,
|
||||||
|
JinaCredentialsField,
|
||||||
|
JinaCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.blocks.jina._auth import (
|
from backend.blocks.jina._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
TEST_CREDENTIALS_INPUT,
|
TEST_CREDENTIALS_INPUT,
|
||||||
@@ -8,13 +15,6 @@ from backend.blocks.jina._auth import (
|
|||||||
JinaCredentialsInput,
|
JinaCredentialsInput,
|
||||||
)
|
)
|
||||||
from backend.blocks.search import GetRequest
|
from backend.blocks.search import GetRequest
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.exceptions import BlockExecutionError
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
|
||||||
|
|||||||
@@ -4,27 +4,24 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from enum import Enum
|
from enum import Enum, EnumMeta
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any, Iterable, List, Literal, Optional
|
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
import ollama
|
import ollama
|
||||||
import openai
|
import openai
|
||||||
from anthropic.types import ToolParam
|
from anthropic.types import ToolParam
|
||||||
from groq import AsyncGroq
|
from groq import AsyncGroq
|
||||||
from pydantic import BaseModel, GetCoreSchemaHandler, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
from pydantic_core import CoreSchema, core_schema
|
|
||||||
|
|
||||||
from backend.data import llm_registry
|
from backend.blocks._base import (
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.llm_registry import ModelMetadata
|
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -69,123 +66,114 @@ TEST_CREDENTIALS_INPUT = {
|
|||||||
|
|
||||||
|
|
||||||
def AICredentialsField() -> AICredentials:
|
def AICredentialsField() -> AICredentials:
|
||||||
"""
|
|
||||||
Returns a CredentialsField for LLM providers.
|
|
||||||
The discriminator_mapping will be refreshed when the schema is generated
|
|
||||||
if it's empty, ensuring the LLM registry is loaded.
|
|
||||||
"""
|
|
||||||
# Get the mapping now - it may be empty initially, but will be refreshed
|
|
||||||
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
|
|
||||||
mapping = llm_registry.get_llm_discriminator_mapping()
|
|
||||||
|
|
||||||
return CredentialsField(
|
return CredentialsField(
|
||||||
description="API key for the LLM provider.",
|
description="API key for the LLM provider.",
|
||||||
discriminator="model",
|
discriminator="model",
|
||||||
discriminator_mapping=mapping, # May be empty initially, refreshed later
|
discriminator_mapping={
|
||||||
|
model.value: model.metadata.provider for model in LlmModel
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def llm_model_schema_extra() -> dict[str, Any]:
|
class ModelMetadata(NamedTuple):
|
||||||
return {"options": llm_registry.get_llm_model_schema_options()}
|
provider: str
|
||||||
|
context_window: int
|
||||||
|
max_output_tokens: int | None
|
||||||
|
display_name: str
|
||||||
|
provider_name: str
|
||||||
|
creator_name: str
|
||||||
|
price_tier: Literal[1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
class LlmModelMeta(type):
|
class LlmModelMeta(EnumMeta):
|
||||||
"""
|
pass
|
||||||
Metaclass for LlmModel that enables attribute-style access to dynamic models.
|
|
||||||
|
|
||||||
This allows code like `LlmModel.GPT4O` to work by converting the attribute
|
|
||||||
name to a slug format:
|
|
||||||
- GPT4O -> gpt-4o
|
|
||||||
- GPT4O_MINI -> gpt-4o-mini
|
|
||||||
- CLAUDE_3_5_SONNET -> claude-3-5-sonnet
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __getattr__(cls, name: str):
|
|
||||||
# Don't intercept private/dunder attributes
|
|
||||||
if name.startswith("_"):
|
|
||||||
raise AttributeError(f"type object 'LlmModel' has no attribute '{name}'")
|
|
||||||
|
|
||||||
# Convert attribute name to slug format:
|
|
||||||
# 1. Lowercase: GPT4O -> gpt4o
|
|
||||||
# 2. Underscores to hyphens: GPT4O_MINI -> gpt4o-mini
|
|
||||||
slug = name.lower().replace("_", "-")
|
|
||||||
|
|
||||||
# Check for exact match in registry first (e.g., "o1" stays "o1")
|
|
||||||
registry_slugs = llm_registry.get_dynamic_model_slugs()
|
|
||||||
if slug in registry_slugs:
|
|
||||||
return cls(slug)
|
|
||||||
|
|
||||||
# If no exact match, try inserting hyphen between letter and digit
|
|
||||||
# e.g., gpt4o -> gpt-4o
|
|
||||||
transformed_slug = re.sub(r"([a-z])(\d)", r"\1-\2", slug)
|
|
||||||
return cls(transformed_slug)
|
|
||||||
|
|
||||||
def __iter__(cls):
|
|
||||||
"""Iterate over all models from the registry.
|
|
||||||
|
|
||||||
Yields LlmModel instances for each model in the dynamic registry.
|
|
||||||
Used by __get_pydantic_json_schema__ to build model metadata.
|
|
||||||
"""
|
|
||||||
for model in llm_registry.iter_dynamic_models():
|
|
||||||
yield cls(model.slug)
|
|
||||||
|
|
||||||
|
|
||||||
class LlmModel(str, metaclass=LlmModelMeta):
|
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||||
"""
|
# OpenAI models
|
||||||
Dynamic LLM model type that accepts any model slug from the registry.
|
O3_MINI = "o3-mini"
|
||||||
|
O3 = "o3-2025-04-16"
|
||||||
This is a string subclass (not an Enum) that allows any model slug value.
|
O1 = "o1"
|
||||||
All models are managed via the LLM Registry in the database.
|
O1_MINI = "o1-mini"
|
||||||
|
# GPT-5 models
|
||||||
Usage:
|
GPT5_2 = "gpt-5.2-2025-12-11"
|
||||||
model = LlmModel("gpt-4o") # Direct construction
|
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||||
model = LlmModel.GPT4O # Attribute access (converted to "gpt-4o")
|
GPT5 = "gpt-5-2025-08-07"
|
||||||
model.value # Returns the slug string
|
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||||
model.provider # Returns the provider from registry
|
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||||
"""
|
GPT5_CHAT = "gpt-5-chat-latest"
|
||||||
|
GPT41 = "gpt-4.1-2025-04-14"
|
||||||
def __new__(cls, value: str):
|
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||||
if isinstance(value, LlmModel):
|
GPT4O_MINI = "gpt-4o-mini"
|
||||||
return value
|
GPT4O = "gpt-4o"
|
||||||
return str.__new__(cls, value)
|
GPT4_TURBO = "gpt-4-turbo"
|
||||||
|
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||||
@classmethod
|
# Anthropic models
|
||||||
def __get_pydantic_core_schema__(
|
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||||
) -> CoreSchema:
|
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||||
"""
|
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||||
Tell Pydantic how to validate LlmModel.
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
|
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||||
Accepts strings and converts them to LlmModel instances.
|
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||||
"""
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
return core_schema.no_info_after_validator_function(
|
# AI/ML API models
|
||||||
cls, # The validator function (LlmModel constructor)
|
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||||
core_schema.str_schema(), # Accept string input
|
AIML_API_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct"
|
||||||
serialization=core_schema.to_string_ser_schema(), # Serialize as string
|
AIML_API_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
||||||
)
|
AIML_API_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
|
||||||
|
AIML_API_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
|
||||||
@property
|
# Groq models
|
||||||
def value(self) -> str:
|
LLAMA3_3_70B = "llama-3.3-70b-versatile"
|
||||||
"""Return the model slug (for compatibility with enum-style access)."""
|
LLAMA3_1_8B = "llama-3.1-8b-instant"
|
||||||
return str(self)
|
# Ollama models
|
||||||
|
OLLAMA_LLAMA3_3 = "llama3.3"
|
||||||
@classmethod
|
OLLAMA_LLAMA3_2 = "llama3.2"
|
||||||
def default(cls) -> "LlmModel":
|
OLLAMA_LLAMA3_8B = "llama3"
|
||||||
"""
|
OLLAMA_LLAMA3_405B = "llama3.1:405b"
|
||||||
Get the default model from the registry.
|
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||||
|
# OpenRouter models
|
||||||
Returns the recommended model if set, otherwise gpt-4o if available
|
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||||
and enabled, otherwise the first enabled model from the registry.
|
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||||
Falls back to "gpt-4o" if registry is empty (e.g., at module import time).
|
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||||
"""
|
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
||||||
from backend.data.llm_registry import get_default_model_slug
|
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||||
|
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||||
slug = get_default_model_slug()
|
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||||
if slug is None:
|
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||||
# Registry is empty (e.g., at module import time before DB connection).
|
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||||
# Fall back to gpt-4o for backward compatibility.
|
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||||
slug = "gpt-4o"
|
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||||
return cls(slug)
|
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||||
|
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||||
|
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||||
|
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||||
|
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||||
|
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||||
|
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||||
|
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
|
||||||
|
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||||
|
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||||
|
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||||
|
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||||
|
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||||
|
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||||
|
GROK_4 = "x-ai/grok-4"
|
||||||
|
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||||
|
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||||
|
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
|
||||||
|
KIMI_K2 = "moonshotai/kimi-k2"
|
||||||
|
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||||
|
QWEN3_CODER = "qwen/qwen3-coder"
|
||||||
|
# Llama API models
|
||||||
|
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||||
|
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||||
|
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
|
||||||
|
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
|
||||||
|
# v0 by Vercel models
|
||||||
|
V0_1_5_MD = "v0-1.5-md"
|
||||||
|
V0_1_5_LG = "v0-1.5-lg"
|
||||||
|
V0_1_0_MD = "v0-1.0-md"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_pydantic_json_schema__(cls, schema, handler):
|
def __get_pydantic_json_schema__(cls, schema, handler):
|
||||||
@@ -193,15 +181,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
|
|||||||
llm_model_metadata = {}
|
llm_model_metadata = {}
|
||||||
for model in cls:
|
for model in cls:
|
||||||
model_name = model.value
|
model_name = model.value
|
||||||
# Skip disabled models - only show enabled models in the picker
|
metadata = model.metadata
|
||||||
if not llm_registry.is_model_enabled(model_name):
|
|
||||||
continue
|
|
||||||
# Use registry directly with None check to gracefully handle
|
|
||||||
# missing metadata during startup/import before registry is populated
|
|
||||||
metadata = llm_registry.get_llm_model_metadata(model_name)
|
|
||||||
if metadata is None:
|
|
||||||
# Skip models without metadata (registry not yet populated)
|
|
||||||
continue
|
|
||||||
llm_model_metadata[model_name] = {
|
llm_model_metadata[model_name] = {
|
||||||
"creator": metadata.creator_name,
|
"creator": metadata.creator_name,
|
||||||
"creator_name": metadata.creator_name,
|
"creator_name": metadata.creator_name,
|
||||||
@@ -217,12 +197,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def metadata(self) -> ModelMetadata:
|
def metadata(self) -> ModelMetadata:
|
||||||
metadata = llm_registry.get_llm_model_metadata(self.value)
|
return MODEL_METADATA[self]
|
||||||
if metadata:
|
|
||||||
return metadata
|
|
||||||
raise ValueError(
|
|
||||||
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider(self) -> str:
|
def provider(self) -> str:
|
||||||
@@ -237,9 +212,300 @@ class LlmModel(str, metaclass=LlmModelMeta):
|
|||||||
return self.metadata.max_output_tokens
|
return self.metadata.max_output_tokens
|
||||||
|
|
||||||
|
|
||||||
# Default model constant for backward compatibility
|
MODEL_METADATA = {
|
||||||
# Uses the dynamic registry to get the default model
|
# https://platform.openai.com/docs/models
|
||||||
DEFAULT_LLM_MODEL = LlmModel.default()
|
LlmModel.O3: ModelMetadata("openai", 200000, 100000, "O3", "OpenAI", "OpenAI", 2),
|
||||||
|
LlmModel.O3_MINI: ModelMetadata(
|
||||||
|
"openai", 200000, 100000, "O3 Mini", "OpenAI", "OpenAI", 1
|
||||||
|
), # o3-mini-2025-01-31
|
||||||
|
LlmModel.O1: ModelMetadata(
|
||||||
|
"openai", 200000, 100000, "O1", "OpenAI", "OpenAI", 3
|
||||||
|
), # o1-2024-12-17
|
||||||
|
LlmModel.O1_MINI: ModelMetadata(
|
||||||
|
"openai", 128000, 65536, "O1 Mini", "OpenAI", "OpenAI", 2
|
||||||
|
), # o1-mini-2024-09-12
|
||||||
|
# GPT-5 models
|
||||||
|
LlmModel.GPT5_2: ModelMetadata(
|
||||||
|
"openai", 400000, 128000, "GPT-5.2", "OpenAI", "OpenAI", 3
|
||||||
|
),
|
||||||
|
LlmModel.GPT5_1: ModelMetadata(
|
||||||
|
"openai", 400000, 128000, "GPT-5.1", "OpenAI", "OpenAI", 2
|
||||||
|
),
|
||||||
|
LlmModel.GPT5: ModelMetadata(
|
||||||
|
"openai", 400000, 128000, "GPT-5", "OpenAI", "OpenAI", 1
|
||||||
|
),
|
||||||
|
LlmModel.GPT5_MINI: ModelMetadata(
|
||||||
|
"openai", 400000, 128000, "GPT-5 Mini", "OpenAI", "OpenAI", 1
|
||||||
|
),
|
||||||
|
LlmModel.GPT5_NANO: ModelMetadata(
|
||||||
|
"openai", 400000, 128000, "GPT-5 Nano", "OpenAI", "OpenAI", 1
|
||||||
|
),
|
||||||
|
LlmModel.GPT5_CHAT: ModelMetadata(
|
||||||
|
"openai", 400000, 16384, "GPT-5 Chat Latest", "OpenAI", "OpenAI", 2
|
||||||
|
),
|
||||||
|
LlmModel.GPT41: ModelMetadata(
|
||||||
|
"openai", 1047576, 32768, "GPT-4.1", "OpenAI", "OpenAI", 1
|
||||||
|
),
|
||||||
|
LlmModel.GPT41_MINI: ModelMetadata(
|
||||||
|
"openai", 1047576, 32768, "GPT-4.1 Mini", "OpenAI", "OpenAI", 1
|
||||||
|
),
|
||||||
|
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||||
|
"openai", 128000, 16384, "GPT-4o Mini", "OpenAI", "OpenAI", 1
|
||||||
|
), # gpt-4o-mini-2024-07-18
|
||||||
|
LlmModel.GPT4O: ModelMetadata(
|
||||||
|
"openai", 128000, 16384, "GPT-4o", "OpenAI", "OpenAI", 2
|
||||||
|
), # gpt-4o-2024-08-06
|
||||||
|
LlmModel.GPT4_TURBO: ModelMetadata(
|
||||||
|
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
|
||||||
|
), # gpt-4-turbo-2024-04-09
|
||||||
|
LlmModel.GPT3_5_TURBO: ModelMetadata(
|
||||||
|
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
|
||||||
|
), # gpt-3.5-turbo-0125
|
||||||
|
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||||
|
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||||
|
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
|
||||||
|
), # claude-opus-4-1-20250805
|
||||||
|
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||||
|
"anthropic", 200000, 32000, "Claude Opus 4", "Anthropic", "Anthropic", 3
|
||||||
|
), # claude-4-opus-20250514
|
||||||
|
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||||
|
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||||
|
), # claude-4-sonnet-20250514
|
||||||
|
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||||
|
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||||
|
), # claude-opus-4-6
|
||||||
|
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||||
|
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||||
|
), # claude-opus-4-5-20251101
|
||||||
|
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
||||||
|
"anthropic", 200000, 64000, "Claude Sonnet 4.5", "Anthropic", "Anthropic", 3
|
||||||
|
), # claude-sonnet-4-5-20250929
|
||||||
|
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||||
|
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
||||||
|
), # claude-haiku-4-5-20251001
|
||||||
|
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||||
|
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
||||||
|
), # claude-3-haiku-20240307
|
||||||
|
# https://docs.aimlapi.com/api-overview/model-database/text-models
|
||||||
|
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata(
|
||||||
|
"aiml_api", 32000, 8000, "Qwen 2.5 72B Instruct Turbo", "AI/ML", "Qwen", 1
|
||||||
|
),
|
||||||
|
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata(
|
||||||
|
"aiml_api",
|
||||||
|
128000,
|
||||||
|
40000,
|
||||||
|
"Llama 3.1 Nemotron 70B Instruct",
|
||||||
|
"AI/ML",
|
||||||
|
"Nvidia",
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata(
|
||||||
|
"aiml_api", 128000, None, "Llama 3.3 70B Instruct Turbo", "AI/ML", "Meta", 1
|
||||||
|
),
|
||||||
|
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata(
|
||||||
|
"aiml_api", 131000, 2000, "Llama 3.1 70B Instruct Turbo", "AI/ML", "Meta", 1
|
||||||
|
),
|
||||||
|
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata(
|
||||||
|
"aiml_api", 128000, None, "Llama 3.2 3B Instruct Turbo", "AI/ML", "Meta", 1
|
||||||
|
),
|
||||||
|
# https://console.groq.com/docs/models
|
||||||
|
LlmModel.LLAMA3_3_70B: ModelMetadata(
|
||||||
|
"groq", 128000, 32768, "Llama 3.3 70B Versatile", "Groq", "Meta", 1
|
||||||
|
),
|
||||||
|
LlmModel.LLAMA3_1_8B: ModelMetadata(
|
||||||
|
"groq", 128000, 8192, "Llama 3.1 8B Instant", "Groq", "Meta", 1
|
||||||
|
),
|
||||||
|
# https://ollama.com/library
|
||||||
|
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata(
|
||||||
|
"ollama", 8192, None, "Llama 3.3", "Ollama", "Meta", 1
|
||||||
|
),
|
||||||
|
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata(
|
||||||
|
"ollama", 8192, None, "Llama 3.2", "Ollama", "Meta", 1
|
||||||
|
),
|
||||||
|
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata(
|
||||||
|
"ollama", 8192, None, "Llama 3", "Ollama", "Meta", 1
|
||||||
|
),
|
||||||
|
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata(
|
||||||
|
"ollama", 8192, None, "Llama 3.1 405B", "Ollama", "Meta", 1
|
||||||
|
),
|
||||||
|
LlmModel.OLLAMA_DOLPHIN: ModelMetadata(
|
||||||
|
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||||
|
),
|
||||||
|
# https://openrouter.ai/models
|
||||||
|
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
1050000,
|
||||||
|
8192,
|
||||||
|
"Gemini 2.5 Pro Preview 03.25",
|
||||||
|
"OpenRouter",
|
||||||
|
"Google",
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
||||||
|
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
||||||
|
),
|
||||||
|
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||||
|
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||||
|
),
|
||||||
|
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||||
|
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||||
|
),
|
||||||
|
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
1048576,
|
||||||
|
65535,
|
||||||
|
"Gemini 2.5 Flash Lite Preview 06.17",
|
||||||
|
"OpenRouter",
|
||||||
|
"Google",
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
1048576,
|
||||||
|
8192,
|
||||||
|
"Gemini 2.0 Flash Lite 001",
|
||||||
|
"OpenRouter",
|
||||||
|
"Google",
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||||
|
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||||
|
),
|
||||||
|
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||||
|
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||||
|
),
|
||||||
|
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||||
|
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||||
|
),
|
||||||
|
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||||
|
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||||
|
),
|
||||||
|
LlmModel.DEEPSEEK_R1_0528: ModelMetadata(
|
||||||
|
"open_router", 163840, 163840, "DeepSeek R1 0528", "OpenRouter", "DeepSeek", 1
|
||||||
|
),
|
||||||
|
LlmModel.PERPLEXITY_SONAR: ModelMetadata(
|
||||||
|
"open_router", 127000, 8000, "Sonar", "OpenRouter", "Perplexity", 1
|
||||||
|
),
|
||||||
|
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||||
|
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||||
|
),
|
||||||
|
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
128000,
|
||||||
|
16000,
|
||||||
|
"Sonar Deep Research",
|
||||||
|
"OpenRouter",
|
||||||
|
"Perplexity",
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
131000,
|
||||||
|
4096,
|
||||||
|
"Hermes 3 Llama 3.1 405B",
|
||||||
|
"OpenRouter",
|
||||||
|
"Nous Research",
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
12288,
|
||||||
|
12288,
|
||||||
|
"Hermes 3 Llama 3.1 70B",
|
||||||
|
"OpenRouter",
|
||||||
|
"Nous Research",
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata(
|
||||||
|
"open_router", 131072, 131072, "GPT-OSS 120B", "OpenRouter", "OpenAI", 1
|
||||||
|
),
|
||||||
|
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata(
|
||||||
|
"open_router", 131072, 32768, "GPT-OSS 20B", "OpenRouter", "OpenAI", 1
|
||||||
|
),
|
||||||
|
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata(
|
||||||
|
"open_router", 300000, 5120, "Nova Lite V1", "OpenRouter", "Amazon", 1
|
||||||
|
),
|
||||||
|
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata(
|
||||||
|
"open_router", 128000, 5120, "Nova Micro V1", "OpenRouter", "Amazon", 1
|
||||||
|
),
|
||||||
|
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata(
|
||||||
|
"open_router", 300000, 5120, "Nova Pro V1", "OpenRouter", "Amazon", 1
|
||||||
|
),
|
||||||
|
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||||
|
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||||
|
),
|
||||||
|
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||||
|
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||||
|
),
|
||||||
|
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata(
|
||||||
|
"open_router", 131072, 131072, "Llama 4 Scout", "OpenRouter", "Meta", 1
|
||||||
|
),
|
||||||
|
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||||
|
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||||
|
),
|
||||||
|
LlmModel.GROK_4: ModelMetadata(
|
||||||
|
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||||
|
),
|
||||||
|
LlmModel.GROK_4_FAST: ModelMetadata(
|
||||||
|
"open_router", 2000000, 30000, "Grok 4 Fast", "OpenRouter", "xAI", 1
|
||||||
|
),
|
||||||
|
LlmModel.GROK_4_1_FAST: ModelMetadata(
|
||||||
|
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
|
||||||
|
),
|
||||||
|
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
|
||||||
|
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
|
||||||
|
),
|
||||||
|
LlmModel.KIMI_K2: ModelMetadata(
|
||||||
|
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
|
||||||
|
),
|
||||||
|
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
262144,
|
||||||
|
262144,
|
||||||
|
"Qwen 3 235B A22B Thinking 2507",
|
||||||
|
"OpenRouter",
|
||||||
|
"Qwen",
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
LlmModel.QWEN3_CODER: ModelMetadata(
|
||||||
|
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
|
||||||
|
),
|
||||||
|
# Llama API models
|
||||||
|
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
|
||||||
|
"llama_api",
|
||||||
|
128000,
|
||||||
|
4028,
|
||||||
|
"Llama 4 Scout 17B 16E Instruct FP8",
|
||||||
|
"Llama API",
|
||||||
|
"Meta",
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata(
|
||||||
|
"llama_api",
|
||||||
|
128000,
|
||||||
|
4028,
|
||||||
|
"Llama 4 Maverick 17B 128E Instruct FP8",
|
||||||
|
"Llama API",
|
||||||
|
"Meta",
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata(
|
||||||
|
"llama_api", 128000, 4028, "Llama 3.3 8B Instruct", "Llama API", "Meta", 1
|
||||||
|
),
|
||||||
|
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata(
|
||||||
|
"llama_api", 128000, 4028, "Llama 3.3 70B Instruct", "Llama API", "Meta", 1
|
||||||
|
),
|
||||||
|
# v0 by Vercel models
|
||||||
|
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000, "v0 1.5 MD", "V0", "V0", 1),
|
||||||
|
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000, "v0 1.5 LG", "V0", "V0", 1),
|
||||||
|
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000, "v0 1.0 MD", "V0", "V0", 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
|
||||||
|
|
||||||
|
for model in LlmModel:
|
||||||
|
if model not in MODEL_METADATA:
|
||||||
|
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
@@ -332,11 +598,8 @@ def get_parallel_tool_calls_param(
|
|||||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||||
) -> bool | openai.Omit:
|
) -> bool | openai.Omit:
|
||||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||||
# Check for o-series models (o1, o1-mini, o3-mini, etc.) which don't support
|
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||||
# parallel tool calls. Use regex to avoid false positives like "openai/gpt-oss".
|
return openai.omit
|
||||||
is_o_series = re.match(r"^o\d", llm_model) is not None
|
|
||||||
if is_o_series or parallel_tool_calls is None:
|
|
||||||
return openai.NOT_GIVEN
|
|
||||||
return parallel_tool_calls
|
return parallel_tool_calls
|
||||||
|
|
||||||
|
|
||||||
@@ -371,93 +634,15 @@ async def llm_call(
|
|||||||
- prompt_tokens: The number of tokens used in the prompt.
|
- prompt_tokens: The number of tokens used in the prompt.
|
||||||
- completion_tokens: The number of tokens used in the completion.
|
- completion_tokens: The number of tokens used in the completion.
|
||||||
"""
|
"""
|
||||||
# Get model metadata and check if enabled - with fallback support
|
provider = llm_model.metadata.provider
|
||||||
# The model we'll actually use (may differ if original is disabled)
|
context_window = llm_model.context_window
|
||||||
model_to_use = llm_model.value
|
|
||||||
|
|
||||||
# Check if model is in registry and if it's enabled
|
|
||||||
from backend.data.llm_registry import (
|
|
||||||
get_fallback_model_for_disabled,
|
|
||||||
get_model_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_info = get_model_info(llm_model.value)
|
|
||||||
|
|
||||||
if model_info and not model_info.is_enabled:
|
|
||||||
# Model is disabled - try to find a fallback from the same provider
|
|
||||||
fallback = get_fallback_model_for_disabled(llm_model.value)
|
|
||||||
if fallback:
|
|
||||||
logger.warning(
|
|
||||||
f"Model '{llm_model.value}' is disabled. Using fallback model '{fallback.slug}' from the same provider ({fallback.metadata.provider})."
|
|
||||||
)
|
|
||||||
model_to_use = fallback.slug
|
|
||||||
# Use fallback model's metadata
|
|
||||||
provider = fallback.metadata.provider
|
|
||||||
context_window = fallback.metadata.context_window
|
|
||||||
model_max_output = fallback.metadata.max_output_tokens or int(2**15)
|
|
||||||
else:
|
|
||||||
# No fallback available - raise error
|
|
||||||
raise ValueError(
|
|
||||||
f"LLM model '{llm_model.value}' is disabled and no fallback model "
|
|
||||||
f"from the same provider is available. Please enable the model or "
|
|
||||||
f"select a different model in the block configuration."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Model is enabled or not in registry (legacy/static model)
|
|
||||||
try:
|
|
||||||
provider = llm_model.metadata.provider
|
|
||||||
context_window = llm_model.context_window
|
|
||||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
|
||||||
except ValueError:
|
|
||||||
# Model not in cache - try refreshing the registry once if we have DB access
|
|
||||||
logger.warning(f"Model {llm_model.value} not found in registry cache")
|
|
||||||
|
|
||||||
# Try refreshing the registry if we have database access
|
|
||||||
from backend.data.db import is_connected
|
|
||||||
|
|
||||||
if is_connected():
|
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
f"Refreshing LLM registry and retrying lookup for {llm_model.value}"
|
|
||||||
)
|
|
||||||
await llm_registry.refresh_llm_registry()
|
|
||||||
# Try again after refresh
|
|
||||||
try:
|
|
||||||
provider = llm_model.metadata.provider
|
|
||||||
context_window = llm_model.context_window
|
|
||||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
|
||||||
logger.info(
|
|
||||||
f"Successfully loaded model {llm_model.value} metadata after registry refresh"
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
# Still not found after refresh
|
|
||||||
raise ValueError(
|
|
||||||
f"LLM model '{llm_model.value}' not found in registry after refresh. "
|
|
||||||
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
|
|
||||||
)
|
|
||||||
except Exception as refresh_exc:
|
|
||||||
logger.error(f"Failed to refresh LLM registry: {refresh_exc}")
|
|
||||||
raise ValueError(
|
|
||||||
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
|
|
||||||
"Please ensure the model is added to the LLM registry via the admin UI."
|
|
||||||
) from refresh_exc
|
|
||||||
else:
|
|
||||||
# No DB access (e.g., in executor without direct DB connection)
|
|
||||||
# The registry should have been loaded on startup
|
|
||||||
raise ValueError(
|
|
||||||
f"LLM model '{llm_model.value}' not found in registry cache. "
|
|
||||||
"The registry may need to be refreshed. Please contact support or try again later."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create effective model for model-specific parameter resolution (e.g., o-series check)
|
|
||||||
# This uses the resolved model_to_use which may differ from llm_model if fallback occurred
|
|
||||||
effective_model = LlmModel(model_to_use)
|
|
||||||
|
|
||||||
if compress_prompt_to_fit:
|
if compress_prompt_to_fit:
|
||||||
result = await compress_context(
|
result = await compress_context(
|
||||||
messages=prompt,
|
messages=prompt,
|
||||||
target_tokens=context_window // 2,
|
target_tokens=llm_model.context_window // 2,
|
||||||
client=None, # Truncation-only, no LLM summarization
|
client=None, # Truncation-only, no LLM summarization
|
||||||
|
reserve=0, # Caller handles response token budget separately
|
||||||
)
|
)
|
||||||
if result.error:
|
if result.error:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -468,7 +653,7 @@ async def llm_call(
|
|||||||
|
|
||||||
# Calculate available tokens based on context window and input length
|
# Calculate available tokens based on context window and input length
|
||||||
estimated_input_tokens = estimate_token_count(prompt)
|
estimated_input_tokens = estimate_token_count(prompt)
|
||||||
# model_max_output already set above
|
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||||
user_max = max_tokens or model_max_output
|
user_max = max_tokens or model_max_output
|
||||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||||
@@ -479,14 +664,14 @@ async def llm_call(
|
|||||||
response_format = None
|
response_format = None
|
||||||
|
|
||||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||||
effective_model, parallel_tool_calls
|
llm_model, parallel_tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
if force_json_output:
|
if force_json_output:
|
||||||
response_format = {"type": "json_object"}
|
response_format = {"type": "json_object"}
|
||||||
|
|
||||||
response = await oai_client.chat.completions.create(
|
response = await oai_client.chat.completions.create(
|
||||||
model=model_to_use,
|
model=llm_model.value,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
response_format=response_format, # type: ignore
|
response_format=response_format, # type: ignore
|
||||||
max_completion_tokens=max_tokens,
|
max_completion_tokens=max_tokens,
|
||||||
@@ -533,7 +718,7 @@ async def llm_call(
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
resp = await client.messages.create(
|
resp = await client.messages.create(
|
||||||
model=model_to_use,
|
model=llm_model.value,
|
||||||
system=sysprompt,
|
system=sysprompt,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@@ -597,7 +782,7 @@ async def llm_call(
|
|||||||
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
||||||
response_format = {"type": "json_object"} if force_json_output else None
|
response_format = {"type": "json_object"} if force_json_output else None
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
model=model_to_use,
|
model=llm_model.value,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
response_format=response_format, # type: ignore
|
response_format=response_format, # type: ignore
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@@ -619,7 +804,7 @@ async def llm_call(
|
|||||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||||
response = await client.generate(
|
response = await client.generate(
|
||||||
model=model_to_use,
|
model=llm_model.value,
|
||||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||||
stream=False,
|
stream=False,
|
||||||
options={"num_ctx": max_tokens},
|
options={"num_ctx": max_tokens},
|
||||||
@@ -641,7 +826,7 @@ async def llm_call(
|
|||||||
)
|
)
|
||||||
|
|
||||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||||
effective_model, parallel_tool_calls
|
llm_model, parallel_tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
@@ -649,7 +834,7 @@ async def llm_call(
|
|||||||
"HTTP-Referer": "https://agpt.co",
|
"HTTP-Referer": "https://agpt.co",
|
||||||
"X-Title": "AutoGPT",
|
"X-Title": "AutoGPT",
|
||||||
},
|
},
|
||||||
model=model_to_use,
|
model=llm_model.value,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
tools=tools_param, # type: ignore
|
tools=tools_param, # type: ignore
|
||||||
@@ -683,7 +868,7 @@ async def llm_call(
|
|||||||
)
|
)
|
||||||
|
|
||||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||||
effective_model, parallel_tool_calls
|
llm_model, parallel_tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
@@ -691,7 +876,7 @@ async def llm_call(
|
|||||||
"HTTP-Referer": "https://agpt.co",
|
"HTTP-Referer": "https://agpt.co",
|
||||||
"X-Title": "AutoGPT",
|
"X-Title": "AutoGPT",
|
||||||
},
|
},
|
||||||
model=model_to_use,
|
model=llm_model.value,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
tools=tools_param, # type: ignore
|
tools=tools_param, # type: ignore
|
||||||
@@ -718,7 +903,7 @@ async def llm_call(
|
|||||||
reasoning=reasoning,
|
reasoning=reasoning,
|
||||||
)
|
)
|
||||||
elif provider == "aiml_api":
|
elif provider == "aiml_api":
|
||||||
client = openai.AsyncOpenAI(
|
client = openai.OpenAI(
|
||||||
base_url="https://api.aimlapi.com/v2",
|
base_url="https://api.aimlapi.com/v2",
|
||||||
api_key=credentials.api_key.get_secret_value(),
|
api_key=credentials.api_key.get_secret_value(),
|
||||||
default_headers={
|
default_headers={
|
||||||
@@ -728,8 +913,8 @@ async def llm_call(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
completion = await client.chat.completions.create(
|
completion = client.chat.completions.create(
|
||||||
model=model_to_use,
|
model=llm_model.value,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
@@ -757,11 +942,11 @@ async def llm_call(
|
|||||||
response_format = {"type": "json_object"}
|
response_format = {"type": "json_object"}
|
||||||
|
|
||||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||||
effective_model, parallel_tool_calls
|
llm_model, parallel_tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
model=model_to_use,
|
model=llm_model.value,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
response_format=response_format, # type: ignore
|
response_format=response_format, # type: ignore
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@@ -812,10 +997,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default_factory=LlmModel.default,
|
default=DEFAULT_LLM_MODEL,
|
||||||
description="The language model to use for answering the prompt.",
|
description="The language model to use for answering the prompt.",
|
||||||
advanced=False,
|
advanced=False,
|
||||||
json_schema_extra=llm_model_schema_extra(),
|
|
||||||
)
|
)
|
||||||
force_json_output: bool = SchemaField(
|
force_json_output: bool = SchemaField(
|
||||||
title="Restrict LLM to pure JSON output",
|
title="Restrict LLM to pure JSON output",
|
||||||
@@ -878,7 +1062,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
|||||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||||
test_input={
|
test_input={
|
||||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
"model": DEFAULT_LLM_MODEL,
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
"expected_format": {
|
"expected_format": {
|
||||||
"key1": "value1",
|
"key1": "value1",
|
||||||
@@ -1244,10 +1428,9 @@ class AITextGeneratorBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default_factory=LlmModel.default,
|
default=DEFAULT_LLM_MODEL,
|
||||||
description="The language model to use for answering the prompt.",
|
description="The language model to use for answering the prompt.",
|
||||||
advanced=False,
|
advanced=False,
|
||||||
json_schema_extra=llm_model_schema_extra(),
|
|
||||||
)
|
)
|
||||||
credentials: AICredentials = AICredentialsField()
|
credentials: AICredentials = AICredentialsField()
|
||||||
sys_prompt: str = SchemaField(
|
sys_prompt: str = SchemaField(
|
||||||
@@ -1341,9 +1524,8 @@ class AITextSummarizerBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default_factory=LlmModel.default,
|
default=DEFAULT_LLM_MODEL,
|
||||||
description="The language model to use for summarizing the text.",
|
description="The language model to use for summarizing the text.",
|
||||||
json_schema_extra=llm_model_schema_extra(),
|
|
||||||
)
|
)
|
||||||
focus: str = SchemaField(
|
focus: str = SchemaField(
|
||||||
title="Focus",
|
title="Focus",
|
||||||
@@ -1559,9 +1741,8 @@ class AIConversationBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default_factory=LlmModel.default,
|
default=DEFAULT_LLM_MODEL,
|
||||||
description="The language model to use for the conversation.",
|
description="The language model to use for the conversation.",
|
||||||
json_schema_extra=llm_model_schema_extra(),
|
|
||||||
)
|
)
|
||||||
credentials: AICredentials = AICredentialsField()
|
credentials: AICredentials = AICredentialsField()
|
||||||
max_tokens: int | None = SchemaField(
|
max_tokens: int | None = SchemaField(
|
||||||
@@ -1598,7 +1779,7 @@ class AIConversationBlock(AIBlockBase):
|
|||||||
},
|
},
|
||||||
{"role": "user", "content": "Where was it played?"},
|
{"role": "user", "content": "Where was it played?"},
|
||||||
],
|
],
|
||||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
"model": DEFAULT_LLM_MODEL,
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -1661,10 +1842,9 @@ class AIListGeneratorBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default_factory=LlmModel.default,
|
default=DEFAULT_LLM_MODEL,
|
||||||
description="The language model to use for generating the list.",
|
description="The language model to use for generating the list.",
|
||||||
advanced=True,
|
advanced=True,
|
||||||
json_schema_extra=llm_model_schema_extra(),
|
|
||||||
)
|
)
|
||||||
credentials: AICredentials = AICredentialsField()
|
credentials: AICredentials = AICredentialsField()
|
||||||
max_retries: int = SchemaField(
|
max_retries: int = SchemaField(
|
||||||
@@ -1719,7 +1899,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
|||||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||||
"fictional worlds."
|
"fictional worlds."
|
||||||
),
|
),
|
||||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
"model": DEFAULT_LLM_MODEL,
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
"max_retries": 3,
|
"max_retries": 3,
|
||||||
"force_json_output": False,
|
"force_json_output": False,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import operator
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import List, Literal
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any, Literal, Optional, Union
|
|||||||
from mem0 import MemoryClient
|
from mem0 import MemoryClient
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
|
from backend.blocks._base import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks.nvidia._auth import (
|
from backend.blocks._base import (
|
||||||
NvidiaCredentials,
|
|
||||||
NvidiaCredentialsField,
|
|
||||||
NvidiaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.nvidia._auth import (
|
||||||
|
NvidiaCredentials,
|
||||||
|
NvidiaCredentialsField,
|
||||||
|
NvidiaCredentialsInput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Any, Literal
|
|||||||
import openai
|
import openai
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any, Literal
|
|||||||
|
|
||||||
from pinecone import Pinecone, ServerlessSpec
|
from pinecone import Pinecone, ServerlessSpec
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import praw
|
|||||||
from praw.models import Comment, MoreComments, Submission
|
from praw.models import Comment, MoreComments, Submission
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.data.block import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -4,19 +4,19 @@ from enum import Enum
|
|||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
|
|
||||||
from backend.blocks.replicate._auth import (
|
from backend.blocks._base import (
|
||||||
TEST_CREDENTIALS,
|
|
||||||
TEST_CREDENTIALS_INPUT,
|
|
||||||
ReplicateCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.replicate._auth import (
|
||||||
|
TEST_CREDENTIALS,
|
||||||
|
TEST_CREDENTIALS_INPUT,
|
||||||
|
ReplicateCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
||||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,19 +4,19 @@ from typing import Optional
|
|||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
|
|
||||||
from backend.blocks.replicate._auth import (
|
from backend.blocks._base import (
|
||||||
TEST_CREDENTIALS,
|
|
||||||
TEST_CREDENTIALS_INPUT,
|
|
||||||
ReplicateCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.blocks.replicate._auth import (
|
||||||
|
TEST_CREDENTIALS,
|
||||||
|
TEST_CREDENTIALS_INPUT,
|
||||||
|
ReplicateCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
||||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||||
from backend.util.exceptions import BlockExecutionError, BlockInputError
|
from backend.util.exceptions import BlockExecutionError, BlockInputError
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user