Resolve merge conflicts with dev

https://claude.ai/code/session_01TPw8kd7p8qwsuNa5qBRcHc
This commit is contained in:
Claude
2026-03-30 11:45:31 +00:00
21 changed files with 1208 additions and 33 deletions

View File

@@ -18,7 +18,7 @@ from prisma.types import (
from backend.data import db
from backend.util.json import SafeJson, sanitize_string
from .model import ChatMessage, ChatSession, ChatSessionInfo
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
logger = logging.getLogger(__name__)
@@ -217,6 +217,9 @@ async def add_chat_messages_batch(
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
if msg.get("duration_ms") is not None:
data["durationMs"] = msg["duration_ms"]
messages_data.append(data)
# Run create_many and session update in parallel within transaction
@@ -359,3 +362,22 @@ async def update_tool_message_content(
f"tool_call_id {tool_call_id}: {e}"
)
return False
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
"""Set durationMs on the last assistant message in a session.
Also invalidates the Redis session cache so the next GET returns
the updated duration.
"""
last_msg = await PrismaChatMessage.prisma().find_first(
where={"sessionId": session_id, "role": "assistant"},
order={"sequence": "desc"},
)
if last_msg:
await PrismaChatMessage.prisma().update(
where={"id": last_msg.id},
data={"durationMs": duration_ms},
)
# Invalidate cache so the session is re-fetched from DB with durationMs
await invalidate_session_cache(session_id)

View File

@@ -54,6 +54,7 @@ class ChatMessage(BaseModel):
refusal: str | None = None
tool_calls: list[dict] | None = None
function_call: dict | None = None
duration_ms: int | None = None
@staticmethod
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
@@ -66,6 +67,7 @@ class ChatMessage(BaseModel):
refusal=prisma_message.refusal,
tool_calls=_parse_json_field(prisma_message.toolCalls),
function_call=_parse_json_field(prisma_message.functionCall),
duration_ms=prisma_message.durationMs,
)

View File

@@ -26,6 +26,7 @@ import orjson
from redis.exceptions import RedisError
from backend.api.model import CopilotCompletionPayload
from backend.data.db_accessors import chat_db
from backend.data.notification_bus import (
AsyncRedisNotificationEventBus,
NotificationEvent,
@@ -111,6 +112,14 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
``session_id`` is used as a fallback for ``turn_id`` when the meta hash
pre-dates the turn_id field (backward compat for in-flight sessions).
"""
created_at = datetime.now(timezone.utc)
created_at_raw = meta.get("created_at")
if created_at_raw:
try:
created_at = datetime.fromisoformat(str(created_at_raw))
except (ValueError, TypeError):
pass
return ActiveSession(
session_id=meta.get("session_id", "") or session_id,
user_id=meta.get("user_id", "") or None,
@@ -119,6 +128,7 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
turn_id=meta.get("turn_id", "") or session_id,
blocking=meta.get("blocking") == "1",
status=meta.get("status", "running"), # type: ignore[arg-type]
created_at=created_at,
)
@@ -802,6 +812,33 @@ async def mark_session_completed(
f"Failed to publish error event for session {session_id}: {e}"
)
# Compute wall-clock duration from session created_at.
# Only persist when (a) the session completed successfully and
# (b) created_at was actually present in Redis meta (not a fallback).
duration_ms: int | None = None
if meta and not error_message:
created_at_raw = meta.get("created_at")
if created_at_raw:
try:
created_at = datetime.fromisoformat(str(created_at_raw))
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
elapsed = datetime.now(timezone.utc) - created_at
duration_ms = max(0, int(elapsed.total_seconds() * 1000))
except (ValueError, TypeError):
logger.warning(
"Failed to compute session duration for %s (created_at=%r)",
session_id,
created_at_raw,
)
# Persist duration on the last assistant message
if duration_ms is not None:
try:
await chat_db().set_turn_duration(session_id, duration_ms)
except Exception as e:
logger.warning(f"Failed to save turn duration for {session_id}: {e}")
# Publish StreamFinish AFTER status is set to "completed"/"failed".
# This is the SINGLE place that publishes StreamFinish — services and
# the processor must NOT publish it themselves.

View File

@@ -537,7 +537,7 @@ async def check_hitl_review(
)
synthetic_node_exec_id = (
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}" f"{uuid.uuid4().hex[:8]}"
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}{uuid.uuid4().hex[:8]}"
)
review_context = ExecutionContext(
@@ -582,7 +582,16 @@ def _resolve_discriminated_credentials(
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
"""Resolve credential requirements, applying discriminator logic where needed.
Handles two discrimination modes:
1. **Provider-based** (``discriminator_mapping`` is set): the discriminator
field value selects the provider (e.g. an AI model name -> provider).
2. **URL/host-based** (``discriminator`` is set but ``discriminator_mapping``
is ``None``): the discriminator field value (typically a URL) is added to
``discriminator_values`` so that host-scoped credential matching can
compare the credential's host against the target URL.
"""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
@@ -592,25 +601,42 @@ def _resolve_discriminated_credentials(
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
if field_info.discriminator:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(field_info.discriminator)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
"Discriminated provider for %s: %s -> %s",
field_name,
discriminator_value,
effective_field_info.provider,
)
if discriminator_value is not None:
if field_info.discriminator_mapping:
# Provider-based discrimination (e.g. model -> provider)
if discriminator_value in field_info.discriminator_mapping:
effective_field_info = field_info.discriminate(
discriminator_value
)
effective_field_info.discriminator_values.add(
discriminator_value
)
# Model names are safe to log (not PII); URLs are
# intentionally omitted in the host-based branch below.
logger.debug(
"Discriminated provider for %s: %s -> %s",
field_name,
discriminator_value,
effective_field_info.provider,
)
else:
# URL/host-based discrimination (e.g. url -> host matching).
# Deep copy to avoid mutating the cached schema-level
# field_info (model_copy() is shallow — the mutable set
# would be shared).
effective_field_info = field_info.model_copy(deep=True)
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
"Added discriminator value for host matching on %s",
field_name,
)
resolved[field_name] = effective_field_info

View File

@@ -0,0 +1,916 @@
"""Tests for credential resolution across all credential types in the CoPilot.
These tests verify that:
1. `_resolve_discriminated_credentials` correctly populates discriminator_values
for URL-based (host-scoped) and provider-based (api_key) credential fields.
2. `find_matching_credential` correctly matches credentials for all types:
APIKeyCredentials, OAuth2Credentials, UserPasswordCredentials, and
HostScopedCredentials.
3. The full `resolve_block_credentials` flow correctly resolves matching
credentials or reports them as missing for each credential type.
4. `RunBlockTool._execute` end-to-end tests return correct response types.
"""
from unittest.mock import AsyncMock, patch
from pydantic import SecretStr
from backend.blocks.http import SendAuthenticatedWebRequestBlock
from backend.data.model import (
APIKeyCredentials,
CredentialsFieldInfo,
CredentialsType,
HostScopedCredentials,
OAuth2Credentials,
UserPasswordCredentials,
)
from backend.integrations.providers import ProviderName
from ._test_data import make_session
from .helpers import _resolve_discriminated_credentials, resolve_block_credentials
from .models import BlockDetailsResponse, SetupRequirementsResponse
from .run_block import RunBlockTool
from .utils import find_matching_credential
_TEST_USER_ID = "test-user-http-cred"
# Properly typed constants to avoid type: ignore on CredentialsFieldInfo construction.
_HOST_SCOPED_TYPES: frozenset[CredentialsType] = frozenset(["host_scoped"])
_API_KEY_TYPES: frozenset[CredentialsType] = frozenset(["api_key"])
_OAUTH2_TYPES: frozenset[CredentialsType] = frozenset(["oauth2"])
_USER_PASSWORD_TYPES: frozenset[CredentialsType] = frozenset(["user_password"])
# ---------------------------------------------------------------------------
# _resolve_discriminated_credentials tests
# ---------------------------------------------------------------------------
class TestResolveDiscriminatedCredentials:
"""Tests for _resolve_discriminated_credentials with URL-based discrimination."""
def _get_auth_block(self):
return SendAuthenticatedWebRequestBlock()
def test_url_discriminator_populates_discriminator_values(self):
"""When input_data contains a URL, discriminator_values should include it."""
block = self._get_auth_block()
input_data = {"url": "https://api.example.com/v1/data"}
result = _resolve_discriminated_credentials(block, input_data)
assert "credentials" in result
field_info = result["credentials"]
assert "https://api.example.com/v1/data" in field_info.discriminator_values
def test_url_discriminator_without_url_keeps_empty_values(self):
"""When no URL is provided, discriminator_values should remain empty."""
block = self._get_auth_block()
input_data = {}
result = _resolve_discriminated_credentials(block, input_data)
assert "credentials" in result
field_info = result["credentials"]
assert len(field_info.discriminator_values) == 0
def test_url_discriminator_does_not_mutate_original_field_info(self):
"""The original block schema field_info must not be mutated."""
block = self._get_auth_block()
# Grab a reference to the original schema-level field_info
original_info = block.input_schema.get_credentials_fields_info()["credentials"]
# Call with a URL, which adds to discriminator_values on the copy
_resolve_discriminated_credentials(
block, {"url": "https://api.example.com/v1/data"}
)
# The original object must remain unchanged
assert len(original_info.discriminator_values) == 0
# And a fresh call without URL should also return empty values
result = _resolve_discriminated_credentials(block, {})
field_info = result["credentials"]
assert len(field_info.discriminator_values) == 0
def test_url_discriminator_preserves_provider_and_type(self):
"""Provider and supported_types should be preserved after URL discrimination."""
block = self._get_auth_block()
input_data = {"url": "https://api.example.com/v1/data"}
result = _resolve_discriminated_credentials(block, input_data)
field_info = result["credentials"]
assert ProviderName.HTTP in field_info.provider
assert "host_scoped" in field_info.supported_types
def test_provider_discriminator_still_works(self):
"""Verify provider-based discrimination (e.g. model -> provider) is preserved.
The refactored conditional in _resolve_discriminated_credentials split the
original single ``if`` into nested ``if/else`` branches. This test ensures
the provider-based path still narrows the provider correctly.
"""
from backend.blocks.llm import AITextGeneratorBlock
block = AITextGeneratorBlock()
input_data = {"model": "gpt-4o-mini"}
result = _resolve_discriminated_credentials(block, input_data)
assert "credentials" in result
field_info = result["credentials"]
# Should narrow provider to openai
assert ProviderName.OPENAI in field_info.provider
assert "gpt-4o-mini" in field_info.discriminator_values
# ---------------------------------------------------------------------------
# find_matching_credential tests (host-scoped)
# ---------------------------------------------------------------------------
class TestFindMatchingHostScopedCredential:
"""Tests for find_matching_credential with host-scoped credentials."""
def _make_host_scoped_cred(
self, host: str, cred_id: str = "test-cred-id"
) -> HostScopedCredentials:
return HostScopedCredentials(
id=cred_id,
provider="http",
host=host,
headers={"Authorization": SecretStr("Bearer test-token")},
title=f"Cred for {host}",
)
def _make_field_info(
self, discriminator_values: set | None = None
) -> CredentialsFieldInfo:
return CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.HTTP]),
credentials_types=_HOST_SCOPED_TYPES,
credentials_scopes=None,
discriminator="url",
discriminator_values=discriminator_values or set(),
)
def test_matches_credential_for_correct_host(self):
"""A host-scoped credential matching the URL host should be returned."""
cred = self._make_host_scoped_cred("api.example.com")
field_info = self._make_field_info({"https://api.example.com/v1/data"})
result = find_matching_credential([cred], field_info)
assert result is not None
assert result.id == cred.id
def test_rejects_credential_for_wrong_host(self):
"""A host-scoped credential for a different host should not match."""
cred = self._make_host_scoped_cred("api.github.com")
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
result = find_matching_credential([cred], field_info)
assert result is None
def test_matches_any_when_no_discriminator_values(self):
"""With empty discriminator_values, any host-scoped credential matches.
Note: this tests the current fallback behavior in _credential_is_for_host()
where empty discriminator_values means "no host constraint" and any
host-scoped credential is accepted. This is by design for the case where
the target URL is not yet known (e.g. schema preview with empty input).
"""
cred = self._make_host_scoped_cred("api.anything.com")
field_info = self._make_field_info(set())
result = find_matching_credential([cred], field_info)
assert result is not None
def test_wildcard_host_matching(self):
"""Wildcard host (*.example.com) should match subdomains."""
cred = self._make_host_scoped_cred("*.example.com")
field_info = self._make_field_info({"https://api.example.com/v1/data"})
result = find_matching_credential([cred], field_info)
assert result is not None
def test_selects_correct_credential_from_multiple(self):
"""When multiple host-scoped credentials exist, the correct one is selected."""
cred_github = self._make_host_scoped_cred("api.github.com", "github-cred")
cred_stripe = self._make_host_scoped_cred("api.stripe.com", "stripe-cred")
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
result = find_matching_credential([cred_github, cred_stripe], field_info)
assert result is not None
assert result.id == "stripe-cred"
# ---------------------------------------------------------------------------
# find_matching_credential tests (api_key)
# ---------------------------------------------------------------------------
class TestFindMatchingAPIKeyCredential:
"""Tests for find_matching_credential with API key credentials."""
def _make_api_key_cred(
self, provider: str = "google_maps", cred_id: str = "test-api-key-id"
) -> APIKeyCredentials:
return APIKeyCredentials(
id=cred_id,
provider=provider,
api_key=SecretStr("sk-test-key-123"),
title=f"API key for {provider}",
expires_at=None,
)
def _make_field_info(
self, provider: ProviderName = ProviderName.GOOGLE_MAPS
) -> CredentialsFieldInfo:
return CredentialsFieldInfo(
credentials_provider=frozenset([provider]),
credentials_types=_API_KEY_TYPES,
credentials_scopes=None,
)
def test_matches_credential_for_correct_provider(self):
"""An API key credential matching the provider should be returned."""
cred = self._make_api_key_cred("google_maps")
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
result = find_matching_credential([cred], field_info)
assert result is not None
assert result.id == cred.id
def test_rejects_credential_for_wrong_provider(self):
"""An API key credential for a different provider should not match."""
cred = self._make_api_key_cred("openai")
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
result = find_matching_credential([cred], field_info)
assert result is None
def test_rejects_credential_for_wrong_type(self):
"""An OAuth2 credential should not match an api_key requirement."""
oauth_cred = OAuth2Credentials(
id="oauth-cred-id",
provider="google_maps",
access_token=SecretStr("mock-token"),
scopes=[],
title="OAuth cred (wrong type)",
)
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
result = find_matching_credential([oauth_cred], field_info)
assert result is None
def test_selects_correct_credential_from_multiple(self):
"""When multiple API key credentials exist, the correct provider is selected."""
cred_maps = self._make_api_key_cred("google_maps", "maps-key")
cred_openai = self._make_api_key_cred("openai", "openai-key")
field_info = self._make_field_info(ProviderName.OPENAI)
result = find_matching_credential([cred_maps, cred_openai], field_info)
assert result is not None
assert result.id == "openai-key"
def test_returns_none_when_no_credentials(self):
"""Should return None when the credential list is empty."""
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
result = find_matching_credential([], field_info)
assert result is None
# ---------------------------------------------------------------------------
# find_matching_credential tests (oauth2)
# ---------------------------------------------------------------------------
class TestFindMatchingOAuth2Credential:
"""Tests for find_matching_credential with OAuth2 credentials."""
def _make_oauth2_cred(
self,
provider: str = "google",
scopes: list[str] | None = None,
cred_id: str = "test-oauth2-id",
) -> OAuth2Credentials:
return OAuth2Credentials(
id=cred_id,
provider=provider,
access_token=SecretStr("mock-access-token"),
refresh_token=SecretStr("mock-refresh-token"),
access_token_expires_at=1234567890,
scopes=scopes or [],
title=f"OAuth2 cred for {provider}",
)
def _make_field_info(
self,
provider: ProviderName = ProviderName.GOOGLE,
required_scopes: frozenset[str] | None = None,
) -> CredentialsFieldInfo:
return CredentialsFieldInfo(
credentials_provider=frozenset([provider]),
credentials_types=_OAUTH2_TYPES,
credentials_scopes=required_scopes,
)
def test_matches_credential_for_correct_provider(self):
"""An OAuth2 credential matching the provider should be returned."""
cred = self._make_oauth2_cred("google")
field_info = self._make_field_info(ProviderName.GOOGLE)
result = find_matching_credential([cred], field_info)
assert result is not None
assert result.id == cred.id
def test_rejects_credential_for_wrong_provider(self):
"""An OAuth2 credential for a different provider should not match."""
cred = self._make_oauth2_cred("github")
field_info = self._make_field_info(ProviderName.GOOGLE)
result = find_matching_credential([cred], field_info)
assert result is None
def test_matches_credential_with_required_scopes(self):
"""An OAuth2 credential with all required scopes should match."""
cred = self._make_oauth2_cred(
"google",
scopes=[
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
],
)
field_info = self._make_field_info(
ProviderName.GOOGLE,
required_scopes=frozenset(
["https://www.googleapis.com/auth/gmail.readonly"]
),
)
result = find_matching_credential([cred], field_info)
assert result is not None
def test_rejects_credential_with_insufficient_scopes(self):
"""An OAuth2 credential missing required scopes should not match."""
cred = self._make_oauth2_cred(
"google",
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
)
field_info = self._make_field_info(
ProviderName.GOOGLE,
required_scopes=frozenset(
[
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
]
),
)
result = find_matching_credential([cred], field_info)
assert result is None
def test_matches_credential_when_no_scopes_required(self):
"""An OAuth2 credential should match when no scopes are required."""
cred = self._make_oauth2_cred("google", scopes=[])
field_info = self._make_field_info(ProviderName.GOOGLE)
result = find_matching_credential([cred], field_info)
assert result is not None
def test_selects_correct_credential_from_multiple(self):
"""When multiple OAuth2 credentials exist, the correct one is selected."""
cred_google = self._make_oauth2_cred("google", cred_id="google-cred")
cred_github = self._make_oauth2_cred("github", cred_id="github-cred")
field_info = self._make_field_info(ProviderName.GITHUB)
result = find_matching_credential([cred_google, cred_github], field_info)
assert result is not None
assert result.id == "github-cred"
def test_returns_none_when_no_credentials(self):
"""Should return None when the credential list is empty."""
field_info = self._make_field_info(ProviderName.GOOGLE)
result = find_matching_credential([], field_info)
assert result is None
# ---------------------------------------------------------------------------
# find_matching_credential tests (user_password)
# ---------------------------------------------------------------------------
class TestFindMatchingUserPasswordCredential:
"""Tests for find_matching_credential with user/password credentials."""
def _make_user_password_cred(
self, provider: str = "smtp", cred_id: str = "test-userpass-id"
) -> UserPasswordCredentials:
return UserPasswordCredentials(
id=cred_id,
provider=provider,
username=SecretStr("test-user"),
password=SecretStr("test-pass"),
title=f"Credentials for {provider}",
)
def _make_field_info(
self, provider: ProviderName = ProviderName.SMTP
) -> CredentialsFieldInfo:
return CredentialsFieldInfo(
credentials_provider=frozenset([provider]),
credentials_types=_USER_PASSWORD_TYPES,
credentials_scopes=None,
)
def test_matches_credential_for_correct_provider(self):
"""A user/password credential matching the provider should be returned."""
cred = self._make_user_password_cred("smtp")
field_info = self._make_field_info(ProviderName.SMTP)
result = find_matching_credential([cred], field_info)
assert result is not None
assert result.id == cred.id
def test_rejects_credential_for_wrong_provider(self):
"""A user/password credential for a different provider should not match."""
cred = self._make_user_password_cred("smtp")
field_info = self._make_field_info(ProviderName.HUBSPOT)
result = find_matching_credential([cred], field_info)
assert result is None
def test_rejects_credential_for_wrong_type(self):
"""An API key credential should not match a user_password requirement."""
api_key_cred = APIKeyCredentials(
id="api-key-cred-id",
provider="smtp",
api_key=SecretStr("wrong-type-key"),
title="API key cred (wrong type)",
)
field_info = self._make_field_info(ProviderName.SMTP)
result = find_matching_credential([api_key_cred], field_info)
assert result is None
def test_selects_correct_credential_from_multiple(self):
"""When multiple user/password credentials exist, the correct one is selected."""
cred_smtp = self._make_user_password_cred("smtp", "smtp-cred")
cred_hubspot = self._make_user_password_cred("hubspot", "hubspot-cred")
field_info = self._make_field_info(ProviderName.HUBSPOT)
result = find_matching_credential([cred_smtp, cred_hubspot], field_info)
assert result is not None
assert result.id == "hubspot-cred"
def test_returns_none_when_no_credentials(self):
"""Should return None when the credential list is empty."""
field_info = self._make_field_info(ProviderName.SMTP)
result = find_matching_credential([], field_info)
assert result is None
# ---------------------------------------------------------------------------
# find_matching_credential tests (mixed credential types)
# ---------------------------------------------------------------------------
class TestFindMatchingCredentialMixedTypes:
"""Tests that find_matching_credential correctly filters by type in a mixed list."""
def test_selects_api_key_from_mixed_list(self):
"""API key requirement should skip OAuth2 and user_password credentials."""
oauth_cred = OAuth2Credentials(
id="oauth-id",
provider="openai",
access_token=SecretStr("token"),
scopes=[],
)
userpass_cred = UserPasswordCredentials(
id="userpass-id",
provider="openai",
username=SecretStr("user"),
password=SecretStr("pass"),
)
api_key_cred = APIKeyCredentials(
id="apikey-id",
provider="openai",
api_key=SecretStr("sk-key"),
)
field_info = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.OPENAI]),
credentials_types=_API_KEY_TYPES,
credentials_scopes=None,
)
result = find_matching_credential(
[oauth_cred, userpass_cred, api_key_cred], field_info
)
assert result is not None
assert result.id == "apikey-id"
def test_selects_oauth2_from_mixed_list(self):
"""OAuth2 requirement should skip API key and user_password credentials."""
api_key_cred = APIKeyCredentials(
id="apikey-id",
provider="google",
api_key=SecretStr("key"),
)
userpass_cred = UserPasswordCredentials(
id="userpass-id",
provider="google",
username=SecretStr("user"),
password=SecretStr("pass"),
)
oauth_cred = OAuth2Credentials(
id="oauth-id",
provider="google",
access_token=SecretStr("token"),
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
)
field_info = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.GOOGLE]),
credentials_types=_OAUTH2_TYPES,
credentials_scopes=frozenset(
["https://www.googleapis.com/auth/gmail.readonly"]
),
)
result = find_matching_credential(
[api_key_cred, userpass_cred, oauth_cred], field_info
)
assert result is not None
assert result.id == "oauth-id"
def test_selects_user_password_from_mixed_list(self):
"""User/password requirement should skip API key and OAuth2 credentials."""
api_key_cred = APIKeyCredentials(
id="apikey-id",
provider="smtp",
api_key=SecretStr("key"),
)
oauth_cred = OAuth2Credentials(
id="oauth-id",
provider="smtp",
access_token=SecretStr("token"),
scopes=[],
)
userpass_cred = UserPasswordCredentials(
id="userpass-id",
provider="smtp",
username=SecretStr("user"),
password=SecretStr("pass"),
)
field_info = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.SMTP]),
credentials_types=_USER_PASSWORD_TYPES,
credentials_scopes=None,
)
result = find_matching_credential(
[api_key_cred, oauth_cred, userpass_cred], field_info
)
assert result is not None
assert result.id == "userpass-id"
def test_returns_none_when_only_wrong_types_available(self):
"""Should return None when all available creds have the wrong type."""
oauth_cred = OAuth2Credentials(
id="oauth-id",
provider="google_maps",
access_token=SecretStr("token"),
scopes=[],
)
field_info = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.GOOGLE_MAPS]),
credentials_types=_API_KEY_TYPES,
credentials_scopes=None,
)
result = find_matching_credential([oauth_cred], field_info)
assert result is None
# ---------------------------------------------------------------------------
# resolve_block_credentials tests (integration — all credential types)
# ---------------------------------------------------------------------------
class TestResolveBlockCredentials:
"""Integration tests for resolve_block_credentials across credential types."""
async def test_matches_host_scoped_credential_for_url(self):
"""resolve_block_credentials should match a host-scoped cred for the given URL."""
block = SendAuthenticatedWebRequestBlock()
input_data = {"url": "https://api.example.com/v1/data"}
mock_cred = HostScopedCredentials(
id="matching-cred-id",
provider="http",
host="api.example.com",
headers={"Authorization": SecretStr("Bearer token")},
title="Example API Cred",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert "credentials" in matched
assert matched["credentials"].id == "matching-cred-id"
assert len(missing) == 0
async def test_reports_missing_when_no_matching_host(self):
"""resolve_block_credentials should report missing creds when host doesn't match."""
block = SendAuthenticatedWebRequestBlock()
input_data = {"url": "https://api.stripe.com/v1/charges"}
wrong_host_cred = HostScopedCredentials(
id="wrong-cred-id",
provider="http",
host="api.github.com",
headers={"Authorization": SecretStr("Bearer token")},
title="GitHub API Cred",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[wrong_host_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
async def test_reports_missing_when_no_credentials(self):
"""resolve_block_credentials should report missing when user has no creds at all."""
block = SendAuthenticatedWebRequestBlock()
input_data = {"url": "https://api.example.com/v1/data"}
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
async def test_matches_api_key_credential_for_llm_block(self):
"""resolve_block_credentials should match an API key cred for an LLM block."""
from backend.blocks.llm import AITextGeneratorBlock
block = AITextGeneratorBlock()
input_data = {"model": "gpt-4o-mini"}
mock_cred = APIKeyCredentials(
id="openai-key-id",
provider="openai",
api_key=SecretStr("sk-test-key"),
title="OpenAI API Key",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert "credentials" in matched
assert matched["credentials"].id == "openai-key-id"
assert len(missing) == 0
async def test_reports_missing_api_key_for_wrong_provider(self):
"""resolve_block_credentials should report missing when API key provider mismatches."""
from backend.blocks.llm import AITextGeneratorBlock
block = AITextGeneratorBlock()
input_data = {"model": "gpt-4o-mini"}
wrong_provider_cred = APIKeyCredentials(
id="wrong-key-id",
provider="google_maps",
api_key=SecretStr("sk-wrong"),
title="Google Maps Key",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[wrong_provider_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
async def test_matches_oauth2_credential_for_google_block(self):
"""resolve_block_credentials should match an OAuth2 cred for a Google block."""
from backend.blocks.google.gmail import GmailReadBlock
block = GmailReadBlock()
input_data = {}
mock_cred = OAuth2Credentials(
id="google-oauth-id",
provider="google",
access_token=SecretStr("mock-token"),
refresh_token=SecretStr("mock-refresh"),
access_token_expires_at=9999999999,
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
title="Google OAuth",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert "credentials" in matched
assert matched["credentials"].id == "google-oauth-id"
assert len(missing) == 0
async def test_reports_missing_oauth2_with_insufficient_scopes(self):
"""resolve_block_credentials should report missing when OAuth2 scopes are insufficient."""
from backend.blocks.google.gmail import GmailSendBlock
block = GmailSendBlock()
input_data = {}
# GmailSendBlock requires gmail.send scope; provide only readonly
insufficient_cred = OAuth2Credentials(
id="limited-oauth-id",
provider="google",
access_token=SecretStr("mock-token"),
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
title="Google OAuth (limited)",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[insufficient_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
async def test_matches_user_password_credential_for_email_block(self):
"""resolve_block_credentials should match a user/password cred for an SMTP block."""
from backend.blocks.email_block import SendEmailBlock
block = SendEmailBlock()
input_data = {}
mock_cred = UserPasswordCredentials(
id="smtp-cred-id",
provider="smtp",
username=SecretStr("test-user"),
password=SecretStr("test-pass"),
title="SMTP Credentials",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert "credentials" in matched
assert matched["credentials"].id == "smtp-cred-id"
assert len(missing) == 0
async def test_reports_missing_user_password_for_wrong_provider(self):
"""resolve_block_credentials should report missing when user/password provider mismatches."""
from backend.blocks.email_block import SendEmailBlock
block = SendEmailBlock()
input_data = {}
wrong_cred = UserPasswordCredentials(
id="wrong-cred-id",
provider="dataforseo",
username=SecretStr("user"),
password=SecretStr("pass"),
title="DataForSEO Creds",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[wrong_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
# ---------------------------------------------------------------------------
# RunBlockTool integration tests for authenticated HTTP
# ---------------------------------------------------------------------------
class TestRunBlockToolAuthenticatedHttp:
"""End-to-end tests for RunBlockTool with SendAuthenticatedWebRequestBlock."""
async def test_returns_setup_requirements_when_creds_missing(self):
"""When no matching host-scoped credential exists, return SetupRequirementsResponse."""
session = make_session(user_id=_TEST_USER_ID)
block = SendAuthenticatedWebRequestBlock()
with patch(
"backend.copilot.tools.helpers.get_block",
return_value=block,
):
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[],
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id=block.id,
input_data={"url": "https://api.example.com/data", "method": "GET"},
)
assert isinstance(response, SetupRequirementsResponse)
assert "credentials" in response.message.lower()
async def test_returns_details_when_creds_matched_but_missing_required_inputs(self):
"""When creds present + required inputs missing -> BlockDetailsResponse.
Note: with input_data={}, no URL is provided so discriminator_values is
empty, meaning _credential_is_for_host() matches any host-scoped
credential vacuously. This test exercises the "creds present + inputs
missing" branch, not host-based matching (which is covered by
TestFindMatchingHostScopedCredential and TestResolveBlockCredentials).
"""
session = make_session(user_id=_TEST_USER_ID)
block = SendAuthenticatedWebRequestBlock()
mock_cred = HostScopedCredentials(
id="matching-cred-id",
provider="http",
host="api.example.com",
headers={"Authorization": SecretStr("Bearer token")},
title="Example API Cred",
)
with patch(
"backend.copilot.tools.helpers.get_block",
return_value=block,
):
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
tool = RunBlockTool()
# Call with empty input to get schema
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id=block.id,
input_data={},
)
assert isinstance(response, BlockDetailsResponse)
assert response.block.name == block.name
# The matched credential should be included in the details
assert len(response.block.credentials) > 0
assert response.block.credentials[0].id == "matching-cred-id"

View File

@@ -121,7 +121,7 @@ def _serialize_missing_credential(
provider = next(iter(field_info.provider), "unknown")
scopes = sorted(field_info.required_scopes or [])
return {
result: dict[str, Any] = {
"id": field_key,
"title": field_key.replace("_", " ").title(),
"provider": provider,
@@ -131,6 +131,17 @@ def _serialize_missing_credential(
"scopes": scopes,
}
# Include discriminator info so the frontend can auto-match
# host-scoped credentials (e.g. SendAuthenticatedWebRequestBlock).
if field_info.discriminator:
result["discriminator"] = field_info.discriminator
if field_info.discriminator_values:
result["discriminator_values"] = sorted(
str(v) for v in field_info.discriminator_values
)
return result
def build_missing_credentials_from_graph(
graph: GraphModel, matched_credentials: dict[str, CredentialsMetaInput] | None

View File

@@ -344,6 +344,7 @@ class DatabaseManager(AppService):
get_next_sequence = _(chat_db.get_next_sequence)
update_tool_message_content = _(chat_db.update_tool_message_content)
update_chat_session_title = _(chat_db.update_chat_session_title)
set_turn_duration = _(chat_db.set_turn_duration)
class DatabaseManagerClient(AppServiceClient):
@@ -540,3 +541,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_next_sequence = d.get_next_sequence
update_tool_message_content = d.update_tool_message_content
update_chat_session_title = d.update_chat_session_title
set_turn_duration = d.set_turn_duration

View File

@@ -722,7 +722,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
credentials_scopes=self.required_scopes,
discriminator=self.discriminator,
discriminator_mapping=self.discriminator_mapping,
discriminator_values=self.discriminator_values,
discriminator_values=set(self.discriminator_values), # defensive copy
)

View File

@@ -0,0 +1,2 @@
-- Add durationMs column to ChatMessage for persisting turn elapsed time.
ALTER TABLE "ChatMessage" ADD COLUMN "durationMs" INTEGER;

View File

@@ -246,7 +246,8 @@ model ChatMessage {
functionCall Json? // Deprecated but kept for compatibility
// Ordering within session
sequence Int
sequence Int
durationMs Int? // Wall-clock milliseconds for this assistant turn
@@unique([sessionId, sequence])
}

View File

@@ -92,6 +92,8 @@ export function CopilotPage() {
isDeleting,
handleConfirmDelete,
handleCancelDelete,
// Historical durations for persisted timer stats
historicalDurations,
// Rate limit reset
rateLimitMessage,
dismissRateLimit,
@@ -175,6 +177,7 @@ export function CopilotPage() {
isUploadingFiles={isUploadingFiles}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
historicalDurations={historicalDurations}
/>
</div>
{isMobile && (

View File

@@ -27,6 +27,8 @@ export interface ChatContainerProps {
droppedFiles?: File[];
/** Called after droppedFiles have been consumed by ChatInput. */
onDroppedFilesConsumed?: () => void;
/** Duration in ms for historical turns, keyed by message ID. */
historicalDurations?: Map<string, number>;
}
export const ChatContainer = ({
messages,
@@ -44,6 +46,7 @@ export const ChatContainer = ({
isUploadingFiles,
droppedFiles,
onDroppedFilesConsumed,
historicalDurations,
}: ChatContainerProps) => {
const isBusy =
status === "streaming" ||
@@ -81,6 +84,7 @@ export const ChatContainer = ({
isLoading={isLoadingSession}
sessionID={sessionId}
onRetry={handleRetry}
historicalDurations={historicalDurations}
/>
<motion.div
initial={{ opacity: 0 }}

View File

@@ -1,4 +1,4 @@
import { useMemo } from "react";
import { useEffect, useMemo, useRef } from "react";
import {
Conversation,
ConversationContent,
@@ -13,6 +13,7 @@ import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { TOOL_PART_PREFIX } from "../JobStatsBar/constants";
import { TurnStatsBar } from "../JobStatsBar/TurnStatsBar";
import { useElapsedTimer } from "../JobStatsBar/useElapsedTimer";
import { CopilotPendingReviews } from "../CopilotPendingReviews/CopilotPendingReviews";
import {
buildRenderSegments,
@@ -37,6 +38,7 @@ interface Props {
isLoading: boolean;
sessionID?: string | null;
onRetry?: () => void;
historicalDurations?: Map<string, number>;
}
function renderSegments(
@@ -111,6 +113,7 @@ export function ChatMessagesContainer({
isLoading,
sessionID,
onRetry,
historicalDurations,
}: Props) {
const lastMessage = messages[messages.length - 1];
const graphExecId = useMemo(() => extractGraphExecId(messages), [messages]);
@@ -139,6 +142,25 @@ export function ChatMessagesContainer({
const showThinking =
status === "submitted" || (status === "streaming" && !hasInflight);
const isActivelyStreaming = status === "streaming" || status === "submitted";
const { elapsedSeconds } = useElapsedTimer(isActivelyStreaming);
// Freeze elapsed time when streaming ends so TurnStatsBar shows the final value.
// Reset when a new streaming turn begins.
const frozenElapsedRef = useRef(0);
const wasStreamingRef = useRef(false);
useEffect(() => {
if (isActivelyStreaming) {
if (!wasStreamingRef.current) {
frozenElapsedRef.current = 0;
}
if (elapsedSeconds > 0) {
frozenElapsedRef.current = elapsedSeconds;
}
}
wasStreamingRef.current = isActivelyStreaming;
});
return (
<Conversation className="min-h-0 flex-1">
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
@@ -239,10 +261,19 @@ export function ChatMessagesContainer({
{isLastInTurn && !isCurrentlyStreaming && (
<TurnStatsBar
turnMessages={getTurnMessages(messages, messageIndex)}
elapsedSeconds={
messageIndex === messages.length - 1
? frozenElapsedRef.current
: undefined
}
durationMs={historicalDurations?.get(message.id)}
/>
)}
{isLastAssistant && showThinking && (
<ThinkingIndicator active={showThinking} />
<ThinkingIndicator
active={showThinking}
elapsedSeconds={elapsedSeconds}
/>
)}
</MessageContent>
{message.role === "user" && textParts.length > 0 && (
@@ -268,7 +299,10 @@ export function ChatMessagesContainer({
{showThinking && lastMessage?.role !== "assistant" && (
<Message from="assistant">
<MessageContent className="text-[1rem] leading-relaxed">
<ThinkingIndicator active={showThinking} />
<ThinkingIndicator
active={showThinking}
elapsedSeconds={elapsedSeconds}
/>
</MessageContent>
</Message>
)}

View File

@@ -1,4 +1,5 @@
import { useEffect, useRef, useState } from "react";
import { formatElapsed } from "../../JobStatsBar/formatElapsed";
import { ScaleLoader } from "../../ScaleLoader/ScaleLoader";
const THINKING_PHRASES = [
@@ -27,6 +28,9 @@ const THINKING_PHRASES = [
const PHRASE_CYCLE_MS = 6_000;
const FADE_DURATION_MS = 300;
/** Only show elapsed time after this many seconds. */
const SHOW_TIME_AFTER_SECONDS = 20;
/**
* Cycles through thinking phrases sequentially with a fade-out/in transition.
* Returns the current phrase and whether it's visible (for opacity).
@@ -72,10 +76,12 @@ function useCyclingPhrase(active: boolean) {
interface Props {
active: boolean;
elapsedSeconds: number;
}
export function ThinkingIndicator({ active }: Props) {
export function ThinkingIndicator({ active, elapsedSeconds }: Props) {
const { phrase, visible } = useCyclingPhrase(active);
const showTime = active && elapsedSeconds >= SHOW_TIME_AFTER_SECONDS;
return (
<span className="inline-flex items-center gap-1.5 text-neutral-500">
@@ -88,6 +94,11 @@ export function ThinkingIndicator({ active }: Props) {
{phrase}
</span>
</span>
{showTime && (
<span className="animate-pulse tabular-nums [animation-duration:1.5s]">
{formatElapsed(elapsedSeconds)}
</span>
)}
</span>
);
}

View File

@@ -1,21 +1,44 @@
import type { UIDataTypes, UIMessage, UITools } from "ai";
import { formatElapsed } from "./formatElapsed";
import { getWorkDoneCounters } from "./useWorkDoneCounters";
interface Props {
turnMessages: UIMessage<unknown, UIDataTypes, UITools>[];
elapsedSeconds?: number;
durationMs?: number;
}
export function TurnStatsBar({ turnMessages }: Props) {
export function TurnStatsBar({
turnMessages,
elapsedSeconds,
durationMs,
}: Props) {
const { counters } = getWorkDoneCounters(turnMessages);
if (counters.length === 0) return null;
// Prefer live elapsedSeconds, fall back to persisted durationMs
const displaySeconds =
elapsedSeconds !== undefined && elapsedSeconds > 0
? elapsedSeconds
: durationMs !== undefined
? Math.round(durationMs / 1000)
: undefined;
const hasTime = displaySeconds !== undefined && displaySeconds > 0;
if (counters.length === 0 && !hasTime) return null;
return (
<div className="mt-2 flex items-center gap-1.5">
{hasTime && (
<span className="text-[11px] tabular-nums text-neutral-500">
Thought for {formatElapsed(displaySeconds)}
</span>
)}
{counters.map(function renderCounter(counter, index) {
const needsDot = index > 0 || hasTime;
return (
<span key={counter.category} className="flex items-center gap-1">
{index > 0 && (
{needsDot && (
<span className="text-xs text-neutral-300">&middot;</span>
)}
<span className="text-[11px] tabular-nums text-neutral-500">

View File

@@ -0,0 +1,7 @@
export function formatElapsed(totalSeconds: number): string {
const minutes = Math.floor(totalSeconds / 60);
const seconds = totalSeconds % 60;
if (minutes === 0) return `${seconds}s`;
return `${minutes}m ${seconds}s`;
}

View File

@@ -0,0 +1,31 @@
import { useEffect, useRef, useState } from "react";
export function useElapsedTimer(isRunning: boolean) {
const [elapsedSeconds, setElapsedSeconds] = useState(0);
const startTimeRef = useRef<number | null>(null);
const intervalRef = useRef<ReturnType<typeof setInterval>>();
useEffect(() => {
if (isRunning) {
if (startTimeRef.current === null) {
startTimeRef.current = Date.now();
setElapsedSeconds(0);
}
intervalRef.current = setInterval(() => {
if (startTimeRef.current !== null) {
setElapsedSeconds(
Math.floor((Date.now() - startTimeRef.current) / 1000),
);
}
}, 1000);
return () => clearInterval(intervalRef.current);
}
clearInterval(intervalRef.current);
startTimeRef.current = null;
}, [isRunning]);
return { elapsedSeconds };
}

View File

@@ -6,6 +6,7 @@ interface SessionChatMessage {
content: string | null;
tool_call_id: string | null;
tool_calls: unknown[] | null;
duration_ms: number | null;
}
function coerceSessionChatMessages(
@@ -34,6 +35,8 @@ function coerceSessionChatMessages(
? null
: String(msg.tool_call_id),
tool_calls: Array.isArray(msg.tool_calls) ? msg.tool_calls : null,
duration_ms:
typeof msg.duration_ms === "number" ? msg.duration_ms : null,
};
})
.filter((m): m is SessionChatMessage => m !== null);
@@ -102,7 +105,10 @@ export function convertChatSessionMessagesToUiMessages(
sessionId: string,
rawMessages: unknown[],
options?: { isComplete?: boolean },
): UIMessage<unknown, UIDataTypes, UITools>[] {
): {
messages: UIMessage<unknown, UIDataTypes, UITools>[];
durations: Map<string, number>;
} {
const messages = coerceSessionChatMessages(rawMessages);
const toolOutputsByCallId = new Map<string, unknown>();
@@ -114,6 +120,7 @@ export function convertChatSessionMessagesToUiMessages(
}
const uiMessages: UIMessage<unknown, UIDataTypes, UITools>[] = [];
const durations = new Map<string, number>();
messages.forEach((msg, index) => {
if (msg.role === "tool") return;
@@ -186,15 +193,24 @@ export function convertChatSessionMessagesToUiMessages(
const prevUI = uiMessages[uiMessages.length - 1];
if (msg.role === "assistant" && prevUI && prevUI.role === "assistant") {
prevUI.parts.push(...parts);
// Capture duration on merged message (last assistant msg wins)
if (msg.duration_ms != null) {
durations.set(prevUI.id, msg.duration_ms);
}
return;
}
const msgId = `${sessionId}-${index}`;
uiMessages.push({
id: `${sessionId}-${index}`,
id: msgId,
role: msg.role,
parts,
});
if (msg.role === "assistant" && msg.duration_ms != null) {
durations.set(msgId, msg.duration_ms);
}
});
return uiMessages;
return { messages: uiMessages, durations };
}

View File

@@ -41,7 +41,15 @@ export function coerceCredentialFields(rawMissingCredentials: unknown): {
? cred.scopes.filter((s): s is string => typeof s === "string")
: undefined;
const schema = {
const discriminator =
typeof cred.discriminator === "string" ? cred.discriminator : undefined;
const discriminatorValues = Array.isArray(cred.discriminator_values)
? cred.discriminator_values.filter(
(v): v is string => typeof v === "string",
)
: undefined;
const schema: Record<string, unknown> = {
type: "object" as const,
properties: {},
credentials_provider: [provider],
@@ -49,6 +57,13 @@ export function coerceCredentialFields(rawMissingCredentials: unknown): {
credentials_scopes: scopes,
};
if (discriminator) {
schema.discriminator = discriminator;
}
if (discriminatorValues && discriminatorValues.length > 0) {
schema.discriminator_values = discriminatorValues;
}
credentialFields.push([key, schema]);
requiredCredentials.add(key);
});

View File

@@ -61,13 +61,21 @@ export function useChatSession() {
// array reference every render. Re-derives only when query data changes.
// When the session is complete (no active stream), mark dangling tool
// calls as completed so stale spinners don't persist after refresh.
const hydratedMessages = useMemo(() => {
if (sessionQuery.data?.status !== 200 || !sessionId) return undefined;
return convertChatSessionMessagesToUiMessages(
const { hydratedMessages, historicalDurations } = useMemo(() => {
if (sessionQuery.data?.status !== 200 || !sessionId)
return {
hydratedMessages: undefined,
historicalDurations: new Map<string, number>(),
};
const result = convertChatSessionMessagesToUiMessages(
sessionId,
sessionQuery.data.data.messages ?? [],
{ isComplete: !hasActiveStream },
);
return {
hydratedMessages: result.messages,
historicalDurations: result.durations,
};
}, [sessionQuery.data, sessionId, hasActiveStream]);
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
@@ -122,6 +130,7 @@ export function useChatSession() {
sessionId,
setSessionId,
hydratedMessages,
historicalDurations,
hasActiveStream,
isLoadingSession: sessionQuery.isLoading,
isSessionError: sessionQuery.isError,

View File

@@ -39,6 +39,7 @@ export function useCopilotPage() {
sessionId,
setSessionId,
hydratedMessages,
historicalDurations,
hasActiveStream,
isLoadingSession,
isSessionError,
@@ -377,6 +378,8 @@ export function useCopilotPage() {
handleDeleteClick,
handleConfirmDelete,
handleCancelDelete,
// Historical durations for persisted timer stats
historicalDurations,
// Rate limit reset
rateLimitMessage,
dismissRateLimit,