mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Resolve merge conflicts with dev
https://claude.ai/code/session_01TPw8kd7p8qwsuNa5qBRcHc
This commit is contained in:
@@ -18,7 +18,7 @@ from prisma.types import (
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -217,6 +217,9 @@ async def add_chat_messages_batch(
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
if msg.get("duration_ms") is not None:
|
||||
data["durationMs"] = msg["duration_ms"]
|
||||
|
||||
messages_data.append(data)
|
||||
|
||||
# Run create_many and session update in parallel within transaction
|
||||
@@ -359,3 +362,22 @@ async def update_tool_message_content(
|
||||
f"tool_call_id {tool_call_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
|
||||
"""Set durationMs on the last assistant message in a session.
|
||||
|
||||
Also invalidates the Redis session cache so the next GET returns
|
||||
the updated duration.
|
||||
"""
|
||||
last_msg = await PrismaChatMessage.prisma().find_first(
|
||||
where={"sessionId": session_id, "role": "assistant"},
|
||||
order={"sequence": "desc"},
|
||||
)
|
||||
if last_msg:
|
||||
await PrismaChatMessage.prisma().update(
|
||||
where={"id": last_msg.id},
|
||||
data={"durationMs": duration_ms},
|
||||
)
|
||||
# Invalidate cache so the session is re-fetched from DB with durationMs
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
@@ -54,6 +54,7 @@ class ChatMessage(BaseModel):
|
||||
refusal: str | None = None
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
duration_ms: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
|
||||
@@ -66,6 +67,7 @@ class ChatMessage(BaseModel):
|
||||
refusal=prisma_message.refusal,
|
||||
tool_calls=_parse_json_field(prisma_message.toolCalls),
|
||||
function_call=_parse_json_field(prisma_message.functionCall),
|
||||
duration_ms=prisma_message.durationMs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ import orjson
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.api.model import CopilotCompletionPayload
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.notification_bus import (
|
||||
AsyncRedisNotificationEventBus,
|
||||
NotificationEvent,
|
||||
@@ -111,6 +112,14 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
|
||||
``session_id`` is used as a fallback for ``turn_id`` when the meta hash
|
||||
pre-dates the turn_id field (backward compat for in-flight sessions).
|
||||
"""
|
||||
created_at = datetime.now(timezone.utc)
|
||||
created_at_raw = meta.get("created_at")
|
||||
if created_at_raw:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(str(created_at_raw))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return ActiveSession(
|
||||
session_id=meta.get("session_id", "") or session_id,
|
||||
user_id=meta.get("user_id", "") or None,
|
||||
@@ -119,6 +128,7 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
|
||||
turn_id=meta.get("turn_id", "") or session_id,
|
||||
blocking=meta.get("blocking") == "1",
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
|
||||
@@ -802,6 +812,33 @@ async def mark_session_completed(
|
||||
f"Failed to publish error event for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
# Compute wall-clock duration from session created_at.
|
||||
# Only persist when (a) the session completed successfully and
|
||||
# (b) created_at was actually present in Redis meta (not a fallback).
|
||||
duration_ms: int | None = None
|
||||
if meta and not error_message:
|
||||
created_at_raw = meta.get("created_at")
|
||||
if created_at_raw:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(str(created_at_raw))
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||
elapsed = datetime.now(timezone.utc) - created_at
|
||||
duration_ms = max(0, int(elapsed.total_seconds() * 1000))
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Failed to compute session duration for %s (created_at=%r)",
|
||||
session_id,
|
||||
created_at_raw,
|
||||
)
|
||||
|
||||
# Persist duration on the last assistant message
|
||||
if duration_ms is not None:
|
||||
try:
|
||||
await chat_db().set_turn_duration(session_id, duration_ms)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save turn duration for {session_id}: {e}")
|
||||
|
||||
# Publish StreamFinish AFTER status is set to "completed"/"failed".
|
||||
# This is the SINGLE place that publishes StreamFinish — services and
|
||||
# the processor must NOT publish it themselves.
|
||||
|
||||
@@ -537,7 +537,7 @@ async def check_hitl_review(
|
||||
)
|
||||
|
||||
synthetic_node_exec_id = (
|
||||
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}" f"{uuid.uuid4().hex[:8]}"
|
||||
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
review_context = ExecutionContext(
|
||||
@@ -582,7 +582,16 @@ def _resolve_discriminated_credentials(
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
"""Resolve credential requirements, applying discriminator logic where needed.
|
||||
|
||||
Handles two discrimination modes:
|
||||
1. **Provider-based** (``discriminator_mapping`` is set): the discriminator
|
||||
field value selects the provider (e.g. an AI model name -> provider).
|
||||
2. **URL/host-based** (``discriminator`` is set but ``discriminator_mapping``
|
||||
is ``None``): the discriminator field value (typically a URL) is added to
|
||||
``discriminator_values`` so that host-scoped credential matching can
|
||||
compare the credential's host against the target URL.
|
||||
"""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
@@ -592,25 +601,42 @@ def _resolve_discriminated_credentials(
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
if field_info.discriminator:
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(field_info.discriminator)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
"Discriminated provider for %s: %s -> %s",
|
||||
field_name,
|
||||
discriminator_value,
|
||||
effective_field_info.provider,
|
||||
)
|
||||
if discriminator_value is not None:
|
||||
if field_info.discriminator_mapping:
|
||||
# Provider-based discrimination (e.g. model -> provider)
|
||||
if discriminator_value in field_info.discriminator_mapping:
|
||||
effective_field_info = field_info.discriminate(
|
||||
discriminator_value
|
||||
)
|
||||
effective_field_info.discriminator_values.add(
|
||||
discriminator_value
|
||||
)
|
||||
# Model names are safe to log (not PII); URLs are
|
||||
# intentionally omitted in the host-based branch below.
|
||||
logger.debug(
|
||||
"Discriminated provider for %s: %s -> %s",
|
||||
field_name,
|
||||
discriminator_value,
|
||||
effective_field_info.provider,
|
||||
)
|
||||
else:
|
||||
# URL/host-based discrimination (e.g. url -> host matching).
|
||||
# Deep copy to avoid mutating the cached schema-level
|
||||
# field_info (model_copy() is shallow — the mutable set
|
||||
# would be shared).
|
||||
effective_field_info = field_info.model_copy(deep=True)
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
"Added discriminator value for host matching on %s",
|
||||
field_name,
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
|
||||
@@ -0,0 +1,916 @@
|
||||
"""Tests for credential resolution across all credential types in the CoPilot.
|
||||
|
||||
These tests verify that:
|
||||
1. `_resolve_discriminated_credentials` correctly populates discriminator_values
|
||||
for URL-based (host-scoped) and provider-based (api_key) credential fields.
|
||||
2. `find_matching_credential` correctly matches credentials for all types:
|
||||
APIKeyCredentials, OAuth2Credentials, UserPasswordCredentials, and
|
||||
HostScopedCredentials.
|
||||
3. The full `resolve_block_credentials` flow correctly resolves matching
|
||||
credentials or reports them as missing for each credential type.
|
||||
4. `RunBlockTool._execute` end-to-end tests return correct response types.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.http import SendAuthenticatedWebRequestBlock
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._test_data import make_session
|
||||
from .helpers import _resolve_discriminated_credentials, resolve_block_credentials
|
||||
from .models import BlockDetailsResponse, SetupRequirementsResponse
|
||||
from .run_block import RunBlockTool
|
||||
from .utils import find_matching_credential
|
||||
|
||||
_TEST_USER_ID = "test-user-http-cred"
|
||||
|
||||
# Properly typed constants to avoid type: ignore on CredentialsFieldInfo construction.
|
||||
_HOST_SCOPED_TYPES: frozenset[CredentialsType] = frozenset(["host_scoped"])
|
||||
_API_KEY_TYPES: frozenset[CredentialsType] = frozenset(["api_key"])
|
||||
_OAUTH2_TYPES: frozenset[CredentialsType] = frozenset(["oauth2"])
|
||||
_USER_PASSWORD_TYPES: frozenset[CredentialsType] = frozenset(["user_password"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_discriminated_credentials tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveDiscriminatedCredentials:
|
||||
"""Tests for _resolve_discriminated_credentials with URL-based discrimination."""
|
||||
|
||||
def _get_auth_block(self):
|
||||
return SendAuthenticatedWebRequestBlock()
|
||||
|
||||
def test_url_discriminator_populates_discriminator_values(self):
|
||||
"""When input_data contains a URL, discriminator_values should include it."""
|
||||
block = self._get_auth_block()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
assert "credentials" in result
|
||||
field_info = result["credentials"]
|
||||
assert "https://api.example.com/v1/data" in field_info.discriminator_values
|
||||
|
||||
def test_url_discriminator_without_url_keeps_empty_values(self):
|
||||
"""When no URL is provided, discriminator_values should remain empty."""
|
||||
block = self._get_auth_block()
|
||||
input_data = {}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
assert "credentials" in result
|
||||
field_info = result["credentials"]
|
||||
assert len(field_info.discriminator_values) == 0
|
||||
|
||||
def test_url_discriminator_does_not_mutate_original_field_info(self):
|
||||
"""The original block schema field_info must not be mutated."""
|
||||
block = self._get_auth_block()
|
||||
|
||||
# Grab a reference to the original schema-level field_info
|
||||
original_info = block.input_schema.get_credentials_fields_info()["credentials"]
|
||||
|
||||
# Call with a URL, which adds to discriminator_values on the copy
|
||||
_resolve_discriminated_credentials(
|
||||
block, {"url": "https://api.example.com/v1/data"}
|
||||
)
|
||||
|
||||
# The original object must remain unchanged
|
||||
assert len(original_info.discriminator_values) == 0
|
||||
|
||||
# And a fresh call without URL should also return empty values
|
||||
result = _resolve_discriminated_credentials(block, {})
|
||||
field_info = result["credentials"]
|
||||
assert len(field_info.discriminator_values) == 0
|
||||
|
||||
def test_url_discriminator_preserves_provider_and_type(self):
|
||||
"""Provider and supported_types should be preserved after URL discrimination."""
|
||||
block = self._get_auth_block()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
field_info = result["credentials"]
|
||||
assert ProviderName.HTTP in field_info.provider
|
||||
assert "host_scoped" in field_info.supported_types
|
||||
|
||||
def test_provider_discriminator_still_works(self):
|
||||
"""Verify provider-based discrimination (e.g. model -> provider) is preserved.
|
||||
|
||||
The refactored conditional in _resolve_discriminated_credentials split the
|
||||
original single ``if`` into nested ``if/else`` branches. This test ensures
|
||||
the provider-based path still narrows the provider correctly.
|
||||
"""
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
||||
block = AITextGeneratorBlock()
|
||||
input_data = {"model": "gpt-4o-mini"}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
assert "credentials" in result
|
||||
field_info = result["credentials"]
|
||||
# Should narrow provider to openai
|
||||
assert ProviderName.OPENAI in field_info.provider
|
||||
assert "gpt-4o-mini" in field_info.discriminator_values
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (host-scoped)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingHostScopedCredential:
|
||||
"""Tests for find_matching_credential with host-scoped credentials."""
|
||||
|
||||
def _make_host_scoped_cred(
|
||||
self, host: str, cred_id: str = "test-cred-id"
|
||||
) -> HostScopedCredentials:
|
||||
return HostScopedCredentials(
|
||||
id=cred_id,
|
||||
provider="http",
|
||||
host=host,
|
||||
headers={"Authorization": SecretStr("Bearer test-token")},
|
||||
title=f"Cred for {host}",
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self, discriminator_values: set | None = None
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.HTTP]),
|
||||
credentials_types=_HOST_SCOPED_TYPES,
|
||||
credentials_scopes=None,
|
||||
discriminator="url",
|
||||
discriminator_values=discriminator_values or set(),
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_host(self):
|
||||
"""A host-scoped credential matching the URL host should be returned."""
|
||||
cred = self._make_host_scoped_cred("api.example.com")
|
||||
field_info = self._make_field_info({"https://api.example.com/v1/data"})
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_host(self):
|
||||
"""A host-scoped credential for a different host should not match."""
|
||||
cred = self._make_host_scoped_cred("api.github.com")
|
||||
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_matches_any_when_no_discriminator_values(self):
|
||||
"""With empty discriminator_values, any host-scoped credential matches.
|
||||
|
||||
Note: this tests the current fallback behavior in _credential_is_for_host()
|
||||
where empty discriminator_values means "no host constraint" and any
|
||||
host-scoped credential is accepted. This is by design for the case where
|
||||
the target URL is not yet known (e.g. schema preview with empty input).
|
||||
"""
|
||||
cred = self._make_host_scoped_cred("api.anything.com")
|
||||
field_info = self._make_field_info(set())
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_wildcard_host_matching(self):
|
||||
"""Wildcard host (*.example.com) should match subdomains."""
|
||||
cred = self._make_host_scoped_cred("*.example.com")
|
||||
field_info = self._make_field_info({"https://api.example.com/v1/data"})
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple host-scoped credentials exist, the correct one is selected."""
|
||||
cred_github = self._make_host_scoped_cred("api.github.com", "github-cred")
|
||||
cred_stripe = self._make_host_scoped_cred("api.stripe.com", "stripe-cred")
|
||||
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
|
||||
|
||||
result = find_matching_credential([cred_github, cred_stripe], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "stripe-cred"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (api_key)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingAPIKeyCredential:
|
||||
"""Tests for find_matching_credential with API key credentials."""
|
||||
|
||||
def _make_api_key_cred(
|
||||
self, provider: str = "google_maps", cred_id: str = "test-api-key-id"
|
||||
) -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
api_key=SecretStr("sk-test-key-123"),
|
||||
title=f"API key for {provider}",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self, provider: ProviderName = ProviderName.GOOGLE_MAPS
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_types=_API_KEY_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_provider(self):
|
||||
"""An API key credential matching the provider should be returned."""
|
||||
cred = self._make_api_key_cred("google_maps")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_provider(self):
|
||||
"""An API key credential for a different provider should not match."""
|
||||
cred = self._make_api_key_cred("openai")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_rejects_credential_for_wrong_type(self):
|
||||
"""An OAuth2 credential should not match an api_key requirement."""
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-cred-id",
|
||||
provider="google_maps",
|
||||
access_token=SecretStr("mock-token"),
|
||||
scopes=[],
|
||||
title="OAuth cred (wrong type)",
|
||||
)
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([oauth_cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple API key credentials exist, the correct provider is selected."""
|
||||
cred_maps = self._make_api_key_cred("google_maps", "maps-key")
|
||||
cred_openai = self._make_api_key_cred("openai", "openai-key")
|
||||
field_info = self._make_field_info(ProviderName.OPENAI)
|
||||
|
||||
result = find_matching_credential([cred_maps, cred_openai], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "openai-key"
|
||||
|
||||
def test_returns_none_when_no_credentials(self):
|
||||
"""Should return None when the credential list is empty."""
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (oauth2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingOAuth2Credential:
|
||||
"""Tests for find_matching_credential with OAuth2 credentials."""
|
||||
|
||||
def _make_oauth2_cred(
|
||||
self,
|
||||
provider: str = "google",
|
||||
scopes: list[str] | None = None,
|
||||
cred_id: str = "test-oauth2-id",
|
||||
) -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
access_token=SecretStr("mock-access-token"),
|
||||
refresh_token=SecretStr("mock-refresh-token"),
|
||||
access_token_expires_at=1234567890,
|
||||
scopes=scopes or [],
|
||||
title=f"OAuth2 cred for {provider}",
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self,
|
||||
provider: ProviderName = ProviderName.GOOGLE,
|
||||
required_scopes: frozenset[str] | None = None,
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_types=_OAUTH2_TYPES,
|
||||
credentials_scopes=required_scopes,
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_provider(self):
|
||||
"""An OAuth2 credential matching the provider should be returned."""
|
||||
cred = self._make_oauth2_cred("google")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_provider(self):
|
||||
"""An OAuth2 credential for a different provider should not match."""
|
||||
cred = self._make_oauth2_cred("github")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_matches_credential_with_required_scopes(self):
|
||||
"""An OAuth2 credential with all required scopes should match."""
|
||||
cred = self._make_oauth2_cred(
|
||||
"google",
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
],
|
||||
)
|
||||
field_info = self._make_field_info(
|
||||
ProviderName.GOOGLE,
|
||||
required_scopes=frozenset(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
),
|
||||
)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_rejects_credential_with_insufficient_scopes(self):
|
||||
"""An OAuth2 credential missing required scopes should not match."""
|
||||
cred = self._make_oauth2_cred(
|
||||
"google",
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
field_info = self._make_field_info(
|
||||
ProviderName.GOOGLE,
|
||||
required_scopes=frozenset(
|
||||
[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_matches_credential_when_no_scopes_required(self):
|
||||
"""An OAuth2 credential should match when no scopes are required."""
|
||||
cred = self._make_oauth2_cred("google", scopes=[])
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple OAuth2 credentials exist, the correct one is selected."""
|
||||
cred_google = self._make_oauth2_cred("google", cred_id="google-cred")
|
||||
cred_github = self._make_oauth2_cred("github", cred_id="github-cred")
|
||||
field_info = self._make_field_info(ProviderName.GITHUB)
|
||||
|
||||
result = find_matching_credential([cred_google, cred_github], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "github-cred"
|
||||
|
||||
def test_returns_none_when_no_credentials(self):
|
||||
"""Should return None when the credential list is empty."""
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (user_password)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingUserPasswordCredential:
|
||||
"""Tests for find_matching_credential with user/password credentials."""
|
||||
|
||||
def _make_user_password_cred(
|
||||
self, provider: str = "smtp", cred_id: str = "test-userpass-id"
|
||||
) -> UserPasswordCredentials:
|
||||
return UserPasswordCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
username=SecretStr("test-user"),
|
||||
password=SecretStr("test-pass"),
|
||||
title=f"Credentials for {provider}",
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self, provider: ProviderName = ProviderName.SMTP
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_types=_USER_PASSWORD_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_provider(self):
|
||||
"""A user/password credential matching the provider should be returned."""
|
||||
cred = self._make_user_password_cred("smtp")
|
||||
field_info = self._make_field_info(ProviderName.SMTP)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_provider(self):
|
||||
"""A user/password credential for a different provider should not match."""
|
||||
cred = self._make_user_password_cred("smtp")
|
||||
field_info = self._make_field_info(ProviderName.HUBSPOT)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_rejects_credential_for_wrong_type(self):
|
||||
"""An API key credential should not match a user_password requirement."""
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="api-key-cred-id",
|
||||
provider="smtp",
|
||||
api_key=SecretStr("wrong-type-key"),
|
||||
title="API key cred (wrong type)",
|
||||
)
|
||||
field_info = self._make_field_info(ProviderName.SMTP)
|
||||
|
||||
result = find_matching_credential([api_key_cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple user/password credentials exist, the correct one is selected."""
|
||||
cred_smtp = self._make_user_password_cred("smtp", "smtp-cred")
|
||||
cred_hubspot = self._make_user_password_cred("hubspot", "hubspot-cred")
|
||||
field_info = self._make_field_info(ProviderName.HUBSPOT)
|
||||
|
||||
result = find_matching_credential([cred_smtp, cred_hubspot], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "hubspot-cred"
|
||||
|
||||
def test_returns_none_when_no_credentials(self):
|
||||
"""Should return None when the credential list is empty."""
|
||||
field_info = self._make_field_info(ProviderName.SMTP)
|
||||
|
||||
result = find_matching_credential([], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (mixed credential types)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingCredentialMixedTypes:
|
||||
"""Tests that find_matching_credential correctly filters by type in a mixed list."""
|
||||
|
||||
def test_selects_api_key_from_mixed_list(self):
|
||||
"""API key requirement should skip OAuth2 and user_password credentials."""
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="openai",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
)
|
||||
userpass_cred = UserPasswordCredentials(
|
||||
id="userpass-id",
|
||||
provider="openai",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
)
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="apikey-id",
|
||||
provider="openai",
|
||||
api_key=SecretStr("sk-key"),
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.OPENAI]),
|
||||
credentials_types=_API_KEY_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
result = find_matching_credential(
|
||||
[oauth_cred, userpass_cred, api_key_cred], field_info
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == "apikey-id"
|
||||
|
||||
def test_selects_oauth2_from_mixed_list(self):
|
||||
"""OAuth2 requirement should skip API key and user_password credentials."""
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="apikey-id",
|
||||
provider="google",
|
||||
api_key=SecretStr("key"),
|
||||
)
|
||||
userpass_cred = UserPasswordCredentials(
|
||||
id="userpass-id",
|
||||
provider="google",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
)
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="google",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.GOOGLE]),
|
||||
credentials_types=_OAUTH2_TYPES,
|
||||
credentials_scopes=frozenset(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
),
|
||||
)
|
||||
|
||||
result = find_matching_credential(
|
||||
[api_key_cred, userpass_cred, oauth_cred], field_info
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == "oauth-id"
|
||||
|
||||
def test_selects_user_password_from_mixed_list(self):
|
||||
"""User/password requirement should skip API key and OAuth2 credentials."""
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="apikey-id",
|
||||
provider="smtp",
|
||||
api_key=SecretStr("key"),
|
||||
)
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="smtp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
)
|
||||
userpass_cred = UserPasswordCredentials(
|
||||
id="userpass-id",
|
||||
provider="smtp",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.SMTP]),
|
||||
credentials_types=_USER_PASSWORD_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
result = find_matching_credential(
|
||||
[api_key_cred, oauth_cred, userpass_cred], field_info
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == "userpass-id"
|
||||
|
||||
def test_returns_none_when_only_wrong_types_available(self):
|
||||
"""Should return None when all available creds have the wrong type."""
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="google_maps",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.GOOGLE_MAPS]),
|
||||
credentials_types=_API_KEY_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
result = find_matching_credential([oauth_cred], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_block_credentials tests (integration — all credential types)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveBlockCredentials:
|
||||
"""Integration tests for resolve_block_credentials across credential types."""
|
||||
|
||||
async def test_matches_host_scoped_credential_for_url(self):
|
||||
"""resolve_block_credentials should match a host-scoped cred for the given URL."""
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
mock_cred = HostScopedCredentials(
|
||||
id="matching-cred-id",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
title="Example API Cred",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "matching-cred-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_when_no_matching_host(self):
|
||||
"""resolve_block_credentials should report missing creds when host doesn't match."""
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
input_data = {"url": "https://api.stripe.com/v1/charges"}
|
||||
|
||||
wrong_host_cred = HostScopedCredentials(
|
||||
id="wrong-cred-id",
|
||||
provider="http",
|
||||
host="api.github.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
title="GitHub API Cred",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[wrong_host_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_reports_missing_when_no_credentials(self):
|
||||
"""resolve_block_credentials should report missing when user has no creds at all."""
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_matches_api_key_credential_for_llm_block(self):
|
||||
"""resolve_block_credentials should match an API key cred for an LLM block."""
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
||||
block = AITextGeneratorBlock()
|
||||
input_data = {"model": "gpt-4o-mini"}
|
||||
|
||||
mock_cred = APIKeyCredentials(
|
||||
id="openai-key-id",
|
||||
provider="openai",
|
||||
api_key=SecretStr("sk-test-key"),
|
||||
title="OpenAI API Key",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "openai-key-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_api_key_for_wrong_provider(self):
|
||||
"""resolve_block_credentials should report missing when API key provider mismatches."""
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
||||
block = AITextGeneratorBlock()
|
||||
input_data = {"model": "gpt-4o-mini"}
|
||||
|
||||
wrong_provider_cred = APIKeyCredentials(
|
||||
id="wrong-key-id",
|
||||
provider="google_maps",
|
||||
api_key=SecretStr("sk-wrong"),
|
||||
title="Google Maps Key",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[wrong_provider_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_matches_oauth2_credential_for_google_block(self):
|
||||
"""resolve_block_credentials should match an OAuth2 cred for a Google block."""
|
||||
from backend.blocks.google.gmail import GmailReadBlock
|
||||
|
||||
block = GmailReadBlock()
|
||||
input_data = {}
|
||||
|
||||
mock_cred = OAuth2Credentials(
|
||||
id="google-oauth-id",
|
||||
provider="google",
|
||||
access_token=SecretStr("mock-token"),
|
||||
refresh_token=SecretStr("mock-refresh"),
|
||||
access_token_expires_at=9999999999,
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
title="Google OAuth",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "google-oauth-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_oauth2_with_insufficient_scopes(self):
|
||||
"""resolve_block_credentials should report missing when OAuth2 scopes are insufficient."""
|
||||
from backend.blocks.google.gmail import GmailSendBlock
|
||||
|
||||
block = GmailSendBlock()
|
||||
input_data = {}
|
||||
|
||||
# GmailSendBlock requires gmail.send scope; provide only readonly
|
||||
insufficient_cred = OAuth2Credentials(
|
||||
id="limited-oauth-id",
|
||||
provider="google",
|
||||
access_token=SecretStr("mock-token"),
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
title="Google OAuth (limited)",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[insufficient_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_matches_user_password_credential_for_email_block(self):
|
||||
"""resolve_block_credentials should match a user/password cred for an SMTP block."""
|
||||
from backend.blocks.email_block import SendEmailBlock
|
||||
|
||||
block = SendEmailBlock()
|
||||
input_data = {}
|
||||
|
||||
mock_cred = UserPasswordCredentials(
|
||||
id="smtp-cred-id",
|
||||
provider="smtp",
|
||||
username=SecretStr("test-user"),
|
||||
password=SecretStr("test-pass"),
|
||||
title="SMTP Credentials",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "smtp-cred-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_user_password_for_wrong_provider(self):
|
||||
"""resolve_block_credentials should report missing when user/password provider mismatches."""
|
||||
from backend.blocks.email_block import SendEmailBlock
|
||||
|
||||
block = SendEmailBlock()
|
||||
input_data = {}
|
||||
|
||||
wrong_cred = UserPasswordCredentials(
|
||||
id="wrong-cred-id",
|
||||
provider="dataforseo",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
title="DataForSEO Creds",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[wrong_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RunBlockTool integration tests for authenticated HTTP
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunBlockToolAuthenticatedHttp:
|
||||
"""End-to-end tests for RunBlockTool with SendAuthenticatedWebRequestBlock."""
|
||||
|
||||
async def test_returns_setup_requirements_when_creds_missing(self):
|
||||
"""When no matching host-scoped credential exists, return SetupRequirementsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=block.id,
|
||||
input_data={"url": "https://api.example.com/data", "method": "GET"},
|
||||
)
|
||||
|
||||
assert isinstance(response, SetupRequirementsResponse)
|
||||
assert "credentials" in response.message.lower()
|
||||
|
||||
async def test_returns_details_when_creds_matched_but_missing_required_inputs(self):
|
||||
"""When creds present + required inputs missing -> BlockDetailsResponse.
|
||||
|
||||
Note: with input_data={}, no URL is provided so discriminator_values is
|
||||
empty, meaning _credential_is_for_host() matches any host-scoped
|
||||
credential vacuously. This test exercises the "creds present + inputs
|
||||
missing" branch, not host-based matching (which is covered by
|
||||
TestFindMatchingHostScopedCredential and TestResolveBlockCredentials).
|
||||
"""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
|
||||
mock_cred = HostScopedCredentials(
|
||||
id="matching-cred-id",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
title="Example API Cred",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
# Call with empty input to get schema
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=block.id,
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockDetailsResponse)
|
||||
assert response.block.name == block.name
|
||||
# The matched credential should be included in the details
|
||||
assert len(response.block.credentials) > 0
|
||||
assert response.block.credentials[0].id == "matching-cred-id"
|
||||
@@ -121,7 +121,7 @@ def _serialize_missing_credential(
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
scopes = sorted(field_info.required_scopes or [])
|
||||
|
||||
return {
|
||||
result: dict[str, Any] = {
|
||||
"id": field_key,
|
||||
"title": field_key.replace("_", " ").title(),
|
||||
"provider": provider,
|
||||
@@ -131,6 +131,17 @@ def _serialize_missing_credential(
|
||||
"scopes": scopes,
|
||||
}
|
||||
|
||||
# Include discriminator info so the frontend can auto-match
|
||||
# host-scoped credentials (e.g. SendAuthenticatedWebRequestBlock).
|
||||
if field_info.discriminator:
|
||||
result["discriminator"] = field_info.discriminator
|
||||
if field_info.discriminator_values:
|
||||
result["discriminator_values"] = sorted(
|
||||
str(v) for v in field_info.discriminator_values
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def build_missing_credentials_from_graph(
|
||||
graph: GraphModel, matched_credentials: dict[str, CredentialsMetaInput] | None
|
||||
|
||||
@@ -344,6 +344,7 @@ class DatabaseManager(AppService):
|
||||
get_next_sequence = _(chat_db.get_next_sequence)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
update_chat_session_title = _(chat_db.update_chat_session_title)
|
||||
set_turn_duration = _(chat_db.set_turn_duration)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
@@ -540,3 +541,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_next_sequence = d.get_next_sequence
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
update_chat_session_title = d.update_chat_session_title
|
||||
set_turn_duration = d.set_turn_duration
|
||||
|
||||
@@ -722,7 +722,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
credentials_scopes=self.required_scopes,
|
||||
discriminator=self.discriminator,
|
||||
discriminator_mapping=self.discriminator_mapping,
|
||||
discriminator_values=self.discriminator_values,
|
||||
discriminator_values=set(self.discriminator_values), # defensive copy
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
-- Add durationMs column to ChatMessage for persisting turn elapsed time.
|
||||
ALTER TABLE "ChatMessage" ADD COLUMN "durationMs" INTEGER;
|
||||
@@ -246,7 +246,8 @@ model ChatMessage {
|
||||
functionCall Json? // Deprecated but kept for compatibility
|
||||
|
||||
// Ordering within session
|
||||
sequence Int
|
||||
sequence Int
|
||||
durationMs Int? // Wall-clock milliseconds for this assistant turn
|
||||
|
||||
@@unique([sessionId, sequence])
|
||||
}
|
||||
|
||||
@@ -92,6 +92,8 @@ export function CopilotPage() {
|
||||
isDeleting,
|
||||
handleConfirmDelete,
|
||||
handleCancelDelete,
|
||||
// Historical durations for persisted timer stats
|
||||
historicalDurations,
|
||||
// Rate limit reset
|
||||
rateLimitMessage,
|
||||
dismissRateLimit,
|
||||
@@ -175,6 +177,7 @@ export function CopilotPage() {
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
historicalDurations={historicalDurations}
|
||||
/>
|
||||
</div>
|
||||
{isMobile && (
|
||||
|
||||
@@ -27,6 +27,8 @@ export interface ChatContainerProps {
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||
onDroppedFilesConsumed?: () => void;
|
||||
/** Duration in ms for historical turns, keyed by message ID. */
|
||||
historicalDurations?: Map<string, number>;
|
||||
}
|
||||
export const ChatContainer = ({
|
||||
messages,
|
||||
@@ -44,6 +46,7 @@ export const ChatContainer = ({
|
||||
isUploadingFiles,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
historicalDurations,
|
||||
}: ChatContainerProps) => {
|
||||
const isBusy =
|
||||
status === "streaming" ||
|
||||
@@ -81,6 +84,7 @@ export const ChatContainer = ({
|
||||
isLoading={isLoadingSession}
|
||||
sessionID={sessionId}
|
||||
onRetry={handleRetry}
|
||||
historicalDurations={historicalDurations}
|
||||
/>
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useMemo } from "react";
|
||||
import { useEffect, useMemo, useRef } from "react";
|
||||
import {
|
||||
Conversation,
|
||||
ConversationContent,
|
||||
@@ -13,6 +13,7 @@ import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner
|
||||
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { TOOL_PART_PREFIX } from "../JobStatsBar/constants";
|
||||
import { TurnStatsBar } from "../JobStatsBar/TurnStatsBar";
|
||||
import { useElapsedTimer } from "../JobStatsBar/useElapsedTimer";
|
||||
import { CopilotPendingReviews } from "../CopilotPendingReviews/CopilotPendingReviews";
|
||||
import {
|
||||
buildRenderSegments,
|
||||
@@ -37,6 +38,7 @@ interface Props {
|
||||
isLoading: boolean;
|
||||
sessionID?: string | null;
|
||||
onRetry?: () => void;
|
||||
historicalDurations?: Map<string, number>;
|
||||
}
|
||||
|
||||
function renderSegments(
|
||||
@@ -111,6 +113,7 @@ export function ChatMessagesContainer({
|
||||
isLoading,
|
||||
sessionID,
|
||||
onRetry,
|
||||
historicalDurations,
|
||||
}: Props) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
const graphExecId = useMemo(() => extractGraphExecId(messages), [messages]);
|
||||
@@ -139,6 +142,25 @@ export function ChatMessagesContainer({
|
||||
const showThinking =
|
||||
status === "submitted" || (status === "streaming" && !hasInflight);
|
||||
|
||||
const isActivelyStreaming = status === "streaming" || status === "submitted";
|
||||
const { elapsedSeconds } = useElapsedTimer(isActivelyStreaming);
|
||||
|
||||
// Freeze elapsed time when streaming ends so TurnStatsBar shows the final value.
|
||||
// Reset when a new streaming turn begins.
|
||||
const frozenElapsedRef = useRef(0);
|
||||
const wasStreamingRef = useRef(false);
|
||||
useEffect(() => {
|
||||
if (isActivelyStreaming) {
|
||||
if (!wasStreamingRef.current) {
|
||||
frozenElapsedRef.current = 0;
|
||||
}
|
||||
if (elapsedSeconds > 0) {
|
||||
frozenElapsedRef.current = elapsedSeconds;
|
||||
}
|
||||
}
|
||||
wasStreamingRef.current = isActivelyStreaming;
|
||||
});
|
||||
|
||||
return (
|
||||
<Conversation className="min-h-0 flex-1">
|
||||
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
|
||||
@@ -239,10 +261,19 @@ export function ChatMessagesContainer({
|
||||
{isLastInTurn && !isCurrentlyStreaming && (
|
||||
<TurnStatsBar
|
||||
turnMessages={getTurnMessages(messages, messageIndex)}
|
||||
elapsedSeconds={
|
||||
messageIndex === messages.length - 1
|
||||
? frozenElapsedRef.current
|
||||
: undefined
|
||||
}
|
||||
durationMs={historicalDurations?.get(message.id)}
|
||||
/>
|
||||
)}
|
||||
{isLastAssistant && showThinking && (
|
||||
<ThinkingIndicator active={showThinking} />
|
||||
<ThinkingIndicator
|
||||
active={showThinking}
|
||||
elapsedSeconds={elapsedSeconds}
|
||||
/>
|
||||
)}
|
||||
</MessageContent>
|
||||
{message.role === "user" && textParts.length > 0 && (
|
||||
@@ -268,7 +299,10 @@ export function ChatMessagesContainer({
|
||||
{showThinking && lastMessage?.role !== "assistant" && (
|
||||
<Message from="assistant">
|
||||
<MessageContent className="text-[1rem] leading-relaxed">
|
||||
<ThinkingIndicator active={showThinking} />
|
||||
<ThinkingIndicator
|
||||
active={showThinking}
|
||||
elapsedSeconds={elapsedSeconds}
|
||||
/>
|
||||
</MessageContent>
|
||||
</Message>
|
||||
)}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { formatElapsed } from "../../JobStatsBar/formatElapsed";
|
||||
import { ScaleLoader } from "../../ScaleLoader/ScaleLoader";
|
||||
|
||||
const THINKING_PHRASES = [
|
||||
@@ -27,6 +28,9 @@ const THINKING_PHRASES = [
|
||||
const PHRASE_CYCLE_MS = 6_000;
|
||||
const FADE_DURATION_MS = 300;
|
||||
|
||||
/** Only show elapsed time after this many seconds. */
|
||||
const SHOW_TIME_AFTER_SECONDS = 20;
|
||||
|
||||
/**
|
||||
* Cycles through thinking phrases sequentially with a fade-out/in transition.
|
||||
* Returns the current phrase and whether it's visible (for opacity).
|
||||
@@ -72,10 +76,12 @@ function useCyclingPhrase(active: boolean) {
|
||||
|
||||
interface Props {
|
||||
active: boolean;
|
||||
elapsedSeconds: number;
|
||||
}
|
||||
|
||||
export function ThinkingIndicator({ active }: Props) {
|
||||
export function ThinkingIndicator({ active, elapsedSeconds }: Props) {
|
||||
const { phrase, visible } = useCyclingPhrase(active);
|
||||
const showTime = active && elapsedSeconds >= SHOW_TIME_AFTER_SECONDS;
|
||||
|
||||
return (
|
||||
<span className="inline-flex items-center gap-1.5 text-neutral-500">
|
||||
@@ -88,6 +94,11 @@ export function ThinkingIndicator({ active }: Props) {
|
||||
{phrase}
|
||||
</span>
|
||||
</span>
|
||||
{showTime && (
|
||||
<span className="animate-pulse tabular-nums [animation-duration:1.5s]">
|
||||
• {formatElapsed(elapsedSeconds)}
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,21 +1,44 @@
|
||||
import type { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { formatElapsed } from "./formatElapsed";
|
||||
import { getWorkDoneCounters } from "./useWorkDoneCounters";
|
||||
|
||||
interface Props {
|
||||
turnMessages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||
elapsedSeconds?: number;
|
||||
durationMs?: number;
|
||||
}
|
||||
|
||||
export function TurnStatsBar({ turnMessages }: Props) {
|
||||
export function TurnStatsBar({
|
||||
turnMessages,
|
||||
elapsedSeconds,
|
||||
durationMs,
|
||||
}: Props) {
|
||||
const { counters } = getWorkDoneCounters(turnMessages);
|
||||
|
||||
if (counters.length === 0) return null;
|
||||
// Prefer live elapsedSeconds, fall back to persisted durationMs
|
||||
const displaySeconds =
|
||||
elapsedSeconds !== undefined && elapsedSeconds > 0
|
||||
? elapsedSeconds
|
||||
: durationMs !== undefined
|
||||
? Math.round(durationMs / 1000)
|
||||
: undefined;
|
||||
|
||||
const hasTime = displaySeconds !== undefined && displaySeconds > 0;
|
||||
|
||||
if (counters.length === 0 && !hasTime) return null;
|
||||
|
||||
return (
|
||||
<div className="mt-2 flex items-center gap-1.5">
|
||||
{hasTime && (
|
||||
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||
Thought for {formatElapsed(displaySeconds)}
|
||||
</span>
|
||||
)}
|
||||
{counters.map(function renderCounter(counter, index) {
|
||||
const needsDot = index > 0 || hasTime;
|
||||
return (
|
||||
<span key={counter.category} className="flex items-center gap-1">
|
||||
{index > 0 && (
|
||||
{needsDot && (
|
||||
<span className="text-xs text-neutral-300">·</span>
|
||||
)}
|
||||
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
export function formatElapsed(totalSeconds: number): string {
|
||||
const minutes = Math.floor(totalSeconds / 60);
|
||||
const seconds = totalSeconds % 60;
|
||||
|
||||
if (minutes === 0) return `${seconds}s`;
|
||||
return `${minutes}m ${seconds}s`;
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
|
||||
export function useElapsedTimer(isRunning: boolean) {
|
||||
const [elapsedSeconds, setElapsedSeconds] = useState(0);
|
||||
const startTimeRef = useRef<number | null>(null);
|
||||
const intervalRef = useRef<ReturnType<typeof setInterval>>();
|
||||
|
||||
useEffect(() => {
|
||||
if (isRunning) {
|
||||
if (startTimeRef.current === null) {
|
||||
startTimeRef.current = Date.now();
|
||||
setElapsedSeconds(0);
|
||||
}
|
||||
|
||||
intervalRef.current = setInterval(() => {
|
||||
if (startTimeRef.current !== null) {
|
||||
setElapsedSeconds(
|
||||
Math.floor((Date.now() - startTimeRef.current) / 1000),
|
||||
);
|
||||
}
|
||||
}, 1000);
|
||||
|
||||
return () => clearInterval(intervalRef.current);
|
||||
}
|
||||
|
||||
clearInterval(intervalRef.current);
|
||||
startTimeRef.current = null;
|
||||
}, [isRunning]);
|
||||
|
||||
return { elapsedSeconds };
|
||||
}
|
||||
@@ -6,6 +6,7 @@ interface SessionChatMessage {
|
||||
content: string | null;
|
||||
tool_call_id: string | null;
|
||||
tool_calls: unknown[] | null;
|
||||
duration_ms: number | null;
|
||||
}
|
||||
|
||||
function coerceSessionChatMessages(
|
||||
@@ -34,6 +35,8 @@ function coerceSessionChatMessages(
|
||||
? null
|
||||
: String(msg.tool_call_id),
|
||||
tool_calls: Array.isArray(msg.tool_calls) ? msg.tool_calls : null,
|
||||
duration_ms:
|
||||
typeof msg.duration_ms === "number" ? msg.duration_ms : null,
|
||||
};
|
||||
})
|
||||
.filter((m): m is SessionChatMessage => m !== null);
|
||||
@@ -102,7 +105,10 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
sessionId: string,
|
||||
rawMessages: unknown[],
|
||||
options?: { isComplete?: boolean },
|
||||
): UIMessage<unknown, UIDataTypes, UITools>[] {
|
||||
): {
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||
durations: Map<string, number>;
|
||||
} {
|
||||
const messages = coerceSessionChatMessages(rawMessages);
|
||||
const toolOutputsByCallId = new Map<string, unknown>();
|
||||
|
||||
@@ -114,6 +120,7 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
}
|
||||
|
||||
const uiMessages: UIMessage<unknown, UIDataTypes, UITools>[] = [];
|
||||
const durations = new Map<string, number>();
|
||||
|
||||
messages.forEach((msg, index) => {
|
||||
if (msg.role === "tool") return;
|
||||
@@ -186,15 +193,24 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
const prevUI = uiMessages[uiMessages.length - 1];
|
||||
if (msg.role === "assistant" && prevUI && prevUI.role === "assistant") {
|
||||
prevUI.parts.push(...parts);
|
||||
// Capture duration on merged message (last assistant msg wins)
|
||||
if (msg.duration_ms != null) {
|
||||
durations.set(prevUI.id, msg.duration_ms);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const msgId = `${sessionId}-${index}`;
|
||||
uiMessages.push({
|
||||
id: `${sessionId}-${index}`,
|
||||
id: msgId,
|
||||
role: msg.role,
|
||||
parts,
|
||||
});
|
||||
|
||||
if (msg.role === "assistant" && msg.duration_ms != null) {
|
||||
durations.set(msgId, msg.duration_ms);
|
||||
}
|
||||
});
|
||||
|
||||
return uiMessages;
|
||||
return { messages: uiMessages, durations };
|
||||
}
|
||||
|
||||
@@ -41,7 +41,15 @@ export function coerceCredentialFields(rawMissingCredentials: unknown): {
|
||||
? cred.scopes.filter((s): s is string => typeof s === "string")
|
||||
: undefined;
|
||||
|
||||
const schema = {
|
||||
const discriminator =
|
||||
typeof cred.discriminator === "string" ? cred.discriminator : undefined;
|
||||
const discriminatorValues = Array.isArray(cred.discriminator_values)
|
||||
? cred.discriminator_values.filter(
|
||||
(v): v is string => typeof v === "string",
|
||||
)
|
||||
: undefined;
|
||||
|
||||
const schema: Record<string, unknown> = {
|
||||
type: "object" as const,
|
||||
properties: {},
|
||||
credentials_provider: [provider],
|
||||
@@ -49,6 +57,13 @@ export function coerceCredentialFields(rawMissingCredentials: unknown): {
|
||||
credentials_scopes: scopes,
|
||||
};
|
||||
|
||||
if (discriminator) {
|
||||
schema.discriminator = discriminator;
|
||||
}
|
||||
if (discriminatorValues && discriminatorValues.length > 0) {
|
||||
schema.discriminator_values = discriminatorValues;
|
||||
}
|
||||
|
||||
credentialFields.push([key, schema]);
|
||||
requiredCredentials.add(key);
|
||||
});
|
||||
|
||||
@@ -61,13 +61,21 @@ export function useChatSession() {
|
||||
// array reference every render. Re-derives only when query data changes.
|
||||
// When the session is complete (no active stream), mark dangling tool
|
||||
// calls as completed so stale spinners don't persist after refresh.
|
||||
const hydratedMessages = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200 || !sessionId) return undefined;
|
||||
return convertChatSessionMessagesToUiMessages(
|
||||
const { hydratedMessages, historicalDurations } = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200 || !sessionId)
|
||||
return {
|
||||
hydratedMessages: undefined,
|
||||
historicalDurations: new Map<string, number>(),
|
||||
};
|
||||
const result = convertChatSessionMessagesToUiMessages(
|
||||
sessionId,
|
||||
sessionQuery.data.data.messages ?? [],
|
||||
{ isComplete: !hasActiveStream },
|
||||
);
|
||||
return {
|
||||
hydratedMessages: result.messages,
|
||||
historicalDurations: result.durations,
|
||||
};
|
||||
}, [sessionQuery.data, sessionId, hasActiveStream]);
|
||||
|
||||
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
|
||||
@@ -122,6 +130,7 @@ export function useChatSession() {
|
||||
sessionId,
|
||||
setSessionId,
|
||||
hydratedMessages,
|
||||
historicalDurations,
|
||||
hasActiveStream,
|
||||
isLoadingSession: sessionQuery.isLoading,
|
||||
isSessionError: sessionQuery.isError,
|
||||
|
||||
@@ -39,6 +39,7 @@ export function useCopilotPage() {
|
||||
sessionId,
|
||||
setSessionId,
|
||||
hydratedMessages,
|
||||
historicalDurations,
|
||||
hasActiveStream,
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
@@ -377,6 +378,8 @@ export function useCopilotPage() {
|
||||
handleDeleteClick,
|
||||
handleConfirmDelete,
|
||||
handleCancelDelete,
|
||||
// Historical durations for persisted timer stats
|
||||
historicalDurations,
|
||||
// Rate limit reset
|
||||
rateLimitMessage,
|
||||
dismissRateLimit,
|
||||
|
||||
Reference in New Issue
Block a user