Add model and functions

This commit is contained in:
Krzysztof Czerwinski
2025-05-23 16:48:34 +02:00
parent 2a06956802
commit 95387bcf78
5 changed files with 233 additions and 5 deletions

View File

@@ -74,6 +74,15 @@ class Pagination(pydantic.BaseModel):
description="Number of items per page.", examples=[25]
)
@staticmethod
def empty() -> "Pagination":
return Pagination(
total_items=0,
total_pages=0,
current_page=0,
page_size=0,
)
class RequestTopUp(pydantic.BaseModel):
credit_amount: int

View File

@@ -0,0 +1,152 @@
import functools
import backend.server.model as server_model
from backend.blocks import load_all_blocks
from backend.data.block import Block, BlockSchema, BlockType
from backend.data.credit import get_block_costs
from backend.integrations.providers import ProviderName
from backend.server.v2.builder.model import (
BlockResponse,
FilterType,
Provider,
ProviderResponse,
)
def get_blocks(
filter: list[FilterType],
query: str = "",
providers: list[ProviderName] | None = None,
page: int = 1,
page_size: int = 50,
) -> BlockResponse:
"""
Get blocks based on the filter and query.
`providers` only applies for `integrations` filter.
"""
blocks: list[Block[BlockSchema, BlockSchema]] = []
query = query.lower()
total = 0
skip = (page - 1) * page_size
take = page_size
block_count = 0
integration_count = 0
for block_type in load_all_blocks().values():
block: Block[BlockSchema, BlockSchema] = block_type()
# Skip disabled blocks
if block.disabled:
continue
# Skip blocks that don't match the query
if query not in block.name.lower() or query not in block.description.lower():
continue
keep = False
credentials = list(block.input_schema.get_credentials_fields().values())
# Skip blocks that don't match the filter
if (
("all_blocks" in filter)
or ("input_blocks" in filter and block.block_type == BlockType.INPUT)
or ("output_block" in filter and block.block_type == BlockType.OUTPUT)
):
block_count += 1
keep = True
elif (
"action_blocks" in filter
and block.block_type != BlockType.INPUT
and block.block_type != BlockType.OUTPUT
):
block_count += 1
keep = True
elif "integrations" in filter and len(credentials) > 0:
# Only keep if provider is in the list
if providers:
if any(c.provider in providers for c in credentials):
keep = True
integration_count += 1
else:
keep = True
integration_count += 1
if not keep:
continue
total += 1
if skip > 0:
skip -= 1
continue
if take > 0:
take -= 1
blocks.append(block)
costs = get_block_costs()
return BlockResponse(
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
total_block_count=block_count,
total_integration_count=integration_count,
pagination=server_model.Pagination(
total_items=total,
total_pages=total // page_size + (1 if total % page_size > 0 else 0),
current_page=page,
page_size=page_size,
),
)
def get_providers(
query: str = "",
page: int = 1,
page_size: int = 50,
) -> ProviderResponse:
providers = []
query = query.lower()
skip = (page - 1) * page_size
take = page_size
all_providers = _get_all_providers()
for provider in all_providers.values():
if (
query not in provider.name.value.lower()
and query not in provider.description.lower()
):
continue
if skip > 0:
skip -= 1
continue
if take > 0:
take -= 1
providers.append(provider)
total = len(all_providers)
return ProviderResponse(
providers=providers,
pagination=server_model.Pagination(
total_items=total,
total_pages=total // page_size + (1 if total % page_size > 0 else 0),
current_page=page,
page_size=page_size,
),
)
@functools.cache
def _get_all_providers() -> dict[ProviderName, Provider]:
providers = {}
for block_type in load_all_blocks().values():
block: Block[BlockSchema, BlockSchema] = block_type()
# Skip disabled blocks
if block.disabled:
continue
credentials = list(block.input_schema.get_credentials_fields().values())
for c in credentials:
if c.provider in providers:
providers[c.provider].integration_count += 1
else:
providers[c.provider] = Provider(
name=c.provider, description="", integration_count=1
)
return providers

View File

@@ -0,0 +1,67 @@
from typing import Any, Literal
from pydantic import BaseModel
import backend.server.model as server_model
import backend.server.v2.library.model as library_model
import backend.server.v2.store.model as store_model
from backend.integrations.providers import ProviderName
FilterType = (
Literal["all_blocks"]
| Literal["input_blocks"]
| Literal["action_blocks"]
| Literal["output_blocks"]
| Literal["integrations"]
| Literal["providers"]
| Literal["marketplace_agents"]
| Literal["my_agents"]
)
SearchResultType = (
Literal["blocks"]
| Literal["integrations"]
| Literal["providers"]
| Literal["marketplace_agents"]
| Literal["my_agents"]
)
BlockData = dict[str, Any]
class SearchOptions(BaseModel):
search_query: str | None = None
filter: list[FilterType] | None = None
providers: list[str] | None = None
by_creator: list[str] | None = None
search_id: str | None = None
page: int | None = None
page_size: int | None = None
class Provider(BaseModel):
name: ProviderName
description: str
integration_count: int
class BlockResponse(BaseModel):
blocks: list[BlockData]
total_block_count: int
total_integration_count: int
pagination: server_model.Pagination
class ProviderResponse(BaseModel):
providers: list[Provider]
pagination: server_model.Pagination
class BlockSearchResponse(BaseModel):
items: list[
BlockData | Provider | library_model.LibraryAgent | store_model.StoreAgent
]
total_items: dict[SearchResultType, int]
page: int
more_pages: bool

View File

@@ -37,7 +37,7 @@ def sanitize_query(query: str | None) -> str | None:
async def get_store_agents(
featured: bool = False,
creator: str | None = None,
creators: list[str] | None = None,
sorted_by: str | None = None,
search_query: str | None = None,
category: str | None = None,
@@ -48,15 +48,15 @@ async def get_store_agents(
Get PUBLIC store agents from the StoreAgent view
"""
logger.debug(
f"Getting store agents. featured={featured}, creator={creator}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
f"Getting store agents. featured={featured}, creator={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
)
sanitized_query = sanitize_query(search_query)
where_clause = {}
if featured:
where_clause["featured"] = featured
if creator:
where_clause["creator_username"] = creator
if creators:
where_clause["creator_username"] = {"in": creators}
if category:
where_clause["categories"] = {"has": category}

View File

@@ -152,7 +152,7 @@ async def get_agents(
try:
agents = await backend.server.v2.store.db.get_store_agents(
featured=featured,
creator=creator,
creators=[creator] if creator else None,
sorted_by=sorted_by,
search_query=search_query,
category=category,