Merge branch 'dev' into abhi/ci-chromatic

This commit is contained in:
Abhimanyu Yadav
2025-05-08 11:45:11 +05:30
committed by GitHub
24 changed files with 776 additions and 49 deletions

View File

@@ -66,6 +66,13 @@ MEDIA_GCS_BUCKET_NAME=
## and tunnel it to your locally running backend.
PLATFORM_BASE_URL=http://localhost:3000
## Cloudflare Turnstile (CAPTCHA) Configuration
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
## This is the backend secret key
TURNSTILE_SECRET_KEY=
## This is the verify URL
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
## == INTEGRATION CREDENTIALS == ##
# Each set of server side credentials is required for the corresponding 3rd party
# integration to work.

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any
from typing import Any, Optional
from backend.data.block import (
Block,
@@ -11,7 +11,7 @@ from backend.data.block import (
get_block,
)
from backend.data.execution import ExecutionStatus
from backend.data.model import SchemaField
from backend.data.model import CredentialsMetaInput, SchemaField
from backend.util import json
logger = logging.getLogger(__name__)
@@ -23,17 +23,21 @@ class AgentExecutorBlock(Block):
graph_id: str = SchemaField(description="Graph ID")
graph_version: int = SchemaField(description="Graph Version")
data: BlockInput = SchemaField(description="Input data for the graph")
inputs: BlockInput = SchemaField(description="Input data for the graph")
input_schema: dict = SchemaField(description="Input schema for the graph")
output_schema: dict = SchemaField(description="Output schema for the graph")
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = SchemaField(default=None, hidden=True)
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
return data.get("input_schema", {})
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data.get("data", {})
return data.get("inputs", {})
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
@@ -67,7 +71,8 @@ class AgentExecutorBlock(Block):
graph_id=input_data.graph_id,
graph_version=input_data.graph_version,
user_id=input_data.user_id,
inputs=input_data.data,
inputs=input_data.inputs,
node_credentials_input_map=input_data.node_credentials_input_map,
)
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
logger.info(f"Starting execution of {log_id}")

View File

@@ -276,7 +276,7 @@ class NodeExecutionResult(BaseModel):
node_exec_id=self.node_exec_id,
node_id=self.node_id,
block_id=self.block_id,
data=self.input_data,
inputs=self.input_data,
)
@@ -691,7 +691,7 @@ class NodeExecutionEntry(BaseModel):
node_exec_id: str
node_id: str
block_id: str
data: BlockInput
inputs: BlockInput
class ExecutionQueue(Generic[T]):

View File

