Refactor Exa blocks to use exa_py SDK and remove legacy API code

Replaces direct HTTP calls and custom API helpers with the official exa_py AsyncExa SDK across all Exa-related blocks. Removes the now-unnecessary _api.py and legacy polling/helpers, updates block outputs to use Pydantic models converted from SDK dataclasses, and introduces stable preview models for webset previews. This improves maintainability, reliability, and future-proofs the integration against API changes.
This commit is contained in:
Nicholas Tindle
2025-10-29 22:21:59 -05:00
parent 1181f9990e
commit e321551492
12 changed files with 632 additions and 1019 deletions

1
.gitignore vendored
View File

@@ -178,3 +178,4 @@ autogpt_platform/backend/settings.py
*.ign.*
.test-contents
.claude/settings.local.json
/autogpt_platform/backend/logs

View File

@@ -1,349 +0,0 @@
"""
Exa Websets API utilities.
This module provides common building blocks for interacting with the Exa API:
- URL construction (ExaApiUrls)
- Header building
- Item counting
- Pagination helpers
- Generic polling
"""
import asyncio
import time
from typing import Any, Callable, Iterator, Tuple
from backend.sdk import Requests
class ExaApiUrls:
"""Centralized URL builder for Exa Websets API endpoints."""
BASE = "https://api.exa.ai/websets/v0"
# ==================== Webset Endpoints ====================
@classmethod
def websets(cls) -> str:
"""List all websets endpoint."""
return f"{cls.BASE}/websets"
@classmethod
def webset(cls, webset_id: str) -> str:
"""Get/update/delete webset endpoint."""
return f"{cls.BASE}/websets/{webset_id}"
@classmethod
def webset_cancel(cls, webset_id: str) -> str:
"""Cancel webset endpoint."""
return f"{cls.BASE}/websets/{webset_id}/cancel"
@classmethod
def webset_preview(cls) -> str:
"""Preview webset query endpoint."""
return f"{cls.BASE}/websets/preview"
# ==================== Item Endpoints ====================
@classmethod
def webset_items(cls, webset_id: str) -> str:
"""List webset items endpoint."""
return f"{cls.BASE}/websets/{webset_id}/items"
@classmethod
def webset_item(cls, webset_id: str, item_id: str) -> str:
"""Get/delete specific item endpoint."""
return f"{cls.BASE}/websets/{webset_id}/items/{item_id}"
# ==================== Search Endpoints ====================
@classmethod
def webset_searches(cls, webset_id: str) -> str:
"""List/create searches endpoint."""
return f"{cls.BASE}/websets/{webset_id}/searches"
@classmethod
def webset_search(cls, webset_id: str, search_id: str) -> str:
"""Get specific search endpoint."""
return f"{cls.BASE}/websets/{webset_id}/searches/{search_id}"
@classmethod
def webset_search_cancel(cls, webset_id: str, search_id: str) -> str:
"""Cancel search endpoint."""
return f"{cls.BASE}/websets/{webset_id}/searches/{search_id}/cancel"
# ==================== Enrichment Endpoints ====================
@classmethod
def webset_enrichments(cls, webset_id: str) -> str:
"""List/create enrichments endpoint."""
return f"{cls.BASE}/websets/{webset_id}/enrichments"
@classmethod
def webset_enrichment(cls, webset_id: str, enrichment_id: str) -> str:
"""Get/update/delete enrichment endpoint."""
return f"{cls.BASE}/websets/{webset_id}/enrichments/{enrichment_id}"
@classmethod
def webset_enrichment_cancel(cls, webset_id: str, enrichment_id: str) -> str:
"""Cancel enrichment endpoint."""
return f"{cls.BASE}/websets/{webset_id}/enrichments/{enrichment_id}/cancel"
# ==================== Monitor Endpoints ====================
@classmethod
def monitors(cls) -> str:
"""List/create monitors endpoint."""
return f"{cls.BASE}/monitors"
@classmethod
def monitor(cls, monitor_id: str) -> str:
"""Get/update/delete monitor endpoint."""
return f"{cls.BASE}/monitors/{monitor_id}"
# ==================== Import Endpoints ====================
@classmethod
def imports(cls) -> str:
"""List/create imports endpoint."""
return f"{cls.BASE}/imports"
@classmethod
def import_(cls, import_id: str) -> str:
"""Get/delete import endpoint."""
return f"{cls.BASE}/imports/{import_id}"
def build_headers(api_key: str, include_content_type: bool = False) -> dict:
"""
Build standard Exa API headers.
Args:
api_key: The API key for authentication
include_content_type: Whether to include Content-Type: application/json header
Returns:
Dictionary of headers ready for API requests
Example:
>>> headers = build_headers("sk-123456")
>>> headers = build_headers("sk-123456", include_content_type=True)
"""
headers = {"x-api-key": api_key}
if include_content_type:
headers["Content-Type"] = "application/json"
return headers
async def get_item_count(webset_id: str, headers: dict) -> int:
"""
Get the total item count for a webset efficiently.
This makes a request with limit=1 and reads from pagination data
to avoid fetching all items.
Args:
webset_id: The webset ID
headers: Request headers with API key
Returns:
Total number of items in the webset
Example:
>>> count = await get_item_count("ws-123", headers)
"""
url = ExaApiUrls.webset_items(webset_id)
response = await Requests().get(url, headers=headers, params={"limit": 1})
data = response.json()
# Prefer pagination total if available
if "pagination" in data:
return data["pagination"].get("total", 0)
# Fallback to data length
return len(data.get("data", []))
def yield_paginated_results(
data: dict, list_key: str = "items", item_key: str = "item"
) -> Iterator[Tuple[str, Any]]:
"""
Yield paginated API results in standard format.
This helper yields both the full list and individual items for flexible
graph connections, plus pagination metadata.
Args:
data: API response data containing 'data', 'hasMore', 'nextCursor' fields
list_key: Output key name for the full list (default: "items")
item_key: Output key name for individual items (default: "item")
Yields:
Tuples of (key, value) for block output:
- (list_key, list): Full list of items
- (item_key, item): Each individual item (yielded separately)
- ("has_more", bool): Whether more results exist
- ("next_cursor", str|None): Cursor for next page
Example:
>>> for key, value in yield_paginated_results(response_data, "websets", "webset"):
>>> yield key, value
"""
items = data.get("data", [])
# Yield full list for batch processing
yield list_key, items
# Yield individual items for single-item processing chains
for item in items:
yield item_key, item
# Yield pagination metadata
yield "has_more", data.get("hasMore", False)
yield "next_cursor", data.get("nextCursor")
async def poll_until_complete(
url: str,
headers: dict,
is_complete: Callable[[dict], bool],
extract_result: Callable[[dict], Any],
timeout: int = 300,
initial_interval: float = 5.0,
max_interval: float = 30.0,
backoff_factor: float = 1.5,
) -> Any:
"""
Generic polling function with exponential backoff for async operations.
This function polls an API endpoint until a completion condition is met,
using exponential backoff to reduce API load.
Args:
url: API endpoint to poll
headers: Request headers with API key
is_complete: Function that takes response data and returns True when complete
extract_result: Function that extracts the result from response data
timeout: Maximum time to wait in seconds (default: 300)
initial_interval: Starting interval between polls in seconds (default: 5.0)
max_interval: Maximum interval between polls in seconds (default: 30.0)
backoff_factor: Factor to multiply interval by each iteration (default: 1.5)
Returns:
The result extracted by extract_result function when operation completes
Raises:
TimeoutError: If operation doesn't complete within timeout
Example:
>>> result = await poll_until_complete(
>>> url=ExaApiUrls.webset(webset_id),
>>> headers=build_headers(api_key),
>>> is_complete=lambda data: data.get("status") == "idle",
>>> extract_result=lambda data: data.get("itemsCount", 0),
>>> timeout=300
>>> )
"""
start_time = time.time()
interval = initial_interval
while time.time() - start_time < timeout:
response = await Requests().get(url, headers=headers)
data = response.json()
if is_complete(data):
return extract_result(data)
await asyncio.sleep(interval)
interval = min(interval * backoff_factor, max_interval)
# Timeout reached - raise error
raise TimeoutError(f"Operation did not complete within {timeout} seconds")
async def poll_webset_until_idle(
webset_id: str, headers: dict, timeout: int = 300
) -> int:
"""
Poll a webset until it reaches 'idle' status.
Convenience wrapper around poll_until_complete specifically for websets.
Args:
webset_id: The webset ID to poll
headers: Request headers with API key
timeout: Maximum time to wait in seconds
Returns:
The item count when webset becomes idle
Raises:
TimeoutError: If webset doesn't become idle within timeout
"""
return await poll_until_complete(
url=ExaApiUrls.webset(webset_id),
headers=headers,
is_complete=lambda data: data.get("status", {}).get("type") == "idle",
extract_result=lambda data: data.get("itemsCount", 0),
timeout=timeout,
)
async def poll_search_until_complete(
webset_id: str, search_id: str, headers: dict, timeout: int = 300
) -> int:
"""
Poll a search until it completes (completed/failed/cancelled).
Convenience wrapper around poll_until_complete specifically for searches.
Args:
webset_id: The webset ID
search_id: The search ID to poll
headers: Request headers with API key
timeout: Maximum time to wait in seconds
Returns:
The number of items found when search completes
Raises:
TimeoutError: If search doesn't complete within timeout
"""
return await poll_until_complete(
url=ExaApiUrls.webset_search(webset_id, search_id),
headers=headers,
is_complete=lambda data: data.get("status")
in ["completed", "failed", "cancelled"],
extract_result=lambda data: data.get("progress", {}).get("found", 0),
timeout=timeout,
)
async def poll_enrichment_until_complete(
webset_id: str, enrichment_id: str, headers: dict, timeout: int = 300
) -> int:
"""
Poll an enrichment until it completes (completed/failed/cancelled).
Convenience wrapper around poll_until_complete specifically for enrichments.
Args:
webset_id: The webset ID
enrichment_id: The enrichment ID to poll
headers: Request headers with API key
timeout: Maximum time to wait in seconds
Returns:
The number of items enriched when operation completes
Raises:
TimeoutError: If enrichment doesn't complete within timeout
"""
return await poll_until_complete(
url=ExaApiUrls.webset_enrichment(webset_id, enrichment_id),
headers=headers,
is_complete=lambda data: data.get("status")
in ["completed", "failed", "cancelled"],
extract_result=lambda data: data.get("progress", {}).get("processedItems", 0),
timeout=timeout,
)

