mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Update backend
This commit is contained in:
@@ -1,25 +1,122 @@
|
||||
import functools
|
||||
import logging
|
||||
|
||||
import backend.server.model as server_model
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.data.block import Block, BlockSchema, BlockType
|
||||
from backend.data.block import Block, BlockCategory, BlockSchema
|
||||
from backend.data.credit import get_block_costs
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.v2.builder.model import (
|
||||
BlockCategoryResponse,
|
||||
BlockResponse,
|
||||
FilterType,
|
||||
BlockType,
|
||||
Provider,
|
||||
ProviderResponse,
|
||||
SearchBlocksResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
|
||||
categories: dict[BlockCategory, BlockCategoryResponse] = {}
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
# Skip blocks that don't have categories (all should have at least one)
|
||||
if not block.categories:
|
||||
continue
|
||||
|
||||
# Add block to the categories
|
||||
for category in block.categories:
|
||||
if category not in categories:
|
||||
categories[category] = BlockCategoryResponse(
|
||||
name=category.name.lower(),
|
||||
total_blocks=0,
|
||||
blocks=[],
|
||||
)
|
||||
|
||||
categories[category].total_blocks += 1
|
||||
|
||||
# Append if the category has less than the specified number of blocks
|
||||
if len(categories[category].blocks) < category_blocks:
|
||||
categories[category].blocks.append(block.to_dict())
|
||||
|
||||
# Sort categories by name
|
||||
return sorted(categories.values(), key=lambda x: x.name)
|
||||
|
||||
|
||||
def get_blocks(
|
||||
filter: list[FilterType],
|
||||
query: str = "",
|
||||
providers: list[ProviderName] | None = None,
|
||||
*,
|
||||
category: str | None = None,
|
||||
type: BlockType | None = None,
|
||||
provider: ProviderName | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> BlockResponse:
|
||||
"""
|
||||
Get blocks based on either category, type or provider.
|
||||
Providing nothing assumes category is `all`.
|
||||
"""
|
||||
# Only one of category, type, or provider can be specified
|
||||
if (category and type) or (category and provider) or (type and provider):
|
||||
raise ValueError("Only one of category, type, or provider can be specified")
|
||||
|
||||
blocks: list[Block[BlockSchema, BlockSchema]] = []
|
||||
skip = (page - 1) * page_size
|
||||
take = page_size
|
||||
total = 0
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
# Skip blocks that don't match the category
|
||||
if category and category not in {c.name.lower() for c in block.categories}:
|
||||
continue
|
||||
# Skip blocks that don't match the type
|
||||
if (
|
||||
(type == "input" and block.block_type.value != "Input")
|
||||
or (type == "output" and block.block_type.value != "Output")
|
||||
or (type == "action" and block.block_type.value in ("Input", "Output"))
|
||||
):
|
||||
continue
|
||||
# Skip blocks that don't match the provider
|
||||
if provider:
|
||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||
if not any(provider in info.provider for info in credentials_info):
|
||||
continue
|
||||
|
||||
total += 1
|
||||
if skip > 0:
|
||||
skip -= 1
|
||||
continue
|
||||
if take > 0:
|
||||
take -= 1
|
||||
blocks.append(block)
|
||||
|
||||
return BlockResponse(
|
||||
blocks=[b.to_dict() for b in blocks],
|
||||
pagination=server_model.Pagination(
|
||||
total_items=total,
|
||||
total_pages=total // page_size + (1 if total % page_size > 0 else 0),
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def search_blocks(
|
||||
include_blocks: bool = True,
|
||||
include_integrations: bool = True,
|
||||
query: str = "",
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> SearchBlocksResponse:
|
||||
"""
|
||||
Get blocks based on the filter and query.
|
||||
`providers` only applies for `integrations` filter.
|
||||
@@ -43,30 +140,12 @@ def get_blocks(
|
||||
continue
|
||||
keep = False
|
||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||
# Skip blocks that don't match the filter
|
||||
if (
|
||||
("all_blocks" in filter)
|
||||
or ("input_blocks" in filter and block.block_type == BlockType.INPUT)
|
||||
or ("output_block" in filter and block.block_type == BlockType.OUTPUT)
|
||||
):
|
||||
block_count += 1
|
||||
if include_integrations and len(credentials) > 0:
|
||||
keep = True
|
||||
elif (
|
||||
"action_blocks" in filter
|
||||
and block.block_type != BlockType.INPUT
|
||||
and block.block_type != BlockType.OUTPUT
|
||||
):
|
||||
block_count += 1
|
||||
integration_count += 1
|
||||
if include_blocks and len(credentials) == 0:
|
||||
keep = True
|
||||
elif "integrations" in filter and len(credentials) > 0:
|
||||
# Only keep if provider is in the list
|
||||
if providers:
|
||||
if any(c.provider in providers for c in credentials):
|
||||
keep = True
|
||||
integration_count += 1
|
||||
else:
|
||||
keep = True
|
||||
integration_count += 1
|
||||
block_count += 1
|
||||
|
||||
if not keep:
|
||||
continue
|
||||
@@ -81,16 +160,18 @@ def get_blocks(
|
||||
|
||||
costs = get_block_costs()
|
||||
|
||||
return BlockResponse(
|
||||
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
|
||||
return SearchBlocksResponse(
|
||||
blocks=BlockResponse(
|
||||
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
|
||||
pagination=server_model.Pagination(
|
||||
total_items=total,
|
||||
total_pages=total // page_size + (1 if total % page_size > 0 else 0),
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
),
|
||||
total_block_count=block_count,
|
||||
total_integration_count=integration_count,
|
||||
pagination=server_model.Pagination(
|
||||
total_items=total,
|
||||
total_pages=total // page_size + (1 if total % page_size > 0 else 0),
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -135,18 +216,20 @@ def get_providers(
|
||||
|
||||
@functools.cache
|
||||
def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
providers = {}
|
||||
providers: dict[ProviderName, Provider] = {}
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||
for c in credentials:
|
||||
if c.provider in providers:
|
||||
providers[c.provider].integration_count += 1
|
||||
else:
|
||||
providers[c.provider] = Provider(
|
||||
name=c.provider, description="", integration_count=1
|
||||
)
|
||||
|
||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||
for info in credentials_info:
|
||||
for provider in info.provider: # provider is a ProviderName enum member
|
||||
if provider in providers:
|
||||
providers[provider].integration_count += 1
|
||||
else:
|
||||
providers[provider] = Provider(
|
||||
name=provider, description="", integration_count=1
|
||||
)
|
||||
return providers
|
||||
|
||||
@@ -7,61 +7,74 @@ import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
FilterType = (
|
||||
Literal["all_blocks"]
|
||||
| Literal["input_blocks"]
|
||||
| Literal["action_blocks"]
|
||||
| Literal["output_blocks"]
|
||||
| Literal["integrations"]
|
||||
| Literal["providers"]
|
||||
| Literal["marketplace_agents"]
|
||||
| Literal["my_agents"]
|
||||
)
|
||||
FilterType = Literal[
|
||||
"blocks",
|
||||
"integrations",
|
||||
"providers",
|
||||
"marketplace_agents",
|
||||
"my_agents",
|
||||
]
|
||||
|
||||
|
||||
SearchResultType = (
|
||||
Literal["blocks"]
|
||||
| Literal["integrations"]
|
||||
| Literal["providers"]
|
||||
| Literal["marketplace_agents"]
|
||||
| Literal["my_agents"]
|
||||
)
|
||||
BlockType = Literal["all", "input", "action", "output"]
|
||||
|
||||
BlockData = dict[str, Any]
|
||||
|
||||
|
||||
class SearchOptions(BaseModel):
|
||||
search_query: str | None = None
|
||||
filter: list[FilterType] | None = None
|
||||
providers: list[str] | None = None
|
||||
by_creator: list[str] | None = None
|
||||
search_id: str | None = None
|
||||
page: int | None = None
|
||||
page_size: int | None = None
|
||||
# Suggestions
|
||||
class SuggestionsResponse(BaseModel):
|
||||
otto_suggestions: list[str]
|
||||
recent_searches: list[str]
|
||||
providers: list[ProviderName]
|
||||
top_blocks: list[BlockData]
|
||||
|
||||
|
||||
# All blocks
|
||||
class BlockCategoryResponse(BaseModel):
|
||||
name: str
|
||||
total_blocks: int
|
||||
blocks: list[BlockData]
|
||||
|
||||
model_config = {"use_enum_values": False} # <== use enum names like "AI"
|
||||
|
||||
|
||||
# Input/Action/Output and see all for block categories
|
||||
class BlockResponse(BaseModel):
|
||||
blocks: list[BlockData]
|
||||
pagination: server_model.Pagination
|
||||
|
||||
|
||||
# Providers
|
||||
class Provider(BaseModel):
|
||||
name: ProviderName
|
||||
description: str
|
||||
integration_count: int
|
||||
|
||||
|
||||
class BlockResponse(BaseModel):
|
||||
blocks: list[BlockData]
|
||||
total_block_count: int
|
||||
total_integration_count: int
|
||||
pagination: server_model.Pagination
|
||||
|
||||
|
||||
class ProviderResponse(BaseModel):
|
||||
providers: list[Provider]
|
||||
pagination: server_model.Pagination
|
||||
|
||||
|
||||
class BlockSearchResponse(BaseModel):
|
||||
# Search
|
||||
class SearchRequest(BaseModel):
|
||||
search_query: str | None = None
|
||||
filter: list[FilterType] | None = None
|
||||
by_creator: list[str] | None = None
|
||||
search_id: str | None = None
|
||||
page: int | None = None
|
||||
page_size: int | None = None
|
||||
|
||||
|
||||
class SearchBlocksResponse(BaseModel):
|
||||
blocks: BlockResponse
|
||||
total_block_count: int
|
||||
total_integration_count: int
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
items: list[
|
||||
BlockData | Provider | library_model.LibraryAgent | store_model.StoreAgent
|
||||
]
|
||||
total_items: dict[SearchResultType, int]
|
||||
total_items: dict[FilterType, int]
|
||||
page: int
|
||||
more_pages: bool
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import logging
|
||||
import typing
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
from autogpt_libs.auth.depends import auth_middleware, get_user_id
|
||||
|
||||
import backend.server.model as server_model
|
||||
@@ -12,6 +11,7 @@ import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.db as store_db
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,44 +37,116 @@ def sanitize_query(query: str | None) -> str | None:
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
@router.get(
|
||||
"/suggestions",
|
||||
dependencies=[fastapi.Depends(auth_middleware)],
|
||||
)
|
||||
async def get_suggestions(
|
||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
||||
) -> builder_model.SuggestionsResponse:
|
||||
# todo kcze temp response
|
||||
return builder_model.SuggestionsResponse(
|
||||
otto_suggestions=[
|
||||
"What blocks do I need to get started?",
|
||||
"Help me create a list",
|
||||
"Help me feed my data to Google Maps",
|
||||
],
|
||||
recent_searches=[
|
||||
"image generation",
|
||||
"deepfake",
|
||||
"competitor analysis",
|
||||
],
|
||||
providers=[
|
||||
ProviderName.TWITTER,
|
||||
ProviderName.GITHUB,
|
||||
ProviderName.HUBSPOT,
|
||||
ProviderName.EXA,
|
||||
ProviderName.JINA,
|
||||
ProviderName.GOOGLE_MAPS,
|
||||
],
|
||||
top_blocks=builder_db.get_blocks(page_size=5).blocks,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/categories",
|
||||
dependencies=[fastapi.Depends(auth_middleware)],
|
||||
)
|
||||
async def get_block_categories(
|
||||
category_blocks: Annotated[int, fastapi.Query()] = 3,
|
||||
) -> Sequence[builder_model.BlockCategoryResponse]:
|
||||
return builder_db.get_block_categories(category_blocks)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/blocks",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(auth_middleware)],
|
||||
)
|
||||
async def get_blocks(
|
||||
options: builder_model.SearchOptions,
|
||||
user_id: typing.Annotated[str, fastapi.Depends(get_user_id)],
|
||||
) -> builder_model.BlockSearchResponse:
|
||||
category: Annotated[str | None, fastapi.Query()] = None,
|
||||
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
|
||||
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
||||
page: Annotated[int, fastapi.Query()] = 1,
|
||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||
) -> builder_model.BlockResponse:
|
||||
return builder_db.get_blocks(
|
||||
category=category,
|
||||
type=type,
|
||||
provider=provider,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
dependencies=[fastapi.Depends(auth_middleware)],
|
||||
)
|
||||
async def get_providers(
|
||||
page: Annotated[int, fastapi.Query()] = 1,
|
||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||
) -> builder_model.ProviderResponse:
|
||||
return builder_db.get_providers(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(auth_middleware)],
|
||||
)
|
||||
async def search(
|
||||
options: builder_model.SearchRequest,
|
||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
||||
) -> builder_model.SearchResponse:
|
||||
# If no filters are provided, then we will return all types
|
||||
if not options.filter:
|
||||
options.filter = [
|
||||
"all_blocks",
|
||||
"blocks",
|
||||
"integrations",
|
||||
"providers",
|
||||
"marketplace_agents",
|
||||
"my_agents",
|
||||
"providers",
|
||||
]
|
||||
options.search_query = sanitize_query(options.search_query)
|
||||
options.page = options.page or 1
|
||||
options.page_size = options.page_size or 50
|
||||
|
||||
# Blocks&Integrations
|
||||
blocks = builder_model.BlockResponse(
|
||||
blocks=[],
|
||||
blocks = builder_model.SearchBlocksResponse(
|
||||
blocks=builder_model.BlockResponse(
|
||||
blocks=[],
|
||||
pagination=server_model.Pagination.empty(),
|
||||
),
|
||||
total_block_count=0,
|
||||
total_integration_count=0,
|
||||
pagination=server_model.Pagination.empty(),
|
||||
)
|
||||
if (
|
||||
"all_blocks" in options.filter
|
||||
or "input_blocks" in options.filter
|
||||
or "action_blocks" in options.filter
|
||||
or "output_blocks" in options.filter
|
||||
or "integrations" in options.filter
|
||||
):
|
||||
blocks = builder_db.get_blocks(
|
||||
filter=options.filter,
|
||||
if "blocks" in options.filter or "integrations" in options.filter:
|
||||
blocks = builder_db.search_blocks(
|
||||
include_blocks="blocks" in options.filter,
|
||||
include_integrations="integrations" in options.filter,
|
||||
query=options.search_query or "",
|
||||
page=options.page,
|
||||
page_size=options.page_size,
|
||||
@@ -85,12 +157,12 @@ async def get_blocks(
|
||||
providers=[],
|
||||
pagination=server_model.Pagination.empty(),
|
||||
)
|
||||
# if "providers" in options.filter:
|
||||
# providers = builder_db.get_providers(
|
||||
# query=options.search_query or "",
|
||||
# page=options.page,
|
||||
# page_size=options.page_size,
|
||||
# )
|
||||
if "providers" in options.filter:
|
||||
providers = builder_db.get_providers(
|
||||
query=options.search_query or "",
|
||||
page=options.page,
|
||||
page_size=options.page_size,
|
||||
)
|
||||
|
||||
# Library Agents
|
||||
my_agents = library_model.LibraryAgentResponse(
|
||||
@@ -119,7 +191,7 @@ async def get_blocks(
|
||||
)
|
||||
|
||||
more_pages = False
|
||||
if blocks.pagination.current_page < blocks.pagination.total_pages:
|
||||
if blocks.blocks.pagination.current_page < blocks.blocks.pagination.total_pages:
|
||||
more_pages = True
|
||||
if my_agents.pagination.current_page < my_agents.pagination.total_pages:
|
||||
more_pages = True
|
||||
@@ -131,15 +203,14 @@ async def get_blocks(
|
||||
|
||||
# todo kcze sort results
|
||||
|
||||
return builder_model.BlockSearchResponse(
|
||||
items=blocks.blocks
|
||||
return builder_model.SearchResponse(
|
||||
items=blocks.blocks.blocks
|
||||
+ providers.providers
|
||||
+ my_agents.agents
|
||||
+ marketplace_agents.agents,
|
||||
total_items={
|
||||
"blocks": blocks.total_block_count,
|
||||
"integrations": blocks.total_integration_count,
|
||||
"providers": providers.pagination.total_items,
|
||||
"marketplace_agents": marketplace_agents.pagination.total_items,
|
||||
"my_agents": my_agents.pagination.total_items,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user