@@ -199,11 +199,6 @@ class BaseGraph(BaseDbModel):
)
)
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
return self._credentials_input_schema.jsonschema()
@staticmethod
def _generate_schema(
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
@@ -236,6 +231,15 @@ class BaseGraph(BaseDbModel):
"required": [p.name for p in schema_fields if p.value is None],
}
class Graph(BaseGraph):
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
return self._credentials_input_schema.jsonschema()
@property
def _credentials_input_schema(self) -> type[BlockSchema]:
graph_credentials_inputs = self.aggregate_credentials_inputs()
@@ -314,17 +318,14 @@ class BaseGraph(BaseDbModel):
),
(node.id, field_name),
)
for node in self.nodes
for graph in [self] + self.sub_graphs
for node in graph.nodes
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
)
)
}
class Graph(BaseGraph):
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs, only used in export
class GraphModel(Graph):
user_id: str
nodes: list[NodeModel] = [] # type: ignore
@@ -400,7 +401,7 @@ class GraphModel(Graph):
if node.block_id != AgentExecutorBlock().id:
continue
node.input_default["user_id"] = user_id
node.input_default.setdefault("data", {})
node.input_default.setdefault("inputs", {})
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
node.input_default["graph_id"] = graph_id_map[graph_id]
@@ -689,6 +690,7 @@ async def get_graph(
version: int | None = None,
user_id: str | None = None,
for_export: bool = False,
include_subgraphs: bool = False,
) -> GraphModel | None:
"""
Retrieves a graph from the DB.
@@ -725,7 +727,7 @@ async def get_graph(
):
return None
if for_export:
if include_subgraphs or for_export:
sub_graphs = await get_sub_graphs(graph)
return GraphModel.from_db(
graph=graph,

View File

@@ -8,14 +8,18 @@ import threading
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties
from redis.lock import Lock as RedisLock
from backend.blocks.io import AgentOutputBlock
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.model import (
CredentialsMetaInput,
GraphExecutionStats,
NodeExecutionStats,
)
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
@@ -139,6 +143,9 @@ def execute_node(
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry,
execution_stats: NodeExecutionStats | None = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> ExecutionStream:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -186,7 +193,7 @@ def execute_node(
)
# Sanity check: validate the execution input.
input_data, error = validate_exec(node, data.data, resolve_input=False)
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
if input_data is None:
log_metadata.error(f"Skip execution, input validation error: {error}")
push_output("error", error)
@@ -196,8 +203,12 @@ def execute_node(
# Re-shape the input data for agent block.
# AgentExecutorBlock specially separate the node input_data & its input_default.
if isinstance(node_block, AgentExecutorBlock):
input_data = {**node.input_default, "data": input_data}
data.data = input_data
_input_data = AgentExecutorBlock.Input(**node.input_default)
_input_data.inputs = input_data
if node_credentials_input_map:
_input_data.node_credentials_input_map = node_credentials_input_map
input_data = _input_data.model_dump()
data.inputs = input_data
# Execute the node
input_data_str = json.dumps(input_data)
@@ -244,6 +255,7 @@ def execute_node(
graph_exec_id=graph_exec_id,
graph_id=graph_id,
log_metadata=log_metadata,
node_credentials_input_map=node_credentials_input_map,
):
yield execution
@@ -262,6 +274,7 @@ def execute_node(
graph_exec_id=graph_exec_id,
graph_id=graph_id,
log_metadata=log_metadata,
node_credentials_input_map=node_credentials_input_map,
):
yield execution
@@ -291,6 +304,7 @@ def _enqueue_next_nodes(
graph_exec_id: str,
graph_id: str,
log_metadata: LogMetadata,
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
) -> list[NodeExecutionEntry]:
def add_enqueued_execution(
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
@@ -306,7 +320,7 @@ def _enqueue_next_nodes(
node_exec_id=node_exec_id,
node_id=node_id,
block_id=block_id,
data=data,
inputs=data,
)
def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
@@ -347,6 +361,15 @@ def _enqueue_next_nodes(
for name in static_link_names:
next_node_input[name] = latest_execution.input_data.get(name)
# Apply node credentials overrides
node_credentials = None
if node_credentials_input_map and (
node_credentials := node_credentials_input_map.get(next_node.id)
):
next_node_input.update(
{k: v.model_dump() for k, v in node_credentials.items()}
)
# Validate the input data for the next node.
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
@@ -389,6 +412,12 @@ def _enqueue_next_nodes(
for input_name in static_link_names:
idata[input_name] = next_node_input[input_name]
# Apply node credentials overrides
if node_credentials:
idata.update(
{k: v.model_dump() for k, v in node_credentials.items()}
)
idata, msg = validate_exec(next_node, idata)
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
if not idata:
@@ -478,6 +507,9 @@ class Executor:
cls,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> NodeExecutionStats:
log_metadata = LogMetadata(
user_id=node_exec.user_id,
@@ -490,7 +522,7 @@ class Executor:
execution_stats = NodeExecutionStats()
timing_info, _ = cls._on_node_execution(
q, node_exec, log_metadata, execution_stats
q, node_exec, log_metadata, execution_stats, node_credentials_input_map
)
execution_stats.walltime = timing_info.wall_time
execution_stats.cputime = timing_info.cpu_time
@@ -510,6 +542,9 @@ class Executor:
node_exec: NodeExecutionEntry,
log_metadata: LogMetadata,
stats: NodeExecutionStats | None = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
):
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
@@ -518,6 +553,7 @@ class Executor:
creds_manager=cls.creds_manager,
data=node_exec,
execution_stats=stats,
node_credentials_input_map=node_credentials_input_map,
):
q.add(execution)
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
@@ -625,7 +661,9 @@ class Executor:
logger.error(f"Block {node_exec.block_id} not found.")
return
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.data)
cost, matching_filter = block_usage_cost(
block=block, input_data=node_exec.inputs
)
if cost > 0:
cls.db_client.spend_credits(
user_id=node_exec.user_id,
@@ -797,7 +835,7 @@ class Executor:
if (node_creds_map := graph_exec.node_credentials_input_map) and (
node_field_creds_map := node_creds_map.get(node_id)
):
queued_node_exec.data.update(
queued_node_exec.inputs.update(
{
field_name: creds_meta.model_dump()
for field_name, creds_meta in node_field_creds_map.items()
@@ -807,7 +845,7 @@ class Executor:
# Initiate node execution
running_executions[queued_node_exec.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, queued_node_exec),
(queue, queued_node_exec, node_creds_map),
callback=make_exec_callback(queued_node_exec),
)

View File

@@ -258,7 +258,7 @@ def validate_exec(
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block: Block | None = get_block(node.block_id)
node_block = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
@@ -608,7 +608,10 @@ async def add_graph_execution_async(
ValueError: If the graph is not found or if there are validation errors.
""" # noqa
graph: GraphModel | None = await get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
graph_id=graph_id,
user_id=user_id,
version=graph_version,
include_subgraphs=True,
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
@@ -668,6 +671,9 @@ def add_graph_execution(
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
@@ -680,6 +686,7 @@ def add_graph_execution(
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
node_credentials_input_map: Credentials inputs to use in the execution, mapped to specific nodes.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
@@ -687,12 +694,15 @@ def add_graph_execution(
"""
db = get_db_client()
graph: GraphModel | None = db.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
graph_id=graph_id,
user_id=user_id,
version=graph_version,
include_subgraphs=True,
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
node_credentials_input_map = (
node_credentials_input_map = node_credentials_input_map or (
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else None

View File

@@ -27,6 +27,7 @@ import backend.server.v2.library.routes
import backend.server.v2.otto.routes
import backend.server.v2.store.model
import backend.server.v2.store.routes
import backend.server.v2.turnstile.routes
import backend.util.service
import backend.util.settings
from backend.blocks.llm import LlmModel
@@ -119,6 +120,9 @@ app.include_router(
app.include_router(
backend.server.v2.otto.routes.router, tags=["v2"], prefix="/api/otto"
)
app.include_router(
backend.server.v2.turnstile.routes.router, tags=["v2"], prefix="/api/turnstile"
)
app.include_router(
backend.server.routers.postmark.postmark.router,

View File

@@ -422,7 +422,11 @@ async def get_graph(
for_export: bool = False,
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(
graph_id, version, user_id=user_id, for_export=for_export
graph_id,
version,
user_id=user_id,
for_export=for_export,
include_subgraphs=True, # needed to construct full credentials input schema
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")

View File

@@ -0,0 +1,30 @@
from typing import Optional
from pydantic import BaseModel, Field
class TurnstileVerifyRequest(BaseModel):
"""Request model for verifying a Turnstile token."""
token: str = Field(description="The Turnstile token to verify")
action: Optional[str] = Field(
default=None, description="The action that the user is attempting to perform"
)
class TurnstileVerifyResponse(BaseModel):
"""Response model for the Turnstile verification endpoint."""
success: bool = Field(description="Whether the token verification was successful")
error: Optional[str] = Field(
default=None, description="Error message if verification failed"
)
challenge_timestamp: Optional[str] = Field(
default=None, description="Timestamp of the challenge (ISO format)"
)
hostname: Optional[str] = Field(
default=None, description="Hostname of the site where the challenge was solved"
)
action: Optional[str] = Field(
default=None, description="The action associated with this verification"
)

View File

@@ -0,0 +1,108 @@
import logging
import aiohttp
from fastapi import APIRouter
from backend.util.settings import Settings
from .models import TurnstileVerifyRequest, TurnstileVerifyResponse
logger = logging.getLogger(__name__)
router = APIRouter()
settings = Settings()
@router.post("/verify", response_model=TurnstileVerifyResponse)
async def verify_turnstile_token(
request: TurnstileVerifyRequest,
) -> TurnstileVerifyResponse:
"""
Verify a Cloudflare Turnstile token.
This endpoint verifies a token returned by the Cloudflare Turnstile challenge
on the client side. It returns whether the verification was successful.
"""
logger.info(f"Verifying Turnstile token for action: {request.action}")
return await verify_token(request)
async def verify_token(request: TurnstileVerifyRequest) -> TurnstileVerifyResponse:
"""
Verify a Cloudflare Turnstile token by making a request to the Cloudflare API.
"""
# Get the secret key from settings
turnstile_secret_key = settings.secrets.turnstile_secret_key
turnstile_verify_url = settings.secrets.turnstile_verify_url
if not turnstile_secret_key:
logger.error("Turnstile secret key is not configured")
return TurnstileVerifyResponse(
success=False,
error="CONFIGURATION_ERROR",
challenge_timestamp=None,
hostname=None,
action=None,
)
try:
async with aiohttp.ClientSession() as session:
payload = {
"secret": turnstile_secret_key,
"response": request.token,
}
if request.action:
payload["action"] = request.action
logger.debug(f"Verifying Turnstile token with action: {request.action}")
async with session.post(
turnstile_verify_url,
data=payload,
timeout=aiohttp.ClientTimeout(total=10),
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Turnstile API error: {error_text}")
return TurnstileVerifyResponse(
success=False,
error=f"API_ERROR: {response.status}",
challenge_timestamp=None,
hostname=None,
action=None,
)
data = await response.json()
logger.debug(f"Turnstile API response: {data}")
# Parse the response and return a structured object
return TurnstileVerifyResponse(
success=data.get("success", False),
error=(
data.get("error-codes", None)[0]
if data.get("error-codes")
else None
),
challenge_timestamp=data.get("challenge_timestamp"),
hostname=data.get("hostname"),
action=data.get("action"),
)
except aiohttp.ClientError as e:
logger.error(f"Connection error to Turnstile API: {str(e)}")
return TurnstileVerifyResponse(
success=False,
error=f"CONNECTION_ERROR: {str(e)}",
challenge_timestamp=None,
hostname=None,
action=None,
)
except Exception as e:
logger.error(f"Unexpected error in Turnstile verification: {str(e)}")
return TurnstileVerifyResponse(
success=False,
error=f"UNEXPECTED_ERROR: {str(e)}",
challenge_timestamp=None,
hostname=None,
action=None,
)

View File

@@ -350,6 +350,16 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
description="The secret key to use for the unsubscribe user by token",
)
# Cloudflare Turnstile credentials
turnstile_secret_key: str = Field(
default="",
description="Cloudflare Turnstile backend secret key",
)
turnstile_verify_url: str = Field(
default="https://challenges.cloudflare.com/turnstile/v0/siteverify",
description="Cloudflare Turnstile verify URL",
)
# OAuth server credentials for integrations
# --8<-- [start:OAuthServerCredentialsExample]
github_client_id: str = Field(default="", description="GitHub OAuth client ID")

View File

@@ -34,7 +34,7 @@ async def spend_credits(entry: NodeExecutionEntry) -> int:
if not block:
raise RuntimeError(f"Block {entry.block_id} not found")
cost, matching_filter = block_usage_cost(block=block, input_data=entry.data)
cost, matching_filter = block_usage_cost(block=block, input_data=entry.inputs)
await user_credit.spend_credits(
entry.user_id,
cost,
@@ -67,7 +67,7 @@ async def test_block_credit_usage(server: SpinTestServer):
graph_exec_id="test_graph_exec",
node_exec_id="test_node_exec",
block_id=AITextGeneratorBlock().id,
data={
inputs={
"model": "gpt-4-turbo",
"credentials": {
"id": openai_credentials.id,
@@ -87,7 +87,7 @@ async def test_block_credit_usage(server: SpinTestServer):
graph_exec_id="test_graph_exec",
node_exec_id="test_node_exec",
block_id=AITextGeneratorBlock().id,
data={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
inputs={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
),
)
assert spending_amount_2 == 0

View File

@@ -25,3 +25,8 @@ GA_MEASUREMENT_ID=G-FH2XK2W4GN
# When running locally, set NEXT_PUBLIC_BEHAVE_AS=CLOUD to use the a locally hosted marketplace (as is typical in development, and the cloud deployment), otherwise set it to LOCAL to have the marketplace open in a new tab
NEXT_PUBLIC_BEHAVE_AS=LOCAL
NEXT_PUBLIC_SHOW_BILLING_PAGE=false
## Cloudflare Turnstile (CAPTCHA) Configuration
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
## This is the frontend site key
NEXT_PUBLIC_CLOUDFLARE_TURNSTILE_SITE_KEY=

View File

@@ -6,6 +6,7 @@ import * as Sentry from "@sentry/nextjs";
import getServerSupabase from "@/lib/supabase/getServerSupabase";
import BackendAPI from "@/lib/autogpt-server-api";
import { loginFormSchema, LoginProvider } from "@/types/auth";
import { verifyTurnstileToken } from "@/lib/turnstile";
export async function logout() {
return await Sentry.withServerActionInstrumentation(
@@ -39,7 +40,10 @@ async function shouldShowOnboarding() {
);
}
export async function login(values: z.infer<typeof loginFormSchema>) {
export async function login(
values: z.infer<typeof loginFormSchema>,
turnstileToken: string,
) {
return await Sentry.withServerActionInstrumentation("login", {}, async () => {
const supabase = getServerSupabase();
const api = new BackendAPI();
@@ -48,6 +52,12 @@ export async function login(values: z.infer<typeof loginFormSchema>) {
redirect("/error");
}
// Verify Turnstile token if provided
const success = await verifyTurnstileToken(turnstileToken, "login");
if (!success) {
return "CAPTCHA verification failed. Please try again.";
}
// We are sure that the values are of the correct type because zod validates the form
const { data, error } = await supabase.auth.signInWithPassword(values);

View File

@@ -24,9 +24,11 @@ import {
AuthFeedback,
AuthBottomText,
PasswordInput,
Turnstile,
} from "@/components/auth";
import { loginFormSchema } from "@/types/auth";
import { getBehaveAs } from "@/lib/utils";
import { useTurnstile } from "@/hooks/useTurnstile";
export default function LoginPage() {
const { supabase, user, isUserLoading } = useSupabase();
@@ -34,6 +36,12 @@ export default function LoginPage() {
const router = useRouter();
const [isLoading, setIsLoading] = useState(false);
const turnstile = useTurnstile({
action: "login",
autoVerify: false,
resetOnError: true,
});
const form = useForm<z.infer<typeof loginFormSchema>>({
resolver: zodResolver(loginFormSchema),
defaultValues: {
@@ -65,15 +73,23 @@ export default function LoginPage() {
return;
}
const error = await login(data);
if (!turnstile.verified) {
setFeedback("Please complete the CAPTCHA challenge.");
setIsLoading(false);
return;
}
const error = await login(data, turnstile.token as string);
setIsLoading(false);
if (error) {
setFeedback(error);
// Always reset the turnstile on any error
turnstile.reset();
return;
}
setFeedback(null);
},
[form],
[form, turnstile],
);
if (user) {
@@ -140,6 +156,17 @@ export default function LoginPage() {
</FormItem>
)}
/>
{/* Turnstile CAPTCHA Component */}
<Turnstile
siteKey={turnstile.siteKey}
onVerify={turnstile.handleVerify}
onExpire={turnstile.handleExpire}
onError={turnstile.handleError}
action="login"
shouldRender={turnstile.shouldRender}
/>
<AuthButton
onClick={() => onLogin(form.getValues())}
isLoading={isLoading}

View File

@@ -3,8 +3,9 @@ import getServerSupabase from "@/lib/supabase/getServerSupabase";
import { redirect } from "next/navigation";
import * as Sentry from "@sentry/nextjs";
import { headers } from "next/headers";
import { verifyTurnstileToken } from "@/lib/turnstile";
export async function sendResetEmail(email: string) {
export async function sendResetEmail(email: string, turnstileToken: string) {
return await Sentry.withServerActionInstrumentation(
"sendResetEmail",
{},
@@ -20,6 +21,15 @@ export async function sendResetEmail(email: string) {
redirect("/error");
}
// Verify Turnstile token if provided
const success = await verifyTurnstileToken(
turnstileToken,
"reset_password",
);
if (!success) {
return "CAPTCHA verification failed. Please try again.";
}
const { error } = await supabase.auth.resetPasswordForEmail(email, {
redirectTo: `${origin}/reset_password`,
});
@@ -34,7 +44,7 @@ export async function sendResetEmail(email: string) {
);
}
export async function changePassword(password: string) {
export async function changePassword(password: string, turnstileToken: string) {
return await Sentry.withServerActionInstrumentation(
"changePassword",
{},
@@ -45,6 +55,15 @@ export async function changePassword(password: string) {
redirect("/error");
}
// Verify Turnstile token if provided
const success = await verifyTurnstileToken(
turnstileToken,
"change_password",
);
if (!success) {
return "CAPTCHA verification failed. Please try again.";
}
const { error } = await supabase.auth.updateUser({ password });
if (error) {

View File

@@ -5,6 +5,7 @@ import {
AuthButton,
AuthFeedback,
PasswordInput,
Turnstile,
} from "@/components/auth";
import {
Form,
@@ -25,6 +26,7 @@ import { z } from "zod";
import { changePassword, sendResetEmail } from "./actions";
import Spinner from "@/components/Spinner";
import { getBehaveAs } from "@/lib/utils";
import { useTurnstile } from "@/hooks/useTurnstile";
export default function ResetPasswordPage() {
const { supabase, user, isUserLoading } = useSupabase();
@@ -33,6 +35,18 @@ export default function ResetPasswordPage() {
const [isError, setIsError] = useState(false);
const [disabled, setDisabled] = useState(false);
const sendEmailTurnstile = useTurnstile({
action: "reset_password",
autoVerify: false,
resetOnError: true,
});
const changePasswordTurnstile = useTurnstile({
action: "change_password",
autoVerify: false,
resetOnError: true,
});
const sendEmailForm = useForm<z.infer<typeof sendEmailFormSchema>>({
resolver: zodResolver(sendEmailFormSchema),
defaultValues: {
@@ -58,11 +72,22 @@ export default function ResetPasswordPage() {
return;
}
const error = await sendResetEmail(data.email);
if (!sendEmailTurnstile.verified) {
setFeedback("Please complete the CAPTCHA challenge.");
setIsError(true);
setIsLoading(false);
return;
}
const error = await sendResetEmail(
data.email,
sendEmailTurnstile.token as string,
);
setIsLoading(false);
if (error) {
setFeedback(error);
setIsError(true);
sendEmailTurnstile.reset();
return;
}
setDisabled(true);
@@ -71,7 +96,7 @@ export default function ResetPasswordPage() {
);
setIsError(false);
},
[sendEmailForm],
[sendEmailForm, sendEmailTurnstile],
);
const onChangePassword = useCallback(
@@ -84,17 +109,28 @@ export default function ResetPasswordPage() {
return;
}
const error = await changePassword(data.password);
if (!changePasswordTurnstile.verified) {
setFeedback("Please complete the CAPTCHA challenge.");
setIsError(true);
setIsLoading(false);
return;
}
const error = await changePassword(
data.password,
changePasswordTurnstile.token as string,
);
setIsLoading(false);
if (error) {
setFeedback(error);
setIsError(true);
changePasswordTurnstile.reset();
return;
}
setFeedback("Password changed successfully. Redirecting to login.");
setIsError(false);
},
[changePasswordForm],
[changePasswordForm, changePasswordTurnstile],
);
if (isUserLoading) {
@@ -145,6 +181,17 @@ export default function ResetPasswordPage() {
</FormItem>
)}
/>
{/* Turnstile CAPTCHA Component for password change */}
<Turnstile
siteKey={changePasswordTurnstile.siteKey}
onVerify={changePasswordTurnstile.handleVerify}
onExpire={changePasswordTurnstile.handleExpire}
onError={changePasswordTurnstile.handleError}
action="change_password"
shouldRender={changePasswordTurnstile.shouldRender}
/>
<AuthButton
onClick={() => onChangePassword(changePasswordForm.getValues())}
isLoading={isLoading}
@@ -175,6 +222,17 @@ export default function ResetPasswordPage() {
</FormItem>
)}
/>
{/* Turnstile CAPTCHA Component for reset email */}
<Turnstile
siteKey={sendEmailTurnstile.siteKey}
onVerify={sendEmailTurnstile.handleVerify}
onExpire={sendEmailTurnstile.handleExpire}
onError={sendEmailTurnstile.handleError}
action="reset_password"
shouldRender={sendEmailTurnstile.shouldRender}
/>
<AuthButton
onClick={() => onSendEmail(sendEmailForm.getValues())}
isLoading={isLoading}

View File

@@ -6,8 +6,12 @@ import * as Sentry from "@sentry/nextjs";
import getServerSupabase from "@/lib/supabase/getServerSupabase";
import { signupFormSchema } from "@/types/auth";
import BackendAPI from "@/lib/autogpt-server-api";
import { verifyTurnstileToken } from "@/lib/turnstile";
export async function signup(values: z.infer<typeof signupFormSchema>) {
export async function signup(
values: z.infer<typeof signupFormSchema>,
turnstileToken: string,
) {
"use server";
return await Sentry.withServerActionInstrumentation(
"signup",
@@ -19,6 +23,12 @@ export async function signup(values: z.infer<typeof signupFormSchema>) {
redirect("/error");
}
// Verify Turnstile token if provided
const success = await verifyTurnstileToken(turnstileToken, "signup");
if (!success) {
return "CAPTCHA verification failed. Please try again.";
}
// We are sure that the values are of the correct type because zod validates the form
const { data, error } = await supabase.auth.signUp(values);

View File

@@ -25,10 +25,12 @@ import {
AuthButton,
AuthBottomText,
PasswordInput,
Turnstile,
} from "@/components/auth";
import AuthFeedback from "@/components/auth/AuthFeedback";
import { signupFormSchema } from "@/types/auth";
import { getBehaveAs } from "@/lib/utils";
import { useTurnstile } from "@/hooks/useTurnstile";
export default function SignupPage() {
const { supabase, user, isUserLoading } = useSupabase();
@@ -37,6 +39,12 @@ export default function SignupPage() {
const [isLoading, setIsLoading] = useState(false);
//TODO: Remove after closed beta
const turnstile = useTurnstile({
action: "signup",
autoVerify: false,
resetOnError: true,
});
const form = useForm<z.infer<typeof signupFormSchema>>({
resolver: zodResolver(signupFormSchema),
defaultValues: {
@@ -56,20 +64,28 @@ export default function SignupPage() {
return;
}
const error = await signup(data);
if (!turnstile.verified) {
setFeedback("Please complete the CAPTCHA challenge.");
setIsLoading(false);
return;
}
const error = await signup(data, turnstile.token as string);
setIsLoading(false);
if (error) {
if (error === "user_already_exists") {
setFeedback("User with this email already exists");
turnstile.reset();
return;
} else {
setFeedback(error);
turnstile.reset();
}
return;
}
setFeedback(null);
},
[form],
[form, turnstile],
);
if (user) {
@@ -141,6 +157,17 @@ export default function SignupPage() {
</FormItem>
)}
/>
{/* Turnstile CAPTCHA Component */}
<Turnstile
siteKey={turnstile.siteKey}
onVerify={turnstile.handleVerify}
onExpire={turnstile.handleExpire}
onError={turnstile.handleError}
action="signup"
shouldRender={turnstile.shouldRender}
/>
<AuthButton
onClick={() => onSignup(form.getValues())}
isLoading={isLoading}

View File

@@ -0,0 +1,140 @@
"use client";
import { useCallback, useEffect, useRef, useState } from "react";
import { cn } from "@/lib/utils";
export interface TurnstileProps {
siteKey: string;
onVerify: (token: string) => void;
onExpire?: () => void;
onError?: (error: Error) => void;
action?: string;
className?: string;
id?: string;
shouldRender?: boolean;
}
export function Turnstile({
siteKey,
onVerify,
onExpire,
onError,
action,
className,
id = "cf-turnstile",
shouldRender = true,
}: TurnstileProps) {
const containerRef = useRef<HTMLDivElement>(null);
const widgetIdRef = useRef<string | null>(null);
const [loaded, setLoaded] = useState(false);
// Load the Turnstile script
useEffect(() => {
if (typeof window === "undefined" || !shouldRender) return;
// Skip if already loaded
if (window.turnstile) {
setLoaded(true);
return;
}
// Create script element
const script = document.createElement("script");
script.src =
"https://challenges.cloudflare.com/turnstile/v0/api.js?render=explicit";
script.async = true;
script.defer = true;
script.onload = () => {
setLoaded(true);
};
script.onerror = () => {
onError?.(new Error("Failed to load Turnstile script"));
};
document.head.appendChild(script);
return () => {
if (document.head.contains(script)) {
document.head.removeChild(script);
}
};
}, [onError, shouldRender]);
// Initialize and render the widget when script is loaded
useEffect(() => {
if (!loaded || !containerRef.current || !window.turnstile || !shouldRender)
return;
// Reset any existing widget
if (widgetIdRef.current && window.turnstile) {
window.turnstile.reset(widgetIdRef.current);
}
// Render a new widget
if (window.turnstile) {
widgetIdRef.current = window.turnstile.render(containerRef.current, {
sitekey: siteKey,
callback: (token: string) => {
onVerify(token);
},
"expired-callback": () => {
onExpire?.();
},
"error-callback": () => {
onError?.(new Error("Turnstile widget encountered an error"));
},
action,
});
}
return () => {
if (widgetIdRef.current && window.turnstile) {
window.turnstile.remove(widgetIdRef.current);
widgetIdRef.current = null;
}
};
}, [loaded, siteKey, onVerify, onExpire, onError, action, shouldRender]);
// Method to reset the widget manually
const reset = useCallback(() => {
if (loaded && widgetIdRef.current && window.turnstile && shouldRender) {
window.turnstile.reset(widgetIdRef.current);
}
}, [loaded, shouldRender]);
// If shouldRender is false, don't render anything
if (!shouldRender) {
return null;
}
return (
<div
id={id}
ref={containerRef}
className={cn("my-4 flex items-center justify-center", className)}
/>
);
}
// Add TypeScript interface to Window to include turnstile property
declare global {
interface Window {
turnstile?: {
render: (
container: HTMLElement,
options: {
sitekey: string;
callback: (token: string) => void;
"expired-callback"?: () => void;
"error-callback"?: () => void;
action?: string;
},
) => string;
reset: (widgetId: string) => void;
remove: (widgetId: string) => void;
};
}
}
export default Turnstile;

View File

@@ -4,6 +4,7 @@ import AuthCard from "./AuthCard";
import AuthFeedback from "./AuthFeedback";
import AuthHeader from "./AuthHeader";
import { PasswordInput } from "./PasswordInput";
import Turnstile from "./Turnstile";
export {
AuthBottomText,
@@ -12,4 +13,5 @@ export {
AuthFeedback,
AuthHeader,
PasswordInput,
Turnstile,
};

View File

@@ -0,0 +1,169 @@
import { useState, useCallback, useEffect } from "react";
import { verifyTurnstileToken } from "@/lib/turnstile";
import { BehaveAs, getBehaveAs } from "@/lib/utils";
interface UseTurnstileOptions {
action?: string;
autoVerify?: boolean;
onSuccess?: () => void;
onError?: (error: Error) => void;
resetOnError?: boolean;
}
interface UseTurnstileResult {
token: string | null;
verifying: boolean;
verified: boolean;
error: Error | null;
handleVerify: (token: string) => Promise<boolean>;
handleExpire: () => void;
handleError: (error: Error) => void;
reset: () => void;
siteKey: string;
shouldRender: boolean;
}
const TURNSTILE_SITE_KEY =
process.env.NEXT_PUBLIC_CLOUDFLARE_TURNSTILE_SITE_KEY || "";
/**
* Custom hook for managing Turnstile state in forms
*/
export function useTurnstile({
action,
autoVerify = true,
onSuccess,
onError,
resetOnError = false,
}: UseTurnstileOptions = {}): UseTurnstileResult {
const [token, setToken] = useState<string | null>(null);
const [verifying, setVerifying] = useState(false);
const [verified, setVerified] = useState(false);
const [error, setError] = useState<Error | null>(null);
const [shouldRender, setShouldRender] = useState(false);
const [widgetId, setWidgetId] = useState<string | null>(null);
useEffect(() => {
const behaveAs = getBehaveAs();
const hasTurnstileKey = !!TURNSTILE_SITE_KEY;
setShouldRender(behaveAs === BehaveAs.CLOUD && hasTurnstileKey);
if (behaveAs !== BehaveAs.CLOUD || !hasTurnstileKey) {
setVerified(true);
}
}, []);
useEffect(() => {
if (token && !autoVerify && shouldRender) {
setVerified(true);
}
}, [token, autoVerify, shouldRender]);
useEffect(() => {
if (typeof window !== "undefined" && window.turnstile) {
const originalRender = window.turnstile.render;
window.turnstile.render = (container, options) => {
const id = originalRender(container, options);
setWidgetId(id);
return id;
};
}
}, []);
const reset = useCallback(() => {
if (shouldRender && window.turnstile && widgetId) {
window.turnstile.reset(widgetId);
// Always reset the state when reset is called
setToken(null);
setVerified(false);
setVerifying(false);
setError(null);
}
}, [shouldRender, widgetId]);
const handleVerify = useCallback(
async (newToken: string) => {
if (!shouldRender) {
return true;
}
setToken(newToken);
setError(null);
if (autoVerify) {
setVerifying(true);
try {
const success = await verifyTurnstileToken(newToken, action);
setVerified(success);
if (success && onSuccess) {
onSuccess();
} else if (!success) {
const newError = new Error("Turnstile verification failed");
setError(newError);
if (onError) onError(newError);
if (resetOnError) {
setVerified(false);
}
}
setVerifying(false);
return success;
} catch (err) {
const newError =
err instanceof Error
? err
: new Error("Unknown error during verification");
setError(newError);
if (resetOnError) {
setVerified(false);
}
setVerifying(false);
if (onError) onError(newError);
return false;
}
} else {
setVerified(true);
}
return true;
},
[action, autoVerify, onSuccess, onError, resetOnError, shouldRender],
);
const handleExpire = useCallback(() => {
if (shouldRender) {
setToken(null);
setVerified(false);
}
}, [shouldRender]);
const handleError = useCallback(
(err: Error) => {
if (shouldRender) {
setError(err);
if (resetOnError) {
setVerified(false);
}
if (onError) onError(err);
}
},
[onError, shouldRender, resetOnError],
);
return {
token,
verifying,
verified,
error,
handleVerify,
handleExpire,
handleError,
reset,
siteKey: TURNSTILE_SITE_KEY,
shouldRender,
};
}

View File

@@ -0,0 +1,42 @@
/**
* Utility functions for working with Cloudflare Turnstile
*/
import { BehaveAs, getBehaveAs } from "@/lib/utils";
export async function verifyTurnstileToken(
token: string,
action?: string,
): Promise<boolean> {
// Skip verification in local development
const behaveAs = getBehaveAs();
if (behaveAs !== BehaveAs.CLOUD) {
return true;
}
try {
const response = await fetch(
`${process.env.NEXT_PUBLIC_AGPT_SERVER_URL}/turnstile/verify`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
token,
action,
}),
},
);
if (!response.ok) {
console.error("Turnstile verification failed:", await response.text());
return false;
}
const data = await response.json();
return data.success === true;
} catch (error) {
console.error("Error verifying Turnstile token:", error);
return false;
}
}