View File

@@ -1,5 +1,7 @@
from typing import Optional
from exa_py import AsyncExa
from exa_py.api import AnswerResponse
from pydantic import BaseModel
from backend.sdk import (
@@ -10,12 +12,10 @@ from backend.sdk import (
BlockSchema,
CredentialsMetaInput,
MediaFileType,
Requests,
SchemaField,
)
from ._config import exa
from .helpers import CostDollars
class AnswerCitation(BaseModel):
@@ -23,25 +23,33 @@ class AnswerCitation(BaseModel):
id: str = SchemaField(description="The temporary ID for the document")
url: str = SchemaField(description="The URL of the search result")
title: Optional[str] = SchemaField(
description="The title of the search result", default=None
)
author: Optional[str] = SchemaField(
description="The author of the content", default=None
)
title: Optional[str] = SchemaField(description="The title of the search result")
author: Optional[str] = SchemaField(description="The author of the content")
publishedDate: Optional[str] = SchemaField(
description="An estimate of the creation date", default=None
)
text: Optional[str] = SchemaField(
description="The full text content of the source", default=None
description="An estimate of the creation date"
)
text: Optional[str] = SchemaField(description="The full text content of the source")
image: Optional[MediaFileType] = SchemaField(
description="The URL of the image associated with the result", default=None
description="The URL of the image associated with the result"
)
favicon: Optional[MediaFileType] = SchemaField(
description="The URL of the favicon for the domain", default=None
description="The URL of the favicon for the domain"
)
@classmethod
def from_sdk(cls, sdk_citation) -> "AnswerCitation":
"""Convert SDK AnswerResult (dataclass) to our Pydantic model."""
return cls(
id=getattr(sdk_citation, "id", ""),
url=getattr(sdk_citation, "url", ""),
title=getattr(sdk_citation, "title", None),
author=getattr(sdk_citation, "author", None),
publishedDate=getattr(sdk_citation, "published_date", None),
text=getattr(sdk_citation, "text", None),
image=getattr(sdk_citation, "image", None),
favicon=getattr(sdk_citation, "favicon", None),
)
class ExaAnswerBlock(Block):
class Input(BlockSchema):
@@ -62,18 +70,12 @@ class ExaAnswerBlock(Block):
description="The generated answer based on search results"
)
citations: list[AnswerCitation] = SchemaField(
description="Search results used to generate the answer",
default_factory=list,
description="Search results used to generate the answer"
)
citation: AnswerCitation = SchemaField(
description="Individual citation from the answer"
)
cost_dollars: Optional[CostDollars] = SchemaField(
description="Cost breakdown for the request", default=None
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
@@ -87,38 +89,27 @@ class ExaAnswerBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/answer"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
payload = {
"query": input_data.query,
"text": input_data.text,
# We don't support streaming in blocks
"stream": False,
}
# Get answer using SDK (stream=False for blocks) - this IS async, needs await
response = await aexa.answer(
query=input_data.query, text=input_data.text, stream=False
)
try:
response = await Requests().post(url, headers=headers, json=payload)
data = response.json()
# this should remain true as long as they don't start defaulting to streaming only.
# provides a bit of safety for sdk updates.
assert type(response) is AnswerResponse
# Yield the answer
if "answer" in data:
yield "answer", data["answer"]
# Yield the answer
yield "answer", response.answer
# Yield citations as a list
if "citations" in data:
yield "citations", data["citations"]
# Convert citations to our Pydantic model using from_sdk()
citations = [
AnswerCitation.from_sdk(sdk_citation)
for sdk_citation in response.citations or []
]
# Also yield individual citations
for citation in data["citations"]:
yield "citation", citation
# Yield cost information if present
if "costDollars" in data:
yield "cost_dollars", data["costDollars"]
except Exception as e:
yield "error", str(e)
yield "citations", citations
for citation in citations:
yield "citation", citation

View File

@@ -1,6 +1,7 @@
from enum import Enum
from typing import Optional
from exa_py import AsyncExa
from pydantic import BaseModel
from backend.sdk import (
@@ -10,7 +11,6 @@ from backend.sdk import (
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
@@ -105,27 +105,22 @@ class ExaContentsBlock(Block):
class Output(BlockSchema):
results: list[ExaSearchResults] = SchemaField(
description="List of document contents with metadata", default_factory=list
description="List of document contents with metadata"
)
result: ExaSearchResults = SchemaField(
description="Single document content result"
)
context: str = SchemaField(
description="A formatted string of the results ready for LLMs", default=""
)
request_id: str = SchemaField(
description="Unique identifier for the request", default=""
description="A formatted string of the results ready for LLMs"
)
request_id: str = SchemaField(description="Unique identifier for the request")
statuses: list[ContentStatus] = SchemaField(
description="Status information for each requested URL",
default_factory=list,
description="Status information for each requested URL"
)
cost_dollars: Optional[CostDollars] = SchemaField(
description="Cost breakdown for the request", default=None
)
error: str = SchemaField(
description="Error message if the request failed", default=""
description="Cost breakdown for the request"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
@@ -139,27 +134,22 @@ class ExaContentsBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/contents"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
# Validate input
if not input_data.urls and not input_data.ids:
raise ValueError("Either 'urls' or 'ids' must be provided")
# Build payload with urls or deprecated ids
payload = {}
# Build kwargs for SDK call
sdk_kwargs = {}
# Prefer urls over ids
if input_data.urls:
payload["urls"] = input_data.urls
sdk_kwargs["urls"] = input_data.urls
elif input_data.ids:
payload["ids"] = input_data.ids
else:
yield "error", "Either 'urls' or 'ids' must be provided"
return
sdk_kwargs["ids"] = input_data.ids
# Handle text field - when true, include HTML tags for better LLM understanding
if input_data.text:
payload["text"] = {"includeHtmlTags": True}
sdk_kwargs["text"] = {"includeHtmlTags": True}
# Handle highlights - only include if modified from defaults
if input_data.highlights and (
@@ -174,7 +164,7 @@ class ExaContentsBlock(Block):
)
if input_data.highlights.query:
highlights_dict["query"] = input_data.highlights.query
payload["highlights"] = highlights_dict
sdk_kwargs["highlights"] = highlights_dict
# Handle summary - only include if modified from defaults
if input_data.summary and (
@@ -186,23 +176,23 @@ class ExaContentsBlock(Block):
summary_dict["query"] = input_data.summary.query
if input_data.summary.schema:
summary_dict["schema"] = input_data.summary.schema
payload["summary"] = summary_dict
sdk_kwargs["summary"] = summary_dict
# Handle livecrawl
if input_data.livecrawl:
payload["livecrawl"] = input_data.livecrawl.value
sdk_kwargs["livecrawl"] = input_data.livecrawl.value
# Handle livecrawl_timeout
if input_data.livecrawl_timeout is not None:
payload["livecrawlTimeout"] = input_data.livecrawl_timeout
sdk_kwargs["livecrawl_timeout"] = input_data.livecrawl_timeout
# Handle subpages
if input_data.subpages is not None:
payload["subpages"] = input_data.subpages
sdk_kwargs["subpages"] = input_data.subpages
# Handle subpage_target
if input_data.subpage_target:
payload["subpageTarget"] = input_data.subpage_target
sdk_kwargs["subpage_target"] = input_data.subpage_target
# Handle extras - only include if modified from defaults
if input_data.extras and (
@@ -212,38 +202,36 @@ class ExaContentsBlock(Block):
if input_data.extras.links:
extras_dict["links"] = input_data.extras.links
if input_data.extras.image_links:
extras_dict["imageLinks"] = input_data.extras.image_links
payload["extras"] = extras_dict
extras_dict["image_links"] = input_data.extras.image_links
sdk_kwargs["extras"] = extras_dict
# Always enable context for LLM-ready output
payload["context"] = True
sdk_kwargs["context"] = True
try:
response = await Requests().post(url, headers=headers, json=payload)
data = response.json()
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
response = await aexa.get_contents(**sdk_kwargs)
# Extract all response fields
yield "results", data.get("results", [])
# SearchResponse is a dataclass, convert results to our Pydantic models
converted_results = [
ExaSearchResults.from_sdk(sdk_result)
for sdk_result in response.results or []
]
# Yield individual results
for result in data.get("results", []):
yield "result", result
yield "results", converted_results
# Yield context if present
if "context" in data:
yield "context", data["context"]
# Yield individual results
for result in converted_results:
yield "result", result
# Yield request ID if present
if "requestId" in data:
yield "request_id", data["requestId"]
# Yield context if present
if response.context:
yield "context", response.context
# Yield statuses if present
if "statuses" in data:
yield "statuses", data["statuses"]
# Yield statuses if present
if response.statuses:
yield "statuses", response.statuses
# Yield cost information if present
if "costDollars" in data:
yield "cost_dollars", data["costDollars"]
except Exception as e:
yield "error", str(e)
# Yield cost information if present
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars

View File

@@ -1,7 +1,5 @@
import asyncio
import time
from enum import Enum
from typing import Any, Callable, Dict, Literal, Optional, TypeVar, Union
from typing import Any, Dict, Literal, Optional, Union
from backend.sdk import BaseModel, MediaFileType, SchemaField
@@ -257,6 +255,25 @@ class ExaSearchResults(BaseModel):
subpages: list[dict] = SchemaField(default_factory=list)
extras: ExaSearchExtras | None = None
@classmethod
def from_sdk(cls, sdk_result) -> "ExaSearchResults":
"""Convert SDK Result (dataclass) to our Pydantic model."""
return cls(
id=getattr(sdk_result, "id", ""),
url=getattr(sdk_result, "url", None),
title=getattr(sdk_result, "title", None),
author=getattr(sdk_result, "author", None),
publishedDate=getattr(sdk_result, "published_date", None),
text=getattr(sdk_result, "text", None),
highlights=getattr(sdk_result, "highlights", None) or [],
highlightScores=getattr(sdk_result, "highlight_scores", None) or [],
summary=getattr(sdk_result, "summary", None),
subpages=getattr(sdk_result, "subpages", None) or [],
image=getattr(sdk_result, "image", None),
favicon=getattr(sdk_result, "favicon", None),
extras=getattr(sdk_result, "extras", None),
)
# Cost tracking models
class CostBreakdown(BaseModel):
@@ -440,258 +457,3 @@ def add_optional_fields(
payload[api_field] = value.value
else:
payload[api_field] = value
T = TypeVar("T")
async def poll_until_complete(
check_fn: Callable[[], tuple[bool, T]],
timeout: int = 300,
initial_interval: float = 5.0,
max_interval: float = 30.0,
backoff_factor: float = 1.5,
progress_callback: Optional[Callable[[str], None]] = None,
) -> T:
"""
Generic polling function for async operations.
Args:
check_fn: Function that returns (is_complete, result)
timeout: Maximum time to wait in seconds
initial_interval: Initial wait interval between polls
max_interval: Maximum wait interval between polls
backoff_factor: Factor to increase interval by each iteration
progress_callback: Optional callback for progress updates
Returns:
The result from check_fn when complete
Raises:
TimeoutError: If operation doesn't complete within timeout
"""
start_time = time.time()
interval = initial_interval
attempt = 0
while time.time() - start_time < timeout:
attempt += 1
is_complete, result = check_fn()
if is_complete:
if progress_callback:
progress_callback(f"✓ Operation completed after {attempt} attempts")
return result
# Calculate remaining time
elapsed = time.time() - start_time
remaining = timeout - elapsed
if progress_callback:
progress_callback(
f"⏳ Attempt {attempt}: Operation still in progress. "
f"Elapsed: {int(elapsed)}s, Remaining: {int(remaining)}s"
)
# Wait before next poll
wait_time = min(interval, remaining)
if wait_time > 0:
await asyncio.sleep(wait_time)
# Exponential backoff
interval = min(interval * backoff_factor, max_interval)
raise TimeoutError(f"Operation did not complete within {timeout} seconds")
async def poll_webset_status(
webset_id: str,
api_key: str,
target_status: str = "idle",
timeout: int = 300,
progress_callback: Optional[Callable[[str], None]] = None,
) -> Dict[str, Any]:
"""
Poll a webset until it reaches the target status.
Args:
webset_id: Webset ID to poll
api_key: API key for authentication
target_status: Status to wait for (default: "idle")
timeout: Maximum time to wait in seconds
progress_callback: Optional callback for progress updates
Returns:
The webset data when target status is reached
"""
import httpx
def check_status() -> tuple[bool, Dict[str, Any]]:
with httpx.Client() as client:
response = client.get(
f"https://api.exa.ai/v1alpha/websets/{webset_id}",
headers={"Authorization": f"Bearer {api_key}"},
)
response.raise_for_status()
data = response.json()
status = data.get("status", {}).get("type")
is_complete = status == target_status
if progress_callback and not is_complete:
items_count = data.get("itemsCount", 0)
progress_callback(f"Status: {status}, Items: {items_count}")
return is_complete, data
return await poll_until_complete(
check_fn=check_status, timeout=timeout, progress_callback=progress_callback
)
async def poll_search_completion(
webset_id: str,
search_id: str,
api_key: str,
timeout: int = 300,
progress_callback: Optional[Callable[[str], None]] = None,
) -> Dict[str, Any]:
"""
Poll a search until it completes.
Args:
webset_id: Webset ID
search_id: Search ID to poll
api_key: API key for authentication
timeout: Maximum time to wait in seconds
progress_callback: Optional callback for progress updates
Returns:
The search data when complete
"""
import httpx
def check_search() -> tuple[bool, Dict[str, Any]]:
with httpx.Client() as client:
response = client.get(
f"https://api.exa.ai/v1alpha/websets/{webset_id}/searches/{search_id}",
headers={"Authorization": f"Bearer {api_key}"},
)
response.raise_for_status()
data = response.json()
status = data.get("status")
is_complete = status in ["completed", "failed", "cancelled"]
if progress_callback and not is_complete:
items_found = data.get("results", {}).get("itemsFound", 0)
progress_callback(
f"Search status: {status}, Items found: {items_found}"
)
return is_complete, data
return await poll_until_complete(
check_fn=check_search, timeout=timeout, progress_callback=progress_callback
)
async def poll_enrichment_completion(
webset_id: str,
enrichment_id: str,
api_key: str,
timeout: int = 300,
progress_callback: Optional[Callable[[str], None]] = None,
) -> Dict[str, Any]:
"""
Poll an enrichment until it completes.
Args:
webset_id: Webset ID
enrichment_id: Enrichment ID to poll
api_key: API key for authentication
timeout: Maximum time to wait in seconds
progress_callback: Optional callback for progress updates
Returns:
The enrichment data when complete
"""
import httpx
def check_enrichment() -> tuple[bool, Dict[str, Any]]:
with httpx.Client() as client:
response = client.get(
f"https://api.exa.ai/v1alpha/websets/{webset_id}/enrichments/{enrichment_id}",
headers={"Authorization": f"Bearer {api_key}"},
)
response.raise_for_status()
data = response.json()
status = data.get("status")
is_complete = status in ["completed", "failed", "cancelled"]
if progress_callback and not is_complete:
progress = data.get("progress", {})
processed = progress.get("processedItems", 0)
total = progress.get("totalItems", 0)
progress_callback(
f"Enrichment status: {status}, Progress: {processed}/{total}"
)
return is_complete, data
return await poll_until_complete(
check_fn=check_enrichment, timeout=timeout, progress_callback=progress_callback
)
def format_progress_message(
operation_type: str, current_state: str, details: Optional[Dict[str, Any]] = None
) -> str:
"""
Format a progress message for display.
Args:
operation_type: Type of operation (webset, search, enrichment)
current_state: Current state description
details: Optional details to include
Returns:
Formatted progress message
"""
message_parts = [f"[{operation_type.upper()}]", current_state]
if details:
detail_parts = []
for key, value in details.items():
detail_parts.append(f"{key}: {value}")
if detail_parts:
message_parts.append(f"({', '.join(detail_parts)})")
return " ".join(message_parts)
def calculate_polling_stats(
start_time: float, timeout: int, attempts: int
) -> Dict[str, Any]:
"""
Calculate polling statistics.
Args:
start_time: Start time (from time.time())
timeout: Maximum timeout in seconds
attempts: Number of polling attempts made
Returns:
Dictionary with polling statistics
"""
elapsed = time.time() - start_time
remaining = max(0, timeout - elapsed)
return {
"elapsed_seconds": int(elapsed),
"remaining_seconds": int(remaining),
"attempts": attempts,
"average_interval": elapsed / attempts if attempts > 0 else 0,
"progress_percentage": min(100, (elapsed / timeout) * 100),
}

View File

@@ -2,6 +2,8 @@ from datetime import datetime
from enum import Enum
from typing import Optional
from exa_py import AsyncExa
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -9,7 +11,6 @@ from backend.sdk import (
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
@@ -18,8 +19,6 @@ from .helpers import (
ContentSettings,
CostDollars,
ExaSearchResults,
add_optional_fields,
format_date_fields,
process_contents_settings,
)
@@ -106,11 +105,9 @@ class ExaSearchBlock(Block):
results: list[ExaSearchResults] = SchemaField(
description="List of search results"
)
result: ExaSearchResults = SchemaField(
description="Single search result",
)
result: ExaSearchResults = SchemaField(description="Single search result")
context: str = SchemaField(
description="A formatted string of the search results ready for LLMs.",
description="A formatted string of the search results ready for LLMs."
)
search_type: str = SchemaField(
description="For auto searches, indicates which search type was selected."
@@ -120,11 +117,9 @@ class ExaSearchBlock(Block):
description="The search type that was actually used for this request (neural or keyword)"
)
cost_dollars: Optional[CostDollars] = SchemaField(
description="Cost breakdown for the request", default=None
)
error: str = SchemaField(
description="Error message if the request failed",
description="Cost breakdown for the request"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
@@ -138,82 +133,88 @@ class ExaSearchBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/search"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
payload = {
# Build kwargs for SDK call
sdk_kwargs = {
"query": input_data.query,
"numResults": input_data.number_of_results,
"num_results": input_data.number_of_results,
}
# Handle contents field with helper function
content_settings = process_contents_settings(input_data.contents)
if content_settings:
payload["contents"] = content_settings
# Handle type field
if input_data.type:
sdk_kwargs["type"] = input_data.type.value
# Handle date fields with helper function
date_field_mapping = {
"start_crawl_date": "startCrawlDate",
"end_crawl_date": "endCrawlDate",
"start_published_date": "startPublishedDate",
"end_published_date": "endPublishedDate",
}
payload.update(format_date_fields(input_data, date_field_mapping))
# Handle category field
if input_data.category:
sdk_kwargs["category"] = input_data.category.value
# Handle enum fields separately since they need special processing
for field_name, api_field in [("type", "type"), ("category", "category")]:
value = getattr(input_data, field_name, None)
if value:
payload[api_field] = value.value if hasattr(value, "value") else value
# Handle user_location
if input_data.user_location:
sdk_kwargs["user_location"] = input_data.user_location
# Handle other optional fields
optional_field_mapping = {
"user_location": "userLocation",
"include_domains": "includeDomains",
"exclude_domains": "excludeDomains",
"include_text": "includeText",
"exclude_text": "excludeText",
}
add_optional_fields(input_data, optional_field_mapping, payload)
# Handle domains
if input_data.include_domains:
sdk_kwargs["include_domains"] = input_data.include_domains
if input_data.exclude_domains:
sdk_kwargs["exclude_domains"] = input_data.exclude_domains
# Add moderation field
# Handle dates
if input_data.start_crawl_date:
sdk_kwargs["start_crawl_date"] = input_data.start_crawl_date.isoformat()
if input_data.end_crawl_date:
sdk_kwargs["end_crawl_date"] = input_data.end_crawl_date.isoformat()
if input_data.start_published_date:
sdk_kwargs["start_published_date"] = (
input_data.start_published_date.isoformat()
)
if input_data.end_published_date:
sdk_kwargs["end_published_date"] = input_data.end_published_date.isoformat()
# Handle text filters
if input_data.include_text:
sdk_kwargs["include_text"] = input_data.include_text
if input_data.exclude_text:
sdk_kwargs["exclude_text"] = input_data.exclude_text
# Handle moderation
if input_data.moderation:
payload["moderation"] = input_data.moderation
sdk_kwargs["moderation"] = input_data.moderation
# Always enable context for LLM-ready output
payload["context"] = True
# Handle contents - check if we need to use search_and_contents
content_settings = process_contents_settings(input_data.contents)
try:
response = await Requests().post(url, headers=headers, json=payload)
data = response.json()
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
# Extract all response fields
yield "results", data.get("results", [])
for result in data.get("results", []):
yield "result", result
if content_settings:
# Use search_and_contents when contents are requested
sdk_kwargs["text"] = content_settings.get("text", False)
if "highlights" in content_settings:
sdk_kwargs["highlights"] = content_settings["highlights"]
if "summary" in content_settings:
sdk_kwargs["summary"] = content_settings["summary"]
response = await aexa.search_and_contents(**sdk_kwargs)
else:
# Use regular search when no contents requested
response = await aexa.search(**sdk_kwargs)
# Yield context if present
if "context" in data:
yield "context", data["context"]
# SearchResponse is a dataclass, convert results to our Pydantic models
converted_results = [
ExaSearchResults.from_sdk(sdk_result)
for sdk_result in response.results or []
]
# Yield search type if present
if "searchType" in data:
yield "search_type", data["searchType"]
yield "results", converted_results
for result in converted_results:
yield "result", result
# Yield request ID if present
if "requestId" in data:
yield "request_id", data["requestId"]
# Yield context if present
if response.context:
yield "context", response.context
# Yield resolved search type if present
if "resolvedSearchType" in data:
yield "resolved_search_type", data["resolvedSearchType"]
# Yield resolved search type if present
if response.resolved_search_type:
yield "resolved_search_type", response.resolved_search_type
# Yield cost information if present
if "costDollars" in data:
yield "cost_dollars", data["costDollars"]
except Exception as e:
yield "error", str(e)
# Yield cost information if present
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars

View File

@@ -1,6 +1,8 @@
from datetime import datetime
from typing import Optional
from exa_py import AsyncExa
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -8,7 +10,6 @@ from backend.sdk import (
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
@@ -17,8 +18,6 @@ from .helpers import (
ContentSettings,
CostDollars,
ExaSearchResults,
add_optional_fields,
format_date_fields,
process_contents_settings,
)
@@ -82,18 +81,16 @@ class ExaFindSimilarBlock(Block):
description="List of similar documents with metadata and content"
)
result: ExaSearchResults = SchemaField(
description="Single similar document result",
description="Single similar document result"
)
context: str = SchemaField(
description="A formatted string of the results ready for LLMs.",
description="A formatted string of the results ready for LLMs."
)
request_id: str = SchemaField(description="Unique identifier for the request")
cost_dollars: Optional[CostDollars] = SchemaField(
description="Cost breakdown for the request", default=None
)
error: str = SchemaField(
description="Error message if the request failed",
description="Cost breakdown for the request"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
@@ -107,67 +104,72 @@ class ExaFindSimilarBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/findSimilar"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
payload = {
# Build kwargs for SDK call
sdk_kwargs = {
"url": input_data.url,
"numResults": input_data.number_of_results,
"num_results": input_data.number_of_results,
}
# Handle contents field with helper function
content_settings = process_contents_settings(input_data.contents)
if content_settings:
payload["contents"] = content_settings
# Handle domains
if input_data.include_domains:
sdk_kwargs["include_domains"] = input_data.include_domains
if input_data.exclude_domains:
sdk_kwargs["exclude_domains"] = input_data.exclude_domains
# Handle date fields with helper function
date_field_mapping = {
"start_crawl_date": "startCrawlDate",
"end_crawl_date": "endCrawlDate",
"start_published_date": "startPublishedDate",
"end_published_date": "endPublishedDate",
}
payload.update(format_date_fields(input_data, date_field_mapping))
# Handle dates
if input_data.start_crawl_date:
sdk_kwargs["start_crawl_date"] = input_data.start_crawl_date.isoformat()
if input_data.end_crawl_date:
sdk_kwargs["end_crawl_date"] = input_data.end_crawl_date.isoformat()
if input_data.start_published_date:
sdk_kwargs["start_published_date"] = (
input_data.start_published_date.isoformat()
)
if input_data.end_published_date:
sdk_kwargs["end_published_date"] = input_data.end_published_date.isoformat()
# Handle other optional fields
optional_field_mapping = {
"include_domains": "includeDomains",
"exclude_domains": "excludeDomains",
"include_text": "includeText",
"exclude_text": "excludeText",
}
add_optional_fields(input_data, optional_field_mapping, payload)
# Handle text filters
if input_data.include_text:
sdk_kwargs["include_text"] = input_data.include_text
if input_data.exclude_text:
sdk_kwargs["exclude_text"] = input_data.exclude_text
# Add moderation field
# Handle moderation
if input_data.moderation:
payload["moderation"] = input_data.moderation
sdk_kwargs["moderation"] = input_data.moderation
# Always enable context for LLM-ready output
payload["context"] = True
# Handle contents - check if we need to use find_similar_and_contents
content_settings = process_contents_settings(input_data.contents)
try:
response = await Requests().post(url, headers=headers, json=payload)
data = response.json()
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
# Extract all response fields
yield "results", data.get("results", [])
for result in data.get("results", []):
yield "result", result
if content_settings:
# Use find_similar_and_contents when contents are requested
sdk_kwargs["text"] = content_settings.get("text", False)
if "highlights" in content_settings:
sdk_kwargs["highlights"] = content_settings["highlights"]
if "summary" in content_settings:
sdk_kwargs["summary"] = content_settings["summary"]
response = await aexa.find_similar_and_contents(**sdk_kwargs)
else:
# Use regular find_similar when no contents requested
response = await aexa.find_similar(**sdk_kwargs)
# Yield context if present
if "context" in data:
yield "context", data["context"]
# SearchResponse is a dataclass, convert results to our Pydantic models
converted_results = [
ExaSearchResults.from_sdk(sdk_result)
for sdk_result in response.results or []
]
# Yield request ID if present
if "requestId" in data:
yield "request_id", data["requestId"]
yield "results", converted_results
for result in converted_results:
yield "result", result
# Yield cost information if present
if "costDollars" in data:
yield "cost_dollars", data["costDollars"]
# Yield context if present
if response.context:
yield "context", response.context
except Exception as e:
yield "error", str(e)
# Yield cost information if present
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars

View File

@@ -131,45 +131,33 @@ class ExaWebsetWebhookBlock(Block):
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
"""Process incoming Exa webhook payload."""
try:
payload = input_data.payload
payload = input_data.payload
# Extract event details
event_type = payload.get("eventType", "unknown")
event_id = payload.get("eventId", "")
# Extract event details
event_type = payload.get("eventType", "unknown")
event_id = payload.get("eventId", "")
# Get webset ID from payload or input
webset_id = payload.get("websetId", input_data.webset_id)
# Get webset ID from payload or input
webset_id = payload.get("websetId", input_data.webset_id)
# Check if we should process this event based on filter
should_process = self._should_process_event(
event_type, input_data.event_filter
)
# Check if we should process this event based on filter
should_process = self._should_process_event(event_type, input_data.event_filter)
if not should_process:
# Skip events that don't match our filter
return
if not should_process:
# Skip events that don't match our filter
return
# Extract event data
event_data = payload.get("data", {})
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
metadata = payload.get("metadata", {})
# Extract event data
event_data = payload.get("data", {})
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
metadata = payload.get("metadata", {})
yield "event_type", event_type
yield "event_id", event_id
yield "webset_id", webset_id
yield "data", event_data
yield "timestamp", timestamp
yield "metadata", metadata
except Exception as e:
# Handle errors gracefully
yield "event_type", "error"
yield "event_id", ""
yield "webset_id", input_data.webset_id
yield "data", {"error": str(e)}
yield "timestamp", ""
yield "metadata", {}
yield "event_type", event_type
yield "event_id", event_id
yield "webset_id", webset_id
yield "data", event_data
yield "timestamp", timestamp
yield "metadata", metadata
def _should_process_event(
self, event_type: str, event_filter: WebsetEventFilter

View File

@@ -702,9 +702,6 @@ class ExaGetWebsetBlock(Block):
description="The enrichments applied to the webset"
)
monitors: list[dict] = SchemaField(description="The monitors for the webset")
items: Optional[list[dict]] = SchemaField(
description="The items in the webset (if expand_items is true)"
)
metadata: dict = SchemaField(
description="Key-value pairs associated with the webset"
)
@@ -761,7 +758,6 @@ class ExaGetWebsetBlock(Block):
yield "searches", searches_data
yield "enrichments", enrichments_data
yield "monitors", monitors_data
yield "items", None # SDK doesn't expand items by default
yield "metadata", sdk_webset.metadata or {}
yield "created_at", (
sdk_webset.created_at.isoformat() if sdk_webset.created_at else ""
@@ -877,6 +873,92 @@ class ExaCancelWebsetBlock(Block):
yield "success", "true"
# Mirrored models for Preview response stability
class PreviewCriterionModel(BaseModel):
"""Stable model for preview criteria."""
description: str
@classmethod
def from_sdk(cls, sdk_criterion) -> "PreviewCriterionModel":
"""Convert SDK criterion to our model."""
return cls(description=sdk_criterion.description)
class PreviewEnrichmentModel(BaseModel):
"""Stable model for preview enrichment."""
description: str
format: str
options: List[str]
@classmethod
def from_sdk(cls, sdk_enrichment) -> "PreviewEnrichmentModel":
"""Convert SDK enrichment to our model."""
format_str = (
sdk_enrichment.format.value
if hasattr(sdk_enrichment.format, "value")
else str(sdk_enrichment.format)
)
options_list = []
if sdk_enrichment.options:
for opt in sdk_enrichment.options:
opt_dict = opt.model_dump(by_alias=True)
options_list.append(opt_dict.get("label", ""))
return cls(
description=sdk_enrichment.description,
format=format_str,
options=options_list,
)
class PreviewSearchModel(BaseModel):
"""Stable model for preview search details."""
entity_type: str
entity_description: Optional[str]
criteria: List[PreviewCriterionModel]
@classmethod
def from_sdk(cls, sdk_search) -> "PreviewSearchModel":
"""Convert SDK search preview to our model."""
# Extract entity type from union
entity_dict = sdk_search.entity.model_dump(by_alias=True)
entity_type = entity_dict.get("type", "auto")
entity_description = entity_dict.get("description")
# Convert criteria
criteria = [
PreviewCriterionModel.from_sdk(c) for c in sdk_search.criteria or []
]
return cls(
entity_type=entity_type,
entity_description=entity_description,
criteria=criteria,
)
class PreviewWebsetModel(BaseModel):
"""Stable model for preview response."""
search: PreviewSearchModel
enrichments: List[PreviewEnrichmentModel]
@classmethod
def from_sdk(cls, sdk_preview) -> "PreviewWebsetModel":
"""Convert SDK PreviewWebsetResponse to our model."""
search = PreviewSearchModel.from_sdk(sdk_preview.search)
enrichments = [
PreviewEnrichmentModel.from_sdk(e) for e in sdk_preview.enrichments or []
]
return cls(search=search, enrichments=enrichments)
class ExaPreviewWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
@@ -898,16 +980,19 @@ class ExaPreviewWebsetBlock(Block):
)
class Output(BlockSchema):
preview: PreviewWebsetModel = SchemaField(
description="Full preview response with search and enrichment details"
)
entity_type: str = SchemaField(
description="The detected or specified entity type"
)
entity_description: Optional[str] = SchemaField(
description="Description of the entity type"
)
criteria: list[dict] = SchemaField(
criteria: list[PreviewCriterionModel] = SchemaField(
description="Generated search criteria that will be used"
)
enrichment_columns: list[dict] = SchemaField(
enrichment_columns: list[PreviewEnrichmentModel] = SchemaField(
description="Available enrichment columns that can be extracted"
)
interpretation: str = SchemaField(
@@ -949,24 +1034,16 @@ class ExaPreviewWebsetBlock(Block):
payload["entity"] = entity
# Preview webset using SDK - no await needed
preview = aexa.websets.preview(params=payload)
sdk_preview = aexa.websets.preview(params=payload)
# Extract entity information
entity_dict = preview.entity.model_dump(by_alias=True, exclude_none=True)
entity_type = entity_dict.get("type", "auto")
entity_description = entity_dict.get("description")
# Convert to our stable Pydantic model
preview = PreviewWebsetModel.from_sdk(sdk_preview)
# Extract criteria
criteria = [
c.model_dump(by_alias=True, exclude_none=True)
for c in preview.criteria or []
]
# Extract enrichment columns
enrichments = [
e.model_dump(by_alias=True, exclude_none=True)
for e in preview.enrichment_columns or []
]
# Extract details for individual fields (for easier graph connections)
entity_type = preview.search.entity_type
entity_description = preview.search.entity_description
criteria = preview.search.criteria
enrichments = preview.enrichments
# Generate interpretation
interpretation = f"Query will search for {entity_type}"
@@ -977,7 +1054,7 @@ class ExaPreviewWebsetBlock(Block):
if enrichments:
interpretation += f" and {len(enrichments)} available enrichment columns"
# Generate suggestions (could be enhanced based on the response)
# Generate suggestions
suggestions = []
if not criteria:
suggestions.append(
@@ -988,6 +1065,10 @@ class ExaPreviewWebsetBlock(Block):
"Consider specifying what data points you want to extract"
)
# Yield full model first
yield "preview", preview
# Then yield individual fields for graph flexibility
yield "entity_type", entity_type
yield "entity_description", entity_description
yield "criteria", criteria
@@ -1073,6 +1154,42 @@ class ExaWebsetStatusBlock(Block):
yield "is_processing", is_processing
# Summary models for ExaWebsetSummaryBlock
class SearchSummaryModel(BaseModel):
"""Summary of searches in a webset."""
total_searches: int
completed_searches: int
total_items_found: int
queries: List[str]
class EnrichmentSummaryModel(BaseModel):
"""Summary of enrichments in a webset."""
total_enrichments: int
completed_enrichments: int
enrichment_types: List[str]
titles: List[str]
class MonitorSummaryModel(BaseModel):
"""Summary of monitors in a webset."""
total_monitors: int
active_monitors: int
next_run: Optional[datetime] = None
class WebsetStatisticsModel(BaseModel):
"""Various statistics about a webset."""
total_operations: int
is_processing: bool
has_monitors: bool
avg_items_per_search: float
class ExaWebsetSummaryBlock(Block):
"""Get a comprehensive summary of a webset including samples and statistics."""
@@ -1105,23 +1222,22 @@ class ExaWebsetSummaryBlock(Block):
class Output(BlockSchema):
webset_id: str = SchemaField(description="The webset identifier")
title: Optional[str] = SchemaField(
description="Title of the webset if available"
)
status: str = SchemaField(description="Current status")
entity_type: str = SchemaField(description="Type of entities in the webset")
total_items: int = SchemaField(description="Total number of items")
sample_items: list[dict] = SchemaField(
sample_items: list[Dict[str, Any]] = SchemaField(
description="Sample items from the webset"
)
search_summary: dict = SchemaField(description="Summary of searches performed")
enrichment_summary: dict = SchemaField(
search_summary: SearchSummaryModel = SchemaField(
description="Summary of searches performed"
)
enrichment_summary: EnrichmentSummaryModel = SchemaField(
description="Summary of enrichments applied"
)
monitor_summary: dict = SchemaField(
monitor_summary: MonitorSummaryModel = SchemaField(
description="Summary of monitors configured"
)
statistics: dict = SchemaField(
statistics: WebsetStatisticsModel = SchemaField(
description="Various statistics about the webset"
)
created_at: str = SchemaField(description="When the webset was created")
@@ -1159,10 +1275,11 @@ class ExaWebsetSummaryBlock(Block):
searches = webset.searches or []
if searches:
first_search = searches[0]
entity_dict = first_search.entity.model_dump(
by_alias=True, exclude_none=True
)
entity_type = entity_dict.get("type", "unknown")
if first_search.entity:
entity_dict = first_search.entity.model_dump(
by_alias=True, exclude_none=True
)
entity_type = entity_dict.get("type", "unknown")
# Get sample items if requested
sample_items_data = []
@@ -1178,81 +1295,88 @@ class ExaWebsetSummaryBlock(Block):
]
total_items = len(sample_items_data)
# Build search summary
search_summary = {}
# Build search summary using Pydantic model
search_summary = SearchSummaryModel(
total_searches=0,
completed_searches=0,
total_items_found=0,
queries=[],
)
if input_data.include_search_details and searches:
search_summary = {
"total_searches": len(searches),
"completed_searches": sum(
search_summary = SearchSummaryModel(
total_searches=len(searches),
completed_searches=sum(
1
for s in searches
if (s.status.value if hasattr(s.status, "value") else str(s.status))
== "completed"
),
"total_items_found": sum(
s.progress.found if s.progress else 0 for s in searches
total_items_found=int(
sum(s.progress.found if s.progress else 0 for s in searches)
),
"queries": [s.query for s in searches[:3]], # First 3 queries
}
queries=[s.query for s in searches[:3]], # First 3 queries
)
# Build enrichment summary
enrichment_summary = {}
# Build enrichment summary using Pydantic model
enrichment_summary = EnrichmentSummaryModel(
total_enrichments=0,
completed_enrichments=0,
enrichment_types=[],
titles=[],
)
enrichments = webset.enrichments or []
if input_data.include_enrichment_details and enrichments:
enrichment_summary = {
"total_enrichments": len(enrichments),
"completed_enrichments": sum(
enrichment_summary = EnrichmentSummaryModel(
total_enrichments=len(enrichments),
completed_enrichments=sum(
1
for e in enrichments
if (e.status.value if hasattr(e.status, "value") else str(e.status))
== "completed"
),
"enrichment_types": list(
enrichment_types=list(
set(
(
e.format.value
if hasattr(e.format, "value")
else str(e.format)
if e.format and hasattr(e.format, "value")
else str(e.format) if e.format else "text"
)
for e in enrichments
)
),
"titles": [
(e.title or e.description or "")[:50] for e in enrichments[:3]
],
}
titles=[(e.title or e.description or "")[:50] for e in enrichments[:3]],
)
# Build monitor summary
# Build monitor summary using Pydantic model
monitors = webset.monitors or []
monitor_summary = {
"total_monitors": len(monitors),
"active_monitors": sum(
next_run_dt = None
if monitors:
next_runs = [m.next_run_at for m in monitors if m.next_run_at]
if next_runs:
next_run_dt = min(next_runs)
monitor_summary = MonitorSummaryModel(
total_monitors=len(monitors),
active_monitors=sum(
1
for m in monitors
if (m.status.value if hasattr(m.status, "value") else str(m.status))
== "enabled"
),
}
next_run=next_run_dt,
)
if monitors:
next_runs = [m.next_run_at.isoformat() for m in monitors if m.next_run_at]
if next_runs:
monitor_summary["next_run"] = min(next_runs)
# Build statistics
statistics = {
"total_operations": len(searches) + len(enrichments),
"is_processing": status in ["running", "pending"],
"has_monitors": len(monitors) > 0,
"avg_items_per_search": (
search_summary.get("total_items_found", 0) / len(searches)
if searches
else 0
# Build statistics using Pydantic model
statistics = WebsetStatisticsModel(
total_operations=len(searches) + len(enrichments),
is_processing=status in ["running", "pending"],
has_monitors=len(monitors) > 0,
avg_items_per_search=(
search_summary.total_items_found / len(searches) if searches else 0
),
}
)
yield "webset_id", webset_id
yield "title", None # SDK doesn't have title field
yield "status", status
yield "entity_type", entity_type
yield "total_items", total_items

View File

@@ -9,9 +9,10 @@ import csv
import json
from enum import Enum
from io import StringIO
from typing import Optional
from typing import Optional, Union
from exa_py import AsyncExa
from exa_py.websets.types import CreateImportResponse
from exa_py.websets.types import Import as SdkImport
from pydantic import BaseModel
@@ -38,7 +39,8 @@ class ImportModel(BaseModel):
format: str
entity_type: str
count: int
size: int
upload_url: Optional[str] # Only in CreateImportResponse
upload_valid_until: Optional[str] # Only in CreateImportResponse
failed_reason: str
failed_message: str
metadata: dict
@@ -46,11 +48,15 @@ class ImportModel(BaseModel):
updated_at: str
@classmethod
def from_sdk(cls, import_obj: SdkImport) -> "ImportModel":
"""Convert SDK Import to our stable model."""
# Extract entity type from union
entity_dict = import_obj.entity.model_dump(by_alias=True, exclude_none=True)
entity_type = entity_dict.get("type", "unknown")
def from_sdk(
cls, import_obj: Union[SdkImport, CreateImportResponse]
) -> "ImportModel":
"""Convert SDK Import or CreateImportResponse to our stable model."""
# Extract entity type from union (may be None)
entity_type = "unknown"
if import_obj.entity:
entity_dict = import_obj.entity.model_dump(by_alias=True, exclude_none=True)
entity_type = entity_dict.get("type", "unknown")
# Handle status enum
status_str = (
@@ -66,15 +72,29 @@ class ImportModel(BaseModel):
else str(import_obj.format)
)
# Handle failed_reason enum (may be None or enum)
failed_reason_str = ""
if import_obj.failed_reason:
failed_reason_str = (
import_obj.failed_reason.value
if hasattr(import_obj.failed_reason, "value")
else str(import_obj.failed_reason)
)
return cls(
id=import_obj.id,
status=status_str,
title=import_obj.title or "",
format=format_str,
entity_type=entity_type,
count=import_obj.count or 0,
size=import_obj.size or 0,
failed_reason=import_obj.failed_reason or "",
count=int(import_obj.count or 0),
upload_url=getattr(
import_obj, "upload_url", None
), # Only in CreateImportResponse
upload_valid_until=getattr(
import_obj, "upload_valid_until", None
), # Only in CreateImportResponse
failed_reason=failed_reason_str,
failed_message=import_obj.failed_message or "",
metadata=import_obj.metadata or {},
created_at=(
@@ -160,6 +180,12 @@ class ExaCreateImportBlock(Block):
title: str = SchemaField(description="Title of the import")
count: int = SchemaField(description="Number of items in the import")
entity_type: str = SchemaField(description="Type of entities imported")
upload_url: Optional[str] = SchemaField(
description="Upload URL for CSV data (only if csv_data not provided in request)"
)
upload_valid_until: Optional[str] = SchemaField(
description="Expiration time for upload URL (only if upload_url is provided)"
)
created_at: str = SchemaField(description="When the import was created")
error: str = SchemaField(description="Error message if the import failed")
@@ -243,6 +269,8 @@ class ExaCreateImportBlock(Block):
yield "title", import_obj.title
yield "count", import_obj.count
yield "entity_type", import_obj.entity_type
yield "upload_url", import_obj.upload_url
yield "upload_valid_until", import_obj.upload_valid_until
yield "created_at", import_obj.created_at
@@ -265,6 +293,12 @@ class ExaGetImportBlock(Block):
format: str = SchemaField(description="Format of the imported data")
entity_type: str = SchemaField(description="Type of entities imported")
count: int = SchemaField(description="Number of items imported")
upload_url: Optional[str] = SchemaField(
description="Upload URL for CSV data (if import not yet uploaded)"
)
upload_valid_until: Optional[str] = SchemaField(
description="Expiration time for upload URL (if applicable)"
)
failed_reason: Optional[str] = SchemaField(
description="Reason for failure (if applicable)"
)
@@ -304,6 +338,8 @@ class ExaGetImportBlock(Block):
yield "format", import_obj.format
yield "entity_type", import_obj.entity_type
yield "count", import_obj.count
yield "upload_url", import_obj.upload_url
yield "upload_valid_until", import_obj.upload_valid_until
yield "failed_reason", import_obj.failed_reason
yield "failed_message", import_obj.failed_message
yield "created_at", import_obj.created_at
@@ -442,10 +478,10 @@ class ExaExportWebsetBlock(Block):
description="Include enrichment data in export",
)
max_items: int = SchemaField(
default=1000,
default=100,
description="Maximum number of items to export",
ge=1,
le=10000,
le=100,
)
class Output(BlockSchema):

View File

@@ -9,7 +9,14 @@ from typing import Any, Dict, List, Optional
from exa_py import AsyncExa
from exa_py.websets.types import WebsetItem as SdkWebsetItem
from pydantic import BaseModel
from exa_py.websets.types import (
WebsetItemArticleProperties,
WebsetItemCompanyProperties,
WebsetItemCustomProperties,
WebsetItemPersonProperties,
WebsetItemResearchPaperProperties,
)
from pydantic import AnyUrl, BaseModel
from backend.sdk import (
APIKeyCredentials,
@@ -24,16 +31,50 @@ from backend.sdk import (
from ._config import exa
# Mirrored model for enrichment results
class EnrichmentResultModel(BaseModel):
"""Stable output model mirroring SDK EnrichmentResult."""
enrichment_id: str
format: str
result: Optional[List[str]]
reasoning: Optional[str]
references: List[Dict[str, Any]]
@classmethod
def from_sdk(cls, sdk_enrich) -> "EnrichmentResultModel":
"""Convert SDK EnrichmentResult to our model."""
format_str = (
sdk_enrich.format.value
if hasattr(sdk_enrich.format, "value")
else str(sdk_enrich.format)
)
# Convert references to dicts
references_list = []
if sdk_enrich.references:
for ref in sdk_enrich.references:
references_list.append(ref.model_dump(by_alias=True, exclude_none=True))
return cls(
enrichment_id=sdk_enrich.enrichment_id,
format=format_str,
result=sdk_enrich.result,
reasoning=sdk_enrich.reasoning,
references=references_list,
)
# Mirrored model for stability - don't use SDK types directly in block outputs
class WebsetItemModel(BaseModel):
"""Stable output model mirroring SDK WebsetItem."""
id: str
url: str
url: Optional[AnyUrl]
title: str
content: str
entity_data: Dict[str, Any]
enrichments: Dict[str, Any]
enrichments: Dict[str, EnrichmentResultModel] # Changed from Dict[str, Any]
verification_status: str
created_at: str
updated_at: str
@@ -43,23 +84,51 @@ class WebsetItemModel(BaseModel):
"""Convert SDK WebsetItem to our stable model."""
# Extract properties from the union type
properties_dict = {}
url_value = None
title = ""
content = ""
if hasattr(item, "properties") and item.properties:
properties_dict = item.properties.model_dump(
by_alias=True, exclude_none=True
)
# Convert enrichments from list to dict keyed by enrichment_id
enrichments_dict = {}
# URL is always available on all property types
url_value = item.properties.url
# Extract title using isinstance checks on the union type
if isinstance(item.properties, WebsetItemPersonProperties):
title = item.properties.person.name
content = "" # Person type has no content
elif isinstance(item.properties, WebsetItemCompanyProperties):
title = item.properties.company.name
content = item.properties.content or ""
elif isinstance(item.properties, WebsetItemArticleProperties):
title = item.properties.description
content = item.properties.content or ""
elif isinstance(item.properties, WebsetItemResearchPaperProperties):
title = item.properties.description
content = item.properties.content or ""
elif isinstance(item.properties, WebsetItemCustomProperties):
title = item.properties.description
content = item.properties.content or ""
else:
# Fallback
title = item.properties.description
content = getattr(item.properties, "content", "")
# Convert enrichments from list to dict keyed by enrichment_id using Pydantic models
enrichments_dict: Dict[str, EnrichmentResultModel] = {}
if hasattr(item, "enrichments") and item.enrichments:
for enrich in item.enrichments:
enrichment_data = enrich.model_dump(by_alias=True, exclude_none=True)
enrichments_dict[enrich.enrichment_id] = enrichment_data
for sdk_enrich in item.enrichments:
enrich_model = EnrichmentResultModel.from_sdk(sdk_enrich)
enrichments_dict[enrich_model.enrichment_id] = enrich_model
return cls(
id=item.id,
url=properties_dict.get("url", ""),
title=properties_dict.get("title", ""),
content=properties_dict.get("content", ""),
url=url_value,
title=title,
content=content or "",
entity_data=properties_dict,
enrichments=enrichments_dict,
verification_status="", # Not yet exposed in SDK
@@ -174,12 +243,12 @@ class ExaListWebsetItemsBlock(Block):
items: list[WebsetItemModel] = SchemaField(
description="List of webset items",
)
webset_id: str = SchemaField(
description="The ID of the webset",
)
item: WebsetItemModel = SchemaField(
description="Individual item (yielded for each item in the list)",
)
total_count: Optional[int] = SchemaField(
description="Total number of items in the webset",
)
has_more: bool = SchemaField(
description="Whether there are more items to paginate through",
)
@@ -250,9 +319,9 @@ class ExaListWebsetItemsBlock(Block):
yield "item", item
# Yield pagination metadata
yield "total_count", None # SDK doesn't provide total in pagination
yield "has_more", response.has_more
yield "next_cursor", response.next_cursor
yield "webset_id", input_data.webset_id
class ExaDeleteWebsetItemBlock(Block):
@@ -336,9 +405,6 @@ class ExaBulkWebsetItemsBlock(Block):
total_retrieved: int = SchemaField(
description="Total number of items retrieved"
)
total_in_webset: Optional[int] = SchemaField(
description="Total number of items in the webset"
)
truncated: bool = SchemaField(
description="Whether results were truncated due to max_items limit"
)
@@ -388,7 +454,6 @@ class ExaBulkWebsetItemsBlock(Block):
yield "item", item
yield "total_retrieved", len(all_items)
yield "total_in_webset", None # SDK doesn't provide total count
yield "truncated", len(all_items) >= input_data.max_items

View File

@@ -8,8 +8,10 @@ to complete, with progress tracking and timeout management.
import asyncio
import time
from enum import Enum
from typing import Any, Dict
from exa_py import AsyncExa
from pydantic import BaseModel
from backend.sdk import (
APIKeyCredentials,
@@ -23,6 +25,19 @@ from backend.sdk import (
from ._config import exa
# Import WebsetItemModel for use in enrichment samples
# This is safe as websets_items doesn't import from websets_polling
from .websets_items import WebsetItemModel
# Model for sample enrichment data
class SampleEnrichmentModel(BaseModel):
"""Sample enrichment result for display."""
item_id: str
item_title: str
enrichment_data: Dict[str, Any]
class WebsetTargetStatus(str, Enum):
IDLE = "idle"
@@ -470,7 +485,7 @@ class ExaWaitForEnrichmentBlock(Block):
description="Title/description of the enrichment"
)
elapsed_time: float = SchemaField(description="Total time elapsed in seconds")
sample_data: list[dict] = SchemaField(
sample_data: list[SampleEnrichmentModel] = SchemaField(
description="Sample of enriched data (if requested)"
)
timed_out: bool = SchemaField(description="Whether the operation timed out")
@@ -567,40 +582,29 @@ class ExaWaitForEnrichmentBlock(Block):
async def _get_sample_enrichments(
self, webset_id: str, enrichment_id: str, aexa: AsyncExa
) -> tuple[list[dict], int]:
) -> tuple[list[SampleEnrichmentModel], int]:
"""Get sample enriched data and count."""
# Get a few items to see enrichment results using SDK
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
sample_data = []
sample_data: list[SampleEnrichmentModel] = []
enriched_count = 0
for item in response.data:
# Check if item has this enrichment
if item.enrichments:
for enrich in item.enrichments:
if enrich.enrichment_id == enrichment_id:
enriched_count += 1
enrich_dict = enrich.model_dump(
by_alias=True, exclude_none=True
)
sample_data.append(
{
"item_id": item.id,
"item_title": (
item.properties.title
if hasattr(item.properties, "title")
else ""
),
"enrichment_data": enrich_dict,
}
)
break
for sdk_item in response.data:
# Convert to our WebsetItemModel first
item = WebsetItemModel.from_sdk(sdk_item)
# Estimate total enriched count based on sample
# Note: This is an estimate - would need to check all items for accurate count
if enriched_count > 0 and len(response.data) > 0:
# For now, just return the sample count as we don't have total item count easily
pass
# Check if this item has the enrichment we're looking for
if enrichment_id in item.enrichments:
enriched_count += 1
enrich_model = item.enrichments[enrichment_id]
# Create sample using our typed model
sample = SampleEnrichmentModel(
item_id=item.id,
item_title=item.title,
enrichment_data=enrich_model.model_dump(exclude_none=True),
)
sample_data.append(sample)
return sample_data, enriched_count