mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
59 Commits
feat/githu
...
fix/block-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc8f176a05 | ||
|
|
d2e54d51e4 | ||
|
|
40f5532b1e | ||
|
|
a08673cdbe | ||
|
|
c87e954644 | ||
|
|
349b9fc009 | ||
|
|
f2f48e98c8 | ||
|
|
dbad60933d | ||
|
|
d13ebbf64b | ||
|
|
ec0d9ab6ff | ||
|
|
e706c5afc2 | ||
|
|
93fc88ca1e | ||
|
|
28ff7b6057 | ||
|
|
378bd3afcc | ||
|
|
19e79dd236 | ||
|
|
f5f4c7d8f9 | ||
|
|
7d869d8288 | ||
|
|
40df7165d0 | ||
|
|
a56dc42a59 | ||
|
|
288ced743b | ||
|
|
41e2e80f60 | ||
|
|
c3eac2a6af | ||
|
|
5b6b68e469 | ||
|
|
296372b8b9 | ||
|
|
ee8896c818 | ||
|
|
699ecc8cec | ||
|
|
211be3aff1 | ||
|
|
3120981e4b | ||
|
|
5966d3669d | ||
|
|
c81ab1fc3b | ||
|
|
5446c7f18f | ||
|
|
2b0c9ba703 | ||
|
|
195c7011ae | ||
|
|
d4944fb22b | ||
|
|
a5ed8fefa9 | ||
|
|
a52a777b29 | ||
|
|
8bec7a6933 | ||
|
|
e73791efed | ||
|
|
2d161ce2b9 | ||
|
|
6fc4989654 | ||
|
|
976443bf6e | ||
|
|
4ceb15b3f1 | ||
|
|
3096f94996 | ||
|
|
6f90729612 | ||
|
|
ebf89dde8b | ||
|
|
5d057e97e5 | ||
|
|
1d2f641a26 | ||
|
|
dcb71ab0b9 | ||
|
|
8136b90860 | ||
|
|
4d179a7c37 | ||
|
|
f78adcdc65 | ||
|
|
40388b7520 | ||
|
|
dd7be1158b | ||
|
|
c0e59f0a6b | ||
|
|
104d1f1bf4 | ||
|
|
d9e9cd4c98 | ||
|
|
ca416300ec | ||
|
|
c589cd0c43 | ||
|
|
b6d863fcd2 |
@@ -5,16 +5,26 @@ Pluggable system for different content sources (store agents, blocks, docs).
|
||||
Each handler knows how to fetch and process its content type for embedding.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, get_args, get_origin
|
||||
from typing import TYPE_CHECKING, Any, get_args, get_origin
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.blocks import get_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.util.text import split_camelcase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -154,6 +164,28 @@ class StoreAgentHandler(ContentHandler):
|
||||
}
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _get_enabled_blocks() -> dict[str, AnyBlockSchema]:
|
||||
"""Return ``{block_id: block_instance}`` for all enabled, instantiable blocks.
|
||||
|
||||
Disabled blocks and blocks that fail to instantiate are silently skipped
|
||||
(with a warning log), so callers never need their own try/except loop.
|
||||
|
||||
Results are cached for the process lifetime via ``lru_cache`` because
|
||||
blocks are registered at import time and never change while running.
|
||||
"""
|
||||
enabled: dict[str, AnyBlockSchema] = {}
|
||||
for block_id, block_cls in get_blocks().items():
|
||||
try:
|
||||
instance = block_cls()
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping block {block_id}: init failed: {e}")
|
||||
continue
|
||||
if not instance.disabled:
|
||||
enabled[block_id] = instance
|
||||
return enabled
|
||||
|
||||
|
||||
class BlockHandler(ContentHandler):
|
||||
"""Handler for block definitions (Python classes)."""
|
||||
|
||||
@@ -163,16 +195,14 @@ class BlockHandler(ContentHandler):
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch blocks without embeddings."""
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
# Get all available blocks
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Check which ones have embeddings
|
||||
if not all_blocks:
|
||||
# to_thread keeps the first (heavy) call off the event loop. On
|
||||
# subsequent calls the lru_cache makes this a dict lookup, so the
|
||||
# thread-pool overhead is negligible compared to the DB queries below.
|
||||
enabled = await asyncio.to_thread(_get_enabled_blocks)
|
||||
if not enabled:
|
||||
return []
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
block_ids = list(enabled.keys())
|
||||
|
||||
# Query for existing embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
@@ -187,52 +217,42 @@ class BlockHandler(ContentHandler):
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_blocks = [
|
||||
(block_id, block_cls)
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if block_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
# Convert to ContentItem — disabled filtering already done by
|
||||
# _get_enabled_blocks so batch_size won't be exhausted by disabled blocks.
|
||||
missing = ((bid, b) for bid, b in enabled.items() if bid not in existing_ids)
|
||||
items = []
|
||||
for block_id, block_cls in missing_blocks[:batch_size]:
|
||||
for block_id, block in itertools.islice(missing, batch_size):
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
if block_instance.disabled:
|
||||
continue
|
||||
|
||||
# Build searchable text from block metadata
|
||||
parts = []
|
||||
if block_instance.name:
|
||||
parts.append(block_instance.name)
|
||||
if block_instance.description:
|
||||
parts.append(block_instance.description)
|
||||
if block_instance.categories:
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
if not block.name:
|
||||
logger.warning(
|
||||
f"Block {block_id} has no name — using block_id as fallback"
|
||||
)
|
||||
display_name = split_camelcase(block.name) if block.name else ""
|
||||
parts = []
|
||||
if display_name:
|
||||
parts.append(display_name)
|
||||
if block.description:
|
||||
parts.append(block.description)
|
||||
if block.categories:
|
||||
parts.append(" ".join(str(cat.value) for cat in block.categories))
|
||||
|
||||
# Add input schema field descriptions
|
||||
block_input_fields = block_instance.input_schema.model_fields
|
||||
parts += [
|
||||
f"{field_name}: {field_info.description}"
|
||||
for field_name, field_info in block_input_fields.items()
|
||||
for field_name, field_info in block.input_schema.model_fields.items()
|
||||
if field_info.description
|
||||
]
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
categories_list = (
|
||||
[cat.value for cat in block_instance.categories]
|
||||
if block_instance.categories
|
||||
else []
|
||||
[cat.value for cat in block.categories] if block.categories else []
|
||||
)
|
||||
|
||||
# Extract provider names from credentials fields
|
||||
credentials_info = (
|
||||
block_instance.input_schema.get_credentials_fields_info()
|
||||
)
|
||||
credentials_info = block.input_schema.get_credentials_fields_info()
|
||||
is_integration = len(credentials_info) > 0
|
||||
provider_names = [
|
||||
provider.value.lower()
|
||||
@@ -243,7 +263,7 @@ class BlockHandler(ContentHandler):
|
||||
# Check if block has LlmModel field in input schema
|
||||
has_llm_model_field = any(
|
||||
_contains_type(field.annotation, LlmModel)
|
||||
for field in block_instance.input_schema.model_fields.values()
|
||||
for field in block.input_schema.model_fields.values()
|
||||
)
|
||||
|
||||
items.append(
|
||||
@@ -252,13 +272,13 @@ class BlockHandler(ContentHandler):
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": block_instance.name,
|
||||
"name": display_name or block.name or block_id,
|
||||
"categories": categories_list,
|
||||
"providers": provider_names,
|
||||
"has_llm_model_field": has_llm_model_field,
|
||||
"is_integration": is_integration,
|
||||
},
|
||||
user_id=None, # Blocks are public
|
||||
user_id=None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -269,22 +289,13 @@ class BlockHandler(ContentHandler):
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about block embedding coverage."""
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Filter out disabled blocks - they're not indexed
|
||||
enabled_block_ids = [
|
||||
block_id
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if not block_cls().disabled
|
||||
]
|
||||
total_blocks = len(enabled_block_ids)
|
||||
enabled = await asyncio.to_thread(_get_enabled_blocks)
|
||||
total_blocks = len(enabled)
|
||||
|
||||
if total_blocks == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
block_ids = enabled_block_ids
|
||||
block_ids = list(enabled.keys())
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""
|
||||
E2E tests for content handlers (blocks, store agents, documentation).
|
||||
|
||||
Tests the full flow: discovering content → generating embeddings → storing.
|
||||
Tests for content handlers (blocks, store agents, documentation).
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -15,15 +13,103 @@ from backend.api.features.store.content_handlers import (
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
_get_enabled_blocks,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_block_cache():
|
||||
"""Clear the lru_cache on _get_enabled_blocks before each test."""
|
||||
_get_enabled_blocks.cache_clear()
|
||||
yield
|
||||
_get_enabled_blocks.cache_clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper to build a mock block class that returns a pre-configured instance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_block_class(
|
||||
*,
|
||||
name: str = "Block",
|
||||
description: str = "",
|
||||
disabled: bool = False,
|
||||
categories: list[MagicMock] | None = None,
|
||||
fields: dict[str, str] | None = None,
|
||||
raise_on_init: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
cls = MagicMock()
|
||||
if raise_on_init is not None:
|
||||
cls.side_effect = raise_on_init
|
||||
return cls
|
||||
inst = MagicMock()
|
||||
inst.name = name
|
||||
inst.disabled = disabled
|
||||
inst.description = description
|
||||
inst.categories = categories or []
|
||||
field_mocks = {
|
||||
fname: MagicMock(description=fdesc) for fname, fdesc in (fields or {}).items()
|
||||
}
|
||||
inst.input_schema.model_fields = field_mocks
|
||||
inst.input_schema.get_credentials_fields_info.return_value = {}
|
||||
cls.return_value = inst
|
||||
return cls
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_enabled_blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_enabled_blocks_filters_disabled():
|
||||
"""Disabled blocks are excluded."""
|
||||
blocks = {
|
||||
"enabled": _make_block_class(name="E", disabled=False),
|
||||
"disabled": _make_block_class(name="D", disabled=True),
|
||||
}
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
result = _get_enabled_blocks()
|
||||
assert list(result.keys()) == ["enabled"]
|
||||
|
||||
|
||||
def test_get_enabled_blocks_skips_broken():
|
||||
"""Blocks that raise on init are skipped without crashing."""
|
||||
blocks = {
|
||||
"good": _make_block_class(name="Good"),
|
||||
"bad": _make_block_class(raise_on_init=RuntimeError("boom")),
|
||||
}
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
result = _get_enabled_blocks()
|
||||
assert list(result.keys()) == ["good"]
|
||||
|
||||
|
||||
def test_get_enabled_blocks_cached():
|
||||
"""_get_enabled_blocks() calls get_blocks() only once across multiple calls."""
|
||||
blocks = {"b1": _make_block_class(name="B1")}
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
) as mock_get_blocks:
|
||||
result1 = _get_enabled_blocks()
|
||||
result2 = _get_enabled_blocks()
|
||||
assert result1 is result2
|
||||
mock_get_blocks.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StoreAgentHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_missing_items(mocker):
|
||||
"""Test StoreAgentHandler fetches approved agents without embeddings."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock database query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "agent-1",
|
||||
@@ -54,9 +140,7 @@ async def test_store_agent_handler_get_stats(mocker):
|
||||
"""Test StoreAgentHandler returns correct stats."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock approved count query
|
||||
mock_approved = [{"count": 50}]
|
||||
# Mock embedded count query
|
||||
mock_embedded = [{"count": 30}]
|
||||
|
||||
with patch(
|
||||
@@ -70,74 +154,130 @@ async def test_store_agent_handler_get_stats(mocker):
|
||||
assert stats["without_embeddings"] == 20
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BlockHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items(mocker):
|
||||
async def test_block_handler_get_missing_items():
|
||||
"""Test BlockHandler discovers blocks without embeddings."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks to return test blocks
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Calculator Block"
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.disabled = False
|
||||
mock_field = MagicMock()
|
||||
mock_field.description = "Math expression to evaluate"
|
||||
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
|
||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
mock_existing = []
|
||||
blocks = {
|
||||
"block-uuid-1": _make_block_class(
|
||||
name="CalculatorBlock",
|
||||
description="Performs calculations",
|
||||
categories=[MagicMock(value="MATH")],
|
||||
fields={"expression": "Math expression to evaluate"},
|
||||
),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_existing,
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "block-uuid-1"
|
||||
assert items[0].content_type == ContentType.BLOCK
|
||||
# CamelCase should be split in searchable text and metadata name
|
||||
assert "Calculator Block" in items[0].searchable_text
|
||||
assert "Performs calculations" in items[0].searchable_text
|
||||
assert "MATH" in items[0].searchable_text
|
||||
assert "expression: Math expression" in items[0].searchable_text
|
||||
assert items[0].metadata["name"] == "Calculator Block"
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats(mocker):
|
||||
async def test_block_handler_get_missing_items_splits_camelcase():
|
||||
"""CamelCase block names are split for better search indexing."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"ai-block": _make_block_class(name="AITextGeneratorBlock"),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert "AI Text Generator Block" in items[0].searchable_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items_batch_size_zero():
|
||||
"""batch_size=0 returns an empty list; the DB is still queried to find missing IDs."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {"b1": _make_block_class(name="B1")}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
) as mock_query:
|
||||
items = await handler.get_missing_items(batch_size=0)
|
||||
assert items == []
|
||||
# DB query is still issued to learn which blocks lack embeddings;
|
||||
# the empty result comes from itertools.islice limiting to 0 items.
|
||||
mock_query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_disabled_dont_exhaust_batch():
|
||||
"""Disabled blocks don't consume batch budget, so enabled blocks get indexed."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# 5 disabled + 3 enabled, batch_size=2
|
||||
blocks = {
|
||||
**{
|
||||
f"dis-{i}": _make_block_class(name=f"D{i}", disabled=True) for i in range(5)
|
||||
},
|
||||
**{f"en-{i}": _make_block_class(name=f"E{i}") for i in range(3)},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=2)
|
||||
|
||||
assert len(items) == 2
|
||||
assert all(item.content_id.startswith("en-") for item in items)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats():
|
||||
"""Test BlockHandler returns correct stats."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks - each block class returns an instance with disabled=False
|
||||
def make_mock_block_class():
|
||||
mock_class = MagicMock()
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.disabled = False
|
||||
mock_class.return_value = mock_instance
|
||||
return mock_class
|
||||
|
||||
mock_blocks = {
|
||||
"block-1": make_mock_block_class(),
|
||||
"block-2": make_mock_block_class(),
|
||||
"block-3": make_mock_block_class(),
|
||||
blocks = {
|
||||
"block-1": _make_block_class(name="B1"),
|
||||
"block-2": _make_block_class(name="B2"),
|
||||
"block-3": _make_block_class(name="B3"),
|
||||
}
|
||||
|
||||
# Mock embedded count query (2 blocks have embeddings)
|
||||
mock_embedded = [{"count": 2}]
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
@@ -150,21 +290,123 @@ async def test_block_handler_get_stats(mocker):
|
||||
assert stats["without_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats_skips_broken():
|
||||
"""get_stats skips broken blocks instead of crashing."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"good": _make_block_class(name="Good"),
|
||||
"bad": _make_block_class(raise_on_init=RuntimeError("boom")),
|
||||
}
|
||||
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 1 # only the good block
|
||||
assert stats["with_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_none_name():
|
||||
"""When block.name is None the fallback display name logic is used."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"none-name-block": _make_block_class(
|
||||
name="placeholder", # will be overridden to None below
|
||||
description="A block with no name",
|
||||
),
|
||||
}
|
||||
# Override the name to None after construction so _make_block_class
|
||||
# doesn't interfere with the mock wiring.
|
||||
blocks["none-name-block"].return_value.name = None
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
# display_name should be "" because block.name is None
|
||||
# searchable_text should still contain the description
|
||||
assert "A block with no name" in items[0].searchable_text
|
||||
# metadata["name"] falls back to block_id when both display_name
|
||||
# and block.name are falsy, ensuring it is always a non-empty string.
|
||||
assert items[0].metadata["name"] == "none-name-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_empty_attributes():
|
||||
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {"block-minimal": _make_block_class(name="Minimal Block")}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"good-block": _make_block_class(name="Good Block", description="Works fine"),
|
||||
"bad-block": _make_block_class(raise_on_init=Exception("Instantiation failed")),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocumentationHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
"""Test DocumentationHandler discovers docs without embeddings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory with test files
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
|
||||
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
|
||||
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
|
||||
|
||||
# Mock _get_docs_root to return temp dir
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
@@ -173,7 +415,6 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
# Check guide.md (content_id format: doc_path::section_index)
|
||||
guide_item = next(
|
||||
(item for item in items if item.content_id == "guide.md::0"), None
|
||||
)
|
||||
@@ -184,7 +425,6 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
assert guide_item.metadata["doc_title"] == "Getting Started"
|
||||
assert guide_item.user_id is None
|
||||
|
||||
# Check api.mdx (content_id format: doc_path::section_index)
|
||||
api_item = next(
|
||||
(item for item in items if item.content_id == "api.mdx::0"), None
|
||||
)
|
||||
@@ -197,14 +437,12 @@ async def test_documentation_handler_get_stats(tmp_path, mocker):
|
||||
"""Test DocumentationHandler returns correct stats."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
(docs_root / "doc1.md").write_text("# Doc 1")
|
||||
(docs_root / "doc2.md").write_text("# Doc 2")
|
||||
(docs_root / "doc3.mdx").write_text("# Doc 3")
|
||||
|
||||
# Mock embedded count query (1 doc has embedding)
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
@@ -224,13 +462,11 @@ async def test_documentation_handler_title_extraction(tmp_path):
|
||||
"""Test DocumentationHandler extracts title from markdown heading."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test with heading
|
||||
doc_with_heading = tmp_path / "with_heading.md"
|
||||
doc_with_heading.write_text("# My Title\n\nContent here")
|
||||
title = handler._extract_doc_title(doc_with_heading)
|
||||
assert title == "My Title"
|
||||
|
||||
# Test without heading
|
||||
doc_without_heading = tmp_path / "no-heading.md"
|
||||
doc_without_heading.write_text("Just content, no heading")
|
||||
title = handler._extract_doc_title(doc_without_heading)
|
||||
@@ -242,7 +478,6 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
"""Test DocumentationHandler chunks markdown by headings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test document with multiple sections
|
||||
doc_with_sections = tmp_path / "sections.md"
|
||||
doc_with_sections.write_text(
|
||||
"# Document Title\n\n"
|
||||
@@ -254,7 +489,6 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
)
|
||||
sections = handler._chunk_markdown_by_headings(doc_with_sections)
|
||||
|
||||
# Should have 3 sections: intro (with doc title), section one, section two
|
||||
assert len(sections) == 3
|
||||
assert sections[0].title == "Document Title"
|
||||
assert sections[0].index == 0
|
||||
@@ -268,7 +502,6 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
assert sections[2].index == 2
|
||||
assert "Content for section two" in sections[2].content
|
||||
|
||||
# Test document without headings
|
||||
doc_no_sections = tmp_path / "no-sections.md"
|
||||
doc_no_sections.write_text("Just plain content without any headings.")
|
||||
sections = handler._chunk_markdown_by_headings(doc_no_sections)
|
||||
@@ -282,21 +515,39 @@ async def test_documentation_handler_section_content_ids():
|
||||
"""Test DocumentationHandler creates and parses section content IDs."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test making content ID
|
||||
content_id = handler._make_section_content_id("docs/guide.md", 2)
|
||||
assert content_id == "docs/guide.md::2"
|
||||
|
||||
# Test parsing content ID
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/guide.md::2")
|
||||
assert doc_path == "docs/guide.md"
|
||||
assert section_index == 2
|
||||
|
||||
# Test parsing legacy format (no section index)
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/old-format.md")
|
||||
assert doc_path == "docs/old-format.md"
|
||||
assert section_index == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handlers_registry():
|
||||
"""Test all content types are registered."""
|
||||
@@ -307,88 +558,3 @@ async def test_content_handlers_registry():
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_empty_attributes():
|
||||
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with empty values (all attributes exist but are falsy)
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
mock_block_instance.disabled = False
|
||||
mock_block_instance.description = ""
|
||||
mock_block_instance.categories = set()
|
||||
mock_block_instance.input_schema.model_fields = {}
|
||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock one good block and one bad block
|
||||
good_block = MagicMock()
|
||||
good_instance = MagicMock()
|
||||
good_instance.name = "Good Block"
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_instance.disabled = False
|
||||
good_instance.input_schema.model_fields = {}
|
||||
good_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
bad_block.side_effect = Exception("Instantiation failed")
|
||||
|
||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
# Should only get the good block
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Mock _get_docs_root to return non-existent path
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
|
||||
@@ -15,6 +15,7 @@ from prisma.enums import ContentType
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.api.features.store.content_handlers import CONTENT_HANDLERS
|
||||
from backend.blocks import get_blocks
|
||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.json import dumps
|
||||
@@ -662,8 +663,6 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
||||
)
|
||||
current_ids = {row["id"] for row in valid_agents}
|
||||
elif content_type == ContentType.BLOCK:
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
current_ids = set(get_blocks().keys())
|
||||
elif content_type == ContentType.DOCUMENTATION:
|
||||
# Use DocumentationHandler to get section-based content IDs
|
||||
|
||||
@@ -31,12 +31,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25 - lowercase and split on non-alphanumeric."""
|
||||
"""Tokenize text for BM25."""
|
||||
if not text:
|
||||
return []
|
||||
# Lowercase and split on non-alphanumeric characters
|
||||
tokens = re.findall(r"\b\w+\b", text.lower())
|
||||
return tokens
|
||||
return re.findall(r"\b\w+\b", text.lower())
|
||||
|
||||
|
||||
def bm25_rerank(
|
||||
|
||||
@@ -14,9 +14,27 @@ from backend.api.features.store.hybrid_search import (
|
||||
HybridSearchWeights,
|
||||
UnifiedSearchWeights,
|
||||
hybrid_search,
|
||||
tokenize,
|
||||
unified_hybrid_search,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tokenize (BM25)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text, expected",
|
||||
[
|
||||
("AITextGeneratorBlock", ["aitextgeneratorblock"]),
|
||||
("hello world", ["hello", "world"]),
|
||||
("", []),
|
||||
("HTTPRequest", ["httprequest"]),
|
||||
],
|
||||
)
|
||||
def test_tokenize(input_text: str, expected: list[str]):
|
||||
assert tokenize(input_text) == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Backward-compatibility shim — ``split_camelcase`` now lives in backend.util.text."""
|
||||
|
||||
from backend.util.text import split_camelcase # noqa: F401
|
||||
|
||||
__all__ = ["split_camelcase"]
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Tests for split_camelcase (now in backend.util.text)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.text import split_camelcase
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# split_camelcase
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text, expected",
|
||||
[
|
||||
("AITextGeneratorBlock", "AI Text Generator Block"),
|
||||
("HTTPRequestBlock", "HTTP Request Block"),
|
||||
("simpleWord", "simple Word"),
|
||||
("already spaced", "already spaced"),
|
||||
("XMLParser", "XML Parser"),
|
||||
("getHTTPResponse", "get HTTP Response"),
|
||||
("Block", "Block"),
|
||||
("", ""),
|
||||
("OAuth2Block", "OAuth2 Block"),
|
||||
("IOError", "IO Error"),
|
||||
("getHTTPSResponse", "get HTTPS Response"),
|
||||
# Known limitation: single-letter uppercase prefixes are NOT split.
|
||||
# "ABlock" stays "ABlock" because the algorithm requires the left
|
||||
# part of an uppercase run to retain at least 2 uppercase chars.
|
||||
("ABlock", "ABlock"),
|
||||
# Digit-to-uppercase transitions
|
||||
("Base64Encoder", "Base64 Encoder"),
|
||||
("UTF8Decoder", "UTF8 Decoder"),
|
||||
# Pure digits — no camelCase boundaries to split
|
||||
("123", "123"),
|
||||
# Known limitation: single-letter uppercase segments after digits
|
||||
# are not split from the following word. "3D" is only 1 uppercase
|
||||
# char so the uppercase-run rule cannot fire, producing "3 DRenderer"
|
||||
# rather than the ideal "3D Renderer".
|
||||
("3DRenderer", "3 DRenderer"),
|
||||
# Exception list — compound terms that should stay together
|
||||
("YouTubeBlock", "YouTube Block"),
|
||||
("OpenAIBlock", "OpenAI Block"),
|
||||
("AutoGPTAgent", "AutoGPT Agent"),
|
||||
("GitHubIntegration", "GitHub Integration"),
|
||||
("LinkedInBlock", "LinkedIn Block"),
|
||||
],
|
||||
)
|
||||
def test_split_camelcase(input_text: str, expected: str):
|
||||
assert split_camelcase(input_text) == expected
|
||||
@@ -1,162 +0,0 @@
|
||||
"""Integration credential lookup with per-process TTL cache.
|
||||
|
||||
Provides token retrieval for connected integrations so that copilot tools
|
||||
(e.g. bash_exec) can inject auth tokens into the execution environment without
|
||||
hitting the database on every command.
|
||||
|
||||
Cache semantics (handled automatically by TTLCache):
|
||||
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
|
||||
for users who have credentials and are running many bash commands.
|
||||
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
|
||||
on every E2B command for users who haven't connected an account yet, while
|
||||
still picking up a newly-connected account within one minute.
|
||||
|
||||
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
|
||||
least-recently-used entry when the limit is reached.
|
||||
|
||||
Multi-worker note: both caches are in-process only. Each worker/replica
|
||||
maintains its own independent cache, so a credential fetch may be duplicated
|
||||
across processes. This is acceptable for the current goal (reduce DB hits per
|
||||
session per-process), but if cache efficiency across replicas becomes important
|
||||
a shared cache (e.g. Redis) should be used instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
register_creds_changed_hook,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps provider slug → env var names to inject when the provider is connected.
|
||||
# Add new providers here when adding integration support.
|
||||
# NOTE: keep in sync with connect_integration._PROVIDER_INFO — both registries
|
||||
# must be updated when adding a new provider.
|
||||
PROVIDER_ENV_VARS: dict[str, list[str]] = {
|
||||
"github": ["GH_TOKEN", "GITHUB_TOKEN"],
|
||||
}
|
||||
|
||||
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
|
||||
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
|
||||
_CACHE_MAX_SIZE = 10_000
|
||||
|
||||
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
|
||||
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
|
||||
# because all callers (get_provider_token, invalidate_user_provider_cache) run
|
||||
# exclusively on the asyncio event loop. There are no await points between a
|
||||
# cache read and its corresponding write within any function, so no concurrent
|
||||
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
|
||||
# this path, a threading.RLock should be wrapped around these caches.
|
||||
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
|
||||
)
|
||||
# Separate cache for "no credentials" results with a shorter TTL.
|
||||
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
# entries. This avoids a lazy import inside creds_manager and eliminates the
|
||||
# circular-import risk.
|
||||
register_creds_changed_hook(invalidate_user_provider_cache)
|
||||
|
||||
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
|
||||
# on every cache-miss call to get_provider_token().
|
||||
_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
"""Return the user's access token for *provider*, or ``None`` if not connected.
|
||||
|
||||
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
|
||||
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
|
||||
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
|
||||
command for users who haven't connected yet, while still picking up a
|
||||
newly-connected account within one minute.
|
||||
"""
|
||||
cache_key = (user_id, provider)
|
||||
|
||||
if cache_key in _null_cache:
|
||||
return None
|
||||
if cached := _token_cache.get(cache_key):
|
||||
return cached
|
||||
|
||||
manager = _manager
|
||||
try:
|
||||
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
|
||||
except Exception:
|
||||
logger.debug("Failed to fetch %s credentials for user %s", provider, user_id)
|
||||
return None
|
||||
|
||||
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
|
||||
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
|
||||
# full git access, while a public-data-only token lacks push/pull permission.
|
||||
# lock=False — background injection; not worth a distributed lock acquisition.
|
||||
oauth2_creds = sorted(
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
fresh = await manager.refresh_if_needed(
|
||||
user_id, cast(OAuth2Credentials, creds), lock=False
|
||||
)
|
||||
token = fresh.access_token.get_secret_value()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh %s OAuth token for user %s; "
|
||||
"falling back to potentially stale token",
|
||||
provider,
|
||||
user_id,
|
||||
)
|
||||
token = cast(OAuth2Credentials, creds).access_token.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# Pass 2: fall back to API key (no expiry, no refresh needed).
|
||||
for creds in creds_list:
|
||||
if creds.type == "api_key":
|
||||
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
"""Return env vars for all providers the user has connected.
|
||||
|
||||
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
|
||||
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
|
||||
Only providers with a stored credential contribute entries.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
for provider, var_names in PROVIDER_ENV_VARS.items():
|
||||
token = await get_provider_token(user_id, provider)
|
||||
if token:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
@@ -1,193 +0,0 @@
|
||||
"""Tests for integration_creds — TTL cache and token lookup paths."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
get_provider_token,
|
||||
invalidate_user_provider_cache,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
_USER = "user-integration-creds-test"
|
||||
_PROVIDER = "github"
|
||||
|
||||
|
||||
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="creds-api-key",
|
||||
provider=_PROVIDER,
|
||||
api_key=SecretStr(key),
|
||||
title="Test API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id="creds-oauth2",
|
||||
provider=_PROVIDER,
|
||||
title="Test OAuth",
|
||||
access_token=SecretStr(token),
|
||||
refresh_token=SecretStr("test-refresh"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
def test_removes_token_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_token_cache[key] = "tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _token_cache
|
||||
|
||||
def test_removes_null_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_null_cache[key] = True
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _null_cache
|
||||
|
||||
def test_noop_when_key_not_cached(self):
|
||||
# Should not raise even when there is no cache entry.
|
||||
invalidate_user_provider_cache("no-such-user", _PROVIDER)
|
||||
|
||||
def test_only_removes_targeted_key(self):
|
||||
other_key = ("other-user", _PROVIDER)
|
||||
_token_cache[other_key] = "other-tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_cached_token_without_db_hit(self):
|
||||
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "cached-tok"
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_none_for_null_cached_provider(self):
|
||||
_null_cache[(_USER, _PROVIDER)] = True
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_api_key_creds_returned_and_cached(self):
|
||||
api_creds = _make_api_key_creds("my-api-key")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "my-api-key"
|
||||
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_preferred_over_api_key(self):
|
||||
oauth_creds = _make_oauth2_creds("oauth-tok")
|
||||
api_creds = _make_api_key_creds("api-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[api_creds, oauth_creds]
|
||||
)
|
||||
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_falls_back_to_stale_token(self):
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "stale-oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
assert _null_cache.get((_USER, _PROVIDER)) is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_db_exception_returns_none_without_caching(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
side_effect=RuntimeError("db down")
|
||||
)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
# DB errors are not cached — next call will retry
|
||||
assert (_USER, _PROVIDER) not in _token_cache
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
|
||||
"""Verify the TTL constants are set correctly for each cache."""
|
||||
assert _null_cache.ttl == _NULL_CACHE_TTL
|
||||
assert _token_cache.ttl == _TOKEN_CACHE_TTL
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
_token_cache[(_USER, "github")] = "gh-tok"
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
for var in PROVIDER_ENV_VARS["github"]:
|
||||
assert result[var] == "gh-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_dict_when_no_credentials(self):
|
||||
_null_cache[(_USER, "github")] = True
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
assert result == {}
|
||||
@@ -95,25 +95,6 @@ Example — committing an image file to GitHub:
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
- If the token changes mid-session (e.g. user reconnects with a new token),
|
||||
run `gh auth setup-git` to re-register the credential helper.
|
||||
- If `gh` or `git` fails with an authentication error (e.g. "authentication
|
||||
required", "could not read Username", or exit code 128), call
|
||||
`connect_integration(provider="github")` to surface the GitHub credentials
|
||||
setup card so the user can connect their account. Once connected, retry
|
||||
the operation.
|
||||
- For operations that need broader access (e.g. private org repos, GitHub
|
||||
Actions), pass the required scopes: e.g.
|
||||
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
|
||||
"""
|
||||
|
||||
|
||||
# Environment-specific supplement templates
|
||||
def _build_storage_supplement(
|
||||
@@ -124,7 +105,6 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
extra_notes: str = "",
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
@@ -139,7 +119,6 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
extra_notes: Environment-specific notes appended after shared notes
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
@@ -173,16 +152,12 @@ def _build_storage_supplement(
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns).
|
||||
|
||||
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
|
||||
like gh will not work — no integration env-var notes are included.
|
||||
"""
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
@@ -200,11 +175,7 @@ def _get_local_storage_supplement(cwd: str) -> str:
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session).
|
||||
|
||||
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
|
||||
injected per command in bash_exec — include the CLI guidance notes.
|
||||
"""
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
@@ -219,7 +190,6 @@ def _get_cloud_sandbox_supplement() -> str:
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
extra_notes=_E2B_TOOL_NOTES,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -769,7 +769,7 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
return None
|
||||
try:
|
||||
sandbox = await get_or_create_sandbox(
|
||||
return await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
@@ -783,9 +783,7 @@ async def stream_chat_completion_sdk(
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
return sandbox
|
||||
return None
|
||||
|
||||
async def _fetch_transcript():
|
||||
"""Download transcript for --resume if applicable."""
|
||||
|
||||
@@ -12,7 +12,6 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
@@ -85,7 +84,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
"connect_integration": ConnectIntegrationTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
|
||||
@@ -22,7 +22,6 @@ from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
|
||||
from backend.copilot.integration_creds import get_integration_env_vars
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -97,9 +96,7 @@ class BashExecTool(BaseTool):
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(
|
||||
sandbox, command, timeout, session_id, user_id
|
||||
)
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
|
||||
# Bubblewrap fallback: local isolated execution.
|
||||
if not has_full_sandbox():
|
||||
@@ -136,27 +133,14 @@ class BashExecTool(BaseTool):
|
||||
command: str,
|
||||
timeout: int,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute *command* on the E2B sandbox via commands.run().
|
||||
|
||||
Integration tokens (e.g. GH_TOKEN) are injected into the sandbox env
|
||||
for any user with connected accounts. E2B has full internet access, so
|
||||
CLI tools like ``gh`` work without manual authentication.
|
||||
"""
|
||||
envs: dict[str, str] = {
|
||||
"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
|
||||
}
|
||||
if user_id is not None:
|
||||
integration_env = await get_integration_env_vars(user_id)
|
||||
envs.update(integration_env)
|
||||
|
||||
"""Execute *command* on the E2B sandbox via commands.run()."""
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=timeout,
|
||||
envs=envs,
|
||||
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
|
||||
)
|
||||
return BashExecResponse(
|
||||
message=f"Command executed on E2B (exit {result.exit_code})",
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Tests for BashExecTool — E2B path with token injection."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .bash_exec import BashExecTool
|
||||
from .models import BashExecResponse
|
||||
|
||||
_USER = "user-bash-exec-test"
|
||||
|
||||
|
||||
def _make_tool() -> BashExecTool:
|
||||
return BashExecTool()
|
||||
|
||||
|
||||
def _make_sandbox(exit_code: int = 0, stdout: str = "", stderr: str = "") -> MagicMock:
|
||||
result = MagicMock()
|
||||
result.exit_code = exit_code
|
||||
result.stdout = stdout
|
||||
result.stderr = stderr
|
||||
|
||||
sandbox = MagicMock()
|
||||
sandbox.commands.run = AsyncMock(return_value=result)
|
||||
return sandbox
|
||||
|
||||
|
||||
class TestBashExecE2BTokenInjection:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_token_injected_when_user_id_set(self):
|
||||
"""When user_id is provided, integration env vars are merged into sandbox envs."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value=env_vars),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=_USER,
|
||||
)
|
||||
|
||||
mock_get_env.assert_awaited_once_with(_USER)
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert call_kwargs["envs"]["GH_TOKEN"] == "gh-secret"
|
||||
assert call_kwargs["envs"]["GITHUB_TOKEN"] == "gh-secret"
|
||||
assert isinstance(result, BashExecResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_token_injection_when_user_id_is_none(self):
|
||||
"""When user_id is None, get_integration_env_vars must NOT be called."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
mock_get_env.assert_not_called()
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert "GH_TOKEN" not in call_kwargs["envs"]
|
||||
assert isinstance(result, BashExecResponse)
|
||||
@@ -1,215 +0,0 @@
|
||||
"""Tool for prompting the user to connect a required integration.
|
||||
|
||||
When the copilot encounters an authentication failure (e.g. `gh` CLI returns
|
||||
"authentication required"), it calls this tool to surface the credentials
|
||||
setup card in the chat — the same UI that appears when a GitHub block runs
|
||||
without configured credentials.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.models import (
|
||||
ErrorResponse,
|
||||
ResponseType,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
class _ProviderInfo(TypedDict):
|
||||
name: str
|
||||
types: list[str]
|
||||
# Default OAuth scopes requested when the agent doesn't specify any.
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class _CredentialEntry(TypedDict):
|
||||
"""Shape of each entry inside SetupRequirementsResponse.user_readiness.missing_credentials."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
provider: str
|
||||
provider_name: str
|
||||
type: str
|
||||
types: list[str]
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _is_github_oauth_configured() -> bool:
|
||||
"""Return True if GitHub OAuth env vars are set.
|
||||
|
||||
Evaluated lazily (not at import time) to avoid triggering Secrets() during
|
||||
module import, which can fail in environments where secrets are not loaded.
|
||||
"""
|
||||
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
return GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
|
||||
# Registry of known providers: name + supported credential types for the UI.
|
||||
# When adding a new provider, also add its env var names to
|
||||
# backend.copilot.integration_creds.PROVIDER_ENV_VARS.
|
||||
def _get_provider_info() -> dict[str, _ProviderInfo]:
|
||||
"""Build the provider registry, evaluating OAuth config lazily."""
|
||||
return {
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"types": (
|
||||
["api_key", "oauth2"] if _is_github_oauth_configured() else ["api_key"]
|
||||
),
|
||||
# Default: repo scope covers clone/push/pull for public and private repos.
|
||||
# Agent can request additional scopes (e.g. "read:org") via the scopes param.
|
||||
"scopes": ["repo"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ConnectIntegrationTool(BaseTool):
|
||||
"""Surface the credentials setup UI when an integration is not connected."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "connect_integration"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Prompt the user to connect a required integration (e.g. GitHub). "
|
||||
"Call this when an external CLI or API call fails because the user "
|
||||
"has not connected the relevant account. "
|
||||
"The tool surfaces a credentials setup card in the chat so the user "
|
||||
"can authenticate without leaving the page. "
|
||||
"After the user connects the account, retry the operation. "
|
||||
"In E2B/cloud sandbox mode the token (GH_TOKEN/GITHUB_TOKEN) is "
|
||||
"automatically injected per-command in bash_exec — no manual export needed. "
|
||||
"In local bubblewrap mode network is isolated so GitHub CLI commands "
|
||||
"will still fail after connecting; inform the user of this limitation."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Integration provider slug, e.g. 'github'. "
|
||||
"Must be one of the supported providers."
|
||||
),
|
||||
"enum": list(_get_provider_info().keys()),
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Brief explanation of why the integration is needed, "
|
||||
"shown to the user in the setup card."
|
||||
),
|
||||
"maxLength": 500,
|
||||
},
|
||||
"scopes": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"OAuth scopes to request. Omit to use the provider default. "
|
||||
"Add extra scopes when you need more access — e.g. for GitHub: "
|
||||
"'repo' (clone/push/pull), 'read:org' (org membership), "
|
||||
"'workflow' (GitHub Actions). "
|
||||
"Requesting only the scopes you actually need is best practice."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["provider"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
# Require auth so only authenticated users can trigger the setup card.
|
||||
# The card itself is user-agnostic (no per-user data needed), so
|
||||
# user_id is intentionally unused in _execute.
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
del user_id # setup card is user-agnostic; auth is enforced via requires_auth
|
||||
session_id = session.session_id if session else None
|
||||
provider: str = (kwargs.get("provider") or "").strip().lower()
|
||||
reason: str = (kwargs.get("reason") or "").strip()[
|
||||
:500
|
||||
] # cap LLM-controlled text
|
||||
extra_scopes: list[str] = [
|
||||
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
|
||||
]
|
||||
|
||||
provider_info = _get_provider_info()
|
||||
info = provider_info.get(provider)
|
||||
if not info:
|
||||
supported = ", ".join(f"'{p}'" for p in provider_info)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Unknown provider '{provider}'. "
|
||||
f"Supported providers: {supported}."
|
||||
),
|
||||
error="unknown_provider",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
provider_name: str = info["name"]
|
||||
supported_types: list[str] = info["types"]
|
||||
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
|
||||
default_scopes: list[str] = info["scopes"]
|
||||
seen: set[str] = set()
|
||||
scopes: list[str] = []
|
||||
for s in default_scopes + extra_scopes:
|
||||
if s not in seen:
|
||||
seen.add(s)
|
||||
scopes.append(s)
|
||||
field_key = f"{provider}_credentials"
|
||||
|
||||
message_parts = [
|
||||
f"To continue, please connect your {provider_name} account.",
|
||||
]
|
||||
if reason:
|
||||
message_parts.append(reason)
|
||||
|
||||
credential_entry: _CredentialEntry = {
|
||||
"id": field_key,
|
||||
"title": f"{provider_name} Credentials",
|
||||
"provider": provider,
|
||||
"provider_name": provider_name,
|
||||
"type": supported_types[0],
|
||||
"types": supported_types,
|
||||
"scopes": scopes,
|
||||
}
|
||||
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
type=ResponseType.SETUP_REQUIREMENTS,
|
||||
message=" ".join(message_parts),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=f"connect_{provider}",
|
||||
agent_name=provider_name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_credentials,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [missing_credentials[field_key]],
|
||||
"inputs": [],
|
||||
"execution_modes": [],
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Tests for ConnectIntegrationTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .models import ErrorResponse, SetupRequirementsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-connect-integration"
|
||||
|
||||
|
||||
class TestConnectIntegrationTool:
|
||||
def _make_tool(self) -> ConnectIntegrationTool:
|
||||
return ConnectIntegrationTool()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unknown_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="nonexistent"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
assert "nonexistent" in result.message
|
||||
assert "github" in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_provider_returns_setup_response(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.setup_info.agent_name == "GitHub"
|
||||
assert result.setup_info.agent_id == "connect_github"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_has_missing_credentials_in_readiness(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
readiness = result.setup_info.user_readiness
|
||||
assert readiness.has_all_credentials is False
|
||||
assert readiness.ready_to_run is False
|
||||
assert "github_credentials" in readiness.missing_credentials
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_requirements_include_credential_entry(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
creds = result.setup_info.requirements["credentials"]
|
||||
assert len(creds) == 1
|
||||
assert creds[0]["provider"] == "github"
|
||||
assert creds[0]["id"] == "github_credentials"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_reason_appears_in_message(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
reason = "Needed to create a pull request."
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github", reason=reason
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert reason in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_session_id_propagated(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.session_id == session.session_id
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_provider_case_insensitive(self):
|
||||
"""Provider slug is normalised to lowercase before lookup."""
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="GitHub"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
|
||||
def test_tool_name(self):
|
||||
assert ConnectIntegrationTool().name == "connect_integration"
|
||||
|
||||
def test_requires_auth(self):
|
||||
assert ConnectIntegrationTool().requires_auth is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unauthenticated_user_gets_need_login_response(self):
|
||||
"""execute() with user_id=None must return NeedLoginResponse, not the setup card.
|
||||
|
||||
This verifies that the requires_auth guard in BaseTool.execute() fires
|
||||
before _execute() is called, so unauthenticated callers cannot probe
|
||||
which integrations are configured.
|
||||
"""
|
||||
import json
|
||||
|
||||
tool = self._make_tool()
|
||||
# Session still needs a user_id string; the None is passed to execute()
|
||||
# to simulate an unauthenticated call.
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool.execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
tool_call_id="test-call-id",
|
||||
provider="github",
|
||||
)
|
||||
raw = result.output
|
||||
output = json.loads(raw) if isinstance(raw, str) else raw
|
||||
assert output.get("type") == "need_login"
|
||||
assert result.success is False
|
||||
@@ -25,35 +25,6 @@ logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
_on_creds_changed: Callable[[str, str], None] | None = None
|
||||
|
||||
|
||||
def register_creds_changed_hook(hook: Callable[[str, str], None]) -> None:
|
||||
"""Register a callback invoked after any credential is created/updated/deleted.
|
||||
|
||||
The callback receives ``(user_id, provider)`` and should be idempotent.
|
||||
Only one hook can be registered at a time; calling this again replaces the
|
||||
previous hook. Intended to be called once at application startup by the
|
||||
copilot module to bust its token cache without creating an import cycle.
|
||||
"""
|
||||
global _on_creds_changed
|
||||
_on_creds_changed = hook
|
||||
|
||||
|
||||
def _bust_copilot_cache(user_id: str, provider: str) -> None:
|
||||
"""Invoke the registered hook (if any) to bust downstream token caches."""
|
||||
if _on_creds_changed is not None:
|
||||
try:
|
||||
_on_creds_changed(user_id, provider)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Credential-change hook failed for user=%s provider=%s",
|
||||
user_id,
|
||||
provider,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class IntegrationCredentialsManager:
|
||||
"""
|
||||
Handles the lifecycle of integration credentials.
|
||||
@@ -98,11 +69,7 @@ class IntegrationCredentialsManager:
|
||||
return self._locks
|
||||
|
||||
async def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
result = await self.store.add_creds(user_id, credentials)
|
||||
# Bust the copilot token cache so that the next bash_exec picks up the
|
||||
# new credential immediately instead of waiting for _NULL_CACHE_TTL.
|
||||
_bust_copilot_cache(user_id, credentials.provider)
|
||||
return result
|
||||
return await self.store.add_creds(user_id, credentials)
|
||||
|
||||
async def exists(self, user_id: str, credentials_id: str) -> bool:
|
||||
return (await self.store.get_creds_by_id(user_id, credentials_id)) is not None
|
||||
@@ -189,8 +156,6 @@ class IntegrationCredentialsManager:
|
||||
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
# Bust copilot cache so the refreshed token is picked up immediately.
|
||||
_bust_copilot_cache(user_id, fresh_credentials.provider)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
try:
|
||||
await _lock.release()
|
||||
@@ -203,17 +168,10 @@ class IntegrationCredentialsManager:
|
||||
async def update(self, user_id: str, updated: Credentials) -> None:
|
||||
async with self._locked(user_id, updated.id):
|
||||
await self.store.update_creds(user_id, updated)
|
||||
# Bust the copilot token cache so the updated credential is picked up immediately.
|
||||
_bust_copilot_cache(user_id, updated.provider)
|
||||
|
||||
async def delete(self, user_id: str, credentials_id: str) -> None:
|
||||
async with self._locked(user_id, credentials_id):
|
||||
# Read inside the lock to avoid TOCTOU — another coroutine could
|
||||
# delete the same credential between the read and the delete.
|
||||
creds = await self.store.get_creds_by_id(user_id, credentials_id)
|
||||
await self.store.delete_creds_by_id(user_id, credentials_id)
|
||||
if creds:
|
||||
_bust_copilot_cache(user_id, creds.provider)
|
||||
|
||||
# -- Locking utilities -- #
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
import bleach
|
||||
from bleach.css_sanitizer import CSSSanitizer
|
||||
@@ -154,3 +155,76 @@ class TextFormatter:
|
||||
)
|
||||
|
||||
return rendered_subject_template, rendered_base_template
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CamelCase splitting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Map of split forms back to their canonical compound terms.
|
||||
# Mirrors the frontend exception list in frontend/src/lib/utils.ts.
|
||||
_CAMELCASE_EXCEPTIONS: dict[str, str] = {
|
||||
"Auto GPT": "AutoGPT",
|
||||
"Open AI": "OpenAI",
|
||||
"You Tube": "YouTube",
|
||||
"Git Hub": "GitHub",
|
||||
"Linked In": "LinkedIn",
|
||||
}
|
||||
|
||||
_CAMELCASE_EXCEPTION_RE = re.compile(
|
||||
"|".join(re.escape(k) for k in _CAMELCASE_EXCEPTIONS),
|
||||
)
|
||||
|
||||
|
||||
def split_camelcase(text: str) -> str:
|
||||
"""Split CamelCase into separate words.
|
||||
|
||||
Uses a single-pass character-by-character algorithm to avoid any
|
||||
regex backtracking concerns (guaranteed O(n) time).
|
||||
|
||||
After splitting, known compound terms are restored via an exception
|
||||
list (e.g. ``"YouTube"`` stays ``"YouTube"`` instead of becoming
|
||||
``"You Tube"``). The list mirrors the frontend mapping in
|
||||
``frontend/src/lib/utils.ts``.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> split_camelcase("AITextGeneratorBlock")
|
||||
'AI Text Generator Block'
|
||||
>>> split_camelcase("OAuth2Block")
|
||||
'OAuth2 Block'
|
||||
>>> split_camelcase("YouTubeBlock")
|
||||
'YouTube Block'
|
||||
"""
|
||||
if len(text) <= 1:
|
||||
return text
|
||||
|
||||
parts: list[str] = []
|
||||
prev = 0
|
||||
for i in range(1, len(text)):
|
||||
# Insert split between lowercase/digit and uppercase: "camelCase" -> "camel|Case"
|
||||
if (text[i - 1].islower() or text[i - 1].isdigit()) and text[i].isupper():
|
||||
parts.append(text[prev:i])
|
||||
prev = i
|
||||
# Insert split between uppercase run (2+ chars) and uppercase+lowercase:
|
||||
# "AIText" -> "AI|Text". Requires at least 3 consecutive uppercase chars
|
||||
# before the lowercase so that the left part keeps 2+ uppercase chars
|
||||
# (mirrors the original regex r"([A-Z]{2,})([A-Z][a-z])").
|
||||
elif (
|
||||
i >= 2
|
||||
and text[i - 2].isupper()
|
||||
and text[i - 1].isupper()
|
||||
and text[i].islower()
|
||||
and (i - 1 - prev) >= 2 # left part must retain at least 2 upper chars
|
||||
):
|
||||
parts.append(text[prev : i - 1])
|
||||
prev = i - 1
|
||||
|
||||
parts.append(text[prev:])
|
||||
result = " ".join(parts)
|
||||
|
||||
# Restore known compound terms that should not be split.
|
||||
result = _CAMELCASE_EXCEPTION_RE.sub(
|
||||
lambda m: _CAMELCASE_EXCEPTIONS[m.group()], result
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -3,7 +3,6 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { ExclamationMarkIcon } from "@phosphor-icons/react";
|
||||
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useState } from "react";
|
||||
import { ConnectIntegrationTool } from "../../../tools/ConnectIntegrationTool/ConnectIntegrationTool";
|
||||
import { CreateAgentTool } from "../../../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../../../tools/EditAgent/EditAgent";
|
||||
import {
|
||||
@@ -130,8 +129,6 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
case "tool-search_docs":
|
||||
case "tool-get_doc_page":
|
||||
return <SearchDocsTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-connect_integration":
|
||||
return <ConnectIntegrationTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-run_block":
|
||||
case "tool-continue_run_block":
|
||||
return <RunBlockTool key={key} part={part as ToolUIPart} />;
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { useState } from "react";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { ContentMessage } from "../../components/ToolAccordion/AccordionContent";
|
||||
import { SetupRequirementsCard } from "../RunBlock/components/SetupRequirementsCard/SetupRequirementsCard";
|
||||
|
||||
type Props = {
|
||||
part: ToolUIPart;
|
||||
};
|
||||
|
||||
function parseJson(raw: unknown): unknown {
|
||||
if (typeof raw === "string") {
|
||||
try {
|
||||
return JSON.parse(raw);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return raw;
|
||||
}
|
||||
|
||||
function parseOutput(raw: unknown): SetupRequirementsResponse | null {
|
||||
const parsed = parseJson(raw);
|
||||
if (parsed && typeof parsed === "object" && "setup_info" in parsed) {
|
||||
return parsed as SetupRequirementsResponse;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function parseError(raw: unknown): string | null {
|
||||
const parsed = parseJson(raw);
|
||||
if (parsed && typeof parsed === "object" && "message" in parsed) {
|
||||
return String((parsed as { message: unknown }).message);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function ConnectIntegrationTool({ part }: Props) {
|
||||
// Persist dismissed state here so SetupRequirementsCard remounts don't re-enable Proceed.
|
||||
const [isDismissed, setIsDismissed] = useState(false);
|
||||
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError = part.state === "output-error";
|
||||
|
||||
const output =
|
||||
part.state === "output-available"
|
||||
? parseOutput((part as { output?: unknown }).output)
|
||||
: null;
|
||||
|
||||
const errorMessage = isError
|
||||
? (parseError((part as { output?: unknown }).output) ??
|
||||
"Failed to connect integration")
|
||||
: null;
|
||||
|
||||
const rawProvider =
|
||||
(part as { input?: { provider?: string } }).input?.provider ?? "";
|
||||
const providerName =
|
||||
output?.setup_info?.agent_name ??
|
||||
// Sanitize LLM-controlled provider slug: trim and cap at 64 chars to
|
||||
// prevent runaway text in the DOM.
|
||||
(rawProvider ? rawProvider.trim().slice(0, 64) : "integration");
|
||||
|
||||
const label = isStreaming
|
||||
? `Connecting ${providerName}…`
|
||||
: isError
|
||||
? `Failed to connect ${providerName}`
|
||||
: output
|
||||
? `Connect ${output.setup_info?.agent_name ?? providerName}`
|
||||
: `Connect ${providerName}`;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<MorphingTextAnimation
|
||||
text={label}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{isError && errorMessage && (
|
||||
<p className="mt-1 text-sm text-red-500">{errorMessage}</p>
|
||||
)}
|
||||
|
||||
{output && (
|
||||
<div className="mt-2">
|
||||
{isDismissed ? (
|
||||
<ContentMessage>Connected. Continuing…</ContentMessage>
|
||||
) : (
|
||||
<SetupRequirementsCard
|
||||
output={output}
|
||||
credentialsLabel={`${output.setup_info?.agent_name ?? providerName} credentials`}
|
||||
retryInstruction="I've connected my account. Please continue."
|
||||
onComplete={() => setIsDismissed(true)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -23,16 +23,12 @@ interface Props {
|
||||
/** Override the label shown above the credentials section.
|
||||
* Defaults to "Credentials". */
|
||||
credentialsLabel?: string;
|
||||
/** Called after Proceed is clicked so the parent can persist the dismissed state
|
||||
* across remounts (avoids re-enabling the Proceed button on remount). */
|
||||
onComplete?: () => void;
|
||||
}
|
||||
|
||||
export function SetupRequirementsCard({
|
||||
output,
|
||||
retryInstruction,
|
||||
credentialsLabel,
|
||||
onComplete,
|
||||
}: Props) {
|
||||
const { onSend } = useCopilotChatActions();
|
||||
|
||||
@@ -72,17 +68,13 @@ export function SetupRequirementsCard({
|
||||
return v !== undefined && v !== null && v !== "";
|
||||
});
|
||||
|
||||
if (hasSent) {
|
||||
return <ContentMessage>Connected. Continuing…</ContentMessage>;
|
||||
}
|
||||
|
||||
const canRun =
|
||||
!hasSent &&
|
||||
(!needsCredentials || isAllCredentialsComplete) &&
|
||||
(!needsInputs || isAllInputsComplete);
|
||||
|
||||
function handleRun() {
|
||||
setHasSent(true);
|
||||
onComplete?.();
|
||||
|
||||
const parts: string[] = [];
|
||||
if (needsCredentials) {
|
||||
|
||||
@@ -125,9 +125,9 @@ export function useCredentialsInput({
|
||||
if (hasAttemptedAutoSelect.current) return;
|
||||
hasAttemptedAutoSelect.current = true;
|
||||
|
||||
// Auto-select only when there is exactly one saved credential.
|
||||
// With multiple options the user must choose — regardless of optional/required.
|
||||
if (savedCreds.length > 1) return;
|
||||
// Auto-select if exactly one credential matches.
|
||||
// For optional fields with multiple options, let the user choose.
|
||||
if (isOptional && savedCreds.length > 1) return;
|
||||
|
||||
const cred = savedCreds[0];
|
||||
onSelectCredential({
|
||||
|
||||
Reference in New Issue
Block a user