mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' into seer/fix-postmark-error-handling
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
from e2b_code_interpreter import Result as E2BExecutionResult
|
||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||
from pydantic import BaseModel, JsonValue, SecretStr
|
||||
|
||||
@@ -37,7 +38,7 @@ class ProgrammingLanguage(Enum):
|
||||
JAVA = "java"
|
||||
|
||||
|
||||
class CodeExecutionResult(BaseModel):
|
||||
class MainCodeExecutionResult(BaseModel):
|
||||
"""
|
||||
*Pydantic model mirroring `e2b_code_interpreter.Result`*
|
||||
|
||||
@@ -47,7 +48,7 @@ class CodeExecutionResult(BaseModel):
|
||||
The result can contain multiple types of data, such as text, images, plots, etc. Each type of data is represented
|
||||
as a string, and the result can contain multiple types of data. The display calls don't have to have text representation,
|
||||
for the actual result the representation is always present for the result, the other representations are always optional.
|
||||
"""
|
||||
""" # noqa
|
||||
|
||||
class Chart(BaseModel, E2BExecutionResultChart):
|
||||
pass
|
||||
@@ -68,14 +69,104 @@ class CodeExecutionResult(BaseModel):
|
||||
"""Extra data that can be included. Not part of the standard types."""
|
||||
|
||||
|
||||
class CodeExecutionBlock(Block):
|
||||
class CodeExecutionResult(MainCodeExecutionResult):
|
||||
__doc__ = MainCodeExecutionResult.__doc__
|
||||
|
||||
is_main_result: bool = False
|
||||
"""Whether this data is the main result of the cell. Data can be produced by display calls of which can be multiple in a cell.""" # noqa
|
||||
|
||||
|
||||
class BaseE2BExecutorMixin:
|
||||
"""Shared implementation methods for E2B executor blocks."""
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
api_key: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
template_id: str = "",
|
||||
setup_commands: Optional[list[str]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
sandbox_id: Optional[str] = None,
|
||||
dispose_sandbox: bool = False,
|
||||
):
|
||||
"""
|
||||
Unified code execution method that handles all three use cases:
|
||||
1. Create new sandbox and execute (ExecuteCodeBlock)
|
||||
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
|
||||
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
|
||||
""" # noqa
|
||||
sandbox = None
|
||||
try:
|
||||
if sandbox_id:
|
||||
# Connect to existing sandbox (ExecuteCodeStepBlock case)
|
||||
sandbox = await AsyncSandbox.connect(
|
||||
sandbox_id=sandbox_id, api_key=api_key
|
||||
)
|
||||
else:
|
||||
# Create new sandbox (ExecuteCodeBlock/InstantiateCodeSandboxBlock case)
|
||||
sandbox = await AsyncSandbox.create(
|
||||
api_key=api_key, template=template_id, timeout=timeout
|
||||
)
|
||||
if setup_commands:
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Execute the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
results = execution.results
|
||||
text_output = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return results, text_output, stdout_logs, stderr_logs, sandbox.sandbox_id
|
||||
finally:
|
||||
# Dispose of sandbox if requested to reduce usage costs
|
||||
if dispose_sandbox and sandbox:
|
||||
await sandbox.kill()
|
||||
|
||||
def process_execution_results(
|
||||
self, results: list[E2BExecutionResult]
|
||||
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
|
||||
"""Process and filter execution results."""
|
||||
# Filter out empty formats and convert to dicts
|
||||
processed_results = [
|
||||
{
|
||||
f: value
|
||||
for f in [*r.formats(), "extra", "is_main_result"]
|
||||
if (value := getattr(r, f, None)) is not None
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
if main_result := next(
|
||||
(r for r in processed_results if r.get("is_main_result")), None
|
||||
):
|
||||
# Make main_result a copy we can modify & remove is_main_result
|
||||
(main_result := {**main_result}).pop("is_main_result")
|
||||
|
||||
return main_result, processed_results
|
||||
|
||||
|
||||
class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
# TODO : Add support to upload and download files
|
||||
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
|
||||
# NOTE: Currently, you can only customize the CPU and Memory
|
||||
# by creating a pre customized sandbox template
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
),
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
@@ -108,6 +199,14 @@ class CodeExecutionBlock(Block):
|
||||
description="Execution timeout in seconds", default=300
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description=(
|
||||
"Whether to dispose of the sandbox immediately after execution. "
|
||||
"If disabled, the sandbox will run until its timeout expires."
|
||||
),
|
||||
default=True,
|
||||
)
|
||||
|
||||
template_id: str = SchemaField(
|
||||
description=(
|
||||
"You can use an E2B sandbox template by entering its ID here. "
|
||||
@@ -119,7 +218,7 @@ class CodeExecutionBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
main_result: CodeExecutionResult = SchemaField(
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
@@ -138,10 +237,10 @@ class CodeExecutionBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0b02b072-abe7-11ef-8372-fb5d162dd712",
|
||||
description="Executes code in an isolated sandbox environment with internet access.",
|
||||
description="Executes code in a sandbox environment with internet access.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=CodeExecutionBlock.Input,
|
||||
output_schema=CodeExecutionBlock.Output,
|
||||
input_schema=ExecuteCodeBlock.Input,
|
||||
output_schema=ExecuteCodeBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -157,102 +256,54 @@ class CodeExecutionBlock(Block):
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
results = execution.results
|
||||
text_output = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return results, text_output, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, text_output, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.code,
|
||||
language=input_data.language,
|
||||
template_id=input_data.template_id,
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
)
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
results = [
|
||||
{
|
||||
f: r[f]
|
||||
for f in [*r.formats(), "extra", "is_main_result"]
|
||||
if getattr(r, f, None) is not None
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
main_result, results = self.process_execution_results(results)
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
for r in results:
|
||||
if r.pop("is_main_result", False):
|
||||
yield "main_result", r
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class InstantiationBlock(Block):
|
||||
class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
)
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
@@ -310,10 +361,13 @@ class InstantiationBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ff0861c9-1726-4aec-9e5b-bf53f3622112",
|
||||
description="Instantiate an isolated sandbox environment with internet access where to execute code in.",
|
||||
description=(
|
||||
"Instantiate a sandbox environment with internet access "
|
||||
"in which you can execute code with the Execute Code Step block."
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=InstantiationBlock.Input,
|
||||
output_schema=InstantiationBlock.Output,
|
||||
input_schema=InstantiateCodeSandboxBlock.Input,
|
||||
output_schema=InstantiateCodeSandboxBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -329,11 +383,12 @@ class InstantiationBlock(Block):
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda setup_code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"sandbox_id", # sandbox_id
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -342,13 +397,13 @@ class InstantiationBlock(Block):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sandbox_id, text_output, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.setup_code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
_, text_output, stdout, stderr, sandbox_id = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.setup_code,
|
||||
language=input_data.language,
|
||||
template_id=input_data.template_id,
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
)
|
||||
if sandbox_id:
|
||||
yield "sandbox_id", sandbox_id
|
||||
@@ -357,64 +412,23 @@ class InstantiationBlock(Block):
|
||||
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
text_output = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return sandbox.sandbox_id, text_output, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class StepExecutionBlock(Block):
|
||||
class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
),
|
||||
)
|
||||
|
||||
sandbox_id: str = SchemaField(
|
||||
@@ -435,8 +449,13 @@ class StepExecutionBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description="Whether to dispose of the sandbox after executing this code.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
main_result: CodeExecutionResult = SchemaField(
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
@@ -455,10 +474,10 @@ class StepExecutionBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="82b59b8e-ea10-4d57-9161-8b169b0adba6",
|
||||
description="Execute code in a previously instantiated sandbox environment.",
|
||||
description="Execute code in a previously instantiated sandbox.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StepExecutionBlock.Input,
|
||||
output_schema=StepExecutionBlock.Output,
|
||||
input_schema=ExecuteCodeStepBlock.Input,
|
||||
output_schema=ExecuteCodeStepBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -472,74 +491,38 @@ class StepExecutionBlock(Block):
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_step_code": lambda sandbox_id, step_code, language, api_key: (
|
||||
"execute_code": lambda api_key, code, language, sandbox_id, dispose_sandbox: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
sandbox_id, # sandbox_id
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_step_code(
|
||||
self,
|
||||
sandbox_id: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
api_key: str,
|
||||
):
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id=sandbox_id, api_key=api_key)
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not found")
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(code, language=language.value)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
results = execution.results
|
||||
text_output = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return results, text_output, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, text_output, stdout_logs, stderr_logs = (
|
||||
await self.execute_step_code(
|
||||
input_data.sandbox_id,
|
||||
input_data.step_code,
|
||||
input_data.language,
|
||||
credentials.api_key.get_secret_value(),
|
||||
)
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.step_code,
|
||||
language=input_data.language,
|
||||
sandbox_id=input_data.sandbox_id,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
)
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
results = [
|
||||
{
|
||||
f: r[f]
|
||||
for f in [*r.formats(), "extra", "is_main_result"]
|
||||
if getattr(r, f, None) is not None
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
main_result, results = self.process_execution_results(results)
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
for r in results:
|
||||
if r.pop("is_main_result", False):
|
||||
yield "main_result", r
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Utility functions for converting between our ScrapeFormat enum and firecrawl FormatOption types."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from firecrawl.v2.types import FormatOption, ScreenshotFormat
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
|
||||
|
||||
def convert_to_format_options(
|
||||
formats: List[ScrapeFormat],
|
||||
) -> List[FormatOption]:
|
||||
"""Convert our ScrapeFormat enum values to firecrawl FormatOption types.
|
||||
|
||||
Handles special cases like screenshot@fullPage which needs to be converted
|
||||
to a ScreenshotFormat object.
|
||||
"""
|
||||
result: List[FormatOption] = []
|
||||
|
||||
for format_enum in formats:
|
||||
if format_enum.value == "screenshot@fullPage":
|
||||
# Special case: convert to ScreenshotFormat with full_page=True
|
||||
result.append(ScreenshotFormat(type="screenshot", full_page=True))
|
||||
else:
|
||||
# Regular string literals
|
||||
result.append(format_enum.value)
|
||||
|
||||
return result
|
||||
@@ -1,8 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp, ScrapeOptions
|
||||
from firecrawl import FirecrawlApp
|
||||
from firecrawl.v2.types import ScrapeOptions
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -14,21 +15,10 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class FirecrawlCrawlBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
@@ -78,18 +68,17 @@ class FirecrawlCrawlBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
crawl_result = app.crawl_url(
|
||||
crawl_result = app.crawl(
|
||||
input_data.url,
|
||||
limit=input_data.limit,
|
||||
scrape_options=ScrapeOptions(
|
||||
formats=[format.value for format in input_data.formats],
|
||||
onlyMainContent=input_data.only_main_content,
|
||||
maxAge=input_data.max_age,
|
||||
waitFor=input_data.wait_for,
|
||||
formats=convert_to_format_options(input_data.formats),
|
||||
only_main_content=input_data.only_main_content,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
yield "data", crawl_result.data
|
||||
@@ -101,7 +90,7 @@ class FirecrawlCrawlBlock(Block):
|
||||
elif f == ScrapeFormat.HTML:
|
||||
yield "html", data.html
|
||||
elif f == ScrapeFormat.RAW_HTML:
|
||||
yield "raw_html", data.rawHtml
|
||||
yield "raw_html", data.raw_html
|
||||
elif f == ScrapeFormat.LINKS:
|
||||
yield "links", data.links
|
||||
elif f == ScrapeFormat.SCREENSHOT:
|
||||
@@ -109,6 +98,6 @@ class FirecrawlCrawlBlock(Block):
|
||||
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
|
||||
yield "screenshot_full_page", data.screenshot
|
||||
elif f == ScrapeFormat.CHANGE_TRACKING:
|
||||
yield "change_tracking", data.changeTracking
|
||||
yield "change_tracking", data.change_tracking
|
||||
elif f == ScrapeFormat.JSON:
|
||||
yield "json", data.json
|
||||
|
||||
@@ -20,7 +20,6 @@ from ._config import firecrawl
|
||||
|
||||
@cost(BlockCost(2, BlockCostType.RUN))
|
||||
class FirecrawlExtractBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
urls: list[str] = SchemaField(
|
||||
@@ -53,7 +52,6 @@ class FirecrawlExtractBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
extract_result = app.extract(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.sdk import (
|
||||
@@ -14,14 +16,16 @@ from ._config import firecrawl
|
||||
|
||||
|
||||
class FirecrawlMapWebsiteBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
|
||||
url: str = SchemaField(description="The website url to map")
|
||||
|
||||
class Output(BlockSchema):
|
||||
links: list[str] = SchemaField(description="The links of the website")
|
||||
links: list[str] = SchemaField(description="List of URLs found on the website")
|
||||
results: list[dict[str, Any]] = SchemaField(
|
||||
description="List of search results with url, title, and description"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -35,12 +39,22 @@ class FirecrawlMapWebsiteBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
map_result = app.map_url(
|
||||
map_result = app.map(
|
||||
url=input_data.url,
|
||||
)
|
||||
|
||||
yield "links", map_result.links
|
||||
# Convert SearchResult objects to dicts
|
||||
results_data = [
|
||||
{
|
||||
"url": link.url,
|
||||
"title": link.title,
|
||||
"description": link.description,
|
||||
}
|
||||
for link in map_result.links
|
||||
]
|
||||
|
||||
yield "links", [link.url for link in map_result.links]
|
||||
yield "results", results_data
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -14,21 +14,10 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class FirecrawlScrapeBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
@@ -78,12 +67,11 @@ class FirecrawlScrapeBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
scrape_result = app.scrape_url(
|
||||
scrape_result = app.scrape(
|
||||
input_data.url,
|
||||
formats=[format.value for format in input_data.formats],
|
||||
formats=convert_to_format_options(input_data.formats),
|
||||
only_main_content=input_data.only_main_content,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
@@ -96,7 +84,7 @@ class FirecrawlScrapeBlock(Block):
|
||||
elif f == ScrapeFormat.HTML:
|
||||
yield "html", scrape_result.html
|
||||
elif f == ScrapeFormat.RAW_HTML:
|
||||
yield "raw_html", scrape_result.rawHtml
|
||||
yield "raw_html", scrape_result.raw_html
|
||||
elif f == ScrapeFormat.LINKS:
|
||||
yield "links", scrape_result.links
|
||||
elif f == ScrapeFormat.SCREENSHOT:
|
||||
@@ -104,6 +92,6 @@ class FirecrawlScrapeBlock(Block):
|
||||
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
|
||||
yield "screenshot_full_page", scrape_result.screenshot
|
||||
elif f == ScrapeFormat.CHANGE_TRACKING:
|
||||
yield "change_tracking", scrape_result.changeTracking
|
||||
yield "change_tracking", scrape_result.change_tracking
|
||||
elif f == ScrapeFormat.JSON:
|
||||
yield "json", scrape_result.json
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp, ScrapeOptions
|
||||
from firecrawl import FirecrawlApp
|
||||
from firecrawl.v2.types import ScrapeOptions
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -14,21 +15,10 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class FirecrawlSearchBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
query: str = SchemaField(description="The query to search for")
|
||||
@@ -61,7 +51,6 @@ class FirecrawlSearchBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
@@ -69,11 +58,12 @@ class FirecrawlSearchBlock(Block):
|
||||
input_data.query,
|
||||
limit=input_data.limit,
|
||||
scrape_options=ScrapeOptions(
|
||||
formats=[format.value for format in input_data.formats],
|
||||
maxAge=input_data.max_age,
|
||||
waitFor=input_data.wait_for,
|
||||
formats=convert_to_format_options(input_data.formats) or None,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
yield "data", scrape_result
|
||||
for site in scrape_result.data:
|
||||
yield "site", site
|
||||
if hasattr(scrape_result, "web") and scrape_result.web:
|
||||
for site in scrape_result.web:
|
||||
yield "site", site
|
||||
|
||||
@@ -13,6 +13,11 @@ from backend.data.block import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.dynamic_fields import (
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
is_dynamic_field,
|
||||
)
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
@@ -98,6 +103,22 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
|
||||
|
||||
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Safely convert raw_response to dictionary format for conversation history.
|
||||
Handles different response types from different LLM providers.
|
||||
"""
|
||||
if isinstance(raw_response, str):
|
||||
# Ollama returns a string, convert to dict format
|
||||
return {"role": "assistant", "content": raw_response}
|
||||
elif isinstance(raw_response, dict):
|
||||
# Already a dict (from tests or some providers)
|
||||
return raw_response
|
||||
else:
|
||||
# OpenAI/Anthropic return objects, convert with json.to_dict
|
||||
return json.to_dict(raw_response)
|
||||
|
||||
|
||||
def get_pending_tool_calls(conversation_history: list[Any]) -> dict[str, int]:
|
||||
"""
|
||||
All the tool calls entry in the conversation history requires a response.
|
||||
@@ -261,6 +282,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def cleanup(s: str):
|
||||
"""Clean up block names for use as tool function names."""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
@staticmethod
|
||||
@@ -288,41 +310,66 @@ class SmartDecisionMakerBlock(Block):
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
properties = {}
|
||||
field_mapping = {} # clean_name -> original_name
|
||||
|
||||
for link in links:
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Clean property key to ensure Anthropic API compatibility for ALL fields
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
|
||||
# Handle dynamic fields (e.g., values_#_*, items_$_*, etc.)
|
||||
# These are fields that get merged by the executor into their base field
|
||||
if (
|
||||
"_#_" in link.sink_name
|
||||
or "_$_" in link.sink_name
|
||||
or "_@_" in link.sink_name
|
||||
):
|
||||
# For dynamic fields, provide a generic string schema
|
||||
# The executor will handle merging these into the appropriate structure
|
||||
properties[sink_name] = {
|
||||
if is_dynamic:
|
||||
# For dynamic fields, use cleaned name but preserve original in description
|
||||
properties[clean_field_name] = {
|
||||
"type": "string",
|
||||
"description": f"Dynamic value for {link.sink_name}",
|
||||
"description": get_dynamic_field_description(field_name),
|
||||
}
|
||||
else:
|
||||
# For regular fields, use the block's schema
|
||||
# For regular fields, use the block's schema directly
|
||||
try:
|
||||
properties[sink_name] = sink_block_input_schema.get_field_schema(
|
||||
link.sink_name
|
||||
properties[clean_field_name] = (
|
||||
sink_block_input_schema.get_field_schema(field_name)
|
||||
)
|
||||
except (KeyError, AttributeError):
|
||||
# If the field doesn't exist in the schema, provide a generic schema
|
||||
properties[sink_name] = {
|
||||
# If field doesn't exist in schema, provide a generic one
|
||||
properties[clean_field_name] = {
|
||||
"type": "string",
|
||||
"description": f"Value for {link.sink_name}",
|
||||
"description": f"Value for {field_name}",
|
||||
}
|
||||
|
||||
# Build the parameters schema using a single unified path
|
||||
base_schema = block.input_schema.jsonschema()
|
||||
base_required = set(base_schema.get("required", []))
|
||||
|
||||
# Compute required fields at the leaf level:
|
||||
# - If a linked field is dynamic and its base is required in the block schema, require the leaf
|
||||
# - If a linked field is regular and is required in the block schema, require the leaf
|
||||
required_fields: set[str] = set()
|
||||
for link in links:
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Always use cleaned field name for property key (Anthropic API compliance)
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
|
||||
if is_dynamic:
|
||||
base_name = extract_base_field_name(field_name)
|
||||
if base_name in base_required:
|
||||
required_fields.add(clean_field_name)
|
||||
else:
|
||||
if field_name in base_required:
|
||||
required_fields.add(clean_field_name)
|
||||
|
||||
tool_function["parameters"] = {
|
||||
**block.input_schema.jsonschema(),
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"additionalProperties": False,
|
||||
"required": sorted(required_fields),
|
||||
}
|
||||
|
||||
# Store field mapping for later use in output processing
|
||||
tool_function["_field_mapping"] = field_mapping
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
@@ -366,13 +413,12 @@ class SmartDecisionMakerBlock(Block):
|
||||
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
|
||||
link.sink_name, {}
|
||||
)
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
description = (
|
||||
sink_block_properties["description"]
|
||||
if "description" in sink_block_properties
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[sink_name] = {
|
||||
properties[link.sink_name] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
"default": json.dumps(sink_block_properties.get("default", None)),
|
||||
@@ -388,24 +434,17 @@ class SmartDecisionMakerBlock(Block):
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
async def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
async def _create_function_signature(
|
||||
node_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Creates function signatures for tools linked to a specified node within a graph.
|
||||
|
||||
This method filters the graph links to identify those that are tools and are
|
||||
connected to the given node_id. It then constructs function signatures for each
|
||||
tool based on the metadata and input schema of the linked nodes.
|
||||
Creates function signatures for connected tools.
|
||||
|
||||
Args:
|
||||
node_id: The node_id for which to create function signatures.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
|
||||
for a tool, including its name, description, and parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If no tool links are found for the specified node_id, or if a sink node
|
||||
or its metadata cannot be found.
|
||||
List of function signatures for tools
|
||||
"""
|
||||
db_client = get_database_manager_async_client()
|
||||
tools = [
|
||||
@@ -430,20 +469,116 @@ class SmartDecisionMakerBlock(Block):
|
||||
raise ValueError(f"Sink node not found: {links[0].sink_id}")
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
return_tool_functions.append(
|
||||
tool_func = (
|
||||
await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
else:
|
||||
return_tool_functions.append(
|
||||
tool_func = (
|
||||
await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
|
||||
return return_tool_functions
|
||||
|
||||
async def _attempt_llm_call_with_validation(
|
||||
self,
|
||||
credentials: llm.APIKeyCredentials,
|
||||
input_data: Input,
|
||||
current_prompt: list[dict],
|
||||
tool_functions: list[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Attempt a single LLM call with tool validation.
|
||||
|
||||
Returns the response if successful, raises ValueError if validation fails.
|
||||
"""
|
||||
resp = await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=current_prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
)
|
||||
|
||||
# Track LLM usage stats per call
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=resp.prompt_tokens,
|
||||
output_token_count=resp.completion_tokens,
|
||||
llm_call_count=1,
|
||||
)
|
||||
)
|
||||
|
||||
if not resp.tool_calls:
|
||||
return resp
|
||||
validation_errors_list: list[str] = []
|
||||
for tool_call in resp.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
try:
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
except Exception as e:
|
||||
validation_errors_list.append(
|
||||
f"Tool call '{tool_name}' has invalid JSON arguments: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if tool_def is None and len(tool_functions) == 1:
|
||||
tool_def = tool_functions[0]
|
||||
|
||||
# Get parameters schema from tool definition
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
and "parameters" in tool_def["function"]
|
||||
):
|
||||
parameters = tool_def["function"]["parameters"]
|
||||
expected_args = parameters.get("properties", {})
|
||||
required_params = set(parameters.get("required", []))
|
||||
else:
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
required_params = set()
|
||||
|
||||
# Validate tool call arguments
|
||||
provided_args = set(tool_args.keys())
|
||||
expected_args_set = set(expected_args.keys())
|
||||
|
||||
# Check for unexpected arguments (typos)
|
||||
unexpected_args = provided_args - expected_args_set
|
||||
# Only check for missing REQUIRED parameters
|
||||
missing_required_args = required_params - provided_args
|
||||
|
||||
if unexpected_args or missing_required_args:
|
||||
error_msg = f"Tool call '{tool_name}' has parameter errors:"
|
||||
if unexpected_args:
|
||||
error_msg += f" Unknown parameters: {sorted(unexpected_args)}."
|
||||
if missing_required_args:
|
||||
error_msg += f" Missing required parameters: {sorted(missing_required_args)}."
|
||||
error_msg += f" Expected parameters: {sorted(expected_args_set)}."
|
||||
if required_params:
|
||||
error_msg += f" Required parameters: {sorted(required_params)}."
|
||||
validation_errors_list.append(error_msg)
|
||||
|
||||
if validation_errors_list:
|
||||
raise ValueError("; ".join(validation_errors_list))
|
||||
|
||||
return resp
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
@@ -466,27 +601,19 @@ class SmartDecisionMakerBlock(Block):
|
||||
if pending_tool_calls and input_data.last_tool_output is None:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
# Only assign the last tool output to the first pending tool call
|
||||
tool_output = []
|
||||
if pending_tool_calls and input_data.last_tool_output is not None:
|
||||
# Get the first pending tool call ID
|
||||
first_call_id = next(iter(pending_tool_calls.keys()))
|
||||
tool_output.append(
|
||||
_create_tool_response(first_call_id, input_data.last_tool_output)
|
||||
)
|
||||
|
||||
# Add tool output to prompt right away
|
||||
prompt.extend(tool_output)
|
||||
|
||||
# Check if there are still pending tool calls after handling the first one
|
||||
remaining_pending_calls = get_pending_tool_calls(prompt)
|
||||
|
||||
# If there are still pending tool calls, yield the conversation and return early
|
||||
if remaining_pending_calls:
|
||||
yield "conversations", prompt
|
||||
return
|
||||
|
||||
# Fallback on adding tool output in the conversation history as user prompt.
|
||||
elif input_data.last_tool_output:
|
||||
logger.error(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
@@ -519,121 +646,42 @@ class SmartDecisionMakerBlock(Block):
|
||||
):
|
||||
prompt.append({"role": "user", "content": prefix + input_data.prompt})
|
||||
|
||||
# Use retry decorator for LLM calls with validation
|
||||
from backend.util.retry import create_retry_decorator
|
||||
current_prompt = list(prompt)
|
||||
max_attempts = max(1, int(input_data.retry))
|
||||
response = None
|
||||
|
||||
# Create retry decorator that excludes ValueError from retry (for non-LLM errors)
|
||||
llm_retry = create_retry_decorator(
|
||||
max_attempts=input_data.retry,
|
||||
exclude_exceptions=(), # Don't exclude ValueError - we want to retry validation failures
|
||||
context="SmartDecisionMaker LLM call",
|
||||
)
|
||||
|
||||
@llm_retry
|
||||
async def call_llm_with_validation():
|
||||
response = await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
)
|
||||
|
||||
# Track LLM usage stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
llm_call_count=1,
|
||||
last_error = None
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = await self._attempt_llm_call_with_validation(
|
||||
credentials, input_data, current_prompt, tool_functions
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
if not response.tool_calls:
|
||||
return response, None # No tool calls, return response
|
||||
|
||||
# Validate all tool calls before proceeding
|
||||
validation_errors = []
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# Get parameters schema from tool definition
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
and "parameters" in tool_def["function"]
|
||||
):
|
||||
parameters = tool_def["function"]["parameters"]
|
||||
expected_args = parameters.get("properties", {})
|
||||
required_params = set(parameters.get("required", []))
|
||||
else:
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
required_params = set()
|
||||
|
||||
# Validate tool call arguments
|
||||
provided_args = set(tool_args.keys())
|
||||
expected_args_set = set(expected_args.keys())
|
||||
|
||||
# Check for unexpected arguments (typos)
|
||||
unexpected_args = provided_args - expected_args_set
|
||||
# Only check for missing REQUIRED parameters
|
||||
missing_required_args = required_params - provided_args
|
||||
|
||||
if unexpected_args or missing_required_args:
|
||||
error_msg = f"Tool call '{tool_name}' has parameter errors:"
|
||||
if unexpected_args:
|
||||
error_msg += f" Unknown parameters: {sorted(unexpected_args)}."
|
||||
if missing_required_args:
|
||||
error_msg += f" Missing required parameters: {sorted(missing_required_args)}."
|
||||
error_msg += f" Expected parameters: {sorted(expected_args_set)}."
|
||||
if required_params:
|
||||
error_msg += f" Required parameters: {sorted(required_params)}."
|
||||
validation_errors.append(error_msg)
|
||||
|
||||
# If validation failed, add feedback and raise for retry
|
||||
if validation_errors:
|
||||
# Add the failed response to conversation
|
||||
prompt.append(response.raw_response)
|
||||
|
||||
# Add error feedback for retry
|
||||
except ValueError as e:
|
||||
last_error = e
|
||||
error_feedback = (
|
||||
"Your tool call had parameter errors. Please fix the following issues and try again:\n"
|
||||
+ "\n".join(f"- {error}" for error in validation_errors)
|
||||
+ "\n\nPlease make sure to use the exact parameter names as specified in the function schema."
|
||||
+ f"- {str(e)}\n"
|
||||
+ "\nPlease make sure to use the exact parameter names as specified in the function schema."
|
||||
)
|
||||
prompt.append({"role": "user", "content": error_feedback})
|
||||
current_prompt = list(current_prompt) + [
|
||||
{"role": "user", "content": error_feedback}
|
||||
]
|
||||
|
||||
raise ValueError(
|
||||
f"Tool call validation failed: {'; '.join(validation_errors)}"
|
||||
)
|
||||
|
||||
return response, validation_errors
|
||||
|
||||
# Call the LLM with retry logic
|
||||
response, validation_errors = await call_llm_with_validation()
|
||||
if response is None:
|
||||
raise last_error or ValueError(
|
||||
"Failed to get valid response after all retry attempts"
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
yield "finished", response.response
|
||||
return
|
||||
|
||||
# If we get here, validation passed - yield tool outputs
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
# Get expected arguments (already validated above)
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
@@ -651,15 +699,36 @@ class SmartDecisionMakerBlock(Block):
|
||||
else:
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
|
||||
# Yield provided arguments, use .get() for optional parameters
|
||||
for arg_name in expected_args:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args.get(arg_name)
|
||||
# Get field mapping from tool definition
|
||||
field_mapping = (
|
||||
tool_def.get("function", {}).get("_field_mapping", {})
|
||||
if tool_def
|
||||
else {}
|
||||
)
|
||||
|
||||
for clean_arg_name in expected_args:
|
||||
# arg_name is now always the cleaned field name (for Anthropic API compliance)
|
||||
# Get the original field name from field mapping for proper emit key generation
|
||||
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
|
||||
arg_value = tool_args.get(clean_arg_name)
|
||||
|
||||
sanitized_tool_name = self.cleanup(tool_name)
|
||||
sanitized_arg_name = self.cleanup(original_field_name)
|
||||
emit_key = f"tools_^_{sanitized_tool_name}_~_{sanitized_arg_name}"
|
||||
|
||||
logger.debug(
|
||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||
graph_exec_id,
|
||||
node_exec_id,
|
||||
emit_key,
|
||||
)
|
||||
yield emit_key, arg_value
|
||||
|
||||
# Add reasoning to conversation history if available
|
||||
if response.reasoning:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(response.raw_response)
|
||||
prompt.append(_convert_raw_response_to_dict(response.raw_response))
|
||||
|
||||
yield "conversations", prompt
|
||||
|
||||
@@ -19,7 +19,7 @@ async def test_block_ids_valid(block: Type[Block]):
|
||||
# Skip list for blocks with known invalid UUIDs
|
||||
skip_blocks = {
|
||||
"GetWeatherInformationBlock",
|
||||
"CodeExecutionBlock",
|
||||
"ExecuteCodeBlock",
|
||||
"CountdownTimerBlock",
|
||||
"TwitterGetListTweetsBlock",
|
||||
"TwitterRemoveListMemberBlock",
|
||||
|
||||
@@ -216,8 +216,17 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
}
|
||||
|
||||
# Mock the _create_function_signature method to avoid database calls
|
||||
with patch("backend.blocks.llm.llm_call", return_value=mock_response), patch.object(
|
||||
SmartDecisionMakerBlock, "_create_function_signature", return_value=[]
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
|
||||
# Create test input
|
||||
@@ -301,11 +310,16 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
mock_response_with_typo.reasoning = None
|
||||
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", return_value=mock_response_with_typo
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_with_typo,
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
@@ -332,7 +346,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
# Verify error message contains details about the typo
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Tool call validation failed" in error_msg
|
||||
assert "Tool call 'search_keywords' has parameter errors" in error_msg
|
||||
assert "Unknown parameters: ['maximum_keyword_difficulty']" in error_msg
|
||||
|
||||
# Verify that LLM was called the expected number of times (retries)
|
||||
@@ -353,11 +367,16 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
mock_response_missing_required.reasoning = None
|
||||
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", return_value=mock_response_missing_required
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_missing_required,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
@@ -398,11 +417,16 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
mock_response_valid.reasoning = None
|
||||
mock_response_valid.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", return_value=mock_response_valid
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_valid,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
@@ -447,11 +471,16 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
mock_response_all_params.reasoning = None
|
||||
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", return_value=mock_response_all_params
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_all_params,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
@@ -478,3 +507,222 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] == "custom_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_raw_response_conversion():
|
||||
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool functions
|
||||
mock_tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Test case 1: Simulate ChatCompletionMessage raw_response that caused the original error
|
||||
class MockChatCompletionMessage:
|
||||
"""Simulate OpenAI's ChatCompletionMessage object that lacks .get() method"""
|
||||
|
||||
def __init__(self, role, content, tool_calls=None):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.tool_calls = tool_calls or []
|
||||
|
||||
# This is what caused the error - no .get() method
|
||||
# def get(self, key, default=None): # Intentionally missing
|
||||
|
||||
# First response: has invalid parameter name (triggers retry)
|
||||
mock_tool_call_invalid = MagicMock()
|
||||
mock_tool_call_invalid.function.name = "test_tool"
|
||||
mock_tool_call_invalid.function.arguments = (
|
||||
'{"wrong_param": "test_value"}' # Invalid parameter name
|
||||
)
|
||||
|
||||
mock_response_retry = MagicMock()
|
||||
mock_response_retry.response = None
|
||||
mock_response_retry.tool_calls = [mock_tool_call_invalid]
|
||||
mock_response_retry.prompt_tokens = 50
|
||||
mock_response_retry.completion_tokens = 25
|
||||
mock_response_retry.reasoning = None
|
||||
# This would cause the original error without our fix
|
||||
mock_response_retry.raw_response = MockChatCompletionMessage(
|
||||
role="assistant", content=None, tool_calls=[mock_tool_call_invalid]
|
||||
)
|
||||
|
||||
# Second response: successful (correct parameter name)
|
||||
mock_tool_call_valid = MagicMock()
|
||||
mock_tool_call_valid.function.name = "test_tool"
|
||||
mock_tool_call_valid.function.arguments = (
|
||||
'{"param": "test_value"}' # Correct parameter name
|
||||
)
|
||||
|
||||
mock_response_success = MagicMock()
|
||||
mock_response_success.response = None
|
||||
mock_response_success.tool_calls = [mock_tool_call_valid]
|
||||
mock_response_success.prompt_tokens = 50
|
||||
mock_response_success.completion_tokens = 25
|
||||
mock_response_success.reasoning = None
|
||||
mock_response_success.raw_response = MockChatCompletionMessage(
|
||||
role="assistant", content=None, tool_calls=[mock_tool_call_valid]
|
||||
)
|
||||
|
||||
# Mock llm_call to return different responses on different calls
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
# First call returns response that will trigger retry due to validation error
|
||||
# Second call returns successful response
|
||||
mock_llm_call.side_effect = [mock_response_retry, mock_response_success]
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
|
||||
# Should succeed after retry, demonstrating our helper function works
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify the tool output was generated successfully
|
||||
assert "tools_^_test_tool_~_param" in outputs
|
||||
assert outputs["tools_^_test_tool_~_param"] == "test_value"
|
||||
|
||||
# Verify conversation history was properly maintained
|
||||
assert "conversations" in outputs
|
||||
conversations = outputs["conversations"]
|
||||
assert len(conversations) > 0
|
||||
|
||||
# The conversations should contain properly converted raw_response objects as dicts
|
||||
# This would have failed with the original bug due to ChatCompletionMessage.get() error
|
||||
for msg in conversations:
|
||||
assert isinstance(msg, dict), f"Expected dict, got {type(msg)}"
|
||||
if msg.get("role") == "assistant":
|
||||
# Should have been converted from ChatCompletionMessage to dict
|
||||
assert "role" in msg
|
||||
|
||||
# Verify LLM was called twice (initial + 1 retry)
|
||||
assert mock_llm_call.call_count == 2
|
||||
|
||||
# Test case 2: Test with different raw_response types (Ollama string, dict)
|
||||
# Test Ollama string response
|
||||
mock_response_ollama = MagicMock()
|
||||
mock_response_ollama.response = "I'll help you with that."
|
||||
mock_response_ollama.tool_calls = None
|
||||
mock_response_ollama.prompt_tokens = 30
|
||||
mock_response_ollama.completion_tokens = 15
|
||||
mock_response_ollama.reasoning = None
|
||||
mock_response_ollama.raw_response = (
|
||||
"I'll help you with that." # Ollama returns string
|
||||
)
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_ollama,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[], # No tools for this test
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Should finish since no tool calls
|
||||
assert "finished" in outputs
|
||||
assert outputs["finished"] == "I'll help you with that."
|
||||
|
||||
# Test case 3: Test with dict raw_response (some providers/tests)
|
||||
mock_response_dict = MagicMock()
|
||||
mock_response_dict.response = "Test response"
|
||||
mock_response_dict.tool_calls = None
|
||||
mock_response_dict.prompt_tokens = 25
|
||||
mock_response_dict.completion_tokens = 10
|
||||
mock_response_dict.reasoning = None
|
||||
mock_response_dict.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": "Test response",
|
||||
} # Dict format
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_dict,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
assert "finished" in outputs
|
||||
assert outputs["finished"] == "Test response"
|
||||
|
||||
@@ -48,16 +48,24 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
assert "parameters" in signature["function"]
|
||||
assert "properties" in signature["function"]["parameters"]
|
||||
|
||||
# Check that dynamic fields are handled
|
||||
# Check that dynamic fields are handled with original names
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3 # Should have all three fields
|
||||
|
||||
# Each dynamic field should have proper schema
|
||||
for prop_value in properties.values():
|
||||
# Check that field names are cleaned (for Anthropic API compatibility)
|
||||
assert "values___name" in properties
|
||||
assert "values___age" in properties
|
||||
assert "values___city" in properties
|
||||
|
||||
# Each dynamic field should have proper schema with descriptive text
|
||||
for field_name, prop_value in properties.items():
|
||||
assert "type" in prop_value
|
||||
assert prop_value["type"] == "string" # Dynamic fields get string type
|
||||
assert "description" in prop_value
|
||||
assert "Dynamic value for" in prop_value["description"]
|
||||
# Check that descriptions properly explain the dynamic field
|
||||
if field_name == "values___name":
|
||||
assert "Dictionary field 'name'" in prop_value["description"]
|
||||
assert "values['name']" in prop_value["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -96,10 +104,18 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 2 # Should have both list items
|
||||
|
||||
# Each dynamic field should have proper schema
|
||||
for prop_value in properties.values():
|
||||
# Check that field names are cleaned (for Anthropic API compatibility)
|
||||
assert "entries___0" in properties
|
||||
assert "entries___1" in properties
|
||||
|
||||
# Each dynamic field should have proper schema with descriptive text
|
||||
for field_name, prop_value in properties.items():
|
||||
assert prop_value["type"] == "string"
|
||||
assert "Dynamic value for" in prop_value["description"]
|
||||
assert "description" in prop_value
|
||||
# Check that descriptions properly explain the list field
|
||||
if field_name == "entries___0":
|
||||
assert "List item 0" in prop_value["description"]
|
||||
assert "entries[0]" in prop_value["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -0,0 +1,553 @@
|
||||
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.text import MatchTextPatternBlock
|
||||
from backend.data.dynamic_fields import get_dynamic_field_description
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_field_description_generation():
|
||||
"""Test that dynamic field descriptions are generated correctly."""
|
||||
# Test dictionary field description
|
||||
desc = get_dynamic_field_description("values_#_name")
|
||||
assert "Dictionary field 'name' for base field 'values'" in desc
|
||||
assert "values['name']" in desc
|
||||
|
||||
# Test list field description
|
||||
desc = get_dynamic_field_description("items_$_0")
|
||||
assert "List item 0 for base field 'items'" in desc
|
||||
assert "items[0]" in desc
|
||||
|
||||
# Test object field description
|
||||
desc = get_dynamic_field_description("user_@_email")
|
||||
assert "Object attribute 'email' for base field 'user'" in desc
|
||||
assert "user.email" in desc
|
||||
|
||||
# Test regular field fallback
|
||||
desc = get_dynamic_field_description("regular_field")
|
||||
assert desc == "Value for regular_field"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_dict_fields():
|
||||
"""Test that function signatures are created correctly for dictionary dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
mock_node.block = CreateDictionaryBlock()
|
||||
mock_node.block_id = CreateDictionaryBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___name", # Sanitized source
|
||||
sink_name="values_#_name", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___age", # Sanitized source
|
||||
sink_name="values_#_age", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___email", # Sanitized source
|
||||
sink_name="values_#_email", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Verify the signature structure
|
||||
assert signature["type"] == "function"
|
||||
assert "function" in signature
|
||||
assert "parameters" in signature["function"]
|
||||
assert "properties" in signature["function"]["parameters"]
|
||||
|
||||
# Check that dynamic fields are handled with original names
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3
|
||||
|
||||
# Check cleaned field names (for Anthropic API compatibility)
|
||||
assert "values___name" in properties
|
||||
assert "values___age" in properties
|
||||
assert "values___email" in properties
|
||||
|
||||
# Check descriptions mention they are dictionary fields
|
||||
assert "Dictionary field" in properties["values___name"]["description"]
|
||||
assert "values['name']" in properties["values___name"]["description"]
|
||||
|
||||
assert "Dictionary field" in properties["values___age"]["description"]
|
||||
assert "values['age']" in properties["values___age"]["description"]
|
||||
|
||||
assert "Dictionary field" in properties["values___email"]["description"]
|
||||
assert "values['email']" in properties["values___email"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_list_fields():
|
||||
"""Test that function signatures are created correctly for list dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
mock_node.block = AddToListBlock()
|
||||
mock_node.block_id = AddToListBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic list fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_2",
|
||||
sink_name="entries_$_2", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Verify the signature structure
|
||||
assert signature["type"] == "function"
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
|
||||
# Check cleaned field names (for Anthropic API compatibility)
|
||||
assert "entries___0" in properties
|
||||
assert "entries___1" in properties
|
||||
assert "entries___2" in properties
|
||||
|
||||
# Check descriptions mention they are list items
|
||||
assert "List item 0" in properties["entries___0"]["description"]
|
||||
assert "entries[0]" in properties["entries___0"]["description"]
|
||||
|
||||
assert "List item 1" in properties["entries___1"]["description"]
|
||||
assert "entries[1]" in properties["entries___1"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_object_fields():
|
||||
"""Test that function signatures are created correctly for object dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for MatchTextPatternBlock (simulating object fields)
|
||||
mock_node = Mock()
|
||||
mock_node.block = MatchTextPatternBlock()
|
||||
mock_node.block_id = MatchTextPatternBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic object fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_name",
|
||||
sink_name="data_@_user_name", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_email",
|
||||
sink_name="data_@_user_email", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Verify the signature structure
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
|
||||
# Check cleaned field names (for Anthropic API compatibility)
|
||||
assert "data___user_name" in properties
|
||||
assert "data___user_email" in properties
|
||||
|
||||
# Check descriptions mention they are object attributes
|
||||
assert "Object attribute" in properties["data___user_name"]["description"]
|
||||
assert "data.user_name" in properties["data___user_name"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_function_signature():
|
||||
"""Test that the mapping between sanitized and original field names is built correctly."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock the database client and connected nodes
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
) as mock_db:
|
||||
mock_client = AsyncMock()
|
||||
mock_db.return_value = mock_client
|
||||
|
||||
# Create mock nodes and links
|
||||
mock_dict_node = Mock()
|
||||
mock_dict_node.block = CreateDictionaryBlock()
|
||||
mock_dict_node.block_id = CreateDictionaryBlock().id
|
||||
mock_dict_node.input_default = {}
|
||||
|
||||
mock_list_node = Mock()
|
||||
mock_list_node.block = AddToListBlock()
|
||||
mock_list_node.block_id = AddToListBlock().id
|
||||
mock_list_node.input_default = {}
|
||||
|
||||
# Mock links with dynamic fields
|
||||
dict_link1 = Mock(
|
||||
source_name="tools_^_create_dictionary_~_name",
|
||||
sink_name="values_#_name",
|
||||
sink_id="dict_node_id",
|
||||
source_id="test_node_id",
|
||||
)
|
||||
dict_link2 = Mock(
|
||||
source_name="tools_^_create_dictionary_~_age",
|
||||
sink_name="values_#_age",
|
||||
sink_id="dict_node_id",
|
||||
source_id="test_node_id",
|
||||
)
|
||||
list_link = Mock(
|
||||
source_name="tools_^_add_to_list_~_0",
|
||||
sink_name="entries_$_0",
|
||||
sink_id="list_node_id",
|
||||
source_id="test_node_id",
|
||||
)
|
||||
|
||||
mock_client.get_connected_output_nodes.return_value = [
|
||||
(dict_link1, mock_dict_node),
|
||||
(dict_link2, mock_dict_node),
|
||||
(list_link, mock_list_node),
|
||||
]
|
||||
|
||||
# Call the method that builds signatures
|
||||
tool_functions = await block._create_function_signature("test_node_id")
|
||||
|
||||
# Verify we got 2 tool functions (one for dict, one for list)
|
||||
assert len(tool_functions) == 2
|
||||
|
||||
# Verify the tool functions contain the dynamic field names
|
||||
dict_tool = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == "createdictionaryblock"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert dict_tool is not None
|
||||
dict_properties = dict_tool["function"]["parameters"]["properties"]
|
||||
assert "values___name" in dict_properties
|
||||
assert "values___age" in dict_properties
|
||||
|
||||
list_tool = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == "addtolistblock"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert list_tool is not None
|
||||
list_properties = list_tool["function"]["parameters"]["properties"]
|
||||
assert "entries___0" in list_properties
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_yielding_with_dynamic_fields():
|
||||
"""Test that outputs are yielded correctly with dynamic field names mapped back."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# No more sanitized mapping needed since we removed sanitization
|
||||
|
||||
# Mock LLM response with tool calls
|
||||
mock_response = Mock()
|
||||
mock_response.tool_calls = [
|
||||
Mock(
|
||||
function=Mock(
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"values___name": "Alice",
|
||||
"values___age": 30,
|
||||
"values___email": "alice@example.com",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
]
|
||||
# Ensure function name is a real string, not a Mock name
|
||||
mock_response.tool_calls[0].function.name = "createdictionaryblock"
|
||||
mock_response.reasoning = "Creating a dictionary with user information"
|
||||
mock_response.raw_response = {"role": "assistant", "content": "test"}
|
||||
mock_response.prompt_tokens = 100
|
||||
mock_response.completion_tokens = 50
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_function_signature", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "createdictionaryblock",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"values___name": {"type": "string"},
|
||||
"values___age": {"type": "number"},
|
||||
"values___email": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Create input data
|
||||
from backend.blocks import llm
|
||||
|
||||
input_data = block.input_schema(
|
||||
prompt="Create a user dictionary",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
)
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
graph_id="test_graph",
|
||||
node_id="test_node",
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
user_id="test_user",
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
# Verify the outputs use sanitized field names (matching frontend normalizeToolName)
|
||||
assert "tools_^_createdictionaryblock_~_values___name" in outputs
|
||||
assert outputs["tools_^_createdictionaryblock_~_values___name"] == "Alice"
|
||||
|
||||
assert "tools_^_createdictionaryblock_~_values___age" in outputs
|
||||
assert outputs["tools_^_createdictionaryblock_~_values___age"] == 30
|
||||
|
||||
assert "tools_^_createdictionaryblock_~_values___email" in outputs
|
||||
assert (
|
||||
outputs["tools_^_createdictionaryblock_~_values___email"]
|
||||
== "alice@example.com"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_regular_and_dynamic_fields():
|
||||
"""Test handling of blocks with both regular and dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node
|
||||
mock_node = Mock()
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "A test block"
|
||||
mock_node.block.input_schema = Mock()
|
||||
|
||||
# Mock the get_field_schema to return a proper schema for regular fields
|
||||
def get_field_schema(field_name):
|
||||
if field_name == "regular_field":
|
||||
return {"type": "string", "description": "A regular field"}
|
||||
elif field_name == "values":
|
||||
return {"type": "object", "description": "A dictionary field"}
|
||||
else:
|
||||
raise KeyError(f"Field {field_name} not found")
|
||||
|
||||
mock_node.block.input_schema.get_field_schema = get_field_schema
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
|
||||
# Create links with both regular and dynamic fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_test_~_regular",
|
||||
sink_name="regular_field", # Regular field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key",
|
||||
sink_name="values_#_key1", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key2",
|
||||
sink_name="values_#_key2", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Check properties
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3
|
||||
|
||||
# Regular field should have its original schema
|
||||
assert "regular_field" in properties
|
||||
assert properties["regular_field"]["description"] == "A regular field"
|
||||
|
||||
# Dynamic fields should have generated descriptions
|
||||
assert "values___key1" in properties
|
||||
assert "Dictionary field" in properties["values___key1"]["description"]
|
||||
|
||||
assert "values___key2" in properties
|
||||
assert "Dictionary field" in properties["values___key2"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_errors_dont_pollute_conversation():
|
||||
"""Test that validation errors are only used during retries and don't pollute the conversation."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Track conversation history changes
|
||||
conversation_snapshots = []
|
||||
|
||||
# Mock response with invalid tool call (missing required parameter)
|
||||
invalid_response = Mock()
|
||||
invalid_response.tool_calls = [
|
||||
Mock(
|
||||
function=Mock(
|
||||
arguments=json.dumps({"wrong_param": "value"}), # Wrong parameter name
|
||||
)
|
||||
)
|
||||
]
|
||||
# Ensure function name is a real string, not a Mock name
|
||||
invalid_response.tool_calls[0].function.name = "test_tool"
|
||||
invalid_response.reasoning = None
|
||||
invalid_response.raw_response = {"role": "assistant", "content": "invalid"}
|
||||
invalid_response.prompt_tokens = 100
|
||||
invalid_response.completion_tokens = 50
|
||||
|
||||
# Mock valid response after retry
|
||||
valid_response = Mock()
|
||||
valid_response.tool_calls = [
|
||||
Mock(function=Mock(arguments=json.dumps({"correct_param": "value"})))
|
||||
]
|
||||
# Ensure function name is a real string, not a Mock name
|
||||
valid_response.tool_calls[0].function.name = "test_tool"
|
||||
valid_response.reasoning = None
|
||||
valid_response.raw_response = {"role": "assistant", "content": "valid"}
|
||||
valid_response.prompt_tokens = 100
|
||||
valid_response.completion_tokens = 50
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
# Capture conversation state
|
||||
conversation_snapshots.append(kwargs.get("prompt", []).copy())
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return invalid_response
|
||||
else:
|
||||
return valid_response
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.side_effect = mock_llm_call
|
||||
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_function_signature", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"correct_param": {
|
||||
"type": "string",
|
||||
"description": "The correct parameter",
|
||||
}
|
||||
},
|
||||
"required": ["correct_param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Create input data
|
||||
from backend.blocks import llm
|
||||
|
||||
input_data = block.input_schema(
|
||||
prompt="Test prompt",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
retry=3, # Allow retries
|
||||
)
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
graph_id="test_graph",
|
||||
node_id="test_node",
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
user_id="test_user",
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
# Verify we had 2 LLM calls (initial + retry)
|
||||
assert call_count == 2
|
||||
|
||||
# Check the final conversation output
|
||||
final_conversation = outputs.get("conversations", [])
|
||||
|
||||
# The final conversation should NOT contain the validation error message
|
||||
error_messages = [
|
||||
msg
|
||||
for msg in final_conversation
|
||||
if msg.get("role") == "user"
|
||||
and "parameter errors" in msg.get("content", "")
|
||||
]
|
||||
assert (
|
||||
len(error_messages) == 0
|
||||
), "Validation error leaked into final conversation"
|
||||
|
||||
# The final conversation should only have the successful response
|
||||
assert final_conversation[-1]["content"] == "valid"
|
||||
@@ -270,13 +270,17 @@ class GetCurrentDateBlock(Block):
|
||||
test_output=[
|
||||
(
|
||||
"date",
|
||||
lambda t: abs(datetime.now() - datetime.strptime(t, "%Y-%m-%d"))
|
||||
< timedelta(days=8), # 7 days difference + 1 day error margin.
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%Y-%m-%d").date()
|
||||
)
|
||||
<= timedelta(days=8), # 7 days difference + 1 day error margin.
|
||||
),
|
||||
(
|
||||
"date",
|
||||
lambda t: abs(datetime.now() - datetime.strptime(t, "%m/%d/%Y"))
|
||||
< timedelta(days=8),
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%m/%d/%Y").date()
|
||||
)
|
||||
<= timedelta(days=8),
|
||||
# 7 days difference + 1 day error margin.
|
||||
),
|
||||
(
|
||||
@@ -382,7 +386,7 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%Y/%m/%d").date()
|
||||
)
|
||||
< timedelta(days=1), # Date format only, no time component
|
||||
<= timedelta(days=1), # Date format only, no time component
|
||||
),
|
||||
(
|
||||
"date_time",
|
||||
|
||||
284
autogpt_platform/backend/backend/data/dynamic_fields.py
Normal file
284
autogpt_platform/backend/backend/data/dynamic_fields.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Utilities for handling dynamic field names with special delimiters.
|
||||
|
||||
Dynamic fields allow graphs to connect complex data structures using special delimiters:
|
||||
- _#_ for dictionary keys (e.g., "values_#_name" → values["name"])
|
||||
- _$_ for list indices (e.g., "items_$_0" → items[0])
|
||||
- _@_ for object attributes (e.g., "obj_@_attr" → obj.attr)
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
# Dynamic field delimiters
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
DYNAMIC_DELIMITERS = (LIST_SPLIT, DICT_SPLIT, OBJC_SPLIT)
|
||||
|
||||
|
||||
def extract_base_field_name(field_name: str) -> str:
|
||||
"""
|
||||
Extract the base field name from a dynamic field name by removing all dynamic suffixes.
|
||||
|
||||
Examples:
|
||||
extract_base_field_name("values_#_name") → "values"
|
||||
extract_base_field_name("items_$_0") → "items"
|
||||
extract_base_field_name("obj_@_attr") → "obj"
|
||||
extract_base_field_name("regular_field") → "regular_field"
|
||||
|
||||
Args:
|
||||
field_name: The field name that may contain dynamic delimiters
|
||||
|
||||
Returns:
|
||||
The base field name without any dynamic suffixes
|
||||
"""
|
||||
base_name = field_name
|
||||
for delimiter in DYNAMIC_DELIMITERS:
|
||||
if delimiter in base_name:
|
||||
base_name = base_name.split(delimiter)[0]
|
||||
return base_name
|
||||
|
||||
|
||||
def is_dynamic_field(field_name: str) -> bool:
|
||||
"""
|
||||
Check if a field name contains dynamic delimiters.
|
||||
|
||||
Args:
|
||||
field_name: The field name to check
|
||||
|
||||
Returns:
|
||||
True if the field contains any dynamic delimiters, False otherwise
|
||||
"""
|
||||
return any(delimiter in field_name for delimiter in DYNAMIC_DELIMITERS)
|
||||
|
||||
|
||||
def get_dynamic_field_description(field_name: str) -> str:
|
||||
"""
|
||||
Generate a description for a dynamic field based on its structure.
|
||||
|
||||
Args:
|
||||
field_name: The full dynamic field name (e.g., "values_#_name")
|
||||
|
||||
Returns:
|
||||
A descriptive string explaining what this dynamic field represents
|
||||
"""
|
||||
base_name = extract_base_field_name(field_name)
|
||||
|
||||
if DICT_SPLIT in field_name:
|
||||
# Extract the key part after _#_
|
||||
parts = field_name.split(DICT_SPLIT)
|
||||
if len(parts) > 1:
|
||||
key = parts[1].split("_")[0] if "_" in parts[1] else parts[1]
|
||||
return f"Dictionary field '{key}' for base field '{base_name}' ({base_name}['{key}'])"
|
||||
elif LIST_SPLIT in field_name:
|
||||
# Extract the index part after _$_
|
||||
parts = field_name.split(LIST_SPLIT)
|
||||
if len(parts) > 1:
|
||||
index = parts[1].split("_")[0] if "_" in parts[1] else parts[1]
|
||||
return (
|
||||
f"List item {index} for base field '{base_name}' ({base_name}[{index}])"
|
||||
)
|
||||
elif OBJC_SPLIT in field_name:
|
||||
# Extract the attribute part after _@_
|
||||
parts = field_name.split(OBJC_SPLIT)
|
||||
if len(parts) > 1:
|
||||
# Get the full attribute name (everything after _@_)
|
||||
attr = parts[1]
|
||||
return f"Object attribute '{attr}' for base field '{base_name}' ({base_name}.{attr})"
|
||||
|
||||
return f"Value for {field_name}"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Dynamic field parsing and merging utilities
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _next_delim(s: str) -> tuple[str | None, int]:
|
||||
"""
|
||||
Return the *earliest* delimiter appearing in `s` and its index.
|
||||
|
||||
If none present → (None, -1).
|
||||
"""
|
||||
first: str | None = None
|
||||
pos = len(s) # sentinel: larger than any real index
|
||||
for d in DYNAMIC_DELIMITERS:
|
||||
i = s.find(d)
|
||||
if 0 <= i < pos:
|
||||
first, pos = d, i
|
||||
return first, (pos if first else -1)
|
||||
|
||||
|
||||
def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Convert the raw path string (starting with a delimiter) into
|
||||
[ (delimiter, identifier), … ] or None if the syntax is malformed.
|
||||
"""
|
||||
tokens: list[tuple[str, str]] = []
|
||||
while path:
|
||||
# 1. Which delimiter starts this chunk?
|
||||
delim = next((d for d in DYNAMIC_DELIMITERS if path.startswith(d)), None)
|
||||
if delim is None:
|
||||
return None # invalid syntax
|
||||
|
||||
# 2. Slice off the delimiter, then up to the next delimiter (or EOS)
|
||||
path = path[len(delim) :]
|
||||
nxt_delim, pos = _next_delim(path)
|
||||
token, path = (
|
||||
path[: pos if pos != -1 else len(path)],
|
||||
path[pos if pos != -1 else len(path) :],
|
||||
)
|
||||
if token == "":
|
||||
return None # empty identifier is invalid
|
||||
tokens.append((delim, token))
|
||||
return tokens
|
||||
|
||||
|
||||
def parse_execution_output(output: tuple[str, Any], name: str) -> Any:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
|
||||
On any failure (wrong name, wrong type, out-of-range, bad path)
|
||||
returns **None**.
|
||||
|
||||
Args:
|
||||
output: Tuple of (base_name, data) representing a block output entry
|
||||
name: The flattened field name to extract from the output data
|
||||
|
||||
Returns:
|
||||
The value at the specified path, or None if not found/invalid
|
||||
"""
|
||||
base_name, data = output
|
||||
|
||||
# Exact match → whole object
|
||||
if name == base_name:
|
||||
return data
|
||||
|
||||
# Must start with the expected name
|
||||
if not name.startswith(base_name):
|
||||
return None
|
||||
path = name[len(base_name) :]
|
||||
if not path:
|
||||
return None # nothing left to parse
|
||||
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
return None
|
||||
|
||||
cur: Any = data
|
||||
for delim, ident in tokens:
|
||||
if delim == LIST_SPLIT:
|
||||
# list[index]
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
return None
|
||||
if not isinstance(cur, list) or idx >= len(cur):
|
||||
return None
|
||||
cur = cur[idx]
|
||||
|
||||
elif delim == DICT_SPLIT:
|
||||
if not isinstance(cur, dict) or ident not in cur:
|
||||
return None
|
||||
cur = cur[ident]
|
||||
|
||||
elif delim == OBJC_SPLIT:
|
||||
if not hasattr(cur, ident):
|
||||
return None
|
||||
cur = getattr(cur, ident)
|
||||
|
||||
else:
|
||||
return None # unreachable
|
||||
|
||||
return cur
|
||||
|
||||
|
||||
def _assign(container: Any, tokens: list[tuple[str, str]], value: Any) -> Any:
|
||||
"""
|
||||
Recursive helper that *returns* the (possibly new) container with
|
||||
`value` assigned along the remaining `tokens` path.
|
||||
"""
|
||||
if not tokens:
|
||||
return value # leaf reached
|
||||
|
||||
delim, ident = tokens[0]
|
||||
rest = tokens[1:]
|
||||
|
||||
# ---------- list ----------
|
||||
if delim == LIST_SPLIT:
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
raise ValueError("index must be an integer")
|
||||
|
||||
if container is None:
|
||||
container = []
|
||||
elif not isinstance(container, list):
|
||||
container = list(container) if hasattr(container, "__iter__") else []
|
||||
|
||||
while len(container) <= idx:
|
||||
container.append(None)
|
||||
container[idx] = _assign(container[idx], rest, value)
|
||||
return container
|
||||
|
||||
# ---------- dict ----------
|
||||
if delim == DICT_SPLIT:
|
||||
if container is None:
|
||||
container = {}
|
||||
elif not isinstance(container, dict):
|
||||
container = dict(container) if hasattr(container, "items") else {}
|
||||
container[ident] = _assign(container.get(ident), rest, value)
|
||||
return container
|
||||
|
||||
# ---------- object ----------
|
||||
if delim == OBJC_SPLIT:
|
||||
if container is None:
|
||||
container = MockObject()
|
||||
elif not hasattr(container, "__dict__"):
|
||||
# If it's not an object, create a new one
|
||||
container = MockObject()
|
||||
setattr(
|
||||
container,
|
||||
ident,
|
||||
_assign(getattr(container, ident, None), rest, value),
|
||||
)
|
||||
return container
|
||||
|
||||
return value # unreachable
|
||||
|
||||
|
||||
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Reconstruct nested objects from a *flattened* dict of key → value.
|
||||
|
||||
Raises ValueError on syntactically invalid list indices.
|
||||
|
||||
Args:
|
||||
data: Dictionary with potentially flattened dynamic field keys
|
||||
|
||||
Returns:
|
||||
Dictionary with nested objects reconstructed from flattened keys
|
||||
"""
|
||||
merged: dict[str, Any] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
# Split off the base name (before the first delimiter, if any)
|
||||
delim, pos = _next_delim(key)
|
||||
if delim is None:
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
base, path = key[:pos], key[pos:]
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
# Invalid key; treat as scalar under the raw name
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
merged[base] = _assign(merged.get(base), tokens, value)
|
||||
|
||||
data.update(merged)
|
||||
return data
|
||||
@@ -20,6 +20,7 @@ from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.data.dynamic_fields import extract_base_field_name
|
||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
@@ -741,7 +742,7 @@ def _is_tool_pin(name: str) -> bool:
|
||||
|
||||
|
||||
def _sanitize_pin_name(name: str) -> str:
|
||||
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
sanitized_name = extract_base_field_name(name)
|
||||
if _is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
return sanitized_name
|
||||
|
||||
@@ -25,6 +25,7 @@ from backend.data.block import (
|
||||
get_block,
|
||||
)
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
@@ -59,7 +60,6 @@ from backend.executor.utils import (
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
parse_execution_output,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
@@ -4,7 +4,7 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, Mapping, Optional, cast
|
||||
from typing import Mapping, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
@@ -20,6 +20,9 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.db import prisma
|
||||
|
||||
# Import dynamic field utilities from centralized location
|
||||
from backend.data.dynamic_fields import merge_execution_input
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
@@ -39,7 +42,6 @@ from backend.util.clients import (
|
||||
)
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
@@ -186,195 +188,7 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
|
||||
|
||||
# ============ Execution Input Helpers ============ #
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Delimiters
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
_DELIMS = (LIST_SPLIT, DICT_SPLIT, OBJC_SPLIT)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Tokenisation utilities
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _next_delim(s: str) -> tuple[str | None, int]:
|
||||
"""
|
||||
Return the *earliest* delimiter appearing in `s` and its index.
|
||||
|
||||
If none present → (None, -1).
|
||||
"""
|
||||
first: str | None = None
|
||||
pos = len(s) # sentinel: larger than any real index
|
||||
for d in _DELIMS:
|
||||
i = s.find(d)
|
||||
if 0 <= i < pos:
|
||||
first, pos = d, i
|
||||
return first, (pos if first else -1)
|
||||
|
||||
|
||||
def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Convert the raw path string (starting with a delimiter) into
|
||||
[ (delimiter, identifier), … ] or None if the syntax is malformed.
|
||||
"""
|
||||
tokens: list[tuple[str, str]] = []
|
||||
while path:
|
||||
# 1. Which delimiter starts this chunk?
|
||||
delim = next((d for d in _DELIMS if path.startswith(d)), None)
|
||||
if delim is None:
|
||||
return None # invalid syntax
|
||||
|
||||
# 2. Slice off the delimiter, then up to the next delimiter (or EOS)
|
||||
path = path[len(delim) :]
|
||||
nxt_delim, pos = _next_delim(path)
|
||||
token, path = (
|
||||
path[: pos if pos != -1 else len(path)],
|
||||
path[pos if pos != -1 else len(path) :],
|
||||
)
|
||||
if token == "":
|
||||
return None # empty identifier is invalid
|
||||
tokens.append((delim, token))
|
||||
return tokens
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Public API – parsing (flattened ➜ concrete)
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockOutputEntry, name: str) -> JsonValue | None:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
|
||||
On any failure (wrong name, wrong type, out-of-range, bad path)
|
||||
returns **None**.
|
||||
"""
|
||||
base_name, data = output
|
||||
|
||||
# Exact match → whole object
|
||||
if name == base_name:
|
||||
return data
|
||||
|
||||
# Must start with the expected name
|
||||
if not name.startswith(base_name):
|
||||
return None
|
||||
path = name[len(base_name) :]
|
||||
if not path:
|
||||
return None # nothing left to parse
|
||||
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
return None
|
||||
|
||||
cur: JsonValue = data
|
||||
for delim, ident in tokens:
|
||||
if delim == LIST_SPLIT:
|
||||
# list[index]
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
return None
|
||||
if not isinstance(cur, list) or idx >= len(cur):
|
||||
return None
|
||||
cur = cur[idx]
|
||||
|
||||
elif delim == DICT_SPLIT:
|
||||
if not isinstance(cur, dict) or ident not in cur:
|
||||
return None
|
||||
cur = cur[ident]
|
||||
|
||||
elif delim == OBJC_SPLIT:
|
||||
if not hasattr(cur, ident):
|
||||
return None
|
||||
cur = getattr(cur, ident)
|
||||
|
||||
else:
|
||||
return None # unreachable
|
||||
|
||||
return cur
|
||||
|
||||
|
||||
def _assign(container: Any, tokens: list[tuple[str, str]], value: Any) -> Any:
|
||||
"""
|
||||
Recursive helper that *returns* the (possibly new) container with
|
||||
`value` assigned along the remaining `tokens` path.
|
||||
"""
|
||||
if not tokens:
|
||||
return value # leaf reached
|
||||
|
||||
delim, ident = tokens[0]
|
||||
rest = tokens[1:]
|
||||
|
||||
# ---------- list ----------
|
||||
if delim == LIST_SPLIT:
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
raise ValueError("index must be an integer")
|
||||
|
||||
if container is None:
|
||||
container = []
|
||||
elif not isinstance(container, list):
|
||||
container = list(container) if hasattr(container, "__iter__") else []
|
||||
|
||||
while len(container) <= idx:
|
||||
container.append(None)
|
||||
container[idx] = _assign(container[idx], rest, value)
|
||||
return container
|
||||
|
||||
# ---------- dict ----------
|
||||
if delim == DICT_SPLIT:
|
||||
if container is None:
|
||||
container = {}
|
||||
elif not isinstance(container, dict):
|
||||
container = dict(container) if hasattr(container, "items") else {}
|
||||
container[ident] = _assign(container.get(ident), rest, value)
|
||||
return container
|
||||
|
||||
# ---------- object ----------
|
||||
if delim == OBJC_SPLIT:
|
||||
if container is None or not isinstance(container, MockObject):
|
||||
container = MockObject()
|
||||
setattr(
|
||||
container,
|
||||
ident,
|
||||
_assign(getattr(container, ident, None), rest, value),
|
||||
)
|
||||
return container
|
||||
|
||||
return value # unreachable
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Reconstruct nested objects from a *flattened* dict of key → value.
|
||||
|
||||
Raises ValueError on syntactically invalid list indices.
|
||||
"""
|
||||
merged: BlockInput = {}
|
||||
|
||||
for key, value in data.items():
|
||||
# Split off the base name (before the first delimiter, if any)
|
||||
delim, pos = _next_delim(key)
|
||||
if delim is None:
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
base, path = key[:pos], key[pos:]
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
# Invalid key; treat as scalar under the raw name
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
merged[base] = _assign(merged.get(base), tokens, value)
|
||||
|
||||
data.update(merged)
|
||||
return data
|
||||
# Dynamic field utilities are now imported from backend.data.dynamic_fields
|
||||
|
||||
|
||||
def validate_exec(
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import cast
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from backend.executor.utils import merge_execution_input, parse_execution_output
|
||||
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ async def callback(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/credentials")
|
||||
@router.get("/credentials", summary="List Credentials")
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
@@ -221,7 +221,9 @@ async def list_credentials_by_provider(
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{provider}/credentials/{cred_id}")
|
||||
@router.get(
|
||||
"/{provider}/credentials/{cred_id}", summary="Get Specific Credential By ID"
|
||||
)
|
||||
async def get_credential(
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to retrieve credentials for")
|
||||
@@ -242,7 +244,7 @@ async def get_credential(
|
||||
return credential
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201)
|
||||
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
||||
async def create_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
provider: Annotated[
|
||||
|
||||
124
autogpt_platform/backend/backend/util/dynamic_fields.py
Normal file
124
autogpt_platform/backend/backend/util/dynamic_fields.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Utilities for handling dynamic field names and delimiters in the AutoGPT Platform.
|
||||
|
||||
Dynamic fields allow graphs to connect complex data structures using special delimiters:
|
||||
- _#_ for dictionary keys (e.g., "values_#_name" → values["name"])
|
||||
- _$_ for list indices (e.g., "items_$_0" → items[0])
|
||||
- _@_ for object attributes (e.g., "obj_@_attr" → obj.attr)
|
||||
|
||||
This module provides utilities for:
|
||||
- Extracting base field names from dynamic field names
|
||||
- Generating proper schemas for base fields
|
||||
- Creating helper functions for field sanitization
|
||||
"""
|
||||
|
||||
from backend.data.dynamic_fields import DICT_SPLIT, LIST_SPLIT, OBJC_SPLIT
|
||||
|
||||
# All dynamic field delimiters
|
||||
DYNAMIC_DELIMITERS = (DICT_SPLIT, LIST_SPLIT, OBJC_SPLIT)
|
||||
|
||||
|
||||
def extract_base_field_name(field_name: str) -> str:
|
||||
"""
|
||||
Extract the base field name from a dynamic field name.
|
||||
|
||||
Examples:
|
||||
extract_base_field_name("values_#_name") → "values"
|
||||
extract_base_field_name("items_$_0") → "items"
|
||||
extract_base_field_name("obj_@_attr") → "obj"
|
||||
extract_base_field_name("regular_field") → "regular_field"
|
||||
|
||||
Args:
|
||||
field_name: The field name that may contain dynamic delimiters
|
||||
|
||||
Returns:
|
||||
The base field name without any dynamic suffixes
|
||||
"""
|
||||
base_name = field_name
|
||||
for delimiter in DYNAMIC_DELIMITERS:
|
||||
if delimiter in base_name:
|
||||
base_name = base_name.split(delimiter)[0]
|
||||
return base_name
|
||||
|
||||
|
||||
def is_dynamic_field(field_name: str) -> bool:
|
||||
"""
|
||||
Check if a field name contains dynamic delimiters.
|
||||
|
||||
Args:
|
||||
field_name: The field name to check
|
||||
|
||||
Returns:
|
||||
True if the field contains any dynamic delimiters, False otherwise
|
||||
"""
|
||||
return any(delimiter in field_name for delimiter in DYNAMIC_DELIMITERS)
|
||||
|
||||
|
||||
def get_dynamic_field_description(
|
||||
base_field_name: str, original_field_name: str
|
||||
) -> str:
|
||||
"""
|
||||
Generate a description for a dynamic field based on its base field and structure.
|
||||
|
||||
Args:
|
||||
base_field_name: The base field name (e.g., "values")
|
||||
original_field_name: The full dynamic field name (e.g., "values_#_name")
|
||||
|
||||
Returns:
|
||||
A descriptive string explaining what this dynamic field represents
|
||||
"""
|
||||
if DICT_SPLIT in original_field_name:
|
||||
key_part = (
|
||||
original_field_name.split(DICT_SPLIT, 1)[1].split(DICT_SPLIT[0])[0]
|
||||
if DICT_SPLIT in original_field_name
|
||||
else "key"
|
||||
)
|
||||
return f"Dictionary value for {base_field_name}['{key_part}']"
|
||||
elif LIST_SPLIT in original_field_name:
|
||||
index_part = (
|
||||
original_field_name.split(LIST_SPLIT, 1)[1].split(LIST_SPLIT[0])[0]
|
||||
if LIST_SPLIT in original_field_name
|
||||
else "index"
|
||||
)
|
||||
return f"List item for {base_field_name}[{index_part}]"
|
||||
elif OBJC_SPLIT in original_field_name:
|
||||
attr_part = (
|
||||
original_field_name.split(OBJC_SPLIT, 1)[1].split(OBJC_SPLIT[0])[0]
|
||||
if OBJC_SPLIT in original_field_name
|
||||
else "attr"
|
||||
)
|
||||
return f"Object attribute for {base_field_name}.{attr_part}"
|
||||
else:
|
||||
return f"Dynamic value for {base_field_name}"
|
||||
|
||||
|
||||
def group_fields_by_base_name(field_names: list[str]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Group a list of field names by their base field names.
|
||||
|
||||
Args:
|
||||
field_names: List of field names that may contain dynamic delimiters
|
||||
|
||||
Returns:
|
||||
Dictionary mapping base field names to lists of original field names
|
||||
|
||||
Example:
|
||||
group_fields_by_base_name([
|
||||
"values_#_name",
|
||||
"values_#_age",
|
||||
"items_$_0",
|
||||
"regular_field"
|
||||
])
|
||||
→ {
|
||||
"values": ["values_#_name", "values_#_age"],
|
||||
"items": ["items_$_0"],
|
||||
"regular_field": ["regular_field"]
|
||||
}
|
||||
"""
|
||||
grouped = {}
|
||||
for field_name in field_names:
|
||||
base_name = extract_base_field_name(field_name)
|
||||
if base_name not in grouped:
|
||||
grouped[base_name] = []
|
||||
grouped[base_name].append(field_name)
|
||||
return grouped
|
||||
175
autogpt_platform/backend/backend/util/dynamic_fields_test.py
Normal file
175
autogpt_platform/backend/backend/util/dynamic_fields_test.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Tests for dynamic field utilities."""
|
||||
|
||||
from backend.util.dynamic_fields import (
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
group_fields_by_base_name,
|
||||
is_dynamic_field,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractBaseFieldName:
|
||||
"""Test extracting base field names from dynamic field names."""
|
||||
|
||||
def test_extract_dict_field(self):
|
||||
"""Test extracting base name from dictionary fields."""
|
||||
assert extract_base_field_name("values_#_name") == "values"
|
||||
assert extract_base_field_name("data_#_key1_#_key2") == "data"
|
||||
assert extract_base_field_name("config_#_database_#_host") == "config"
|
||||
|
||||
def test_extract_list_field(self):
|
||||
"""Test extracting base name from list fields."""
|
||||
assert extract_base_field_name("items_$_0") == "items"
|
||||
assert extract_base_field_name("results_$_5_$_10") == "results"
|
||||
assert extract_base_field_name("nested_$_0_$_1_$_2") == "nested"
|
||||
|
||||
def test_extract_object_field(self):
|
||||
"""Test extracting base name from object fields."""
|
||||
assert extract_base_field_name("user_@_name") == "user"
|
||||
assert extract_base_field_name("response_@_data_@_items") == "response"
|
||||
assert extract_base_field_name("obj_@_attr1_@_attr2") == "obj"
|
||||
|
||||
def test_extract_mixed_fields(self):
|
||||
"""Test extracting base name from mixed dynamic fields."""
|
||||
assert extract_base_field_name("data_$_0_#_key") == "data"
|
||||
assert extract_base_field_name("items_#_user_@_name") == "items"
|
||||
assert extract_base_field_name("complex_$_0_@_attr_#_key") == "complex"
|
||||
|
||||
def test_extract_regular_field(self):
|
||||
"""Test extracting base name from regular (non-dynamic) fields."""
|
||||
assert extract_base_field_name("regular_field") == "regular_field"
|
||||
assert extract_base_field_name("simple") == "simple"
|
||||
assert extract_base_field_name("") == ""
|
||||
|
||||
def test_extract_field_with_underscores(self):
|
||||
"""Test fields with regular underscores (not dynamic delimiters)."""
|
||||
assert extract_base_field_name("field_name_here") == "field_name_here"
|
||||
assert extract_base_field_name("my_field_#_key") == "my_field"
|
||||
|
||||
|
||||
class TestIsDynamicField:
|
||||
"""Test identifying dynamic fields."""
|
||||
|
||||
def test_is_dynamic_dict_field(self):
|
||||
"""Test identifying dictionary dynamic fields."""
|
||||
assert is_dynamic_field("values_#_name") is True
|
||||
assert is_dynamic_field("data_#_key1_#_key2") is True
|
||||
|
||||
def test_is_dynamic_list_field(self):
|
||||
"""Test identifying list dynamic fields."""
|
||||
assert is_dynamic_field("items_$_0") is True
|
||||
assert is_dynamic_field("results_$_5_$_10") is True
|
||||
|
||||
def test_is_dynamic_object_field(self):
|
||||
"""Test identifying object dynamic fields."""
|
||||
assert is_dynamic_field("user_@_name") is True
|
||||
assert is_dynamic_field("response_@_data_@_items") is True
|
||||
|
||||
def test_is_dynamic_mixed_field(self):
|
||||
"""Test identifying mixed dynamic fields."""
|
||||
assert is_dynamic_field("data_$_0_#_key") is True
|
||||
assert is_dynamic_field("items_#_user_@_name") is True
|
||||
|
||||
def test_is_not_dynamic_field(self):
|
||||
"""Test identifying non-dynamic fields."""
|
||||
assert is_dynamic_field("regular_field") is False
|
||||
assert is_dynamic_field("field_name_here") is False
|
||||
assert is_dynamic_field("simple") is False
|
||||
assert is_dynamic_field("") is False
|
||||
|
||||
|
||||
class TestGetDynamicFieldDescription:
|
||||
"""Test generating descriptions for dynamic fields."""
|
||||
|
||||
def test_dict_field_description(self):
|
||||
"""Test descriptions for dictionary fields."""
|
||||
desc = get_dynamic_field_description("values", "values_#_name")
|
||||
assert "Dictionary value for values['name']" == desc
|
||||
|
||||
desc = get_dynamic_field_description("config", "config_#_database")
|
||||
assert "Dictionary value for config['database']" == desc
|
||||
|
||||
def test_list_field_description(self):
|
||||
"""Test descriptions for list fields."""
|
||||
desc = get_dynamic_field_description("items", "items_$_0")
|
||||
assert "List item for items[0]" == desc
|
||||
|
||||
desc = get_dynamic_field_description("results", "results_$_5")
|
||||
assert "List item for results[5]" == desc
|
||||
|
||||
def test_object_field_description(self):
|
||||
"""Test descriptions for object fields."""
|
||||
desc = get_dynamic_field_description("user", "user_@_name")
|
||||
assert "Object attribute for user.name" == desc
|
||||
|
||||
desc = get_dynamic_field_description("response", "response_@_data")
|
||||
assert "Object attribute for response.data" == desc
|
||||
|
||||
def test_fallback_description(self):
|
||||
"""Test fallback description for non-dynamic fields."""
|
||||
desc = get_dynamic_field_description("field", "field")
|
||||
assert "Dynamic value for field" == desc
|
||||
|
||||
|
||||
class TestGroupFieldsByBaseName:
|
||||
"""Test grouping fields by their base names."""
|
||||
|
||||
def test_group_mixed_fields(self):
|
||||
"""Test grouping a mix of dynamic and regular fields."""
|
||||
fields = [
|
||||
"values_#_name",
|
||||
"values_#_age",
|
||||
"items_$_0",
|
||||
"items_$_1",
|
||||
"user_@_email",
|
||||
"regular_field",
|
||||
"another_field",
|
||||
]
|
||||
|
||||
result = group_fields_by_base_name(fields)
|
||||
|
||||
expected = {
|
||||
"values": ["values_#_name", "values_#_age"],
|
||||
"items": ["items_$_0", "items_$_1"],
|
||||
"user": ["user_@_email"],
|
||||
"regular_field": ["regular_field"],
|
||||
"another_field": ["another_field"],
|
||||
}
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_group_empty_list(self):
|
||||
"""Test grouping an empty list."""
|
||||
result = group_fields_by_base_name([])
|
||||
assert result == {}
|
||||
|
||||
def test_group_single_field(self):
|
||||
"""Test grouping a single field."""
|
||||
result = group_fields_by_base_name(["values_#_name"])
|
||||
assert result == {"values": ["values_#_name"]}
|
||||
|
||||
def test_group_complex_dynamic_fields(self):
|
||||
"""Test grouping complex nested dynamic fields."""
|
||||
fields = [
|
||||
"data_$_0_#_key1",
|
||||
"data_$_0_#_key2",
|
||||
"data_$_1_#_key1",
|
||||
"other_@_attr",
|
||||
]
|
||||
|
||||
result = group_fields_by_base_name(fields)
|
||||
|
||||
expected = {
|
||||
"data": ["data_$_0_#_key1", "data_$_0_#_key2", "data_$_1_#_key1"],
|
||||
"other": ["other_@_attr"],
|
||||
}
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_preserve_order(self):
|
||||
"""Test that field order is preserved within groups."""
|
||||
fields = ["values_#_c", "values_#_a", "values_#_b"]
|
||||
result = group_fields_by_base_name(fields)
|
||||
|
||||
# Should preserve the original order
|
||||
assert result["values"] == ["values_#_c", "values_#_a", "values_#_b"]
|
||||
@@ -19,9 +19,48 @@ def _msg_tokens(msg: dict, enc) -> int:
|
||||
"""
|
||||
OpenAI counts ≈3 wrapper tokens per chat message, plus 1 if "name"
|
||||
is present, plus the tokenised content length.
|
||||
For tool calls, we need to count tokens in tool_calls and content fields.
|
||||
"""
|
||||
WRAPPER = 3 + (1 if "name" in msg else 0)
|
||||
return WRAPPER + _tok_len(msg.get("content") or "", enc)
|
||||
|
||||
# Count content tokens
|
||||
content_tokens = _tok_len(msg.get("content") or "", enc)
|
||||
|
||||
# Count tool call tokens for both OpenAI and Anthropic formats
|
||||
tool_call_tokens = 0
|
||||
|
||||
# OpenAI format: tool_calls array at message level
|
||||
if "tool_calls" in msg and isinstance(msg["tool_calls"], list):
|
||||
for tool_call in msg["tool_calls"]:
|
||||
# Count the tool call structure tokens
|
||||
tool_call_tokens += _tok_len(tool_call.get("id", ""), enc)
|
||||
tool_call_tokens += _tok_len(tool_call.get("type", ""), enc)
|
||||
if "function" in tool_call:
|
||||
tool_call_tokens += _tok_len(tool_call["function"].get("name", ""), enc)
|
||||
tool_call_tokens += _tok_len(
|
||||
tool_call["function"].get("arguments", ""), enc
|
||||
)
|
||||
|
||||
# Anthropic format: tool_use within content array
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "tool_use":
|
||||
# Count the tool use structure tokens
|
||||
tool_call_tokens += _tok_len(item.get("id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("name", ""), enc)
|
||||
tool_call_tokens += _tok_len(json.dumps(item.get("input", {})), enc)
|
||||
elif isinstance(item, dict) and item.get("type") == "tool_result":
|
||||
# Count tool result tokens
|
||||
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
elif isinstance(item, dict) and "content" in item:
|
||||
# Other content types with content field
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
# For list content, override content_tokens since we counted everything above
|
||||
content_tokens = 0
|
||||
|
||||
return WRAPPER + content_tokens + tool_call_tokens
|
||||
|
||||
|
||||
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
|
||||
278
autogpt_platform/backend/backend/util/prompt_test.py
Normal file
278
autogpt_platform/backend/backend/util/prompt_test.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Tests for prompt utility functions, especially tool call token counting."""
|
||||
|
||||
import pytest
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import _msg_tokens, estimate_token_count
|
||||
|
||||
|
||||
class TestMsgTokens:
|
||||
"""Test the _msg_tokens function with various message types."""
|
||||
|
||||
@pytest.fixture
|
||||
def enc(self):
|
||||
"""Get the encoding for gpt-4o model."""
|
||||
return encoding_for_model("gpt-4o")
|
||||
|
||||
def test_regular_message_token_counting(self, enc):
|
||||
"""Test that regular messages are counted correctly (backward compatibility)."""
|
||||
msg = {"role": "user", "content": "What's the weather like in San Francisco?"}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should be wrapper (3) + content tokens
|
||||
expected = 3 + len(enc.encode(msg["content"]))
|
||||
assert tokens == expected
|
||||
assert tokens > 3 # Has content
|
||||
|
||||
def test_regular_message_with_name(self, enc):
|
||||
"""Test that messages with name field get extra wrapper token."""
|
||||
msg = {"role": "user", "name": "test_user", "content": "Hello!"}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should be wrapper (3 + 1 for name) + content tokens
|
||||
expected = 4 + len(enc.encode(msg["content"]))
|
||||
assert tokens == expected
|
||||
|
||||
def test_openai_tool_call_token_counting(self, enc):
|
||||
"""Test OpenAI format tool call token counting."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco", "unit": "celsius"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count wrapper + all tool call components
|
||||
expected_tool_tokens = (
|
||||
len(enc.encode("call_abc123"))
|
||||
+ len(enc.encode("function"))
|
||||
+ len(enc.encode("get_weather"))
|
||||
+ len(enc.encode('{"location": "San Francisco", "unit": "celsius"}'))
|
||||
)
|
||||
expected = 3 + expected_tool_tokens # wrapper + tool tokens
|
||||
|
||||
assert tokens == expected
|
||||
assert tokens > 8 # Should be significantly more than just wrapper
|
||||
|
||||
def test_openai_multiple_tool_calls(self, enc):
|
||||
"""Test OpenAI format with multiple tool calls."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "func1", "arguments": '{"arg": "value1"}'},
|
||||
},
|
||||
{
|
||||
"id": "call_2",
|
||||
"type": "function",
|
||||
"function": {"name": "func2", "arguments": '{"arg": "value2"}'},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count all tool calls
|
||||
assert tokens > 20 # Should be more than single tool call
|
||||
|
||||
def test_anthropic_tool_use_token_counting(self, enc):
|
||||
"""Test Anthropic format tool use token counting."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_xyz456",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "San Francisco", "unit": "celsius"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count wrapper + tool use components
|
||||
expected_tool_tokens = (
|
||||
len(enc.encode("toolu_xyz456"))
|
||||
+ len(enc.encode("get_weather"))
|
||||
+ len(
|
||||
enc.encode(json.dumps({"location": "San Francisco", "unit": "celsius"}))
|
||||
)
|
||||
)
|
||||
expected = 3 + expected_tool_tokens # wrapper + tool tokens
|
||||
|
||||
assert tokens == expected
|
||||
assert tokens > 8 # Should be significantly more than just wrapper
|
||||
|
||||
def test_anthropic_tool_result_token_counting(self, enc):
|
||||
"""Test Anthropic format tool result token counting."""
|
||||
msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_xyz456",
|
||||
"content": "The weather in San Francisco is 22°C and sunny.",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count wrapper + tool result components
|
||||
expected_tool_tokens = len(enc.encode("toolu_xyz456")) + len(
|
||||
enc.encode("The weather in San Francisco is 22°C and sunny.")
|
||||
)
|
||||
expected = 3 + expected_tool_tokens # wrapper + tool tokens
|
||||
|
||||
assert tokens == expected
|
||||
assert tokens > 8 # Should be significantly more than just wrapper
|
||||
|
||||
def test_anthropic_mixed_content(self, enc):
|
||||
"""Test Anthropic format with mixed content types."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "content": "I'll check the weather for you."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "SF"},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count all content items
|
||||
assert tokens > 15 # Should count both text and tool use
|
||||
|
||||
def test_empty_content(self, enc):
|
||||
"""Test message with empty or None content."""
|
||||
msg = {"role": "assistant", "content": None}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
assert tokens == 3 # Just wrapper tokens
|
||||
|
||||
msg["content"] = ""
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
assert tokens == 3 # Just wrapper tokens
|
||||
|
||||
def test_string_content_with_tool_calls(self, enc):
|
||||
"""Test OpenAI format where content is string but tool_calls exist."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": "Let me check that for you.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {"name": "test_func", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
tokens = _msg_tokens(msg, enc)
|
||||
|
||||
# Should count both content and tool calls
|
||||
content_tokens = len(enc.encode("Let me check that for you."))
|
||||
tool_tokens = (
|
||||
len(enc.encode("call_123"))
|
||||
+ len(enc.encode("function"))
|
||||
+ len(enc.encode("test_func"))
|
||||
+ len(enc.encode("{}"))
|
||||
)
|
||||
expected = 3 + content_tokens + tool_tokens
|
||||
|
||||
assert tokens == expected
|
||||
|
||||
|
||||
class TestEstimateTokenCount:
|
||||
"""Test the estimate_token_count function with conversations containing tool calls."""
|
||||
|
||||
def test_conversation_with_tool_calls(self):
|
||||
"""Test token counting for a complete conversation with tool calls."""
|
||||
conversation = [
|
||||
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "San Francisco"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_123",
|
||||
"content": "22°C and sunny",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The weather in San Francisco is 22°C and sunny.",
|
||||
},
|
||||
]
|
||||
|
||||
total_tokens = estimate_token_count(conversation)
|
||||
|
||||
# Verify total equals sum of individual messages
|
||||
enc = encoding_for_model("gpt-4o")
|
||||
expected_total = sum(_msg_tokens(msg, enc) for msg in conversation)
|
||||
|
||||
assert total_tokens == expected_total
|
||||
assert total_tokens > 40 # Should be substantial for this conversation
|
||||
|
||||
def test_openai_conversation(self):
|
||||
"""Test token counting for OpenAI format conversation."""
|
||||
conversation = [
|
||||
{"role": "user", "content": "Calculate 2 + 2"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_calc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"arguments": '{"expression": "2 + 2"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_calc", "content": "4"},
|
||||
{"role": "assistant", "content": "The result is 4."},
|
||||
]
|
||||
|
||||
total_tokens = estimate_token_count(conversation)
|
||||
|
||||
# Verify total equals sum of individual messages
|
||||
enc = encoding_for_model("gpt-4o")
|
||||
expected_total = sum(_msg_tokens(msg, enc) for msg in conversation)
|
||||
|
||||
assert total_tokens == expected_total
|
||||
assert total_tokens > 20 # Should be substantial
|
||||
@@ -28,6 +28,7 @@ from fastapi import FastAPI, Request, responses
|
||||
from pydantic import BaseModel, TypeAdapter, create_model
|
||||
|
||||
import backend.util.exceptions as exceptions
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
@@ -283,6 +284,24 @@ class AppService(BaseAppService, ABC):
|
||||
super().run()
|
||||
self.fastapi_app = FastAPI()
|
||||
|
||||
# Add Prometheus instrumentation to all services
|
||||
try:
|
||||
instrument_fastapi(
|
||||
self.fastapi_app,
|
||||
service_name=self.service_name,
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=False,
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
f"Prometheus instrumentation not available for {self.service_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to instrument {self.service_name} with Prometheus: {e}"
|
||||
)
|
||||
|
||||
# Register the exposed API routes.
|
||||
for attr_name, attr in vars(type(self)).items():
|
||||
if getattr(attr, EXPOSED_FLAG, False):
|
||||
|
||||
71
autogpt_platform/backend/poetry.lock
generated
71
autogpt_platform/backend/poetry.lock
generated
@@ -1240,14 +1240,14 @@ tests = ["coverage", "coveralls", "dill", "mock", "nose"]
|
||||
|
||||
[[package]]
|
||||
name = "faker"
|
||||
version = "37.6.0"
|
||||
version = "37.8.0"
|
||||
description = "Faker is a Python package that generates fake data for you."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "faker-37.6.0-py3-none-any.whl", hash = "sha256:3c5209b23d7049d596a51db5d76403a0ccfea6fc294ffa2ecfef6a8843b1e6a7"},
|
||||
{file = "faker-37.6.0.tar.gz", hash = "sha256:0f8cc34f30095184adf87c3c24c45b38b33ad81c35ef6eb0a3118f301143012c"},
|
||||
{file = "faker-37.8.0-py3-none-any.whl", hash = "sha256:b08233118824423b5fc239f7dd51f145e7018082b4164f8da6a9994e1f1ae793"},
|
||||
{file = "faker-37.8.0.tar.gz", hash = "sha256:090bb5abbec2b30949a95ce1ba6b20d1d0ed222883d63483a0d4be4a970d6fb8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1339,20 +1339,21 @@ packaging = ">=20"
|
||||
|
||||
[[package]]
|
||||
name = "firecrawl-py"
|
||||
version = "2.16.3"
|
||||
version = "4.3.6"
|
||||
description = "Python SDK for Firecrawl API"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "firecrawl_py-2.16.3-py3-none-any.whl", hash = "sha256:94bb46af5e0df6c8ec414ac999a5355c0f5a46f15fd1cf5a02a3b31062db0aa8"},
|
||||
{file = "firecrawl_py-2.16.3.tar.gz", hash = "sha256:5fd063ef4acc4c4be62648f1e11467336bc127780b3afc28d39078a012e6a14c"},
|
||||
{file = "firecrawl_py-4.3.6-py3-none-any.whl", hash = "sha256:9b5dffdf5ed08fdbf0966f17e18c1a034d59f42a20b2bf9a6291a83190d7eb0f"},
|
||||
{file = "firecrawl_py-4.3.6.tar.gz", hash = "sha256:303827a86d0f6237a8ddcaa0bcdaa4c5ee11d9a4880b0685302b8d9a0e191ee0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = "*"
|
||||
httpx = "*"
|
||||
nest-asyncio = "*"
|
||||
pydantic = "*"
|
||||
pydantic = ">=2.0"
|
||||
python-dotenv = "*"
|
||||
requests = "*"
|
||||
websockets = "*"
|
||||
@@ -4912,14 +4913,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "pyright"
|
||||
version = "1.1.404"
|
||||
version = "1.1.406"
|
||||
description = "Command line wrapper for pyright"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pyright-1.1.404-py3-none-any.whl", hash = "sha256:c7b7ff1fdb7219c643079e4c3e7d4125f0dafcc19d253b47e898d130ea426419"},
|
||||
{file = "pyright-1.1.404.tar.gz", hash = "sha256:455e881a558ca6be9ecca0b30ce08aa78343ecc031d37a198ffa9a7a1abeb63e"},
|
||||
{file = "pyright-1.1.406-py3-none-any.whl", hash = "sha256:1d81fb43c2407bf566e97e57abb01c811973fdb21b2df8df59f870f688bdca71"},
|
||||
{file = "pyright-1.1.406.tar.gz", hash = "sha256:c4872bc58c9643dac09e8a2e74d472c62036910b3bd37a32813989ef7576ea2c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4977,14 +4978,14 @@ testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-mock"
|
||||
version = "3.14.1"
|
||||
version = "3.15.1"
|
||||
description = "Thin-wrapper around the mock package for easier use with pytest"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0"},
|
||||
{file = "pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e"},
|
||||
{file = "pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d"},
|
||||
{file = "pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -5764,31 +5765,31 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.12.11"
|
||||
version = "0.13.3"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.12.11-py3-none-linux_armv6l.whl", hash = "sha256:93fce71e1cac3a8bf9200e63a38ac5c078f3b6baebffb74ba5274fb2ab276065"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8e33ac7b28c772440afa80cebb972ffd823621ded90404f29e5ab6d1e2d4b93"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d69fb9d4937aa19adb2e9f058bc4fbfe986c2040acb1a4a9747734834eaa0bfd"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:411954eca8464595077a93e580e2918d0a01a19317af0a72132283e28ae21bee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6a2c0a2e1a450f387bf2c6237c727dd22191ae8c00e448e0672d624b2bbd7fb0"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ca4c3a7f937725fd2413c0e884b5248a19369ab9bdd850b5781348ba283f644"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4d1df0098124006f6a66ecf3581a7f7e754c4df7644b2e6704cd7ca80ff95211"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a8dd5f230efc99a24ace3b77e3555d3fbc0343aeed3fc84c8d89e75ab2ff793"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dc75533039d0ed04cd33fb8ca9ac9620b99672fe7ff1533b6402206901c34ee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fc58f9266d62c6eccc75261a665f26b4ef64840887fc6cbc552ce5b29f96cc8"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:5a0113bd6eafd545146440225fe60b4e9489f59eb5f5f107acd715ba5f0b3d2f"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0d737b4059d66295c3ea5720e6efc152623bb83fde5444209b69cd33a53e2000"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:916fc5defee32dbc1fc1650b576a8fed68f5e8256e2180d4d9855aea43d6aab2"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c984f07d7adb42d3ded5be894fb4007f30f82c87559438b4879fe7aa08c62b39"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e07fbb89f2e9249f219d88331c833860489b49cdf4b032b8e4432e9b13e8a4b9"},
|
||||
{file = "ruff-0.12.11-py3-none-win32.whl", hash = "sha256:c792e8f597c9c756e9bcd4d87cf407a00b60af77078c96f7b6366ea2ce9ba9d3"},
|
||||
{file = "ruff-0.12.11-py3-none-win_amd64.whl", hash = "sha256:a3283325960307915b6deb3576b96919ee89432ebd9c48771ca12ee8afe4a0fd"},
|
||||
{file = "ruff-0.12.11-py3-none-win_arm64.whl", hash = "sha256:bae4d6e6a2676f8fb0f98b74594a048bae1b944aab17e9f5d504062303c6dbea"},
|
||||
{file = "ruff-0.12.11.tar.gz", hash = "sha256:c6b09ae8426a65bbee5425b9d0b82796dbb07cb1af045743c79bfb163001165d"},
|
||||
{file = "ruff-0.13.3-py3-none-linux_armv6l.whl", hash = "sha256:311860a4c5e19189c89d035638f500c1e191d283d0cc2f1600c8c80d6dcd430c"},
|
||||
{file = "ruff-0.13.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:2bdad6512fb666b40fcadb65e33add2b040fc18a24997d2e47fee7d66f7fcae2"},
|
||||
{file = "ruff-0.13.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fc6fa4637284708d6ed4e5e970d52fc3b76a557d7b4e85a53013d9d201d93286"},
|
||||
{file = "ruff-0.13.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c9e6469864f94a98f412f20ea143d547e4c652f45e44f369d7b74ee78185838"},
|
||||
{file = "ruff-0.13.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5bf62b705f319476c78891e0e97e965b21db468b3c999086de8ffb0d40fd2822"},
|
||||
{file = "ruff-0.13.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78cc1abed87ce40cb07ee0667ce99dbc766c9f519eabfd948ed87295d8737c60"},
|
||||
{file = "ruff-0.13.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4fb75e7c402d504f7a9a259e0442b96403fa4a7310ffe3588d11d7e170d2b1e3"},
|
||||
{file = "ruff-0.13.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:17b951f9d9afb39330b2bdd2dd144ce1c1335881c277837ac1b50bfd99985ed3"},
|
||||
{file = "ruff-0.13.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6052f8088728898e0a449f0dde8fafc7ed47e4d878168b211977e3e7e854f662"},
|
||||
{file = "ruff-0.13.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc742c50f4ba72ce2a3be362bd359aef7d0d302bf7637a6f942eaa763bd292af"},
|
||||
{file = "ruff-0.13.3-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:8e5640349493b378431637019366bbd73c927e515c9c1babfea3e932f5e68e1d"},
|
||||
{file = "ruff-0.13.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6b139f638a80eae7073c691a5dd8d581e0ba319540be97c343d60fb12949c8d0"},
|
||||
{file = "ruff-0.13.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:6b547def0a40054825de7cfa341039ebdfa51f3d4bfa6a0772940ed351d2746c"},
|
||||
{file = "ruff-0.13.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9cc48a3564423915c93573f1981d57d101e617839bef38504f85f3677b3a0a3e"},
|
||||
{file = "ruff-0.13.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1a993b17ec03719c502881cb2d5f91771e8742f2ca6de740034433a97c561989"},
|
||||
{file = "ruff-0.13.3-py3-none-win32.whl", hash = "sha256:f14e0d1fe6460f07814d03c6e32e815bff411505178a1f539a38f6097d3e8ee3"},
|
||||
{file = "ruff-0.13.3-py3-none-win_amd64.whl", hash = "sha256:621e2e5812b691d4f244638d693e640f188bacbb9bc793ddd46837cea0503dd2"},
|
||||
{file = "ruff-0.13.3-py3-none-win_arm64.whl", hash = "sha256:9e9e9d699841eaf4c2c798fa783df2fabc680b72059a02ca0ed81c460bc58330"},
|
||||
{file = "ruff-0.13.3.tar.gz", hash = "sha256:5b0ba0db740eefdfbcce4299f49e9eaefc643d4d007749d77d047c2bab19908e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7273,4 +7274,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "a19fcce9d4ab88f14eb1e5baa83e1e5a90c9995b04b84dae7cfa6257cf19012c"
|
||||
content-hash = "ff0f6f8d90793ea95f1f7008f7c845432ff46fca0937d5068b4f7cfec0ee7674"
|
||||
|
||||
@@ -78,7 +78,7 @@ aioclamd = "^1.0.0"
|
||||
setuptools = "^80.9.0"
|
||||
gcloud-aio-storage = "^9.5.0"
|
||||
pandas = "^2.3.1"
|
||||
firecrawl-py = "^2.16.3"
|
||||
firecrawl-py = "^4.3.6"
|
||||
exa-py = "^1.14.20"
|
||||
croniter = "^6.0.0"
|
||||
stagehand = "^0.5.1"
|
||||
@@ -86,16 +86,16 @@ stagehand = "^0.5.1"
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
black = "^24.10.0"
|
||||
faker = "^37.6.0"
|
||||
faker = "^37.8.0"
|
||||
httpx = "^0.28.1"
|
||||
isort = "^5.13.2"
|
||||
poethepoet = "^0.37.0"
|
||||
pre-commit = "^4.3.0"
|
||||
pyright = "^1.1.404"
|
||||
pytest-mock = "^3.14.0"
|
||||
pyright = "^1.1.406"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-watcher = "^0.4.2"
|
||||
requests = "^2.32.5"
|
||||
ruff = "^0.12.11"
|
||||
ruff = "^0.13.3"
|
||||
# NOTE: please insert new dependencies in their alphabetical location
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -31,7 +31,7 @@ Sentry.init({
|
||||
Sentry.extraErrorDataIntegration(),
|
||||
Sentry.browserProfilingIntegration(),
|
||||
Sentry.httpClientIntegration(),
|
||||
// Sentry.launchDarklyIntegration(),
|
||||
Sentry.launchDarklyIntegration(),
|
||||
Sentry.replayIntegration({
|
||||
unmask: [".sentry-unmask, [data-sentry-unmask]"],
|
||||
}),
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
"dependencies": {
|
||||
"@faker-js/faker": "10.0.0",
|
||||
"@hookform/resolvers": "5.2.1",
|
||||
"@marsidev/react-turnstile": "1.3.1",
|
||||
"@next/third-parties": "15.4.6",
|
||||
"@phosphor-icons/react": "2.1.10",
|
||||
"@radix-ui/react-alert-dialog": "1.1.15",
|
||||
|
||||
14
autogpt_platform/frontend/pnpm-lock.yaml
generated
14
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -14,6 +14,9 @@ importers:
|
||||
'@hookform/resolvers':
|
||||
specifier: 5.2.1
|
||||
version: 5.2.1(react-hook-form@7.62.0(react@18.3.1))
|
||||
'@marsidev/react-turnstile':
|
||||
specifier: 1.3.1
|
||||
version: 1.3.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@next/third-parties':
|
||||
specifier: 15.4.6
|
||||
version: 15.4.6(next@15.4.7(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.55.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||
@@ -1428,6 +1431,12 @@ packages:
|
||||
peerDependencies:
|
||||
jsep: ^0.4.0||^1.0.0
|
||||
|
||||
'@marsidev/react-turnstile@1.3.1':
|
||||
resolution: {integrity: sha512-h2THG/75k4Y049hgjSGPIcajxXnh+IZAiXVbryQyVmagkboN7pJtBgR16g8akjwUBSfRrg6jw6KvPDjscQflog==}
|
||||
peerDependencies:
|
||||
react: ^17.0.2 || ^18.0.0 || ^19.0
|
||||
react-dom: ^17.0.2 || ^18.0.0 || ^19.0
|
||||
|
||||
'@mdx-js/react@3.1.1':
|
||||
resolution: {integrity: sha512-f++rKLQgUVYDAtECQ6fn/is15GkEH9+nZPM3MS0RcxVqoTfawHvDlSCH7JbMhAM6uJ32v3eXLvLmLvjGu7PTQw==}
|
||||
peerDependencies:
|
||||
@@ -8668,6 +8677,11 @@ snapshots:
|
||||
dependencies:
|
||||
jsep: 1.4.0
|
||||
|
||||
'@marsidev/react-turnstile@1.3.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
|
||||
dependencies:
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
|
||||
'@mdx-js/react@3.1.1(@types/react@18.3.17)(react@18.3.1)':
|
||||
dependencies:
|
||||
'@types/mdx': 2.0.13
|
||||
|
||||
@@ -7,10 +7,8 @@ import { useMemo } from "react";
|
||||
import { CustomNode } from "./nodes/CustomNode";
|
||||
import { useCustomEdge } from "./edges/useCustomEdge";
|
||||
import CustomEdge from "./edges/CustomEdge";
|
||||
import { RightSidebar } from "../RIghtSidebar";
|
||||
|
||||
export const Flow = () => {
|
||||
// All these 3 are working perfectly
|
||||
const nodes = useNodeStore(useShallow((state) => state.nodes));
|
||||
const onNodesChange = useNodeStore(
|
||||
useShallow((state) => state.onNodesChange),
|
||||
@@ -20,7 +18,6 @@ export const Flow = () => {
|
||||
|
||||
return (
|
||||
<div className="flex h-full w-full dark:bg-slate-900">
|
||||
{/* Builder area - flexible width */}
|
||||
<div className="relative flex-1">
|
||||
<ReactFlow
|
||||
nodes={nodes}
|
||||
@@ -36,9 +33,6 @@ export const Flow = () => {
|
||||
<NewControlPanel />
|
||||
</ReactFlow>
|
||||
</div>
|
||||
<div className="w-[30%]">
|
||||
<RightSidebar />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -42,7 +42,7 @@ const CustomEdge = ({
|
||||
<EdgeLabelRenderer>
|
||||
<Button
|
||||
onClick={() => removeConnection(id)}
|
||||
className={`absolute z-10 min-w-0 p-1`}
|
||||
className={`absolute z-10 h-fit min-w-0 p-1`}
|
||||
variant="secondary"
|
||||
style={{
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
||||
|
||||
@@ -8,6 +8,7 @@ import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
import { preprocessInputSchema } from "../processors/input-schema-pre-processor";
|
||||
import { OutputHandler } from "./OutputHandler";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export type CustomNodeData = {
|
||||
hardcodedValues: {
|
||||
@@ -22,14 +23,19 @@ export type CustomNodeData = {
|
||||
export type CustomNode = XYNode<CustomNodeData, "custom">;
|
||||
|
||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
({ data, id }) => {
|
||||
({ data, id, selected }) => {
|
||||
const showAdvanced = useNodeStore(
|
||||
(state) => state.nodeAdvancedStates[id] || false,
|
||||
);
|
||||
const setShowAdvanced = useNodeStore((state) => state.setShowAdvanced);
|
||||
|
||||
return (
|
||||
<div className="rounded-xl border border-slate-200/60 bg-gradient-to-br from-white to-slate-50/30 shadow-lg shadow-slate-900/5 backdrop-blur-sm">
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-xl bg-gradient-to-br from-white to-slate-50/30 shadow-lg shadow-slate-900/5 ring-1 ring-slate-200/60 backdrop-blur-sm",
|
||||
selected && "shadow-2xl ring-2 ring-slate-200",
|
||||
)}
|
||||
>
|
||||
{/* Header */}
|
||||
<div className="flex h-14 items-center justify-center rounded-xl border-b border-slate-200/50 bg-gradient-to-r from-slate-50/80 to-white/90">
|
||||
<Text
|
||||
|
||||
@@ -29,7 +29,7 @@ export const OutputHandler = ({
|
||||
<div className="flex flex-col items-end justify-between gap-2 rounded-b-xl border-t border-slate-200/50 bg-white py-3.5">
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="mr-4 p-0"
|
||||
className="mr-4 h-fit min-w-0 p-0 hover:border-transparent hover:bg-transparent"
|
||||
onClick={() => setIsOutputVisible(!isOutputVisible)}
|
||||
>
|
||||
<Text
|
||||
@@ -54,30 +54,27 @@ export const OutputHandler = ({
|
||||
|
||||
return shouldShow ? (
|
||||
<div key={key} className="relative flex items-center gap-2">
|
||||
<Text
|
||||
variant="body"
|
||||
className="flex items-center gap-2 font-medium text-slate-700"
|
||||
>
|
||||
{property?.description && (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span
|
||||
style={{ marginLeft: 6, cursor: "pointer" }}
|
||||
aria-label="info"
|
||||
tabIndex={0}
|
||||
>
|
||||
<InfoIcon />
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{property?.description}</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
{property?.description && (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span
|
||||
style={{ marginLeft: 6, cursor: "pointer" }}
|
||||
aria-label="info"
|
||||
tabIndex={0}
|
||||
>
|
||||
<InfoIcon />
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{property?.description}</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
<Text variant="body" className="text-slate-700">
|
||||
{property?.title || key}{" "}
|
||||
<Text variant="small" as="span" className={colorClass}>
|
||||
({displayType})
|
||||
</Text>
|
||||
</Text>
|
||||
<Text variant="small" as="span" className={colorClass}>
|
||||
({displayType})
|
||||
</Text>
|
||||
<NodeHandle id={key} isConnected={isConnected} side="right" />
|
||||
</div>
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import React from "react";
|
||||
import { FieldProps } from "@rjsf/utils";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
|
||||
// We need to add all the logic for the credential fields here
|
||||
export const CredentialsField = (props: FieldProps) => {
|
||||
const { formData = {}, onChange, required: _required, schema } = props;
|
||||
|
||||
const _credentialProvider = schema.credentials_provider;
|
||||
const _credentialType = schema.credentials_types;
|
||||
const _description = schema.description;
|
||||
const _title = schema.title;
|
||||
|
||||
// Helper to update one property
|
||||
const setField = (key: string, value: any) =>
|
||||
onChange({ ...formData, [key]: value });
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Input
|
||||
hideLabel={true}
|
||||
label={""}
|
||||
id="credentials-id"
|
||||
type="text"
|
||||
value={formData.id || ""}
|
||||
onChange={(e) => setField("id", e.target.value)}
|
||||
placeholder="Enter your API Key"
|
||||
required
|
||||
size="small"
|
||||
wrapperClassName="mb-0"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,93 @@
|
||||
import React, { useEffect } from "react";
|
||||
import { FieldProps } from "@rjsf/utils";
|
||||
import { useCredentialField } from "./useCredentialField";
|
||||
import { KeyIcon, PlusIcon } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { SelectCredential } from "./SelectCredential";
|
||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api";
|
||||
import { APIKeyCredentialsModal } from "./models/APIKeyCredentialModal/APIKeyCredentialModal";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
|
||||
export const CredentialsField = (props: FieldProps) => {
|
||||
const { formData = {}, onChange, required: _required, schema } = props;
|
||||
const {
|
||||
credentials,
|
||||
isCredentialListLoading,
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
isAPIKeyModalOpen,
|
||||
setIsAPIKeyModalOpen,
|
||||
credentialsExists,
|
||||
} = useCredentialField({
|
||||
credentialSchema: schema as BlockIOCredentialsSubSchema,
|
||||
});
|
||||
|
||||
const setField = (key: string, value: any) =>
|
||||
onChange({ ...formData, [key]: value });
|
||||
|
||||
useEffect(() => {
|
||||
if (!isCredentialListLoading && credentials.length > 0 && !formData.id) {
|
||||
const latestCredential = credentials[credentials.length - 1];
|
||||
setField("id", latestCredential.id);
|
||||
}
|
||||
}, [isCredentialListLoading, credentials, formData.id]);
|
||||
|
||||
const handleCredentialCreated = (credentialId: string) => {
|
||||
setField("id", credentialId);
|
||||
};
|
||||
|
||||
if (isCredentialListLoading) {
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Skeleton className="h-8 w-full rounded-xlarge" />
|
||||
<Skeleton className="h-8 w-[30%] rounded-xlarge" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
{credentialsExists && (
|
||||
<SelectCredential
|
||||
credentials={credentials}
|
||||
value={formData.id}
|
||||
onChange={(value) => setField("id", value)}
|
||||
disabled={false}
|
||||
label="Credential"
|
||||
placeholder="Select credential"
|
||||
/>
|
||||
)}
|
||||
|
||||
<div>
|
||||
{supportsApiKey && (
|
||||
<>
|
||||
<APIKeyCredentialsModal
|
||||
schema={schema as BlockIOCredentialsSubSchema}
|
||||
open={isAPIKeyModalOpen}
|
||||
onClose={() => setIsAPIKeyModalOpen(false)}
|
||||
onSuccess={handleCredentialCreated}
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
className="w-auto min-w-0"
|
||||
size="small"
|
||||
onClick={() => setIsAPIKeyModalOpen(true)}
|
||||
>
|
||||
<KeyIcon />
|
||||
<Text variant="body-medium" className="!text-white opacity-100">
|
||||
Add API key
|
||||
</Text>
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
{supportsOAuth2 && (
|
||||
<Button type="button" className="w-fit" size="small">
|
||||
<PlusIcon />
|
||||
Add OAuth2
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,69 @@
|
||||
import React from "react";
|
||||
import { Select } from "@/components/atoms/Select/Select";
|
||||
import { CredentialsMetaResponse } from "@/app/api/__generated__/models/credentialsMetaResponse";
|
||||
import { ArrowSquareOutIcon, KeyIcon } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import Link from "next/link";
|
||||
|
||||
type SelectCredentialProps = {
|
||||
credentials: CredentialsMetaResponse[];
|
||||
value?: string;
|
||||
onChange: (credentialId: string) => void;
|
||||
disabled?: boolean;
|
||||
label?: string;
|
||||
placeholder?: string;
|
||||
};
|
||||
|
||||
export const SelectCredential: React.FC<SelectCredentialProps> = ({
|
||||
credentials,
|
||||
value,
|
||||
onChange,
|
||||
disabled = false,
|
||||
label = "Credential",
|
||||
placeholder = "Select credential",
|
||||
}) => {
|
||||
const options = credentials.map((cred) => {
|
||||
const details: string[] = [];
|
||||
if (cred.title && cred.title !== cred.provider) {
|
||||
details.push(cred.title);
|
||||
}
|
||||
if (cred.username) {
|
||||
details.push(cred.username);
|
||||
}
|
||||
if (cred.host) {
|
||||
details.push(cred.host);
|
||||
}
|
||||
const label =
|
||||
details.length > 0
|
||||
? `${cred.provider} (${details.join(" - ")})`
|
||||
: cred.provider;
|
||||
|
||||
return {
|
||||
value: cred.id,
|
||||
label,
|
||||
icon: <KeyIcon className="h-4 w-4" />,
|
||||
};
|
||||
});
|
||||
|
||||
return (
|
||||
<div className="flex w-full items-center gap-2">
|
||||
<Select
|
||||
label={label}
|
||||
id="select-credential"
|
||||
wrapperClassName="!mb-0 flex-1"
|
||||
value={value}
|
||||
onValueChange={onChange}
|
||||
options={options}
|
||||
disabled={disabled}
|
||||
placeholder={placeholder}
|
||||
size="small"
|
||||
hideLabel
|
||||
/>
|
||||
<Link href={`/profile/integrations`}>
|
||||
<Button variant="outline" size="icon" className="h-8 w-8 p-0">
|
||||
<ArrowSquareOutIcon className="h-4 w-4" />
|
||||
</Button>
|
||||
</Link>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,101 @@
|
||||
import { CredentialsMetaResponse } from "@/app/api/__generated__/models/credentialsMetaResponse";
|
||||
// Need to replace these icons with phosphor icons
|
||||
import {
|
||||
FaDiscord,
|
||||
FaMedium,
|
||||
FaGithub,
|
||||
FaGoogle,
|
||||
FaHubspot,
|
||||
FaTwitter,
|
||||
} from "react-icons/fa";
|
||||
import { GoogleLogoIcon, KeyIcon, NotionLogoIcon } from "@phosphor-icons/react";
|
||||
|
||||
export const filterCredentialsByProvider = (
|
||||
credentials: CredentialsMetaResponse[] | undefined,
|
||||
provider: string[],
|
||||
) => {
|
||||
const filtered =
|
||||
credentials?.filter((credential) =>
|
||||
provider.includes(credential.provider),
|
||||
) ?? [];
|
||||
return {
|
||||
credentials: filtered,
|
||||
exists: filtered.length > 0,
|
||||
};
|
||||
};
|
||||
|
||||
export function toDisplayName(provider: string): string {
|
||||
console.log("provider", provider);
|
||||
// Special cases that need manual handling
|
||||
const specialCases: Record<string, string> = {
|
||||
aiml_api: "AI/ML",
|
||||
d_id: "D-ID",
|
||||
e2b: "E2B",
|
||||
llama_api: "Llama API",
|
||||
open_router: "Open Router",
|
||||
smtp: "SMTP",
|
||||
revid: "Rev.ID",
|
||||
};
|
||||
|
||||
if (specialCases[provider]) {
|
||||
return specialCases[provider];
|
||||
}
|
||||
|
||||
// General case: convert snake_case to Title Case
|
||||
return provider
|
||||
.split(/[_-]/)
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase())
|
||||
.join(" ");
|
||||
}
|
||||
|
||||
export function isCredentialFieldSchema(schema: any): boolean {
|
||||
return (
|
||||
typeof schema === "object" &&
|
||||
schema !== null &&
|
||||
"credentials_provider" in schema
|
||||
);
|
||||
}
|
||||
|
||||
export const providerIcons: Partial<
|
||||
Record<string, React.FC<{ className?: string }>>
|
||||
> = {
|
||||
aiml_api: KeyIcon,
|
||||
anthropic: KeyIcon,
|
||||
apollo: KeyIcon,
|
||||
e2b: KeyIcon,
|
||||
github: FaGithub,
|
||||
google: GoogleLogoIcon,
|
||||
groq: KeyIcon,
|
||||
http: KeyIcon,
|
||||
notion: NotionLogoIcon,
|
||||
nvidia: KeyIcon,
|
||||
discord: FaDiscord,
|
||||
d_id: KeyIcon,
|
||||
google_maps: FaGoogle,
|
||||
jina: KeyIcon,
|
||||
ideogram: KeyIcon,
|
||||
linear: KeyIcon,
|
||||
medium: FaMedium,
|
||||
mem0: KeyIcon,
|
||||
ollama: KeyIcon,
|
||||
openai: KeyIcon,
|
||||
openweathermap: KeyIcon,
|
||||
open_router: KeyIcon,
|
||||
llama_api: KeyIcon,
|
||||
pinecone: KeyIcon,
|
||||
enrichlayer: KeyIcon,
|
||||
slant3d: KeyIcon,
|
||||
screenshotone: KeyIcon,
|
||||
smtp: KeyIcon,
|
||||
replicate: KeyIcon,
|
||||
reddit: KeyIcon,
|
||||
fal: KeyIcon,
|
||||
revid: KeyIcon,
|
||||
twitter: FaTwitter,
|
||||
unreal_speech: KeyIcon,
|
||||
exa: KeyIcon,
|
||||
hubspot: FaHubspot,
|
||||
smartlead: KeyIcon,
|
||||
todoist: KeyIcon,
|
||||
zerobounce: KeyIcon,
|
||||
};
|
||||
@@ -0,0 +1,119 @@
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import {
|
||||
Form,
|
||||
FormDescription,
|
||||
FormField,
|
||||
} from "@/components/__legacy__/ui/form";
|
||||
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types"; // we need to find a way to replace it with autogenerated types
|
||||
import { useAPIKeyCredentialsModal } from "./useAPIKeyCredentialsModal";
|
||||
import { toDisplayName } from "../../helpers";
|
||||
|
||||
type Props = {
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onSuccess: (credentialId: string) => void;
|
||||
};
|
||||
|
||||
export function APIKeyCredentialsModal({
|
||||
schema,
|
||||
open,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}: Props) {
|
||||
const { form, isLoading, schemaDescription, onSubmit, provider } =
|
||||
useAPIKeyCredentialsModal({ schema, onClose, onSuccess });
|
||||
|
||||
if (isLoading) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title={`Add new API key for ${toDisplayName(provider) ?? ""}`}
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: (isOpen) => {
|
||||
if (!isOpen) onClose();
|
||||
},
|
||||
}}
|
||||
onClose={onClose}
|
||||
styling={{
|
||||
maxWidth: "25rem",
|
||||
}}
|
||||
>
|
||||
<Dialog.Content>
|
||||
{schemaDescription && (
|
||||
<p className="mb-4 text-sm text-zinc-600">{schemaDescription}</p>
|
||||
)}
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-2">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="apiKey"
|
||||
render={({ field }) => (
|
||||
<>
|
||||
<Input
|
||||
id="apiKey"
|
||||
label="API Key"
|
||||
type="password"
|
||||
placeholder="Enter API key..."
|
||||
size="small"
|
||||
hint={
|
||||
schema.credentials_scopes ? (
|
||||
<FormDescription>
|
||||
Required scope(s) for this block:{" "}
|
||||
{schema.credentials_scopes?.map((s, i, a) => (
|
||||
<span key={i}>
|
||||
<code className="text-xs font-bold">{s}</code>
|
||||
{i < a.length - 1 && ", "}
|
||||
</span>
|
||||
))}
|
||||
</FormDescription>
|
||||
) : null
|
||||
}
|
||||
{...field}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="title"
|
||||
render={({ field }) => (
|
||||
<Input
|
||||
id="title"
|
||||
label="Name"
|
||||
type="text"
|
||||
placeholder="Enter a name for this API key..."
|
||||
size="small"
|
||||
{...field}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="expiresAt"
|
||||
render={({ field }) => (
|
||||
<Input
|
||||
id="expiresAt"
|
||||
label="Expiration Date"
|
||||
type="datetime-local"
|
||||
placeholder="Select expiration date..."
|
||||
size="small"
|
||||
{...field}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<Button type="submit" size="small" className="min-w-68">
|
||||
Save & use this API key
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
import { z } from "zod";
|
||||
import { useForm, type UseFormReturn } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types";
|
||||
import {
|
||||
getGetV1ListCredentialsQueryKey,
|
||||
usePostV1CreateCredentials,
|
||||
} from "@/app/api/__generated__/endpoints/integrations/integrations";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { APIKeyCredentials } from "@/app/api/__generated__/models/aPIKeyCredentials";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { PostV1CreateCredentials201 } from "@/app/api/__generated__/models/postV1CreateCredentials201";
|
||||
|
||||
export type APIKeyFormValues = {
|
||||
apiKey: string;
|
||||
title: string;
|
||||
expiresAt?: string;
|
||||
};
|
||||
|
||||
type useAPIKeyCredentialsModalType = {
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
onClose: () => void;
|
||||
onSuccess: (credentialId: string) => void;
|
||||
};
|
||||
|
||||
export function useAPIKeyCredentialsModal({
|
||||
schema,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}: useAPIKeyCredentialsModalType): {
|
||||
form: UseFormReturn<APIKeyFormValues>;
|
||||
isLoading: boolean;
|
||||
provider: string;
|
||||
schemaDescription?: string;
|
||||
onSubmit: (values: APIKeyFormValues) => Promise<void>;
|
||||
} {
|
||||
const { toast } = useToast();
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const { mutateAsync: createCredentials, isPending: isCreatingCredentials } =
|
||||
usePostV1CreateCredentials({
|
||||
mutation: {
|
||||
onSuccess: async (response) => {
|
||||
const credentialId = (response.data as PostV1CreateCredentials201)
|
||||
?.id;
|
||||
onClose();
|
||||
form.reset();
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "Credentials created successfully",
|
||||
variant: "default",
|
||||
});
|
||||
|
||||
await queryClient.refetchQueries({
|
||||
queryKey: getGetV1ListCredentialsQueryKey(),
|
||||
});
|
||||
|
||||
if (credentialId && onSuccess) {
|
||||
onSuccess(credentialId);
|
||||
}
|
||||
},
|
||||
onError: () => {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to create credentials.",
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const formSchema = z.object({
|
||||
apiKey: z.string().min(1, "API Key is required"),
|
||||
title: z.string().min(1, "Name is required"),
|
||||
expiresAt: z.string().optional(),
|
||||
});
|
||||
|
||||
const form = useForm<APIKeyFormValues>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
apiKey: "",
|
||||
title: "",
|
||||
expiresAt: "",
|
||||
},
|
||||
});
|
||||
|
||||
async function onSubmit(values: APIKeyFormValues) {
|
||||
const expiresAt = values.expiresAt
|
||||
? new Date(values.expiresAt).getTime() / 1000
|
||||
: undefined;
|
||||
|
||||
createCredentials({
|
||||
provider: schema.credentials_provider[0],
|
||||
data: {
|
||||
provider: schema.credentials_provider[0],
|
||||
type: "api_key",
|
||||
api_key: values.apiKey,
|
||||
title: values.title,
|
||||
expires_at: expiresAt,
|
||||
} as APIKeyCredentials,
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
form,
|
||||
isLoading: isCreatingCredentials,
|
||||
provider: schema.credentials_provider[0],
|
||||
schemaDescription: schema.description,
|
||||
onSubmit,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
import { useGetV1ListCredentials } from "@/app/api/__generated__/endpoints/integrations/integrations";
|
||||
import { CredentialsMetaResponse } from "@/app/api/__generated__/models/credentialsMetaResponse";
|
||||
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api";
|
||||
import { useState } from "react";
|
||||
import { filterCredentialsByProvider } from "./helpers";
|
||||
|
||||
export const useCredentialField = ({
|
||||
credentialSchema,
|
||||
}: {
|
||||
credentialSchema: BlockIOCredentialsSubSchema; // Here we are using manual typing, we need to fix it with automatic one
|
||||
}) => {
|
||||
const [isAPIKeyModalOpen, setIsAPIKeyModalOpen] = useState(false);
|
||||
|
||||
// Fetch all the credentials from the backend
|
||||
// We will save it in cache for 10 min, if user edits the credential, we will invalidate the cache
|
||||
// Whenever user adds a block, we filter the credentials list and check if this block's provider is in the list
|
||||
const { data: credentials, isLoading: isCredentialListLoading } =
|
||||
useGetV1ListCredentials({
|
||||
query: {
|
||||
refetchInterval: 10 * 60 * 1000,
|
||||
select: (x) => {
|
||||
return x.data as CredentialsMetaResponse[];
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const supportsApiKey = credentialSchema.credentials_types.includes("api_key");
|
||||
const supportsOAuth2 = credentialSchema.credentials_types.includes("oauth2");
|
||||
|
||||
const credentialProviders = credentialSchema.credentials_provider;
|
||||
const { credentials: filteredCredentials, exists: credentialsExists } =
|
||||
filterCredentialsByProvider(credentials, credentialProviders);
|
||||
|
||||
return {
|
||||
credentials: filteredCredentials,
|
||||
isCredentialListLoading,
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
isAPIKeyModalOpen,
|
||||
setIsAPIKeyModalOpen,
|
||||
credentialsExists,
|
||||
};
|
||||
};
|
||||
@@ -1,5 +1,5 @@
|
||||
import { RegistryFieldsType } from "@rjsf/utils";
|
||||
import { CredentialsField } from "./CredentialField";
|
||||
import { CredentialsField } from "./CredentialField/CredentialField";
|
||||
import { AnyOfField } from "./AnyOfField/AnyOfField";
|
||||
import { ObjectField } from "./ObjectField";
|
||||
|
||||
|
||||
@@ -15,6 +15,11 @@ import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { generateHandleId } from "../../handlers/helpers";
|
||||
import { getTypeDisplayInfo } from "../helpers";
|
||||
import { ArrayEditorContext } from "../../components/ArrayEditor/ArrayEditorContext";
|
||||
import {
|
||||
isCredentialFieldSchema,
|
||||
toDisplayName,
|
||||
} from "../fields/CredentialField/helpers";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const FieldTemplate: React.FC<FieldTemplateProps> = ({
|
||||
id,
|
||||
@@ -47,6 +52,7 @@ const FieldTemplate: React.FC<FieldTemplateProps> = ({
|
||||
}
|
||||
const isAnyOf = Array.isArray((schema as any)?.anyOf);
|
||||
const isOneOf = Array.isArray((schema as any)?.oneOf);
|
||||
const isCredential = isCredentialFieldSchema(schema);
|
||||
const suppressHandle = isAnyOf || isOneOf;
|
||||
|
||||
if (!showAdvanced && schema.advanced === true && !isConnected) {
|
||||
@@ -63,12 +69,17 @@ const FieldTemplate: React.FC<FieldTemplateProps> = ({
|
||||
<div className="mt-4 w-[400px] space-y-1">
|
||||
{label && schema.type && (
|
||||
<label htmlFor={id} className="flex items-center gap-1">
|
||||
{!suppressHandle && !fromAnyOf && (
|
||||
{!suppressHandle && !fromAnyOf && !isCredential && (
|
||||
<NodeHandle id={fieldKey} isConnected={isConnected} side="left" />
|
||||
)}
|
||||
{!fromAnyOf && (
|
||||
<Text variant="body" className="line-clamp-1">
|
||||
{label}
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn("line-clamp-1", isCredential && "ml-3")}
|
||||
>
|
||||
{isCredential
|
||||
? toDisplayName(schema.credentials_provider[0]) + " credentials"
|
||||
: label}
|
||||
</Text>
|
||||
)}
|
||||
{!fromAnyOf && (
|
||||
|
||||
@@ -4,6 +4,7 @@ import { Button } from "@/components/__legacy__/Button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
@@ -64,6 +65,9 @@ export default function LibraryUploadAgentDialog(): React.ReactNode {
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle className="mb-8 text-center">Upload Agent</DialogTitle>
|
||||
<DialogDescription>
|
||||
Upload your agent by providing a name, description, and JSON file.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<Form {...form}>
|
||||
|
||||
@@ -85,18 +85,16 @@ export default function LoginPage() {
|
||||
/>
|
||||
|
||||
{/* Turnstile CAPTCHA Component */}
|
||||
{turnstile.shouldRender ? (
|
||||
<Turnstile
|
||||
key={captchaKey}
|
||||
siteKey={turnstile.siteKey}
|
||||
onVerify={turnstile.handleVerify}
|
||||
onExpire={turnstile.handleExpire}
|
||||
onError={turnstile.handleError}
|
||||
setWidgetId={turnstile.setWidgetId}
|
||||
action="login"
|
||||
shouldRender={turnstile.shouldRender}
|
||||
/>
|
||||
) : null}
|
||||
<Turnstile
|
||||
key={captchaKey}
|
||||
siteKey={turnstile.siteKey}
|
||||
onVerify={turnstile.handleVerify}
|
||||
onExpire={turnstile.handleExpire}
|
||||
onError={turnstile.handleError}
|
||||
setWidgetId={turnstile.setWidgetId}
|
||||
action="login"
|
||||
shouldRender={turnstile.shouldRender}
|
||||
/>
|
||||
|
||||
<Button
|
||||
variant="primary"
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
import { SearchBar } from "@/components/__legacy__/SearchBar";
|
||||
import { useMainSearchResultPage } from "./useMainSearchResultPage";
|
||||
import { SearchFilterChips } from "@/components/__legacy__/SearchFilterChips";
|
||||
import { SortDropdown } from "@/components/__legacy__/SortDropdown";
|
||||
import { AgentsSection } from "../AgentsSection/AgentsSection";
|
||||
import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
import { FeaturedCreators } from "../FeaturedCreators/FeaturedCreators";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { MainMarketplacePageLoading } from "../MainMarketplacePageLoading";
|
||||
|
||||
export const MainSearchResultPage = ({
|
||||
searchTerm,
|
||||
sort,
|
||||
}: {
|
||||
searchTerm: string;
|
||||
sort: string;
|
||||
}) => {
|
||||
const {
|
||||
agents,
|
||||
creators,
|
||||
totalCount,
|
||||
agentsCount,
|
||||
creatorsCount,
|
||||
handleFilterChange,
|
||||
handleSortChange,
|
||||
showAgents,
|
||||
showCreators,
|
||||
isAgentsLoading,
|
||||
isCreatorsLoading,
|
||||
isAgentsError,
|
||||
isCreatorsError,
|
||||
} = useMainSearchResultPage({ searchTerm, sort });
|
||||
|
||||
const isLoading = isAgentsLoading || isCreatorsLoading;
|
||||
const hasError = isAgentsError || isCreatorsError;
|
||||
|
||||
if (isLoading) {
|
||||
return <MainMarketplacePageLoading />;
|
||||
}
|
||||
|
||||
if (hasError) {
|
||||
return (
|
||||
<div className="flex min-h-[500px] items-center justify-center">
|
||||
<ErrorCard
|
||||
isSuccess={false}
|
||||
responseError={{ message: "Failed to load marketplace data" }}
|
||||
context="marketplace page"
|
||||
onRetry={() => window.location.reload()}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div className="w-full">
|
||||
<div className="mx-auto min-h-screen max-w-[1440px] px-10 lg:min-w-[1440px]">
|
||||
<div className="mt-8 flex items-center">
|
||||
<div className="flex-1">
|
||||
<h2 className="text-base font-medium leading-normal text-neutral-800 dark:text-neutral-200">
|
||||
Results for:
|
||||
</h2>
|
||||
<h1 className="font-poppins text-2xl font-semibold leading-[32px] text-neutral-800 dark:text-neutral-100">
|
||||
{searchTerm}
|
||||
</h1>
|
||||
</div>
|
||||
<div className="flex-none">
|
||||
<SearchBar width="w-[439px]" height="h-[60px]" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{totalCount > 0 ? (
|
||||
<>
|
||||
<div className="mt-[36px] flex items-center justify-between">
|
||||
<SearchFilterChips
|
||||
totalCount={totalCount}
|
||||
agentsCount={agentsCount}
|
||||
creatorsCount={creatorsCount}
|
||||
onFilterChange={handleFilterChange}
|
||||
/>
|
||||
<SortDropdown onSort={handleSortChange} />
|
||||
</div>
|
||||
{/* Content section */}
|
||||
<div className="min-h-[500px] max-w-[1440px] space-y-8 py-8">
|
||||
{showAgents && agentsCount > 0 && agents && (
|
||||
<div className="mt-[36px]">
|
||||
<AgentsSection agents={agents} sectionTitle="Agents" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{showAgents && agentsCount > 0 && creatorsCount > 0 && (
|
||||
<Separator />
|
||||
)}
|
||||
{showCreators && creatorsCount > 0 && creators && (
|
||||
<FeaturedCreators
|
||||
featuredCreators={creators}
|
||||
title="Creators"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<div className="mt-20 flex flex-col items-center justify-center">
|
||||
<h3 className="mb-2 text-xl font-medium text-neutral-600 dark:text-neutral-300">
|
||||
No results found
|
||||
</h3>
|
||||
<p className="text-neutral-500 dark:text-neutral-400">
|
||||
Try adjusting your search terms or filters
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,123 @@
|
||||
import {
|
||||
useGetV2ListStoreAgents,
|
||||
useGetV2ListStoreCreators,
|
||||
} from "@/app/api/__generated__/endpoints/store/store";
|
||||
import { CreatorsResponse } from "@/app/api/__generated__/models/creatorsResponse";
|
||||
import { StoreAgentsResponse } from "@/app/api/__generated__/models/storeAgentsResponse";
|
||||
import { useState, useMemo } from "react";
|
||||
|
||||
interface useMainSearchResultPageType {
|
||||
searchTerm: string;
|
||||
sort: string;
|
||||
}
|
||||
|
||||
export const useMainSearchResultPage = ({
|
||||
searchTerm,
|
||||
sort,
|
||||
}: useMainSearchResultPageType) => {
|
||||
const [showAgents, setShowAgents] = useState(true);
|
||||
const [showCreators, setShowCreators] = useState(true);
|
||||
const [clientSortBy, setClientSortBy] = useState<string>(sort);
|
||||
|
||||
const {
|
||||
data: agentsData,
|
||||
isLoading: isAgentsLoading,
|
||||
isError: isAgentsError,
|
||||
} = useGetV2ListStoreAgents(
|
||||
{
|
||||
search_query: searchTerm,
|
||||
sorted_by: sort,
|
||||
},
|
||||
{
|
||||
query: {
|
||||
select: (x) => {
|
||||
return (x.data as StoreAgentsResponse).agents;
|
||||
},
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const {
|
||||
data: creatorsData,
|
||||
isLoading: isCreatorsLoading,
|
||||
isError: isCreatorsError,
|
||||
} = useGetV2ListStoreCreators(
|
||||
{ search_query: searchTerm, sorted_by: sort },
|
||||
{
|
||||
query: {
|
||||
select: (x) => {
|
||||
return (x.data as CreatorsResponse).creators;
|
||||
},
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// This is the strategy, we are using for sorting the agents and creators.
|
||||
// currently we are doing it client side but maybe we will shift it to the server side.
|
||||
// we will store the sortBy state in the url params, and then refetch the data with the new sortBy.
|
||||
|
||||
const agents = useMemo(() => {
|
||||
if (!agentsData) return [];
|
||||
|
||||
const sorted = [...agentsData];
|
||||
|
||||
if (clientSortBy === "runs") {
|
||||
return sorted.sort((a, b) => b.runs - a.runs);
|
||||
} else if (clientSortBy === "rating") {
|
||||
return sorted.sort((a, b) => b.rating - a.rating);
|
||||
} else {
|
||||
return sorted;
|
||||
}
|
||||
}, [agentsData, clientSortBy]);
|
||||
|
||||
const creators = useMemo(() => {
|
||||
if (!creatorsData) return [];
|
||||
|
||||
const sorted = [...creatorsData];
|
||||
|
||||
if (clientSortBy === "runs") {
|
||||
return sorted.sort((a, b) => b.agent_runs - a.agent_runs);
|
||||
} else if (clientSortBy === "rating") {
|
||||
return sorted.sort((a, b) => b.agent_rating - a.agent_rating);
|
||||
} else {
|
||||
return sorted.sort((a, b) => b.num_agents - a.num_agents);
|
||||
}
|
||||
}, [creatorsData, clientSortBy]);
|
||||
|
||||
const agentsCount = agents?.length ?? 0;
|
||||
const creatorsCount = creators?.length ?? 0;
|
||||
const totalCount = agentsCount + creatorsCount;
|
||||
|
||||
const handleFilterChange = (value: string) => {
|
||||
if (value === "agents") {
|
||||
setShowAgents(true);
|
||||
setShowCreators(false);
|
||||
} else if (value === "creators") {
|
||||
setShowAgents(false);
|
||||
setShowCreators(true);
|
||||
} else {
|
||||
setShowAgents(true);
|
||||
setShowCreators(true);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSortChange = (sortValue: string) => {
|
||||
setClientSortBy(sortValue);
|
||||
};
|
||||
|
||||
return {
|
||||
agents,
|
||||
creators,
|
||||
handleFilterChange,
|
||||
handleSortChange,
|
||||
agentsCount,
|
||||
creatorsCount,
|
||||
totalCount,
|
||||
showAgents,
|
||||
showCreators,
|
||||
isAgentsLoading,
|
||||
isCreatorsLoading,
|
||||
isAgentsError,
|
||||
isCreatorsError,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,27 @@
|
||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||
|
||||
export const MainSearchResultPageLoading = () => {
|
||||
return (
|
||||
<div className="w-full">
|
||||
<div className="mx-auto min-h-screen max-w-[1440px] px-10 lg:min-w-[1440px]">
|
||||
<div className="mt-8 flex items-center">
|
||||
<div className="flex-1">
|
||||
<Skeleton className="mb-2 h-5 w-32 bg-neutral-200 dark:bg-neutral-700" />
|
||||
<Skeleton className="h-8 w-64 bg-neutral-200 dark:bg-neutral-700" />
|
||||
</div>
|
||||
<div className="flex-none">
|
||||
<Skeleton className="h-[60px] w-[439px] bg-neutral-200 dark:bg-neutral-700" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-[36px] flex items-center justify-between">
|
||||
<Skeleton className="h-8 w-48 bg-neutral-200 dark:bg-neutral-700" />
|
||||
<Skeleton className="h-8 w-32 bg-neutral-200 dark:bg-neutral-700" />
|
||||
</div>
|
||||
<div className="mt-20 flex flex-col items-center justify-center">
|
||||
<Skeleton className="mb-4 h-6 w-40 bg-neutral-200 dark:bg-neutral-700" />
|
||||
<Skeleton className="h-6 w-80 bg-neutral-200 dark:bg-neutral-700" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,14 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { use, useCallback, useEffect, useState } from "react";
|
||||
import { AgentsSection } from "@/components/__legacy__/composite/AgentsSection";
|
||||
import { SearchBar } from "@/components/__legacy__/SearchBar";
|
||||
import { FeaturedCreators } from "@/components/__legacy__/composite/FeaturedCreators";
|
||||
import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
import { SearchFilterChips } from "@/components/__legacy__/SearchFilterChips";
|
||||
import { SortDropdown } from "@/components/__legacy__/SortDropdown";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { Creator, StoreAgent } from "@/lib/autogpt-server-api";
|
||||
import { use } from "react";
|
||||
import { MainSearchResultPage } from "../components/MainSearchResultPage/MainSearchResultPage";
|
||||
|
||||
type MarketplaceSearchPageSearchParams = { searchTerm?: string; sort?: string };
|
||||
|
||||
@@ -18,171 +11,9 @@ export default function MarketplaceSearchPage({
|
||||
searchParams: Promise<MarketplaceSearchPageSearchParams>;
|
||||
}) {
|
||||
return (
|
||||
<SearchResults
|
||||
<MainSearchResultPage
|
||||
searchTerm={use(searchParams).searchTerm || ""}
|
||||
sort={use(searchParams).sort || "trending"}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function SearchResults({
|
||||
searchTerm,
|
||||
sort,
|
||||
}: {
|
||||
searchTerm: string;
|
||||
sort: string;
|
||||
}): React.ReactElement {
|
||||
const [showAgents, setShowAgents] = useState(true);
|
||||
const [showCreators, setShowCreators] = useState(true);
|
||||
const [agents, setAgents] = useState<StoreAgent[]>([]);
|
||||
const [creators, setCreators] = useState<Creator[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const api = useBackendAPI();
|
||||
|
||||
useEffect(() => {
|
||||
const fetchData = async () => {
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
const [agentsRes, creatorsRes] = await Promise.all([
|
||||
api.getStoreAgents({
|
||||
search_query: searchTerm,
|
||||
sorted_by: sort,
|
||||
}),
|
||||
api.getStoreCreators({
|
||||
search_query: searchTerm,
|
||||
}),
|
||||
]);
|
||||
|
||||
setAgents(agentsRes.agents || []);
|
||||
setCreators(creatorsRes.creators || []);
|
||||
} catch (error) {
|
||||
console.error("Error fetching data:", error);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchData();
|
||||
}, [api, searchTerm, sort]);
|
||||
|
||||
const agentsCount = agents.length;
|
||||
const creatorsCount = creators.length;
|
||||
const totalCount = agentsCount + creatorsCount;
|
||||
|
||||
const handleFilterChange = (value: string) => {
|
||||
if (value === "agents") {
|
||||
setShowAgents(true);
|
||||
setShowCreators(false);
|
||||
} else if (value === "creators") {
|
||||
setShowAgents(false);
|
||||
setShowCreators(true);
|
||||
} else {
|
||||
setShowAgents(true);
|
||||
setShowCreators(true);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSortChange = useCallback(
|
||||
(sortValue: string) => {
|
||||
let sortBy = "recent";
|
||||
if (sortValue === "runs") {
|
||||
sortBy = "runs";
|
||||
} else if (sortValue === "rating") {
|
||||
sortBy = "rating";
|
||||
}
|
||||
|
||||
const sortedAgents = [...agents].sort((a, b) => {
|
||||
if (sortBy === "runs") {
|
||||
return b.runs - a.runs;
|
||||
} else if (sortBy === "rating") {
|
||||
return b.rating - a.rating;
|
||||
} else {
|
||||
return (
|
||||
new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime()
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
const sortedCreators = [...creators].sort((a, b) => {
|
||||
if (sortBy === "runs") {
|
||||
return b.agent_runs - a.agent_runs;
|
||||
} else if (sortBy === "rating") {
|
||||
return b.agent_rating - a.agent_rating;
|
||||
} else {
|
||||
// Creators don't have updated_at, sort by number of agents as fallback
|
||||
return b.num_agents - a.num_agents;
|
||||
}
|
||||
});
|
||||
|
||||
setAgents(sortedAgents);
|
||||
setCreators(sortedCreators);
|
||||
},
|
||||
[agents, creators],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="w-full">
|
||||
<div className="mx-auto min-h-screen max-w-[1440px] px-10 lg:min-w-[1440px]">
|
||||
<div className="mt-8 flex items-center">
|
||||
<div className="flex-1">
|
||||
<h2 className="text-base font-medium leading-normal text-neutral-800 dark:text-neutral-200">
|
||||
Results for:
|
||||
</h2>
|
||||
<h1 className="font-poppins text-2xl font-semibold leading-[32px] text-neutral-800 dark:text-neutral-100">
|
||||
{searchTerm}
|
||||
</h1>
|
||||
</div>
|
||||
<div className="flex-none">
|
||||
<SearchBar width="w-[439px]" height="h-[60px]" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{isLoading ? (
|
||||
<div className="mt-20 flex flex-col items-center justify-center">
|
||||
<p className="text-neutral-500 dark:text-neutral-400">Loading...</p>
|
||||
</div>
|
||||
) : totalCount > 0 ? (
|
||||
<>
|
||||
<div className="mt-[36px] flex items-center justify-between">
|
||||
<SearchFilterChips
|
||||
totalCount={totalCount}
|
||||
agentsCount={agentsCount}
|
||||
creatorsCount={creatorsCount}
|
||||
onFilterChange={handleFilterChange}
|
||||
/>
|
||||
<SortDropdown onSort={handleSortChange} />
|
||||
</div>
|
||||
{/* Content section */}
|
||||
<div className="min-h-[500px] max-w-[1440px]">
|
||||
{showAgents && agentsCount > 0 && (
|
||||
<div className="mt-[36px]">
|
||||
<AgentsSection agents={agents} sectionTitle="Agents" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{showAgents && agentsCount > 0 && creatorsCount > 0 && (
|
||||
<Separator />
|
||||
)}
|
||||
{showCreators && creatorsCount > 0 && (
|
||||
<FeaturedCreators
|
||||
featuredCreators={creators}
|
||||
title="Creators"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<div className="mt-20 flex flex-col items-center justify-center">
|
||||
<h3 className="mb-2 text-xl font-medium text-neutral-600 dark:text-neutral-300">
|
||||
No results found
|
||||
</h3>
|
||||
<p className="text-neutral-500 dark:text-neutral-400">
|
||||
Try adjusting your search terms or filters
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@
|
||||
"get": {
|
||||
"tags": ["v1", "integrations"],
|
||||
"summary": "List Credentials",
|
||||
"operationId": "getV1ListCredentials",
|
||||
"operationId": "getV1List credentials",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
@@ -123,7 +123,7 @@
|
||||
"$ref": "#/components/schemas/CredentialsMetaResponse"
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Response Getv1Listcredentials"
|
||||
"title": "Response Getv1List Credentials"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -184,7 +184,7 @@
|
||||
"post": {
|
||||
"tags": ["v1", "integrations"],
|
||||
"summary": "Create Credentials",
|
||||
"operationId": "postV1CreateCredentials",
|
||||
"operationId": "postV1Create credentials",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
@@ -246,7 +246,7 @@
|
||||
"host_scoped": "#/components/schemas/HostScopedCredentials-Output"
|
||||
}
|
||||
},
|
||||
"title": "Response Postv1Createcredentials"
|
||||
"title": "Response Postv1Create Credentials"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -268,8 +268,8 @@
|
||||
"/api/integrations/{provider}/credentials/{cred_id}": {
|
||||
"get": {
|
||||
"tags": ["v1", "integrations"],
|
||||
"summary": "Get Credential",
|
||||
"operationId": "getV1GetCredential",
|
||||
"summary": "Get Specific Credential By ID",
|
||||
"operationId": "getV1Get specific credential by id",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
@@ -315,7 +315,7 @@
|
||||
"host_scoped": "#/components/schemas/HostScopedCredentials-Output"
|
||||
}
|
||||
},
|
||||
"title": "Response Getv1Getcredential"
|
||||
"title": "Response Getv1Get Specific Credential By Id"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
import { NuqsAdapter } from "nuqs/adapters/next/app";
|
||||
import { TooltipProvider } from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import CredentialsProvider from "@/providers/agent-credentials/credentials-provider";
|
||||
import { SentryUserTracker } from "@/components/monitor/SentryUserTracker";
|
||||
|
||||
export function Providers({ children, ...props }: ThemeProviderProps) {
|
||||
const queryClient = getQueryClient();
|
||||
@@ -20,6 +21,7 @@ export function Providers({ children, ...props }: ThemeProviderProps) {
|
||||
<NuqsAdapter>
|
||||
<NextThemesProvider {...props}>
|
||||
<BackendAPIProvider>
|
||||
<SentryUserTracker />
|
||||
<CredentialsProvider>
|
||||
<LaunchDarklyProvider>
|
||||
<OnboardingProvider>
|
||||
|
||||
@@ -10,9 +10,9 @@ import {
|
||||
import { ChevronDownIcon } from "@radix-ui/react-icons";
|
||||
|
||||
const sortOptions: SortOption[] = [
|
||||
{ label: "Most Recent", value: "recent" },
|
||||
// { label: "Most Recent", value: "recent" }, // we are not using this for now because we don't have date data from the backend
|
||||
{ label: "Most Runs", value: "runs" },
|
||||
{ label: "Highest Rated", value: "rating" },
|
||||
// { label: "Highest Rated", value: "rating" }, // we are not using this for now because we don't have rating data from the backend
|
||||
];
|
||||
|
||||
interface SortOption {
|
||||
|
||||
@@ -39,6 +39,7 @@ export interface TaskGroup {
|
||||
|
||||
export default function Wallet() {
|
||||
const { state, updateState } = useOnboarding();
|
||||
|
||||
const groups = useMemo<TaskGroup[]>(() => {
|
||||
return [
|
||||
{
|
||||
@@ -348,10 +349,11 @@ export default function Wallet() {
|
||||
</div>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className={cn(
|
||||
"absolute -right-[7.9rem] -top-[3.2rem] z-50 w-[28.5rem] px-[0.625rem] py-2",
|
||||
"rounded-xl border-zinc-100 bg-white shadow-[0_3px_3px] shadow-zinc-200",
|
||||
)}
|
||||
side="bottom"
|
||||
align="end"
|
||||
sideOffset={12}
|
||||
collisionPadding={16}
|
||||
className={cn("z-50 w-[28.5rem] px-[0.625rem] py-2")}
|
||||
>
|
||||
{/* Header */}
|
||||
<div className="mx-1 flex items-center justify-between border-b border-zinc-200 pb-3">
|
||||
|
||||
@@ -40,27 +40,57 @@ export function Turnstile({
|
||||
return;
|
||||
}
|
||||
|
||||
// Create script element
|
||||
const script = document.createElement("script");
|
||||
script.src =
|
||||
const scriptSrc =
|
||||
"https://challenges.cloudflare.com/turnstile/v0/api.js?render=explicit";
|
||||
|
||||
// If a script already exists, reuse it and attach listeners
|
||||
const existingScript = Array.from(document.scripts).find(
|
||||
(s) => s.src === scriptSrc,
|
||||
);
|
||||
|
||||
if (existingScript) {
|
||||
if (window.turnstile) {
|
||||
setLoaded(true);
|
||||
return;
|
||||
}
|
||||
|
||||
const handleLoad: EventListener = () => {
|
||||
setLoaded(true);
|
||||
};
|
||||
const handleError: EventListener = () => {
|
||||
onError?.(new Error("Failed to load Turnstile script"));
|
||||
};
|
||||
|
||||
existingScript.addEventListener("load", handleLoad);
|
||||
existingScript.addEventListener("error", handleError);
|
||||
|
||||
return () => {
|
||||
existingScript.removeEventListener("load", handleLoad);
|
||||
existingScript.removeEventListener("error", handleError);
|
||||
};
|
||||
}
|
||||
|
||||
// Create a single script element if not present and keep it in the document
|
||||
const script = document.createElement("script");
|
||||
script.src = scriptSrc;
|
||||
script.async = true;
|
||||
script.defer = true;
|
||||
|
||||
script.onload = () => {
|
||||
const handleLoad: EventListener = () => {
|
||||
setLoaded(true);
|
||||
};
|
||||
|
||||
script.onerror = () => {
|
||||
const handleError: EventListener = () => {
|
||||
onError?.(new Error("Failed to load Turnstile script"));
|
||||
};
|
||||
|
||||
script.addEventListener("load", handleLoad);
|
||||
script.addEventListener("error", handleError);
|
||||
|
||||
document.head.appendChild(script);
|
||||
|
||||
return () => {
|
||||
if (document.head.contains(script)) {
|
||||
document.head.removeChild(script);
|
||||
}
|
||||
script.removeEventListener("load", handleLoad);
|
||||
script.removeEventListener("error", handleError);
|
||||
};
|
||||
}, [onError, shouldRender]);
|
||||
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
|
||||
/**
|
||||
* SentryUserTracker component sets user context in Sentry for error tracking.
|
||||
* This component should be placed high in the component tree to ensure user
|
||||
* context is available for all error reports.
|
||||
*
|
||||
* It automatically:
|
||||
* - Sets user context when a user logs in
|
||||
* - Clears user context when a user logs out
|
||||
* - Updates context when user data changes
|
||||
*/
|
||||
export function SentryUserTracker() {
|
||||
const { user, isUserLoading } = useSupabase();
|
||||
|
||||
useEffect(() => {
|
||||
if (user) {
|
||||
// Wait until user loading is complete before setting user context
|
||||
if (isUserLoading) return;
|
||||
|
||||
// Set user context for Sentry error tracking
|
||||
Sentry.setUser({
|
||||
id: user.id,
|
||||
email: user.email ?? undefined,
|
||||
// Add custom attributes
|
||||
...(user.role && { role: user.role }),
|
||||
});
|
||||
} else {
|
||||
// Always clear user context when user is null, regardless of loading state
|
||||
// This ensures logout properly clears the context immediately
|
||||
Sentry.setUser(null);
|
||||
}
|
||||
}, [user, isUserLoading]);
|
||||
|
||||
// This component doesn't render anything
|
||||
return null;
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import type { ReactNode } from "react";
|
||||
import { useMemo } from "react";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { BehaveAs, getBehaveAs } from "@/lib/utils";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||
@@ -45,7 +46,10 @@ export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
||||
clientSideID={clientId}
|
||||
context={context}
|
||||
reactOptions={{ useCamelCaseFlagKeys: false }}
|
||||
options={{ bootstrap: "localStorage" }}
|
||||
options={{
|
||||
bootstrap: "localStorage",
|
||||
inspectors: [Sentry.buildLaunchDarklyFlagUsedHandler()],
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</LDProvider>
|
||||
|
||||
@@ -98,7 +98,7 @@ export class MarketplacePage extends BasePage {
|
||||
}
|
||||
|
||||
async searchAndNavigate(query: string, page: Page) {
|
||||
const searchInput = await this.getSearchInput(page);
|
||||
const searchInput = (await this.getSearchInput(page)).first();
|
||||
await searchInput.fill(query);
|
||||
await searchInput.press("Enter");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user