Merge branch 'dev' into claude/admin-user-management-011CULzkwgiPXZYcvCeozofC

This commit is contained in:
Nicholas Tindle
2025-11-07 11:28:29 -06:00
committed by GitHub
14 changed files with 433 additions and 38 deletions

View File

@@ -44,6 +44,7 @@ from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
from backend.server.external.api import external_app
from backend.server.middleware.security import SecurityHeadersMiddleware
from backend.server.utils.cors import build_cors_params
from backend.util import json
from backend.util.cloud_storage import shutdown_cloud_storage_handler
from backend.util.exceptions import (
@@ -309,9 +310,14 @@ async def health():
class AgentServer(backend.util.service.AppProcess):
def run(self):
cors_params = build_cors_params(
settings.config.backend_cors_allow_origins,
settings.config.app_env,
)
server_app = starlette.middleware.cors.CORSMiddleware(
app=app,
allow_origins=settings.config.backend_cors_allow_origins,
**cors_params,
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers

View File

@@ -1,7 +1,8 @@
"""Helper functions for improved test assertions and error handling."""
import json
from typing import Any, Dict, Optional
from contextlib import contextmanager
from typing import Any, Dict, Iterator, Optional
def assert_response_status(
@@ -107,3 +108,24 @@ def assert_mock_called_with_partial(mock_obj: Any, **expected_kwargs: Any) -> No
assert (
actual_kwargs[key] == expected_value
), f"Mock called with {key}={actual_kwargs[key]}, expected {expected_value}"
@contextmanager
def override_config(settings: Any, attribute: str, value: Any) -> Iterator[None]:
"""Temporarily override a config attribute for testing.
Warning: Directly mutates settings.config. If config is reloaded or cached
elsewhere during the test, side effects may leak. Use with caution in
parallel tests or when config is accessed globally.
Args:
settings: The settings object containing .config
attribute: The config attribute name to override
value: The temporary value to set
"""
original = getattr(settings.config, attribute)
setattr(settings.config, attribute, value)
try:
yield
finally:
setattr(settings.config, attribute, original)

View File

@@ -0,0 +1,67 @@
from __future__ import annotations
import re
from typing import List, Sequence, TypedDict
from backend.util.settings import AppEnvironment
class CorsParams(TypedDict):
allow_origins: List[str]
allow_origin_regex: str | None
def build_cors_params(origins: Sequence[str], app_env: AppEnvironment) -> CorsParams:
allow_origins: List[str] = []
regex_patterns: List[str] = []
if app_env == AppEnvironment.PRODUCTION:
for origin in origins:
if origin.startswith("regex:"):
pattern = origin[len("regex:") :]
pattern_lower = pattern.lower()
if "localhost" in pattern_lower or "127.0.0.1" in pattern_lower:
raise ValueError(
f"Production environment cannot allow localhost origins via regex: {pattern}"
)
try:
compiled = re.compile(pattern)
test_urls = [
"http://localhost:3000",
"http://127.0.0.1:3000",
"https://localhost:8000",
"https://127.0.0.1:8000",
]
for test_url in test_urls:
if compiled.search(test_url):
raise ValueError(
f"Production regex pattern matches localhost/127.0.0.1: {pattern}"
)
except re.error:
pass
continue
lowered = origin.lower()
if "localhost" in lowered or "127.0.0.1" in lowered:
raise ValueError(
"Production environment cannot allow localhost origins"
)
for origin in origins:
if origin.startswith("regex:"):
regex_patterns.append(origin[len("regex:") :])
else:
allow_origins.append(origin)
allow_origin_regex = None
if regex_patterns:
if len(regex_patterns) == 1:
allow_origin_regex = f"^(?:{regex_patterns[0]})$"
else:
combined_pattern = "|".join(f"(?:{pattern})" for pattern in regex_patterns)
allow_origin_regex = f"^(?:{combined_pattern})$"
return {
"allow_origins": allow_origins,
"allow_origin_regex": allow_origin_regex,
}

View File

@@ -0,0 +1,62 @@
import pytest
from backend.server.utils.cors import build_cors_params
from backend.util.settings import AppEnvironment
def test_build_cors_params_splits_regex_patterns() -> None:
origins = [
"https://app.example.com",
"regex:https://.*\\.example\\.com",
]
result = build_cors_params(origins, AppEnvironment.LOCAL)
assert result["allow_origins"] == ["https://app.example.com"]
assert result["allow_origin_regex"] == "^(?:https://.*\\.example\\.com)$"
def test_build_cors_params_combines_multiple_regex_patterns() -> None:
origins = [
"regex:https://alpha.example.com",
"regex:https://beta.example.com",
]
result = build_cors_params(origins, AppEnvironment.DEVELOPMENT)
assert result["allow_origins"] == []
assert result["allow_origin_regex"] == (
"^(?:(?:https://alpha.example.com)|(?:https://beta.example.com))$"
)
def test_build_cors_params_blocks_localhost_literal_in_production() -> None:
with pytest.raises(ValueError):
build_cors_params(["http://localhost:3000"], AppEnvironment.PRODUCTION)
def test_build_cors_params_blocks_localhost_regex_in_production() -> None:
with pytest.raises(ValueError):
build_cors_params(["regex:https://.*localhost.*"], AppEnvironment.PRODUCTION)
def test_build_cors_params_blocks_case_insensitive_localhost_regex() -> None:
with pytest.raises(ValueError, match="localhost origins via regex"):
build_cors_params(["regex:https://(?i)LOCALHOST.*"], AppEnvironment.PRODUCTION)
def test_build_cors_params_blocks_regex_matching_localhost_at_runtime() -> None:
with pytest.raises(ValueError, match="matches localhost"):
build_cors_params(["regex:https?://.*:3000"], AppEnvironment.PRODUCTION)
def test_build_cors_params_allows_vercel_preview_regex() -> None:
result = build_cors_params(
["regex:https://autogpt-git-[a-z0-9-]+\\.vercel\\.app"],
AppEnvironment.PRODUCTION,
)
assert result["allow_origins"] == []
assert result["allow_origin_regex"] == (
"^(?:https://autogpt-git-[a-z0-9-]+\\.vercel\\.app)$"
)

View File

@@ -22,6 +22,7 @@ from backend.server.model import (
WSSubscribeGraphExecutionRequest,
WSSubscribeGraphExecutionsRequest,
)
from backend.server.utils.cors import build_cors_params
from backend.util.retry import continuous_retry
from backend.util.service import AppProcess
from backend.util.settings import AppEnvironment, Config, Settings
@@ -315,9 +316,13 @@ async def health():
class WebsocketServer(AppProcess):
def run(self):
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
cors_params = build_cors_params(
settings.config.backend_cors_allow_origins,
settings.config.app_env,
)
server_app = CORSMiddleware(
app=app,
allow_origins=settings.config.backend_cors_allow_origins,
**cors_params,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],

View File

@@ -8,11 +8,13 @@ from pytest_snapshot.plugin import Snapshot
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.test_helpers import override_config
from backend.server.ws_api import AppEnvironment, WebsocketServer, WSMessage, WSMethod
from backend.server.ws_api import app as websocket_app
from backend.server.ws_api import (
WSMessage,
WSMethod,
handle_subscribe,
handle_unsubscribe,
settings,
websocket_router,
)
@@ -29,6 +31,47 @@ def mock_manager() -> AsyncMock:
return AsyncMock(spec=ConnectionManager)
def test_websocket_server_uses_cors_helper(mocker) -> None:
cors_params = {
"allow_origins": ["https://app.example.com"],
"allow_origin_regex": None,
}
mocker.patch("backend.server.ws_api.uvicorn.run")
cors_middleware = mocker.patch(
"backend.server.ws_api.CORSMiddleware", return_value=object()
)
build_cors = mocker.patch(
"backend.server.ws_api.build_cors_params", return_value=cors_params
)
with override_config(
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
), override_config(settings, "app_env", AppEnvironment.LOCAL):
WebsocketServer().run()
build_cors.assert_called_once_with(
cors_params["allow_origins"], AppEnvironment.LOCAL
)
cors_middleware.assert_called_once_with(
app=websocket_app,
allow_origins=cors_params["allow_origins"],
allow_origin_regex=cors_params["allow_origin_regex"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
mocker.patch("backend.server.ws_api.uvicorn.run")
with override_config(
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
with pytest.raises(ValueError):
WebsocketServer().run()
@pytest.mark.asyncio
async def test_websocket_router_subscribe(
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker

View File

@@ -1,5 +1,6 @@
import json
import os
import re
from enum import Enum
from typing import Any, Dict, Generic, List, Set, Tuple, Type, TypeVar
@@ -427,34 +428,62 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Maximum message size limit for communication with the message bus",
)
backend_cors_allow_origins: List[str] = Field(default=["http://localhost:3000"])
backend_cors_allow_origins: List[str] = Field(
default=["http://localhost:3000"],
description="Allowed Origins for CORS. Supports exact URLs (http/https) or entries prefixed with "
'"regex:" to match via regular expression.',
)
@field_validator("backend_cors_allow_origins")
@classmethod
def validate_cors_allow_origins(cls, v: List[str]) -> List[str]:
out = []
port = None
has_localhost = False
has_127_0_0_1 = False
for url in v:
url = url.strip()
if url.startswith(("http://", "https://")):
if "localhost" in url:
port = url.split(":")[2]
has_localhost = True
if "127.0.0.1" in url:
port = url.split(":")[2]
has_127_0_0_1 = True
out.append(url)
else:
raise ValueError(f"Invalid URL: {url}")
validated: List[str] = []
localhost_ports: set[str] = set()
ip127_ports: set[str] = set()
if has_127_0_0_1 and not has_localhost:
out.append(f"http://localhost:{port}")
if has_localhost and not has_127_0_0_1:
out.append(f"http://127.0.0.1:{port}")
for raw_origin in v:
origin = raw_origin.strip()
if origin.startswith("regex:"):
pattern = origin[len("regex:") :]
if not pattern:
raise ValueError("Invalid regex pattern: pattern cannot be empty")
try:
re.compile(pattern)
except re.error as exc:
raise ValueError(
f"Invalid regex pattern '{pattern}': {exc}"
) from exc
validated.append(origin)
continue
return out
if origin.startswith(("http://", "https://")):
if "localhost" in origin:
try:
port = origin.split(":")[2]
localhost_ports.add(port)
except IndexError as exc:
raise ValueError(
"localhost origins must include an explicit port, e.g. http://localhost:3000"
) from exc
if "127.0.0.1" in origin:
try:
port = origin.split(":")[2]
ip127_ports.add(port)
except IndexError as exc:
raise ValueError(
"127.0.0.1 origins must include an explicit port, e.g. http://127.0.0.1:3000"
) from exc
validated.append(origin)
continue
raise ValueError(f"Invalid URL or regex origin: {origin}")
for port in ip127_ports - localhost_ports:
validated.append(f"http://localhost:{port}")
for port in localhost_ports - ip127_ports:
validated.append(f"http://127.0.0.1:{port}")
return validated
@classmethod
def settings_customise_sources(

View File

@@ -1,7 +0,0 @@
import { postV1ResetOnboardingProgress } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
import { redirect } from "next/navigation";
export default async function OnboardingResetPage() {
await postV1ResetOnboardingProgress();
redirect("/onboarding/1-welcome");
}

View File

@@ -0,0 +1,32 @@
"use client";
import { postV1ResetOnboardingProgress } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { redirect } from "next/navigation";
import { useEffect } from "react";
export default function OnboardingResetPage() {
const { toast } = useToast();
useEffect(() => {
postV1ResetOnboardingProgress()
.then(() => {
toast({
title: "Onboarding reset successfully",
description: "You can now start the onboarding process again",
variant: "success",
});
redirect("/onboarding/1-welcome");
})
.catch(() => {
toast({
title: "Failed to reset onboarding",
description: "Please try again later",
variant: "destructive",
});
});
}, []);
return <LoadingSpinner cover />;
}

View File

@@ -1,6 +1,7 @@
import BackendAPI from "@/lib/autogpt-server-api";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { verifyTurnstileToken } from "@/lib/turnstile";
import { environment } from "@/services/environment";
import { loginFormSchema } from "@/types/auth";
import * as Sentry from "@sentry/nextjs";
import { NextResponse } from "next/server";
@@ -26,7 +27,7 @@ export async function POST(request: Request) {
// Verify Turnstile token if provided
const captchaOk = await verifyTurnstileToken(turnstileToken ?? "", "login");
if (!captchaOk) {
if (!captchaOk && !environment.isVercelPreview()) {
return NextResponse.json(
{ error: "CAPTCHA verification failed. Please try again." },
{ status: 400 },

View File

@@ -1,8 +1,9 @@
import { NextResponse } from "next/server";
import * as Sentry from "@sentry/nextjs";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { verifyTurnstileToken } from "@/lib/turnstile";
import { environment } from "@/services/environment";
import { signupFormSchema } from "@/types/auth";
import * as Sentry from "@sentry/nextjs";
import { NextResponse } from "next/server";
import { shouldShowOnboarding } from "../../helpers";
import { isWaitlistError, logWaitlistError } from "../utils";
@@ -31,7 +32,7 @@ export async function POST(request: Request) {
"signup",
);
if (!captchaOk) {
if (!captchaOk && !environment.isVercelPreview()) {
return NextResponse.json(
{ error: "CAPTCHA verification failed. Please try again." },
{ status: 400 },

View File

@@ -0,0 +1,86 @@
import type { Meta, StoryObj } from "@storybook/nextjs";
import { LoadingSpinner } from "./LoadingSpinner";
const meta: Meta<typeof LoadingSpinner> = {
title: "Atoms/LoadingSpinner",
component: LoadingSpinner,
tags: ["autodocs"],
parameters: {
layout: "centered",
docs: {
description: {
component:
"Animated loading indicator using the Phosphor CircleNotch icon. Provide a `size` prop or custom classes to fit different contexts.",
},
},
},
argTypes: {
size: {
control: "select",
options: ["small", "medium", "large"],
description: "Spinner size preset",
},
className: {
control: "text",
description: "Additional CSS classes to customize color or layout",
},
},
args: {
size: "medium",
className: "text-indigo-500",
role: "status",
"aria-label": "loading",
},
};
export default meta;
type Story = StoryObj<typeof meta>;
export const Default: Story = {};
export const Small: Story = {
args: {
size: "small",
},
};
export const Large: Story = {
args: {
size: "large",
},
};
export const CustomColor: Story = {
args: {
className: "text-emerald-500",
},
};
export const Cover: Story = {
args: {
cover: true,
},
};
export const AllSizes: Story = {
render: renderAllSizes,
};
function renderAllSizes() {
return (
<div className="flex items-center gap-8 text-indigo-500">
<div className="flex flex-col items-center gap-2">
<LoadingSpinner size="small" aria-label="loading-small" />
<span className="text-xs capitalize text-zinc-500">Small</span>
</div>
<div className="flex flex-col items-center gap-2">
<LoadingSpinner size="medium" aria-label="loading-medium" />
<span className="text-xs capitalize text-zinc-500">Medium</span>
</div>
<div className="flex flex-col items-center gap-2">
<LoadingSpinner size="large" aria-label="loading-large" />
<span className="text-xs capitalize text-zinc-500">Large</span>
</div>
</div>
);
}

View File

@@ -0,0 +1,43 @@
import { cn } from "@/lib/utils";
import { CircleNotchIcon } from "@phosphor-icons/react/dist/ssr";
import React from "react";
const sizeClassNameMap = {
small: "h-4 w-4",
medium: "h-6 w-6",
large: "h-10 w-10",
} as const;
type SpinnerSize = keyof typeof sizeClassNameMap;
type LoadingSpinnerProps = {
size?: SpinnerSize;
className?: string;
cover?: boolean;
} & React.ComponentPropsWithoutRef<typeof CircleNotchIcon>;
export function LoadingSpinner(props: LoadingSpinnerProps) {
const { size = "medium", className, cover = false, ...restProps } = props;
const spinner = (
<CircleNotchIcon
className={cn(
"animate-spin text-inherit",
sizeClassNameMap[size],
className,
)}
weight="bold"
{...restProps}
/>
);
if (cover) {
return (
<div className="fixed inset-0 z-50 flex items-center justify-center">
{spinner}
</div>
);
}
return spinner;
}

View File

@@ -84,6 +84,10 @@ function isClientSide() {
return typeof window !== "undefined";
}
function isVercelPreview() {
return process.env.VERCEL_ENV === "preview";
}
function isCAPTCHAEnabled() {
return process.env.NEXT_PUBLIC_TURNSTILE === "enabled";
}
@@ -110,6 +114,7 @@ export const environment = {
isDev,
isCloud,
isLocal,
isVercelPreview,
isCAPTCHAEnabled,
areFeatureFlagsEnabled,
};