mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' into abhi/ci-chromatic
This commit is contained in:
@@ -66,6 +66,13 @@ MEDIA_GCS_BUCKET_NAME=
|
||||
## and tunnel it to your locally running backend.
|
||||
PLATFORM_BASE_URL=http://localhost:3000
|
||||
|
||||
## Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
## This is the backend secret key
|
||||
TURNSTILE_SECRET_KEY=
|
||||
## This is the verify URL
|
||||
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
|
||||
|
||||
## == INTEGRATION CREDENTIALS == ##
|
||||
# Each set of server side credentials is required for the corresponding 3rd party
|
||||
# integration to work.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -11,7 +11,7 @@ from backend.data.block import (
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsMetaInput, SchemaField
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -23,17 +23,21 @@ class AgentExecutorBlock(Block):
|
||||
graph_id: str = SchemaField(description="Graph ID")
|
||||
graph_version: int = SchemaField(description="Graph Version")
|
||||
|
||||
data: BlockInput = SchemaField(description="Input data for the graph")
|
||||
inputs: BlockInput = SchemaField(description="Input data for the graph")
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = SchemaField(default=None, hidden=True)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
return data.get("input_schema", {})
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data.get("data", {})
|
||||
return data.get("inputs", {})
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
@@ -67,7 +71,8 @@ class AgentExecutorBlock(Block):
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
inputs=input_data.data,
|
||||
inputs=input_data.inputs,
|
||||
node_credentials_input_map=input_data.node_credentials_input_map,
|
||||
)
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
|
||||
@@ -276,7 +276,7 @@ class NodeExecutionResult(BaseModel):
|
||||
node_exec_id=self.node_exec_id,
|
||||
node_id=self.node_id,
|
||||
block_id=self.block_id,
|
||||
data=self.input_data,
|
||||
inputs=self.input_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -691,7 +691,7 @@ class NodeExecutionEntry(BaseModel):
|
||||
node_exec_id: str
|
||||
node_id: str
|
||||
block_id: str
|
||||
data: BlockInput
|
||||
inputs: BlockInput
|
||||
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
|
||||
@@ -199,11 +199,6 @@ class BaseGraph(BaseDbModel):
|
||||
)
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
@@ -236,6 +231,15 @@ class BaseGraph(BaseDbModel):
|
||||
"required": [p.name for p in schema_fields if p.value is None],
|
||||
}
|
||||
|
||||
|
||||
class Graph(BaseGraph):
|
||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
@@ -314,17 +318,14 @@ class BaseGraph(BaseDbModel):
|
||||
),
|
||||
(node.id, field_name),
|
||||
)
|
||||
for node in self.nodes
|
||||
for graph in [self] + self.sub_graphs
|
||||
for node in graph.nodes
|
||||
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class Graph(BaseGraph):
|
||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs, only used in export
|
||||
|
||||
|
||||
class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
@@ -400,7 +401,7 @@ class GraphModel(Graph):
|
||||
if node.block_id != AgentExecutorBlock().id:
|
||||
continue
|
||||
node.input_default["user_id"] = user_id
|
||||
node.input_default.setdefault("data", {})
|
||||
node.input_default.setdefault("inputs", {})
|
||||
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
|
||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||
|
||||
@@ -689,6 +690,7 @@ async def get_graph(
|
||||
version: int | None = None,
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
include_subgraphs: bool = False,
|
||||
) -> GraphModel | None:
|
||||
"""
|
||||
Retrieves a graph from the DB.
|
||||
@@ -725,7 +727,7 @@ async def get_graph(
|
||||
):
|
||||
return None
|
||||
|
||||
if for_export:
|
||||
if include_subgraphs or for_export:
|
||||
sub_graphs = await get_sub_graphs(graph)
|
||||
return GraphModel.from_db(
|
||||
graph=graph,
|
||||
|
||||
@@ -8,14 +8,18 @@ import threading
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
GraphExecutionStats,
|
||||
NodeExecutionStats,
|
||||
)
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
@@ -139,6 +143,9 @@ def execute_node(
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> ExecutionStream:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -186,7 +193,7 @@ def execute_node(
|
||||
)
|
||||
|
||||
# Sanity check: validate the execution input.
|
||||
input_data, error = validate_exec(node, data.data, resolve_input=False)
|
||||
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
|
||||
if input_data is None:
|
||||
log_metadata.error(f"Skip execution, input validation error: {error}")
|
||||
push_output("error", error)
|
||||
@@ -196,8 +203,12 @@ def execute_node(
|
||||
# Re-shape the input data for agent block.
|
||||
# AgentExecutorBlock specially separate the node input_data & its input_default.
|
||||
if isinstance(node_block, AgentExecutorBlock):
|
||||
input_data = {**node.input_default, "data": input_data}
|
||||
data.data = input_data
|
||||
_input_data = AgentExecutorBlock.Input(**node.input_default)
|
||||
_input_data.inputs = input_data
|
||||
if node_credentials_input_map:
|
||||
_input_data.node_credentials_input_map = node_credentials_input_map
|
||||
input_data = _input_data.model_dump()
|
||||
data.inputs = input_data
|
||||
|
||||
# Execute the node
|
||||
input_data_str = json.dumps(input_data)
|
||||
@@ -244,6 +255,7 @@ def execute_node(
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
log_metadata=log_metadata,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
):
|
||||
yield execution
|
||||
|
||||
@@ -262,6 +274,7 @@ def execute_node(
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
log_metadata=log_metadata,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
):
|
||||
yield execution
|
||||
|
||||
@@ -291,6 +304,7 @@ def _enqueue_next_nodes(
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
|
||||
) -> list[NodeExecutionEntry]:
|
||||
def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
@@ -306,7 +320,7 @@ def _enqueue_next_nodes(
|
||||
node_exec_id=node_exec_id,
|
||||
node_id=node_id,
|
||||
block_id=block_id,
|
||||
data=data,
|
||||
inputs=data,
|
||||
)
|
||||
|
||||
def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
|
||||
@@ -347,6 +361,15 @@ def _enqueue_next_nodes(
|
||||
for name in static_link_names:
|
||||
next_node_input[name] = latest_execution.input_data.get(name)
|
||||
|
||||
# Apply node credentials overrides
|
||||
node_credentials = None
|
||||
if node_credentials_input_map and (
|
||||
node_credentials := node_credentials_input_map.get(next_node.id)
|
||||
):
|
||||
next_node_input.update(
|
||||
{k: v.model_dump() for k, v in node_credentials.items()}
|
||||
)
|
||||
|
||||
# Validate the input data for the next node.
|
||||
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
|
||||
suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
|
||||
@@ -389,6 +412,12 @@ def _enqueue_next_nodes(
|
||||
for input_name in static_link_names:
|
||||
idata[input_name] = next_node_input[input_name]
|
||||
|
||||
# Apply node credentials overrides
|
||||
if node_credentials:
|
||||
idata.update(
|
||||
{k: v.model_dump() for k, v in node_credentials.items()}
|
||||
)
|
||||
|
||||
idata, msg = validate_exec(next_node, idata)
|
||||
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
|
||||
if not idata:
|
||||
@@ -478,6 +507,9 @@ class Executor:
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecutionEntry],
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -490,7 +522,7 @@ class Executor:
|
||||
|
||||
execution_stats = NodeExecutionStats()
|
||||
timing_info, _ = cls._on_node_execution(
|
||||
q, node_exec, log_metadata, execution_stats
|
||||
q, node_exec, log_metadata, execution_stats, node_credentials_input_map
|
||||
)
|
||||
execution_stats.walltime = timing_info.wall_time
|
||||
execution_stats.cputime = timing_info.cpu_time
|
||||
@@ -510,6 +542,9 @@ class Executor:
|
||||
node_exec: NodeExecutionEntry,
|
||||
log_metadata: LogMetadata,
|
||||
stats: NodeExecutionStats | None = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
):
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
@@ -518,6 +553,7 @@ class Executor:
|
||||
creds_manager=cls.creds_manager,
|
||||
data=node_exec,
|
||||
execution_stats=stats,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
):
|
||||
q.add(execution)
|
||||
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
|
||||
@@ -625,7 +661,9 @@ class Executor:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return
|
||||
|
||||
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.data)
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost > 0:
|
||||
cls.db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -797,7 +835,7 @@ class Executor:
|
||||
if (node_creds_map := graph_exec.node_credentials_input_map) and (
|
||||
node_field_creds_map := node_creds_map.get(node_id)
|
||||
):
|
||||
queued_node_exec.data.update(
|
||||
queued_node_exec.inputs.update(
|
||||
{
|
||||
field_name: creds_meta.model_dump()
|
||||
for field_name, creds_meta in node_field_creds_map.items()
|
||||
@@ -807,7 +845,7 @@ class Executor:
|
||||
# Initiate node execution
|
||||
running_executions[queued_node_exec.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, queued_node_exec),
|
||||
(queue, queued_node_exec, node_creds_map),
|
||||
callback=make_exec_callback(queued_node_exec),
|
||||
)
|
||||
|
||||
|
||||
@@ -258,7 +258,7 @@ def validate_exec(
|
||||
If the data is valid, the first element will be the resolved input data, and
|
||||
the second element will be the block name.
|
||||
"""
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
node_block = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
@@ -608,7 +608,10 @@ async def add_graph_execution_async(
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
""" # noqa
|
||||
graph: GraphModel | None = await get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
@@ -668,6 +671,9 @@ def add_graph_execution(
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
@@ -680,6 +686,7 @@ def add_graph_execution(
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
node_credentials_input_map: Credentials inputs to use in the execution, mapped to specific nodes.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
@@ -687,12 +694,15 @@ def add_graph_execution(
|
||||
"""
|
||||
db = get_db_client()
|
||||
graph: GraphModel | None = db.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
node_credentials_input_map = (
|
||||
node_credentials_input_map = node_credentials_input_map or (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
|
||||
@@ -27,6 +27,7 @@ import backend.server.v2.library.routes
|
||||
import backend.server.v2.otto.routes
|
||||
import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
import backend.server.v2.turnstile.routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import LlmModel
|
||||
@@ -119,6 +120,9 @@ app.include_router(
|
||||
app.include_router(
|
||||
backend.server.v2.otto.routes.router, tags=["v2"], prefix="/api/otto"
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.turnstile.routes.router, tags=["v2"], prefix="/api/turnstile"
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
backend.server.routers.postmark.postmark.router,
|
||||
|
||||
@@ -422,7 +422,11 @@ async def get_graph(
|
||||
for_export: bool = False,
|
||||
) -> graph_db.GraphModel:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id, version, user_id=user_id, for_export=for_export
|
||||
graph_id,
|
||||
version,
|
||||
user_id=user_id,
|
||||
for_export=for_export,
|
||||
include_subgraphs=True, # needed to construct full credentials input schema
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TurnstileVerifyRequest(BaseModel):
|
||||
"""Request model for verifying a Turnstile token."""
|
||||
|
||||
token: str = Field(description="The Turnstile token to verify")
|
||||
action: Optional[str] = Field(
|
||||
default=None, description="The action that the user is attempting to perform"
|
||||
)
|
||||
|
||||
|
||||
class TurnstileVerifyResponse(BaseModel):
|
||||
"""Response model for the Turnstile verification endpoint."""
|
||||
|
||||
success: bool = Field(description="Whether the token verification was successful")
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="Error message if verification failed"
|
||||
)
|
||||
challenge_timestamp: Optional[str] = Field(
|
||||
default=None, description="Timestamp of the challenge (ISO format)"
|
||||
)
|
||||
hostname: Optional[str] = Field(
|
||||
default=None, description="Hostname of the site where the challenge was solved"
|
||||
)
|
||||
action: Optional[str] = Field(
|
||||
default=None, description="The action associated with this verification"
|
||||
)
|
||||
108
autogpt_platform/backend/backend/server/v2/turnstile/routes.py
Normal file
108
autogpt_platform/backend/backend/server/v2/turnstile/routes.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import logging
|
||||
|
||||
import aiohttp
|
||||
from fastapi import APIRouter
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import TurnstileVerifyRequest, TurnstileVerifyResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
settings = Settings()
|
||||
|
||||
|
||||
@router.post("/verify", response_model=TurnstileVerifyResponse)
|
||||
async def verify_turnstile_token(
|
||||
request: TurnstileVerifyRequest,
|
||||
) -> TurnstileVerifyResponse:
|
||||
"""
|
||||
Verify a Cloudflare Turnstile token.
|
||||
This endpoint verifies a token returned by the Cloudflare Turnstile challenge
|
||||
on the client side. It returns whether the verification was successful.
|
||||
"""
|
||||
logger.info(f"Verifying Turnstile token for action: {request.action}")
|
||||
return await verify_token(request)
|
||||
|
||||
|
||||
async def verify_token(request: TurnstileVerifyRequest) -> TurnstileVerifyResponse:
|
||||
"""
|
||||
Verify a Cloudflare Turnstile token by making a request to the Cloudflare API.
|
||||
"""
|
||||
# Get the secret key from settings
|
||||
turnstile_secret_key = settings.secrets.turnstile_secret_key
|
||||
turnstile_verify_url = settings.secrets.turnstile_verify_url
|
||||
|
||||
if not turnstile_secret_key:
|
||||
logger.error("Turnstile secret key is not configured")
|
||||
return TurnstileVerifyResponse(
|
||||
success=False,
|
||||
error="CONFIGURATION_ERROR",
|
||||
challenge_timestamp=None,
|
||||
hostname=None,
|
||||
action=None,
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
"secret": turnstile_secret_key,
|
||||
"response": request.token,
|
||||
}
|
||||
|
||||
if request.action:
|
||||
payload["action"] = request.action
|
||||
|
||||
logger.debug(f"Verifying Turnstile token with action: {request.action}")
|
||||
|
||||
async with session.post(
|
||||
turnstile_verify_url,
|
||||
data=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Turnstile API error: {error_text}")
|
||||
return TurnstileVerifyResponse(
|
||||
success=False,
|
||||
error=f"API_ERROR: {response.status}",
|
||||
challenge_timestamp=None,
|
||||
hostname=None,
|
||||
action=None,
|
||||
)
|
||||
|
||||
data = await response.json()
|
||||
logger.debug(f"Turnstile API response: {data}")
|
||||
|
||||
# Parse the response and return a structured object
|
||||
return TurnstileVerifyResponse(
|
||||
success=data.get("success", False),
|
||||
error=(
|
||||
data.get("error-codes", None)[0]
|
||||
if data.get("error-codes")
|
||||
else None
|
||||
),
|
||||
challenge_timestamp=data.get("challenge_timestamp"),
|
||||
hostname=data.get("hostname"),
|
||||
action=data.get("action"),
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Connection error to Turnstile API: {str(e)}")
|
||||
return TurnstileVerifyResponse(
|
||||
success=False,
|
||||
error=f"CONNECTION_ERROR: {str(e)}",
|
||||
challenge_timestamp=None,
|
||||
hostname=None,
|
||||
action=None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in Turnstile verification: {str(e)}")
|
||||
return TurnstileVerifyResponse(
|
||||
success=False,
|
||||
error=f"UNEXPECTED_ERROR: {str(e)}",
|
||||
challenge_timestamp=None,
|
||||
hostname=None,
|
||||
action=None,
|
||||
)
|
||||
@@ -350,6 +350,16 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
description="The secret key to use for the unsubscribe user by token",
|
||||
)
|
||||
|
||||
# Cloudflare Turnstile credentials
|
||||
turnstile_secret_key: str = Field(
|
||||
default="",
|
||||
description="Cloudflare Turnstile backend secret key",
|
||||
)
|
||||
turnstile_verify_url: str = Field(
|
||||
default="https://challenges.cloudflare.com/turnstile/v0/siteverify",
|
||||
description="Cloudflare Turnstile verify URL",
|
||||
)
|
||||
|
||||
# OAuth server credentials for integrations
|
||||
# --8<-- [start:OAuthServerCredentialsExample]
|
||||
github_client_id: str = Field(default="", description="GitHub OAuth client ID")
|
||||
|
||||
@@ -34,7 +34,7 @@ async def spend_credits(entry: NodeExecutionEntry) -> int:
|
||||
if not block:
|
||||
raise RuntimeError(f"Block {entry.block_id} not found")
|
||||
|
||||
cost, matching_filter = block_usage_cost(block=block, input_data=entry.data)
|
||||
cost, matching_filter = block_usage_cost(block=block, input_data=entry.inputs)
|
||||
await user_credit.spend_credits(
|
||||
entry.user_id,
|
||||
cost,
|
||||
@@ -67,7 +67,7 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
graph_exec_id="test_graph_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
block_id=AITextGeneratorBlock().id,
|
||||
data={
|
||||
inputs={
|
||||
"model": "gpt-4-turbo",
|
||||
"credentials": {
|
||||
"id": openai_credentials.id,
|
||||
@@ -87,7 +87,7 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
graph_exec_id="test_graph_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
block_id=AITextGeneratorBlock().id,
|
||||
data={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||
inputs={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||
),
|
||||
)
|
||||
assert spending_amount_2 == 0
|
||||
|
||||
@@ -25,3 +25,8 @@ GA_MEASUREMENT_ID=G-FH2XK2W4GN
|
||||
# When running locally, set NEXT_PUBLIC_BEHAVE_AS=CLOUD to use the a locally hosted marketplace (as is typical in development, and the cloud deployment), otherwise set it to LOCAL to have the marketplace open in a new tab
|
||||
NEXT_PUBLIC_BEHAVE_AS=LOCAL
|
||||
NEXT_PUBLIC_SHOW_BILLING_PAGE=false
|
||||
|
||||
## Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
## This is the frontend site key
|
||||
NEXT_PUBLIC_CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
|
||||
@@ -6,6 +6,7 @@ import * as Sentry from "@sentry/nextjs";
|
||||
import getServerSupabase from "@/lib/supabase/getServerSupabase";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||
import { verifyTurnstileToken } from "@/lib/turnstile";
|
||||
|
||||
export async function logout() {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
@@ -39,7 +40,10 @@ async function shouldShowOnboarding() {
|
||||
);
|
||||
}
|
||||
|
||||
export async function login(values: z.infer<typeof loginFormSchema>) {
|
||||
export async function login(
|
||||
values: z.infer<typeof loginFormSchema>,
|
||||
turnstileToken: string,
|
||||
) {
|
||||
return await Sentry.withServerActionInstrumentation("login", {}, async () => {
|
||||
const supabase = getServerSupabase();
|
||||
const api = new BackendAPI();
|
||||
@@ -48,6 +52,12 @@ export async function login(values: z.infer<typeof loginFormSchema>) {
|
||||
redirect("/error");
|
||||
}
|
||||
|
||||
// Verify Turnstile token if provided
|
||||
const success = await verifyTurnstileToken(turnstileToken, "login");
|
||||
if (!success) {
|
||||
return "CAPTCHA verification failed. Please try again.";
|
||||
}
|
||||
|
||||
// We are sure that the values are of the correct type because zod validates the form
|
||||
const { data, error } = await supabase.auth.signInWithPassword(values);
|
||||
|
||||
|
||||
@@ -24,9 +24,11 @@ import {
|
||||
AuthFeedback,
|
||||
AuthBottomText,
|
||||
PasswordInput,
|
||||
Turnstile,
|
||||
} from "@/components/auth";
|
||||
import { loginFormSchema } from "@/types/auth";
|
||||
import { getBehaveAs } from "@/lib/utils";
|
||||
import { useTurnstile } from "@/hooks/useTurnstile";
|
||||
|
||||
export default function LoginPage() {
|
||||
const { supabase, user, isUserLoading } = useSupabase();
|
||||
@@ -34,6 +36,12 @@ export default function LoginPage() {
|
||||
const router = useRouter();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
const turnstile = useTurnstile({
|
||||
action: "login",
|
||||
autoVerify: false,
|
||||
resetOnError: true,
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof loginFormSchema>>({
|
||||
resolver: zodResolver(loginFormSchema),
|
||||
defaultValues: {
|
||||
@@ -65,15 +73,23 @@ export default function LoginPage() {
|
||||
return;
|
||||
}
|
||||
|
||||
const error = await login(data);
|
||||
if (!turnstile.verified) {
|
||||
setFeedback("Please complete the CAPTCHA challenge.");
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const error = await login(data, turnstile.token as string);
|
||||
setIsLoading(false);
|
||||
if (error) {
|
||||
setFeedback(error);
|
||||
// Always reset the turnstile on any error
|
||||
turnstile.reset();
|
||||
return;
|
||||
}
|
||||
setFeedback(null);
|
||||
},
|
||||
[form],
|
||||
[form, turnstile],
|
||||
);
|
||||
|
||||
if (user) {
|
||||
@@ -140,6 +156,17 @@ export default function LoginPage() {
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* Turnstile CAPTCHA Component */}
|
||||
<Turnstile
|
||||
siteKey={turnstile.siteKey}
|
||||
onVerify={turnstile.handleVerify}
|
||||
onExpire={turnstile.handleExpire}
|
||||
onError={turnstile.handleError}
|
||||
action="login"
|
||||
shouldRender={turnstile.shouldRender}
|
||||
/>
|
||||
|
||||
<AuthButton
|
||||
onClick={() => onLogin(form.getValues())}
|
||||
isLoading={isLoading}
|
||||
|
||||
@@ -3,8 +3,9 @@ import getServerSupabase from "@/lib/supabase/getServerSupabase";
|
||||
import { redirect } from "next/navigation";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { headers } from "next/headers";
|
||||
import { verifyTurnstileToken } from "@/lib/turnstile";
|
||||
|
||||
export async function sendResetEmail(email: string) {
|
||||
export async function sendResetEmail(email: string, turnstileToken: string) {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"sendResetEmail",
|
||||
{},
|
||||
@@ -20,6 +21,15 @@ export async function sendResetEmail(email: string) {
|
||||
redirect("/error");
|
||||
}
|
||||
|
||||
// Verify Turnstile token if provided
|
||||
const success = await verifyTurnstileToken(
|
||||
turnstileToken,
|
||||
"reset_password",
|
||||
);
|
||||
if (!success) {
|
||||
return "CAPTCHA verification failed. Please try again.";
|
||||
}
|
||||
|
||||
const { error } = await supabase.auth.resetPasswordForEmail(email, {
|
||||
redirectTo: `${origin}/reset_password`,
|
||||
});
|
||||
@@ -34,7 +44,7 @@ export async function sendResetEmail(email: string) {
|
||||
);
|
||||
}
|
||||
|
||||
export async function changePassword(password: string) {
|
||||
export async function changePassword(password: string, turnstileToken: string) {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"changePassword",
|
||||
{},
|
||||
@@ -45,6 +55,15 @@ export async function changePassword(password: string) {
|
||||
redirect("/error");
|
||||
}
|
||||
|
||||
// Verify Turnstile token if provided
|
||||
const success = await verifyTurnstileToken(
|
||||
turnstileToken,
|
||||
"change_password",
|
||||
);
|
||||
if (!success) {
|
||||
return "CAPTCHA verification failed. Please try again.";
|
||||
}
|
||||
|
||||
const { error } = await supabase.auth.updateUser({ password });
|
||||
|
||||
if (error) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
AuthButton,
|
||||
AuthFeedback,
|
||||
PasswordInput,
|
||||
Turnstile,
|
||||
} from "@/components/auth";
|
||||
import {
|
||||
Form,
|
||||
@@ -25,6 +26,7 @@ import { z } from "zod";
|
||||
import { changePassword, sendResetEmail } from "./actions";
|
||||
import Spinner from "@/components/Spinner";
|
||||
import { getBehaveAs } from "@/lib/utils";
|
||||
import { useTurnstile } from "@/hooks/useTurnstile";
|
||||
|
||||
export default function ResetPasswordPage() {
|
||||
const { supabase, user, isUserLoading } = useSupabase();
|
||||
@@ -33,6 +35,18 @@ export default function ResetPasswordPage() {
|
||||
const [isError, setIsError] = useState(false);
|
||||
const [disabled, setDisabled] = useState(false);
|
||||
|
||||
const sendEmailTurnstile = useTurnstile({
|
||||
action: "reset_password",
|
||||
autoVerify: false,
|
||||
resetOnError: true,
|
||||
});
|
||||
|
||||
const changePasswordTurnstile = useTurnstile({
|
||||
action: "change_password",
|
||||
autoVerify: false,
|
||||
resetOnError: true,
|
||||
});
|
||||
|
||||
const sendEmailForm = useForm<z.infer<typeof sendEmailFormSchema>>({
|
||||
resolver: zodResolver(sendEmailFormSchema),
|
||||
defaultValues: {
|
||||
@@ -58,11 +72,22 @@ export default function ResetPasswordPage() {
|
||||
return;
|
||||
}
|
||||
|
||||
const error = await sendResetEmail(data.email);
|
||||
if (!sendEmailTurnstile.verified) {
|
||||
setFeedback("Please complete the CAPTCHA challenge.");
|
||||
setIsError(true);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const error = await sendResetEmail(
|
||||
data.email,
|
||||
sendEmailTurnstile.token as string,
|
||||
);
|
||||
setIsLoading(false);
|
||||
if (error) {
|
||||
setFeedback(error);
|
||||
setIsError(true);
|
||||
sendEmailTurnstile.reset();
|
||||
return;
|
||||
}
|
||||
setDisabled(true);
|
||||
@@ -71,7 +96,7 @@ export default function ResetPasswordPage() {
|
||||
);
|
||||
setIsError(false);
|
||||
},
|
||||
[sendEmailForm],
|
||||
[sendEmailForm, sendEmailTurnstile],
|
||||
);
|
||||
|
||||
const onChangePassword = useCallback(
|
||||
@@ -84,17 +109,28 @@ export default function ResetPasswordPage() {
|
||||
return;
|
||||
}
|
||||
|
||||
const error = await changePassword(data.password);
|
||||
if (!changePasswordTurnstile.verified) {
|
||||
setFeedback("Please complete the CAPTCHA challenge.");
|
||||
setIsError(true);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const error = await changePassword(
|
||||
data.password,
|
||||
changePasswordTurnstile.token as string,
|
||||
);
|
||||
setIsLoading(false);
|
||||
if (error) {
|
||||
setFeedback(error);
|
||||
setIsError(true);
|
||||
changePasswordTurnstile.reset();
|
||||
return;
|
||||
}
|
||||
setFeedback("Password changed successfully. Redirecting to login.");
|
||||
setIsError(false);
|
||||
},
|
||||
[changePasswordForm],
|
||||
[changePasswordForm, changePasswordTurnstile],
|
||||
);
|
||||
|
||||
if (isUserLoading) {
|
||||
@@ -145,6 +181,17 @@ export default function ResetPasswordPage() {
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* Turnstile CAPTCHA Component for password change */}
|
||||
<Turnstile
|
||||
siteKey={changePasswordTurnstile.siteKey}
|
||||
onVerify={changePasswordTurnstile.handleVerify}
|
||||
onExpire={changePasswordTurnstile.handleExpire}
|
||||
onError={changePasswordTurnstile.handleError}
|
||||
action="change_password"
|
||||
shouldRender={changePasswordTurnstile.shouldRender}
|
||||
/>
|
||||
|
||||
<AuthButton
|
||||
onClick={() => onChangePassword(changePasswordForm.getValues())}
|
||||
isLoading={isLoading}
|
||||
@@ -175,6 +222,17 @@ export default function ResetPasswordPage() {
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* Turnstile CAPTCHA Component for reset email */}
|
||||
<Turnstile
|
||||
siteKey={sendEmailTurnstile.siteKey}
|
||||
onVerify={sendEmailTurnstile.handleVerify}
|
||||
onExpire={sendEmailTurnstile.handleExpire}
|
||||
onError={sendEmailTurnstile.handleError}
|
||||
action="reset_password"
|
||||
shouldRender={sendEmailTurnstile.shouldRender}
|
||||
/>
|
||||
|
||||
<AuthButton
|
||||
onClick={() => onSendEmail(sendEmailForm.getValues())}
|
||||
isLoading={isLoading}
|
||||
|
||||
@@ -6,8 +6,12 @@ import * as Sentry from "@sentry/nextjs";
|
||||
import getServerSupabase from "@/lib/supabase/getServerSupabase";
|
||||
import { signupFormSchema } from "@/types/auth";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { verifyTurnstileToken } from "@/lib/turnstile";
|
||||
|
||||
export async function signup(values: z.infer<typeof signupFormSchema>) {
|
||||
export async function signup(
|
||||
values: z.infer<typeof signupFormSchema>,
|
||||
turnstileToken: string,
|
||||
) {
|
||||
"use server";
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"signup",
|
||||
@@ -19,6 +23,12 @@ export async function signup(values: z.infer<typeof signupFormSchema>) {
|
||||
redirect("/error");
|
||||
}
|
||||
|
||||
// Verify Turnstile token if provided
|
||||
const success = await verifyTurnstileToken(turnstileToken, "signup");
|
||||
if (!success) {
|
||||
return "CAPTCHA verification failed. Please try again.";
|
||||
}
|
||||
|
||||
// We are sure that the values are of the correct type because zod validates the form
|
||||
const { data, error } = await supabase.auth.signUp(values);
|
||||
|
||||
|
||||
@@ -25,10 +25,12 @@ import {
|
||||
AuthButton,
|
||||
AuthBottomText,
|
||||
PasswordInput,
|
||||
Turnstile,
|
||||
} from "@/components/auth";
|
||||
import AuthFeedback from "@/components/auth/AuthFeedback";
|
||||
import { signupFormSchema } from "@/types/auth";
|
||||
import { getBehaveAs } from "@/lib/utils";
|
||||
import { useTurnstile } from "@/hooks/useTurnstile";
|
||||
|
||||
export default function SignupPage() {
|
||||
const { supabase, user, isUserLoading } = useSupabase();
|
||||
@@ -37,6 +39,12 @@ export default function SignupPage() {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
//TODO: Remove after closed beta
|
||||
|
||||
const turnstile = useTurnstile({
|
||||
action: "signup",
|
||||
autoVerify: false,
|
||||
resetOnError: true,
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof signupFormSchema>>({
|
||||
resolver: zodResolver(signupFormSchema),
|
||||
defaultValues: {
|
||||
@@ -56,20 +64,28 @@ export default function SignupPage() {
|
||||
return;
|
||||
}
|
||||
|
||||
const error = await signup(data);
|
||||
if (!turnstile.verified) {
|
||||
setFeedback("Please complete the CAPTCHA challenge.");
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const error = await signup(data, turnstile.token as string);
|
||||
setIsLoading(false);
|
||||
if (error) {
|
||||
if (error === "user_already_exists") {
|
||||
setFeedback("User with this email already exists");
|
||||
turnstile.reset();
|
||||
return;
|
||||
} else {
|
||||
setFeedback(error);
|
||||
turnstile.reset();
|
||||
}
|
||||
return;
|
||||
}
|
||||
setFeedback(null);
|
||||
},
|
||||
[form],
|
||||
[form, turnstile],
|
||||
);
|
||||
|
||||
if (user) {
|
||||
@@ -141,6 +157,17 @@ export default function SignupPage() {
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* Turnstile CAPTCHA Component */}
|
||||
<Turnstile
|
||||
siteKey={turnstile.siteKey}
|
||||
onVerify={turnstile.handleVerify}
|
||||
onExpire={turnstile.handleExpire}
|
||||
onError={turnstile.handleError}
|
||||
action="signup"
|
||||
shouldRender={turnstile.shouldRender}
|
||||
/>
|
||||
|
||||
<AuthButton
|
||||
onClick={() => onSignup(form.getValues())}
|
||||
isLoading={isLoading}
|
||||
|
||||
140
autogpt_platform/frontend/src/components/auth/Turnstile.tsx
Normal file
140
autogpt_platform/frontend/src/components/auth/Turnstile.tsx
Normal file
@@ -0,0 +1,140 @@
|
||||
"use client";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface TurnstileProps {
|
||||
siteKey: string;
|
||||
onVerify: (token: string) => void;
|
||||
onExpire?: () => void;
|
||||
onError?: (error: Error) => void;
|
||||
action?: string;
|
||||
className?: string;
|
||||
id?: string;
|
||||
shouldRender?: boolean;
|
||||
}
|
||||
|
||||
export function Turnstile({
|
||||
siteKey,
|
||||
onVerify,
|
||||
onExpire,
|
||||
onError,
|
||||
action,
|
||||
className,
|
||||
id = "cf-turnstile",
|
||||
shouldRender = true,
|
||||
}: TurnstileProps) {
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const widgetIdRef = useRef<string | null>(null);
|
||||
const [loaded, setLoaded] = useState(false);
|
||||
|
||||
// Load the Turnstile script
|
||||
useEffect(() => {
|
||||
if (typeof window === "undefined" || !shouldRender) return;
|
||||
|
||||
// Skip if already loaded
|
||||
if (window.turnstile) {
|
||||
setLoaded(true);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create script element
|
||||
const script = document.createElement("script");
|
||||
script.src =
|
||||
"https://challenges.cloudflare.com/turnstile/v0/api.js?render=explicit";
|
||||
script.async = true;
|
||||
script.defer = true;
|
||||
|
||||
script.onload = () => {
|
||||
setLoaded(true);
|
||||
};
|
||||
|
||||
script.onerror = () => {
|
||||
onError?.(new Error("Failed to load Turnstile script"));
|
||||
};
|
||||
|
||||
document.head.appendChild(script);
|
||||
|
||||
return () => {
|
||||
if (document.head.contains(script)) {
|
||||
document.head.removeChild(script);
|
||||
}
|
||||
};
|
||||
}, [onError, shouldRender]);
|
||||
|
||||
// Initialize and render the widget when script is loaded
|
||||
useEffect(() => {
|
||||
if (!loaded || !containerRef.current || !window.turnstile || !shouldRender)
|
||||
return;
|
||||
|
||||
// Reset any existing widget
|
||||
if (widgetIdRef.current && window.turnstile) {
|
||||
window.turnstile.reset(widgetIdRef.current);
|
||||
}
|
||||
|
||||
// Render a new widget
|
||||
if (window.turnstile) {
|
||||
widgetIdRef.current = window.turnstile.render(containerRef.current, {
|
||||
sitekey: siteKey,
|
||||
callback: (token: string) => {
|
||||
onVerify(token);
|
||||
},
|
||||
"expired-callback": () => {
|
||||
onExpire?.();
|
||||
},
|
||||
"error-callback": () => {
|
||||
onError?.(new Error("Turnstile widget encountered an error"));
|
||||
},
|
||||
action,
|
||||
});
|
||||
}
|
||||
|
||||
return () => {
|
||||
if (widgetIdRef.current && window.turnstile) {
|
||||
window.turnstile.remove(widgetIdRef.current);
|
||||
widgetIdRef.current = null;
|
||||
}
|
||||
};
|
||||
}, [loaded, siteKey, onVerify, onExpire, onError, action, shouldRender]);
|
||||
|
||||
// Method to reset the widget manually
|
||||
const reset = useCallback(() => {
|
||||
if (loaded && widgetIdRef.current && window.turnstile && shouldRender) {
|
||||
window.turnstile.reset(widgetIdRef.current);
|
||||
}
|
||||
}, [loaded, shouldRender]);
|
||||
|
||||
// If shouldRender is false, don't render anything
|
||||
if (!shouldRender) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
id={id}
|
||||
ref={containerRef}
|
||||
className={cn("my-4 flex items-center justify-center", className)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Add TypeScript interface to Window to include turnstile property
|
||||
declare global {
|
||||
interface Window {
|
||||
turnstile?: {
|
||||
render: (
|
||||
container: HTMLElement,
|
||||
options: {
|
||||
sitekey: string;
|
||||
callback: (token: string) => void;
|
||||
"expired-callback"?: () => void;
|
||||
"error-callback"?: () => void;
|
||||
action?: string;
|
||||
},
|
||||
) => string;
|
||||
reset: (widgetId: string) => void;
|
||||
remove: (widgetId: string) => void;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export default Turnstile;
|
||||
@@ -4,6 +4,7 @@ import AuthCard from "./AuthCard";
|
||||
import AuthFeedback from "./AuthFeedback";
|
||||
import AuthHeader from "./AuthHeader";
|
||||
import { PasswordInput } from "./PasswordInput";
|
||||
import Turnstile from "./Turnstile";
|
||||
|
||||
export {
|
||||
AuthBottomText,
|
||||
@@ -12,4 +13,5 @@ export {
|
||||
AuthFeedback,
|
||||
AuthHeader,
|
||||
PasswordInput,
|
||||
Turnstile,
|
||||
};
|
||||
|
||||
169
autogpt_platform/frontend/src/hooks/useTurnstile.ts
Normal file
169
autogpt_platform/frontend/src/hooks/useTurnstile.ts
Normal file
@@ -0,0 +1,169 @@
|
||||
import { useState, useCallback, useEffect } from "react";
|
||||
import { verifyTurnstileToken } from "@/lib/turnstile";
|
||||
import { BehaveAs, getBehaveAs } from "@/lib/utils";
|
||||
|
||||
interface UseTurnstileOptions {
|
||||
action?: string;
|
||||
autoVerify?: boolean;
|
||||
onSuccess?: () => void;
|
||||
onError?: (error: Error) => void;
|
||||
resetOnError?: boolean;
|
||||
}
|
||||
|
||||
interface UseTurnstileResult {
|
||||
token: string | null;
|
||||
verifying: boolean;
|
||||
verified: boolean;
|
||||
error: Error | null;
|
||||
handleVerify: (token: string) => Promise<boolean>;
|
||||
handleExpire: () => void;
|
||||
handleError: (error: Error) => void;
|
||||
reset: () => void;
|
||||
siteKey: string;
|
||||
shouldRender: boolean;
|
||||
}
|
||||
|
||||
const TURNSTILE_SITE_KEY =
|
||||
process.env.NEXT_PUBLIC_CLOUDFLARE_TURNSTILE_SITE_KEY || "";
|
||||
|
||||
/**
|
||||
* Custom hook for managing Turnstile state in forms
|
||||
*/
|
||||
export function useTurnstile({
|
||||
action,
|
||||
autoVerify = true,
|
||||
onSuccess,
|
||||
onError,
|
||||
resetOnError = false,
|
||||
}: UseTurnstileOptions = {}): UseTurnstileResult {
|
||||
const [token, setToken] = useState<string | null>(null);
|
||||
const [verifying, setVerifying] = useState(false);
|
||||
const [verified, setVerified] = useState(false);
|
||||
const [error, setError] = useState<Error | null>(null);
|
||||
const [shouldRender, setShouldRender] = useState(false);
|
||||
const [widgetId, setWidgetId] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const behaveAs = getBehaveAs();
|
||||
const hasTurnstileKey = !!TURNSTILE_SITE_KEY;
|
||||
|
||||
setShouldRender(behaveAs === BehaveAs.CLOUD && hasTurnstileKey);
|
||||
|
||||
if (behaveAs !== BehaveAs.CLOUD || !hasTurnstileKey) {
|
||||
setVerified(true);
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (token && !autoVerify && shouldRender) {
|
||||
setVerified(true);
|
||||
}
|
||||
}, [token, autoVerify, shouldRender]);
|
||||
|
||||
useEffect(() => {
|
||||
if (typeof window !== "undefined" && window.turnstile) {
|
||||
const originalRender = window.turnstile.render;
|
||||
window.turnstile.render = (container, options) => {
|
||||
const id = originalRender(container, options);
|
||||
setWidgetId(id);
|
||||
return id;
|
||||
};
|
||||
}
|
||||
}, []);
|
||||
|
||||
const reset = useCallback(() => {
|
||||
if (shouldRender && window.turnstile && widgetId) {
|
||||
window.turnstile.reset(widgetId);
|
||||
|
||||
// Always reset the state when reset is called
|
||||
setToken(null);
|
||||
setVerified(false);
|
||||
setVerifying(false);
|
||||
setError(null);
|
||||
}
|
||||
}, [shouldRender, widgetId]);
|
||||
|
||||
const handleVerify = useCallback(
|
||||
async (newToken: string) => {
|
||||
if (!shouldRender) {
|
||||
return true;
|
||||
}
|
||||
|
||||
setToken(newToken);
|
||||
setError(null);
|
||||
|
||||
if (autoVerify) {
|
||||
setVerifying(true);
|
||||
|
||||
try {
|
||||
const success = await verifyTurnstileToken(newToken, action);
|
||||
setVerified(success);
|
||||
|
||||
if (success && onSuccess) {
|
||||
onSuccess();
|
||||
} else if (!success) {
|
||||
const newError = new Error("Turnstile verification failed");
|
||||
setError(newError);
|
||||
if (onError) onError(newError);
|
||||
if (resetOnError) {
|
||||
setVerified(false);
|
||||
}
|
||||
}
|
||||
|
||||
setVerifying(false);
|
||||
return success;
|
||||
} catch (err) {
|
||||
const newError =
|
||||
err instanceof Error
|
||||
? err
|
||||
: new Error("Unknown error during verification");
|
||||
setError(newError);
|
||||
if (resetOnError) {
|
||||
setVerified(false);
|
||||
}
|
||||
setVerifying(false);
|
||||
if (onError) onError(newError);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
setVerified(true);
|
||||
}
|
||||
|
||||
return true;
|
||||
},
|
||||
[action, autoVerify, onSuccess, onError, resetOnError, shouldRender],
|
||||
);
|
||||
|
||||
const handleExpire = useCallback(() => {
|
||||
if (shouldRender) {
|
||||
setToken(null);
|
||||
setVerified(false);
|
||||
}
|
||||
}, [shouldRender]);
|
||||
|
||||
const handleError = useCallback(
|
||||
(err: Error) => {
|
||||
if (shouldRender) {
|
||||
setError(err);
|
||||
if (resetOnError) {
|
||||
setVerified(false);
|
||||
}
|
||||
if (onError) onError(err);
|
||||
}
|
||||
},
|
||||
[onError, shouldRender, resetOnError],
|
||||
);
|
||||
|
||||
return {
|
||||
token,
|
||||
verifying,
|
||||
verified,
|
||||
error,
|
||||
handleVerify,
|
||||
handleExpire,
|
||||
handleError,
|
||||
reset,
|
||||
siteKey: TURNSTILE_SITE_KEY,
|
||||
shouldRender,
|
||||
};
|
||||
}
|
||||
42
autogpt_platform/frontend/src/lib/turnstile.ts
Normal file
42
autogpt_platform/frontend/src/lib/turnstile.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Utility functions for working with Cloudflare Turnstile
|
||||
*/
|
||||
import { BehaveAs, getBehaveAs } from "@/lib/utils";
|
||||
|
||||
export async function verifyTurnstileToken(
|
||||
token: string,
|
||||
action?: string,
|
||||
): Promise<boolean> {
|
||||
// Skip verification in local development
|
||||
const behaveAs = getBehaveAs();
|
||||
if (behaveAs !== BehaveAs.CLOUD) {
|
||||
return true;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`${process.env.NEXT_PUBLIC_AGPT_SERVER_URL}/turnstile/verify`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
token,
|
||||
action,
|
||||
}),
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
console.error("Turnstile verification failed:", await response.text());
|
||||
return false;
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data.success === true;
|
||||
} catch (error) {
|
||||
console.error("Error verifying Turnstile token:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user