mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/rate-limit-tiering
This commit is contained in:
106
.claude/skills/open-pr/SKILL.md
Normal file
106
.claude/skills/open-pr/SKILL.md
Normal file
@@ -0,0 +1,106 @@
|
||||
---
|
||||
name: open-pr
|
||||
description: Open a pull request with proper PR template, test coverage, and review workflow. Guides agents through creating a PR that follows repo conventions, ensures existing behaviors aren't broken, covers new behaviors with tests, and handles review via bot when local testing isn't possible. TRIGGER when user asks to "open a PR", "create a PR", "make a PR", "submit a PR", "open pull request", "push and create PR", or any variation of opening/submitting a pull request.
|
||||
user-invocable: true
|
||||
args: "[base-branch] — optional target branch (defaults to dev)."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Open a Pull Request
|
||||
|
||||
## Step 1: Pre-flight checks
|
||||
|
||||
Before opening the PR:
|
||||
|
||||
1. Ensure all changes are committed
|
||||
2. Ensure the branch is pushed to the remote (`git push -u origin <branch>`)
|
||||
3. Run linters/formatters across the whole repo (not just changed files) and commit any fixes
|
||||
|
||||
## Step 2: Test coverage
|
||||
|
||||
**This is critical.** Before opening the PR, verify:
|
||||
|
||||
### Existing behavior is not broken
|
||||
- Identify which modules/components your changes touch
|
||||
- Run the existing test suites for those areas
|
||||
- If tests fail, fix them before opening the PR — do not open a PR with known regressions
|
||||
|
||||
### New behavior has test coverage
|
||||
- Every new feature, endpoint, or behavior change needs tests
|
||||
- If you added a new block, add tests for that block
|
||||
- If you changed API behavior, add or update API tests
|
||||
- If you changed frontend behavior, verify it doesn't break existing flows
|
||||
|
||||
If you cannot run the full test suite locally, note which tests you ran and which you couldn't in the test plan.
|
||||
|
||||
## Step 3: Create the PR using the repo template
|
||||
|
||||
Read the canonical PR template at `.github/PULL_REQUEST_TEMPLATE.md` and use it **verbatim** as your PR body:
|
||||
|
||||
1. Read the template: `cat .github/PULL_REQUEST_TEMPLATE.md`
|
||||
2. Preserve the exact section titles and formatting, including:
|
||||
- `### Why / What / How`
|
||||
- `### Changes 🏗️`
|
||||
- `### Checklist 📋`
|
||||
3. Replace HTML comment prompts (`<!-- ... -->`) with actual content; do not leave them in
|
||||
4. **Do not pre-check boxes** — leave all checkboxes as `- [ ]` until each step is actually completed
|
||||
5. Do not alter the template structure, rename sections, or remove any checklist items
|
||||
|
||||
**PR title must use conventional commit format** (e.g., `feat(backend): add new block`, `fix(frontend): resolve routing bug`, `dx(skills): update PR workflow`). See CLAUDE.md for the full list of scopes.
|
||||
|
||||
Use `gh pr create` with the base branch (defaults to `dev` if no `[base-branch]` was provided). Use `--body-file` to avoid shell interpretation of backticks and special characters:
|
||||
|
||||
```bash
|
||||
BASE_BRANCH="${BASE_BRANCH:-dev}"
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
<filled-in template from .github/PULL_REQUEST_TEMPLATE.md>
|
||||
PREOF
|
||||
gh pr create --base "$BASE_BRANCH" --title "<type>(scope): short description" --body-file "$PR_BODY"
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
|
||||
## Step 4: Review workflow
|
||||
|
||||
### If you have a workspace that allows testing (docker, running backend, etc.)
|
||||
- Run `/pr-test` to do E2E manual testing of the PR using docker compose, agent-browser, and API calls. This is the most thorough way to validate your changes before review.
|
||||
- After testing, run `/pr-review` to self-review the PR for correctness, security, code quality, and testing gaps before requesting human review.
|
||||
|
||||
### If you do NOT have a workspace that allows testing
|
||||
This is common for agents running in worktrees without a full stack. In this case:
|
||||
|
||||
1. Run `/pr-review` locally to catch obvious issues before pushing
|
||||
2. **Comment `/review` on the PR** after creating it to trigger the review bot
|
||||
3. **Poll for the review** rather than blindly waiting — check for new review comments every 30 seconds using `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` and the GraphQL inline threads query. The bot typically responds within 30 minutes, but polling lets the agent react as soon as it arrives.
|
||||
4. Do NOT proceed or merge until the bot review comes back
|
||||
5. Address any issues the bot raises — use `/pr-address` which has a full polling loop with CI + comment tracking
|
||||
|
||||
```bash
|
||||
# After creating the PR:
|
||||
PR_NUMBER=$(gh pr view --json number -q .number)
|
||||
gh pr comment "$PR_NUMBER" --body "/review"
|
||||
# Then use /pr-address to poll for and address the review when it arrives
|
||||
```
|
||||
|
||||
## Step 5: Address review feedback
|
||||
|
||||
Once the review bot or human reviewers leave comments:
|
||||
- Run `/pr-address` to address review comments. It will loop until CI is green and all comments are resolved.
|
||||
- Do not merge without human approval.
|
||||
|
||||
## Related skills
|
||||
|
||||
| Skill | When to use |
|
||||
|---|---|
|
||||
| `/pr-test` | E2E testing with docker compose, agent-browser, API calls — use when you have a running workspace |
|
||||
| `/pr-review` | Review for correctness, security, code quality — use before requesting human review |
|
||||
| `/pr-address` | Address reviewer comments and loop until CI green — use after reviews come in |
|
||||
|
||||
## Step 6: Post-creation
|
||||
|
||||
After the PR is created and review is triggered:
|
||||
- Share the PR URL with the user
|
||||
- If waiting on the review bot, let the user know the expected wait time (~30 min)
|
||||
- Do not merge without human approval
|
||||
195
.claude/skills/setup-repo/SKILL.md
Normal file
195
.claude/skills/setup-repo/SKILL.md
Normal file
@@ -0,0 +1,195 @@
|
||||
---
|
||||
name: setup-repo
|
||||
description: Initialize a worktree-based repo layout for parallel development. Creates a main worktree, a reviews worktree for PR reviews, and N numbered work branches. Handles .env creation, dependency installation, and branchlet config. TRIGGER when user asks to set up the repo from scratch, initialize worktrees, bootstrap their dev environment, "setup repo", "setup worktrees", "initialize dev environment", "set up branches", or when a freshly cloned repo has no sibling worktrees.
|
||||
user-invocable: true
|
||||
args: "No arguments — interactive setup via prompts."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Repository Setup
|
||||
|
||||
This skill sets up a worktree-based development layout from a freshly cloned repo. It creates:
|
||||
- A **main** worktree (the primary checkout)
|
||||
- A **reviews** worktree (for PR reviews)
|
||||
- **N work branches** (branch1..branchN) for parallel development
|
||||
|
||||
## Step 1: Identify the repo
|
||||
|
||||
Determine the repo root and parent directory:
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
REPO_NAME=$(basename "$ROOT")
|
||||
PARENT=$(dirname "$ROOT")
|
||||
```
|
||||
|
||||
Detect if the repo is already inside a worktree layout by counting sibling worktrees (not just checking the directory name, which could be anything):
|
||||
|
||||
```bash
|
||||
# Count worktrees that are siblings (live under $PARENT but aren't $ROOT itself)
|
||||
SIBLING_COUNT=$(git worktree list --porcelain 2>/dev/null | grep "^worktree " | grep -c "$PARENT/" || true)
|
||||
if [ "$SIBLING_COUNT" -gt 1 ]; then
|
||||
echo "INFO: Existing worktree layout detected at $PARENT ($SIBLING_COUNT worktrees)"
|
||||
# Use $ROOT as-is; skip renaming/restructuring
|
||||
else
|
||||
echo "INFO: Fresh clone detected, proceeding with setup"
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 2: Ask the user questions
|
||||
|
||||
Use AskUserQuestion to gather setup preferences:
|
||||
|
||||
1. **How many parallel work branches do you need?** (Options: 4, 8, 16, or custom)
|
||||
- These become `branch1` through `branchN`
|
||||
2. **Which branch should be the base?** (Options: origin/master, origin/dev, or custom)
|
||||
- All work branches and reviews will start from this
|
||||
|
||||
## Step 3: Fetch and set up branches
|
||||
|
||||
```bash
|
||||
cd "$ROOT"
|
||||
git fetch origin
|
||||
|
||||
# Create the reviews branch from base (skip if already exists)
|
||||
if git show-ref --verify --quiet refs/heads/reviews; then
|
||||
echo "INFO: Branch 'reviews' already exists, skipping"
|
||||
else
|
||||
git branch reviews <base-branch>
|
||||
fi
|
||||
|
||||
# Create numbered work branches from base (skip if already exists)
|
||||
for i in $(seq 1 "$COUNT"); do
|
||||
if git show-ref --verify --quiet "refs/heads/branch$i"; then
|
||||
echo "INFO: Branch 'branch$i' already exists, skipping"
|
||||
else
|
||||
git branch "branch$i" <base-branch>
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Step 4: Create worktrees
|
||||
|
||||
Create worktrees as siblings to the main checkout:
|
||||
|
||||
```bash
|
||||
if [ -d "$PARENT/reviews" ]; then
|
||||
echo "INFO: Worktree '$PARENT/reviews' already exists, skipping"
|
||||
else
|
||||
git worktree add "$PARENT/reviews" reviews
|
||||
fi
|
||||
|
||||
for i in $(seq 1 "$COUNT"); do
|
||||
if [ -d "$PARENT/branch$i" ]; then
|
||||
echo "INFO: Worktree '$PARENT/branch$i' already exists, skipping"
|
||||
else
|
||||
git worktree add "$PARENT/branch$i" "branch$i"
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Step 5: Set up environment files
|
||||
|
||||
**Do NOT assume .env files exist.** For each worktree (including main if needed):
|
||||
|
||||
1. Check if `.env` exists in the source worktree for each path
|
||||
2. If `.env` exists, copy it
|
||||
3. If only `.env.default` or `.env.example` exists, copy that as `.env`
|
||||
4. If neither exists, warn the user and list which env files are missing
|
||||
|
||||
Env file locations to check (same as the `/worktree` skill — keep these in sync):
|
||||
- `autogpt_platform/.env`
|
||||
- `autogpt_platform/backend/.env`
|
||||
- `autogpt_platform/frontend/.env`
|
||||
|
||||
> **Note:** This env copying logic intentionally mirrors the `/worktree` skill's approach. If you update the path list or fallback logic here, update `/worktree` as well.
|
||||
|
||||
```bash
|
||||
SOURCE="$ROOT"
|
||||
WORKTREES="reviews"
|
||||
for i in $(seq 1 "$COUNT"); do WORKTREES="$WORKTREES branch$i"; done
|
||||
|
||||
FOUND_ANY_ENV=0
|
||||
for wt in $WORKTREES; do
|
||||
TARGET="$PARENT/$wt"
|
||||
for envpath in autogpt_platform autogpt_platform/backend autogpt_platform/frontend; do
|
||||
if [ -f "$SOURCE/$envpath/.env" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env" "$TARGET/$envpath/.env"
|
||||
elif [ -f "$SOURCE/$envpath/.env.default" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env.default" "$TARGET/$envpath/.env"
|
||||
echo "NOTE: $wt/$envpath/.env was created from .env.default — you may need to edit it"
|
||||
elif [ -f "$SOURCE/$envpath/.env.example" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env.example" "$TARGET/$envpath/.env"
|
||||
echo "NOTE: $wt/$envpath/.env was created from .env.example — you may need to edit it"
|
||||
else
|
||||
echo "WARNING: No .env, .env.default, or .env.example found at $SOURCE/$envpath/"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
if [ "$FOUND_ANY_ENV" -eq 0 ]; then
|
||||
echo "WARNING: No environment files or templates were found in the source worktree."
|
||||
# Use AskUserQuestion to confirm: "Continue setup without env files?"
|
||||
# If the user declines, stop here and let them set up .env files first.
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 6: Copy branchlet config
|
||||
|
||||
Copy `.branchlet.json` from main to each worktree so branchlet can manage sub-worktrees:
|
||||
|
||||
```bash
|
||||
if [ -f "$ROOT/.branchlet.json" ]; then
|
||||
for wt in $WORKTREES; do
|
||||
cp "$ROOT/.branchlet.json" "$PARENT/$wt/.branchlet.json"
|
||||
done
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 7: Install dependencies
|
||||
|
||||
Install deps in all worktrees. Run these sequentially per worktree:
|
||||
|
||||
```bash
|
||||
for wt in $WORKTREES; do
|
||||
TARGET="$PARENT/$wt"
|
||||
echo "=== Installing deps for $wt ==="
|
||||
(cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install) &&
|
||||
(cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate) &&
|
||||
(cd "$TARGET/autogpt_platform/frontend" && pnpm install) &&
|
||||
echo "=== Done: $wt ===" ||
|
||||
echo "=== FAILED: $wt ==="
|
||||
done
|
||||
```
|
||||
|
||||
This is slow. Run in background if possible and notify when complete.
|
||||
|
||||
## Step 8: Verify and report
|
||||
|
||||
After setup, verify and report to the user:
|
||||
|
||||
```bash
|
||||
git worktree list
|
||||
```
|
||||
|
||||
Summarize:
|
||||
- Number of worktrees created
|
||||
- Which env files were copied vs created from defaults vs missing
|
||||
- Any warnings or errors encountered
|
||||
|
||||
## Final directory layout
|
||||
|
||||
```
|
||||
parent/
|
||||
main/ # Primary checkout (already exists)
|
||||
reviews/ # PR review worktree
|
||||
branch1/ # Work branch 1
|
||||
branch2/ # Work branch 2
|
||||
...
|
||||
branchN/ # Work branch N
|
||||
```
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -19,6 +20,33 @@ from backend.blocks.llm import (
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
# Minimum max_output_tokens accepted by OpenAI-compatible APIs.
|
||||
# A true/false answer fits comfortably within this budget.
|
||||
MIN_LLM_OUTPUT_TOKENS = 16
|
||||
|
||||
|
||||
def _parse_boolean_response(response_text: str) -> tuple[bool, str | None]:
|
||||
"""Parse an LLM response into a boolean result.
|
||||
|
||||
Returns a ``(result, error)`` tuple. *error* is ``None`` when the
|
||||
response is unambiguous; otherwise it contains a diagnostic message
|
||||
and *result* defaults to ``False``.
|
||||
"""
|
||||
text = response_text.strip().lower()
|
||||
if text == "true":
|
||||
return True, None
|
||||
if text == "false":
|
||||
return False, None
|
||||
|
||||
# Fuzzy match – use word boundaries to avoid false positives like "untrue".
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", text))
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
return True, None
|
||||
if tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
return False, None
|
||||
|
||||
return False, f"Unclear AI response: '{response_text}'"
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
@@ -162,54 +190,26 @@ class AIConditionBlock(AIBlockBase):
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=MIN_LLM_OUTPUT_TOKENS,
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
result, error = _parse_boolean_response(response.response)
|
||||
if error:
|
||||
yield "error", error
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
147
autogpt_platform/backend/backend/blocks/ai_condition_test.py
Normal file
147
autogpt_platform/backend/backend/blocks/ai_condition_test.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Tests for AIConditionBlock – regression coverage for max_tokens and error propagation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.ai_condition import (
|
||||
MIN_LLM_OUTPUT_TOKENS,
|
||||
AIConditionBlock,
|
||||
_parse_boolean_response,
|
||||
)
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AICredentials,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
_TEST_AI_CREDENTIALS = cast(AICredentials, TEST_CREDENTIALS_INPUT)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper to collect all yields from the async generator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _collect_outputs(block: AIConditionBlock, input_data, credentials):
|
||||
outputs: dict[str, object] = {}
|
||||
async for name, value in block.run(input_data, credentials=credentials):
|
||||
outputs[name] = value
|
||||
return outputs
|
||||
|
||||
|
||||
def _make_input(**overrides) -> AIConditionBlock.Input:
|
||||
defaults: dict = {
|
||||
"input_value": "hello@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "yes!",
|
||||
"no_value": "no!",
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return AIConditionBlock.Input(**defaults)
|
||||
|
||||
|
||||
def _mock_llm_response(response_text: str) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response=response_text,
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_boolean_response unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseBooleanResponse:
|
||||
def test_true_exact(self):
|
||||
assert _parse_boolean_response("true") == (True, None)
|
||||
|
||||
def test_false_exact(self):
|
||||
assert _parse_boolean_response("false") == (False, None)
|
||||
|
||||
def test_true_with_whitespace(self):
|
||||
assert _parse_boolean_response(" True ") == (True, None)
|
||||
|
||||
def test_yes_fuzzy(self):
|
||||
assert _parse_boolean_response("Yes") == (True, None)
|
||||
|
||||
def test_no_fuzzy(self):
|
||||
assert _parse_boolean_response("no") == (False, None)
|
||||
|
||||
def test_one_fuzzy(self):
|
||||
assert _parse_boolean_response("1") == (True, None)
|
||||
|
||||
def test_zero_fuzzy(self):
|
||||
assert _parse_boolean_response("0") == (False, None)
|
||||
|
||||
def test_unclear_response(self):
|
||||
result, error = _parse_boolean_response("I'm not sure")
|
||||
assert result is False
|
||||
assert error is not None
|
||||
assert "Unclear" in error
|
||||
|
||||
def test_conflicting_tokens(self):
|
||||
result, error = _parse_boolean_response("true and false")
|
||||
assert result is False
|
||||
assert error is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: max_tokens is set to MIN_LLM_OUTPUT_TOKENS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaxTokensRegression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_receives_min_output_tokens(self):
|
||||
"""max_tokens must be MIN_LLM_OUTPUT_TOKENS (16) – the previous value
|
||||
of 1 was too low and caused OpenAI to reject the request."""
|
||||
block = AIConditionBlock()
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def spy_llm_call(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return _mock_llm_response("true")
|
||||
|
||||
block.llm_call = spy_llm_call # type: ignore[assignment]
|
||||
|
||||
input_data = _make_input()
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
|
||||
assert captured_kwargs["max_tokens"] == MIN_LLM_OUTPUT_TOKENS
|
||||
assert captured_kwargs["max_tokens"] == 16
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: exceptions from llm_call must propagate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExceptionPropagation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_exception_propagates(self):
|
||||
"""If llm_call raises, the exception must NOT be swallowed.
|
||||
Previously the block caught all exceptions and silently returned
|
||||
result=False."""
|
||||
block = AIConditionBlock()
|
||||
|
||||
async def boom(**kwargs):
|
||||
raise RuntimeError("LLM provider error")
|
||||
|
||||
block.llm_call = boom # type: ignore[assignment]
|
||||
|
||||
input_data = _make_input()
|
||||
with pytest.raises(RuntimeError, match="LLM provider error"):
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
@@ -25,24 +25,64 @@ def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
|
||||
|
||||
Use this helper in any copilot SDK test that needs a well-formed
|
||||
transcript without hitting the real storage layer.
|
||||
|
||||
Delegates to ``build_structured_transcript`` — plain content strings
|
||||
are automatically wrapped in ``[{"type": "text", "text": ...}]`` for
|
||||
assistant messages.
|
||||
"""
|
||||
# Cast widening: tuple[str, str] is structurally compatible with
|
||||
# tuple[str, str | list[dict]] but list invariance requires explicit
|
||||
# annotation.
|
||||
widened: list[tuple[str, str | list[dict]]] = list(pairs)
|
||||
return build_structured_transcript(widened)
|
||||
|
||||
|
||||
def build_structured_transcript(
|
||||
entries: list[tuple[str, str | list[dict]]],
|
||||
) -> str:
|
||||
"""Build a JSONL transcript with structured content blocks.
|
||||
|
||||
Each entry is (role, content) where content is either a plain string
|
||||
(for user messages) or a list of content block dicts (for assistant
|
||||
messages with thinking/tool_use/text blocks).
|
||||
|
||||
Example::
|
||||
|
||||
build_structured_transcript([
|
||||
("user", "Hello"),
|
||||
("assistant", [
|
||||
{"type": "thinking", "thinking": "...", "signature": "sig1"},
|
||||
{"type": "text", "text": "Hi there"},
|
||||
]),
|
||||
])
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for role, content in pairs:
|
||||
for role, content in entries:
|
||||
uid = str(uuid4())
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
msg: dict = {"role": role, "content": content}
|
||||
if role == "assistant":
|
||||
msg.update(
|
||||
{
|
||||
"model": "",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
)
|
||||
if role == "assistant" and isinstance(content, list):
|
||||
msg: dict = {
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": content,
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
elif role == "assistant":
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
else:
|
||||
msg = {"role": role, "content": content}
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
|
||||
@@ -442,8 +442,11 @@ class TestCompactTranscript:
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
# 3 messages: compressed prefix (2) + preserved last assistant (1)
|
||||
assert len(msgs) == 3
|
||||
assert msgs[1]["content"] == "Summarized response"
|
||||
# The last assistant entry is preserved verbatim from original
|
||||
assert msgs[2]["content"] == "Details"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_compression_failure(self, mock_chat_config):
|
||||
|
||||
@@ -15,6 +15,7 @@ from claude_agent_sdk import (
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
@@ -100,6 +101,11 @@ class SDKResponseAdapter:
|
||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||
)
|
||||
|
||||
elif isinstance(block, ThinkingBlock):
|
||||
# Thinking blocks are preserved in the transcript but
|
||||
# not streamed to the frontend — skip silently.
|
||||
pass
|
||||
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
self._end_text_if_open(responses)
|
||||
|
||||
|
||||
@@ -124,8 +124,11 @@ class TestScenarioCompactAndRetry:
|
||||
assert result != original # Must be different
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
# 3 messages: compressed prefix (2) + preserved last assistant (1)
|
||||
assert len(msgs) == 3
|
||||
assert msgs[0]["content"] == "[summary of conversation]"
|
||||
# Last assistant preserved verbatim
|
||||
assert msgs[2]["content"] == "Long answer 2"
|
||||
|
||||
def test_compacted_transcript_loads_into_builder(self):
|
||||
"""TranscriptBuilder can load a compacted transcript and continue."""
|
||||
@@ -737,7 +740,10 @@ class TestRetryEdgeCases:
|
||||
assert result is not None
|
||||
assert result != transcript
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
# 3 messages: compressed prefix (2) + preserved last assistant (1)
|
||||
assert len(msgs) == 3
|
||||
# Last assistant preserved verbatim
|
||||
assert msgs[2]["content"] == "Answer 19"
|
||||
|
||||
def test_messages_to_transcript_roundtrip_preserves_content(self):
|
||||
"""Verify messages → transcript → messages preserves all content."""
|
||||
|
||||
@@ -671,7 +671,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
"""Convert SDK content blocks to transcript format.
|
||||
|
||||
Handles TextBlock, ToolUseBlock, ToolResultBlock, and ThinkingBlock.
|
||||
Unknown block types are logged and skipped.
|
||||
Raw dicts (e.g. ``redacted_thinking`` blocks that the SDK may not have
|
||||
a typed class for) are passed through verbatim to preserve them in the
|
||||
transcript. Unknown typed block objects are logged and skipped.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for block in blocks or []:
|
||||
@@ -703,6 +705,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
"signature": block.signature,
|
||||
}
|
||||
)
|
||||
elif isinstance(block, dict) and "type" in block:
|
||||
# Preserve raw dict blocks (e.g. redacted_thinking) verbatim.
|
||||
result.append(block)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Unknown content block type: {type(block).__name__}. "
|
||||
|
||||
@@ -0,0 +1,822 @@
|
||||
"""Tests for thinking/redacted_thinking block preservation.
|
||||
|
||||
Validates the fix for the Anthropic API error:
|
||||
"thinking or redacted_thinking blocks in the latest assistant message
|
||||
cannot be modified. These blocks must remain as they were in the
|
||||
original response."
|
||||
|
||||
The API requires that thinking blocks in the LAST assistant message are
|
||||
preserved value-identical. Older assistant messages may have thinking blocks
|
||||
stripped entirely. This test suite covers:
|
||||
|
||||
1. _flatten_assistant_content — strips thinking from older messages
|
||||
2. compact_transcript — preserves last assistant's thinking blocks
|
||||
3. response_adapter — handles ThinkingBlock without error
|
||||
4. _format_sdk_content_blocks — preserves redacted_thinking blocks
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from claude_agent_sdk import AssistantMessage, TextBlock, ThinkingBlock
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextStart,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_structured_transcript
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .service import _format_sdk_content_blocks
|
||||
from .transcript import (
|
||||
_find_last_assistant_entry,
|
||||
_flatten_assistant_content,
|
||||
_messages_to_transcript,
|
||||
_rechain_tail,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic thinking block content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
THINKING_BLOCK = {
|
||||
"type": "thinking",
|
||||
"thinking": "Let me analyze the user's request carefully...",
|
||||
"signature": "ErUBCkYIAxgCIkD0V2MsRXPkuGolGexaW9V1kluijxXGF",
|
||||
}
|
||||
|
||||
REDACTED_THINKING_BLOCK = {
|
||||
"type": "redacted_thinking",
|
||||
"data": "EmwKAhgBEgy2VEE8PJaS2oLJCPkaT...",
|
||||
}
|
||||
|
||||
|
||||
def _make_thinking_transcript() -> str:
|
||||
"""Build a transcript with thinking blocks in multiple assistant turns.
|
||||
|
||||
Layout:
|
||||
User 1 → Assistant 1 (thinking + text + tool_use)
|
||||
User 2 (tool_result) → Assistant 2 (thinking + text)
|
||||
User 3 → Assistant 3 (thinking + redacted_thinking + text) ← LAST
|
||||
"""
|
||||
return build_structured_transcript(
|
||||
[
|
||||
("user", "What files are in this project?"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "I should list the files.",
|
||||
"signature": "sig_old_1",
|
||||
},
|
||||
{"type": "text", "text": "Let me check the files."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu1",
|
||||
"name": "list_files",
|
||||
"input": {"path": "/"},
|
||||
},
|
||||
],
|
||||
),
|
||||
("user", "Here are the files: a.py, b.py"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "Good, I see two Python files.",
|
||||
"signature": "sig_old_2",
|
||||
},
|
||||
{"type": "text", "text": "I found a.py and b.py."},
|
||||
],
|
||||
),
|
||||
("user", "Tell me about a.py"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
THINKING_BLOCK,
|
||||
REDACTED_THINKING_BLOCK,
|
||||
{"type": "text", "text": "a.py contains the main entry point."},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _last_assistant_content(transcript_jsonl: str) -> list[dict] | None:
|
||||
"""Extract the content blocks of the last assistant entry in a transcript."""
|
||||
last_content = None
|
||||
for line in transcript_jsonl.strip().split("\n"):
|
||||
entry = json.loads(line)
|
||||
msg = entry.get("message", {})
|
||||
if msg.get("role") == "assistant":
|
||||
last_content = msg.get("content")
|
||||
return last_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_last_assistant_entry — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindLastAssistantEntry:
|
||||
def test_splits_at_last_assistant(self):
|
||||
"""Prefix contains everything before last assistant; tail starts at it."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", [{"type": "text", "text": "Hi"}]),
|
||||
("user", "More"),
|
||||
("assistant", [{"type": "text", "text": "Details"}]),
|
||||
]
|
||||
)
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
# 3 entries in prefix (user, assistant, user), 1 in tail (last assistant)
|
||||
assert len(prefix) == 3
|
||||
assert len(tail) == 1
|
||||
|
||||
def test_no_assistant_returns_all_in_prefix(self):
|
||||
"""When there's no assistant, all lines are in prefix, tail is empty."""
|
||||
transcript = build_structured_transcript(
|
||||
[("user", "Hello"), ("user", "Another question")]
|
||||
)
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
assert len(prefix) == 2
|
||||
assert tail == []
|
||||
|
||||
def test_assistant_at_index_zero(self):
|
||||
"""When assistant is the first entry, prefix is empty."""
|
||||
transcript = build_structured_transcript(
|
||||
[("assistant", [{"type": "text", "text": "Start"}])]
|
||||
)
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
assert prefix == []
|
||||
assert len(tail) == 1
|
||||
|
||||
def test_trailing_user_included_in_tail(self):
|
||||
"""User message after last assistant is part of the tail."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Q1"),
|
||||
("assistant", [{"type": "text", "text": "A1"}]),
|
||||
("user", "Q2"),
|
||||
]
|
||||
)
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
assert len(prefix) == 1 # first user
|
||||
assert len(tail) == 2 # last assistant + trailing user
|
||||
|
||||
def test_multi_entry_turn_fully_preserved(self):
|
||||
"""An assistant turn spanning multiple JSONL entries (same message.id)
|
||||
must be entirely in the tail, not split across prefix and tail."""
|
||||
# Build manually because build_structured_transcript generates unique ids
|
||||
lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "Hello"},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1-think",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_same_turn",
|
||||
"type": "message",
|
||||
"content": [THINKING_BLOCK],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1-tool",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_same_turn",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu1",
|
||||
"name": "Bash",
|
||||
"input": {},
|
||||
},
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
transcript = "\n".join(lines) + "\n"
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
# Both assistant entries share msg_same_turn → both in tail
|
||||
assert len(prefix) == 1 # only the user entry
|
||||
assert len(tail) == 2 # both assistant entries (thinking + tool_use)
|
||||
|
||||
def test_no_message_id_preserves_last_assistant(self):
|
||||
"""When the last assistant entry has no message.id, it should still
|
||||
be preserved in the tail (fail closed) rather than being compressed."""
|
||||
lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "Hello"},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [THINKING_BLOCK, {"type": "text", "text": "Hi"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
transcript = "\n".join(lines) + "\n"
|
||||
prefix, tail = _find_last_assistant_entry(transcript)
|
||||
assert len(prefix) == 1 # user entry
|
||||
assert len(tail) == 1 # assistant entry preserved
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _rechain_tail — UUID chain patching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRechainTail:
|
||||
def test_patches_first_entry_parentuuid(self):
|
||||
"""First tail entry's parentUuid should point to last prefix uuid."""
|
||||
prefix = _messages_to_transcript(
|
||||
[
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
)
|
||||
# Get the last uuid from the prefix
|
||||
last_prefix_uuid = None
|
||||
for line in prefix.strip().split("\n"):
|
||||
entry = json.loads(line)
|
||||
last_prefix_uuid = entry.get("uuid")
|
||||
|
||||
tail_lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "tail-a1",
|
||||
"parentUuid": "old-parent",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Tail msg"}],
|
||||
},
|
||||
}
|
||||
)
|
||||
]
|
||||
result = _rechain_tail(prefix, tail_lines)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["parentUuid"] == last_prefix_uuid
|
||||
assert entry["uuid"] == "tail-a1" # uuid preserved
|
||||
|
||||
def test_chains_multiple_tail_entries(self):
|
||||
"""Subsequent tail entries chain to each other."""
|
||||
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
|
||||
tail_lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "t1",
|
||||
"parentUuid": "old1",
|
||||
"message": {"role": "assistant", "content": []},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "t2",
|
||||
"parentUuid": "old2",
|
||||
"message": {"role": "user", "content": "Follow-up"},
|
||||
}
|
||||
),
|
||||
]
|
||||
result = _rechain_tail(prefix, tail_lines)
|
||||
entries = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
assert len(entries) == 2
|
||||
# Second entry's parentUuid should be first entry's uuid
|
||||
assert entries[1]["parentUuid"] == "t1"
|
||||
|
||||
def test_empty_tail_returns_empty(self):
|
||||
"""No tail entries → empty string."""
|
||||
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
|
||||
assert _rechain_tail(prefix, []) == ""
|
||||
|
||||
def test_preserves_message_content_verbatim(self):
|
||||
"""Tail message content (including thinking blocks) must not be modified."""
|
||||
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
|
||||
original_content = [
|
||||
THINKING_BLOCK,
|
||||
REDACTED_THINKING_BLOCK,
|
||||
{"type": "text", "text": "Response"},
|
||||
]
|
||||
tail_lines = [
|
||||
json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "t1",
|
||||
"parentUuid": "old",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": original_content,
|
||||
},
|
||||
}
|
||||
)
|
||||
]
|
||||
result = _rechain_tail(prefix, tail_lines)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["message"]["content"] == original_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content — thinking blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenThinkingBlocks:
|
||||
def test_thinking_blocks_are_stripped(self):
|
||||
"""Thinking blocks should not appear in flattened text for compression."""
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "secret thoughts", "signature": "sig"},
|
||||
{"type": "text", "text": "Hello user"},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "secret thoughts" not in result
|
||||
assert "Hello user" in result
|
||||
|
||||
def test_redacted_thinking_blocks_are_stripped(self):
|
||||
"""Redacted thinking blocks should not appear in flattened text."""
|
||||
blocks = [
|
||||
{"type": "redacted_thinking", "data": "encrypted_data"},
|
||||
{"type": "text", "text": "Response text"},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "encrypted_data" not in result
|
||||
assert "Response text" in result
|
||||
|
||||
def test_thinking_only_message_flattens_to_empty(self):
|
||||
"""A message with only thinking blocks flattens to empty string."""
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "just thinking...", "signature": "sig"},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert result == ""
|
||||
|
||||
def test_mixed_thinking_text_tool(self):
|
||||
"""Mixed blocks: only text and tool_use survive flattening."""
|
||||
blocks = [
|
||||
{"type": "thinking", "thinking": "hmm", "signature": "sig"},
|
||||
{"type": "redacted_thinking", "data": "xyz"},
|
||||
{"type": "text", "text": "I'll read the file."},
|
||||
{"type": "tool_use", "name": "Read", "input": {"path": "/x"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "hmm" not in result
|
||||
assert "xyz" not in result
|
||||
assert "I'll read the file." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compact_transcript — thinking block preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactTranscriptThinkingBlocks:
|
||||
"""Verify that compact_transcript preserves thinking blocks in the
|
||||
last assistant message while stripping them from older messages."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_last_assistant_thinking_blocks_preserved(self, mock_chat_config):
|
||||
"""After compaction, the last assistant entry must retain its
|
||||
original thinking and redacted_thinking blocks verbatim."""
|
||||
transcript = _make_thinking_transcript()
|
||||
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[conversation summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 800,
|
||||
"token_count": 200,
|
||||
"messages_summarized": 4,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
|
||||
last_content = _last_assistant_content(result)
|
||||
assert last_content is not None, "No assistant entry found"
|
||||
assert isinstance(last_content, list)
|
||||
|
||||
# The last assistant must have the thinking blocks preserved
|
||||
block_types = [b["type"] for b in last_content]
|
||||
assert (
|
||||
"thinking" in block_types
|
||||
), "thinking block missing from last assistant message"
|
||||
assert (
|
||||
"redacted_thinking" in block_types
|
||||
), "redacted_thinking block missing from last assistant message"
|
||||
assert "text" in block_types
|
||||
|
||||
# Verify the thinking block content is value-identical
|
||||
thinking_blocks = [b for b in last_content if b["type"] == "thinking"]
|
||||
assert len(thinking_blocks) == 1
|
||||
assert thinking_blocks[0]["thinking"] == THINKING_BLOCK["thinking"]
|
||||
assert thinking_blocks[0]["signature"] == THINKING_BLOCK["signature"]
|
||||
|
||||
redacted_blocks = [b for b in last_content if b["type"] == "redacted_thinking"]
|
||||
assert len(redacted_blocks) == 1
|
||||
assert redacted_blocks[0]["data"] == REDACTED_THINKING_BLOCK["data"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_older_assistant_thinking_blocks_stripped(self, mock_chat_config):
|
||||
"""Older assistant messages should NOT retain thinking blocks
|
||||
after compaction (they're compressed into summaries)."""
|
||||
transcript = _make_thinking_transcript()
|
||||
|
||||
# The compressor will receive messages where older assistant
|
||||
# entries have already had thinking blocks stripped.
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
async def mock_compression(messages, model, log_prefix):
|
||||
captured_messages.extend(messages)
|
||||
return type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": messages,
|
||||
"original_token_count": 800,
|
||||
"token_count": 400,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
side_effect=mock_compression,
|
||||
):
|
||||
await compact_transcript(transcript, model="test-model")
|
||||
|
||||
# Check that the messages sent to compression don't contain
|
||||
# thinking content from older assistant messages
|
||||
for msg in captured_messages:
|
||||
if msg["role"] == "assistant":
|
||||
content = msg.get("content", "")
|
||||
assert (
|
||||
"I should list the files." not in content
|
||||
), "Old thinking block content leaked into compression input"
|
||||
assert (
|
||||
"Good, I see two Python files." not in content
|
||||
), "Old thinking block content leaked into compression input"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trailing_user_message_after_last_assistant(self, mock_chat_config):
|
||||
"""When the last entry is a user message, the last *assistant*
|
||||
message's thinking blocks should still be preserved."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
THINKING_BLOCK,
|
||||
{"type": "text", "text": "Hi there"},
|
||||
],
|
||||
),
|
||||
("user", "Follow-up question"),
|
||||
]
|
||||
)
|
||||
|
||||
# The compressor only receives the prefix (1 user message); the
|
||||
# tail (assistant + trailing user) is preserved verbatim.
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 400,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
|
||||
last_content = _last_assistant_content(result)
|
||||
assert last_content is not None
|
||||
assert isinstance(last_content, list)
|
||||
block_types = [b["type"] for b in last_content]
|
||||
assert (
|
||||
"thinking" in block_types
|
||||
), "thinking block lost from last assistant despite trailing user msg"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_assistant_with_thinking_preserved(self, mock_chat_config):
|
||||
"""When there's only one assistant message (which is also the last),
|
||||
its thinking blocks must be preserved."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
THINKING_BLOCK,
|
||||
{"type": "text", "text": "World"},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "World"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 200,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
|
||||
last_content = _last_assistant_content(result)
|
||||
assert last_content is not None
|
||||
assert isinstance(last_content, list)
|
||||
block_types = [b["type"] for b in last_content]
|
||||
assert "thinking" in block_types
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tail_parentuuid_rewired_to_prefix(self, mock_chat_config):
|
||||
"""After compaction, the first tail entry's parentUuid must point to
|
||||
the last entry in the compressed prefix — not its original parent."""
|
||||
transcript = _make_thinking_transcript()
|
||||
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[conversation summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 800,
|
||||
"token_count": 200,
|
||||
"messages_summarized": 4,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
lines = [ln for ln in result.strip().split("\n") if ln.strip()]
|
||||
entries = [json.loads(ln) for ln in lines]
|
||||
|
||||
# Find the boundary: the compressed prefix ends just before the
|
||||
# first tail entry (last assistant in original transcript).
|
||||
tail_start = None
|
||||
for i, entry in enumerate(entries):
|
||||
msg = entry.get("message", {})
|
||||
if isinstance(msg.get("content"), list):
|
||||
# Structured content = preserved tail entry
|
||||
tail_start = i
|
||||
break
|
||||
|
||||
assert tail_start is not None, "Could not find preserved tail entry"
|
||||
assert tail_start > 0, "Tail should not be the first entry"
|
||||
|
||||
# The tail entry's parentUuid must be the uuid of the preceding entry
|
||||
prefix_last_uuid = entries[tail_start - 1]["uuid"]
|
||||
tail_first_parent = entries[tail_start]["parentUuid"]
|
||||
assert tail_first_parent == prefix_last_uuid, (
|
||||
f"Tail parentUuid {tail_first_parent!r} != "
|
||||
f"last prefix uuid {prefix_last_uuid!r}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thinking_blocks_still_works(self, mock_chat_config):
|
||||
"""Compaction should still work normally when there are no thinking
|
||||
blocks in the transcript."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", [{"type": "text", "text": "Hi"}]),
|
||||
("user", "More"),
|
||||
("assistant", [{"type": "text", "text": "Details"}]),
|
||||
]
|
||||
)
|
||||
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[summary]"},
|
||||
{"role": "assistant", "content": "Summary"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 200,
|
||||
"token_count": 50,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
# Verify last assistant content is preserved even without thinking blocks
|
||||
last_content = _last_assistant_content(result)
|
||||
assert last_content is not None
|
||||
assert last_content == [{"type": "text", "text": "Details"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _transcript_to_messages — thinking block handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTranscriptToMessagesThinking:
|
||||
def test_thinking_blocks_excluded_from_flattened_content(self):
|
||||
"""When _transcript_to_messages flattens content, thinking block
|
||||
text should not leak into the message content string."""
|
||||
transcript = build_structured_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
(
|
||||
"assistant",
|
||||
[
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "SECRET_THOUGHT",
|
||||
"signature": "sig",
|
||||
},
|
||||
{"type": "text", "text": "Visible response"},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
messages = _transcript_to_messages(transcript)
|
||||
assistant_msg = [m for m in messages if m["role"] == "assistant"][0]
|
||||
assert "SECRET_THOUGHT" not in assistant_msg["content"]
|
||||
assert "Visible response" in assistant_msg["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# response_adapter — ThinkingBlock handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResponseAdapterThinkingBlock:
|
||||
def test_thinking_block_does_not_crash(self):
|
||||
"""ThinkingBlock in AssistantMessage should not cause an error."""
|
||||
adapter = SDKResponseAdapter(message_id="msg-1", session_id="sess-1")
|
||||
msg = AssistantMessage(
|
||||
content=[
|
||||
ThinkingBlock(
|
||||
thinking="Let me think about this...",
|
||||
signature="sig_test_123",
|
||||
),
|
||||
TextBlock(text="Here is my response."),
|
||||
],
|
||||
model="claude-test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# Should produce stream events for text only, no crash
|
||||
types = [type(r) for r in results]
|
||||
assert StreamStartStep in types
|
||||
assert StreamTextStart in types or StreamTextDelta in types
|
||||
|
||||
def test_thinking_block_does_not_emit_stream_events(self):
|
||||
"""ThinkingBlock should NOT produce any StreamTextDelta events
|
||||
containing thinking content."""
|
||||
adapter = SDKResponseAdapter(message_id="msg-1", session_id="sess-1")
|
||||
msg = AssistantMessage(
|
||||
content=[
|
||||
ThinkingBlock(
|
||||
thinking="My secret thoughts",
|
||||
signature="sig_test_456",
|
||||
),
|
||||
TextBlock(text="Public response"),
|
||||
],
|
||||
model="claude-test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
text_deltas = [r for r in results if isinstance(r, StreamTextDelta)]
|
||||
for delta in text_deltas:
|
||||
assert "secret thoughts" not in (delta.delta or "")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_sdk_content_blocks — redacted_thinking handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatSdkContentBlocks:
|
||||
def test_thinking_block_preserved(self):
|
||||
"""ThinkingBlock should be serialized with type, thinking, and signature."""
|
||||
blocks = [
|
||||
ThinkingBlock(thinking="My thoughts", signature="sig123"),
|
||||
TextBlock(text="Response"),
|
||||
]
|
||||
result = _format_sdk_content_blocks(blocks)
|
||||
assert len(result) == 2
|
||||
assert result[0] == {
|
||||
"type": "thinking",
|
||||
"thinking": "My thoughts",
|
||||
"signature": "sig123",
|
||||
}
|
||||
assert result[1] == {"type": "text", "text": "Response"}
|
||||
|
||||
def test_raw_dict_redacted_thinking_preserved(self):
|
||||
"""Raw dict blocks (e.g. redacted_thinking) pass through unchanged."""
|
||||
raw_block = {"type": "redacted_thinking", "data": "EmwKAh...encrypted"}
|
||||
blocks = [
|
||||
raw_block,
|
||||
TextBlock(text="Response"),
|
||||
]
|
||||
result = _format_sdk_content_blocks(blocks)
|
||||
assert len(result) == 2
|
||||
assert result[0] == raw_block
|
||||
assert result[1] == {"type": "text", "text": "Response"}
|
||||
@@ -605,20 +605,31 @@ COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string.
|
||||
|
||||
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
|
||||
placeholders. This is intentional: ``compress_context`` requires plain
|
||||
text for token counting and LLM summarization. The structural loss is
|
||||
acceptable because compaction only runs when the original transcript was
|
||||
already too large for the model — a summarized plain-text version is
|
||||
better than no context at all.
|
||||
placeholders. ``thinking`` and ``redacted_thinking`` blocks are
|
||||
silently dropped — they carry no useful context for compression
|
||||
summaries and must not leak into compacted transcripts (the Anthropic
|
||||
API requires thinking blocks in the last assistant message to be
|
||||
value-identical to the original response; including stale thinking
|
||||
text would violate that constraint).
|
||||
|
||||
This is intentional: ``compress_context`` requires plain text for
|
||||
token counting and LLM summarization. The structural loss is
|
||||
acceptable because compaction only runs when the original transcript
|
||||
was already too large for the model.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict):
|
||||
btype = block.get("type", "")
|
||||
if btype in _THINKING_BLOCK_TYPES:
|
||||
continue
|
||||
if btype == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif btype == "tool_use":
|
||||
@@ -805,6 +816,68 @@ async def _run_compression(
|
||||
)
|
||||
|
||||
|
||||
def _find_last_assistant_entry(
|
||||
content: str,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Split JSONL lines into (compressible_prefix, preserved_tail).
|
||||
|
||||
The tail starts at the **first** entry of the last assistant turn and
|
||||
includes everything after it (typically trailing user messages). An
|
||||
assistant turn can span multiple consecutive JSONL entries sharing the
|
||||
same ``message.id`` (e.g., a thinking entry followed by a tool_use
|
||||
entry). All entries of the turn are preserved verbatim.
|
||||
|
||||
The Anthropic API requires that ``thinking`` and ``redacted_thinking``
|
||||
blocks in the **last** assistant message remain value-identical to the
|
||||
original response (the API validates parsed signature values, not raw
|
||||
JSON bytes). By excluding the entire turn from compression we
|
||||
guarantee those blocks are never altered.
|
||||
|
||||
Returns ``(all_lines, [])`` when no assistant entry is found.
|
||||
"""
|
||||
lines = [ln for ln in content.strip().split("\n") if ln.strip()]
|
||||
|
||||
# Parse all lines once to avoid double JSON deserialization.
|
||||
# json.loads with fallback=None returns Any; non-dict entries are
|
||||
# safely skipped by the isinstance(entry, dict) guards below.
|
||||
parsed: list = [json.loads(ln, fallback=None) for ln in lines]
|
||||
|
||||
# Reverse scan: find the message.id and index of the last assistant entry.
|
||||
last_asst_msg_id: str | None = None
|
||||
last_asst_idx: int | None = None
|
||||
for i in range(len(parsed) - 1, -1, -1):
|
||||
entry = parsed[i]
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
if msg.get("role") == "assistant":
|
||||
last_asst_idx = i
|
||||
last_asst_msg_id = msg.get("id")
|
||||
break
|
||||
|
||||
if last_asst_idx is None:
|
||||
return lines, []
|
||||
|
||||
# If the assistant entry has no message.id, fall back to preserving
|
||||
# from that single entry onward — safer than compressing everything.
|
||||
if last_asst_msg_id is None:
|
||||
return lines[:last_asst_idx], lines[last_asst_idx:]
|
||||
|
||||
# Forward scan: find the first entry of this turn (same message.id).
|
||||
first_turn_idx: int | None = None
|
||||
for i, entry in enumerate(parsed):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
if msg.get("role") == "assistant" and msg.get("id") == last_asst_msg_id:
|
||||
first_turn_idx = i
|
||||
break
|
||||
|
||||
if first_turn_idx is None:
|
||||
return lines, []
|
||||
return lines[:first_turn_idx], lines[first_turn_idx:]
|
||||
|
||||
|
||||
async def compact_transcript(
|
||||
content: str,
|
||||
*,
|
||||
@@ -816,42 +889,50 @@ async def compact_transcript(
|
||||
Converts transcript entries to plain messages, runs ``compress_context``
|
||||
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||
|
||||
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
|
||||
is flattened to plain text for compression. This matches the fidelity of
|
||||
the Plan C (DB compression) fallback path, where
|
||||
``_format_conversation_context`` similarly renders tool calls as
|
||||
``You called tool: name(args)`` and results as ``Tool result: ...``.
|
||||
Neither path preserves structured API content blocks — the compacted
|
||||
context serves as text history for the LLM, which creates proper
|
||||
structured tool calls going forward.
|
||||
The **last assistant entry** (and any entries after it) are preserved
|
||||
verbatim — never flattened or compressed. The Anthropic API requires
|
||||
``thinking`` and ``redacted_thinking`` blocks in the latest assistant
|
||||
message to be value-identical to the original response (the API
|
||||
validates parsed signature values, not raw JSON bytes); compressing
|
||||
them would destroy the cryptographic signatures and cause
|
||||
``invalid_request_error``.
|
||||
|
||||
Images are per-turn attachments loaded from workspace storage by file ID
|
||||
(via ``_prepare_file_attachments``), not part of the conversation history.
|
||||
They are re-attached each turn and are unaffected by compaction.
|
||||
Structured content in *older* assistant entries (``tool_use`` blocks,
|
||||
``thinking`` blocks, ``tool_result`` nesting, images) is flattened to
|
||||
plain text for compression. This matches the fidelity of the Plan C
|
||||
(DB compression) fallback path.
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
|
||||
See also:
|
||||
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
|
||||
lists for pre-query DB history. Both share ``compress_context()``
|
||||
but operate on different input formats (JSONL transcript entries
|
||||
here vs. ChatMessage dicts there).
|
||||
lists for pre-query DB history.
|
||||
"""
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||
prefix_lines, tail_lines = _find_last_assistant_entry(content)
|
||||
|
||||
# Build the JSONL string for the compressible prefix
|
||||
prefix_content = "\n".join(prefix_lines) + "\n" if prefix_lines else ""
|
||||
messages = _transcript_to_messages(prefix_content) if prefix_content else []
|
||||
|
||||
if len(messages) + len(tail_lines) < 2:
|
||||
total = len(messages) + len(tail_lines)
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, total)
|
||||
return None
|
||||
if not messages:
|
||||
logger.warning("%s Nothing to compress (only tail entries remain)", log_prefix)
|
||||
return None
|
||||
try:
|
||||
result = await _run_compression(messages, model, log_prefix)
|
||||
if not result.was_compacted:
|
||||
# Compressor says it's within budget, but the SDK rejected it.
|
||||
# Return None so the caller falls through to DB fallback.
|
||||
logger.warning(
|
||||
"%s Compressor reports within budget but SDK rejected — "
|
||||
"signalling failure",
|
||||
log_prefix,
|
||||
)
|
||||
return None
|
||||
if not result.messages:
|
||||
logger.warning("%s Compressor returned empty messages", log_prefix)
|
||||
return None
|
||||
logger.info(
|
||||
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||
log_prefix,
|
||||
@@ -860,7 +941,29 @@ async def compact_transcript(
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
compacted = _messages_to_transcript(result.messages)
|
||||
compressed_part = _messages_to_transcript(result.messages)
|
||||
|
||||
# Re-append the preserved tail (last assistant + trailing entries)
|
||||
# with parentUuid patched to chain onto the compressed prefix.
|
||||
tail_part = _rechain_tail(compressed_part, tail_lines)
|
||||
compacted = compressed_part + tail_part
|
||||
|
||||
if len(compacted) >= len(content):
|
||||
# Byte count can increase due to preserved tail entries
|
||||
# (thinking blocks, JSON overhead) even when token count
|
||||
# decreased. Log a warning but still return — the API
|
||||
# validates tokens not bytes, and the caller falls through
|
||||
# to DB fallback if the transcript is still too large.
|
||||
logger.warning(
|
||||
"%s Compacted transcript (%d bytes) is not smaller than "
|
||||
"original (%d bytes) — may still reduce token count",
|
||||
log_prefix,
|
||||
len(compacted),
|
||||
len(content),
|
||||
)
|
||||
# Authoritative validation — the caller (_reduce_context) also
|
||||
# validates, but this is the canonical check that guarantees we
|
||||
# never return a malformed transcript from this function.
|
||||
if not validate_transcript(compacted):
|
||||
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||
return None
|
||||
@@ -870,3 +973,43 @@ async def compact_transcript(
|
||||
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _rechain_tail(compressed_prefix: str, tail_lines: list[str]) -> str:
|
||||
"""Patch tail entries so their parentUuid chain links to the compressed prefix.
|
||||
|
||||
The first tail entry's ``parentUuid`` is set to the ``uuid`` of the
|
||||
last entry in the compressed prefix. Subsequent tail entries are
|
||||
rechained to point to their predecessor in the tail — their original
|
||||
``parentUuid`` values may reference entries that were compressed away.
|
||||
"""
|
||||
if not tail_lines:
|
||||
return ""
|
||||
# Find the last uuid in the compressed prefix
|
||||
last_prefix_uuid = ""
|
||||
for line in reversed(compressed_prefix.strip().split("\n")):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if isinstance(entry, dict) and "uuid" in entry:
|
||||
last_prefix_uuid = entry["uuid"]
|
||||
break
|
||||
|
||||
result_lines: list[str] = []
|
||||
prev_uuid: str | None = None
|
||||
for i, line in enumerate(tail_lines):
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
# Safety guard: _find_last_assistant_entry already filters empty
|
||||
# lines, and well-formed JSONL always parses to dicts. Non-dict
|
||||
# lines are passed through unchanged; prev_uuid is intentionally
|
||||
# NOT updated so the next dict entry chains to the last known uuid.
|
||||
result_lines.append(line)
|
||||
continue
|
||||
if i == 0:
|
||||
entry["parentUuid"] = last_prefix_uuid
|
||||
elif prev_uuid is not None:
|
||||
entry["parentUuid"] = prev_uuid
|
||||
prev_uuid = entry.get("uuid")
|
||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
return "\n".join(result_lines) + "\n"
|
||||
|
||||
@@ -537,7 +537,7 @@ async def check_hitl_review(
|
||||
)
|
||||
|
||||
synthetic_node_exec_id = (
|
||||
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}" f"{uuid.uuid4().hex[:8]}"
|
||||
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
review_context = ExecutionContext(
|
||||
@@ -582,7 +582,16 @@ def _resolve_discriminated_credentials(
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
"""Resolve credential requirements, applying discriminator logic where needed.
|
||||
|
||||
Handles two discrimination modes:
|
||||
1. **Provider-based** (``discriminator_mapping`` is set): the discriminator
|
||||
field value selects the provider (e.g. an AI model name -> provider).
|
||||
2. **URL/host-based** (``discriminator`` is set but ``discriminator_mapping``
|
||||
is ``None``): the discriminator field value (typically a URL) is added to
|
||||
``discriminator_values`` so that host-scoped credential matching can
|
||||
compare the credential's host against the target URL.
|
||||
"""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
@@ -592,25 +601,42 @@ def _resolve_discriminated_credentials(
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
if field_info.discriminator:
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(field_info.discriminator)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
"Discriminated provider for %s: %s -> %s",
|
||||
field_name,
|
||||
discriminator_value,
|
||||
effective_field_info.provider,
|
||||
)
|
||||
if discriminator_value is not None:
|
||||
if field_info.discriminator_mapping:
|
||||
# Provider-based discrimination (e.g. model -> provider)
|
||||
if discriminator_value in field_info.discriminator_mapping:
|
||||
effective_field_info = field_info.discriminate(
|
||||
discriminator_value
|
||||
)
|
||||
effective_field_info.discriminator_values.add(
|
||||
discriminator_value
|
||||
)
|
||||
# Model names are safe to log (not PII); URLs are
|
||||
# intentionally omitted in the host-based branch below.
|
||||
logger.debug(
|
||||
"Discriminated provider for %s: %s -> %s",
|
||||
field_name,
|
||||
discriminator_value,
|
||||
effective_field_info.provider,
|
||||
)
|
||||
else:
|
||||
# URL/host-based discrimination (e.g. url -> host matching).
|
||||
# Deep copy to avoid mutating the cached schema-level
|
||||
# field_info (model_copy() is shallow — the mutable set
|
||||
# would be shared).
|
||||
effective_field_info = field_info.model_copy(deep=True)
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
"Added discriminator value for host matching on %s",
|
||||
field_name,
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
|
||||
@@ -0,0 +1,916 @@
|
||||
"""Tests for credential resolution across all credential types in the CoPilot.
|
||||
|
||||
These tests verify that:
|
||||
1. `_resolve_discriminated_credentials` correctly populates discriminator_values
|
||||
for URL-based (host-scoped) and provider-based (api_key) credential fields.
|
||||
2. `find_matching_credential` correctly matches credentials for all types:
|
||||
APIKeyCredentials, OAuth2Credentials, UserPasswordCredentials, and
|
||||
HostScopedCredentials.
|
||||
3. The full `resolve_block_credentials` flow correctly resolves matching
|
||||
credentials or reports them as missing for each credential type.
|
||||
4. `RunBlockTool._execute` end-to-end tests return correct response types.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.http import SendAuthenticatedWebRequestBlock
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._test_data import make_session
|
||||
from .helpers import _resolve_discriminated_credentials, resolve_block_credentials
|
||||
from .models import BlockDetailsResponse, SetupRequirementsResponse
|
||||
from .run_block import RunBlockTool
|
||||
from .utils import find_matching_credential
|
||||
|
||||
_TEST_USER_ID = "test-user-http-cred"
|
||||
|
||||
# Properly typed constants to avoid type: ignore on CredentialsFieldInfo construction.
|
||||
_HOST_SCOPED_TYPES: frozenset[CredentialsType] = frozenset(["host_scoped"])
|
||||
_API_KEY_TYPES: frozenset[CredentialsType] = frozenset(["api_key"])
|
||||
_OAUTH2_TYPES: frozenset[CredentialsType] = frozenset(["oauth2"])
|
||||
_USER_PASSWORD_TYPES: frozenset[CredentialsType] = frozenset(["user_password"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_discriminated_credentials tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveDiscriminatedCredentials:
|
||||
"""Tests for _resolve_discriminated_credentials with URL-based discrimination."""
|
||||
|
||||
def _get_auth_block(self):
|
||||
return SendAuthenticatedWebRequestBlock()
|
||||
|
||||
def test_url_discriminator_populates_discriminator_values(self):
|
||||
"""When input_data contains a URL, discriminator_values should include it."""
|
||||
block = self._get_auth_block()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
assert "credentials" in result
|
||||
field_info = result["credentials"]
|
||||
assert "https://api.example.com/v1/data" in field_info.discriminator_values
|
||||
|
||||
def test_url_discriminator_without_url_keeps_empty_values(self):
|
||||
"""When no URL is provided, discriminator_values should remain empty."""
|
||||
block = self._get_auth_block()
|
||||
input_data = {}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
assert "credentials" in result
|
||||
field_info = result["credentials"]
|
||||
assert len(field_info.discriminator_values) == 0
|
||||
|
||||
def test_url_discriminator_does_not_mutate_original_field_info(self):
|
||||
"""The original block schema field_info must not be mutated."""
|
||||
block = self._get_auth_block()
|
||||
|
||||
# Grab a reference to the original schema-level field_info
|
||||
original_info = block.input_schema.get_credentials_fields_info()["credentials"]
|
||||
|
||||
# Call with a URL, which adds to discriminator_values on the copy
|
||||
_resolve_discriminated_credentials(
|
||||
block, {"url": "https://api.example.com/v1/data"}
|
||||
)
|
||||
|
||||
# The original object must remain unchanged
|
||||
assert len(original_info.discriminator_values) == 0
|
||||
|
||||
# And a fresh call without URL should also return empty values
|
||||
result = _resolve_discriminated_credentials(block, {})
|
||||
field_info = result["credentials"]
|
||||
assert len(field_info.discriminator_values) == 0
|
||||
|
||||
def test_url_discriminator_preserves_provider_and_type(self):
|
||||
"""Provider and supported_types should be preserved after URL discrimination."""
|
||||
block = self._get_auth_block()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
field_info = result["credentials"]
|
||||
assert ProviderName.HTTP in field_info.provider
|
||||
assert "host_scoped" in field_info.supported_types
|
||||
|
||||
def test_provider_discriminator_still_works(self):
|
||||
"""Verify provider-based discrimination (e.g. model -> provider) is preserved.
|
||||
|
||||
The refactored conditional in _resolve_discriminated_credentials split the
|
||||
original single ``if`` into nested ``if/else`` branches. This test ensures
|
||||
the provider-based path still narrows the provider correctly.
|
||||
"""
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
||||
block = AITextGeneratorBlock()
|
||||
input_data = {"model": "gpt-4o-mini"}
|
||||
|
||||
result = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
assert "credentials" in result
|
||||
field_info = result["credentials"]
|
||||
# Should narrow provider to openai
|
||||
assert ProviderName.OPENAI in field_info.provider
|
||||
assert "gpt-4o-mini" in field_info.discriminator_values
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (host-scoped)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingHostScopedCredential:
|
||||
"""Tests for find_matching_credential with host-scoped credentials."""
|
||||
|
||||
def _make_host_scoped_cred(
|
||||
self, host: str, cred_id: str = "test-cred-id"
|
||||
) -> HostScopedCredentials:
|
||||
return HostScopedCredentials(
|
||||
id=cred_id,
|
||||
provider="http",
|
||||
host=host,
|
||||
headers={"Authorization": SecretStr("Bearer test-token")},
|
||||
title=f"Cred for {host}",
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self, discriminator_values: set | None = None
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.HTTP]),
|
||||
credentials_types=_HOST_SCOPED_TYPES,
|
||||
credentials_scopes=None,
|
||||
discriminator="url",
|
||||
discriminator_values=discriminator_values or set(),
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_host(self):
|
||||
"""A host-scoped credential matching the URL host should be returned."""
|
||||
cred = self._make_host_scoped_cred("api.example.com")
|
||||
field_info = self._make_field_info({"https://api.example.com/v1/data"})
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_host(self):
|
||||
"""A host-scoped credential for a different host should not match."""
|
||||
cred = self._make_host_scoped_cred("api.github.com")
|
||||
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_matches_any_when_no_discriminator_values(self):
|
||||
"""With empty discriminator_values, any host-scoped credential matches.
|
||||
|
||||
Note: this tests the current fallback behavior in _credential_is_for_host()
|
||||
where empty discriminator_values means "no host constraint" and any
|
||||
host-scoped credential is accepted. This is by design for the case where
|
||||
the target URL is not yet known (e.g. schema preview with empty input).
|
||||
"""
|
||||
cred = self._make_host_scoped_cred("api.anything.com")
|
||||
field_info = self._make_field_info(set())
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_wildcard_host_matching(self):
|
||||
"""Wildcard host (*.example.com) should match subdomains."""
|
||||
cred = self._make_host_scoped_cred("*.example.com")
|
||||
field_info = self._make_field_info({"https://api.example.com/v1/data"})
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple host-scoped credentials exist, the correct one is selected."""
|
||||
cred_github = self._make_host_scoped_cred("api.github.com", "github-cred")
|
||||
cred_stripe = self._make_host_scoped_cred("api.stripe.com", "stripe-cred")
|
||||
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
|
||||
|
||||
result = find_matching_credential([cred_github, cred_stripe], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "stripe-cred"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (api_key)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingAPIKeyCredential:
|
||||
"""Tests for find_matching_credential with API key credentials."""
|
||||
|
||||
def _make_api_key_cred(
|
||||
self, provider: str = "google_maps", cred_id: str = "test-api-key-id"
|
||||
) -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
api_key=SecretStr("sk-test-key-123"),
|
||||
title=f"API key for {provider}",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self, provider: ProviderName = ProviderName.GOOGLE_MAPS
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_types=_API_KEY_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_provider(self):
|
||||
"""An API key credential matching the provider should be returned."""
|
||||
cred = self._make_api_key_cred("google_maps")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_provider(self):
|
||||
"""An API key credential for a different provider should not match."""
|
||||
cred = self._make_api_key_cred("openai")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_rejects_credential_for_wrong_type(self):
|
||||
"""An OAuth2 credential should not match an api_key requirement."""
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-cred-id",
|
||||
provider="google_maps",
|
||||
access_token=SecretStr("mock-token"),
|
||||
scopes=[],
|
||||
title="OAuth cred (wrong type)",
|
||||
)
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([oauth_cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple API key credentials exist, the correct provider is selected."""
|
||||
cred_maps = self._make_api_key_cred("google_maps", "maps-key")
|
||||
cred_openai = self._make_api_key_cred("openai", "openai-key")
|
||||
field_info = self._make_field_info(ProviderName.OPENAI)
|
||||
|
||||
result = find_matching_credential([cred_maps, cred_openai], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "openai-key"
|
||||
|
||||
def test_returns_none_when_no_credentials(self):
|
||||
"""Should return None when the credential list is empty."""
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
|
||||
|
||||
result = find_matching_credential([], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (oauth2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingOAuth2Credential:
|
||||
"""Tests for find_matching_credential with OAuth2 credentials."""
|
||||
|
||||
def _make_oauth2_cred(
|
||||
self,
|
||||
provider: str = "google",
|
||||
scopes: list[str] | None = None,
|
||||
cred_id: str = "test-oauth2-id",
|
||||
) -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
access_token=SecretStr("mock-access-token"),
|
||||
refresh_token=SecretStr("mock-refresh-token"),
|
||||
access_token_expires_at=1234567890,
|
||||
scopes=scopes or [],
|
||||
title=f"OAuth2 cred for {provider}",
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self,
|
||||
provider: ProviderName = ProviderName.GOOGLE,
|
||||
required_scopes: frozenset[str] | None = None,
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_types=_OAUTH2_TYPES,
|
||||
credentials_scopes=required_scopes,
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_provider(self):
|
||||
"""An OAuth2 credential matching the provider should be returned."""
|
||||
cred = self._make_oauth2_cred("google")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_provider(self):
|
||||
"""An OAuth2 credential for a different provider should not match."""
|
||||
cred = self._make_oauth2_cred("github")
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_matches_credential_with_required_scopes(self):
|
||||
"""An OAuth2 credential with all required scopes should match."""
|
||||
cred = self._make_oauth2_cred(
|
||||
"google",
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
],
|
||||
)
|
||||
field_info = self._make_field_info(
|
||||
ProviderName.GOOGLE,
|
||||
required_scopes=frozenset(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
),
|
||||
)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_rejects_credential_with_insufficient_scopes(self):
|
||||
"""An OAuth2 credential missing required scopes should not match."""
|
||||
cred = self._make_oauth2_cred(
|
||||
"google",
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
field_info = self._make_field_info(
|
||||
ProviderName.GOOGLE,
|
||||
required_scopes=frozenset(
|
||||
[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_matches_credential_when_no_scopes_required(self):
|
||||
"""An OAuth2 credential should match when no scopes are required."""
|
||||
cred = self._make_oauth2_cred("google", scopes=[])
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple OAuth2 credentials exist, the correct one is selected."""
|
||||
cred_google = self._make_oauth2_cred("google", cred_id="google-cred")
|
||||
cred_github = self._make_oauth2_cred("github", cred_id="github-cred")
|
||||
field_info = self._make_field_info(ProviderName.GITHUB)
|
||||
|
||||
result = find_matching_credential([cred_google, cred_github], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "github-cred"
|
||||
|
||||
def test_returns_none_when_no_credentials(self):
|
||||
"""Should return None when the credential list is empty."""
|
||||
field_info = self._make_field_info(ProviderName.GOOGLE)
|
||||
|
||||
result = find_matching_credential([], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (user_password)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingUserPasswordCredential:
|
||||
"""Tests for find_matching_credential with user/password credentials."""
|
||||
|
||||
def _make_user_password_cred(
|
||||
self, provider: str = "smtp", cred_id: str = "test-userpass-id"
|
||||
) -> UserPasswordCredentials:
|
||||
return UserPasswordCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
username=SecretStr("test-user"),
|
||||
password=SecretStr("test-pass"),
|
||||
title=f"Credentials for {provider}",
|
||||
)
|
||||
|
||||
def _make_field_info(
|
||||
self, provider: ProviderName = ProviderName.SMTP
|
||||
) -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_types=_USER_PASSWORD_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
def test_matches_credential_for_correct_provider(self):
|
||||
"""A user/password credential matching the provider should be returned."""
|
||||
cred = self._make_user_password_cred("smtp")
|
||||
field_info = self._make_field_info(ProviderName.SMTP)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
assert result.id == cred.id
|
||||
|
||||
def test_rejects_credential_for_wrong_provider(self):
|
||||
"""A user/password credential for a different provider should not match."""
|
||||
cred = self._make_user_password_cred("smtp")
|
||||
field_info = self._make_field_info(ProviderName.HUBSPOT)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_rejects_credential_for_wrong_type(self):
|
||||
"""An API key credential should not match a user_password requirement."""
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="api-key-cred-id",
|
||||
provider="smtp",
|
||||
api_key=SecretStr("wrong-type-key"),
|
||||
title="API key cred (wrong type)",
|
||||
)
|
||||
field_info = self._make_field_info(ProviderName.SMTP)
|
||||
|
||||
result = find_matching_credential([api_key_cred], field_info)
|
||||
assert result is None
|
||||
|
||||
def test_selects_correct_credential_from_multiple(self):
|
||||
"""When multiple user/password credentials exist, the correct one is selected."""
|
||||
cred_smtp = self._make_user_password_cred("smtp", "smtp-cred")
|
||||
cred_hubspot = self._make_user_password_cred("hubspot", "hubspot-cred")
|
||||
field_info = self._make_field_info(ProviderName.HUBSPOT)
|
||||
|
||||
result = find_matching_credential([cred_smtp, cred_hubspot], field_info)
|
||||
assert result is not None
|
||||
assert result.id == "hubspot-cred"
|
||||
|
||||
def test_returns_none_when_no_credentials(self):
|
||||
"""Should return None when the credential list is empty."""
|
||||
field_info = self._make_field_info(ProviderName.SMTP)
|
||||
|
||||
result = find_matching_credential([], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_credential tests (mixed credential types)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMatchingCredentialMixedTypes:
|
||||
"""Tests that find_matching_credential correctly filters by type in a mixed list."""
|
||||
|
||||
def test_selects_api_key_from_mixed_list(self):
|
||||
"""API key requirement should skip OAuth2 and user_password credentials."""
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="openai",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
)
|
||||
userpass_cred = UserPasswordCredentials(
|
||||
id="userpass-id",
|
||||
provider="openai",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
)
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="apikey-id",
|
||||
provider="openai",
|
||||
api_key=SecretStr("sk-key"),
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.OPENAI]),
|
||||
credentials_types=_API_KEY_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
result = find_matching_credential(
|
||||
[oauth_cred, userpass_cred, api_key_cred], field_info
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == "apikey-id"
|
||||
|
||||
def test_selects_oauth2_from_mixed_list(self):
|
||||
"""OAuth2 requirement should skip API key and user_password credentials."""
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="apikey-id",
|
||||
provider="google",
|
||||
api_key=SecretStr("key"),
|
||||
)
|
||||
userpass_cred = UserPasswordCredentials(
|
||||
id="userpass-id",
|
||||
provider="google",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
)
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="google",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.GOOGLE]),
|
||||
credentials_types=_OAUTH2_TYPES,
|
||||
credentials_scopes=frozenset(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
),
|
||||
)
|
||||
|
||||
result = find_matching_credential(
|
||||
[api_key_cred, userpass_cred, oauth_cred], field_info
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == "oauth-id"
|
||||
|
||||
def test_selects_user_password_from_mixed_list(self):
|
||||
"""User/password requirement should skip API key and OAuth2 credentials."""
|
||||
api_key_cred = APIKeyCredentials(
|
||||
id="apikey-id",
|
||||
provider="smtp",
|
||||
api_key=SecretStr("key"),
|
||||
)
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="smtp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
)
|
||||
userpass_cred = UserPasswordCredentials(
|
||||
id="userpass-id",
|
||||
provider="smtp",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.SMTP]),
|
||||
credentials_types=_USER_PASSWORD_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
result = find_matching_credential(
|
||||
[api_key_cred, oauth_cred, userpass_cred], field_info
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == "userpass-id"
|
||||
|
||||
def test_returns_none_when_only_wrong_types_available(self):
|
||||
"""Should return None when all available creds have the wrong type."""
|
||||
oauth_cred = OAuth2Credentials(
|
||||
id="oauth-id",
|
||||
provider="google_maps",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
)
|
||||
field_info = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.GOOGLE_MAPS]),
|
||||
credentials_types=_API_KEY_TYPES,
|
||||
credentials_scopes=None,
|
||||
)
|
||||
|
||||
result = find_matching_credential([oauth_cred], field_info)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_block_credentials tests (integration — all credential types)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveBlockCredentials:
|
||||
"""Integration tests for resolve_block_credentials across credential types."""
|
||||
|
||||
async def test_matches_host_scoped_credential_for_url(self):
|
||||
"""resolve_block_credentials should match a host-scoped cred for the given URL."""
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
mock_cred = HostScopedCredentials(
|
||||
id="matching-cred-id",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
title="Example API Cred",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "matching-cred-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_when_no_matching_host(self):
|
||||
"""resolve_block_credentials should report missing creds when host doesn't match."""
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
input_data = {"url": "https://api.stripe.com/v1/charges"}
|
||||
|
||||
wrong_host_cred = HostScopedCredentials(
|
||||
id="wrong-cred-id",
|
||||
provider="http",
|
||||
host="api.github.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
title="GitHub API Cred",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[wrong_host_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_reports_missing_when_no_credentials(self):
|
||||
"""resolve_block_credentials should report missing when user has no creds at all."""
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
input_data = {"url": "https://api.example.com/v1/data"}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_matches_api_key_credential_for_llm_block(self):
|
||||
"""resolve_block_credentials should match an API key cred for an LLM block."""
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
||||
block = AITextGeneratorBlock()
|
||||
input_data = {"model": "gpt-4o-mini"}
|
||||
|
||||
mock_cred = APIKeyCredentials(
|
||||
id="openai-key-id",
|
||||
provider="openai",
|
||||
api_key=SecretStr("sk-test-key"),
|
||||
title="OpenAI API Key",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "openai-key-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_api_key_for_wrong_provider(self):
|
||||
"""resolve_block_credentials should report missing when API key provider mismatches."""
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
|
||||
block = AITextGeneratorBlock()
|
||||
input_data = {"model": "gpt-4o-mini"}
|
||||
|
||||
wrong_provider_cred = APIKeyCredentials(
|
||||
id="wrong-key-id",
|
||||
provider="google_maps",
|
||||
api_key=SecretStr("sk-wrong"),
|
||||
title="Google Maps Key",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[wrong_provider_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_matches_oauth2_credential_for_google_block(self):
|
||||
"""resolve_block_credentials should match an OAuth2 cred for a Google block."""
|
||||
from backend.blocks.google.gmail import GmailReadBlock
|
||||
|
||||
block = GmailReadBlock()
|
||||
input_data = {}
|
||||
|
||||
mock_cred = OAuth2Credentials(
|
||||
id="google-oauth-id",
|
||||
provider="google",
|
||||
access_token=SecretStr("mock-token"),
|
||||
refresh_token=SecretStr("mock-refresh"),
|
||||
access_token_expires_at=9999999999,
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
title="Google OAuth",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "google-oauth-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_oauth2_with_insufficient_scopes(self):
|
||||
"""resolve_block_credentials should report missing when OAuth2 scopes are insufficient."""
|
||||
from backend.blocks.google.gmail import GmailSendBlock
|
||||
|
||||
block = GmailSendBlock()
|
||||
input_data = {}
|
||||
|
||||
# GmailSendBlock requires gmail.send scope; provide only readonly
|
||||
insufficient_cred = OAuth2Credentials(
|
||||
id="limited-oauth-id",
|
||||
provider="google",
|
||||
access_token=SecretStr("mock-token"),
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
title="Google OAuth (limited)",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[insufficient_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
async def test_matches_user_password_credential_for_email_block(self):
|
||||
"""resolve_block_credentials should match a user/password cred for an SMTP block."""
|
||||
from backend.blocks.email_block import SendEmailBlock
|
||||
|
||||
block = SendEmailBlock()
|
||||
input_data = {}
|
||||
|
||||
mock_cred = UserPasswordCredentials(
|
||||
id="smtp-cred-id",
|
||||
provider="smtp",
|
||||
username=SecretStr("test-user"),
|
||||
password=SecretStr("test-pass"),
|
||||
title="SMTP Credentials",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert "credentials" in matched
|
||||
assert matched["credentials"].id == "smtp-cred-id"
|
||||
assert len(missing) == 0
|
||||
|
||||
async def test_reports_missing_user_password_for_wrong_provider(self):
|
||||
"""resolve_block_credentials should report missing when user/password provider mismatches."""
|
||||
from backend.blocks.email_block import SendEmailBlock
|
||||
|
||||
block = SendEmailBlock()
|
||||
input_data = {}
|
||||
|
||||
wrong_cred = UserPasswordCredentials(
|
||||
id="wrong-cred-id",
|
||||
provider="dataforseo",
|
||||
username=SecretStr("user"),
|
||||
password=SecretStr("pass"),
|
||||
title="DataForSEO Creds",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[wrong_cred],
|
||||
):
|
||||
matched, missing = await resolve_block_credentials(
|
||||
_TEST_USER_ID, block, input_data
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RunBlockTool integration tests for authenticated HTTP
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunBlockToolAuthenticatedHttp:
|
||||
"""End-to-end tests for RunBlockTool with SendAuthenticatedWebRequestBlock."""
|
||||
|
||||
async def test_returns_setup_requirements_when_creds_missing(self):
|
||||
"""When no matching host-scoped credential exists, return SetupRequirementsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=block.id,
|
||||
input_data={"url": "https://api.example.com/data", "method": "GET"},
|
||||
)
|
||||
|
||||
assert isinstance(response, SetupRequirementsResponse)
|
||||
assert "credentials" in response.message.lower()
|
||||
|
||||
async def test_returns_details_when_creds_matched_but_missing_required_inputs(self):
|
||||
"""When creds present + required inputs missing -> BlockDetailsResponse.
|
||||
|
||||
Note: with input_data={}, no URL is provided so discriminator_values is
|
||||
empty, meaning _credential_is_for_host() matches any host-scoped
|
||||
credential vacuously. This test exercises the "creds present + inputs
|
||||
missing" branch, not host-based matching (which is covered by
|
||||
TestFindMatchingHostScopedCredential and TestResolveBlockCredentials).
|
||||
"""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block = SendAuthenticatedWebRequestBlock()
|
||||
|
||||
mock_cred = HostScopedCredentials(
|
||||
id="matching-cred-id",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
title="Example API Cred",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.utils.get_user_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_cred],
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
# Call with empty input to get schema
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=block.id,
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockDetailsResponse)
|
||||
assert response.block.name == block.name
|
||||
# The matched credential should be included in the details
|
||||
assert len(response.block.credentials) > 0
|
||||
assert response.block.credentials[0].id == "matching-cred-id"
|
||||
@@ -121,7 +121,7 @@ def _serialize_missing_credential(
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
scopes = sorted(field_info.required_scopes or [])
|
||||
|
||||
return {
|
||||
result: dict[str, Any] = {
|
||||
"id": field_key,
|
||||
"title": field_key.replace("_", " ").title(),
|
||||
"provider": provider,
|
||||
@@ -131,6 +131,17 @@ def _serialize_missing_credential(
|
||||
"scopes": scopes,
|
||||
}
|
||||
|
||||
# Include discriminator info so the frontend can auto-match
|
||||
# host-scoped credentials (e.g. SendAuthenticatedWebRequestBlock).
|
||||
if field_info.discriminator:
|
||||
result["discriminator"] = field_info.discriminator
|
||||
if field_info.discriminator_values:
|
||||
result["discriminator_values"] = sorted(
|
||||
str(v) for v in field_info.discriminator_values
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def build_missing_credentials_from_graph(
|
||||
graph: GraphModel, matched_credentials: dict[str, CredentialsMetaInput] | None
|
||||
|
||||
@@ -722,7 +722,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
credentials_scopes=self.required_scopes,
|
||||
discriminator=self.discriminator,
|
||||
discriminator_mapping=self.discriminator_mapping,
|
||||
discriminator_values=self.discriminator_values,
|
||||
discriminator_values=set(self.discriminator_values), # defensive copy
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,15 @@ export function coerceCredentialFields(rawMissingCredentials: unknown): {
|
||||
? cred.scopes.filter((s): s is string => typeof s === "string")
|
||||
: undefined;
|
||||
|
||||
const schema = {
|
||||
const discriminator =
|
||||
typeof cred.discriminator === "string" ? cred.discriminator : undefined;
|
||||
const discriminatorValues = Array.isArray(cred.discriminator_values)
|
||||
? cred.discriminator_values.filter(
|
||||
(v): v is string => typeof v === "string",
|
||||
)
|
||||
: undefined;
|
||||
|
||||
const schema: Record<string, unknown> = {
|
||||
type: "object" as const,
|
||||
properties: {},
|
||||
credentials_provider: [provider],
|
||||
@@ -49,6 +57,13 @@ export function coerceCredentialFields(rawMissingCredentials: unknown): {
|
||||
credentials_scopes: scopes,
|
||||
};
|
||||
|
||||
if (discriminator) {
|
||||
schema.discriminator = discriminator;
|
||||
}
|
||||
if (discriminatorValues && discriminatorValues.length > 0) {
|
||||
schema.discriminator_values = discriminatorValues;
|
||||
}
|
||||
|
||||
credentialFields.push([key, schema]);
|
||||
requiredCredentials.add(key);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user