mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
Refactor Exa blocks to use exa_py SDK and remove legacy API code
Replaces direct HTTP calls and custom API helpers with the official exa_py AsyncExa SDK across all Exa-related blocks. Removes the now-unnecessary _api.py and legacy polling/helpers, updates block outputs to use Pydantic models converted from SDK dataclasses, and introduces stable preview models for webset previews. This improves maintainability, reliability, and future-proofs the integration against API changes.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -178,3 +178,4 @@ autogpt_platform/backend/settings.py
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
/autogpt_platform/backend/logs
|
||||
|
||||
@@ -1,349 +0,0 @@
|
||||
"""
|
||||
Exa Websets API utilities.
|
||||
|
||||
This module provides common building blocks for interacting with the Exa API:
|
||||
- URL construction (ExaApiUrls)
|
||||
- Header building
|
||||
- Item counting
|
||||
- Pagination helpers
|
||||
- Generic polling
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Callable, Iterator, Tuple
|
||||
|
||||
from backend.sdk import Requests
|
||||
|
||||
|
||||
class ExaApiUrls:
|
||||
"""Centralized URL builder for Exa Websets API endpoints."""
|
||||
|
||||
BASE = "https://api.exa.ai/websets/v0"
|
||||
|
||||
# ==================== Webset Endpoints ====================
|
||||
|
||||
@classmethod
|
||||
def websets(cls) -> str:
|
||||
"""List all websets endpoint."""
|
||||
return f"{cls.BASE}/websets"
|
||||
|
||||
@classmethod
|
||||
def webset(cls, webset_id: str) -> str:
|
||||
"""Get/update/delete webset endpoint."""
|
||||
return f"{cls.BASE}/websets/{webset_id}"
|
||||
|
||||
@classmethod
|
||||
def webset_cancel(cls, webset_id: str) -> str:
|
||||
"""Cancel webset endpoint."""
|
||||
return f"{cls.BASE}/websets/{webset_id}/cancel"
|
||||
|
||||
@classmethod
|
||||
def webset_preview(cls) -> str:
|
||||
"""Preview webset query endpoint."""
|
||||
return f"{cls.BASE}/websets/preview"
|
||||
|
||||
# ==================== Item Endpoints ====================
|
||||
|
||||
@classmethod
|
||||
def webset_items(cls, webset_id: str) -> str:
|
||||
"""List webset items endpoint."""
|
||||
return f"{cls.BASE}/websets/{webset_id}/items"
|
||||
|
||||
@classmethod
|
||||
def webset_item(cls, webset_id: str, item_id: str) -> str:
|
||||
"""Get/delete specific item endpoint."""
|
||||
return f"{cls.BASE}/websets/{webset_id}/items/{item_id}"
|
||||
|
||||
# ==================== Search Endpoints ====================
|
||||
|
||||
@classmethod
|
||||
def webset_searches(cls, webset_id: str) -> str:
|
||||
"""List/create searches endpoint."""
|
||||
return f"{cls.BASE}/websets/{webset_id}/searches"
|
||||
|
||||
@classmethod
|
||||
def webset_search(cls, webset_id: str, search_id: str) -> str:
|
||||
"""Get specific search endpoint."""
|
||||
return f"{cls.BASE}/websets/{webset_id}/searches/{search_id}"
|
||||
|
||||
@classmethod
|
||||
def webset_search_cancel(cls, webset_id: str, search_id: str) -> str:
|
||||
"""Cancel search endpoint."""
|
||||
return f"{cls.BASE}/websets/{webset_id}/searches/{search_id}/cancel"
|
||||
|
||||
# ==================== Enrichment Endpoints ====================
|
||||
|
||||
@classmethod
|
||||
def webset_enrichments(cls, webset_id: str) -> str:
|
||||
"""List/create enrichments endpoint."""
|
||||
return f"{cls.BASE}/websets/{webset_id}/enrichments"
|
||||
|
||||
@classmethod
|
||||
def webset_enrichment(cls, webset_id: str, enrichment_id: str) -> str:
|
||||
"""Get/update/delete enrichment endpoint."""
|
||||
return f"{cls.BASE}/websets/{webset_id}/enrichments/{enrichment_id}"
|
||||
|
||||
@classmethod
|
||||
def webset_enrichment_cancel(cls, webset_id: str, enrichment_id: str) -> str:
|
||||
"""Cancel enrichment endpoint."""
|
||||
return f"{cls.BASE}/websets/{webset_id}/enrichments/{enrichment_id}/cancel"
|
||||
|
||||
# ==================== Monitor Endpoints ====================
|
||||
|
||||
@classmethod
|
||||
def monitors(cls) -> str:
|
||||
"""List/create monitors endpoint."""
|
||||
return f"{cls.BASE}/monitors"
|
||||
|
||||
@classmethod
|
||||
def monitor(cls, monitor_id: str) -> str:
|
||||
"""Get/update/delete monitor endpoint."""
|
||||
return f"{cls.BASE}/monitors/{monitor_id}"
|
||||
|
||||
# ==================== Import Endpoints ====================
|
||||
|
||||
@classmethod
|
||||
def imports(cls) -> str:
|
||||
"""List/create imports endpoint."""
|
||||
return f"{cls.BASE}/imports"
|
||||
|
||||
@classmethod
|
||||
def import_(cls, import_id: str) -> str:
|
||||
"""Get/delete import endpoint."""
|
||||
return f"{cls.BASE}/imports/{import_id}"
|
||||
|
||||
|
||||
def build_headers(api_key: str, include_content_type: bool = False) -> dict:
|
||||
"""
|
||||
Build standard Exa API headers.
|
||||
|
||||
Args:
|
||||
api_key: The API key for authentication
|
||||
include_content_type: Whether to include Content-Type: application/json header
|
||||
|
||||
Returns:
|
||||
Dictionary of headers ready for API requests
|
||||
|
||||
Example:
|
||||
>>> headers = build_headers("sk-123456")
|
||||
>>> headers = build_headers("sk-123456", include_content_type=True)
|
||||
"""
|
||||
headers = {"x-api-key": api_key}
|
||||
if include_content_type:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return headers
|
||||
|
||||
|
||||
async def get_item_count(webset_id: str, headers: dict) -> int:
|
||||
"""
|
||||
Get the total item count for a webset efficiently.
|
||||
|
||||
This makes a request with limit=1 and reads from pagination data
|
||||
to avoid fetching all items.
|
||||
|
||||
Args:
|
||||
webset_id: The webset ID
|
||||
headers: Request headers with API key
|
||||
|
||||
Returns:
|
||||
Total number of items in the webset
|
||||
|
||||
Example:
|
||||
>>> count = await get_item_count("ws-123", headers)
|
||||
"""
|
||||
url = ExaApiUrls.webset_items(webset_id)
|
||||
response = await Requests().get(url, headers=headers, params={"limit": 1})
|
||||
data = response.json()
|
||||
|
||||
# Prefer pagination total if available
|
||||
if "pagination" in data:
|
||||
return data["pagination"].get("total", 0)
|
||||
|
||||
# Fallback to data length
|
||||
return len(data.get("data", []))
|
||||
|
||||
|
||||
def yield_paginated_results(
|
||||
data: dict, list_key: str = "items", item_key: str = "item"
|
||||
) -> Iterator[Tuple[str, Any]]:
|
||||
"""
|
||||
Yield paginated API results in standard format.
|
||||
|
||||
This helper yields both the full list and individual items for flexible
|
||||
graph connections, plus pagination metadata.
|
||||
|
||||
Args:
|
||||
data: API response data containing 'data', 'hasMore', 'nextCursor' fields
|
||||
list_key: Output key name for the full list (default: "items")
|
||||
item_key: Output key name for individual items (default: "item")
|
||||
|
||||
Yields:
|
||||
Tuples of (key, value) for block output:
|
||||
- (list_key, list): Full list of items
|
||||
- (item_key, item): Each individual item (yielded separately)
|
||||
- ("has_more", bool): Whether more results exist
|
||||
- ("next_cursor", str|None): Cursor for next page
|
||||
|
||||
Example:
|
||||
>>> for key, value in yield_paginated_results(response_data, "websets", "webset"):
|
||||
>>> yield key, value
|
||||
"""
|
||||
items = data.get("data", [])
|
||||
|
||||
# Yield full list for batch processing
|
||||
yield list_key, items
|
||||
|
||||
# Yield individual items for single-item processing chains
|
||||
for item in items:
|
||||
yield item_key, item
|
||||
|
||||
# Yield pagination metadata
|
||||
yield "has_more", data.get("hasMore", False)
|
||||
yield "next_cursor", data.get("nextCursor")
|
||||
|
||||
|
||||
async def poll_until_complete(
|
||||
url: str,
|
||||
headers: dict,
|
||||
is_complete: Callable[[dict], bool],
|
||||
extract_result: Callable[[dict], Any],
|
||||
timeout: int = 300,
|
||||
initial_interval: float = 5.0,
|
||||
max_interval: float = 30.0,
|
||||
backoff_factor: float = 1.5,
|
||||
) -> Any:
|
||||
"""
|
||||
Generic polling function with exponential backoff for async operations.
|
||||
|
||||
This function polls an API endpoint until a completion condition is met,
|
||||
using exponential backoff to reduce API load.
|
||||
|
||||
Args:
|
||||
url: API endpoint to poll
|
||||
headers: Request headers with API key
|
||||
is_complete: Function that takes response data and returns True when complete
|
||||
extract_result: Function that extracts the result from response data
|
||||
timeout: Maximum time to wait in seconds (default: 300)
|
||||
initial_interval: Starting interval between polls in seconds (default: 5.0)
|
||||
max_interval: Maximum interval between polls in seconds (default: 30.0)
|
||||
backoff_factor: Factor to multiply interval by each iteration (default: 1.5)
|
||||
|
||||
Returns:
|
||||
The result extracted by extract_result function when operation completes
|
||||
|
||||
Raises:
|
||||
TimeoutError: If operation doesn't complete within timeout
|
||||
|
||||
Example:
|
||||
>>> result = await poll_until_complete(
|
||||
>>> url=ExaApiUrls.webset(webset_id),
|
||||
>>> headers=build_headers(api_key),
|
||||
>>> is_complete=lambda data: data.get("status") == "idle",
|
||||
>>> extract_result=lambda data: data.get("itemsCount", 0),
|
||||
>>> timeout=300
|
||||
>>> )
|
||||
"""
|
||||
start_time = time.time()
|
||||
interval = initial_interval
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
response = await Requests().get(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
if is_complete(data):
|
||||
return extract_result(data)
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * backoff_factor, max_interval)
|
||||
|
||||
# Timeout reached - raise error
|
||||
raise TimeoutError(f"Operation did not complete within {timeout} seconds")
|
||||
|
||||
|
||||
async def poll_webset_until_idle(
|
||||
webset_id: str, headers: dict, timeout: int = 300
|
||||
) -> int:
|
||||
"""
|
||||
Poll a webset until it reaches 'idle' status.
|
||||
|
||||
Convenience wrapper around poll_until_complete specifically for websets.
|
||||
|
||||
Args:
|
||||
webset_id: The webset ID to poll
|
||||
headers: Request headers with API key
|
||||
timeout: Maximum time to wait in seconds
|
||||
|
||||
Returns:
|
||||
The item count when webset becomes idle
|
||||
|
||||
Raises:
|
||||
TimeoutError: If webset doesn't become idle within timeout
|
||||
"""
|
||||
return await poll_until_complete(
|
||||
url=ExaApiUrls.webset(webset_id),
|
||||
headers=headers,
|
||||
is_complete=lambda data: data.get("status", {}).get("type") == "idle",
|
||||
extract_result=lambda data: data.get("itemsCount", 0),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
async def poll_search_until_complete(
|
||||
webset_id: str, search_id: str, headers: dict, timeout: int = 300
|
||||
) -> int:
|
||||
"""
|
||||
Poll a search until it completes (completed/failed/cancelled).
|
||||
|
||||
Convenience wrapper around poll_until_complete specifically for searches.
|
||||
|
||||
Args:
|
||||
webset_id: The webset ID
|
||||
search_id: The search ID to poll
|
||||
headers: Request headers with API key
|
||||
timeout: Maximum time to wait in seconds
|
||||
|
||||
Returns:
|
||||
The number of items found when search completes
|
||||
|
||||
Raises:
|
||||
TimeoutError: If search doesn't complete within timeout
|
||||
"""
|
||||
return await poll_until_complete(
|
||||
url=ExaApiUrls.webset_search(webset_id, search_id),
|
||||
headers=headers,
|
||||
is_complete=lambda data: data.get("status")
|
||||
in ["completed", "failed", "cancelled"],
|
||||
extract_result=lambda data: data.get("progress", {}).get("found", 0),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
async def poll_enrichment_until_complete(
|
||||
webset_id: str, enrichment_id: str, headers: dict, timeout: int = 300
|
||||
) -> int:
|
||||
"""
|
||||
Poll an enrichment until it completes (completed/failed/cancelled).
|
||||
|
||||
Convenience wrapper around poll_until_complete specifically for enrichments.
|
||||
|
||||
Args:
|
||||
webset_id: The webset ID
|
||||
enrichment_id: The enrichment ID to poll
|
||||
headers: Request headers with API key
|
||||
timeout: Maximum time to wait in seconds
|
||||
|
||||
Returns:
|
||||
The number of items enriched when operation completes
|
||||
|
||||
Raises:
|
||||
TimeoutError: If enrichment doesn't complete within timeout
|
||||
"""
|
||||
return await poll_until_complete(
|
||||
url=ExaApiUrls.webset_enrichment(webset_id, enrichment_id),
|
||||
headers=headers,
|
||||
is_complete=lambda data: data.get("status")
|
||||
in ["completed", "failed", "cancelled"],
|
||||
extract_result=lambda data: data.get("progress", {}).get("processedItems", 0),
|
||||
timeout=timeout,
|
||||
)
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.api import AnswerResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
@@ -10,12 +12,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
MediaFileType,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import CostDollars
|
||||
|
||||
|
||||
class AnswerCitation(BaseModel):
|
||||
@@ -23,25 +23,33 @@ class AnswerCitation(BaseModel):
|
||||
|
||||
id: str = SchemaField(description="The temporary ID for the document")
|
||||
url: str = SchemaField(description="The URL of the search result")
|
||||
title: Optional[str] = SchemaField(
|
||||
description="The title of the search result", default=None
|
||||
)
|
||||
author: Optional[str] = SchemaField(
|
||||
description="The author of the content", default=None
|
||||
)
|
||||
title: Optional[str] = SchemaField(description="The title of the search result")
|
||||
author: Optional[str] = SchemaField(description="The author of the content")
|
||||
publishedDate: Optional[str] = SchemaField(
|
||||
description="An estimate of the creation date", default=None
|
||||
)
|
||||
text: Optional[str] = SchemaField(
|
||||
description="The full text content of the source", default=None
|
||||
description="An estimate of the creation date"
|
||||
)
|
||||
text: Optional[str] = SchemaField(description="The full text content of the source")
|
||||
image: Optional[MediaFileType] = SchemaField(
|
||||
description="The URL of the image associated with the result", default=None
|
||||
description="The URL of the image associated with the result"
|
||||
)
|
||||
favicon: Optional[MediaFileType] = SchemaField(
|
||||
description="The URL of the favicon for the domain", default=None
|
||||
description="The URL of the favicon for the domain"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_citation) -> "AnswerCitation":
|
||||
"""Convert SDK AnswerResult (dataclass) to our Pydantic model."""
|
||||
return cls(
|
||||
id=getattr(sdk_citation, "id", ""),
|
||||
url=getattr(sdk_citation, "url", ""),
|
||||
title=getattr(sdk_citation, "title", None),
|
||||
author=getattr(sdk_citation, "author", None),
|
||||
publishedDate=getattr(sdk_citation, "published_date", None),
|
||||
text=getattr(sdk_citation, "text", None),
|
||||
image=getattr(sdk_citation, "image", None),
|
||||
favicon=getattr(sdk_citation, "favicon", None),
|
||||
)
|
||||
|
||||
|
||||
class ExaAnswerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
@@ -62,18 +70,12 @@ class ExaAnswerBlock(Block):
|
||||
description="The generated answer based on search results"
|
||||
)
|
||||
citations: list[AnswerCitation] = SchemaField(
|
||||
description="Search results used to generate the answer",
|
||||
default_factory=list,
|
||||
description="Search results used to generate the answer"
|
||||
)
|
||||
citation: AnswerCitation = SchemaField(
|
||||
description="Individual citation from the answer"
|
||||
)
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request", default=None
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -87,38 +89,27 @@ class ExaAnswerBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/answer"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"text": input_data.text,
|
||||
# We don't support streaming in blocks
|
||||
"stream": False,
|
||||
}
|
||||
# Get answer using SDK (stream=False for blocks) - this IS async, needs await
|
||||
response = await aexa.answer(
|
||||
query=input_data.query, text=input_data.text, stream=False
|
||||
)
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# this should remain true as long as they don't start defaulting to streaming only.
|
||||
# provides a bit of safety for sdk updates.
|
||||
assert type(response) is AnswerResponse
|
||||
|
||||
# Yield the answer
|
||||
if "answer" in data:
|
||||
yield "answer", data["answer"]
|
||||
# Yield the answer
|
||||
yield "answer", response.answer
|
||||
|
||||
# Yield citations as a list
|
||||
if "citations" in data:
|
||||
yield "citations", data["citations"]
|
||||
# Convert citations to our Pydantic model using from_sdk()
|
||||
citations = [
|
||||
AnswerCitation.from_sdk(sdk_citation)
|
||||
for sdk_citation in response.citations or []
|
||||
]
|
||||
|
||||
# Also yield individual citations
|
||||
for citation in data["citations"]:
|
||||
yield "citation", citation
|
||||
|
||||
# Yield cost information if present
|
||||
if "costDollars" in data:
|
||||
yield "cost_dollars", data["costDollars"]
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "citations", citations
|
||||
for citation in citations:
|
||||
yield "citation", citation
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
@@ -10,7 +11,6 @@ from backend.sdk import (
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
@@ -105,27 +105,22 @@ class ExaContentsBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: list[ExaSearchResults] = SchemaField(
|
||||
description="List of document contents with metadata", default_factory=list
|
||||
description="List of document contents with metadata"
|
||||
)
|
||||
result: ExaSearchResults = SchemaField(
|
||||
description="Single document content result"
|
||||
)
|
||||
context: str = SchemaField(
|
||||
description="A formatted string of the results ready for LLMs", default=""
|
||||
)
|
||||
request_id: str = SchemaField(
|
||||
description="Unique identifier for the request", default=""
|
||||
description="A formatted string of the results ready for LLMs"
|
||||
)
|
||||
request_id: str = SchemaField(description="Unique identifier for the request")
|
||||
statuses: list[ContentStatus] = SchemaField(
|
||||
description="Status information for each requested URL",
|
||||
default_factory=list,
|
||||
description="Status information for each requested URL"
|
||||
)
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request", default=None
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
description="Cost breakdown for the request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -139,27 +134,22 @@ class ExaContentsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/contents"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
# Validate input
|
||||
if not input_data.urls and not input_data.ids:
|
||||
raise ValueError("Either 'urls' or 'ids' must be provided")
|
||||
|
||||
# Build payload with urls or deprecated ids
|
||||
payload = {}
|
||||
# Build kwargs for SDK call
|
||||
sdk_kwargs = {}
|
||||
|
||||
# Prefer urls over ids
|
||||
if input_data.urls:
|
||||
payload["urls"] = input_data.urls
|
||||
sdk_kwargs["urls"] = input_data.urls
|
||||
elif input_data.ids:
|
||||
payload["ids"] = input_data.ids
|
||||
else:
|
||||
yield "error", "Either 'urls' or 'ids' must be provided"
|
||||
return
|
||||
sdk_kwargs["ids"] = input_data.ids
|
||||
|
||||
# Handle text field - when true, include HTML tags for better LLM understanding
|
||||
if input_data.text:
|
||||
payload["text"] = {"includeHtmlTags": True}
|
||||
sdk_kwargs["text"] = {"includeHtmlTags": True}
|
||||
|
||||
# Handle highlights - only include if modified from defaults
|
||||
if input_data.highlights and (
|
||||
@@ -174,7 +164,7 @@ class ExaContentsBlock(Block):
|
||||
)
|
||||
if input_data.highlights.query:
|
||||
highlights_dict["query"] = input_data.highlights.query
|
||||
payload["highlights"] = highlights_dict
|
||||
sdk_kwargs["highlights"] = highlights_dict
|
||||
|
||||
# Handle summary - only include if modified from defaults
|
||||
if input_data.summary and (
|
||||
@@ -186,23 +176,23 @@ class ExaContentsBlock(Block):
|
||||
summary_dict["query"] = input_data.summary.query
|
||||
if input_data.summary.schema:
|
||||
summary_dict["schema"] = input_data.summary.schema
|
||||
payload["summary"] = summary_dict
|
||||
sdk_kwargs["summary"] = summary_dict
|
||||
|
||||
# Handle livecrawl
|
||||
if input_data.livecrawl:
|
||||
payload["livecrawl"] = input_data.livecrawl.value
|
||||
sdk_kwargs["livecrawl"] = input_data.livecrawl.value
|
||||
|
||||
# Handle livecrawl_timeout
|
||||
if input_data.livecrawl_timeout is not None:
|
||||
payload["livecrawlTimeout"] = input_data.livecrawl_timeout
|
||||
sdk_kwargs["livecrawl_timeout"] = input_data.livecrawl_timeout
|
||||
|
||||
# Handle subpages
|
||||
if input_data.subpages is not None:
|
||||
payload["subpages"] = input_data.subpages
|
||||
sdk_kwargs["subpages"] = input_data.subpages
|
||||
|
||||
# Handle subpage_target
|
||||
if input_data.subpage_target:
|
||||
payload["subpageTarget"] = input_data.subpage_target
|
||||
sdk_kwargs["subpage_target"] = input_data.subpage_target
|
||||
|
||||
# Handle extras - only include if modified from defaults
|
||||
if input_data.extras and (
|
||||
@@ -212,38 +202,36 @@ class ExaContentsBlock(Block):
|
||||
if input_data.extras.links:
|
||||
extras_dict["links"] = input_data.extras.links
|
||||
if input_data.extras.image_links:
|
||||
extras_dict["imageLinks"] = input_data.extras.image_links
|
||||
payload["extras"] = extras_dict
|
||||
extras_dict["image_links"] = input_data.extras.image_links
|
||||
sdk_kwargs["extras"] = extras_dict
|
||||
|
||||
# Always enable context for LLM-ready output
|
||||
payload["context"] = True
|
||||
sdk_kwargs["context"] = True
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
response = await aexa.get_contents(**sdk_kwargs)
|
||||
|
||||
# Extract all response fields
|
||||
yield "results", data.get("results", [])
|
||||
# SearchResponse is a dataclass, convert results to our Pydantic models
|
||||
converted_results = [
|
||||
ExaSearchResults.from_sdk(sdk_result)
|
||||
for sdk_result in response.results or []
|
||||
]
|
||||
|
||||
# Yield individual results
|
||||
for result in data.get("results", []):
|
||||
yield "result", result
|
||||
yield "results", converted_results
|
||||
|
||||
# Yield context if present
|
||||
if "context" in data:
|
||||
yield "context", data["context"]
|
||||
# Yield individual results
|
||||
for result in converted_results:
|
||||
yield "result", result
|
||||
|
||||
# Yield request ID if present
|
||||
if "requestId" in data:
|
||||
yield "request_id", data["requestId"]
|
||||
# Yield context if present
|
||||
if response.context:
|
||||
yield "context", response.context
|
||||
|
||||
# Yield statuses if present
|
||||
if "statuses" in data:
|
||||
yield "statuses", data["statuses"]
|
||||
# Yield statuses if present
|
||||
if response.statuses:
|
||||
yield "statuses", response.statuses
|
||||
|
||||
# Yield cost information if present
|
||||
if "costDollars" in data:
|
||||
yield "cost_dollars", data["costDollars"]
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
# Yield cost information if present
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Literal, Optional, TypeVar, Union
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
from backend.sdk import BaseModel, MediaFileType, SchemaField
|
||||
|
||||
@@ -257,6 +255,25 @@ class ExaSearchResults(BaseModel):
|
||||
subpages: list[dict] = SchemaField(default_factory=list)
|
||||
extras: ExaSearchExtras | None = None
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_result) -> "ExaSearchResults":
|
||||
"""Convert SDK Result (dataclass) to our Pydantic model."""
|
||||
return cls(
|
||||
id=getattr(sdk_result, "id", ""),
|
||||
url=getattr(sdk_result, "url", None),
|
||||
title=getattr(sdk_result, "title", None),
|
||||
author=getattr(sdk_result, "author", None),
|
||||
publishedDate=getattr(sdk_result, "published_date", None),
|
||||
text=getattr(sdk_result, "text", None),
|
||||
highlights=getattr(sdk_result, "highlights", None) or [],
|
||||
highlightScores=getattr(sdk_result, "highlight_scores", None) or [],
|
||||
summary=getattr(sdk_result, "summary", None),
|
||||
subpages=getattr(sdk_result, "subpages", None) or [],
|
||||
image=getattr(sdk_result, "image", None),
|
||||
favicon=getattr(sdk_result, "favicon", None),
|
||||
extras=getattr(sdk_result, "extras", None),
|
||||
)
|
||||
|
||||
|
||||
# Cost tracking models
|
||||
class CostBreakdown(BaseModel):
|
||||
@@ -440,258 +457,3 @@ def add_optional_fields(
|
||||
payload[api_field] = value.value
|
||||
else:
|
||||
payload[api_field] = value
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def poll_until_complete(
|
||||
check_fn: Callable[[], tuple[bool, T]],
|
||||
timeout: int = 300,
|
||||
initial_interval: float = 5.0,
|
||||
max_interval: float = 30.0,
|
||||
backoff_factor: float = 1.5,
|
||||
progress_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> T:
|
||||
"""
|
||||
Generic polling function for async operations.
|
||||
|
||||
Args:
|
||||
check_fn: Function that returns (is_complete, result)
|
||||
timeout: Maximum time to wait in seconds
|
||||
initial_interval: Initial wait interval between polls
|
||||
max_interval: Maximum wait interval between polls
|
||||
backoff_factor: Factor to increase interval by each iteration
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
The result from check_fn when complete
|
||||
|
||||
Raises:
|
||||
TimeoutError: If operation doesn't complete within timeout
|
||||
"""
|
||||
start_time = time.time()
|
||||
interval = initial_interval
|
||||
attempt = 0
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
attempt += 1
|
||||
is_complete, result = check_fn()
|
||||
|
||||
if is_complete:
|
||||
if progress_callback:
|
||||
progress_callback(f"✓ Operation completed after {attempt} attempts")
|
||||
return result
|
||||
|
||||
# Calculate remaining time
|
||||
elapsed = time.time() - start_time
|
||||
remaining = timeout - elapsed
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
f"⏳ Attempt {attempt}: Operation still in progress. "
|
||||
f"Elapsed: {int(elapsed)}s, Remaining: {int(remaining)}s"
|
||||
)
|
||||
|
||||
# Wait before next poll
|
||||
wait_time = min(interval, remaining)
|
||||
if wait_time > 0:
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
# Exponential backoff
|
||||
interval = min(interval * backoff_factor, max_interval)
|
||||
|
||||
raise TimeoutError(f"Operation did not complete within {timeout} seconds")
|
||||
|
||||
|
||||
async def poll_webset_status(
|
||||
webset_id: str,
|
||||
api_key: str,
|
||||
target_status: str = "idle",
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll a webset until it reaches the target status.
|
||||
|
||||
Args:
|
||||
webset_id: Webset ID to poll
|
||||
api_key: API key for authentication
|
||||
target_status: Status to wait for (default: "idle")
|
||||
timeout: Maximum time to wait in seconds
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
The webset data when target status is reached
|
||||
"""
|
||||
import httpx
|
||||
|
||||
def check_status() -> tuple[bool, Dict[str, Any]]:
|
||||
with httpx.Client() as client:
|
||||
response = client.get(
|
||||
f"https://api.exa.ai/v1alpha/websets/{webset_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
status = data.get("status", {}).get("type")
|
||||
is_complete = status == target_status
|
||||
|
||||
if progress_callback and not is_complete:
|
||||
items_count = data.get("itemsCount", 0)
|
||||
progress_callback(f"Status: {status}, Items: {items_count}")
|
||||
|
||||
return is_complete, data
|
||||
|
||||
return await poll_until_complete(
|
||||
check_fn=check_status, timeout=timeout, progress_callback=progress_callback
|
||||
)
|
||||
|
||||
|
||||
async def poll_search_completion(
|
||||
webset_id: str,
|
||||
search_id: str,
|
||||
api_key: str,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll a search until it completes.
|
||||
|
||||
Args:
|
||||
webset_id: Webset ID
|
||||
search_id: Search ID to poll
|
||||
api_key: API key for authentication
|
||||
timeout: Maximum time to wait in seconds
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
The search data when complete
|
||||
"""
|
||||
import httpx
|
||||
|
||||
def check_search() -> tuple[bool, Dict[str, Any]]:
|
||||
with httpx.Client() as client:
|
||||
response = client.get(
|
||||
f"https://api.exa.ai/v1alpha/websets/{webset_id}/searches/{search_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
status = data.get("status")
|
||||
is_complete = status in ["completed", "failed", "cancelled"]
|
||||
|
||||
if progress_callback and not is_complete:
|
||||
items_found = data.get("results", {}).get("itemsFound", 0)
|
||||
progress_callback(
|
||||
f"Search status: {status}, Items found: {items_found}"
|
||||
)
|
||||
|
||||
return is_complete, data
|
||||
|
||||
return await poll_until_complete(
|
||||
check_fn=check_search, timeout=timeout, progress_callback=progress_callback
|
||||
)
|
||||
|
||||
|
||||
async def poll_enrichment_completion(
|
||||
webset_id: str,
|
||||
enrichment_id: str,
|
||||
api_key: str,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll an enrichment until it completes.
|
||||
|
||||
Args:
|
||||
webset_id: Webset ID
|
||||
enrichment_id: Enrichment ID to poll
|
||||
api_key: API key for authentication
|
||||
timeout: Maximum time to wait in seconds
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
The enrichment data when complete
|
||||
"""
|
||||
import httpx
|
||||
|
||||
def check_enrichment() -> tuple[bool, Dict[str, Any]]:
|
||||
with httpx.Client() as client:
|
||||
response = client.get(
|
||||
f"https://api.exa.ai/v1alpha/websets/{webset_id}/enrichments/{enrichment_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
status = data.get("status")
|
||||
is_complete = status in ["completed", "failed", "cancelled"]
|
||||
|
||||
if progress_callback and not is_complete:
|
||||
progress = data.get("progress", {})
|
||||
processed = progress.get("processedItems", 0)
|
||||
total = progress.get("totalItems", 0)
|
||||
progress_callback(
|
||||
f"Enrichment status: {status}, Progress: {processed}/{total}"
|
||||
)
|
||||
|
||||
return is_complete, data
|
||||
|
||||
return await poll_until_complete(
|
||||
check_fn=check_enrichment, timeout=timeout, progress_callback=progress_callback
|
||||
)
|
||||
|
||||
|
||||
def format_progress_message(
|
||||
operation_type: str, current_state: str, details: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Format a progress message for display.
|
||||
|
||||
Args:
|
||||
operation_type: Type of operation (webset, search, enrichment)
|
||||
current_state: Current state description
|
||||
details: Optional details to include
|
||||
|
||||
Returns:
|
||||
Formatted progress message
|
||||
"""
|
||||
message_parts = [f"[{operation_type.upper()}]", current_state]
|
||||
|
||||
if details:
|
||||
detail_parts = []
|
||||
for key, value in details.items():
|
||||
detail_parts.append(f"{key}: {value}")
|
||||
if detail_parts:
|
||||
message_parts.append(f"({', '.join(detail_parts)})")
|
||||
|
||||
return " ".join(message_parts)
|
||||
|
||||
|
||||
def calculate_polling_stats(
|
||||
start_time: float, timeout: int, attempts: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate polling statistics.
|
||||
|
||||
Args:
|
||||
start_time: Start time (from time.time())
|
||||
timeout: Maximum timeout in seconds
|
||||
attempts: Number of polling attempts made
|
||||
|
||||
Returns:
|
||||
Dictionary with polling statistics
|
||||
"""
|
||||
elapsed = time.time() - start_time
|
||||
remaining = max(0, timeout - elapsed)
|
||||
|
||||
return {
|
||||
"elapsed_seconds": int(elapsed),
|
||||
"remaining_seconds": int(remaining),
|
||||
"attempts": attempts,
|
||||
"average_interval": elapsed / attempts if attempts > 0 else 0,
|
||||
"progress_percentage": min(100, (elapsed / timeout) * 100),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -9,7 +11,6 @@ from backend.sdk import (
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
@@ -18,8 +19,6 @@ from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
add_optional_fields,
|
||||
format_date_fields,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
@@ -106,11 +105,9 @@ class ExaSearchBlock(Block):
|
||||
results: list[ExaSearchResults] = SchemaField(
|
||||
description="List of search results"
|
||||
)
|
||||
result: ExaSearchResults = SchemaField(
|
||||
description="Single search result",
|
||||
)
|
||||
result: ExaSearchResults = SchemaField(description="Single search result")
|
||||
context: str = SchemaField(
|
||||
description="A formatted string of the search results ready for LLMs.",
|
||||
description="A formatted string of the search results ready for LLMs."
|
||||
)
|
||||
search_type: str = SchemaField(
|
||||
description="For auto searches, indicates which search type was selected."
|
||||
@@ -120,11 +117,9 @@ class ExaSearchBlock(Block):
|
||||
description="The search type that was actually used for this request (neural or keyword)"
|
||||
)
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request", default=None
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
description="Cost breakdown for the request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -138,82 +133,88 @@ class ExaSearchBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/search"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
# Build kwargs for SDK call
|
||||
sdk_kwargs = {
|
||||
"query": input_data.query,
|
||||
"numResults": input_data.number_of_results,
|
||||
"num_results": input_data.number_of_results,
|
||||
}
|
||||
|
||||
# Handle contents field with helper function
|
||||
content_settings = process_contents_settings(input_data.contents)
|
||||
if content_settings:
|
||||
payload["contents"] = content_settings
|
||||
# Handle type field
|
||||
if input_data.type:
|
||||
sdk_kwargs["type"] = input_data.type.value
|
||||
|
||||
# Handle date fields with helper function
|
||||
date_field_mapping = {
|
||||
"start_crawl_date": "startCrawlDate",
|
||||
"end_crawl_date": "endCrawlDate",
|
||||
"start_published_date": "startPublishedDate",
|
||||
"end_published_date": "endPublishedDate",
|
||||
}
|
||||
payload.update(format_date_fields(input_data, date_field_mapping))
|
||||
# Handle category field
|
||||
if input_data.category:
|
||||
sdk_kwargs["category"] = input_data.category.value
|
||||
|
||||
# Handle enum fields separately since they need special processing
|
||||
for field_name, api_field in [("type", "type"), ("category", "category")]:
|
||||
value = getattr(input_data, field_name, None)
|
||||
if value:
|
||||
payload[api_field] = value.value if hasattr(value, "value") else value
|
||||
# Handle user_location
|
||||
if input_data.user_location:
|
||||
sdk_kwargs["user_location"] = input_data.user_location
|
||||
|
||||
# Handle other optional fields
|
||||
optional_field_mapping = {
|
||||
"user_location": "userLocation",
|
||||
"include_domains": "includeDomains",
|
||||
"exclude_domains": "excludeDomains",
|
||||
"include_text": "includeText",
|
||||
"exclude_text": "excludeText",
|
||||
}
|
||||
add_optional_fields(input_data, optional_field_mapping, payload)
|
||||
# Handle domains
|
||||
if input_data.include_domains:
|
||||
sdk_kwargs["include_domains"] = input_data.include_domains
|
||||
if input_data.exclude_domains:
|
||||
sdk_kwargs["exclude_domains"] = input_data.exclude_domains
|
||||
|
||||
# Add moderation field
|
||||
# Handle dates
|
||||
if input_data.start_crawl_date:
|
||||
sdk_kwargs["start_crawl_date"] = input_data.start_crawl_date.isoformat()
|
||||
if input_data.end_crawl_date:
|
||||
sdk_kwargs["end_crawl_date"] = input_data.end_crawl_date.isoformat()
|
||||
if input_data.start_published_date:
|
||||
sdk_kwargs["start_published_date"] = (
|
||||
input_data.start_published_date.isoformat()
|
||||
)
|
||||
if input_data.end_published_date:
|
||||
sdk_kwargs["end_published_date"] = input_data.end_published_date.isoformat()
|
||||
|
||||
# Handle text filters
|
||||
if input_data.include_text:
|
||||
sdk_kwargs["include_text"] = input_data.include_text
|
||||
if input_data.exclude_text:
|
||||
sdk_kwargs["exclude_text"] = input_data.exclude_text
|
||||
|
||||
# Handle moderation
|
||||
if input_data.moderation:
|
||||
payload["moderation"] = input_data.moderation
|
||||
sdk_kwargs["moderation"] = input_data.moderation
|
||||
|
||||
# Always enable context for LLM-ready output
|
||||
payload["context"] = True
|
||||
# Handle contents - check if we need to use search_and_contents
|
||||
content_settings = process_contents_settings(input_data.contents)
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Extract all response fields
|
||||
yield "results", data.get("results", [])
|
||||
for result in data.get("results", []):
|
||||
yield "result", result
|
||||
if content_settings:
|
||||
# Use search_and_contents when contents are requested
|
||||
sdk_kwargs["text"] = content_settings.get("text", False)
|
||||
if "highlights" in content_settings:
|
||||
sdk_kwargs["highlights"] = content_settings["highlights"]
|
||||
if "summary" in content_settings:
|
||||
sdk_kwargs["summary"] = content_settings["summary"]
|
||||
response = await aexa.search_and_contents(**sdk_kwargs)
|
||||
else:
|
||||
# Use regular search when no contents requested
|
||||
response = await aexa.search(**sdk_kwargs)
|
||||
|
||||
# Yield context if present
|
||||
if "context" in data:
|
||||
yield "context", data["context"]
|
||||
# SearchResponse is a dataclass, convert results to our Pydantic models
|
||||
converted_results = [
|
||||
ExaSearchResults.from_sdk(sdk_result)
|
||||
for sdk_result in response.results or []
|
||||
]
|
||||
|
||||
# Yield search type if present
|
||||
if "searchType" in data:
|
||||
yield "search_type", data["searchType"]
|
||||
yield "results", converted_results
|
||||
for result in converted_results:
|
||||
yield "result", result
|
||||
|
||||
# Yield request ID if present
|
||||
if "requestId" in data:
|
||||
yield "request_id", data["requestId"]
|
||||
# Yield context if present
|
||||
if response.context:
|
||||
yield "context", response.context
|
||||
|
||||
# Yield resolved search type if present
|
||||
if "resolvedSearchType" in data:
|
||||
yield "resolved_search_type", data["resolvedSearchType"]
|
||||
# Yield resolved search type if present
|
||||
if response.resolved_search_type:
|
||||
yield "resolved_search_type", response.resolved_search_type
|
||||
|
||||
# Yield cost information if present
|
||||
if "costDollars" in data:
|
||||
yield "cost_dollars", data["costDollars"]
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
# Yield cost information if present
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -8,7 +10,6 @@ from backend.sdk import (
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
@@ -17,8 +18,6 @@ from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
add_optional_fields,
|
||||
format_date_fields,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
@@ -82,18 +81,16 @@ class ExaFindSimilarBlock(Block):
|
||||
description="List of similar documents with metadata and content"
|
||||
)
|
||||
result: ExaSearchResults = SchemaField(
|
||||
description="Single similar document result",
|
||||
description="Single similar document result"
|
||||
)
|
||||
context: str = SchemaField(
|
||||
description="A formatted string of the results ready for LLMs.",
|
||||
description="A formatted string of the results ready for LLMs."
|
||||
)
|
||||
request_id: str = SchemaField(description="Unique identifier for the request")
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request", default=None
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
description="Cost breakdown for the request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -107,67 +104,72 @@ class ExaFindSimilarBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/findSimilar"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
# Build kwargs for SDK call
|
||||
sdk_kwargs = {
|
||||
"url": input_data.url,
|
||||
"numResults": input_data.number_of_results,
|
||||
"num_results": input_data.number_of_results,
|
||||
}
|
||||
|
||||
# Handle contents field with helper function
|
||||
content_settings = process_contents_settings(input_data.contents)
|
||||
if content_settings:
|
||||
payload["contents"] = content_settings
|
||||
# Handle domains
|
||||
if input_data.include_domains:
|
||||
sdk_kwargs["include_domains"] = input_data.include_domains
|
||||
if input_data.exclude_domains:
|
||||
sdk_kwargs["exclude_domains"] = input_data.exclude_domains
|
||||
|
||||
# Handle date fields with helper function
|
||||
date_field_mapping = {
|
||||
"start_crawl_date": "startCrawlDate",
|
||||
"end_crawl_date": "endCrawlDate",
|
||||
"start_published_date": "startPublishedDate",
|
||||
"end_published_date": "endPublishedDate",
|
||||
}
|
||||
payload.update(format_date_fields(input_data, date_field_mapping))
|
||||
# Handle dates
|
||||
if input_data.start_crawl_date:
|
||||
sdk_kwargs["start_crawl_date"] = input_data.start_crawl_date.isoformat()
|
||||
if input_data.end_crawl_date:
|
||||
sdk_kwargs["end_crawl_date"] = input_data.end_crawl_date.isoformat()
|
||||
if input_data.start_published_date:
|
||||
sdk_kwargs["start_published_date"] = (
|
||||
input_data.start_published_date.isoformat()
|
||||
)
|
||||
if input_data.end_published_date:
|
||||
sdk_kwargs["end_published_date"] = input_data.end_published_date.isoformat()
|
||||
|
||||
# Handle other optional fields
|
||||
optional_field_mapping = {
|
||||
"include_domains": "includeDomains",
|
||||
"exclude_domains": "excludeDomains",
|
||||
"include_text": "includeText",
|
||||
"exclude_text": "excludeText",
|
||||
}
|
||||
add_optional_fields(input_data, optional_field_mapping, payload)
|
||||
# Handle text filters
|
||||
if input_data.include_text:
|
||||
sdk_kwargs["include_text"] = input_data.include_text
|
||||
if input_data.exclude_text:
|
||||
sdk_kwargs["exclude_text"] = input_data.exclude_text
|
||||
|
||||
# Add moderation field
|
||||
# Handle moderation
|
||||
if input_data.moderation:
|
||||
payload["moderation"] = input_data.moderation
|
||||
sdk_kwargs["moderation"] = input_data.moderation
|
||||
|
||||
# Always enable context for LLM-ready output
|
||||
payload["context"] = True
|
||||
# Handle contents - check if we need to use find_similar_and_contents
|
||||
content_settings = process_contents_settings(input_data.contents)
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Extract all response fields
|
||||
yield "results", data.get("results", [])
|
||||
for result in data.get("results", []):
|
||||
yield "result", result
|
||||
if content_settings:
|
||||
# Use find_similar_and_contents when contents are requested
|
||||
sdk_kwargs["text"] = content_settings.get("text", False)
|
||||
if "highlights" in content_settings:
|
||||
sdk_kwargs["highlights"] = content_settings["highlights"]
|
||||
if "summary" in content_settings:
|
||||
sdk_kwargs["summary"] = content_settings["summary"]
|
||||
response = await aexa.find_similar_and_contents(**sdk_kwargs)
|
||||
else:
|
||||
# Use regular find_similar when no contents requested
|
||||
response = await aexa.find_similar(**sdk_kwargs)
|
||||
|
||||
# Yield context if present
|
||||
if "context" in data:
|
||||
yield "context", data["context"]
|
||||
# SearchResponse is a dataclass, convert results to our Pydantic models
|
||||
converted_results = [
|
||||
ExaSearchResults.from_sdk(sdk_result)
|
||||
for sdk_result in response.results or []
|
||||
]
|
||||
|
||||
# Yield request ID if present
|
||||
if "requestId" in data:
|
||||
yield "request_id", data["requestId"]
|
||||
yield "results", converted_results
|
||||
for result in converted_results:
|
||||
yield "result", result
|
||||
|
||||
# Yield cost information if present
|
||||
if "costDollars" in data:
|
||||
yield "cost_dollars", data["costDollars"]
|
||||
# Yield context if present
|
||||
if response.context:
|
||||
yield "context", response.context
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
# Yield cost information if present
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
|
||||
@@ -131,45 +131,33 @@ class ExaWebsetWebhookBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""Process incoming Exa webhook payload."""
|
||||
try:
|
||||
payload = input_data.payload
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract event details
|
||||
event_type = payload.get("eventType", "unknown")
|
||||
event_id = payload.get("eventId", "")
|
||||
# Extract event details
|
||||
event_type = payload.get("eventType", "unknown")
|
||||
event_id = payload.get("eventId", "")
|
||||
|
||||
# Get webset ID from payload or input
|
||||
webset_id = payload.get("websetId", input_data.webset_id)
|
||||
# Get webset ID from payload or input
|
||||
webset_id = payload.get("websetId", input_data.webset_id)
|
||||
|
||||
# Check if we should process this event based on filter
|
||||
should_process = self._should_process_event(
|
||||
event_type, input_data.event_filter
|
||||
)
|
||||
# Check if we should process this event based on filter
|
||||
should_process = self._should_process_event(event_type, input_data.event_filter)
|
||||
|
||||
if not should_process:
|
||||
# Skip events that don't match our filter
|
||||
return
|
||||
if not should_process:
|
||||
# Skip events that don't match our filter
|
||||
return
|
||||
|
||||
# Extract event data
|
||||
event_data = payload.get("data", {})
|
||||
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
|
||||
metadata = payload.get("metadata", {})
|
||||
# Extract event data
|
||||
event_data = payload.get("data", {})
|
||||
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
|
||||
metadata = payload.get("metadata", {})
|
||||
|
||||
yield "event_type", event_type
|
||||
yield "event_id", event_id
|
||||
yield "webset_id", webset_id
|
||||
yield "data", event_data
|
||||
yield "timestamp", timestamp
|
||||
yield "metadata", metadata
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors gracefully
|
||||
yield "event_type", "error"
|
||||
yield "event_id", ""
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "data", {"error": str(e)}
|
||||
yield "timestamp", ""
|
||||
yield "metadata", {}
|
||||
yield "event_type", event_type
|
||||
yield "event_id", event_id
|
||||
yield "webset_id", webset_id
|
||||
yield "data", event_data
|
||||
yield "timestamp", timestamp
|
||||
yield "metadata", metadata
|
||||
|
||||
def _should_process_event(
|
||||
self, event_type: str, event_filter: WebsetEventFilter
|
||||
|
||||
@@ -702,9 +702,6 @@ class ExaGetWebsetBlock(Block):
|
||||
description="The enrichments applied to the webset"
|
||||
)
|
||||
monitors: list[dict] = SchemaField(description="The monitors for the webset")
|
||||
items: Optional[list[dict]] = SchemaField(
|
||||
description="The items in the webset (if expand_items is true)"
|
||||
)
|
||||
metadata: dict = SchemaField(
|
||||
description="Key-value pairs associated with the webset"
|
||||
)
|
||||
@@ -761,7 +758,6 @@ class ExaGetWebsetBlock(Block):
|
||||
yield "searches", searches_data
|
||||
yield "enrichments", enrichments_data
|
||||
yield "monitors", monitors_data
|
||||
yield "items", None # SDK doesn't expand items by default
|
||||
yield "metadata", sdk_webset.metadata or {}
|
||||
yield "created_at", (
|
||||
sdk_webset.created_at.isoformat() if sdk_webset.created_at else ""
|
||||
@@ -877,6 +873,92 @@ class ExaCancelWebsetBlock(Block):
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
# Mirrored models for Preview response stability
|
||||
class PreviewCriterionModel(BaseModel):
|
||||
"""Stable model for preview criteria."""
|
||||
|
||||
description: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_criterion) -> "PreviewCriterionModel":
|
||||
"""Convert SDK criterion to our model."""
|
||||
return cls(description=sdk_criterion.description)
|
||||
|
||||
|
||||
class PreviewEnrichmentModel(BaseModel):
|
||||
"""Stable model for preview enrichment."""
|
||||
|
||||
description: str
|
||||
format: str
|
||||
options: List[str]
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_enrichment) -> "PreviewEnrichmentModel":
|
||||
"""Convert SDK enrichment to our model."""
|
||||
format_str = (
|
||||
sdk_enrichment.format.value
|
||||
if hasattr(sdk_enrichment.format, "value")
|
||||
else str(sdk_enrichment.format)
|
||||
)
|
||||
|
||||
options_list = []
|
||||
if sdk_enrichment.options:
|
||||
for opt in sdk_enrichment.options:
|
||||
opt_dict = opt.model_dump(by_alias=True)
|
||||
options_list.append(opt_dict.get("label", ""))
|
||||
|
||||
return cls(
|
||||
description=sdk_enrichment.description,
|
||||
format=format_str,
|
||||
options=options_list,
|
||||
)
|
||||
|
||||
|
||||
class PreviewSearchModel(BaseModel):
|
||||
"""Stable model for preview search details."""
|
||||
|
||||
entity_type: str
|
||||
entity_description: Optional[str]
|
||||
criteria: List[PreviewCriterionModel]
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_search) -> "PreviewSearchModel":
|
||||
"""Convert SDK search preview to our model."""
|
||||
# Extract entity type from union
|
||||
entity_dict = sdk_search.entity.model_dump(by_alias=True)
|
||||
entity_type = entity_dict.get("type", "auto")
|
||||
entity_description = entity_dict.get("description")
|
||||
|
||||
# Convert criteria
|
||||
criteria = [
|
||||
PreviewCriterionModel.from_sdk(c) for c in sdk_search.criteria or []
|
||||
]
|
||||
|
||||
return cls(
|
||||
entity_type=entity_type,
|
||||
entity_description=entity_description,
|
||||
criteria=criteria,
|
||||
)
|
||||
|
||||
|
||||
class PreviewWebsetModel(BaseModel):
|
||||
"""Stable model for preview response."""
|
||||
|
||||
search: PreviewSearchModel
|
||||
enrichments: List[PreviewEnrichmentModel]
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_preview) -> "PreviewWebsetModel":
|
||||
"""Convert SDK PreviewWebsetResponse to our model."""
|
||||
|
||||
search = PreviewSearchModel.from_sdk(sdk_preview.search)
|
||||
enrichments = [
|
||||
PreviewEnrichmentModel.from_sdk(e) for e in sdk_preview.enrichments or []
|
||||
]
|
||||
|
||||
return cls(search=search, enrichments=enrichments)
|
||||
|
||||
|
||||
class ExaPreviewWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
@@ -898,16 +980,19 @@ class ExaPreviewWebsetBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
preview: PreviewWebsetModel = SchemaField(
|
||||
description="Full preview response with search and enrichment details"
|
||||
)
|
||||
entity_type: str = SchemaField(
|
||||
description="The detected or specified entity type"
|
||||
)
|
||||
entity_description: Optional[str] = SchemaField(
|
||||
description="Description of the entity type"
|
||||
)
|
||||
criteria: list[dict] = SchemaField(
|
||||
criteria: list[PreviewCriterionModel] = SchemaField(
|
||||
description="Generated search criteria that will be used"
|
||||
)
|
||||
enrichment_columns: list[dict] = SchemaField(
|
||||
enrichment_columns: list[PreviewEnrichmentModel] = SchemaField(
|
||||
description="Available enrichment columns that can be extracted"
|
||||
)
|
||||
interpretation: str = SchemaField(
|
||||
@@ -949,24 +1034,16 @@ class ExaPreviewWebsetBlock(Block):
|
||||
payload["entity"] = entity
|
||||
|
||||
# Preview webset using SDK - no await needed
|
||||
preview = aexa.websets.preview(params=payload)
|
||||
sdk_preview = aexa.websets.preview(params=payload)
|
||||
|
||||
# Extract entity information
|
||||
entity_dict = preview.entity.model_dump(by_alias=True, exclude_none=True)
|
||||
entity_type = entity_dict.get("type", "auto")
|
||||
entity_description = entity_dict.get("description")
|
||||
# Convert to our stable Pydantic model
|
||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||
|
||||
# Extract criteria
|
||||
criteria = [
|
||||
c.model_dump(by_alias=True, exclude_none=True)
|
||||
for c in preview.criteria or []
|
||||
]
|
||||
|
||||
# Extract enrichment columns
|
||||
enrichments = [
|
||||
e.model_dump(by_alias=True, exclude_none=True)
|
||||
for e in preview.enrichment_columns or []
|
||||
]
|
||||
# Extract details for individual fields (for easier graph connections)
|
||||
entity_type = preview.search.entity_type
|
||||
entity_description = preview.search.entity_description
|
||||
criteria = preview.search.criteria
|
||||
enrichments = preview.enrichments
|
||||
|
||||
# Generate interpretation
|
||||
interpretation = f"Query will search for {entity_type}"
|
||||
@@ -977,7 +1054,7 @@ class ExaPreviewWebsetBlock(Block):
|
||||
if enrichments:
|
||||
interpretation += f" and {len(enrichments)} available enrichment columns"
|
||||
|
||||
# Generate suggestions (could be enhanced based on the response)
|
||||
# Generate suggestions
|
||||
suggestions = []
|
||||
if not criteria:
|
||||
suggestions.append(
|
||||
@@ -988,6 +1065,10 @@ class ExaPreviewWebsetBlock(Block):
|
||||
"Consider specifying what data points you want to extract"
|
||||
)
|
||||
|
||||
# Yield full model first
|
||||
yield "preview", preview
|
||||
|
||||
# Then yield individual fields for graph flexibility
|
||||
yield "entity_type", entity_type
|
||||
yield "entity_description", entity_description
|
||||
yield "criteria", criteria
|
||||
@@ -1073,6 +1154,42 @@ class ExaWebsetStatusBlock(Block):
|
||||
yield "is_processing", is_processing
|
||||
|
||||
|
||||
# Summary models for ExaWebsetSummaryBlock
|
||||
class SearchSummaryModel(BaseModel):
|
||||
"""Summary of searches in a webset."""
|
||||
|
||||
total_searches: int
|
||||
completed_searches: int
|
||||
total_items_found: int
|
||||
queries: List[str]
|
||||
|
||||
|
||||
class EnrichmentSummaryModel(BaseModel):
|
||||
"""Summary of enrichments in a webset."""
|
||||
|
||||
total_enrichments: int
|
||||
completed_enrichments: int
|
||||
enrichment_types: List[str]
|
||||
titles: List[str]
|
||||
|
||||
|
||||
class MonitorSummaryModel(BaseModel):
|
||||
"""Summary of monitors in a webset."""
|
||||
|
||||
total_monitors: int
|
||||
active_monitors: int
|
||||
next_run: Optional[datetime] = None
|
||||
|
||||
|
||||
class WebsetStatisticsModel(BaseModel):
|
||||
"""Various statistics about a webset."""
|
||||
|
||||
total_operations: int
|
||||
is_processing: bool
|
||||
has_monitors: bool
|
||||
avg_items_per_search: float
|
||||
|
||||
|
||||
class ExaWebsetSummaryBlock(Block):
|
||||
"""Get a comprehensive summary of a webset including samples and statistics."""
|
||||
|
||||
@@ -1105,23 +1222,22 @@ class ExaWebsetSummaryBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The webset identifier")
|
||||
title: Optional[str] = SchemaField(
|
||||
description="Title of the webset if available"
|
||||
)
|
||||
status: str = SchemaField(description="Current status")
|
||||
entity_type: str = SchemaField(description="Type of entities in the webset")
|
||||
total_items: int = SchemaField(description="Total number of items")
|
||||
sample_items: list[dict] = SchemaField(
|
||||
sample_items: list[Dict[str, Any]] = SchemaField(
|
||||
description="Sample items from the webset"
|
||||
)
|
||||
search_summary: dict = SchemaField(description="Summary of searches performed")
|
||||
enrichment_summary: dict = SchemaField(
|
||||
search_summary: SearchSummaryModel = SchemaField(
|
||||
description="Summary of searches performed"
|
||||
)
|
||||
enrichment_summary: EnrichmentSummaryModel = SchemaField(
|
||||
description="Summary of enrichments applied"
|
||||
)
|
||||
monitor_summary: dict = SchemaField(
|
||||
monitor_summary: MonitorSummaryModel = SchemaField(
|
||||
description="Summary of monitors configured"
|
||||
)
|
||||
statistics: dict = SchemaField(
|
||||
statistics: WebsetStatisticsModel = SchemaField(
|
||||
description="Various statistics about the webset"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the webset was created")
|
||||
@@ -1159,10 +1275,11 @@ class ExaWebsetSummaryBlock(Block):
|
||||
searches = webset.searches or []
|
||||
if searches:
|
||||
first_search = searches[0]
|
||||
entity_dict = first_search.entity.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
entity_type = entity_dict.get("type", "unknown")
|
||||
if first_search.entity:
|
||||
entity_dict = first_search.entity.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
entity_type = entity_dict.get("type", "unknown")
|
||||
|
||||
# Get sample items if requested
|
||||
sample_items_data = []
|
||||
@@ -1178,81 +1295,88 @@ class ExaWebsetSummaryBlock(Block):
|
||||
]
|
||||
total_items = len(sample_items_data)
|
||||
|
||||
# Build search summary
|
||||
search_summary = {}
|
||||
# Build search summary using Pydantic model
|
||||
search_summary = SearchSummaryModel(
|
||||
total_searches=0,
|
||||
completed_searches=0,
|
||||
total_items_found=0,
|
||||
queries=[],
|
||||
)
|
||||
if input_data.include_search_details and searches:
|
||||
search_summary = {
|
||||
"total_searches": len(searches),
|
||||
"completed_searches": sum(
|
||||
search_summary = SearchSummaryModel(
|
||||
total_searches=len(searches),
|
||||
completed_searches=sum(
|
||||
1
|
||||
for s in searches
|
||||
if (s.status.value if hasattr(s.status, "value") else str(s.status))
|
||||
== "completed"
|
||||
),
|
||||
"total_items_found": sum(
|
||||
s.progress.found if s.progress else 0 for s in searches
|
||||
total_items_found=int(
|
||||
sum(s.progress.found if s.progress else 0 for s in searches)
|
||||
),
|
||||
"queries": [s.query for s in searches[:3]], # First 3 queries
|
||||
}
|
||||
queries=[s.query for s in searches[:3]], # First 3 queries
|
||||
)
|
||||
|
||||
# Build enrichment summary
|
||||
enrichment_summary = {}
|
||||
# Build enrichment summary using Pydantic model
|
||||
enrichment_summary = EnrichmentSummaryModel(
|
||||
total_enrichments=0,
|
||||
completed_enrichments=0,
|
||||
enrichment_types=[],
|
||||
titles=[],
|
||||
)
|
||||
enrichments = webset.enrichments or []
|
||||
if input_data.include_enrichment_details and enrichments:
|
||||
enrichment_summary = {
|
||||
"total_enrichments": len(enrichments),
|
||||
"completed_enrichments": sum(
|
||||
enrichment_summary = EnrichmentSummaryModel(
|
||||
total_enrichments=len(enrichments),
|
||||
completed_enrichments=sum(
|
||||
1
|
||||
for e in enrichments
|
||||
if (e.status.value if hasattr(e.status, "value") else str(e.status))
|
||||
== "completed"
|
||||
),
|
||||
"enrichment_types": list(
|
||||
enrichment_types=list(
|
||||
set(
|
||||
(
|
||||
e.format.value
|
||||
if hasattr(e.format, "value")
|
||||
else str(e.format)
|
||||
if e.format and hasattr(e.format, "value")
|
||||
else str(e.format) if e.format else "text"
|
||||
)
|
||||
for e in enrichments
|
||||
)
|
||||
),
|
||||
"titles": [
|
||||
(e.title or e.description or "")[:50] for e in enrichments[:3]
|
||||
],
|
||||
}
|
||||
titles=[(e.title or e.description or "")[:50] for e in enrichments[:3]],
|
||||
)
|
||||
|
||||
# Build monitor summary
|
||||
# Build monitor summary using Pydantic model
|
||||
monitors = webset.monitors or []
|
||||
monitor_summary = {
|
||||
"total_monitors": len(monitors),
|
||||
"active_monitors": sum(
|
||||
next_run_dt = None
|
||||
if monitors:
|
||||
next_runs = [m.next_run_at for m in monitors if m.next_run_at]
|
||||
if next_runs:
|
||||
next_run_dt = min(next_runs)
|
||||
|
||||
monitor_summary = MonitorSummaryModel(
|
||||
total_monitors=len(monitors),
|
||||
active_monitors=sum(
|
||||
1
|
||||
for m in monitors
|
||||
if (m.status.value if hasattr(m.status, "value") else str(m.status))
|
||||
== "enabled"
|
||||
),
|
||||
}
|
||||
next_run=next_run_dt,
|
||||
)
|
||||
|
||||
if monitors:
|
||||
next_runs = [m.next_run_at.isoformat() for m in monitors if m.next_run_at]
|
||||
if next_runs:
|
||||
monitor_summary["next_run"] = min(next_runs)
|
||||
|
||||
# Build statistics
|
||||
statistics = {
|
||||
"total_operations": len(searches) + len(enrichments),
|
||||
"is_processing": status in ["running", "pending"],
|
||||
"has_monitors": len(monitors) > 0,
|
||||
"avg_items_per_search": (
|
||||
search_summary.get("total_items_found", 0) / len(searches)
|
||||
if searches
|
||||
else 0
|
||||
# Build statistics using Pydantic model
|
||||
statistics = WebsetStatisticsModel(
|
||||
total_operations=len(searches) + len(enrichments),
|
||||
is_processing=status in ["running", "pending"],
|
||||
has_monitors=len(monitors) > 0,
|
||||
avg_items_per_search=(
|
||||
search_summary.total_items_found / len(searches) if searches else 0
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
yield "webset_id", webset_id
|
||||
yield "title", None # SDK doesn't have title field
|
||||
yield "status", status
|
||||
yield "entity_type", entity_type
|
||||
yield "total_items", total_items
|
||||
|
||||
@@ -9,9 +9,10 @@ import csv
|
||||
import json
|
||||
from enum import Enum
|
||||
from io import StringIO
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import CreateImportResponse
|
||||
from exa_py.websets.types import Import as SdkImport
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -38,7 +39,8 @@ class ImportModel(BaseModel):
|
||||
format: str
|
||||
entity_type: str
|
||||
count: int
|
||||
size: int
|
||||
upload_url: Optional[str] # Only in CreateImportResponse
|
||||
upload_valid_until: Optional[str] # Only in CreateImportResponse
|
||||
failed_reason: str
|
||||
failed_message: str
|
||||
metadata: dict
|
||||
@@ -46,11 +48,15 @@ class ImportModel(BaseModel):
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, import_obj: SdkImport) -> "ImportModel":
|
||||
"""Convert SDK Import to our stable model."""
|
||||
# Extract entity type from union
|
||||
entity_dict = import_obj.entity.model_dump(by_alias=True, exclude_none=True)
|
||||
entity_type = entity_dict.get("type", "unknown")
|
||||
def from_sdk(
|
||||
cls, import_obj: Union[SdkImport, CreateImportResponse]
|
||||
) -> "ImportModel":
|
||||
"""Convert SDK Import or CreateImportResponse to our stable model."""
|
||||
# Extract entity type from union (may be None)
|
||||
entity_type = "unknown"
|
||||
if import_obj.entity:
|
||||
entity_dict = import_obj.entity.model_dump(by_alias=True, exclude_none=True)
|
||||
entity_type = entity_dict.get("type", "unknown")
|
||||
|
||||
# Handle status enum
|
||||
status_str = (
|
||||
@@ -66,15 +72,29 @@ class ImportModel(BaseModel):
|
||||
else str(import_obj.format)
|
||||
)
|
||||
|
||||
# Handle failed_reason enum (may be None or enum)
|
||||
failed_reason_str = ""
|
||||
if import_obj.failed_reason:
|
||||
failed_reason_str = (
|
||||
import_obj.failed_reason.value
|
||||
if hasattr(import_obj.failed_reason, "value")
|
||||
else str(import_obj.failed_reason)
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=import_obj.id,
|
||||
status=status_str,
|
||||
title=import_obj.title or "",
|
||||
format=format_str,
|
||||
entity_type=entity_type,
|
||||
count=import_obj.count or 0,
|
||||
size=import_obj.size or 0,
|
||||
failed_reason=import_obj.failed_reason or "",
|
||||
count=int(import_obj.count or 0),
|
||||
upload_url=getattr(
|
||||
import_obj, "upload_url", None
|
||||
), # Only in CreateImportResponse
|
||||
upload_valid_until=getattr(
|
||||
import_obj, "upload_valid_until", None
|
||||
), # Only in CreateImportResponse
|
||||
failed_reason=failed_reason_str,
|
||||
failed_message=import_obj.failed_message or "",
|
||||
metadata=import_obj.metadata or {},
|
||||
created_at=(
|
||||
@@ -160,6 +180,12 @@ class ExaCreateImportBlock(Block):
|
||||
title: str = SchemaField(description="Title of the import")
|
||||
count: int = SchemaField(description="Number of items in the import")
|
||||
entity_type: str = SchemaField(description="Type of entities imported")
|
||||
upload_url: Optional[str] = SchemaField(
|
||||
description="Upload URL for CSV data (only if csv_data not provided in request)"
|
||||
)
|
||||
upload_valid_until: Optional[str] = SchemaField(
|
||||
description="Expiration time for upload URL (only if upload_url is provided)"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the import was created")
|
||||
error: str = SchemaField(description="Error message if the import failed")
|
||||
|
||||
@@ -243,6 +269,8 @@ class ExaCreateImportBlock(Block):
|
||||
yield "title", import_obj.title
|
||||
yield "count", import_obj.count
|
||||
yield "entity_type", import_obj.entity_type
|
||||
yield "upload_url", import_obj.upload_url
|
||||
yield "upload_valid_until", import_obj.upload_valid_until
|
||||
yield "created_at", import_obj.created_at
|
||||
|
||||
|
||||
@@ -265,6 +293,12 @@ class ExaGetImportBlock(Block):
|
||||
format: str = SchemaField(description="Format of the imported data")
|
||||
entity_type: str = SchemaField(description="Type of entities imported")
|
||||
count: int = SchemaField(description="Number of items imported")
|
||||
upload_url: Optional[str] = SchemaField(
|
||||
description="Upload URL for CSV data (if import not yet uploaded)"
|
||||
)
|
||||
upload_valid_until: Optional[str] = SchemaField(
|
||||
description="Expiration time for upload URL (if applicable)"
|
||||
)
|
||||
failed_reason: Optional[str] = SchemaField(
|
||||
description="Reason for failure (if applicable)"
|
||||
)
|
||||
@@ -304,6 +338,8 @@ class ExaGetImportBlock(Block):
|
||||
yield "format", import_obj.format
|
||||
yield "entity_type", import_obj.entity_type
|
||||
yield "count", import_obj.count
|
||||
yield "upload_url", import_obj.upload_url
|
||||
yield "upload_valid_until", import_obj.upload_valid_until
|
||||
yield "failed_reason", import_obj.failed_reason
|
||||
yield "failed_message", import_obj.failed_message
|
||||
yield "created_at", import_obj.created_at
|
||||
@@ -442,10 +478,10 @@ class ExaExportWebsetBlock(Block):
|
||||
description="Include enrichment data in export",
|
||||
)
|
||||
max_items: int = SchemaField(
|
||||
default=1000,
|
||||
default=100,
|
||||
description="Maximum number of items to export",
|
||||
ge=1,
|
||||
le=10000,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -9,7 +9,14 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import WebsetItem as SdkWebsetItem
|
||||
from pydantic import BaseModel
|
||||
from exa_py.websets.types import (
|
||||
WebsetItemArticleProperties,
|
||||
WebsetItemCompanyProperties,
|
||||
WebsetItemCustomProperties,
|
||||
WebsetItemPersonProperties,
|
||||
WebsetItemResearchPaperProperties,
|
||||
)
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
@@ -24,16 +31,50 @@ from backend.sdk import (
|
||||
from ._config import exa
|
||||
|
||||
|
||||
# Mirrored model for enrichment results
|
||||
class EnrichmentResultModel(BaseModel):
|
||||
"""Stable output model mirroring SDK EnrichmentResult."""
|
||||
|
||||
enrichment_id: str
|
||||
format: str
|
||||
result: Optional[List[str]]
|
||||
reasoning: Optional[str]
|
||||
references: List[Dict[str, Any]]
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_enrich) -> "EnrichmentResultModel":
|
||||
"""Convert SDK EnrichmentResult to our model."""
|
||||
format_str = (
|
||||
sdk_enrich.format.value
|
||||
if hasattr(sdk_enrich.format, "value")
|
||||
else str(sdk_enrich.format)
|
||||
)
|
||||
|
||||
# Convert references to dicts
|
||||
references_list = []
|
||||
if sdk_enrich.references:
|
||||
for ref in sdk_enrich.references:
|
||||
references_list.append(ref.model_dump(by_alias=True, exclude_none=True))
|
||||
|
||||
return cls(
|
||||
enrichment_id=sdk_enrich.enrichment_id,
|
||||
format=format_str,
|
||||
result=sdk_enrich.result,
|
||||
reasoning=sdk_enrich.reasoning,
|
||||
references=references_list,
|
||||
)
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
class WebsetItemModel(BaseModel):
|
||||
"""Stable output model mirroring SDK WebsetItem."""
|
||||
|
||||
id: str
|
||||
url: str
|
||||
url: Optional[AnyUrl]
|
||||
title: str
|
||||
content: str
|
||||
entity_data: Dict[str, Any]
|
||||
enrichments: Dict[str, Any]
|
||||
enrichments: Dict[str, EnrichmentResultModel] # Changed from Dict[str, Any]
|
||||
verification_status: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
@@ -43,23 +84,51 @@ class WebsetItemModel(BaseModel):
|
||||
"""Convert SDK WebsetItem to our stable model."""
|
||||
# Extract properties from the union type
|
||||
properties_dict = {}
|
||||
url_value = None
|
||||
title = ""
|
||||
content = ""
|
||||
|
||||
if hasattr(item, "properties") and item.properties:
|
||||
properties_dict = item.properties.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
# Convert enrichments from list to dict keyed by enrichment_id
|
||||
enrichments_dict = {}
|
||||
# URL is always available on all property types
|
||||
url_value = item.properties.url
|
||||
|
||||
# Extract title using isinstance checks on the union type
|
||||
if isinstance(item.properties, WebsetItemPersonProperties):
|
||||
title = item.properties.person.name
|
||||
content = "" # Person type has no content
|
||||
elif isinstance(item.properties, WebsetItemCompanyProperties):
|
||||
title = item.properties.company.name
|
||||
content = item.properties.content or ""
|
||||
elif isinstance(item.properties, WebsetItemArticleProperties):
|
||||
title = item.properties.description
|
||||
content = item.properties.content or ""
|
||||
elif isinstance(item.properties, WebsetItemResearchPaperProperties):
|
||||
title = item.properties.description
|
||||
content = item.properties.content or ""
|
||||
elif isinstance(item.properties, WebsetItemCustomProperties):
|
||||
title = item.properties.description
|
||||
content = item.properties.content or ""
|
||||
else:
|
||||
# Fallback
|
||||
title = item.properties.description
|
||||
content = getattr(item.properties, "content", "")
|
||||
|
||||
# Convert enrichments from list to dict keyed by enrichment_id using Pydantic models
|
||||
enrichments_dict: Dict[str, EnrichmentResultModel] = {}
|
||||
if hasattr(item, "enrichments") and item.enrichments:
|
||||
for enrich in item.enrichments:
|
||||
enrichment_data = enrich.model_dump(by_alias=True, exclude_none=True)
|
||||
enrichments_dict[enrich.enrichment_id] = enrichment_data
|
||||
for sdk_enrich in item.enrichments:
|
||||
enrich_model = EnrichmentResultModel.from_sdk(sdk_enrich)
|
||||
enrichments_dict[enrich_model.enrichment_id] = enrich_model
|
||||
|
||||
return cls(
|
||||
id=item.id,
|
||||
url=properties_dict.get("url", ""),
|
||||
title=properties_dict.get("title", ""),
|
||||
content=properties_dict.get("content", ""),
|
||||
url=url_value,
|
||||
title=title,
|
||||
content=content or "",
|
||||
entity_data=properties_dict,
|
||||
enrichments=enrichments_dict,
|
||||
verification_status="", # Not yet exposed in SDK
|
||||
@@ -174,12 +243,12 @@ class ExaListWebsetItemsBlock(Block):
|
||||
items: list[WebsetItemModel] = SchemaField(
|
||||
description="List of webset items",
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID of the webset",
|
||||
)
|
||||
item: WebsetItemModel = SchemaField(
|
||||
description="Individual item (yielded for each item in the list)",
|
||||
)
|
||||
total_count: Optional[int] = SchemaField(
|
||||
description="Total number of items in the webset",
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more items to paginate through",
|
||||
)
|
||||
@@ -250,9 +319,9 @@ class ExaListWebsetItemsBlock(Block):
|
||||
yield "item", item
|
||||
|
||||
# Yield pagination metadata
|
||||
yield "total_count", None # SDK doesn't provide total in pagination
|
||||
yield "has_more", response.has_more
|
||||
yield "next_cursor", response.next_cursor
|
||||
yield "webset_id", input_data.webset_id
|
||||
|
||||
|
||||
class ExaDeleteWebsetItemBlock(Block):
|
||||
@@ -336,9 +405,6 @@ class ExaBulkWebsetItemsBlock(Block):
|
||||
total_retrieved: int = SchemaField(
|
||||
description="Total number of items retrieved"
|
||||
)
|
||||
total_in_webset: Optional[int] = SchemaField(
|
||||
description="Total number of items in the webset"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether results were truncated due to max_items limit"
|
||||
)
|
||||
@@ -388,7 +454,6 @@ class ExaBulkWebsetItemsBlock(Block):
|
||||
yield "item", item
|
||||
|
||||
yield "total_retrieved", len(all_items)
|
||||
yield "total_in_webset", None # SDK doesn't provide total count
|
||||
yield "truncated", len(all_items) >= input_data.max_items
|
||||
|
||||
|
||||
|
||||
@@ -8,8 +8,10 @@ to complete, with progress tracking and timeout management.
|
||||
import asyncio
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
@@ -23,6 +25,19 @@ from backend.sdk import (
|
||||
|
||||
from ._config import exa
|
||||
|
||||
# Import WebsetItemModel for use in enrichment samples
|
||||
# This is safe as websets_items doesn't import from websets_polling
|
||||
from .websets_items import WebsetItemModel
|
||||
|
||||
|
||||
# Model for sample enrichment data
|
||||
class SampleEnrichmentModel(BaseModel):
|
||||
"""Sample enrichment result for display."""
|
||||
|
||||
item_id: str
|
||||
item_title: str
|
||||
enrichment_data: Dict[str, Any]
|
||||
|
||||
|
||||
class WebsetTargetStatus(str, Enum):
|
||||
IDLE = "idle"
|
||||
@@ -470,7 +485,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
description="Title/description of the enrichment"
|
||||
)
|
||||
elapsed_time: float = SchemaField(description="Total time elapsed in seconds")
|
||||
sample_data: list[dict] = SchemaField(
|
||||
sample_data: list[SampleEnrichmentModel] = SchemaField(
|
||||
description="Sample of enriched data (if requested)"
|
||||
)
|
||||
timed_out: bool = SchemaField(description="Whether the operation timed out")
|
||||
@@ -567,40 +582,29 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
|
||||
async def _get_sample_enrichments(
|
||||
self, webset_id: str, enrichment_id: str, aexa: AsyncExa
|
||||
) -> tuple[list[dict], int]:
|
||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||
"""Get sample enriched data and count."""
|
||||
# Get a few items to see enrichment results using SDK
|
||||
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||
|
||||
sample_data = []
|
||||
sample_data: list[SampleEnrichmentModel] = []
|
||||
enriched_count = 0
|
||||
|
||||
for item in response.data:
|
||||
# Check if item has this enrichment
|
||||
if item.enrichments:
|
||||
for enrich in item.enrichments:
|
||||
if enrich.enrichment_id == enrichment_id:
|
||||
enriched_count += 1
|
||||
enrich_dict = enrich.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
sample_data.append(
|
||||
{
|
||||
"item_id": item.id,
|
||||
"item_title": (
|
||||
item.properties.title
|
||||
if hasattr(item.properties, "title")
|
||||
else ""
|
||||
),
|
||||
"enrichment_data": enrich_dict,
|
||||
}
|
||||
)
|
||||
break
|
||||
for sdk_item in response.data:
|
||||
# Convert to our WebsetItemModel first
|
||||
item = WebsetItemModel.from_sdk(sdk_item)
|
||||
|
||||
# Estimate total enriched count based on sample
|
||||
# Note: This is an estimate - would need to check all items for accurate count
|
||||
if enriched_count > 0 and len(response.data) > 0:
|
||||
# For now, just return the sample count as we don't have total item count easily
|
||||
pass
|
||||
# Check if this item has the enrichment we're looking for
|
||||
if enrichment_id in item.enrichments:
|
||||
enriched_count += 1
|
||||
enrich_model = item.enrichments[enrichment_id]
|
||||
|
||||
# Create sample using our typed model
|
||||
sample = SampleEnrichmentModel(
|
||||
item_id=item.id,
|
||||
item_title=item.title,
|
||||
enrichment_data=enrich_model.model_dump(exclude_none=True),
|
||||
)
|
||||
sample_data.append(sample)
|
||||
|
||||
return sample_data, enriched_count
|
||||
|
||||
Reference in New Issue
Block a user