Update backend

This commit is contained in:
Krzysztof Czerwinski
2025-05-25 15:11:29 +02:00
parent 1be830835b
commit bb69e32fee
3 changed files with 278 additions and 111 deletions

View File

@@ -1,25 +1,122 @@
import functools
import logging
import backend.server.model as server_model
from backend.blocks import load_all_blocks
from backend.data.block import Block, BlockSchema, BlockType
from backend.data.block import Block, BlockCategory, BlockSchema
from backend.data.credit import get_block_costs
from backend.integrations.providers import ProviderName
from backend.server.v2.builder.model import (
BlockCategoryResponse,
BlockResponse,
FilterType,
BlockType,
Provider,
ProviderResponse,
SearchBlocksResponse,
)
logger = logging.getLogger(__name__)
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
categories: dict[BlockCategory, BlockCategoryResponse] = {}
for block_type in load_all_blocks().values():
block: Block[BlockSchema, BlockSchema] = block_type()
# Skip disabled blocks
if block.disabled:
continue
# Skip blocks that don't have categories (all should have at least one)
if not block.categories:
continue
# Add block to the categories
for category in block.categories:
if category not in categories:
categories[category] = BlockCategoryResponse(
name=category.name.lower(),
total_blocks=0,
blocks=[],
)
categories[category].total_blocks += 1
# Append if the category has less than the specified number of blocks
if len(categories[category].blocks) < category_blocks:
categories[category].blocks.append(block.to_dict())
# Sort categories by name
return sorted(categories.values(), key=lambda x: x.name)
def get_blocks(
filter: list[FilterType],
query: str = "",
providers: list[ProviderName] | None = None,
*,
category: str | None = None,
type: BlockType | None = None,
provider: ProviderName | None = None,
page: int = 1,
page_size: int = 50,
) -> BlockResponse:
"""
Get blocks based on either category, type or provider.
Providing nothing assumes category is `all`.
"""
# Only one of category, type, or provider can be specified
if (category and type) or (category and provider) or (type and provider):
raise ValueError("Only one of category, type, or provider can be specified")
blocks: list[Block[BlockSchema, BlockSchema]] = []
skip = (page - 1) * page_size
take = page_size
total = 0
for block_type in load_all_blocks().values():
block: Block[BlockSchema, BlockSchema] = block_type()
# Skip disabled blocks
if block.disabled:
continue
# Skip blocks that don't match the category
if category and category not in {c.name.lower() for c in block.categories}:
continue
# Skip blocks that don't match the type
if (
(type == "input" and block.block_type.value != "Input")
or (type == "output" and block.block_type.value != "Output")
or (type == "action" and block.block_type.value in ("Input", "Output"))
):
continue
# Skip blocks that don't match the provider
if provider:
credentials_info = block.input_schema.get_credentials_fields_info().values()
if not any(provider in info.provider for info in credentials_info):
continue
total += 1
if skip > 0:
skip -= 1
continue
if take > 0:
take -= 1
blocks.append(block)
return BlockResponse(
blocks=[b.to_dict() for b in blocks],
pagination=server_model.Pagination(
total_items=total,
total_pages=total // page_size + (1 if total % page_size > 0 else 0),
current_page=page,
page_size=page_size,
),
)
def search_blocks(
include_blocks: bool = True,
include_integrations: bool = True,
query: str = "",
page: int = 1,
page_size: int = 50,
) -> SearchBlocksResponse:
"""
Get blocks based on the filter and query.
`providers` only applies for `integrations` filter.
@@ -43,30 +140,12 @@ def get_blocks(
continue
keep = False
credentials = list(block.input_schema.get_credentials_fields().values())
# Skip blocks that don't match the filter
if (
("all_blocks" in filter)
or ("input_blocks" in filter and block.block_type == BlockType.INPUT)
or ("output_block" in filter and block.block_type == BlockType.OUTPUT)
):
block_count += 1
if include_integrations and len(credentials) > 0:
keep = True
elif (
"action_blocks" in filter
and block.block_type != BlockType.INPUT
and block.block_type != BlockType.OUTPUT
):
block_count += 1
integration_count += 1
if include_blocks and len(credentials) == 0:
keep = True
elif "integrations" in filter and len(credentials) > 0:
# Only keep if provider is in the list
if providers:
if any(c.provider in providers for c in credentials):
keep = True
integration_count += 1
else:
keep = True
integration_count += 1
block_count += 1
if not keep:
continue
@@ -81,16 +160,18 @@ def get_blocks(
costs = get_block_costs()
return BlockResponse(
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
return SearchBlocksResponse(
blocks=BlockResponse(
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
pagination=server_model.Pagination(
total_items=total,
total_pages=total // page_size + (1 if total % page_size > 0 else 0),
current_page=page,
page_size=page_size,
),
),
total_block_count=block_count,
total_integration_count=integration_count,
pagination=server_model.Pagination(
total_items=total,
total_pages=total // page_size + (1 if total % page_size > 0 else 0),
current_page=page,
page_size=page_size,
),
)
@@ -135,18 +216,20 @@ def get_providers(
@functools.cache
def _get_all_providers() -> dict[ProviderName, Provider]:
providers = {}
providers: dict[ProviderName, Provider] = {}
for block_type in load_all_blocks().values():
block: Block[BlockSchema, BlockSchema] = block_type()
# Skip disabled blocks
if block.disabled:
continue
credentials = list(block.input_schema.get_credentials_fields().values())
for c in credentials:
if c.provider in providers:
providers[c.provider].integration_count += 1
else:
providers[c.provider] = Provider(
name=c.provider, description="", integration_count=1
)
credentials_info = block.input_schema.get_credentials_fields_info().values()
for info in credentials_info:
for provider in info.provider: # provider is a ProviderName enum member
if provider in providers:
providers[provider].integration_count += 1
else:
providers[provider] = Provider(
name=provider, description="", integration_count=1
)
return providers

View File

@@ -7,61 +7,74 @@ import backend.server.v2.library.model as library_model
import backend.server.v2.store.model as store_model
from backend.integrations.providers import ProviderName
FilterType = (
Literal["all_blocks"]
| Literal["input_blocks"]
| Literal["action_blocks"]
| Literal["output_blocks"]
| Literal["integrations"]
| Literal["providers"]
| Literal["marketplace_agents"]
| Literal["my_agents"]
)
FilterType = Literal[
"blocks",
"integrations",
"providers",
"marketplace_agents",
"my_agents",
]
SearchResultType = (
Literal["blocks"]
| Literal["integrations"]
| Literal["providers"]
| Literal["marketplace_agents"]
| Literal["my_agents"]
)
BlockType = Literal["all", "input", "action", "output"]
BlockData = dict[str, Any]
class SearchOptions(BaseModel):
search_query: str | None = None
filter: list[FilterType] | None = None
providers: list[str] | None = None
by_creator: list[str] | None = None
search_id: str | None = None
page: int | None = None
page_size: int | None = None
# Suggestions
class SuggestionsResponse(BaseModel):
otto_suggestions: list[str]
recent_searches: list[str]
providers: list[ProviderName]
top_blocks: list[BlockData]
# All blocks
class BlockCategoryResponse(BaseModel):
name: str
total_blocks: int
blocks: list[BlockData]
model_config = {"use_enum_values": False} # <== use enum names like "AI"
# Input/Action/Output and see all for block categories
class BlockResponse(BaseModel):
blocks: list[BlockData]
pagination: server_model.Pagination
# Providers
class Provider(BaseModel):
name: ProviderName
description: str
integration_count: int
class BlockResponse(BaseModel):
blocks: list[BlockData]
total_block_count: int
total_integration_count: int
pagination: server_model.Pagination
class ProviderResponse(BaseModel):
providers: list[Provider]
pagination: server_model.Pagination
class BlockSearchResponse(BaseModel):
# Search
class SearchRequest(BaseModel):
search_query: str | None = None
filter: list[FilterType] | None = None
by_creator: list[str] | None = None
search_id: str | None = None
page: int | None = None
page_size: int | None = None
class SearchBlocksResponse(BaseModel):
blocks: BlockResponse
total_block_count: int
total_integration_count: int
class SearchResponse(BaseModel):
items: list[
BlockData | Provider | library_model.LibraryAgent | store_model.StoreAgent
]
total_items: dict[SearchResultType, int]
total_items: dict[FilterType, int]
page: int
more_pages: bool

View File

@@ -1,8 +1,7 @@
import logging
import typing
from typing import Annotated, Sequence
import fastapi
import fastapi.responses
from autogpt_libs.auth.depends import auth_middleware, get_user_id
import backend.server.model as server_model
@@ -12,6 +11,7 @@ import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.server.v2.store.db as store_db
import backend.server.v2.store.model as store_model
from backend.integrations.providers import ProviderName
logger = logging.getLogger(__name__)
@@ -37,44 +37,116 @@ def sanitize_query(query: str | None) -> str | None:
)
@router.post(
@router.get(
"/suggestions",
dependencies=[fastapi.Depends(auth_middleware)],
)
async def get_suggestions(
user_id: Annotated[str, fastapi.Depends(get_user_id)],
) -> builder_model.SuggestionsResponse:
# todo kcze temp response
return builder_model.SuggestionsResponse(
otto_suggestions=[
"What blocks do I need to get started?",
"Help me create a list",
"Help me feed my data to Google Maps",
],
recent_searches=[
"image generation",
"deepfake",
"competitor analysis",
],
providers=[
ProviderName.TWITTER,
ProviderName.GITHUB,
ProviderName.HUBSPOT,
ProviderName.EXA,
ProviderName.JINA,
ProviderName.GOOGLE_MAPS,
],
top_blocks=builder_db.get_blocks(page_size=5).blocks,
)
@router.get(
"/categories",
dependencies=[fastapi.Depends(auth_middleware)],
)
async def get_block_categories(
category_blocks: Annotated[int, fastapi.Query()] = 3,
) -> Sequence[builder_model.BlockCategoryResponse]:
return builder_db.get_block_categories(category_blocks)
@router.get(
"/blocks",
tags=["store", "private"],
dependencies=[fastapi.Depends(auth_middleware)],
)
async def get_blocks(
options: builder_model.SearchOptions,
user_id: typing.Annotated[str, fastapi.Depends(get_user_id)],
) -> builder_model.BlockSearchResponse:
category: Annotated[str | None, fastapi.Query()] = None,
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1,
page_size: Annotated[int, fastapi.Query()] = 50,
) -> builder_model.BlockResponse:
return builder_db.get_blocks(
category=category,
type=type,
provider=provider,
page=page,
page_size=page_size,
)
@router.get(
"/providers",
dependencies=[fastapi.Depends(auth_middleware)],
)
async def get_providers(
page: Annotated[int, fastapi.Query()] = 1,
page_size: Annotated[int, fastapi.Query()] = 50,
) -> builder_model.ProviderResponse:
return builder_db.get_providers(
page=page,
page_size=page_size,
)
@router.post(
"/search",
tags=["store", "private"],
dependencies=[fastapi.Depends(auth_middleware)],
)
async def search(
options: builder_model.SearchRequest,
user_id: Annotated[str, fastapi.Depends(get_user_id)],
) -> builder_model.SearchResponse:
# If no filters are provided, then we will return all types
if not options.filter:
options.filter = [
"all_blocks",
"blocks",
"integrations",
"providers",
"marketplace_agents",
"my_agents",
"providers",
]
options.search_query = sanitize_query(options.search_query)
options.page = options.page or 1
options.page_size = options.page_size or 50
# Blocks&Integrations
blocks = builder_model.BlockResponse(
blocks=[],
blocks = builder_model.SearchBlocksResponse(
blocks=builder_model.BlockResponse(
blocks=[],
pagination=server_model.Pagination.empty(),
),
total_block_count=0,
total_integration_count=0,
pagination=server_model.Pagination.empty(),
)
if (
"all_blocks" in options.filter
or "input_blocks" in options.filter
or "action_blocks" in options.filter
or "output_blocks" in options.filter
or "integrations" in options.filter
):
blocks = builder_db.get_blocks(
filter=options.filter,
if "blocks" in options.filter or "integrations" in options.filter:
blocks = builder_db.search_blocks(
include_blocks="blocks" in options.filter,
include_integrations="integrations" in options.filter,
query=options.search_query or "",
page=options.page,
page_size=options.page_size,
@@ -85,12 +157,12 @@ async def get_blocks(
providers=[],
pagination=server_model.Pagination.empty(),
)
# if "providers" in options.filter:
# providers = builder_db.get_providers(
# query=options.search_query or "",
# page=options.page,
# page_size=options.page_size,
# )
if "providers" in options.filter:
providers = builder_db.get_providers(
query=options.search_query or "",
page=options.page,
page_size=options.page_size,
)
# Library Agents
my_agents = library_model.LibraryAgentResponse(
@@ -119,7 +191,7 @@ async def get_blocks(
)
more_pages = False
if blocks.pagination.current_page < blocks.pagination.total_pages:
if blocks.blocks.pagination.current_page < blocks.blocks.pagination.total_pages:
more_pages = True
if my_agents.pagination.current_page < my_agents.pagination.total_pages:
more_pages = True
@@ -131,15 +203,14 @@ async def get_blocks(
# todo kcze sort results
return builder_model.BlockSearchResponse(
items=blocks.blocks
return builder_model.SearchResponse(
items=blocks.blocks.blocks
+ providers.providers
+ my_agents.agents
+ marketplace_agents.agents,
total_items={
"blocks": blocks.total_block_count,
"integrations": blocks.total_integration_count,
"providers": providers.pagination.total_items,
"marketplace_agents": marketplace_agents.pagination.total_items,
"my_agents": my_agents.pagination.total_items,
},