mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
53 Commits
gitlab-eve
...
add-github
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
58b8c66215 | ||
|
|
48ed801b27 | ||
|
|
756caebd27 | ||
|
|
95a5e6da0a | ||
|
|
ea50ade6ec | ||
|
|
420c8d0aa9 | ||
|
|
56bf86ee7f | ||
|
|
eee8655f70 | ||
|
|
8d573d9cc6 | ||
|
|
fd81234ac4 | ||
|
|
080ea0db5e | ||
|
|
5156c580fe | ||
|
|
25262a3a3f | ||
|
|
862c363ded | ||
|
|
cf156b0073 | ||
|
|
703a1eeca2 | ||
|
|
a6573de584 | ||
|
|
23b3b188c4 | ||
|
|
2ff094b363 | ||
|
|
b0169342f7 | ||
|
|
d5036c2813 | ||
|
|
8f0f3e49c8 | ||
|
|
03f49a40a0 | ||
|
|
3a85dbce78 | ||
|
|
4e63531fa6 | ||
|
|
db48a7af26 | ||
|
|
4c8179cd08 | ||
|
|
9e3aed7f53 | ||
|
|
3a40ecb931 | ||
|
|
f8b4f9369f | ||
|
|
5bb6522f2f | ||
|
|
273c38f0b6 | ||
|
|
02b999c166 | ||
|
|
28d26f8178 | ||
|
|
2468708293 | ||
|
|
a89811f952 | ||
|
|
aef5f9cc89 | ||
|
|
aea611602f | ||
|
|
fc4c62a73d | ||
|
|
b41dd2ba8b | ||
|
|
731183e069 | ||
|
|
c22c03eeb6 | ||
|
|
1093afdced | ||
|
|
93355fd770 | ||
|
|
6464eaed3c | ||
|
|
237948978b | ||
|
|
baa3a7e5b7 | ||
|
|
dd7234d712 | ||
|
|
2a6f5c8976 | ||
|
|
e86067c15b | ||
|
|
137bede1f5 | ||
|
|
8a1d80ac8f | ||
|
|
77043da280 |
@@ -46,34 +46,12 @@ These files contain image tags that **must** be updated whenever the SDK version
|
||||
### `openhands/version.py`
|
||||
- Reads version from `pyproject.toml` at runtime → `openhands.__version__`
|
||||
|
||||
### `openhands/resolver/issue_resolver.py`
|
||||
- Builds `ghcr.io/openhands/runtime:{openhands.__version__}-nikolaik` dynamically
|
||||
|
||||
### `openhands/runtime/utils/runtime_build.py`
|
||||
- Base repo URL `ghcr.io/openhands/runtime` is a constant; version comes from elsewhere
|
||||
|
||||
### `.github/scripts/update_pr_description.sh`
|
||||
- Uses `${SHORT_SHA}` variable at CI runtime, not hardcoded
|
||||
|
||||
### `enterprise/Dockerfile`
|
||||
- `ARG BASE="ghcr.io/openhands/openhands"` — base image, version supplied at build time
|
||||
|
||||
## V0 Legacy Files (separate update cadence)
|
||||
|
||||
These reference the V0 runtime image (`ghcr.io/openhands/runtime:X.Y-nikolaik`) for local Docker/Kubernetes paths. They are **not** updated as part of a V1 release but may be updated independently.
|
||||
|
||||
### `Development.md`
|
||||
- `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:X.Y-nikolaik`
|
||||
|
||||
### `openhands/runtime/impl/kubernetes/README.md`
|
||||
- `runtime_container_image = "docker.openhands.dev/openhands/runtime:X.Y-nikolaik"`
|
||||
|
||||
### `enterprise/enterprise_local/README.md`
|
||||
- Uses `ghcr.io/openhands/runtime:main-nikolaik` (points to `main`, not versioned)
|
||||
|
||||
### `third_party/runtime/impl/daytona/README.md`
|
||||
- Uses `${OPENHANDS_VERSION}` variable, not hardcoded
|
||||
|
||||
## Image Registries
|
||||
|
||||
| Registry | Usage |
|
||||
|
||||
228
.github/workflows/e2e-tests.yml
vendored
228
.github/workflows/e2e-tests.yml
vendored
@@ -1,228 +0,0 @@
|
||||
name: End-to-End Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, labeled]
|
||||
branches:
|
||||
- main
|
||||
- develop
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
e2e-tests:
|
||||
if: contains(github.event.pull_request.labels.*.name, 'end-to-end') || github.event_name == 'workflow_dispatch'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 60
|
||||
|
||||
env:
|
||||
GITHUB_REPO_NAME: ${{ github.repository }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install poetry via pipx
|
||||
uses: abatilo/actions-poetry@v4
|
||||
with:
|
||||
poetry-version: 2.1.3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'poetry'
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libgtk-3-0 libnotify4 libnss3 libxss1 libxtst6 xauth xvfb libgbm1 libasound2t64 netcat-openbsd
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: '22'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: 'frontend/package-lock.json'
|
||||
|
||||
- name: Setup environment for end-to-end tests
|
||||
run: |
|
||||
# Create test results directory
|
||||
mkdir -p test-results
|
||||
|
||||
# Create downloads directory for OpenHands (use a directory in the home folder)
|
||||
mkdir -p $HOME/downloads
|
||||
sudo chown -R $USER:$USER $HOME/downloads
|
||||
sudo chmod -R 755 $HOME/downloads
|
||||
|
||||
- name: Build OpenHands
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL || 'gpt-4o' }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY || 'test-key' }}
|
||||
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
|
||||
INSTALL_DOCKER: 1
|
||||
RUNTIME: docker
|
||||
FRONTEND_PORT: 12000
|
||||
FRONTEND_HOST: 0.0.0.0
|
||||
BACKEND_HOST: 0.0.0.0
|
||||
BACKEND_PORT: 3000
|
||||
ENABLE_BROWSER: true
|
||||
INSTALL_PLAYWRIGHT: 1
|
||||
run: |
|
||||
# Fix poetry.lock file if needed
|
||||
echo "Fixing poetry.lock file if needed..."
|
||||
poetry lock
|
||||
|
||||
# Build OpenHands using make build
|
||||
echo "Running make build..."
|
||||
make build
|
||||
|
||||
# Install Chromium Headless Shell for Playwright (needed for pytest-playwright)
|
||||
echo "Installing Chromium Headless Shell for Playwright..."
|
||||
poetry run playwright install chromium-headless-shell
|
||||
|
||||
# Verify Playwright browsers are installed (for e2e tests only)
|
||||
echo "Verifying Playwright browsers installation for e2e tests..."
|
||||
BROWSER_CHECK=$(poetry run python tests/e2e/check_playwright.py 2>/dev/null)
|
||||
|
||||
if [ "$BROWSER_CHECK" != "chromium_found" ]; then
|
||||
echo "ERROR: Chromium browser not found or not working for e2e tests"
|
||||
echo "$BROWSER_CHECK"
|
||||
exit 1
|
||||
else
|
||||
echo "Playwright browsers are properly installed for e2e tests."
|
||||
fi
|
||||
|
||||
# Docker runtime will handle workspace directory creation
|
||||
|
||||
# Start the application using make run with custom parameters and reduced logging
|
||||
echo "Starting OpenHands using make run..."
|
||||
# Set environment variables to reduce logging verbosity
|
||||
export PYTHONUNBUFFERED=1
|
||||
export LOG_LEVEL=WARNING
|
||||
export UVICORN_LOG_LEVEL=warning
|
||||
export OPENHANDS_LOG_LEVEL=WARNING
|
||||
FRONTEND_PORT=12000 FRONTEND_HOST=0.0.0.0 BACKEND_HOST=0.0.0.0 make run > /tmp/openhands-e2e-test.log 2>&1 &
|
||||
|
||||
# Store the PID of the make run process
|
||||
MAKE_PID=$!
|
||||
echo "OpenHands started with PID: $MAKE_PID"
|
||||
|
||||
# Wait for the application to start
|
||||
echo "Waiting for OpenHands to start..."
|
||||
max_attempts=15
|
||||
attempt=1
|
||||
|
||||
while [ $attempt -le $max_attempts ]; do
|
||||
echo "Checking if OpenHands is running (attempt $attempt of $max_attempts)..."
|
||||
|
||||
# Check if the process is still running
|
||||
if ! ps -p $MAKE_PID > /dev/null; then
|
||||
echo "ERROR: OpenHands process has terminated unexpectedly"
|
||||
echo "Last 50 lines of the log:"
|
||||
tail -n 50 /tmp/openhands-e2e-test.log
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if frontend port is open
|
||||
if nc -z localhost 12000; then
|
||||
# Verify we can get HTML content
|
||||
if curl -s http://localhost:12000 | grep -q "<html"; then
|
||||
echo "SUCCESS: OpenHands is running and serving HTML content on port 12000"
|
||||
break
|
||||
else
|
||||
echo "Port 12000 is open but not serving HTML content yet"
|
||||
fi
|
||||
else
|
||||
echo "Frontend port 12000 is not open yet"
|
||||
fi
|
||||
|
||||
# Show log output on each attempt
|
||||
echo "Recent log output:"
|
||||
tail -n 20 /tmp/openhands-e2e-test.log
|
||||
|
||||
# Wait before next attempt
|
||||
echo "Waiting 10 seconds before next check..."
|
||||
sleep 10
|
||||
attempt=$((attempt + 1))
|
||||
|
||||
# Exit if we've reached the maximum number of attempts
|
||||
if [ $attempt -gt $max_attempts ]; then
|
||||
echo "ERROR: OpenHands failed to start after $max_attempts attempts"
|
||||
echo "Last 50 lines of the log:"
|
||||
tail -n 50 /tmp/openhands-e2e-test.log
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
# Final verification that the app is running
|
||||
if ! nc -z localhost 12000 || ! curl -s http://localhost:12000 | grep -q "<html"; then
|
||||
echo "ERROR: OpenHands is not running properly on port 12000"
|
||||
echo "Last 50 lines of the log:"
|
||||
tail -n 50 /tmp/openhands-e2e-test.log
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Print success message
|
||||
echo "OpenHands is running successfully on port 12000"
|
||||
|
||||
- name: Run end-to-end tests
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.E2E_TEST_GITHUB_TOKEN }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL || 'gpt-4o' }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY || 'test-key' }}
|
||||
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
|
||||
run: |
|
||||
# Check if the application is running
|
||||
if ! nc -z localhost 12000; then
|
||||
echo "ERROR: OpenHands is not running on port 12000"
|
||||
echo "Last 50 lines of the log:"
|
||||
tail -n 50 /tmp/openhands-e2e-test.log
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run the tests with detailed output
|
||||
cd tests/e2e
|
||||
poetry run python -m pytest \
|
||||
test_settings.py::test_github_token_configuration \
|
||||
test_conversation.py::test_conversation_start \
|
||||
test_browsing_catchphrase.py::test_browsing_catchphrase \
|
||||
test_multi_conversation_resume.py::test_multi_conversation_resume \
|
||||
-v --no-header --capture=no --timeout=900
|
||||
|
||||
- name: Upload test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: playwright-report
|
||||
path: tests/e2e/test-results/
|
||||
retention-days: 30
|
||||
|
||||
- name: Upload OpenHands logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: openhands-logs
|
||||
path: |
|
||||
/tmp/openhands-e2e-test.log
|
||||
/tmp/openhands-e2e-build.log
|
||||
/tmp/openhands-backend.log
|
||||
/tmp/openhands-frontend.log
|
||||
/tmp/backend-health-check.log
|
||||
/tmp/frontend-check.log
|
||||
/tmp/vite-config.log
|
||||
/tmp/makefile-contents.log
|
||||
retention-days: 30
|
||||
|
||||
- name: Cleanup
|
||||
if: always()
|
||||
run: |
|
||||
# Stop OpenHands processes
|
||||
echo "Stopping OpenHands processes..."
|
||||
pkill -f "python -m openhands.server" || true
|
||||
pkill -f "npm run dev" || true
|
||||
pkill -f "make run" || true
|
||||
|
||||
# Print process status for debugging
|
||||
echo "Checking if any OpenHands processes are still running:"
|
||||
ps aux | grep -E "openhands|npm run dev" || true
|
||||
40
.github/workflows/pr-review-by-openhands.yml
vendored
40
.github/workflows/pr-review-by-openhands.yml
vendored
@@ -2,12 +2,14 @@
|
||||
name: PR Review by OpenHands
|
||||
|
||||
on:
|
||||
# TEMPORARY MITIGATION (Clinejection hardening)
|
||||
#
|
||||
# We temporarily avoid `pull_request_target` here. We'll restore it after the PR review
|
||||
# workflow is fully hardened for untrusted execution.
|
||||
# Use pull_request for same-repo PRs so workflow changes can self-verify in PRs.
|
||||
pull_request:
|
||||
types: [opened, ready_for_review, labeled, review_requested]
|
||||
# Use pull_request_target for fork PRs.
|
||||
# The bot token used here is intentionally scoped to PR review operations,
|
||||
# so the remaining blast radius is bounded even though PR content is untrusted.
|
||||
pull_request_target:
|
||||
types: [opened, ready_for_review, labeled, review_requested]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -16,13 +18,33 @@ permissions:
|
||||
|
||||
jobs:
|
||||
pr-review:
|
||||
# Note: fork PRs will not have access to repository secrets under `pull_request`.
|
||||
# Skip forks to avoid noisy failures until we restore a hardened `pull_request_target` flow.
|
||||
# Run on same-repo PRs via pull_request and on fork PRs via pull_request_target.
|
||||
# Trigger when one of the following conditions is met:
|
||||
# 1. A new non-draft PR is opened by a non-first-time contributor, OR
|
||||
# 2. A draft PR is converted to ready for review by a non-first-time contributor, OR
|
||||
# 3. The 'review-this' label is added, OR
|
||||
# 4. openhands-agent or all-hands-bot is requested as a reviewer
|
||||
# Note: FIRST_TIME_CONTRIBUTOR and NONE PRs require manual trigger via label/reviewer request.
|
||||
# Trigger logic:
|
||||
# 1. Route same-repo PRs through `pull_request` and fork PRs through `pull_request_target`
|
||||
# 2. Auto-trigger on `opened` / `ready_for_review` for non-first-time contributors
|
||||
# 3. Always allow manual triggers via `review-this` or reviewer request
|
||||
# The author association check is duplicated intentionally for both
|
||||
# auto-triggered actions (`opened` and `ready_for_review`).
|
||||
if: |
|
||||
github.event.pull_request.head.repo.full_name == github.repository &&
|
||||
(
|
||||
(github.event.action == 'opened' && github.event.pull_request.draft == false) ||
|
||||
github.event.action == 'ready_for_review' ||
|
||||
(
|
||||
github.event_name == 'pull_request' &&
|
||||
github.event.pull_request.head.repo.full_name == github.repository
|
||||
) ||
|
||||
(
|
||||
github.event_name == 'pull_request_target' &&
|
||||
github.event.pull_request.head.repo.full_name != github.repository
|
||||
)
|
||||
) &&
|
||||
(
|
||||
(github.event.action == 'opened' && github.event.pull_request.draft == false && github.event.pull_request.author_association != 'FIRST_TIME_CONTRIBUTOR' && github.event.pull_request.author_association != 'NONE') ||
|
||||
(github.event.action == 'ready_for_review' && github.event.pull_request.author_association != 'FIRST_TIME_CONTRIBUTOR' && github.event.pull_request.author_association != 'NONE') ||
|
||||
(github.event.action == 'labeled' && github.event.label.name == 'review-this') ||
|
||||
(
|
||||
github.event.action == 'review_requested' &&
|
||||
|
||||
4
.github/workflows/py-tests.yml
vendored
4
.github/workflows/py-tests.yml
vendored
@@ -60,10 +60,6 @@ jobs:
|
||||
run: PYTHONPATH=".:$PYTHONPATH" poetry run pytest --forked -n auto -s ./tests/unit --cov=openhands --cov-branch
|
||||
env:
|
||||
COVERAGE_FILE: ".coverage.${{ matrix.python_version }}"
|
||||
- name: Run Runtime Tests with CLIRuntime
|
||||
run: PYTHONPATH=".:$PYTHONPATH" TEST_RUNTIME=cli poetry run pytest -n 5 --reruns 2 --reruns-delay 3 -s tests/runtime/test_bash.py --cov=openhands --cov-branch
|
||||
env:
|
||||
COVERAGE_FILE: ".coverage.runtime.${{ matrix.python_version }}"
|
||||
- name: Store coverage file
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
|
||||
27
AGENTS.md
27
AGENTS.md
@@ -284,6 +284,32 @@ If you are starting a pull request (PR), please follow the template in `.github/
|
||||
|
||||
These details may or may not be useful for your current task.
|
||||
|
||||
### Conversation State Management
|
||||
|
||||
#### Agent State and Sandbox Status:
|
||||
The frontend uses `useAgentState` hook (`frontend/src/hooks/use-agent-state.ts`) to determine the current conversation state. This hook:
|
||||
- Returns `curAgentState` (AgentState enum) for UI state determination
|
||||
- Returns `isArchived` flag when `sandbox_status === "MISSING"` (archived conversations)
|
||||
- Prioritizes live WebSocket execution status over cached API data
|
||||
|
||||
#### Archived Conversations (sandbox_status === "MISSING"):
|
||||
When a conversation's sandbox is no longer available (archived):
|
||||
- `useAgentState` returns `AgentState.STOPPED` and `isArchived: true`
|
||||
- Chat input is replaced with an archived banner (`ArchivedBanner` component)
|
||||
- VS Code tab, Terminal, and Planner show read-only messages instead of loading states
|
||||
- All interactive elements that require a running sandbox are disabled
|
||||
|
||||
#### Testing useAgentState:
|
||||
When mocking `useAgentState` in tests, always include the `isArchived` property:
|
||||
```typescript
|
||||
vi.mock("#/hooks/use-agent-state", () => ({
|
||||
useAgentState: () => ({
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
isArchived: false,
|
||||
}),
|
||||
}));
|
||||
```
|
||||
|
||||
### Microagents
|
||||
|
||||
Microagents are specialized prompts that enhance OpenHands with domain-specific knowledge and task-specific workflows. They are Markdown files that can include frontmatter for configuration.
|
||||
@@ -363,6 +389,7 @@ There are two main patterns for saving settings in the OpenHands frontend:
|
||||
**When to use each pattern:**
|
||||
- Use Pattern 1 (Immediate Save) for entity management where each item is independent
|
||||
- Use Pattern 2 (Manual Save) for configuration forms where settings are interdependent or need validation
|
||||
- Git provider tokens in the local/OSS integrations settings are managed through the V1 secrets endpoints (`POST`/`DELETE /api/v1/secrets/git-providers`). Do not reuse the logout flow for disconnecting tokens; `useLogout` is for actual app logout and still targets legacy OSS logout behavior.
|
||||
|
||||
### Adding New LLM Models
|
||||
|
||||
|
||||
@@ -36,7 +36,6 @@ Full details in our [Development Guide](./Development.md).
|
||||
|
||||
- **[Frontend](./frontend/README.md)** - React application
|
||||
- **[App Server (V1)](./openhands/app_server/README.md)** - Current FastAPI application server and REST API modules
|
||||
- **[Runtime](./openhands/runtime/README.md)** - Execution environments
|
||||
- **[Evaluation](https://github.com/OpenHands/benchmarks)** - Testing and benchmarks
|
||||
|
||||
## What Can You Build?
|
||||
|
||||
@@ -16,7 +16,7 @@ open source community:
|
||||
|
||||
#### [Aider](https://github.com/paul-gauthier/aider)
|
||||
- License: Apache License 2.0
|
||||
- Description: AI pair programming tool. OpenHands has adapted and integrated its linter module for code-related tasks in [`agentskills utilities`](https://github.com/OpenHands/OpenHands/tree/main/openhands/runtime/plugins/agent_skills/utils/aider)
|
||||
- Description: AI pair programming tool. OpenHands has adapted and integrated its linter module for code-related tasks.
|
||||
|
||||
#### [BrowserGym](https://github.com/ServiceNow/BrowserGym)
|
||||
- License: Apache License 2.0
|
||||
|
||||
@@ -309,16 +309,6 @@ poetry run pytest ./tests/unit/test_*.py
|
||||
|
||||
---
|
||||
|
||||
## Using Existing Docker Images
|
||||
|
||||
To reduce build time, you can use an existing runtime image:
|
||||
|
||||
```bash
|
||||
export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.2-nikolaik
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Help
|
||||
|
||||
```bash
|
||||
@@ -339,4 +329,3 @@ make help
|
||||
- [/tests/unit/README.md](./tests/unit/README.md): Guide to writing and running unit tests
|
||||
- [OpenHands/benchmarks](https://github.com/OpenHands/benchmarks): Documentation for the evaluation framework and benchmarks
|
||||
- [/skills/README.md](./skills/README.md): Information about the skills architecture and implementation
|
||||
- [/openhands/runtime/README.md](./openhands/runtime/README.md): Documentation for the runtime environment and execution model
|
||||
|
||||
@@ -88,7 +88,6 @@ USER openhands
|
||||
|
||||
COPY --chown=openhands:openhands --chmod=770 ./skills ./skills
|
||||
COPY --chown=openhands:openhands --chmod=770 ./openhands ./openhands
|
||||
COPY --chown=openhands:openhands --chmod=777 ./openhands/runtime/plugins ./openhands/runtime/plugins
|
||||
COPY --chown=openhands:openhands pyproject.toml poetry.lock README.md MANIFEST.in LICENSE ./
|
||||
|
||||
# Add this line to set group ownership of all files/directories not already in "app" group
|
||||
|
||||
@@ -23,18 +23,6 @@ if [ -z "$WORKSPACE_MOUNT_PATH" ]; then
|
||||
unset WORKSPACE_BASE
|
||||
fi
|
||||
|
||||
if [[ "$INSTALL_THIRD_PARTY_RUNTIMES" == "true" ]]; then
|
||||
echo "Downloading and installing third_party_runtimes..."
|
||||
echo "Warning: Third-party runtimes are provided as-is, not actively supported and may be removed in future releases."
|
||||
|
||||
if pip install 'openhands-ai[third_party_runtimes]' -qqq 2> >(tee /dev/stderr); then
|
||||
echo "third_party_runtimes installed successfully."
|
||||
else
|
||||
echo "Failed to install third_party_runtimes." >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ "$SANDBOX_USER_ID" -eq 0 ]]; then
|
||||
echo "Running OpenHands as root"
|
||||
export RUN_AS_OPENHANDS=false
|
||||
|
||||
@@ -3,9 +3,9 @@ repos:
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/|enterprise/)
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|enterprise/)
|
||||
- id: end-of-file-fixer
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/|enterprise/)
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|enterprise/)
|
||||
- id: check-yaml
|
||||
args: ["--allow-multiple-documents"]
|
||||
- id: debug-statements
|
||||
@@ -37,12 +37,12 @@ repos:
|
||||
entry: ruff check --config dev_config/python/ruff.toml
|
||||
types_or: [python, pyi, jupyter]
|
||||
args: [--fix, --unsafe-fixes]
|
||||
exclude: ^(third_party/|enterprise/)
|
||||
exclude: ^(enterprise/)
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
entry: ruff format --config dev_config/python/ruff.toml
|
||||
types_or: [python, pyi, jupyter]
|
||||
exclude: ^(third_party/|enterprise/)
|
||||
exclude: ^(enterprise/)
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.15.0
|
||||
|
||||
@@ -10,7 +10,7 @@ strict_optional = True
|
||||
disable_error_code = type-abstract
|
||||
|
||||
# Exclude third-party runtime directory from type checking
|
||||
exclude = (third_party/|enterprise/)
|
||||
exclude = (enterprise/)
|
||||
|
||||
[mypy-openai.*]
|
||||
follow_imports = skip
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Exclude third-party runtime directory from linting
|
||||
exclude = ["third_party/", "enterprise/"]
|
||||
exclude = ["enterprise/"]
|
||||
|
||||
[lint]
|
||||
select = [
|
||||
|
||||
@@ -61,13 +61,6 @@ export LITE_LLM_API_KEY=<your LLM API key>
|
||||
python enterprise_local/convert_to_env.py
|
||||
```
|
||||
|
||||
You'll also need to set up the runtime image, so that the dev server doesn't try to rebuild it.
|
||||
|
||||
```
|
||||
export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:main-nikolaik
|
||||
docker pull $SANDBOX_RUNTIME_CONTAINER_IMAGE
|
||||
```
|
||||
|
||||
By default the application will log in json, you can override.
|
||||
|
||||
```
|
||||
@@ -203,7 +196,6 @@ And then invoking `printenv`. NOTE: _DO NOT DO THIS WITH PROD!!!_ (Hopefully by
|
||||
"REDIS_HOST": "localhost:6379",
|
||||
"OPENHANDS": "<YOUR LOCAL OPENHANDS DIR>",
|
||||
"FRONTEND_DIRECTORY": "<YOUR LOCAL OPENHANDS DIR>/frontend/build",
|
||||
"SANDBOX_RUNTIME_CONTAINER_IMAGE": "ghcr.io/openhands/runtime:main-nikolaik",
|
||||
"FILE_STORE_PATH": "<YOUR HOME DIRECTORY>>/.openhands-state",
|
||||
"OPENHANDS_CONFIG_CLS": "server.config.SaaSServerConfig",
|
||||
"GITHUB_APP_ID": "1062351",
|
||||
@@ -237,7 +229,6 @@ And then invoking `printenv`. NOTE: _DO NOT DO THIS WITH PROD!!!_ (Hopefully by
|
||||
"REDIS_HOST": "localhost:6379",
|
||||
"OPENHANDS": "<YOUR LOCAL OPENHANDS DIR>",
|
||||
"FRONTEND_DIRECTORY": "<YOUR LOCAL OPENHANDS DIR>/frontend/build",
|
||||
"SANDBOX_RUNTIME_CONTAINER_IMAGE": "ghcr.io/openhands/runtime:main-nikolaik",
|
||||
"FILE_STORE_PATH": "<YOUR HOME DIRECTORY>>/.openhands-state",
|
||||
"OPENHANDS_CONFIG_CLS": "server.config.SaaSServerConfig",
|
||||
"GITHUB_APP_ID": "1062351",
|
||||
|
||||
@@ -112,9 +112,6 @@ lines.append(
|
||||
lines.append(
|
||||
'OPENHANDS_BITBUCKET_DATA_CENTER_SERVICE_CLS=integrations.bitbucket_data_center.bitbucket_dc_service.SaaSBitbucketDCService'
|
||||
)
|
||||
lines.append(
|
||||
'OPENHANDS_CONVERSATION_VALIDATOR_CLS=storage.saas_conversation_validator.SaasConversationValidator'
|
||||
)
|
||||
lines.append('POSTHOG_CLIENT_KEY=test')
|
||||
lines.append('ENABLE_PROACTIVE_CONVERSATION_STARTERS=true')
|
||||
lines.append('MAX_CONCURRENT_CONVERSATIONS=10')
|
||||
|
||||
@@ -2,7 +2,6 @@ from types import MappingProxyType
|
||||
|
||||
from github import Auth, Github, GithubIntegration
|
||||
from integrations.github.data_collector import GitHubDataCollector
|
||||
from integrations.github.github_solvability import summarize_issue_solvability
|
||||
from integrations.github.github_view import (
|
||||
GithubFactory,
|
||||
GithubFailingAction,
|
||||
@@ -20,7 +19,6 @@ from integrations.models import (
|
||||
from integrations.types import ResolverViewInterface
|
||||
from integrations.utils import (
|
||||
CONVERSATION_URL,
|
||||
ENABLE_SOLVABILITY_ANALYSIS,
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
get_session_expired_message,
|
||||
@@ -33,6 +31,7 @@ from server.auth.auth_error import ExpiredError
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
from openhands.app_server.secrets.secrets_models import Secrets
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.integrations.service_types import AuthenticationError
|
||||
@@ -41,7 +40,6 @@ from openhands.server.types import (
|
||||
MissingSettingsError,
|
||||
SessionExpiredError,
|
||||
)
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
|
||||
|
||||
class GithubManager(Manager[GithubViewType]):
|
||||
@@ -358,26 +356,7 @@ class GithubManager(Manager[GithubViewType]):
|
||||
)
|
||||
)
|
||||
|
||||
# We first initialize a conversation and generate the solvability report BEFORE starting the conversation runtime
|
||||
# This helps us accumulate llm spend without requiring a running runtime. This setups us up for
|
||||
# 1. If there is a problem starting the runtime we still have accumulated total conversation cost
|
||||
# 2. In the future, based on the report confidence we can conditionally start the conversation
|
||||
# 3. Once the conversation is started, its base cost will include the report's spend as well which allows us to control max budget per resolver task
|
||||
convo_metadata = await github_view.initialize_new_conversation()
|
||||
solvability_summary = None
|
||||
if not ENABLE_SOLVABILITY_ANALYSIS:
|
||||
logger.info(
|
||||
'[Github]: Solvability report feature is disabled, skipping'
|
||||
)
|
||||
else:
|
||||
try:
|
||||
solvability_summary = await summarize_issue_solvability(
|
||||
github_view, user_token
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'[Github]: Error summarizing issue solvability: {str(e)}'
|
||||
)
|
||||
conversation_id = await github_view.initialize_new_conversation()
|
||||
|
||||
saas_user_auth = await get_saas_user_auth(
|
||||
github_view.user_info.keycloak_user_id, self.token_manager
|
||||
@@ -386,26 +365,21 @@ class GithubManager(Manager[GithubViewType]):
|
||||
await github_view.create_new_conversation(
|
||||
self.jinja_env,
|
||||
secret_store.provider_tokens,
|
||||
convo_metadata,
|
||||
conversation_id,
|
||||
saas_user_auth,
|
||||
)
|
||||
|
||||
conversation_id = github_view.conversation_id
|
||||
conversation_id_hex = github_view.conversation_id
|
||||
|
||||
logger.info(
|
||||
f'[GitHub] Created conversation {conversation_id} for user {user_info.username}'
|
||||
f'[GitHub] Created conversation {conversation_id_hex} for user {user_info.username}'
|
||||
)
|
||||
|
||||
# V1 callback processors are registered by the view during conversation creation
|
||||
|
||||
# Send message with conversation link
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
base_msg = f"I'm on it! {user_info.username} can [track my progress at all-hands.dev]({conversation_link})"
|
||||
# Combine messages: include solvability report with "I'm on it!" if successful
|
||||
if solvability_summary:
|
||||
msg_info = f'{base_msg}\n\n{solvability_summary}'
|
||||
else:
|
||||
msg_info = base_msg
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id_hex)
|
||||
msg_info = f"I'm on it! {user_info.username} can [track my progress at all-hands.dev]({conversation_link})"
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(
|
||||
|
||||
@@ -1,188 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from github import Auth, Github
|
||||
from integrations.github.github_view import (
|
||||
GithubInlinePRComment,
|
||||
GithubIssueComment,
|
||||
GithubPRComment,
|
||||
GithubViewType,
|
||||
)
|
||||
from integrations.solvability.data import load_classifier
|
||||
from integrations.solvability.models.report import SolvabilityReport
|
||||
from integrations.solvability.models.summary import SolvabilitySummary
|
||||
from integrations.utils import ENABLE_SOLVABILITY_ANALYSIS
|
||||
from pydantic import ValidationError
|
||||
from server.config import get_config
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.utils import create_registry_and_conversation_stats
|
||||
|
||||
|
||||
def fetch_github_issue_context(
|
||||
github_view: GithubViewType,
|
||||
user_token: str,
|
||||
) -> str:
|
||||
"""Fetch full GitHub issue/PR context including title, body, and comments.
|
||||
|
||||
Args:
|
||||
full_repo_name: Full repository name in the format 'owner/repo'
|
||||
issue_number: The issue or PR number
|
||||
user_token: GitHub user access token
|
||||
max_comments: Maximum number of comments to fetch (default: 10)
|
||||
max_comment_length: Maximum length of each comment to include in the context (default: 500)
|
||||
|
||||
Returns:
|
||||
A comprehensive string containing the issue/PR context
|
||||
"""
|
||||
|
||||
# Build context string
|
||||
context_parts = []
|
||||
|
||||
# Add title and body
|
||||
context_parts.append(f'Title: {github_view.title}')
|
||||
context_parts.append(f'Description:\n{github_view.description}')
|
||||
|
||||
with Github(auth=Auth.Token(user_token)) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
issue = repo.get_issue(github_view.issue_number)
|
||||
if issue.labels:
|
||||
labels = [label.name for label in issue.labels]
|
||||
context_parts.append(f"Labels: {', '.join(labels)}")
|
||||
|
||||
for comment in github_view.previous_comments:
|
||||
context_parts.append(f'- {comment.author}: {comment.body}')
|
||||
|
||||
return '\n\n'.join(context_parts)
|
||||
|
||||
|
||||
async def summarize_issue_solvability(
|
||||
github_view: GithubViewType,
|
||||
user_token: str,
|
||||
timeout: float = 60.0 * 5,
|
||||
) -> str:
|
||||
"""Generate a solvability summary for an issue using the resolver view interface.
|
||||
|
||||
Args:
|
||||
resolver_view: A resolver view interface instance (e.g., GithubIssue, GithubPRComment)
|
||||
user_token: GitHub user access token for API access
|
||||
timeout: Maximum time in seconds to wait for the result (default: 60.0)
|
||||
|
||||
Returns:
|
||||
The solvability summary as a string
|
||||
|
||||
Raises:
|
||||
ValueError: If LLM settings cannot be found for the user
|
||||
asyncio.TimeoutError: If the operation exceeds the specified timeout
|
||||
"""
|
||||
if not ENABLE_SOLVABILITY_ANALYSIS:
|
||||
raise ValueError('Solvability report feature is disabled')
|
||||
|
||||
if github_view.user_info.keycloak_user_id is None:
|
||||
raise ValueError(
|
||||
f'[Solvability] No user ID found for user {github_view.user_info.username}'
|
||||
)
|
||||
|
||||
# Grab the user's information so we can load their LLM configuration
|
||||
store = SaasSettingsStore(
|
||||
user_id=github_view.user_info.keycloak_user_id,
|
||||
config=get_config(),
|
||||
)
|
||||
|
||||
user_settings = await store.load()
|
||||
|
||||
if user_settings is None:
|
||||
raise ValueError(
|
||||
f'[Solvability] No user settings found for user ID {github_view.user_info.user_id}'
|
||||
)
|
||||
|
||||
# Check if solvability analysis is enabled for this user, exit early if
|
||||
# needed
|
||||
if not getattr(user_settings, 'enable_solvability_analysis', False):
|
||||
raise ValueError(
|
||||
f'Solvability analysis disabled for user {github_view.user_info.user_id}'
|
||||
)
|
||||
|
||||
agent_settings = user_settings.agent_settings
|
||||
llm_settings = agent_settings.llm
|
||||
if llm_settings.api_key is None:
|
||||
raise ValueError(
|
||||
f'[Solvability] No LLM API key found for user {github_view.user_info.user_id}'
|
||||
)
|
||||
|
||||
try:
|
||||
llm_config = LLMConfig(
|
||||
model=llm_settings.model,
|
||||
api_key=llm_settings.api_key.get_secret_value(),
|
||||
base_url=llm_settings.base_url,
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
f'[Solvability] Invalid LLM configuration for user {github_view.user_info.user_id}: {str(e)}'
|
||||
)
|
||||
|
||||
# Fetch the full GitHub issue/PR context using the GitHub API
|
||||
start_time = time.time()
|
||||
issue_context = fetch_github_issue_context(github_view, user_token)
|
||||
logger.info(
|
||||
f'[Solvability] Grabbed issue context for {github_view.conversation_id}',
|
||||
extra={
|
||||
'conversation_id': github_view.conversation_id,
|
||||
'response_latency': time.time() - start_time,
|
||||
'full_repo_name': github_view.full_repo_name,
|
||||
'issue_number': github_view.issue_number,
|
||||
},
|
||||
)
|
||||
|
||||
# For comment-based triggers, also include the specific comment that triggered the action
|
||||
if isinstance(
|
||||
github_view, (GithubIssueComment, GithubPRComment, GithubInlinePRComment)
|
||||
):
|
||||
issue_context += f'\n\nTriggering Comment:\n{github_view.comment_body}'
|
||||
|
||||
solvability_classifier = load_classifier('default-classifier')
|
||||
|
||||
async with asyncio.timeout(timeout):
|
||||
solvability_report: SolvabilityReport = await call_sync_from_async(
|
||||
lambda: solvability_classifier.solvability_report(
|
||||
issue_context, llm_config=llm_config
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Solvability] Generated report for {github_view.conversation_id}',
|
||||
extra={
|
||||
'conversation_id': github_view.conversation_id,
|
||||
'report': solvability_report.model_dump(exclude=['issue']),
|
||||
},
|
||||
)
|
||||
|
||||
llm_registry, conversation_stats, _ = create_registry_and_conversation_stats(
|
||||
get_config(),
|
||||
github_view.conversation_id,
|
||||
github_view.user_info.keycloak_user_id,
|
||||
None,
|
||||
)
|
||||
|
||||
solvability_summary = await call_sync_from_async(
|
||||
lambda: SolvabilitySummary.from_report(
|
||||
solvability_report,
|
||||
llm=llm_registry.get_llm(
|
||||
service_id='solvability_analysis', config=llm_config
|
||||
),
|
||||
)
|
||||
)
|
||||
conversation_stats.save_metrics()
|
||||
|
||||
logger.info(
|
||||
f'[Solvability] Generated summary for {github_view.conversation_id}',
|
||||
extra={
|
||||
'conversation_id': github_view.conversation_id,
|
||||
'summary': solvability_summary.model_dump(exclude=['content']),
|
||||
},
|
||||
)
|
||||
|
||||
return solvability_summary.format_as_markdown()
|
||||
@@ -14,11 +14,9 @@ from integrations.resolver_org_router import resolve_org_for_repo
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import (
|
||||
ENABLE_PROACTIVE_CONVERSATION_STARTERS,
|
||||
ENABLE_V1_GITHUB_RESOLVER,
|
||||
HOST,
|
||||
HOST_URL,
|
||||
get_oh_labels,
|
||||
get_user_v1_enabled_setting,
|
||||
has_exact_mention,
|
||||
)
|
||||
from jinja2 import Environment
|
||||
@@ -27,13 +25,13 @@ from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
AppConversationStartTaskStatus,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.app_server.config import get_app_conversation_service
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
@@ -44,20 +42,11 @@ from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
from openhands.integrations.service_types import Comment
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
|
||||
|
||||
async def is_v1_enabled_for_github_resolver(user_id: str) -> bool:
|
||||
return await get_user_v1_enabled_setting(user_id) and ENABLE_V1_GITHUB_RESOLVER
|
||||
|
||||
|
||||
async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
"""Get the user's proactive conversation setting.
|
||||
|
||||
@@ -105,7 +94,6 @@ class GithubIssue(ResolverViewInterface):
|
||||
title: str
|
||||
description: str
|
||||
previous_comments: list[Comment]
|
||||
v1_enabled: bool
|
||||
|
||||
def _get_branch_name(self) -> str | None:
|
||||
return getattr(self, 'branch_name', None)
|
||||
@@ -152,11 +140,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
self.v1_enabled = await is_v1_enabled_for_github_resolver(
|
||||
self.user_info.keycloak_user_id
|
||||
)
|
||||
|
||||
async def initialize_new_conversation(self) -> UUID:
|
||||
# Resolve target org based on claimed git organizations
|
||||
self.resolved_org_id = await resolve_org_for_repo(
|
||||
provider='github',
|
||||
@@ -164,54 +148,20 @@ class GithubIssue(ResolverViewInterface):
|
||||
keycloak_user_id=self.user_info.keycloak_user_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[GitHub V1]: User flag found for {self.user_info.keycloak_user_id} is {self.v1_enabled}'
|
||||
)
|
||||
if self.v1_enabled:
|
||||
# Create dummy conversationm metadata
|
||||
# Don't save to conversation store
|
||||
# V1 conversations are stored in a separate table
|
||||
self.conversation_id = uuid4().hex
|
||||
return ConversationMetadata(
|
||||
conversation_id=self.conversation_id,
|
||||
selected_repository=self.full_repo_name,
|
||||
)
|
||||
|
||||
# Create the conversation store with resolver org routing
|
||||
# (bypasses initialize_conversation to avoid threading enterprise-only
|
||||
# resolver_org_id through the generic OSS interface)
|
||||
store = await SaasConversationStore.get_resolver_instance(
|
||||
get_config(),
|
||||
self.user_info.keycloak_user_id,
|
||||
self.resolved_org_id,
|
||||
)
|
||||
|
||||
conversation_id = uuid4().hex
|
||||
conversation_metadata = ConversationMetadata(
|
||||
trigger=ConversationTrigger.RESOLVER,
|
||||
conversation_id=conversation_id,
|
||||
title=get_default_conversation_title(conversation_id),
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=self._get_branch_name(),
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
await store.save_metadata(conversation_metadata)
|
||||
|
||||
self.conversation_id = conversation_id
|
||||
return conversation_metadata
|
||||
# All conversations use V1 app conversation service
|
||||
conversation_id = uuid4()
|
||||
self.conversation_id = conversation_id.hex
|
||||
return conversation_id
|
||||
|
||||
async def create_new_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
conversation_id: UUID,
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
# V0 conversation path has been removed - all conversations use V1 app conversation service
|
||||
await self._create_v1_conversation(
|
||||
jinja_env, saas_user_auth, conversation_metadata
|
||||
)
|
||||
await self._create_v1_conversation(jinja_env, saas_user_auth, conversation_id)
|
||||
|
||||
async def _get_v1_initial_user_message(self, jinja_env: Environment) -> str:
|
||||
"""Build the initial user message for V1 resolver conversations.
|
||||
@@ -239,7 +189,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
saas_user_auth: UserAuth,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
conversation_id: UUID,
|
||||
):
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
logger.info('[GitHub V1]: Creating V1 conversation')
|
||||
@@ -259,7 +209,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
|
||||
# Create the V1 conversation start request with the callback processor
|
||||
start_request = AppConversationStartRequest(
|
||||
conversation_id=UUID(conversation_metadata.conversation_id),
|
||||
conversation_id=conversation_id,
|
||||
# NOTE: Resolver instructions are intended to be lower priority than the
|
||||
# system prompt, so we inject them into the initial user message.
|
||||
system_message_suffix=None,
|
||||
@@ -813,7 +763,6 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1_enabled=False,
|
||||
)
|
||||
|
||||
elif GithubFactory.is_issue_comment(message):
|
||||
@@ -839,7 +788,6 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1_enabled=False,
|
||||
)
|
||||
|
||||
elif GithubFactory.is_pr_comment(message):
|
||||
@@ -881,7 +829,6 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1_enabled=False,
|
||||
)
|
||||
|
||||
elif GithubFactory.is_inline_pr_comment(message):
|
||||
@@ -915,7 +862,6 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1_enabled=False,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -25,6 +25,7 @@ from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
from openhands.app_server.secrets.secrets_models import Secrets
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
@@ -33,7 +34,6 @@ from openhands.server.types import (
|
||||
MissingSettingsError,
|
||||
SessionExpiredError,
|
||||
)
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
|
||||
|
||||
class GitlabManager(Manager[GitlabViewType]):
|
||||
@@ -208,8 +208,8 @@ class GitlabManager(Manager[GitlabViewType]):
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize conversation and get metadata (following GitHub pattern)
|
||||
convo_metadata = await gitlab_view.initialize_new_conversation()
|
||||
# Initialize conversation and get UUID
|
||||
conversation_id = await gitlab_view.initialize_new_conversation()
|
||||
|
||||
saas_user_auth = await get_saas_user_auth(
|
||||
gitlab_view.user_info.keycloak_user_id, self.token_manager
|
||||
@@ -218,19 +218,19 @@ class GitlabManager(Manager[GitlabViewType]):
|
||||
await gitlab_view.create_new_conversation(
|
||||
self.jinja_env,
|
||||
secret_store.provider_tokens,
|
||||
convo_metadata,
|
||||
conversation_id,
|
||||
saas_user_auth,
|
||||
)
|
||||
|
||||
conversation_id = gitlab_view.conversation_id
|
||||
conversation_id_hex = gitlab_view.conversation_id
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Created conversation {conversation_id} for user {user_info.username}'
|
||||
f'[GitLab] Created conversation {conversation_id_hex} for user {user_info.username}'
|
||||
)
|
||||
|
||||
# V1 callback processors are registered by the view during conversation creation
|
||||
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id_hex)
|
||||
msg_info = f"I'm on it! {user_info.username} can [track my progress at all-hands.dev]({conversation_link})"
|
||||
|
||||
except MissingSettingsError as e:
|
||||
|
||||
@@ -6,22 +6,20 @@ from integrations.resolver_context import ResolverUserContext
|
||||
from integrations.resolver_org_router import resolve_org_for_repo
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import (
|
||||
ENABLE_V1_GITLAB_RESOLVER,
|
||||
HOST,
|
||||
get_oh_labels,
|
||||
get_user_v1_enabled_setting,
|
||||
has_exact_mention,
|
||||
)
|
||||
from jinja2 import Environment
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
AppConversationStartTaskStatus,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.app_server.config import get_app_conversation_service
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
@@ -32,21 +30,12 @@ from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
from openhands.integrations.service_types import Comment
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
CONFIDENTIAL_NOTE = 'confidential_note'
|
||||
NOTE_TYPES = ['note', CONFIDENTIAL_NOTE]
|
||||
|
||||
|
||||
async def is_v1_enabled_for_gitlab_resolver(user_id: str) -> bool:
|
||||
return await get_user_v1_enabled_setting(user_id) and ENABLE_V1_GITLAB_RESOLVER
|
||||
|
||||
|
||||
# =================================================
|
||||
# SECTION: Factory to create appriorate Gitlab view
|
||||
# =================================================
|
||||
@@ -68,7 +57,6 @@ class GitlabIssue(ResolverViewInterface):
|
||||
description: str
|
||||
previous_comments: list[Comment]
|
||||
is_mr: bool
|
||||
v1_enabled: bool
|
||||
|
||||
def _get_branch_name(self) -> str | None:
|
||||
return getattr(self, 'branch_name', None)
|
||||
@@ -114,10 +102,7 @@ class GitlabIssue(ResolverViewInterface):
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# v1_enabled is already set at construction time in the factory method
|
||||
# This is the source of truth for the conversation type
|
||||
|
||||
async def initialize_new_conversation(self) -> UUID:
|
||||
# Resolve target org based on claimed git organizations
|
||||
self.resolved_org_id = await resolve_org_for_repo(
|
||||
provider='gitlab',
|
||||
@@ -125,57 +110,26 @@ class GitlabIssue(ResolverViewInterface):
|
||||
keycloak_user_id=self.user_info.keycloak_user_id,
|
||||
)
|
||||
|
||||
if self.v1_enabled:
|
||||
# Create dummy conversation metadata
|
||||
# Don't save to conversation store
|
||||
# V1 conversations are stored in a separate table
|
||||
self.conversation_id = uuid4().hex
|
||||
return ConversationMetadata(
|
||||
conversation_id=self.conversation_id,
|
||||
selected_repository=self.full_repo_name,
|
||||
)
|
||||
|
||||
# Create the conversation store with resolver org routing
|
||||
# (bypasses initialize_conversation to avoid threading enterprise-only
|
||||
# resolver_org_id through the generic OSS interface)
|
||||
store = await SaasConversationStore.get_resolver_instance(
|
||||
get_config(),
|
||||
self.user_info.keycloak_user_id,
|
||||
self.resolved_org_id,
|
||||
)
|
||||
|
||||
conversation_id = uuid4().hex
|
||||
conversation_metadata = ConversationMetadata(
|
||||
trigger=ConversationTrigger.RESOLVER,
|
||||
conversation_id=conversation_id,
|
||||
title=get_default_conversation_title(conversation_id),
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=self._get_branch_name(),
|
||||
git_provider=ProviderType.GITLAB,
|
||||
)
|
||||
await store.save_metadata(conversation_metadata)
|
||||
|
||||
self.conversation_id = conversation_id
|
||||
return conversation_metadata
|
||||
# All conversations use V1 app conversation service
|
||||
conversation_id = uuid4()
|
||||
self.conversation_id = conversation_id.hex
|
||||
return conversation_id
|
||||
|
||||
async def create_new_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
conversation_id: UUID,
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
# V0 conversation path has been removed - all conversations use V1 app conversation service
|
||||
await self._create_v1_conversation(
|
||||
jinja_env, saas_user_auth, conversation_metadata
|
||||
)
|
||||
await self._create_v1_conversation(jinja_env, saas_user_auth, conversation_id)
|
||||
|
||||
async def _create_v1_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
saas_user_auth: UserAuth,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
conversation_id: UUID,
|
||||
):
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
logger.info('[GitLab V1]: Creating V1 conversation')
|
||||
@@ -201,7 +155,7 @@ class GitlabIssue(ResolverViewInterface):
|
||||
|
||||
# Create the V1 conversation start request with the callback processor
|
||||
start_request = AppConversationStartRequest(
|
||||
conversation_id=UUID(conversation_metadata.conversation_id),
|
||||
conversation_id=conversation_id,
|
||||
system_message_suffix=conversation_instructions,
|
||||
initial_message=initial_message,
|
||||
selected_repository=self.full_repo_name,
|
||||
@@ -450,16 +404,6 @@ class GitlabFactory:
|
||||
user_id=user_id, username=username, keycloak_user_id=keycloak_user_id
|
||||
)
|
||||
|
||||
# Check v1_enabled at construction time - this is the source of truth
|
||||
v1_enabled = (
|
||||
await is_v1_enabled_for_gitlab_resolver(keycloak_user_id)
|
||||
if keycloak_user_id
|
||||
else False
|
||||
)
|
||||
logger.info(
|
||||
f'[GitLab V1]: User flag found for {keycloak_user_id} is {v1_enabled}'
|
||||
)
|
||||
|
||||
if GitlabFactory.is_labeled_issue(message):
|
||||
issue_iid = payload['object_attributes']['iid']
|
||||
|
||||
@@ -481,7 +425,6 @@ class GitlabFactory:
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=False,
|
||||
v1_enabled=v1_enabled,
|
||||
)
|
||||
|
||||
elif GitlabFactory.is_issue_comment(message):
|
||||
@@ -512,7 +455,6 @@ class GitlabFactory:
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=False,
|
||||
v1_enabled=v1_enabled,
|
||||
)
|
||||
|
||||
elif GitlabFactory.is_mr_comment(message):
|
||||
@@ -545,7 +487,6 @@ class GitlabFactory:
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=True,
|
||||
v1_enabled=v1_enabled,
|
||||
)
|
||||
|
||||
elif GitlabFactory.is_mr_comment(message, inline=True):
|
||||
@@ -586,7 +527,6 @@ class GitlabFactory:
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=True,
|
||||
v1_enabled=v1_enabled,
|
||||
)
|
||||
|
||||
raise ValueError(f'Unhandled GitLab webhook event: {message}')
|
||||
|
||||
@@ -35,6 +35,7 @@ from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
AppConversationStartTaskStatus,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.app_server.config import get_app_conversation_service
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
@@ -43,10 +44,6 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler, ProviderType
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
@@ -192,32 +189,30 @@ class JiraNewConversationView(JiraViewInterface):
|
||||
)
|
||||
await integration_store.create_conversation(jira_conversation)
|
||||
|
||||
conversation_metadata = await self._create_v1_metadata()
|
||||
await self._create_v1_conversation(jinja_env, conversation_metadata)
|
||||
conversation_id = await self._initialize_conversation()
|
||||
await self._create_v1_conversation(jinja_env, conversation_id)
|
||||
return self.conversation_id
|
||||
|
||||
async def _create_v1_metadata(self) -> ConversationMetadata:
|
||||
"""Create conversation metadata for V1 conversations.
|
||||
async def _initialize_conversation(self) -> UUID:
|
||||
"""Initialize conversation and return the conversation ID.
|
||||
|
||||
The JiraConversation mapping is saved to the integration store (above), but
|
||||
V1 conversation metadata is managed by the app conversation system, not
|
||||
the legacy conversation store.
|
||||
"""
|
||||
logger.info('[Jira]: Creating V1 metadata')
|
||||
logger.info('[Jira]: Initializing V1 conversation')
|
||||
|
||||
# Generate a dummy conversation for V1 (not saved to store)
|
||||
self.conversation_id = uuid4().hex
|
||||
# Generate a conversation ID for V1
|
||||
conversation_id = uuid4()
|
||||
self.conversation_id = conversation_id.hex
|
||||
self.resolved_org_id = await self._get_resolved_org_id()
|
||||
|
||||
return ConversationMetadata(
|
||||
conversation_id=self.conversation_id,
|
||||
selected_repository=self.selected_repo,
|
||||
)
|
||||
return conversation_id
|
||||
|
||||
async def _create_v1_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
conversation_id: UUID,
|
||||
):
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
logger.info('[Jira]: Creating V1 conversation')
|
||||
@@ -236,7 +231,7 @@ class JiraNewConversationView(JiraViewInterface):
|
||||
|
||||
# Create the V1 conversation start request
|
||||
start_request = AppConversationStartRequest(
|
||||
conversation_id=UUID(conversation_metadata.conversation_id),
|
||||
conversation_id=conversation_id,
|
||||
system_message_suffix=None,
|
||||
initial_message=initial_message,
|
||||
selected_repository=self.selected_repo,
|
||||
|
||||
@@ -27,6 +27,7 @@ from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
AppConversationStartTaskStatus,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.app_server.config import get_app_conversation_service
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
@@ -35,9 +36,6 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler, ProviderType
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationTrigger,
|
||||
)
|
||||
|
||||
integration_store = JiraDcIntegrationStore.get_instance()
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ class SlackErrorCode(Enum):
|
||||
PROVIDER_AUTH_FAILED = 'SLACK_ERR_006'
|
||||
LLM_AUTH_FAILED = 'SLACK_ERR_007'
|
||||
MISSING_SETTINGS = 'SLACK_ERR_008'
|
||||
MISSING_SLACK_SCOPES = 'SLACK_ERR_009'
|
||||
UNEXPECTED_ERROR = 'SLACK_ERR_999'
|
||||
|
||||
|
||||
@@ -98,6 +99,11 @@ _USER_MESSAGES: dict[SlackErrorCode, str] = {
|
||||
'{username} please re-login into '
|
||||
f'[OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
),
|
||||
SlackErrorCode.MISSING_SLACK_SCOPES: (
|
||||
'⚠️ The Slack app is missing required permissions. '
|
||||
f'Please ask your workspace admin to re-install the OpenHands Slack App at {HOST_URL}/slack/install '
|
||||
'to authorize the updated permissions.'
|
||||
),
|
||||
SlackErrorCode.UNEXPECTED_ERROR: (
|
||||
'Uh oh! There was an unexpected error (ref: {code}). Please try again later.'
|
||||
),
|
||||
|
||||
@@ -112,7 +112,6 @@ class SlackViewInterface(SlackMessageView, SummaryExtractionTracker, ABC):
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
conversation_id: str
|
||||
v1_enabled: bool
|
||||
|
||||
@abstractmethod
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
|
||||
@@ -5,6 +5,7 @@ from uuid import UUID, uuid4
|
||||
from integrations.models import Message
|
||||
from integrations.resolver_context import ResolverUserContext
|
||||
from integrations.resolver_org_router import resolve_org_for_repo
|
||||
from integrations.slack.slack_errors import SlackError, SlackErrorCode
|
||||
from integrations.slack.slack_types import (
|
||||
SlackMessageView,
|
||||
SlackViewInterface,
|
||||
@@ -13,11 +14,10 @@ from integrations.slack.slack_types import (
|
||||
from integrations.slack.slack_v1_callback_processor import SlackV1CallbackProcessor
|
||||
from integrations.utils import (
|
||||
CONVERSATION_URL,
|
||||
ENABLE_V1_SLACK_RESOLVER,
|
||||
get_user_v1_enabled_setting,
|
||||
)
|
||||
from jinja2 import Environment
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_conversation_store import SlackConversationStore
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
@@ -26,6 +26,7 @@ from storage.slack_user import SlackUser
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
AppConversationStartTaskStatus,
|
||||
ConversationTrigger,
|
||||
SendMessageRequest,
|
||||
)
|
||||
from openhands.app_server.config import get_app_conversation_service
|
||||
@@ -36,9 +37,6 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT
|
||||
|
||||
# =================================================
|
||||
@@ -51,10 +49,6 @@ slack_conversation_store = SlackConversationStore.get_instance()
|
||||
slack_team_store = SlackTeamStore.get_instance()
|
||||
|
||||
|
||||
async def is_v1_enabled_for_slack_resolver(user_id: str) -> bool:
|
||||
return await get_user_v1_enabled_setting(user_id) and ENABLE_V1_SLACK_RESOLVER
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackNewConversationView(SlackViewInterface):
|
||||
bot_access_token: str
|
||||
@@ -70,7 +64,6 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
send_summary_instruction: bool
|
||||
conversation_id: str
|
||||
team_id: str
|
||||
v1_enabled: bool
|
||||
|
||||
def _get_initial_prompt(self, text: str, blocks: list[dict]):
|
||||
bot_id = self._get_bot_id(blocks)
|
||||
@@ -95,24 +88,34 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
messages = []
|
||||
if self.thread_ts:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_replies(
|
||||
channel=self.channel_id,
|
||||
ts=self.thread_ts,
|
||||
inclusive=True,
|
||||
latest=self.message_ts,
|
||||
limit=CONTEXT_LIMIT, # We can be smarter about getting more context/condensing it even in the future
|
||||
)
|
||||
try:
|
||||
result = client.conversations_replies(
|
||||
channel=self.channel_id,
|
||||
ts=self.thread_ts,
|
||||
inclusive=True,
|
||||
latest=self.message_ts,
|
||||
limit=CONTEXT_LIMIT, # We can be smarter about getting more context/condensing it even in the future
|
||||
)
|
||||
except SlackApiError as e:
|
||||
if e.response.get('error') == 'missing_scope':
|
||||
raise SlackError(SlackErrorCode.MISSING_SLACK_SCOPES) from e
|
||||
raise
|
||||
|
||||
messages = result['messages']
|
||||
|
||||
else:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_history(
|
||||
channel=self.channel_id,
|
||||
inclusive=True,
|
||||
latest=self.message_ts,
|
||||
limit=CONTEXT_LIMIT,
|
||||
)
|
||||
try:
|
||||
result = client.conversations_history(
|
||||
channel=self.channel_id,
|
||||
inclusive=True,
|
||||
latest=self.message_ts,
|
||||
limit=CONTEXT_LIMIT,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
if e.response.get('error') == 'missing_scope':
|
||||
raise SlackError(SlackErrorCode.MISSING_SLACK_SCOPES) from e
|
||||
raise
|
||||
|
||||
messages = result['messages']
|
||||
messages.reverse()
|
||||
@@ -149,7 +152,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
'Attempting to start conversation without confirming selected repo from user'
|
||||
)
|
||||
|
||||
async def save_slack_convo(self, v1_enabled: bool = False):
|
||||
async def save_slack_convo(self):
|
||||
if self.slack_to_openhands_user:
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
|
||||
@@ -161,7 +164,6 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
'keycloak_user_id': user_info.keycloak_user_id,
|
||||
'org_id': user_info.org_id,
|
||||
'parent_id': self.thread_ts or self.message_ts,
|
||||
'v1_enabled': v1_enabled,
|
||||
},
|
||||
)
|
||||
slack_conversation = SlackConversation(
|
||||
@@ -171,7 +173,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
org_id=user_info.org_id,
|
||||
parent_id=self.thread_ts
|
||||
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
|
||||
v1_enabled=v1_enabled,
|
||||
v1_enabled=True, # All conversations are V1
|
||||
)
|
||||
await slack_conversation_store.create_slack_conversation(slack_conversation)
|
||||
|
||||
@@ -268,7 +270,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
)
|
||||
|
||||
logger.info(f'[Slack V1]: Created new conversation: {self.conversation_id}')
|
||||
await self.save_slack_convo(v1_enabled=True)
|
||||
await self.save_slack_convo()
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
@@ -290,13 +292,18 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_replies(
|
||||
channel=self.channel_id,
|
||||
ts=self.message_ts,
|
||||
inclusive=True,
|
||||
latest=self.message_ts,
|
||||
limit=1, # Get exact user message, in future we can be smarter with collecting additional context
|
||||
)
|
||||
try:
|
||||
result = client.conversations_replies(
|
||||
channel=self.channel_id,
|
||||
ts=self.message_ts,
|
||||
inclusive=True,
|
||||
latest=self.message_ts,
|
||||
limit=1, # Get exact user message, in future we can be smarter with collecting additional context
|
||||
)
|
||||
except SlackApiError as e:
|
||||
if e.response.get('error') == 'missing_scope':
|
||||
raise SlackError(SlackErrorCode.MISSING_SLACK_SCOPES) from e
|
||||
raise
|
||||
|
||||
user_message = result['messages'][0]
|
||||
user_message = self._get_initial_prompt(
|
||||
@@ -375,7 +382,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
)
|
||||
|
||||
# 6. Send the message to the agent server
|
||||
url = f"{agent_server_url.rstrip('/')}/api/conversations/{UUID(self.conversation_id)}/events"
|
||||
url = f'{agent_server_url.rstrip("/")}/api/conversations/{UUID(self.conversation_id)}/events'
|
||||
|
||||
headers = {'X-Session-API-Key': running_sandbox.session_api_key}
|
||||
payload = send_message_request.model_dump()
|
||||
@@ -516,7 +523,6 @@ class SlackFactory:
|
||||
conversation_id=conversation.conversation_id,
|
||||
slack_conversation=conversation,
|
||||
team_id=team_id,
|
||||
v1_enabled=False,
|
||||
)
|
||||
|
||||
elif SlackFactory.did_user_select_repo_from_form(message):
|
||||
@@ -534,7 +540,6 @@ class SlackFactory:
|
||||
send_summary_instruction=True,
|
||||
conversation_id='',
|
||||
team_id=team_id,
|
||||
v1_enabled=False,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -552,7 +557,6 @@ class SlackFactory:
|
||||
send_summary_instruction=True,
|
||||
conversation_id='',
|
||||
team_id=team_id,
|
||||
v1_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"""
|
||||
Utilities for loading and managing pre-trained classifiers.
|
||||
|
||||
Assumes that classifiers are stored adjacent to this file in the `solvability/data` directory, using a simple
|
||||
`name + .json` pattern.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from integrations.solvability.models.classifier import SolvabilityClassifier
|
||||
|
||||
|
||||
def load_classifier(name: str) -> SolvabilityClassifier:
|
||||
"""
|
||||
Load a classifier by name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the classifier to load.
|
||||
|
||||
Returns:
|
||||
SolvabilityClassifier: The loaded classifier instance.
|
||||
"""
|
||||
data_dir = Path(__file__).parent
|
||||
classifier_path = data_dir / f'{name}.json'
|
||||
|
||||
if not classifier_path.exists():
|
||||
raise FileNotFoundError(f"Classifier '{name}' not found at {classifier_path}")
|
||||
|
||||
with classifier_path.open('r') as f:
|
||||
return SolvabilityClassifier.model_validate_json(f.read())
|
||||
|
||||
|
||||
def available_classifiers() -> list[str]:
|
||||
"""
|
||||
List all available classifiers in the data directory.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of classifier names (without the .json extension).
|
||||
"""
|
||||
data_dir = Path(__file__).parent
|
||||
return [f.stem for f in data_dir.glob('*.json') if f.is_file()]
|
||||
File diff suppressed because one or more lines are too long
@@ -1,38 +0,0 @@
|
||||
"""
|
||||
Solvability Models Package
|
||||
|
||||
This package contains the core machine learning models and components for predicting
|
||||
the solvability of GitHub issues and similar technical problems.
|
||||
|
||||
The solvability prediction system works by:
|
||||
1. Using a Featurizer to extract semantic features from issue descriptions via LLM calls
|
||||
2. Training a RandomForestClassifier on these features to predict solvability
|
||||
3. Generating detailed reports with feature importance analysis
|
||||
|
||||
Key Components:
|
||||
- Feature: Defines individual features that can be extracted from issues
|
||||
- Featurizer: Orchestrates LLM-based feature extraction with sampling and batching
|
||||
- SolvabilityClassifier: Main ML pipeline combining featurization and classification
|
||||
- SolvabilityReport: Comprehensive output with predictions, feature analysis, and metadata
|
||||
- ImportanceStrategy: Configurable methods for calculating feature importance (SHAP, permutation, impurity)
|
||||
"""
|
||||
|
||||
from integrations.solvability.models.classifier import SolvabilityClassifier
|
||||
from integrations.solvability.models.featurizer import (
|
||||
EmbeddingDimension,
|
||||
Feature,
|
||||
FeatureEmbedding,
|
||||
Featurizer,
|
||||
)
|
||||
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||
from integrations.solvability.models.report import SolvabilityReport
|
||||
|
||||
__all__ = [
|
||||
'Feature',
|
||||
'EmbeddingDimension',
|
||||
'FeatureEmbedding',
|
||||
'Featurizer',
|
||||
'ImportanceStrategy',
|
||||
'SolvabilityClassifier',
|
||||
'SolvabilityReport',
|
||||
]
|
||||
@@ -1,433 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import pickle
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import shap
|
||||
from integrations.solvability.models.featurizer import Feature, Featurizer
|
||||
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||
from integrations.solvability.models.report import SolvabilityReport
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
PrivateAttr,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.inspection import permutation_importance
|
||||
from sklearn.utils.validation import check_is_fitted
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
|
||||
|
||||
class SolvabilityClassifier(BaseModel):
|
||||
"""
|
||||
Machine learning pipeline for predicting the solvability of GitHub issues and similar problems.
|
||||
|
||||
This classifier combines LLM-based feature extraction with traditional ML classification:
|
||||
1. Uses a Featurizer to extract semantic boolean features from issue descriptions via LLM calls
|
||||
2. Trains a RandomForestClassifier on these features to predict solvability scores
|
||||
3. Provides feature importance analysis using configurable strategies (SHAP, permutation, impurity)
|
||||
4. Generates comprehensive reports with predictions, feature analysis, and cost metrics
|
||||
|
||||
The classifier supports both training on labeled data and inference on new issues, with built-in
|
||||
support for batch processing and concurrent feature extraction.
|
||||
"""
|
||||
|
||||
identifier: str
|
||||
"""
|
||||
The identifier for the classifier.
|
||||
"""
|
||||
|
||||
featurizer: Featurizer
|
||||
"""
|
||||
The featurizer to use for transforming the input data.
|
||||
"""
|
||||
|
||||
classifier: RandomForestClassifier
|
||||
"""
|
||||
The RandomForestClassifier used for predicting solvability from extracted features.
|
||||
|
||||
This ensemble model provides robust predictions and built-in feature importance metrics.
|
||||
"""
|
||||
|
||||
importance_strategy: ImportanceStrategy = ImportanceStrategy.IMPURITY
|
||||
"""
|
||||
Strategy to use for calculating feature importance.
|
||||
"""
|
||||
|
||||
samples: int = 10
|
||||
"""
|
||||
Number of samples to use for calculating feature embedding coefficients.
|
||||
"""
|
||||
|
||||
random_state: int | None = None
|
||||
"""
|
||||
Random state for reproducibility.
|
||||
"""
|
||||
|
||||
_classifier_attrs: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
"""
|
||||
Private dictionary storing cached results from feature extraction and importance calculations.
|
||||
|
||||
Contains keys like 'features_', 'cost_', 'feature_importances_', and 'labels_' that are populated
|
||||
during transform(), fit(), and predict() operations. Access these via the corresponding properties.
|
||||
|
||||
This field is never serialized, so cached values will not persist across model save/load cycles.
|
||||
"""
|
||||
|
||||
model_config = {
|
||||
'arbitrary_types_allowed': True,
|
||||
}
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_random_state(self) -> SolvabilityClassifier:
|
||||
"""
|
||||
Validate the random state configuration between this object and the classifier.
|
||||
"""
|
||||
# If both random states are set, they definitely need to agree.
|
||||
if self.random_state is not None and self.classifier.random_state is not None:
|
||||
if self.random_state != self.classifier.random_state:
|
||||
raise ValueError(
|
||||
'The random state of the classifier and the top-level classifier must agree.'
|
||||
)
|
||||
|
||||
# Otherwise, we'll always set the classifier's random state to the top-level one.
|
||||
self.classifier.random_state = self.random_state
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def features_(self) -> pd.DataFrame:
|
||||
"""
|
||||
Get the features used by the classifier for the most recent inputs.
|
||||
"""
|
||||
if 'features_' not in self._classifier_attrs:
|
||||
raise ValueError(
|
||||
'SolvabilityClassifier.transform() has not yet been called.'
|
||||
)
|
||||
return self._classifier_attrs['features_']
|
||||
|
||||
@property
|
||||
def cost_(self) -> pd.DataFrame:
|
||||
"""
|
||||
Get the cost of the classifier for the most recent inputs.
|
||||
"""
|
||||
if 'cost_' not in self._classifier_attrs:
|
||||
raise ValueError(
|
||||
'SolvabilityClassifier.transform() has not yet been called.'
|
||||
)
|
||||
return self._classifier_attrs['cost_']
|
||||
|
||||
@property
|
||||
def feature_importances_(self) -> np.ndarray:
|
||||
"""
|
||||
Get the feature importances for the most recent inputs.
|
||||
"""
|
||||
if 'feature_importances_' not in self._classifier_attrs:
|
||||
raise ValueError(
|
||||
'No SolvabilityClassifier methods that produce feature importances (.fit(), .predict_proba(), and '
|
||||
'.predict()) have been called.'
|
||||
)
|
||||
return self._classifier_attrs['feature_importances_'] # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def is_fitted(self) -> bool:
|
||||
"""
|
||||
Check if the classifier is fitted.
|
||||
"""
|
||||
try:
|
||||
check_is_fitted(self.classifier)
|
||||
return True
|
||||
except NotFittedError:
|
||||
return False
|
||||
|
||||
def transform(self, issues: pd.Series, llm_config: LLMConfig) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input issues using the featurizer to extract features.
|
||||
|
||||
This method orchestrates the feature extraction pipeline:
|
||||
1. Uses the featurizer to generate embeddings for all issues
|
||||
2. Converts embeddings to a structured DataFrame
|
||||
3. Separates feature columns from metadata columns
|
||||
4. Stores results for later access via properties
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
llm_config: LLM configuration to use for feature extraction.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: A DataFrame containing only the feature columns (no metadata).
|
||||
"""
|
||||
# Generate feature embeddings for all issues using batch processing
|
||||
feature_embeddings = self.featurizer.embed_batch(
|
||||
issues, samples=self.samples, llm_config=llm_config
|
||||
)
|
||||
df = pd.DataFrame(embedding.to_row() for embedding in feature_embeddings)
|
||||
|
||||
# Split into feature columns (used by classifier) and cost columns (metadata)
|
||||
feature_columns = [feature.identifier for feature in self.featurizer.features]
|
||||
cost_columns = [col for col in df.columns if col not in feature_columns]
|
||||
|
||||
# Store both sets for access via properties
|
||||
self._classifier_attrs['features_'] = df[feature_columns]
|
||||
self._classifier_attrs['cost_'] = df[cost_columns]
|
||||
|
||||
return self.features_
|
||||
|
||||
def fit(
|
||||
self, issues: pd.Series, labels: pd.Series, llm_config: LLMConfig
|
||||
) -> SolvabilityClassifier:
|
||||
"""
|
||||
Fit the classifier to the input issues and labels.
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
|
||||
labels: A pandas Series containing the labels (0 or 1) for each issue.
|
||||
|
||||
llm_config: LLM configuration to use for feature extraction.
|
||||
|
||||
Returns:
|
||||
SolvabilityClassifier: The fitted classifier.
|
||||
"""
|
||||
features = self.transform(issues, llm_config=llm_config)
|
||||
self.classifier.fit(features, labels)
|
||||
|
||||
# Store labels for permutation importance calculation
|
||||
self._classifier_attrs['labels_'] = labels
|
||||
self._classifier_attrs['feature_importances_'] = self._importance(
|
||||
features, self.classifier.predict_proba(features), labels
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def predict_proba(self, issues: pd.Series, llm_config: LLMConfig) -> np.ndarray:
|
||||
"""
|
||||
Predict the solvability probabilities for the input issues.
|
||||
|
||||
Returns class probabilities where the second column represents the probability
|
||||
of the issue being solvable (positive class).
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
llm_config: LLM configuration to use for feature extraction.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of shape (n_samples, 2) with probabilities for each class.
|
||||
Column 0: probability of not solvable, Column 1: probability of solvable.
|
||||
"""
|
||||
features = self.transform(issues, llm_config=llm_config)
|
||||
scores = self.classifier.predict_proba(features)
|
||||
|
||||
# Calculate feature importances based on the configured strategy
|
||||
# For permutation importance, we need ground truth labels if available
|
||||
labels = self._classifier_attrs.get('labels_')
|
||||
if (
|
||||
self.importance_strategy == ImportanceStrategy.PERMUTATION
|
||||
and labels is not None
|
||||
):
|
||||
self._classifier_attrs['feature_importances_'] = self._importance(
|
||||
features, scores, labels
|
||||
)
|
||||
else:
|
||||
self._classifier_attrs['feature_importances_'] = self._importance(
|
||||
features, scores
|
||||
)
|
||||
|
||||
return scores # type: ignore[no-any-return]
|
||||
|
||||
def predict(self, issues: pd.Series, llm_config: LLMConfig) -> np.ndarray:
|
||||
"""
|
||||
Predict the solvability of the input issues by returning binary labels.
|
||||
|
||||
Uses a 0.5 probability threshold to convert probabilities to binary predictions.
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
llm_config: LLM configuration to use for feature extraction.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Boolean array where True indicates the issue is predicted as solvable.
|
||||
"""
|
||||
probabilities = self.predict_proba(issues, llm_config=llm_config)
|
||||
# Apply 0.5 threshold to convert probabilities to binary predictions
|
||||
labels = probabilities[:, 1] >= 0.5
|
||||
return labels
|
||||
|
||||
def _importance(
|
||||
self,
|
||||
features: pd.DataFrame,
|
||||
scores: np.ndarray,
|
||||
labels: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Calculate feature importance scores using the configured strategy.
|
||||
|
||||
Different strategies provide different interpretations:
|
||||
- SHAP: Shapley values indicating contribution to individual predictions
|
||||
- PERMUTATION: Decrease in model performance when feature is shuffled
|
||||
- IMPURITY: Gini impurity decrease from splits on each feature
|
||||
|
||||
Args:
|
||||
features: Feature matrix used for predictions.
|
||||
scores: Model prediction scores (unused for some strategies).
|
||||
labels: Ground truth labels (required for permutation importance).
|
||||
|
||||
Returns:
|
||||
np.ndarray: Feature importance scores, one per feature.
|
||||
"""
|
||||
match self.importance_strategy:
|
||||
case ImportanceStrategy.SHAP:
|
||||
# Use SHAP TreeExplainer for tree-based models
|
||||
explainer = shap.TreeExplainer(self.classifier)
|
||||
shap_values = explainer.shap_values(features)
|
||||
# Return mean SHAP values for the positive class (solvable)
|
||||
return shap_values.mean(axis=0)[:, 1] # type: ignore[no-any-return]
|
||||
|
||||
case ImportanceStrategy.PERMUTATION:
|
||||
# Permutation importance requires ground truth labels
|
||||
if labels is None:
|
||||
raise ValueError('Labels are required for permutation importance')
|
||||
result = permutation_importance(
|
||||
self.classifier,
|
||||
features,
|
||||
labels,
|
||||
n_repeats=10, # Number of permutation rounds for stability
|
||||
random_state=self.random_state,
|
||||
)
|
||||
return result.importances_mean # type: ignore[no-any-return]
|
||||
|
||||
case ImportanceStrategy.IMPURITY:
|
||||
# Use built-in feature importances from RandomForest
|
||||
return self.classifier.feature_importances_ # type: ignore[no-any-return]
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f'Unknown importance strategy: {self.importance_strategy}'
|
||||
)
|
||||
|
||||
def add_features(self, features: list[Feature]) -> SolvabilityClassifier:
|
||||
"""
|
||||
Add new features to the classifier's featurizer.
|
||||
|
||||
Note: Adding features after training requires retraining the classifier
|
||||
since the feature space will have changed.
|
||||
|
||||
Args:
|
||||
features: List of Feature objects to add.
|
||||
|
||||
Returns:
|
||||
SolvabilityClassifier: Self for method chaining.
|
||||
"""
|
||||
for feature in features:
|
||||
if feature not in self.featurizer.features:
|
||||
self.featurizer.features.append(feature)
|
||||
return self
|
||||
|
||||
def forget_features(self, features: list[Feature]) -> SolvabilityClassifier:
|
||||
"""
|
||||
Remove features from the classifier's featurizer.
|
||||
|
||||
Note: Removing features after training requires retraining the classifier
|
||||
since the feature space will have changed.
|
||||
|
||||
Args:
|
||||
features: List of Feature objects to remove.
|
||||
|
||||
Returns:
|
||||
SolvabilityClassifier: Self for method chaining.
|
||||
"""
|
||||
for feature in features:
|
||||
try:
|
||||
self.featurizer.features.remove(feature)
|
||||
except ValueError:
|
||||
# Feature not in list, continue with others
|
||||
continue
|
||||
return self
|
||||
|
||||
@field_serializer('classifier')
|
||||
@staticmethod
|
||||
def _rfc_to_json(rfc: RandomForestClassifier) -> str:
|
||||
"""
|
||||
Convert a RandomForestClassifier to a JSON-compatible value (a string).
|
||||
"""
|
||||
return base64.b64encode(pickle.dumps(rfc)).decode('utf-8')
|
||||
|
||||
@field_validator('classifier', mode='before')
|
||||
@staticmethod
|
||||
def _json_to_rfc(value: str | RandomForestClassifier) -> RandomForestClassifier:
|
||||
"""
|
||||
Convert a JSON-compatible value (a string) back to a RandomForestClassifier.
|
||||
"""
|
||||
if isinstance(value, RandomForestClassifier):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
model = pickle.loads(base64.b64decode(value))
|
||||
if isinstance(model, RandomForestClassifier):
|
||||
return model
|
||||
except Exception as e:
|
||||
raise ValueError(f'Failed to decode the classifier: {e}')
|
||||
|
||||
raise ValueError(
|
||||
'The classifier must be a RandomForestClassifier or a JSON-compatible dictionary.'
|
||||
)
|
||||
|
||||
def solvability_report(
|
||||
self, issue: str, llm_config: LLMConfig, **kwargs: Any
|
||||
) -> SolvabilityReport:
|
||||
"""
|
||||
Generate a solvability report for the given issue.
|
||||
|
||||
Args:
|
||||
issue: The issue description for which to generate the report.
|
||||
llm_config: Optional LLM configuration to use for feature extraction.
|
||||
kwargs: Additional metadata to include in the report.
|
||||
|
||||
Returns:
|
||||
SolvabilityReport: The generated solvability report.
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError(
|
||||
'The classifier must be fitted before generating a report.'
|
||||
)
|
||||
|
||||
scores = self.predict_proba(pd.Series([issue]), llm_config=llm_config)
|
||||
|
||||
return SolvabilityReport(
|
||||
identifier=self.identifier,
|
||||
issue=issue,
|
||||
score=scores[0, 1],
|
||||
features=self.features_.iloc[0].to_dict(),
|
||||
samples=self.samples,
|
||||
importance_strategy=self.importance_strategy,
|
||||
# Unlike the features, the importances are just a series with no link
|
||||
# to the actual feature names. For that we have to recombine with the
|
||||
# feature identifiers.
|
||||
feature_importances=dict(
|
||||
zip(
|
||||
self.featurizer.feature_identifiers(),
|
||||
self.feature_importances_.tolist(),
|
||||
)
|
||||
),
|
||||
random_state=self.random_state,
|
||||
metadata=dict(kwargs) if kwargs else None,
|
||||
# Both cost and response_latency are columns in the cost_ DataFrame,
|
||||
# so we can get both by just unpacking the first row.
|
||||
**self.cost_.iloc[0].to_dict(),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, issue: str, llm_config: LLMConfig, **kwargs: Any
|
||||
) -> SolvabilityReport:
|
||||
"""
|
||||
Generate a solvability report for the given issue.
|
||||
"""
|
||||
return self.solvability_report(issue, llm_config=llm_config, **kwargs)
|
||||
@@ -1,38 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DifficultyLevel(Enum):
|
||||
"""Enum representing the difficulty level based on solvability score."""
|
||||
|
||||
EASY = ('EASY', 0.7, '🟢')
|
||||
MEDIUM = ('MEDIUM', 0.4, '🟡')
|
||||
HARD = ('HARD', 0.0, '🔴')
|
||||
|
||||
def __init__(self, label: str, threshold: float, emoji: str):
|
||||
self.label = label
|
||||
self.threshold = threshold
|
||||
self.emoji = emoji
|
||||
|
||||
@classmethod
|
||||
def from_score(cls, score: float) -> DifficultyLevel:
|
||||
"""Get difficulty level from a solvability score.
|
||||
|
||||
Returns the difficulty level with the highest threshold that is less than or equal to the given score.
|
||||
"""
|
||||
# Sort enum values by threshold in descending order
|
||||
sorted_levels = sorted(cls, key=lambda x: x.threshold, reverse=True)
|
||||
|
||||
# Find the first level where score meets the threshold
|
||||
for level in sorted_levels:
|
||||
if score >= level.threshold:
|
||||
return level
|
||||
|
||||
# This should never happen if thresholds are set correctly,
|
||||
# but return the lowest threshold level as fallback
|
||||
return sorted_levels[-1]
|
||||
|
||||
def format_display(self) -> str:
|
||||
"""Format the difficulty level for display."""
|
||||
return f'{self.emoji} **Solvability: {self.label}**'
|
||||
@@ -1,368 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
|
||||
class Feature(BaseModel):
|
||||
"""
|
||||
Represents a single boolean feature that can be extracted from issue descriptions.
|
||||
|
||||
Features are semantic properties of issues (e.g., "has_code_example", "requires_debugging")
|
||||
that are evaluated by LLMs and used as input to the solvability classifier.
|
||||
"""
|
||||
|
||||
identifier: str
|
||||
"""Unique identifier for the feature, used as column name in feature matrices."""
|
||||
|
||||
description: str
|
||||
"""Human-readable description of what the feature represents, used in LLM prompts."""
|
||||
|
||||
@property
|
||||
def to_tool_description_field(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert this feature to a JSON schema field for LLM tool calling.
|
||||
|
||||
Returns:
|
||||
dict: JSON schema field definition for this feature.
|
||||
"""
|
||||
return {
|
||||
'type': 'boolean',
|
||||
'description': self.description,
|
||||
}
|
||||
|
||||
|
||||
class EmbeddingDimension(BaseModel):
|
||||
"""
|
||||
Represents a single dimension (feature evaluation) within a feature embedding sample.
|
||||
|
||||
Each dimension corresponds to one feature being evaluated as true/false for a given issue.
|
||||
"""
|
||||
|
||||
feature_id: str
|
||||
"""Identifier of the feature being evaluated."""
|
||||
|
||||
result: bool
|
||||
"""Boolean result of the feature evaluation for this sample."""
|
||||
|
||||
|
||||
# Type alias for a single embedding sample - maps feature identifiers to boolean values
|
||||
EmbeddingSample = dict[str, bool]
|
||||
"""
|
||||
A single sample from the LLM evaluation of features for an issue.
|
||||
Maps feature identifiers to their boolean evaluations.
|
||||
"""
|
||||
|
||||
|
||||
class FeatureEmbedding(BaseModel):
|
||||
"""
|
||||
Represents the complete feature embedding for a single issue, including multiple samples
|
||||
and associated metadata about the LLM calls used to generate it.
|
||||
|
||||
Multiple samples are collected to account for LLM variability and provide more robust
|
||||
feature estimates through averaging.
|
||||
"""
|
||||
|
||||
samples: list[EmbeddingSample]
|
||||
"""List of individual feature evaluation samples from the LLM."""
|
||||
|
||||
prompt_tokens: int | None = None
|
||||
"""Total prompt tokens consumed across all LLM calls for this embedding."""
|
||||
|
||||
completion_tokens: int | None = None
|
||||
"""Total completion tokens generated across all LLM calls for this embedding."""
|
||||
|
||||
response_latency: float | None = None
|
||||
"""Total response latency (seconds) across all LLM calls for this embedding."""
|
||||
|
||||
@property
|
||||
def dimensions(self) -> list[str]:
|
||||
"""
|
||||
Get all unique feature identifiers present across all samples.
|
||||
|
||||
Returns:
|
||||
list[str]: List of feature identifiers that appear in at least one sample.
|
||||
"""
|
||||
dims: set[str] = set()
|
||||
for sample in self.samples:
|
||||
dims.update(sample.keys())
|
||||
return list(dims)
|
||||
|
||||
def coefficient(self, dimension: str) -> float | None:
|
||||
"""
|
||||
Calculate the average coefficient (0-1) for a specific feature dimension.
|
||||
|
||||
This computes the proportion of samples where the feature was evaluated as True,
|
||||
providing a continuous feature value for the classifier.
|
||||
|
||||
Args:
|
||||
dimension: Feature identifier to calculate coefficient for.
|
||||
|
||||
Returns:
|
||||
float | None: Average coefficient (0.0-1.0), or None if dimension not found.
|
||||
"""
|
||||
# Extract boolean values for this dimension, converting to 0/1
|
||||
values = [
|
||||
1 if v else 0
|
||||
for v in [sample.get(dimension) for sample in self.samples]
|
||||
if v is not None
|
||||
]
|
||||
if values:
|
||||
return sum(values) / len(values)
|
||||
return None
|
||||
|
||||
def to_row(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert the embedding to a flat dictionary suitable for DataFrame construction.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Dictionary with metadata fields and feature coefficients.
|
||||
"""
|
||||
return {
|
||||
'response_latency': self.response_latency,
|
||||
'prompt_tokens': self.prompt_tokens,
|
||||
'completion_tokens': self.completion_tokens,
|
||||
**{dimension: self.coefficient(dimension) for dimension in self.dimensions},
|
||||
}
|
||||
|
||||
def sample_entropy(self) -> dict[str, float]:
|
||||
"""
|
||||
Calculate the Shannon entropy of feature evaluations across samples.
|
||||
|
||||
Higher entropy indicates more variability in LLM responses for a feature,
|
||||
which may suggest ambiguity in the feature definition or issue description.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: Mapping of feature identifiers to their entropy values (0-1).
|
||||
"""
|
||||
from collections import Counter
|
||||
from math import log2
|
||||
|
||||
entropy = {}
|
||||
for dimension in self.dimensions:
|
||||
# Count True/False occurrences for this feature across samples
|
||||
counts = Counter(sample.get(dimension, False) for sample in self.samples)
|
||||
total = sum(counts.values())
|
||||
if total == 0:
|
||||
entropy[dimension] = 0.0
|
||||
continue
|
||||
# Calculate Shannon entropy: -Σ(p * log2(p))
|
||||
entropy_value = -sum(
|
||||
(count / total) * log2(count / total)
|
||||
for count in counts.values()
|
||||
if count > 0
|
||||
)
|
||||
entropy[dimension] = entropy_value
|
||||
return entropy
|
||||
|
||||
|
||||
class Featurizer(BaseModel):
|
||||
"""
|
||||
Orchestrates LLM-based feature extraction from issue descriptions.
|
||||
|
||||
The Featurizer uses structured LLM tool calling to evaluate boolean features
|
||||
for issue descriptions. It handles prompt construction, tool schema generation,
|
||||
and batch processing with concurrency.
|
||||
"""
|
||||
|
||||
system_prompt: str
|
||||
"""System prompt that provides context and instructions to the LLM."""
|
||||
|
||||
message_prefix: str
|
||||
"""Prefix added to user messages before the issue description."""
|
||||
|
||||
features: list[Feature]
|
||||
"""List of features to extract from each issue description."""
|
||||
|
||||
def system_message(self) -> dict[str, Any]:
|
||||
"""
|
||||
Construct the system message for LLM conversations.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: System message dictionary for LLM API calls.
|
||||
"""
|
||||
return {
|
||||
'role': 'system',
|
||||
'content': self.system_prompt,
|
||||
}
|
||||
|
||||
def user_message(
|
||||
self, issue_description: str, set_cache: bool = True
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Construct the user message containing the issue description.
|
||||
|
||||
Args:
|
||||
issue_description: The description of the issue to analyze.
|
||||
set_cache: Whether to enable ephemeral caching for this message.
|
||||
Should be False for single samples to avoid cache overhead.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: User message dictionary for LLM API calls.
|
||||
"""
|
||||
message: dict[str, Any] = {
|
||||
'role': 'user',
|
||||
'content': f'{self.message_prefix}{issue_description}',
|
||||
}
|
||||
if set_cache:
|
||||
message['cache_control'] = {'type': 'ephemeral'}
|
||||
return message
|
||||
|
||||
@property
|
||||
def tool_choice(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get the tool choice configuration for forcing LLM to use the featurizer tool.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Tool choice configuration for LLM API calls.
|
||||
"""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {'name': 'call_featurizer'},
|
||||
}
|
||||
|
||||
@property
|
||||
def tool_description(self) -> dict[str, Any]:
|
||||
"""
|
||||
Generate the tool schema for the featurizer function.
|
||||
|
||||
Creates a JSON schema that describes the featurizer tool with all configured
|
||||
features as boolean parameters.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Complete tool description for LLM API calls.
|
||||
"""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'call_featurizer',
|
||||
'description': 'Record the features present in the issue.',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
feature.identifier: feature.to_tool_description_field
|
||||
for feature in self.features
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def embed(
|
||||
self,
|
||||
issue_description: str,
|
||||
llm_config: LLMConfig,
|
||||
temperature: float = 1.0,
|
||||
samples: int = 10,
|
||||
) -> FeatureEmbedding:
|
||||
"""
|
||||
Generate a feature embedding for a single issue description.
|
||||
|
||||
Makes multiple LLM calls to collect samples and reduce variance in feature evaluations.
|
||||
Each call uses tool calling to extract structured boolean feature values.
|
||||
|
||||
Args:
|
||||
issue_description: The description of the issue to analyze.
|
||||
llm_config: Configuration for the LLM to use.
|
||||
temperature: Sampling temperature for the model. Higher values increase randomness.
|
||||
samples: Number of samples to generate for averaging.
|
||||
|
||||
Returns:
|
||||
FeatureEmbedding: Complete embedding with samples and metadata.
|
||||
"""
|
||||
embedding_samples: list[dict[str, Any]] = []
|
||||
response_latency: float = 0.0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
|
||||
# TODO: use llm registry
|
||||
llm = LLM(llm_config, service_id='solvability')
|
||||
|
||||
# Generate multiple samples to account for LLM variability
|
||||
for _ in range(samples):
|
||||
start_time = time.time()
|
||||
response = llm.completion(
|
||||
messages=[
|
||||
self.system_message(),
|
||||
self.user_message(issue_description, set_cache=(samples > 1)),
|
||||
],
|
||||
tools=[self.tool_description],
|
||||
tool_choice=self.tool_choice,
|
||||
temperature=temperature,
|
||||
)
|
||||
stop_time = time.time()
|
||||
|
||||
# Extract timing and token usage metrics
|
||||
latency = stop_time - start_time
|
||||
# Parse the structured tool call response containing feature evaluations
|
||||
features = response.choices[0].message.tool_calls[0].function.arguments # type: ignore[index, union-attr]
|
||||
embedding = json.loads(features)
|
||||
|
||||
# Accumulate results and metrics
|
||||
embedding_samples.append(embedding)
|
||||
prompt_tokens += response.usage.prompt_tokens # type: ignore[union-attr, attr-defined]
|
||||
completion_tokens += response.usage.completion_tokens # type: ignore[union-attr, attr-defined]
|
||||
response_latency += latency
|
||||
|
||||
return FeatureEmbedding(
|
||||
samples=embedding_samples,
|
||||
response_latency=response_latency,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
def embed_batch(
|
||||
self,
|
||||
issue_descriptions: list[str],
|
||||
llm_config: LLMConfig,
|
||||
temperature: float = 1.0,
|
||||
samples: int = 10,
|
||||
) -> list[FeatureEmbedding]:
|
||||
"""
|
||||
Generate embeddings for a batch of issue descriptions using concurrent processing.
|
||||
|
||||
Processes multiple issues in parallel to improve throughput while maintaining
|
||||
result ordering.
|
||||
|
||||
Args:
|
||||
issue_descriptions: List of issue descriptions to analyze.
|
||||
llm_config: Configuration for the LLM to use.
|
||||
temperature: Sampling temperature for the model.
|
||||
samples: Number of samples to generate per issue.
|
||||
|
||||
Returns:
|
||||
list[FeatureEmbedding]: List of embeddings in the same order as input.
|
||||
"""
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# Submit all embedding tasks concurrently
|
||||
future_to_desc = {
|
||||
executor.submit(
|
||||
self.embed,
|
||||
desc,
|
||||
llm_config,
|
||||
temperature=temperature,
|
||||
samples=samples,
|
||||
): i
|
||||
for i, desc in enumerate(issue_descriptions)
|
||||
}
|
||||
|
||||
# Collect results in original order to maintain consistency
|
||||
results: list[FeatureEmbedding] = [None] * len(issue_descriptions) # type: ignore[list-item]
|
||||
for future in as_completed(future_to_desc):
|
||||
index = future_to_desc[future]
|
||||
results[index] = future.result()
|
||||
|
||||
return results
|
||||
|
||||
def feature_identifiers(self) -> list[str]:
|
||||
"""
|
||||
Get the identifiers of all configured features.
|
||||
|
||||
Returns:
|
||||
list[str]: List of feature identifiers in the order they were defined.
|
||||
"""
|
||||
return [feature.identifier for feature in self.features]
|
||||
@@ -1,23 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ImportanceStrategy(str, Enum):
|
||||
"""
|
||||
Strategy to use for calculating feature importances, which are used to estimate the predictive power of each feature
|
||||
in training loops and explanations.
|
||||
"""
|
||||
|
||||
SHAP = 'shap'
|
||||
"""
|
||||
Use SHAP (SHapley Additive exPlanations) to calculate feature importances.
|
||||
"""
|
||||
|
||||
PERMUTATION = 'permutation'
|
||||
"""
|
||||
Use the permutation-based feature importances.
|
||||
"""
|
||||
|
||||
IMPURITY = 'impurity'
|
||||
"""
|
||||
Use the impurity-based feature importances from the RandomForestClassifier.
|
||||
"""
|
||||
@@ -1,87 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SolvabilityReport(BaseModel):
|
||||
"""
|
||||
Comprehensive report containing solvability predictions and analysis for a single issue.
|
||||
|
||||
This report includes the solvability score, extracted feature values, feature importance analysis,
|
||||
cost metrics (tokens and latency), and metadata about the prediction process. It serves as the
|
||||
primary output format for solvability analysis and can be used for logging, debugging, and
|
||||
generating human-readable summaries.
|
||||
"""
|
||||
|
||||
identifier: str
|
||||
"""
|
||||
The identifier of the solvability model used to generate the report.
|
||||
"""
|
||||
|
||||
issue: str
|
||||
"""
|
||||
The issue description for which the solvability is predicted.
|
||||
|
||||
This field is exactly the input to the solvability model.
|
||||
"""
|
||||
|
||||
score: float
|
||||
"""
|
||||
[0, 1]-valued score indicating the likelihood of the issue being solvable.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
"""
|
||||
Total number of prompt tokens used in API calls made to generate the features.
|
||||
"""
|
||||
|
||||
completion_tokens: int
|
||||
"""
|
||||
Total number of completion tokens used in API calls made to generate the features.
|
||||
"""
|
||||
|
||||
response_latency: float
|
||||
"""
|
||||
Total response latency of API calls made to generate the features.
|
||||
"""
|
||||
|
||||
features: dict[str, float]
|
||||
"""
|
||||
[0, 1]-valued scores for each feature in the model.
|
||||
|
||||
These are the values fed to the random forest classifier to generate the solvability score.
|
||||
"""
|
||||
|
||||
samples: int
|
||||
"""
|
||||
Number of samples used to compute the feature embedding coefficients.
|
||||
"""
|
||||
|
||||
importance_strategy: ImportanceStrategy
|
||||
"""
|
||||
Strategy used to calculate feature importances.
|
||||
"""
|
||||
|
||||
feature_importances: dict[str, float]
|
||||
"""
|
||||
Importance scores for each feature in the model.
|
||||
|
||||
Interpretation of these scores depends on the importance strategy used.
|
||||
"""
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
"""
|
||||
Datetime when the report was created.
|
||||
"""
|
||||
|
||||
random_state: int | None = None
|
||||
"""
|
||||
Classifier random state used when generating this report.
|
||||
"""
|
||||
|
||||
metadata: dict[str, Any] | None = None
|
||||
"""
|
||||
Metadata for logging and debugging purposes.
|
||||
"""
|
||||
@@ -1,172 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from integrations.solvability.models.difficulty_level import DifficultyLevel
|
||||
from integrations.solvability.models.report import SolvabilityReport
|
||||
from integrations.solvability.prompts import load_prompt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.llm import LLM
|
||||
|
||||
|
||||
class SolvabilitySummary(BaseModel):
|
||||
"""Summary of the solvability analysis in human-readable format."""
|
||||
|
||||
score: float
|
||||
"""
|
||||
Solvability score indicating the likelihood of the issue being solvable.
|
||||
"""
|
||||
|
||||
summary: str
|
||||
"""
|
||||
The executive summary content generated by the LLM.
|
||||
"""
|
||||
|
||||
actionable_feedback: str
|
||||
"""
|
||||
Actionable feedback content generated by the LLM.
|
||||
"""
|
||||
|
||||
positive_feedback: str
|
||||
"""
|
||||
Positive feedback content generated by the LLM, highlighting what is good about the issue.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
"""
|
||||
Number of prompt tokens used in the API call to generate the summary.
|
||||
"""
|
||||
|
||||
completion_tokens: int
|
||||
"""
|
||||
Number of completion tokens used in the API call to generate the summary.
|
||||
"""
|
||||
|
||||
response_latency: float
|
||||
"""
|
||||
Response latency of the API call to generate the summary.
|
||||
"""
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
"""
|
||||
Datetime when the summary was created.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def tool_description() -> dict[str, Any]:
|
||||
"""Get the tool description for the LLM."""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'solvability_summary',
|
||||
'description': 'Generate a human-readable summary of the solvability analysis.',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'summary': {
|
||||
'type': 'string',
|
||||
'description': 'A high-level (at most two sentences) summary of the solvability report.',
|
||||
},
|
||||
'actionable_feedback': {
|
||||
'type': 'string',
|
||||
'description': (
|
||||
'Bullet list of 1-3 pieces of actionable feedback on how the user can address the lowest scoring relevant features.'
|
||||
),
|
||||
},
|
||||
'positive_feedback': {
|
||||
'type': 'string',
|
||||
'description': (
|
||||
'Bullet list of 1-3 pieces of positive feedback on the issue, highlighting what is good about it.'
|
||||
),
|
||||
},
|
||||
},
|
||||
'required': ['summary', 'actionable_feedback'],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def tool_choice() -> dict[str, Any]:
|
||||
"""Get the tool choice for the LLM."""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'solvability_summary',
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def system_message() -> dict[str, Any]:
|
||||
"""Get the system message for the LLM."""
|
||||
return {
|
||||
'role': 'system',
|
||||
'content': load_prompt('summary_system_message'),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def user_message(report: SolvabilityReport) -> dict[str, Any]:
|
||||
"""Get the user message for the LLM."""
|
||||
return {
|
||||
'role': 'user',
|
||||
'content': load_prompt(
|
||||
'summary_user_message',
|
||||
report=report.model_dump(),
|
||||
difficulty_level=DifficultyLevel.from_score(report.score).value[0],
|
||||
),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_report(report: SolvabilityReport, llm: LLM) -> SolvabilitySummary:
|
||||
"""Create a SolvabilitySummary from a SolvabilityReport."""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
response = llm.completion(
|
||||
messages=[
|
||||
SolvabilitySummary.system_message(),
|
||||
SolvabilitySummary.user_message(report),
|
||||
],
|
||||
tools=[SolvabilitySummary.tool_description()],
|
||||
tool_choice=SolvabilitySummary.tool_choice(),
|
||||
)
|
||||
response_latency = time.time() - start_time
|
||||
|
||||
# Grab the arguments from the forced function call
|
||||
arguments = json.loads(
|
||||
response.choices[0].message.tool_calls[0].function.arguments
|
||||
)
|
||||
|
||||
return SolvabilitySummary(
|
||||
# The score is copied directly from the report
|
||||
score=report.score,
|
||||
# Performance and usage metrics are pulled from the response
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
response_latency=response_latency,
|
||||
# Every other field should be taken from the forced function call
|
||||
**arguments,
|
||||
)
|
||||
|
||||
def format_as_markdown(self) -> str:
|
||||
"""Format the summary content as Markdown."""
|
||||
# Convert score to difficulty level enum
|
||||
difficulty_level = DifficultyLevel.from_score(self.score)
|
||||
|
||||
# Create the main difficulty display
|
||||
result = f'{difficulty_level.format_display()}\n\n{self.summary}'
|
||||
|
||||
# If not easy, show the three features with lowest importance scores
|
||||
if difficulty_level != DifficultyLevel.EASY:
|
||||
# Add dropdown with lowest importance features
|
||||
result += '\n\nYou can make the issue easier to resolve by addressing these concerns in the conversation:\n\n'
|
||||
result += self.actionable_feedback
|
||||
|
||||
# If the difficulty isn't hard, add some positive feedback
|
||||
if difficulty_level != DifficultyLevel.HARD:
|
||||
result += '\n\nPositive feedback:\n\n'
|
||||
result += self.positive_feedback
|
||||
|
||||
return result
|
||||
@@ -1,13 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import jinja2
|
||||
|
||||
|
||||
def load_prompt(prompt: str, **kwargs) -> str:
|
||||
"""Load a prompt by name. Passes all the keyword arguments to the prompt template."""
|
||||
env = jinja2.Environment(loader=jinja2.FileSystemLoader(Path(__file__).parent))
|
||||
template = env.get_template(f'{prompt}.j2')
|
||||
return template.render(**kwargs)
|
||||
|
||||
|
||||
__all__ = ['load_prompt']
|
||||
@@ -1,10 +0,0 @@
|
||||
You are a helpful assistant that generates human-readable summaries of solvability reports.
|
||||
The report predicts how likely it is that the issue can be resolved, and is produced purely based on the information provided in the issue description and comments.
|
||||
The report explains which features are present in the issue and how impactful they are to the solvability score (using SHAP values).
|
||||
Your task is to create a concise, high-level summary of the solvability analysis,
|
||||
with an emphasis on the key factors that make the issue easy or hard to resolve.
|
||||
Focus on the features with extreme scores, BUT ONLY if they are related to the issue at hand after careful consideration.
|
||||
You should NEVER mention: SHAP, scores, feature names, or technical metrics.
|
||||
You will also be given the expected difficulty of the issue, as EASY/MEDIUM/HARD.
|
||||
Be sure to frame your responses with that difficulty in mind.
|
||||
For example, if the issue is HARD you should not describe it as "straightforward".
|
||||
@@ -1,9 +0,0 @@
|
||||
Generate a high-level summary of the solvability report:
|
||||
|
||||
{{ report }}
|
||||
|
||||
We estimate the issue is {{ difficulty_level }}.
|
||||
The summary should be concise (at most two sentences) and describe the primary characteristics of this issue.
|
||||
Focus on what information is present and what factors are most relevant to resolution.
|
||||
Actionable feedback should be something that can be addressed by the user purely by providing more information.
|
||||
Positive feedback should explain the features that are positively contributing to the solvability score.
|
||||
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from jinja2 import Environment
|
||||
from pydantic import BaseModel
|
||||
@@ -10,7 +11,6 @@ if TYPE_CHECKING:
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
|
||||
|
||||
class GitLabResourceType(Enum):
|
||||
@@ -53,11 +53,11 @@ class ResolverViewInterface(SummaryExtractionTracker):
|
||||
"""Instructions passed when conversation is first initialized."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def initialize_new_conversation(self) -> 'ConversationMetadata':
|
||||
"""Initialize a new conversation and return metadata.
|
||||
async def initialize_new_conversation(self) -> UUID:
|
||||
"""Initialize a new conversation and return the conversation ID.
|
||||
|
||||
For V1 conversations, creates a dummy ConversationMetadata.
|
||||
For V0 conversations, initializes through the conversation store.
|
||||
This method resolves the target organization and generates a new
|
||||
conversation ID.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -65,7 +65,7 @@ class ResolverViewInterface(SummaryExtractionTracker):
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: 'PROVIDER_TOKEN_TYPE',
|
||||
conversation_metadata: 'ConversationMetadata',
|
||||
conversation_id: UUID,
|
||||
saas_user_auth: 'UserAuth',
|
||||
) -> None:
|
||||
"""Create a new conversation.
|
||||
@@ -73,7 +73,7 @@ class ResolverViewInterface(SummaryExtractionTracker):
|
||||
Args:
|
||||
jinja_env: Jinja2 environment for template rendering
|
||||
git_provider_tokens: Token mapping for git providers
|
||||
conversation_metadata: Metadata for the conversation
|
||||
conversation_id: The UUID of the conversation to create
|
||||
saas_user_auth: User authentication for SaaS
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1,23 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.constants import WEB_HOST
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events import Event, EventSource
|
||||
from openhands.events.action import (
|
||||
AgentFinishAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.event_store_abc import EventStoreABC
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.integrations.service_types import Repository
|
||||
|
||||
# ---- DO NOT REMOVE ----
|
||||
@@ -27,10 +15,8 @@ HOST = WEB_HOST
|
||||
|
||||
IS_LOCAL_DEPLOYMENT = 'localhost' in HOST
|
||||
HOST_URL = f'https://{HOST}' if not IS_LOCAL_DEPLOYMENT else f'http://{HOST}'
|
||||
GITHUB_WEBHOOK_URL = f'{HOST_URL}/integration/github/events'
|
||||
GITLAB_WEBHOOK_URL = f'{HOST_URL}/integration/gitlab/events'
|
||||
conversation_prefix = 'conversations/{}'
|
||||
CONVERSATION_URL = f'{HOST_URL}/{conversation_prefix}'
|
||||
CONVERSATION_URL = f'{HOST_URL}/conversations/{{}}'
|
||||
|
||||
# Toggle for auto-response feature that proactively starts conversations with users when workflow tests fail
|
||||
ENABLE_PROACTIVE_CONVERSATION_STARTERS = (
|
||||
@@ -77,30 +63,11 @@ def get_user_not_found_message(username: str | None = None) -> str:
|
||||
return f"It looks like you haven't created an OpenHands account yet. Please sign up at [OpenHands Cloud]({HOST_URL}) and try again."
|
||||
|
||||
|
||||
# Toggle for solvability report feature
|
||||
ENABLE_SOLVABILITY_ANALYSIS = (
|
||||
os.getenv('ENABLE_SOLVABILITY_ANALYSIS', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
# Toggle for V1 GitHub resolver feature
|
||||
ENABLE_V1_GITHUB_RESOLVER = (
|
||||
os.getenv('ENABLE_V1_GITHUB_RESOLVER', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
ENABLE_V1_SLACK_RESOLVER = (
|
||||
os.getenv('ENABLE_V1_SLACK_RESOLVER', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
# Toggle for V1 GitLab resolver feature
|
||||
ENABLE_V1_GITLAB_RESOLVER = (
|
||||
os.getenv('ENABLE_V1_GITLAB_RESOLVER', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR = (
|
||||
os.getenv('OPENHANDS_RESOLVER_TEMPLATES_DIR')
|
||||
or 'openhands/integrations/templates/resolver/'
|
||||
)
|
||||
jinja_env = Environment(loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR))
|
||||
_jinja_env = Environment(loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR))
|
||||
|
||||
|
||||
def get_oh_labels(web_host: str) -> tuple[str, str]:
|
||||
@@ -122,31 +89,11 @@ def get_oh_labels(web_host: str) -> tuple[str, str]:
|
||||
|
||||
|
||||
def get_summary_instruction():
|
||||
summary_instruction_template = jinja_env.get_template('summary_prompt.j2')
|
||||
summary_instruction_template = _jinja_env.get_template('summary_prompt.j2')
|
||||
summary_instruction = summary_instruction_template.render()
|
||||
return summary_instruction
|
||||
|
||||
|
||||
async def get_user_v1_enabled_setting(user_id: str | None) -> bool:
|
||||
"""Get the user's V1 conversation API setting.
|
||||
|
||||
Args:
|
||||
user_id: The keycloak user ID
|
||||
|
||||
Returns:
|
||||
True if V1 conversations are enabled for this user, False otherwise
|
||||
"""
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
org = await OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
|
||||
if not org or org.v1_enabled is None:
|
||||
return False
|
||||
|
||||
return org.v1_enabled
|
||||
|
||||
|
||||
def has_exact_mention(text: str, mention: str) -> bool:
|
||||
"""Check if the text contains an exact mention (not part of a larger word).
|
||||
|
||||
@@ -173,205 +120,6 @@ def has_exact_mention(text: str, mention: str) -> bool:
|
||||
return bool(re.search(rf'(?:^|[^\w@]){pattern}(?![\w-])', text_lower))
|
||||
|
||||
|
||||
def confirm_event_type(event: Event):
|
||||
return isinstance(event, AgentStateChangedObservation) and not (
|
||||
event.agent_state == AgentState.REJECTED
|
||||
or event.agent_state == AgentState.USER_CONFIRMED
|
||||
or event.agent_state == AgentState.USER_REJECTED
|
||||
or event.agent_state == AgentState.LOADING
|
||||
or event.agent_state == AgentState.RUNNING
|
||||
)
|
||||
|
||||
|
||||
def get_readable_error_reason(reason: str):
|
||||
if reason == 'STATUS$ERROR_LLM_AUTHENTICATION':
|
||||
reason = 'Authentication with the LLM provider failed. Please check your API key or credentials'
|
||||
elif reason == 'STATUS$ERROR_LLM_SERVICE_UNAVAILABLE':
|
||||
reason = 'The LLM service is temporarily unavailable. Please try again later'
|
||||
elif reason == 'STATUS$ERROR_LLM_INTERNAL_SERVER_ERROR':
|
||||
reason = 'The LLM provider encountered an internal error. Please try again soon'
|
||||
elif reason == 'STATUS$ERROR_LLM_OUT_OF_CREDITS':
|
||||
reason = "You've run out of credits. Please top up to continue"
|
||||
elif reason == 'STATUS$ERROR_LLM_CONTENT_POLICY_VIOLATION':
|
||||
reason = 'Content policy violation. The output was blocked by content filtering policy'
|
||||
return reason
|
||||
|
||||
|
||||
def get_summary_for_agent_state(
|
||||
observations: list[AgentStateChangedObservation], conversation_link: str
|
||||
) -> str:
|
||||
unknown_error_msg = f'OpenHands encountered an unknown error. [See the conversation]({conversation_link}) for more information, or try again'
|
||||
|
||||
if len(observations) == 0:
|
||||
logger.error(
|
||||
'Unknown error: No agent state observations found',
|
||||
extra={'conversation_link': conversation_link},
|
||||
)
|
||||
return unknown_error_msg
|
||||
|
||||
observation: AgentStateChangedObservation = observations[0]
|
||||
state = observation.agent_state
|
||||
|
||||
if state == AgentState.RATE_LIMITED:
|
||||
logger.warning(
|
||||
'Agent was rate limited',
|
||||
extra={
|
||||
'agent_state': state.value,
|
||||
'conversation_link': conversation_link,
|
||||
'observation_reason': getattr(observation, 'reason', None),
|
||||
},
|
||||
)
|
||||
return 'OpenHands was rate limited by the LLM provider. Please try again later.'
|
||||
|
||||
if state == AgentState.ERROR:
|
||||
reason = observation.reason
|
||||
reason = get_readable_error_reason(reason)
|
||||
|
||||
logger.error(
|
||||
'Agent encountered an error',
|
||||
extra={
|
||||
'agent_state': state.value,
|
||||
'conversation_link': conversation_link,
|
||||
'observation_reason': observation.reason,
|
||||
'readable_reason': reason,
|
||||
},
|
||||
)
|
||||
|
||||
return f'OpenHands encountered an error: **{reason}**.\n\n[See the conversation]({conversation_link}) for more information.'
|
||||
|
||||
if state == AgentState.AWAITING_USER_INPUT:
|
||||
logger.info(
|
||||
'Agent is awaiting user input',
|
||||
extra={
|
||||
'agent_state': state.value,
|
||||
'conversation_link': conversation_link,
|
||||
'observation_reason': getattr(observation, 'reason', None),
|
||||
},
|
||||
)
|
||||
return f'OpenHands is waiting for your input. [Continue the conversation]({conversation_link}) to provide additional instructions.'
|
||||
|
||||
# Log unknown agent state as error
|
||||
logger.error(
|
||||
'Unknown error: Unhandled agent state',
|
||||
extra={
|
||||
'agent_state': state.value if hasattr(state, 'value') else str(state),
|
||||
'conversation_link': conversation_link,
|
||||
'observation_reason': getattr(observation, 'reason', None),
|
||||
},
|
||||
)
|
||||
return unknown_error_msg
|
||||
|
||||
|
||||
def get_final_agent_observation(
|
||||
event_store: EventStoreABC,
|
||||
) -> list[AgentStateChangedObservation]:
|
||||
events = list(
|
||||
event_store.search_events(
|
||||
filter=EventFilter(
|
||||
source=EventSource.ENVIRONMENT,
|
||||
include_types=(AgentStateChangedObservation,),
|
||||
),
|
||||
limit=1,
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
result = [e for e in events if isinstance(e, AgentStateChangedObservation)]
|
||||
assert len(result) == len(events)
|
||||
return result
|
||||
|
||||
|
||||
def get_last_user_msg(event_store: EventStoreABC) -> list[MessageAction]:
|
||||
events = list(
|
||||
event_store.search_events(
|
||||
filter=EventFilter(
|
||||
source=EventSource.USER,
|
||||
include_types=(MessageAction,),
|
||||
),
|
||||
limit=1,
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
result = [e for e in events if isinstance(e, MessageAction)]
|
||||
assert len(result) == len(events)
|
||||
return result
|
||||
|
||||
|
||||
def extract_summary_from_event_store(
|
||||
event_store: EventStoreABC, conversation_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Get agent summary or alternative message depending on current AgentState
|
||||
"""
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
summary_instruction = get_summary_instruction()
|
||||
|
||||
instruction_events = list(
|
||||
event_store.search_events(
|
||||
filter=EventFilter(
|
||||
query=json.dumps(summary_instruction),
|
||||
source=EventSource.USER,
|
||||
include_types=(MessageAction,),
|
||||
),
|
||||
limit=1,
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(event_store)
|
||||
|
||||
# Find summary instruction event ID
|
||||
if not instruction_events:
|
||||
logger.warning(
|
||||
'no_instruction_event_found', extra={'conversation_id': conversation_id}
|
||||
)
|
||||
return get_summary_for_agent_state(
|
||||
final_agent_observation, conversation_link
|
||||
) # Agent did not receive summary instruction
|
||||
|
||||
summary_events = list(
|
||||
event_store.search_events(
|
||||
filter=EventFilter(
|
||||
source=EventSource.AGENT,
|
||||
include_types=(MessageAction, AgentFinishAction),
|
||||
),
|
||||
limit=1,
|
||||
reverse=True,
|
||||
start_id=instruction_events[0].id,
|
||||
)
|
||||
)
|
||||
|
||||
if not summary_events:
|
||||
logger.warning(
|
||||
'no_agent_messages_found', extra={'conversation_id': conversation_id}
|
||||
)
|
||||
return get_summary_for_agent_state(
|
||||
final_agent_observation, conversation_link
|
||||
) # Agent failed to generate summary
|
||||
|
||||
summary_event = summary_events[0]
|
||||
if isinstance(summary_event, MessageAction):
|
||||
return summary_event.content
|
||||
|
||||
assert isinstance(summary_event, AgentFinishAction)
|
||||
return summary_event.final_thought
|
||||
|
||||
|
||||
def append_conversation_footer(message: str, conversation_id: str) -> str:
|
||||
"""
|
||||
Append a small footer with the conversation URL to a message.
|
||||
|
||||
Args:
|
||||
message: The original message content
|
||||
conversation_id: The conversation ID to link to
|
||||
|
||||
Returns:
|
||||
The message with the conversation footer appended
|
||||
"""
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
footer = f'\n\n[View full conversation]({conversation_link})'
|
||||
return message + footer
|
||||
|
||||
|
||||
def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||
"""
|
||||
Extract all repository names in the format 'owner/repo' from various Git provider URLs
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Add llm_profiles column to user table.
|
||||
|
||||
The Settings model exposes ``llm_profiles`` (saved LLM configurations plus
|
||||
the active profile name), but the SaaS path persists a flattened Settings
|
||||
dump onto the User/Org rows. Without a column here the field is silently
|
||||
dropped on store() and always defaults to empty on load(), so saved
|
||||
profiles disappear after any settings update or page refresh.
|
||||
|
||||
The column is plain ``String`` because the ORM-level ``EncryptedJSON``
|
||||
TypeDecorator stores JSON-serialized profiles as a JWE-encrypted string —
|
||||
profiles can carry per-profile ``api_key`` values, so the at-rest
|
||||
representation must match the existing org/member encrypted-secret pattern.
|
||||
|
||||
Revision ID: 109
|
||||
Revises: 108
|
||||
Create Date: 2026-04-28
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '109'
|
||||
down_revision: Union[str, None] = '108'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('user', sa.Column('llm_profiles', sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('user', 'llm_profiles')
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Add agent_kind column to conversation_metadata table.
|
||||
|
||||
Stores the agent type ('llm' or 'acp') for each conversation so the
|
||||
correct agent-server endpoint can be used when routing requests.
|
||||
|
||||
Revision ID: 110
|
||||
Revises: 109
|
||||
Create Date: 2026-04-28
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '110'
|
||||
down_revision: Union[str, None] = '109'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('agent_kind', sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('conversation_metadata', 'agent_kind')
|
||||
49
enterprise/poetry.lock
generated
49
enterprise/poetry.lock
generated
@@ -4961,14 +4961,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "lmnr"
|
||||
version = "0.7.46"
|
||||
version = "0.7.49"
|
||||
description = "Python SDK for Laminar"
|
||||
optional = false
|
||||
python-versions = "<4,>=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "lmnr-0.7.46-py3-none-any.whl", hash = "sha256:596599af3eb999c5fb253640967fa893d34998b78c577b8773c214d89efa81c9"},
|
||||
{file = "lmnr-0.7.46.tar.gz", hash = "sha256:082c9d17a1962b559651eea843eff49c1ec54729654ba37388c4a360e862af78"},
|
||||
{file = "lmnr-0.7.49-py3-none-any.whl", hash = "sha256:510113b02bac3e639fa80244c67ff0be5948234275b0ef04cd310d66c7d720bf"},
|
||||
{file = "lmnr-0.7.49.tar.gz", hash = "sha256:0b6da7d1707ce4e248c15083835a70723be9e6cc652b77ddc95c12e27dc87ef3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4984,11 +4984,11 @@ opentelemetry-sdk = ">=1.39.0,<2.0.0"
|
||||
opentelemetry-semantic-conventions = "0.60b1"
|
||||
opentelemetry-semantic-conventions-ai = "0.4.13"
|
||||
orjson = ">=3.0.0,<4.0.0"
|
||||
packaging = ">=22.0"
|
||||
packaging = ">=22.0,<27.0"
|
||||
pydantic = ">=2.0.3,<3.0.0"
|
||||
python-dotenv = ">=1.0,<2.0"
|
||||
tenacity = ">=8.0,<10.0"
|
||||
tqdm = ">=4.0"
|
||||
tqdm = ">=4.0,<5.0"
|
||||
|
||||
[package.extras]
|
||||
alephalpha = ["opentelemetry-instrumentation-alephalpha (==0.52.4)"]
|
||||
@@ -6454,14 +6454,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.17.0"
|
||||
version = "1.19.0"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.17.0-py3-none-any.whl", hash = "sha256:44336cad001c31caeb516481a5a7aea6dd9b5ab4798461f147b5231668d8fb74"},
|
||||
{file = "openhands_agent_server-1.17.0.tar.gz", hash = "sha256:3a88449a3b9ded653dcd2a8c518810c75602873cf9f7d4e8f9b90fd8fd225652"},
|
||||
{file = "openhands_agent_server-1.19.0-py3-none-any.whl", hash = "sha256:132902dc918f446e3b0f5cda9f4da36a4881fc73fe509eb177959afe988c38bb"},
|
||||
{file = "openhands_agent_server-1.19.0.tar.gz", hash = "sha256:4f81b5ec550881706b361c51a422b6daad2a33c73b94d2f3088c84ed32ce049e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6523,9 +6523,9 @@ memory-profiler = ">=0.61"
|
||||
numpy = "*"
|
||||
openai = "2.8"
|
||||
openhands-aci = "0.3.3"
|
||||
openhands-agent-server = "1.17"
|
||||
openhands-sdk = "1.17"
|
||||
openhands-tools = "1.17"
|
||||
openhands-agent-server = "1.19"
|
||||
openhands-sdk = "1.19"
|
||||
openhands-tools = "1.19"
|
||||
opentelemetry-api = ">=1.33.1"
|
||||
opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
|
||||
orjson = ">=3.11.6"
|
||||
@@ -6547,7 +6547,7 @@ python-docx = "*"
|
||||
python-dotenv = "*"
|
||||
python-frontmatter = ">=1.1"
|
||||
python-json-logger = ">=3.2.1"
|
||||
python-multipart = ">=0.0.22"
|
||||
python-multipart = ">=0.0.26"
|
||||
python-pptx = "*"
|
||||
python-socketio = "5.14"
|
||||
pythonnet = {version = "*", markers = "sys_platform == \"win32\""}
|
||||
@@ -6571,23 +6571,20 @@ uvicorn = "*"
|
||||
whatthepatch = ">=1.0.6"
|
||||
zope-interface = "7.2"
|
||||
|
||||
[package.extras]
|
||||
third-party-runtimes = ["daytona (==0.24.2)", "e2b-code-interpreter (>=2)", "modal (>=0.66.26,<1.2)", "runloop-api-client (==0.50)"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
url = ".."
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.17.0"
|
||||
version = "1.19.0"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.17.0-py3-none-any.whl", hash = "sha256:3b771e72209453871c3036a562cf33e9ad9642a54bd48edb44f89915ac54709d"},
|
||||
{file = "openhands_sdk-1.17.0.tar.gz", hash = "sha256:3c69df6590f023a514137272d413658848e0d5bc9aecf941b946c8662862779a"},
|
||||
{file = "openhands_sdk-1.19.0-py3-none-any.whl", hash = "sha256:704906533da50f2d0e93bf28609b1a36a4aa4ce578bfac13a3d1a76609d87db8"},
|
||||
{file = "openhands_sdk-1.19.0.tar.gz", hash = "sha256:5611d877e6495a712725569f6bca3de8fabefd9e44c61dc30bd39f8883371508"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6598,7 +6595,7 @@ fastmcp = ">=3.0.0"
|
||||
filelock = ">=3.20.1"
|
||||
httpx = {version = ">=0.27.0", extras = ["socks"]}
|
||||
litellm = ">=1.82.6,<1.82.7 || >1.82.7,<1.82.8 || >1.82.8"
|
||||
lmnr = ">=0.7.24"
|
||||
lmnr = ">=0.7.47"
|
||||
pydantic = ">=2.12.5"
|
||||
python-frontmatter = ">=1.1.0"
|
||||
python-json-logger = ">=3.3.0"
|
||||
@@ -6610,14 +6607,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.17.0"
|
||||
version = "1.19.0"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.17.0-py3-none-any.whl", hash = "sha256:76cd30fcc153627444f18638bcd926c9190989f80a3492381e84a181c021d815"},
|
||||
{file = "openhands_tools-1.17.0.tar.gz", hash = "sha256:4a9d6c1aec00d366d0feb1ac2e9ee9988ad9806a0ef89f7dbe4655644e639d4a"},
|
||||
{file = "openhands_tools-1.19.0-py3-none-any.whl", hash = "sha256:ff5ddb40d628a468eda4488b2c0045470c88e396bab43330b6b468f3ada47b9e"},
|
||||
{file = "openhands_tools-1.19.0.tar.gz", hash = "sha256:b4dc59a813fe1fe7bda519979498a7bdf07dd8f83ea3f0aad78c154f5fcb9a32"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -14147,6 +14144,14 @@ optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "tree_sitter_c_sharp-0.23.1-cp310-abi3-macosx_10_9_x86_64.whl", hash = "sha256:e87be7572991552606a3155d2f6c2045ded8bce94bfd9f74bf521d949c219a1c"},
|
||||
{file = "tree_sitter_c_sharp-0.23.1-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:86c2fdf178c66474a1be2965602818d30780e4e3ed890e3c206931f65d9a154c"},
|
||||
{file = "tree_sitter_c_sharp-0.23.1-cp310-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:035d259e64c41d02cc45afc3b8b46388b232e7d16d84734d851cca7334761da5"},
|
||||
{file = "tree_sitter_c_sharp-0.23.1-cp310-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fa472cb9de7e14fee9408e144f29f68384cd8e9c677dff0002da19f361a59bdf"},
|
||||
{file = "tree_sitter_c_sharp-0.23.1-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:1a0ea86eccff74e85ab4a2cf77c813fad7c84162962ce242dff0c51601028832"},
|
||||
{file = "tree_sitter_c_sharp-0.23.1-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8ab26dc998bbd4b4287b129f67c10ca715deb402ed77d0645674490ea509097e"},
|
||||
{file = "tree_sitter_c_sharp-0.23.1-cp310-abi3-win_amd64.whl", hash = "sha256:d4486653feaff3314ef45534dcb6f9ea8ab3aa160896287c6473788f88eb38be"},
|
||||
{file = "tree_sitter_c_sharp-0.23.1-cp310-abi3-win_arm64.whl", hash = "sha256:e7a14b76ec23cc8386cf662d5ea602d81331376c93ca6299a97b174047790345"},
|
||||
{file = "tree_sitter_c_sharp-0.23.1-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2b612a6e5bd17bb7fa2aab4bb6fc1fba45c94f09cb034ab332e45603b86e32fd"},
|
||||
{file = "tree_sitter_c_sharp-0.23.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a8b98f62bc53efcd4d971151950c9b9cd5cbe3bacdb0cd69fdccac63350d83e"},
|
||||
{file = "tree_sitter_c_sharp-0.23.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:986e93d845a438ec3c4416401aa98e6a6f6631d644bbbc2e43fcb915c51d255d"},
|
||||
|
||||
@@ -28,7 +28,6 @@ from server.routes.api_keys import api_router as api_keys_router # noqa: E402
|
||||
from server.routes.auth import api_router, oauth_router # noqa: E402
|
||||
from server.routes.billing import billing_router # noqa: E402
|
||||
from server.routes.email import api_router as email_router # noqa: E402
|
||||
from server.routes.feedback import router as feedback_router # noqa: E402
|
||||
from server.routes.github_proxy import add_github_proxy_routes # noqa: E402
|
||||
from server.routes.integration.jira import jira_integration_router # noqa: E402
|
||||
from server.routes.integration.jira_dc import jira_dc_integration_router # noqa: E402
|
||||
@@ -147,7 +146,6 @@ if BITBUCKET_DATA_CENTER_HOST:
|
||||
|
||||
base_app.include_router(bitbucket_dc_proxy_router)
|
||||
base_app.include_router(email_router) # Add routes for email management
|
||||
base_app.include_router(feedback_router) # Add routes for conversation feedback
|
||||
|
||||
|
||||
base_app.add_middleware(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
from openhands.integrations.gitlab.constants import GITLAB_HOST
|
||||
|
||||
GITHUB_APP_CLIENT_ID = os.getenv('GITHUB_APP_CLIENT_ID', '').strip()
|
||||
GITHUB_APP_CLIENT_SECRET = os.getenv('GITHUB_APP_CLIENT_SECRET', '').strip()
|
||||
GITHUB_APP_WEBHOOK_SECRET = os.getenv('GITHUB_APP_WEBHOOK_SECRET', '')
|
||||
@@ -14,6 +16,7 @@ KEYCLOAK_SERVER_URL_EXT = os.getenv(
|
||||
KEYCLOAK_ADMIN_PASSWORD = os.getenv('KEYCLOAK_ADMIN_PASSWORD', '')
|
||||
GITLAB_APP_CLIENT_ID = os.getenv('GITLAB_APP_CLIENT_ID', '').strip()
|
||||
GITLAB_APP_CLIENT_SECRET = os.getenv('GITLAB_APP_CLIENT_SECRET', '').strip()
|
||||
GITLAB_TOKEN_URL = f'https://{GITLAB_HOST}/oauth/token'
|
||||
BITBUCKET_APP_CLIENT_ID = os.getenv('BITBUCKET_APP_CLIENT_ID', '').strip()
|
||||
BITBUCKET_APP_CLIENT_SECRET = os.getenv('BITBUCKET_APP_CLIENT_SECRET', '').strip()
|
||||
ENABLE_ENTERPRISE_SSO = os.getenv('ENABLE_ENTERPRISE_SSO', '').strip()
|
||||
|
||||
@@ -35,15 +35,15 @@ from storage.user_authorization_store import UserAuthorizationStore
|
||||
from storage.user_store import UserStore
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
|
||||
from openhands.app_server.secrets.secrets_models import Secrets
|
||||
from openhands.app_server.settings.settings_models import Settings
|
||||
from openhands.app_server.settings.settings_store import SettingsStore
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderToken,
|
||||
ProviderType,
|
||||
)
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.user_auth.user_auth import AuthType, UserAuth
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
token_manager = TokenManager()
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ from server.auth.constants import (
|
||||
GITHUB_APP_CLIENT_SECRET,
|
||||
GITLAB_APP_CLIENT_ID,
|
||||
GITLAB_APP_CLIENT_SECRET,
|
||||
GITLAB_TOKEN_URL,
|
||||
KEYCLOAK_REALM_NAME,
|
||||
KEYCLOAK_SERVER_URL,
|
||||
KEYCLOAK_SERVER_URL_EXT,
|
||||
@@ -417,7 +418,7 @@ class TokenManager:
|
||||
return await self._parse_refresh_response(data)
|
||||
|
||||
async def _refresh_gitlab_token(self, refresh_token: str) -> dict[str, str | int]:
|
||||
url = 'https://gitlab.com/oauth/token'
|
||||
url = GITLAB_TOKEN_URL
|
||||
logger.info(f'Refreshing GitLab token with URL: {url}')
|
||||
|
||||
payload = {
|
||||
|
||||
@@ -72,12 +72,6 @@ class SaaSServerConfig(ServerConfig):
|
||||
auth_url: str | None = os.environ.get('AUTH_URL')
|
||||
settings_store_class: str = 'storage.saas_settings_store.SaasSettingsStore'
|
||||
secret_store_class: str = 'storage.saas_secrets_store.SaasSecretsStore'
|
||||
conversation_store_class: str = (
|
||||
'storage.saas_conversation_store.SaasConversationStore'
|
||||
)
|
||||
monitoring_listener_class: str = (
|
||||
'server.saas_monitoring_listener.SaaSMonitoringListener'
|
||||
)
|
||||
user_auth_class: str = 'server.auth.saas_user_auth.SaasUserAuth'
|
||||
# Maintenance window configuration
|
||||
maintenance_start_time: str = os.environ.get(
|
||||
|
||||
@@ -16,8 +16,8 @@ from server.routes.auth import set_response_cookie
|
||||
from server.utils.url_utils import get_cookie_domain, get_cookie_samesite
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.shared import config
|
||||
from openhands.server.user_auth.user_auth import AuthType, UserAuth, get_user_auth
|
||||
from openhands.server.utils import config
|
||||
|
||||
|
||||
class SetAuthCookieMiddleware:
|
||||
|
||||
@@ -703,6 +703,41 @@ async def accept_tos(request: Request):
|
||||
return response
|
||||
|
||||
|
||||
@api_router.get('/onboarding_status')
|
||||
async def onboarding_status(request: Request):
|
||||
"""Return whether the current user must still complete onboarding.
|
||||
|
||||
Kept as a dedicated endpoint instead of riding on ``GET /api/v1/settings``
|
||||
(the natural home for fields like ``email_verified``) because the settings
|
||||
response is heavyweight: ``SaasSettingsStore.load`` joins User, Org, and
|
||||
OrgMember rows and deep-merges the org-level and member-level
|
||||
``agent_settings`` before returning. Onboarding gating runs on every
|
||||
protected-route navigation, so we need a lightweight read of a single
|
||||
boolean rather than paying for the full settings aggregation.
|
||||
"""
|
||||
user_auth = cast(SaasUserAuth, await get_user_auth(request))
|
||||
user_id = await user_auth.get_user_id()
|
||||
|
||||
if not user_id:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User is not authenticated'},
|
||||
)
|
||||
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'User not found'},
|
||||
)
|
||||
|
||||
should_complete = await _should_redirect_to_onboarding(user_id, user)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'should_complete_onboarding': should_complete},
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('/complete_onboarding')
|
||||
async def complete_onboarding(request: Request):
|
||||
"""Mark onboarding as completed for the current user."""
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.future import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.app_server.utils.dependencies import get_dependencies
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.server.shared import file_store
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint
|
||||
# is protected. The actual protection is provided by SetAuthCookieMiddleware
|
||||
# TODO: It may be an error by you can actually post feedback to a conversation you don't
|
||||
# own right now - maybe this is useful in the context of public shared conversations?
|
||||
router = APIRouter(
|
||||
prefix='/feedback', tags=['feedback'], dependencies=get_dependencies()
|
||||
)
|
||||
|
||||
|
||||
async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
"""Get all event IDs for a given conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation to get events for
|
||||
user_id: The ID of the user who owns the conversation
|
||||
|
||||
Returns:
|
||||
List of event IDs in the conversation
|
||||
|
||||
Raises:
|
||||
HTTPException: If conversation metadata not found
|
||||
"""
|
||||
|
||||
# Verify the conversation belongs to the user
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
)
|
||||
metadata = result.scalars().first()
|
||||
if not metadata:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Conversation {conversation_id} not found',
|
||||
)
|
||||
|
||||
# Create an event store to access the events directly
|
||||
# This works even when the conversation is not running
|
||||
event_store = EventStore(
|
||||
sid=conversation_id,
|
||||
file_store=file_store,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Get events from the event store
|
||||
events = event_store.search_events(start_id=0)
|
||||
|
||||
# Return list of event IDs
|
||||
return [event.id for event in events]
|
||||
|
||||
|
||||
class FeedbackRequest(BaseModel):
|
||||
conversation_id: str
|
||||
event_id: Optional[int] = None
|
||||
rating: int = Field(..., ge=1, le=5)
|
||||
reason: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@router.post('/conversation', status_code=status.HTTP_201_CREATED)
|
||||
async def submit_conversation_feedback(feedback: FeedbackRequest):
|
||||
"""
|
||||
Submit feedback for a conversation.
|
||||
|
||||
This endpoint accepts a rating (1-5) and optional reason for the feedback.
|
||||
The feedback is associated with a specific conversation and optionally a specific event.
|
||||
"""
|
||||
# Validate rating is between 1 and 5
|
||||
if feedback.rating < 1 or feedback.rating > 5:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Rating must be between 1 and 5',
|
||||
)
|
||||
|
||||
# Create new feedback record
|
||||
new_feedback = ConversationFeedback(
|
||||
conversation_id=feedback.conversation_id,
|
||||
event_id=feedback.event_id,
|
||||
rating=feedback.rating,
|
||||
reason=feedback.reason,
|
||||
metadata=feedback.metadata,
|
||||
)
|
||||
|
||||
# Add to database
|
||||
async with a_session_maker() as session:
|
||||
session.add(new_feedback)
|
||||
await session.commit()
|
||||
|
||||
return {'status': 'success', 'message': 'Feedback submitted successfully'}
|
||||
|
||||
|
||||
@router.get('/conversation/{conversation_id}/batch')
|
||||
async def get_batch_feedback(conversation_id: str, user_id: str = Depends(get_user_id)):
|
||||
"""
|
||||
Get feedback for all events in a conversation.
|
||||
|
||||
Returns feedback status for each event, including whether feedback exists
|
||||
and if so, the rating and reason.
|
||||
"""
|
||||
# Get all event IDs for the conversation
|
||||
event_ids = await get_event_ids(conversation_id, user_id)
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
# Query for existing feedback for all events
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ConversationFeedback).where(
|
||||
ConversationFeedback.conversation_id == conversation_id,
|
||||
ConversationFeedback.event_id.in_(event_ids),
|
||||
)
|
||||
)
|
||||
|
||||
# Create a mapping of event_id to feedback
|
||||
feedback_map = {
|
||||
feedback.event_id: {
|
||||
'exists': True,
|
||||
'rating': feedback.rating,
|
||||
'reason': feedback.reason,
|
||||
}
|
||||
for feedback in result.scalars()
|
||||
}
|
||||
|
||||
# Build response including all events
|
||||
response = {}
|
||||
for event_id in event_ids:
|
||||
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
|
||||
|
||||
return response
|
||||
@@ -2,21 +2,25 @@ import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request
|
||||
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from integrations.github.data_collector import GitHubDataCollector
|
||||
from integrations.github.github_manager import GithubManager
|
||||
from integrations.models import Message, SourceType
|
||||
from pydantic import BaseModel
|
||||
from server.auth.constants import (
|
||||
AUTOMATION_EVENT_FORWARDING_ENABLED,
|
||||
GITHUB_APP_WEBHOOK_SECRET,
|
||||
)
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
|
||||
# Environment variable to disable GitHub webhooks
|
||||
GITHUB_WEBHOOKS_ENABLED = os.environ.get('GITHUB_WEBHOOKS_ENABLED', '1') in (
|
||||
@@ -105,3 +109,42 @@ async def github_events(
|
||||
except Exception as e:
|
||||
logger.exception(f'Error processing GitHub event: {e}')
|
||||
return JSONResponse(status_code=400, content={'error': 'Invalid payload.'})
|
||||
|
||||
|
||||
class GitHubTokenResponse(BaseModel):
|
||||
"""Response model for the GitHub token endpoint."""
|
||||
|
||||
access_token: str
|
||||
|
||||
|
||||
@github_integration_router.get('/github/token')
|
||||
async def get_github_token(request: Request) -> GitHubTokenResponse:
|
||||
"""Get the GitHub access token for the authenticated user.
|
||||
|
||||
This endpoint retrieves the user's GitHub OAuth token, refreshing it
|
||||
if necessary. The token can be used for GitHub API operations.
|
||||
|
||||
Returns:
|
||||
GitHubTokenResponse containing the access token.
|
||||
|
||||
Raises:
|
||||
HTTPException 401: If the user is not authenticated.
|
||||
HTTPException 404: If no GitHub token is available for the user.
|
||||
"""
|
||||
user_auth = cast(SaasUserAuth, await get_user_auth(request))
|
||||
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if not provider_tokens:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='No provider tokens available for this user.',
|
||||
)
|
||||
|
||||
github_token = provider_tokens.get(ProviderType.GITHUB)
|
||||
if not github_token or not github_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='No GitHub token available for this user.',
|
||||
)
|
||||
|
||||
return GitHubTokenResponse(access_token=github_token.token.get_secret_value())
|
||||
|
||||
@@ -2,15 +2,7 @@ import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
Depends,
|
||||
Header,
|
||||
HTTPException,
|
||||
Request,
|
||||
status,
|
||||
)
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from integrations.gitlab.gitlab_manager import GitlabManager
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
@@ -23,15 +15,12 @@ from integrations.models import Message, SourceType
|
||||
from integrations.types import GitLabResourceType
|
||||
from integrations.utils import GITLAB_WEBHOOK_URL, IS_LOCAL_DEPLOYMENT
|
||||
from pydantic import BaseModel
|
||||
from server.auth.constants import AUTOMATION_EVENT_FORWARDING_ENABLED
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
from storage.gitlab_webhook import GitlabWebhook
|
||||
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.server.shared import sio
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
@@ -40,7 +29,6 @@ webhook_store = GitlabWebhookStore()
|
||||
|
||||
token_manager = TokenManager()
|
||||
gitlab_manager = GitlabManager(token_manager)
|
||||
automation_event_service = AutomationEventService(token_manager)
|
||||
|
||||
|
||||
# Request/Response models
|
||||
@@ -94,7 +82,6 @@ async def verify_gitlab_signature(
|
||||
@gitlab_integration_router.post('/gitlab/events')
|
||||
async def gitlab_events(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
x_gitlab_token: str = Header(None),
|
||||
x_openhands_webhook_id: str = Header(None),
|
||||
x_openhands_user_id: str = Header(None),
|
||||
@@ -125,16 +112,6 @@ async def gitlab_events(
|
||||
content={'message': 'Duplicate GitLab event ignored.'},
|
||||
)
|
||||
|
||||
# Forward to automation service (fire-and-forget background task)
|
||||
if AUTOMATION_EVENT_FORWARDING_ENABLED:
|
||||
background_tasks.add_task(
|
||||
automation_event_service.forward_event,
|
||||
provider=ProviderType.GITLAB,
|
||||
payload=payload_data,
|
||||
installation_id=x_openhands_webhook_id,
|
||||
)
|
||||
|
||||
# Existing resolver bot processing
|
||||
message = Message(
|
||||
source=SourceType.GITLAB,
|
||||
message={
|
||||
|
||||
@@ -29,6 +29,7 @@ from server.constants import (
|
||||
SLACK_WEBHOOKS_ENABLED,
|
||||
)
|
||||
from server.logger import logger
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.oauth import AuthorizeUrlGenerator
|
||||
from slack_sdk.signature import SignatureVerifier
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
@@ -46,7 +47,16 @@ slack_router = APIRouter(prefix='/slack')
|
||||
|
||||
# Build https://slack.com/oauth/v2/authorize with sufficient query parameters
|
||||
authorize_url_generator = AuthorizeUrlGenerator(
|
||||
client_id=SLACK_CLIENT_ID, scopes=['app_mentions:read', 'chat:write']
|
||||
client_id=SLACK_CLIENT_ID,
|
||||
scopes=[
|
||||
'app_mentions:read',
|
||||
'chat:write',
|
||||
'users:read',
|
||||
'channels:history',
|
||||
'groups:history',
|
||||
'mpim:history',
|
||||
'im:history',
|
||||
],
|
||||
)
|
||||
token_manager = TokenManager()
|
||||
|
||||
@@ -232,7 +242,24 @@ async def keycloak_callback(
|
||||
|
||||
# Retrieve the display_name from slack
|
||||
client = AsyncWebClient(token=bot_access_token)
|
||||
slack_user_info = await client.users_info(user=slack_user_id)
|
||||
try:
|
||||
slack_user_info = await client.users_info(user=slack_user_id)
|
||||
except SlackApiError as e:
|
||||
if e.response.get('error') == 'missing_scope':
|
||||
logger.warning(
|
||||
'slack_missing_scope_during_install',
|
||||
extra={'slack_user_id': slack_user_id, 'team_id': team_id},
|
||||
)
|
||||
return _html_response(
|
||||
title='Re-installation Required',
|
||||
description=(
|
||||
'The Slack app is missing required permissions. '
|
||||
f'Please <a href="{HOST_URL}/slack/install" style="color:#ecedee;text-decoration:underline;">re-install the OpenHands Slack App</a> '
|
||||
'to authorize the updated permissions.'
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
raise
|
||||
slack_display_name = slack_user_info.data['user']['profile']['display_name']
|
||||
slack_user = SlackUser(
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
@@ -366,7 +393,7 @@ async def on_options_load(request: Request, background_tasks: BackgroundTasks):
|
||||
# Verify this is a block_suggestion payload
|
||||
if payload.get('type') != 'block_suggestion':
|
||||
logger.warning(
|
||||
f"slack_on_options_load: Unexpected payload type: {payload.get('type')}"
|
||||
f'slack_on_options_load: Unexpected payload type: {payload.get("type")}'
|
||||
)
|
||||
return JSONResponse({'options': []})
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import os
|
||||
from fastmcp import Client, FastMCP
|
||||
from fastmcp.client.transports import NpxStdioTransport
|
||||
|
||||
from openhands.app_server.mcp.mcp_router import mcp_server
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.routes.mcp import mcp_server
|
||||
|
||||
ENABLE_MCP_SEARCH_ENGINE = (
|
||||
os.getenv('ENABLE_MCP_SEARCH_ENGINE', 'false').lower() == 'true'
|
||||
|
||||
@@ -162,7 +162,6 @@ class OrgResponse(BaseModel):
|
||||
search_api_key: str | None = None
|
||||
sandbox_api_key: str | None = None
|
||||
max_budget_per_task: float | None = None
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
credits: float | None = None
|
||||
is_personal: bool = False
|
||||
@@ -195,7 +194,6 @@ class OrgResponse(BaseModel):
|
||||
search_api_key=None,
|
||||
sandbox_api_key=None,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
credits=credits,
|
||||
is_personal=str(org.id) == user_id if user_id else False,
|
||||
@@ -232,7 +230,6 @@ class OrgUpdate(BaseModel):
|
||||
sandbox_runtime_container_image: str | None = None
|
||||
sandbox_api_key: str | None = None
|
||||
max_budget_per_task: float | None = Field(default=None, gt=0)
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
search_api_key: str | None = None
|
||||
llm_api_key: str | None = None
|
||||
@@ -553,7 +550,6 @@ class OrgAppSettingsResponse(BaseModel):
|
||||
"""Response model for organization app settings."""
|
||||
|
||||
enable_proactive_conversation_starters: bool = True
|
||||
enable_solvability_analysis: bool | None = None
|
||||
max_budget_per_task: float | None = None
|
||||
|
||||
@classmethod
|
||||
@@ -570,7 +566,6 @@ class OrgAppSettingsResponse(BaseModel):
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters
|
||||
if org.enable_proactive_conversation_starters is not None
|
||||
else True,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
)
|
||||
|
||||
@@ -579,7 +574,6 @@ class OrgAppSettingsUpdate(BaseModel):
|
||||
"""Request model for updating organization app settings."""
|
||||
|
||||
enable_proactive_conversation_starters: bool | None = None
|
||||
enable_solvability_analysis: bool | None = None
|
||||
max_budget_per_task: float | None = None
|
||||
|
||||
@field_validator('max_budget_per_task')
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
from server.logger import logger
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
)
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
|
||||
|
||||
class SaaSMonitoringListener(MonitoringListener):
|
||||
"""Forward app signals to structured logging for GCP native monitoring."""
|
||||
|
||||
def on_session_event(self, event: Event) -> None:
|
||||
"""Track metrics about events being added to a Session's EventStream."""
|
||||
if (
|
||||
isinstance(event, AgentStateChangedObservation)
|
||||
and event.agent_state == AgentState.ERROR
|
||||
):
|
||||
logger.info(
|
||||
'Tracking agent status error',
|
||||
extra={'signal': 'saas_agent_status_errors'},
|
||||
)
|
||||
|
||||
def on_agent_session_start(self, success: bool, duration: float) -> None:
|
||||
"""Track an agent session start.
|
||||
|
||||
Success is true if startup completed without error.
|
||||
Duration is start time in seconds observed by AgentSession.
|
||||
"""
|
||||
logger.info(
|
||||
'Tracking agent session start',
|
||||
extra={
|
||||
'signal': 'saas_agent_session_start',
|
||||
'success': success,
|
||||
'duration': duration,
|
||||
},
|
||||
)
|
||||
|
||||
def on_create_conversation(self) -> None:
|
||||
"""Track the beginning of conversation creation.
|
||||
|
||||
Does not currently capture whether it succeed.
|
||||
"""
|
||||
logger.info(
|
||||
'Tracking create conversation', extra={'signal': 'saas_create_conversation'}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_instance(
|
||||
cls,
|
||||
config: OpenHandsConfig,
|
||||
) -> 'SaaSMonitoringListener':
|
||||
return cls()
|
||||
@@ -163,8 +163,7 @@ class AutomationEventService:
|
||||
org_id = await self._resolve_git_org(provider, git_org_name)
|
||||
|
||||
# Fallback for personal repos (owner_type indicates individual user)
|
||||
# GitHub uses 'User', GitLab uses 'user'
|
||||
if not org_id and owner_type and owner_type.lower() == 'user':
|
||||
if not org_id and owner_type == 'User':
|
||||
org_id = await self._resolve_personal_org(provider, owner_id)
|
||||
if org_id:
|
||||
logger.info(
|
||||
@@ -207,18 +206,6 @@ class AutomationEventService:
|
||||
owner = repo.get('owner', {})
|
||||
return owner.get('login'), owner.get('type'), owner.get('id')
|
||||
|
||||
if provider == ProviderType.GITLAB:
|
||||
# GitLab uses 'project' instead of 'repository'
|
||||
# path_with_namespace is like "org-name/repo-name" or "user-name/repo-name"
|
||||
project = payload.get('project', {})
|
||||
path_with_namespace = project.get('path_with_namespace', '')
|
||||
git_org = path_with_namespace.split('/')[0] if path_with_namespace else None
|
||||
namespace = project.get('namespace', {})
|
||||
# GitLab uses 'group' for organizations and 'user' for personal projects
|
||||
owner_type = namespace.get('kind')
|
||||
owner_id = namespace.get('id')
|
||||
return git_org, owner_type, owner_id
|
||||
|
||||
logger.warning(f'Unsupported provider ({provider.value})')
|
||||
return None, None, None
|
||||
|
||||
|
||||
@@ -313,11 +313,22 @@ class OrgInvitationService:
|
||||
raise InvitationInvalidError('User not found')
|
||||
|
||||
user_email = user.email
|
||||
# Fallback: fetch email from Keycloak if not in database (for existing users)
|
||||
# Fallback: fetch email from Keycloak if not in database (for existing users).
|
||||
# When found, persist it back to User.email so the members list shows it
|
||||
# without requiring the user to log out and log back in.
|
||||
if not user_email:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(str(user_id))
|
||||
user_email = user_info.get('email') if user_info else None
|
||||
if user_info:
|
||||
user_email = user_info.get('email')
|
||||
if user_email:
|
||||
await UserStore.backfill_user_email(
|
||||
str(user_id),
|
||||
{
|
||||
'email': user_email,
|
||||
'email_verified': user_info.get('emailVerified', False),
|
||||
},
|
||||
)
|
||||
|
||||
if not user_email:
|
||||
raise EmailMismatchError('Your account does not have an email address')
|
||||
|
||||
@@ -3,7 +3,7 @@ from datetime import datetime
|
||||
# Simplified imports to avoid dependency chain issues
|
||||
# from openhands.integrations.service_types import ProviderType
|
||||
# from openhands.sdk.llm import MetricsSnapshot
|
||||
# from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
# from openhands.app_server.app_conversation.app_conversation_models import ConversationTrigger
|
||||
# For now, use Any to avoid import issues
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -1,295 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
|
||||
from server.logger import logger
|
||||
from sqlalchemy import and_, select
|
||||
from storage.conversation_callback import (
|
||||
CallbackStatus,
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.serialization.event import event_from_dict
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import (
|
||||
get_conversation_agent_state_filename,
|
||||
get_conversation_dir,
|
||||
)
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
config = load_openhands_config()
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
|
||||
|
||||
async def process_event(
|
||||
user_id: str, conversation_id: str, subpath: str, content: dict
|
||||
):
|
||||
"""
|
||||
Process a conversation event and invoke any registered callbacks.
|
||||
|
||||
Args:
|
||||
user_id: The user ID associated with the conversation
|
||||
conversation_id: The conversation ID
|
||||
subpath: The event subpath
|
||||
content: The event content
|
||||
"""
|
||||
logger.debug(
|
||||
'process_event',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'conversation_id': conversation_id,
|
||||
'content': content,
|
||||
},
|
||||
)
|
||||
write_path = get_conversation_dir(conversation_id, user_id) + subpath
|
||||
|
||||
# This writes to the google cloud storage, so we do this in a background thread to not block the main runloop...
|
||||
await call_sync_from_async(file_store.write, write_path, json.dumps(content))
|
||||
|
||||
event = event_from_dict(content)
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
# Load and invoke all active callbacks for this conversation
|
||||
await invoke_conversation_callbacks(conversation_id, event)
|
||||
|
||||
# Update active working seconds if agent state is not Running
|
||||
if event.agent_state != AgentState.RUNNING:
|
||||
event_store = EventStore(conversation_id, file_store, user_id)
|
||||
update_active_working_seconds(
|
||||
event_store, conversation_id, user_id, file_store
|
||||
)
|
||||
|
||||
|
||||
async def invoke_conversation_callbacks(
|
||||
conversation_id: str, observation: AgentStateChangedObservation
|
||||
):
|
||||
"""
|
||||
Load and invoke all active callbacks for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID to process callbacks for
|
||||
observation: The AgentStateChangedObservation that triggered the callback
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ConversationCallback).filter(
|
||||
and_(
|
||||
ConversationCallback.conversation_id == conversation_id,
|
||||
ConversationCallback.status == CallbackStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
)
|
||||
callbacks = result.scalars().all()
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
processor = callback.get_processor()
|
||||
await processor.__call__(callback, observation)
|
||||
logger.info(
|
||||
'callback_invoked_successfully',
|
||||
extra={
|
||||
'conversation_id': conversation_id,
|
||||
'callback_id': callback.id,
|
||||
'processor_type': callback.processor_type,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'callback_invocation_failed',
|
||||
extra={
|
||||
'conversation_id': conversation_id,
|
||||
'callback_id': callback.id,
|
||||
'processor_type': callback.processor_type,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Mark callback as error status
|
||||
callback.status = CallbackStatus.ERROR
|
||||
callback.updated_at = datetime.now()
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
def update_conversation_metadata(conversation_id: str, content: dict):
|
||||
"""
|
||||
Update conversation metadata with new content.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID to update
|
||||
content: The metadata content to update
|
||||
"""
|
||||
logger.debug(
|
||||
'update_conversation_metadata',
|
||||
extra={
|
||||
'conversation_id': conversation_id,
|
||||
'content': content,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
conversation = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
conversation.title = content.get('title') or conversation.title
|
||||
conversation.last_updated_at = datetime.now()
|
||||
conversation.accumulated_cost = (
|
||||
content.get('accumulated_cost') or conversation.accumulated_cost
|
||||
)
|
||||
conversation.prompt_tokens = (
|
||||
content.get('prompt_tokens') or conversation.prompt_tokens
|
||||
)
|
||||
conversation.completion_tokens = (
|
||||
content.get('completion_tokens') or conversation.completion_tokens
|
||||
)
|
||||
conversation.total_tokens = (
|
||||
content.get('total_tokens') or conversation.total_tokens
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
def register_callback_processor(
|
||||
conversation_id: str, processor: ConversationCallbackProcessor
|
||||
) -> int:
|
||||
"""
|
||||
Register a callback processor for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID to register the callback for
|
||||
processor: The ConversationCallbackProcessor instance to register
|
||||
|
||||
Returns:
|
||||
int: The ID of the created callback
|
||||
"""
|
||||
with session_maker() as session:
|
||||
callback = ConversationCallback(
|
||||
conversation_id=conversation_id, status=CallbackStatus.ACTIVE
|
||||
)
|
||||
callback.set_processor(processor)
|
||||
session.add(callback)
|
||||
session.commit()
|
||||
return callback.id
|
||||
|
||||
|
||||
def update_active_working_seconds(
|
||||
event_store: EventStore, conversation_id: str, user_id: str, file_store: FileStore
|
||||
):
|
||||
"""
|
||||
Calculate and update the total active working seconds for a conversation.
|
||||
|
||||
This function reads all events for the conversation, looks for AgentStateChanged
|
||||
observations, and calculates the total time spent in a running state.
|
||||
|
||||
Args:
|
||||
event_store: The EventStore instance for reading events
|
||||
conversation_id: The conversation ID to process
|
||||
user_id: The user ID associated with the conversation
|
||||
file_store: The FileStore instance for accessing conversation data
|
||||
"""
|
||||
try:
|
||||
# Track agent state changes and calculate running time
|
||||
running_start_time = None
|
||||
total_running_seconds = 0.0
|
||||
|
||||
for event in event_store.search_events():
|
||||
if isinstance(event, AgentStateChangedObservation) and event.timestamp:
|
||||
event_timestamp = datetime.fromisoformat(event.timestamp).timestamp()
|
||||
|
||||
if event.agent_state == AgentState.RUNNING:
|
||||
# Agent started running
|
||||
if running_start_time is None:
|
||||
running_start_time = event_timestamp
|
||||
elif running_start_time is not None:
|
||||
# Agent stopped running, calculate duration
|
||||
duration = event_timestamp - running_start_time
|
||||
total_running_seconds += duration
|
||||
running_start_time = None
|
||||
|
||||
# If agent is still running at the end, don't count that time yet
|
||||
# (it will be counted when the agent stops)
|
||||
|
||||
# Create or update the conversation_work record
|
||||
with session_maker() as session:
|
||||
conversation_work = (
|
||||
session.query(ConversationWork)
|
||||
.filter(ConversationWork.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if conversation_work:
|
||||
# Update existing record
|
||||
conversation_work.seconds = total_running_seconds
|
||||
conversation_work.updated_at = datetime.now().isoformat()
|
||||
else:
|
||||
# Create new record
|
||||
conversation_work = ConversationWork(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
seconds=total_running_seconds,
|
||||
)
|
||||
session.add(conversation_work)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'updated_active_working_seconds',
|
||||
extra={
|
||||
'conversation_id': conversation_id,
|
||||
'user_id': user_id,
|
||||
'total_seconds': total_running_seconds,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'failed_to_update_active_working_seconds',
|
||||
extra={
|
||||
'conversation_id': conversation_id,
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def update_agent_state(user_id: str, conversation_id: str, content: bytes):
|
||||
"""
|
||||
Update agent state file for a conversation.
|
||||
|
||||
Args:
|
||||
user_id: The user ID associated with the conversation
|
||||
conversation_id: The conversation ID
|
||||
content: The agent state content as bytes
|
||||
"""
|
||||
logger.debug(
|
||||
'update_agent_state',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'conversation_id': conversation_id,
|
||||
'content_size': len(content),
|
||||
},
|
||||
)
|
||||
write_path = get_conversation_agent_state_filename(conversation_id, user_id)
|
||||
file_store.write(write_path, content)
|
||||
|
||||
|
||||
def update_conversation_stats(user_id: str, conversation_id: str, content: bytes):
|
||||
existing_convo_stats = ConversationStats(
|
||||
file_store=file_store, conversation_id=conversation_id, user_id=user_id
|
||||
)
|
||||
|
||||
incoming_convo_stats = ConversationStats(None, conversation_id, None)
|
||||
pickled = base64.b64decode(content)
|
||||
incoming_convo_stats.restored_metrics = pickle.loads(pickled)
|
||||
|
||||
# Merging automatically saves to file store
|
||||
existing_convo_stats.merge_and_save(incoming_convo_stats)
|
||||
@@ -16,7 +16,7 @@ from server.verified_models.verified_model_service import (
|
||||
)
|
||||
|
||||
from openhands.app_server.config import get_db_session
|
||||
from openhands.server.routes import public
|
||||
from openhands.app_server.config_api.config_router import get_llm_models_dependency
|
||||
from openhands.utils.llm import ModelsResponse, get_supported_llm_models
|
||||
|
||||
api_router = APIRouter(prefix='/api/admin/verified-models', tags=['Verified Models'])
|
||||
@@ -138,6 +138,4 @@ async def get_saas_llm_models_dependency(request: Request) -> ModelsResponse:
|
||||
# This must be called after the app is created in saas_server.py
|
||||
def override_llm_models_dependency(app):
|
||||
"""Override the default LLM models implementation with SaaS version."""
|
||||
app.dependency_overrides[public.get_llm_models_dependency] = (
|
||||
get_saas_llm_models_dependency
|
||||
)
|
||||
app.dependency_overrides[get_llm_models_dependency] = get_saas_llm_models_dependency
|
||||
|
||||
@@ -2,7 +2,6 @@ from storage.api_key import ApiKey
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.conversation_callback import CallbackStatus, ConversationCallback
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.feedback import ConversationFeedback, Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
@@ -45,8 +44,6 @@ __all__ = [
|
||||
'AuthTokens',
|
||||
'BillingSession',
|
||||
'BillingSessionType',
|
||||
'CallbackStatus',
|
||||
'ConversationCallback',
|
||||
'ConversationFeedback',
|
||||
'StoredConversationMetadataSaas',
|
||||
'ConversationWork',
|
||||
|
||||
@@ -3,7 +3,7 @@ Unified SQLAlchemy declarative base for all models.
|
||||
|
||||
Re-exports the core Base to ensure enterprise and core models share the same
|
||||
metadata registry. This allows foreign key relationships between enterprise
|
||||
models (e.g., ConversationCallback) and core models (e.g., StoredConversationMetadata).
|
||||
models and core models (e.g., StoredConversationMetadata).
|
||||
|
||||
The core Base now uses SQLAlchemy 2.0 DeclarativeBase for proper type inference
|
||||
with Mapped types, while remaining backward compatible with existing Column()
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import DateTime, ForeignKey, String, Text, text
|
||||
from sqlalchemy import Enum as SQLEnum
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from storage.base import Base
|
||||
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class ConversationCallbackProcessor(BaseModel, ABC):
|
||||
"""
|
||||
Abstract base class for conversation callback processors.
|
||||
|
||||
Conversation processors are invoked when events occur in a conversation
|
||||
to perform additional processing, notifications, or integrations.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
# Allow extra fields for flexibility
|
||||
extra='allow',
|
||||
# Allow arbitrary types
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(
|
||||
self,
|
||||
callback: ConversationCallback,
|
||||
observation: 'AgentStateChangedObservation',
|
||||
) -> None:
|
||||
"""
|
||||
Process a conversation event.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation to process
|
||||
observation: The AgentStateChangedObservation that triggered the callback
|
||||
callback: The conversation callback
|
||||
"""
|
||||
|
||||
|
||||
class CallbackStatus(Enum):
|
||||
"""Status of a conversation callback."""
|
||||
|
||||
ACTIVE = 'ACTIVE'
|
||||
COMPLETED = 'COMPLETED'
|
||||
ERROR = 'ERROR'
|
||||
|
||||
|
||||
class ConversationCallback(Base):
|
||||
"""
|
||||
Model for storing conversation callbacks that process conversation events.
|
||||
"""
|
||||
|
||||
__tablename__ = 'conversation_callbacks'
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
conversation_id: Mapped[str] = mapped_column(
|
||||
String,
|
||||
ForeignKey('conversation_metadata.conversation_id'),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
status: Mapped[CallbackStatus] = mapped_column(
|
||||
SQLEnum(CallbackStatus), nullable=False, default=CallbackStatus.ACTIVE
|
||||
)
|
||||
processor_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
processor_json: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
onupdate=datetime.now,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def get_processor(self) -> ConversationCallbackProcessor:
|
||||
"""
|
||||
Get the processor instance from the stored processor type and JSON data.
|
||||
|
||||
Returns:
|
||||
ConversationCallbackProcessor: The processor instance
|
||||
"""
|
||||
# Import the processor class dynamically
|
||||
processor_class: type[ConversationCallbackProcessor] = get_impl(
|
||||
ConversationCallbackProcessor, self.processor_type
|
||||
)
|
||||
processor = processor_class.model_validate_json(self.processor_json)
|
||||
return processor
|
||||
|
||||
def set_processor(self, processor: ConversationCallbackProcessor) -> None:
|
||||
"""
|
||||
Set the processor instance, storing its type and JSON representation.
|
||||
|
||||
Args:
|
||||
processor: The ConversationCallbackProcessor instance to store
|
||||
"""
|
||||
self.processor_type = (
|
||||
f'{processor.__class__.__module__}.{processor.__class__.__name__}'
|
||||
)
|
||||
self.processor_json = processor.model_dump_json()
|
||||
@@ -1,10 +1,14 @@
|
||||
import binascii
|
||||
import hashlib
|
||||
import json
|
||||
from base64 import b64decode, b64encode
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from pydantic import SecretStr
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from server.config import get_config
|
||||
from sqlalchemy import String, TypeDecorator
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
|
||||
_jwt_service = None
|
||||
_fernet = None
|
||||
@@ -135,3 +139,39 @@ def model_to_kwargs(model_instance):
|
||||
column.name: getattr(model_instance, column.name)
|
||||
for column in model_instance.__table__.columns
|
||||
}
|
||||
|
||||
|
||||
class EncryptedJSON(TypeDecorator[dict[str, Any]]):
|
||||
"""JSON column whose serialized payload is encrypted at rest.
|
||||
|
||||
Accepts either a plain ``dict`` or a pydantic ``BaseModel``. Pydantic
|
||||
models are dumped via ``model_dump(mode='json', context={'expose_secrets': True})``
|
||||
so nested ``SecretStr`` values keep their real payload — the column
|
||||
itself is the encryption boundary, so masking on the way in would
|
||||
corrupt round-trips.
|
||||
|
||||
Use for JSON payloads that may contain secrets (e.g. nested ``api_key``
|
||||
fields) where the existing ``_<field>`` String + property pattern is
|
||||
awkward — this keeps the column accessible as a normal ORM attribute
|
||||
while encrypting the entire JSON blob via the same JWE service used
|
||||
by ``encrypt_value``/``decrypt_value``.
|
||||
"""
|
||||
|
||||
impl = String
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(
|
||||
self, value: BaseModel | dict[str, Any] | None, dialect: Dialect
|
||||
) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, BaseModel):
|
||||
value = value.model_dump(mode='json', context={'expose_secrets': True})
|
||||
return encrypt_value(json.dumps(value))
|
||||
|
||||
def process_result_value(
|
||||
self, value: str | None, dialect: Dialect
|
||||
) -> dict[str, Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
return json.loads(decrypt_value(value))
|
||||
|
||||
@@ -19,7 +19,7 @@ from server.constants import (
|
||||
from server.logger import logger
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.app_server.settings.settings_models import Settings
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# Timeout in seconds for key verification requests to LiteLLM
|
||||
|
||||
@@ -62,9 +62,6 @@ class Org(Base):
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_sandbox_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
max_budget_per_task: Mapped[float | None] = mapped_column(nullable=True)
|
||||
enable_solvability_analysis: Mapped[bool | None] = mapped_column(
|
||||
nullable=True, default=False
|
||||
)
|
||||
v1_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
conversation_expiration: Mapped[int | None] = mapped_column(nullable=True)
|
||||
byor_export_enabled: Mapped[bool] = mapped_column(nullable=False, default=False)
|
||||
|
||||
@@ -14,7 +14,7 @@ from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.app_server.settings.settings_models import Settings
|
||||
from openhands.utils.jsonpatch_compat import deep_merge
|
||||
|
||||
|
||||
|
||||
@@ -24,9 +24,9 @@ from storage.org_store import OrgStore
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.app_server.settings.settings_models import Settings
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.sdk.settings import AgentSettings, ConversationSettings
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class OrgService:
|
||||
|
||||
@@ -24,9 +24,9 @@ from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.app_server.settings.settings_models import Settings
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.sdk.settings import AgentSettings, ConversationSettings
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.utils.jsonpatch_compat import deep_merge
|
||||
from openhands.utils.llm import is_openhands_model
|
||||
|
||||
|
||||
@@ -1,280 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC
|
||||
from typing import TYPE_CHECKING, Callable, ContextManager
|
||||
from uuid import UUID
|
||||
|
||||
from storage.database import session_maker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.storage.data_models.conversation_metadata_result_set import (
|
||||
ConversationMetadataResultSet,
|
||||
)
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.search_utils import offset_to_page_id, page_id_to_offset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaasConversationStore(ConversationStore):
|
||||
user_id: str
|
||||
session_maker: Callable[[], ContextManager[Session]]
|
||||
org_id: UUID | None = None # will be fetched automatically
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
org_id: UUID | None,
|
||||
session_maker: Callable[[], ContextManager[Session]],
|
||||
resolver_org_id: UUID | None = None,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.org_id = org_id
|
||||
self.session_maker = session_maker
|
||||
self.resolver_org_id = resolver_org_id
|
||||
|
||||
def _select_by_id(self, session, conversation_id: str):
|
||||
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
|
||||
query = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.user_id == UUID(self.user_id))
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.filter(StoredConversationMetadata.conversation_version == 'V0')
|
||||
)
|
||||
|
||||
if self.org_id is not None:
|
||||
query = query.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
|
||||
return query
|
||||
|
||||
def _to_external_model(
|
||||
self, conversation_metadata: StoredConversationMetadata
|
||||
) -> ConversationMetadata:
|
||||
kwargs = {
|
||||
c.name: getattr(conversation_metadata, c.name)
|
||||
for c in StoredConversationMetadata.__table__.columns
|
||||
}
|
||||
# TODO: I'm not sure why the timezone is not set on the dates coming back out of the db
|
||||
kwargs['created_at'] = kwargs['created_at'].replace(tzinfo=UTC)
|
||||
kwargs['last_updated_at'] = kwargs['last_updated_at'].replace(tzinfo=UTC)
|
||||
if kwargs['trigger']:
|
||||
kwargs['trigger'] = ConversationTrigger(kwargs['trigger'])
|
||||
if kwargs['git_provider'] and isinstance(kwargs['git_provider'], str):
|
||||
# Convert string to ProviderType enum
|
||||
kwargs['git_provider'] = ProviderType(kwargs['git_provider'])
|
||||
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Remove V1 attributes
|
||||
kwargs.pop('max_budget_per_task', None)
|
||||
kwargs.pop('cache_read_tokens', None)
|
||||
kwargs.pop('cache_write_tokens', None)
|
||||
kwargs.pop('reasoning_tokens', None)
|
||||
kwargs.pop('context_window', None)
|
||||
kwargs.pop('per_turn_token', None)
|
||||
kwargs.pop('parent_conversation_id', None)
|
||||
kwargs.pop('public')
|
||||
|
||||
return ConversationMetadata(**kwargs)
|
||||
|
||||
async def save_metadata(self, metadata: ConversationMetadata):
|
||||
kwargs = dataclasses.asdict(metadata)
|
||||
|
||||
# Remove user_id and org_id from kwargs since they're no longer in StoredConversationMetadata
|
||||
kwargs.pop('user_id', None)
|
||||
kwargs.pop('org_id', None)
|
||||
|
||||
# Convert ProviderType enum to string for storage
|
||||
if kwargs.get('git_provider') is not None:
|
||||
kwargs['git_provider'] = (
|
||||
kwargs['git_provider'].value
|
||||
if hasattr(kwargs['git_provider'], 'value')
|
||||
else kwargs['git_provider']
|
||||
)
|
||||
|
||||
stored_metadata = StoredConversationMetadata(**kwargs)
|
||||
|
||||
# Override with resolver org_id if set (from git org claim resolution),
|
||||
# same pattern as V1's save_app_conversation_info in
|
||||
# saas_app_conversation_info_injector.py
|
||||
org_id = self.org_id
|
||||
if self.resolver_org_id is not None:
|
||||
org_id = self.resolver_org_id
|
||||
|
||||
def _save_metadata():
|
||||
with self.session_maker() as session:
|
||||
# Save the main conversation metadata
|
||||
session.merge(stored_metadata)
|
||||
|
||||
# Create or update the SaaS metadata record
|
||||
saas_metadata = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== stored_metadata.conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not saas_metadata:
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=stored_metadata.conversation_id,
|
||||
user_id=UUID(self.user_id),
|
||||
org_id=org_id,
|
||||
)
|
||||
session.add(saas_metadata)
|
||||
else:
|
||||
# Validate
|
||||
expected_user_id = UUID(self.user_id)
|
||||
expected_org_id = org_id
|
||||
|
||||
if saas_metadata.user_id != expected_user_id:
|
||||
raise ValueError(
|
||||
f'Existing user_id ({saas_metadata.user_id}) does not match expected value ({expected_user_id}).'
|
||||
)
|
||||
|
||||
if expected_org_id and saas_metadata.org_id != expected_org_id:
|
||||
raise ValueError(
|
||||
f'Existing org_id ({saas_metadata.org_id}) does not match expected value ({expected_org_id}).'
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_save_metadata)
|
||||
|
||||
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
|
||||
def _get_metadata():
|
||||
with self.session_maker() as session:
|
||||
conversation_metadata = self._select_by_id(
|
||||
session, conversation_id
|
||||
).first()
|
||||
if not conversation_metadata:
|
||||
raise FileNotFoundError(conversation_id)
|
||||
return self._to_external_model(conversation_metadata)
|
||||
|
||||
return await call_sync_from_async(_get_metadata)
|
||||
|
||||
async def delete_metadata(self, conversation_id: str) -> None:
|
||||
def _delete_metadata():
|
||||
with self.session_maker() as session:
|
||||
saas_record = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id),
|
||||
StoredConversationMetadataSaas.org_id == self.org_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if saas_record:
|
||||
# Delete both records, but only if the SaaS one exists
|
||||
session.query(StoredConversationMetadata).filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
).delete()
|
||||
|
||||
session.delete(saas_record)
|
||||
|
||||
session.commit()
|
||||
else:
|
||||
# No SaaS record found → skip deleting main metadata
|
||||
session.rollback()
|
||||
|
||||
await call_sync_from_async(_delete_metadata)
|
||||
|
||||
async def exists(self, conversation_id: str) -> bool:
|
||||
def _exists():
|
||||
with self.session_maker() as session:
|
||||
result = self._select_by_id(session, conversation_id).scalar()
|
||||
return bool(result)
|
||||
|
||||
return await call_sync_from_async(_exists)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
page_id: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ConversationMetadataResultSet:
|
||||
offset = page_id_to_offset(page_id)
|
||||
|
||||
def _search():
|
||||
with self.session_maker() as session:
|
||||
stored_conversations = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id)
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
.filter(StoredConversationMetadata.conversation_version == 'V0')
|
||||
.order_by(StoredConversationMetadata.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit + 1)
|
||||
.all()
|
||||
)
|
||||
conversations = [
|
||||
self._to_external_model(c) for c in stored_conversations
|
||||
]
|
||||
current_page_size = len(conversations)
|
||||
next_page_id = offset_to_page_id(
|
||||
offset + limit, current_page_size > limit
|
||||
)
|
||||
return ConversationMetadataResultSet(
|
||||
conversations[:limit], next_page_id
|
||||
)
|
||||
|
||||
return await call_sync_from_async(_search)
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
cls,
|
||||
config: OpenHandsConfig,
|
||||
user_id: str, # type: ignore[override]
|
||||
) -> ConversationStore:
|
||||
# Use async version since callers now use asyncio.run_coroutine_threadsafe()
|
||||
# to dispatch to the main event loop where asyncpg connections work properly.
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
return SaasConversationStore(user_id, org_id, session_maker)
|
||||
|
||||
@classmethod
|
||||
async def get_resolver_instance(
|
||||
cls,
|
||||
config: OpenHandsConfig,
|
||||
user_id: str,
|
||||
resolver_org_id: UUID | None = None,
|
||||
) -> 'SaasConversationStore':
|
||||
"""Get a store for resolver conversations with explicit org routing.
|
||||
|
||||
Unlike get_instance, this accepts a resolver_org_id that overrides
|
||||
the user's default org when saving conversation metadata.
|
||||
"""
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
return SaasConversationStore(user_id, org_id, session_maker, resolver_org_id)
|
||||
@@ -1,148 +0,0 @@
|
||||
from server.auth.auth_error import AuthError, ExpiredError
|
||||
from server.auth.saas_user_auth import saas_user_auth_from_signed_token
|
||||
from server.auth.token_manager import TokenManager
|
||||
from socketio.exceptions import ConnectionRefusedError
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.shared import ConversationStoreImpl
|
||||
from openhands.storage.conversation.conversation_validator import ConversationValidator
|
||||
|
||||
|
||||
class SaasConversationValidator(ConversationValidator):
|
||||
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
|
||||
|
||||
async def _validate_api_key(self, api_key: str) -> str | None:
|
||||
"""
|
||||
Validate an API key and return the user_id if valid.
|
||||
|
||||
Args:
|
||||
api_key: The API key to validate
|
||||
|
||||
Returns:
|
||||
The user_id if the API key is valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
token_manager = TokenManager()
|
||||
|
||||
# Validate the API key and get the user_id
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
validation_result = await api_key_store.validate_api_key(api_key)
|
||||
|
||||
if not validation_result:
|
||||
logger.warning('Invalid API key')
|
||||
return None
|
||||
|
||||
user_id = validation_result.user_id
|
||||
|
||||
# Get the offline token for the user
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
if not offline_token:
|
||||
logger.warning(f'No offline token found for user {user_id}')
|
||||
return None
|
||||
|
||||
return user_id
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Error validating API key: {str(e)}')
|
||||
return None
|
||||
|
||||
async def _validate_conversation_access(
|
||||
self, conversation_id: str, user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that the user has access to the conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
user_id: The ID of the user
|
||||
github_user_id: The GitHub user ID, if available
|
||||
|
||||
Returns:
|
||||
True if the user has access to the conversation, False otherwise
|
||||
|
||||
Raises:
|
||||
ConnectionRefusedError: If the user does not have access to the conversation
|
||||
"""
|
||||
config = load_openhands_config()
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
|
||||
if not await conversation_store.validate_metadata(conversation_id, user_id):
|
||||
logger.error(
|
||||
f'User {user_id} is not allowed to join conversation {conversation_id}'
|
||||
)
|
||||
raise ConnectionRefusedError(
|
||||
f'User {user_id} is not allowed to join conversation {conversation_id}'
|
||||
)
|
||||
return True
|
||||
|
||||
async def validate(
|
||||
self,
|
||||
conversation_id: str,
|
||||
cookies_str: str,
|
||||
authorization_header: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Validate the conversation access using either an API key from the Authorization header
|
||||
or a keycloak_auth cookie.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
cookies_str: The cookies string from the request
|
||||
authorization_header: The Authorization header from the request, if available
|
||||
|
||||
Returns:
|
||||
A tuple of (user_id, github_user_id)
|
||||
|
||||
Raises:
|
||||
ConnectionRefusedError: If the user does not have access to the conversation
|
||||
AuthError: If the authentication fails
|
||||
RuntimeError: If there is an error with the configuration or user info
|
||||
"""
|
||||
# Try to authenticate using Authorization header first
|
||||
if authorization_header and authorization_header.startswith('Bearer '):
|
||||
api_key = authorization_header.replace('Bearer ', '')
|
||||
user_id = await self._validate_api_key(api_key)
|
||||
|
||||
if user_id:
|
||||
logger.info(
|
||||
f'User {user_id} is connecting to conversation {conversation_id} via API key'
|
||||
)
|
||||
|
||||
await self._validate_conversation_access(conversation_id, user_id)
|
||||
return user_id
|
||||
|
||||
# Fall back to cookie authentication
|
||||
token_manager = TokenManager()
|
||||
config = load_openhands_config()
|
||||
cookies = (
|
||||
dict(cookie.split('=', 1) for cookie in cookies_str.split('; '))
|
||||
if cookies_str
|
||||
else {}
|
||||
)
|
||||
|
||||
signed_token = cookies.get('keycloak_auth', '')
|
||||
if not signed_token:
|
||||
logger.warning('No keycloak_auth cookie or valid Authorization header')
|
||||
raise ConnectionRefusedError(
|
||||
'No keycloak_auth cookie or valid Authorization header'
|
||||
)
|
||||
if not config.jwt_secret:
|
||||
raise RuntimeError('JWT secret not found')
|
||||
|
||||
try:
|
||||
user_auth = await saas_user_auth_from_signed_token(signed_token)
|
||||
access_token = await user_auth.get_access_token()
|
||||
except ExpiredError:
|
||||
raise ConnectionRefusedError('SESSION$TIMEOUT_MESSAGE')
|
||||
if access_token is None:
|
||||
raise AuthError('no_access_token')
|
||||
user_info = await token_manager.get_user_info(access_token.get_secret_value())
|
||||
# sub is a required field in KeycloakUserInfo, validation happens in get_user_info
|
||||
user_id = user_info.sub
|
||||
|
||||
logger.info(f'User {user_id} is connecting to conversation {conversation_id}')
|
||||
|
||||
await self._validate_conversation_access(conversation_id, user_id) # type: ignore
|
||||
return user_id
|
||||
@@ -10,10 +10,10 @@ from storage.database import a_session_maker
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.app_server.secrets.secrets_models import Secrets
|
||||
from openhands.app_server.secrets.secrets_store import SecretsStore
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -21,9 +21,9 @@ from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.app_server.settings.settings_models import Settings
|
||||
from openhands.app_server.settings.settings_store import SettingsStore
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.jsonpatch_compat import deep_merge
|
||||
from openhands.utils.llm import is_openhands_model
|
||||
|
||||
@@ -148,6 +148,10 @@ class SaasSettingsStore(SettingsStore):
|
||||
# Apply default if sandbox_grouping_strategy is None in the database
|
||||
if kwargs.get('sandbox_grouping_strategy') is None:
|
||||
kwargs.pop('sandbox_grouping_strategy', None)
|
||||
# Pre-migration rows read back as None; Settings.llm_profiles is
|
||||
# non-nullable, so let the default_factory take over.
|
||||
if kwargs.get('llm_profiles') is None:
|
||||
kwargs.pop('llm_profiles', None)
|
||||
|
||||
return Settings(**kwargs)
|
||||
|
||||
|
||||
@@ -3,13 +3,14 @@ SQLAlchemy model for User.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import JSON
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import EncryptedJSON
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from storage.org import Org
|
||||
@@ -36,6 +37,9 @@ class User(Base):
|
||||
git_user_email: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
sandbox_grouping_strategy: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
disabled_skills: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
|
||||
llm_profiles: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
EncryptedJSON, nullable=True
|
||||
)
|
||||
onboarding_completed: Mapped[bool | None] = mapped_column(
|
||||
nullable=True, default=False
|
||||
)
|
||||
|
||||
@@ -50,9 +50,6 @@ class UserSettings(Base):
|
||||
search_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
sandbox_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
max_budget_per_task: Mapped[float | None] = mapped_column(nullable=True)
|
||||
enable_solvability_analysis: Mapped[bool | None] = mapped_column(
|
||||
nullable=True, default=False
|
||||
)
|
||||
email: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
email_verified: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
git_user_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
@@ -88,8 +85,8 @@ class UserSettings(Base):
|
||||
) # False = not migrated, True = migrated
|
||||
|
||||
def to_settings(self):
|
||||
from openhands.app_server.settings.settings_models import Settings
|
||||
from openhands.sdk.settings import AgentSettings, ConversationSettings
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
return Settings(
|
||||
agent_settings=AgentSettings.model_validate(self.agent_settings or {}),
|
||||
|
||||
@@ -931,7 +931,7 @@ class UserStore:
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.app_server.settings.settings_models import Settings
|
||||
|
||||
@staticmethod
|
||||
async def create_default_settings(
|
||||
@@ -945,7 +945,7 @@ class UserStore:
|
||||
if not org_id:
|
||||
return None
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.app_server.settings.settings_models import Settings
|
||||
|
||||
default_settings = Settings(
|
||||
language='en', enable_proactive_conversation_starters=True
|
||||
@@ -1049,7 +1049,6 @@ class UserStore:
|
||||
if org.sandbox_api_key
|
||||
else None,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
sandbox_grouping_strategy=org.sandbox_grouping_strategy,
|
||||
agent_settings=agent_settings,
|
||||
|
||||
@@ -32,7 +32,6 @@ from openhands.app_server.sandbox.sandbox_models import (
|
||||
SandboxInfo,
|
||||
SandboxStatus,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -76,7 +75,10 @@ def conversation_state_update_event():
|
||||
|
||||
@pytest.fixture
|
||||
def wrong_event():
|
||||
return MessageAction(content='Hello world')
|
||||
"""Return a mock event that is not a ConversationStateUpdateEvent."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.id = uuid4()
|
||||
return mock_event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from integrations.github.github_view import (
|
||||
@@ -17,7 +17,6 @@ from jinja2 import Environment, FileSystemLoader
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartTaskStatus,
|
||||
)
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -44,11 +43,8 @@ class _FakeAppConversationService:
|
||||
yield MagicMock(status=AppConversationStartTaskStatus.READY, detail=None)
|
||||
|
||||
|
||||
def _build_conversation_metadata() -> ConversationMetadata:
|
||||
return ConversationMetadata(
|
||||
conversation_id=str(uuid4()),
|
||||
selected_repository='test-owner/test-repo',
|
||||
)
|
||||
def _build_conversation_id() -> UUID:
|
||||
return uuid4()
|
||||
|
||||
|
||||
def _build_user_data() -> UserData:
|
||||
@@ -77,7 +73,6 @@ class TestGithubViewV1InitialUserMessage:
|
||||
title='ignored',
|
||||
description='ignored',
|
||||
previous_comments=[],
|
||||
v1_enabled=True,
|
||||
comment_body='please fix this',
|
||||
comment_id=999,
|
||||
)
|
||||
@@ -98,7 +93,7 @@ class TestGithubViewV1InitialUserMessage:
|
||||
await view._create_v1_conversation(
|
||||
jinja_env=jinja_env,
|
||||
saas_user_auth=MagicMock(),
|
||||
conversation_metadata=_build_conversation_metadata(),
|
||||
conversation_id=_build_conversation_id(),
|
||||
)
|
||||
|
||||
assert len(fake_service.requests) == 1
|
||||
@@ -131,7 +126,6 @@ class TestGithubViewV1InitialUserMessage:
|
||||
title='ignored',
|
||||
description='ignored',
|
||||
previous_comments=[],
|
||||
v1_enabled=True,
|
||||
comment_body='nit: rename variable',
|
||||
comment_id=1001,
|
||||
branch_name='feature-branch',
|
||||
@@ -155,7 +149,7 @@ class TestGithubViewV1InitialUserMessage:
|
||||
await view._create_v1_conversation(
|
||||
jinja_env=jinja_env,
|
||||
saas_user_auth=MagicMock(),
|
||||
conversation_metadata=_build_conversation_metadata(),
|
||||
conversation_id=_build_conversation_id(),
|
||||
)
|
||||
|
||||
assert len(fake_service.requests) == 1
|
||||
@@ -187,7 +181,6 @@ class TestGithubViewV1InitialUserMessage:
|
||||
title='ignored',
|
||||
description='ignored',
|
||||
previous_comments=[],
|
||||
v1_enabled=True,
|
||||
comment_body='please add a null check',
|
||||
comment_id=1002,
|
||||
branch_name='feature-branch',
|
||||
@@ -210,7 +203,7 @@ class TestGithubViewV1InitialUserMessage:
|
||||
await view._create_v1_conversation(
|
||||
jinja_env=jinja_env,
|
||||
saas_user_auth=MagicMock(),
|
||||
conversation_metadata=_build_conversation_metadata(),
|
||||
conversation_id=_build_conversation_id(),
|
||||
)
|
||||
|
||||
req = fake_service.requests[0]
|
||||
|
||||
@@ -5,13 +5,12 @@ All conversations now use V1 app conversation system.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from integrations.gitlab.gitlab_view import GitlabIssue
|
||||
from integrations.types import UserData
|
||||
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gitlab_view():
|
||||
@@ -35,7 +34,6 @@ def mock_gitlab_view():
|
||||
description='Test description',
|
||||
previous_comments=[],
|
||||
is_mr=False,
|
||||
v1_enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -57,12 +55,9 @@ def mock_saas_user_auth():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_convo_metadata():
|
||||
"""Create a mock ConversationMetadata."""
|
||||
return ConversationMetadata(
|
||||
conversation_id='test_conversation_id',
|
||||
selected_repository='test-group/test-repo',
|
||||
)
|
||||
def mock_conversation_id():
|
||||
"""Create a mock conversation UUID."""
|
||||
return uuid4()
|
||||
|
||||
|
||||
class TestGitlabManagerJobCreation:
|
||||
@@ -81,7 +76,7 @@ class TestGitlabManagerJobCreation:
|
||||
mock_token_manager,
|
||||
mock_gitlab_view,
|
||||
mock_saas_user_auth,
|
||||
mock_convo_metadata,
|
||||
mock_conversation_id,
|
||||
):
|
||||
"""Test that start_job creates a conversation and sends acknowledgment message."""
|
||||
from integrations.gitlab.gitlab_manager import GitlabManager
|
||||
@@ -91,7 +86,7 @@ class TestGitlabManagerJobCreation:
|
||||
|
||||
# Mock the view's methods
|
||||
mock_gitlab_view.initialize_new_conversation = AsyncMock(
|
||||
return_value=mock_convo_metadata
|
||||
return_value=mock_conversation_id
|
||||
)
|
||||
mock_gitlab_view.create_new_conversation = AsyncMock()
|
||||
|
||||
|
||||
@@ -12,6 +12,16 @@ def gitlab_service():
|
||||
return SaaSGitLabService(external_auth_id='test_user_id')
|
||||
|
||||
|
||||
class TestSaaSGitLabServiceInit:
|
||||
"""Tests for SaaSGitLabService __init__."""
|
||||
|
||||
def test_explicit_base_domain_overrides_default(self):
|
||||
"""An explicit base_domain parameter overrides the upstream class default."""
|
||||
service = SaaSGitLabService(external_auth_id='u1', base_domain='other.host')
|
||||
|
||||
assert service.BASE_URL == 'https://other.host/api/v4'
|
||||
|
||||
|
||||
class TestGetUserResourcesWithAdminAccess:
|
||||
"""Test cases for get_user_resources_with_admin_access method."""
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ from openhands.app_server.sandbox.sandbox_models import (
|
||||
SandboxInfo,
|
||||
SandboxStatus,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -73,7 +72,10 @@ def conversation_state_update_event():
|
||||
|
||||
@pytest.fixture
|
||||
def wrong_event():
|
||||
return MessageAction(content='Hello world')
|
||||
"""Return a mock event that is not a ConversationStateUpdateEvent."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.id = uuid4()
|
||||
return mock_event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -215,7 +215,6 @@ def new_conversation_view(
|
||||
conversation_id='conv-123',
|
||||
_decrypted_api_key='decrypted_key',
|
||||
)
|
||||
view.v1_enabled = False
|
||||
return view
|
||||
|
||||
|
||||
|
||||
@@ -444,10 +444,10 @@ class TestJiraV1Conversation:
|
||||
"""Tests for V1 conversation creation and callback processor registration."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_v1_metadata_generates_conversation_id(
|
||||
async def test_initialize_conversation_generates_conversation_id(
|
||||
self, new_conversation_view
|
||||
):
|
||||
"""Test that _create_v1_metadata generates a new conversation ID."""
|
||||
"""Test that _initialize_conversation generates a new conversation ID."""
|
||||
new_conversation_view.conversation_id = ''
|
||||
|
||||
with patch.object(
|
||||
@@ -455,17 +455,19 @@ class TestJiraV1Conversation:
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = None
|
||||
|
||||
metadata = await new_conversation_view._create_v1_metadata()
|
||||
conversation_id = await new_conversation_view._initialize_conversation()
|
||||
|
||||
# Conversation ID should be generated
|
||||
assert new_conversation_view.conversation_id != ''
|
||||
assert len(new_conversation_view.conversation_id) == 32 # UUID hex format
|
||||
assert metadata.conversation_id == new_conversation_view.conversation_id
|
||||
assert conversation_id.hex == new_conversation_view.conversation_id
|
||||
mock_get_org.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_v1_metadata_sets_resolved_org(self, new_conversation_view):
|
||||
"""Test that _create_v1_metadata sets resolved_org_id."""
|
||||
async def test_initialize_conversation_sets_resolved_org(
|
||||
self, new_conversation_view
|
||||
):
|
||||
"""Test that _initialize_conversation sets resolved_org_id."""
|
||||
from uuid import UUID
|
||||
|
||||
test_org_id = UUID('12345678-1234-5678-1234-567812345678')
|
||||
@@ -475,7 +477,7 @@ class TestJiraV1Conversation:
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = test_org_id
|
||||
|
||||
await new_conversation_view._create_v1_metadata()
|
||||
await new_conversation_view._initialize_conversation()
|
||||
|
||||
assert new_conversation_view.resolved_org_id == test_org_id
|
||||
|
||||
|
||||
@@ -28,9 +28,16 @@ from openhands.app_server.sandbox.sandbox_models import (
|
||||
SandboxInfo,
|
||||
SandboxStatus,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
|
||||
|
||||
def _create_mock_event():
|
||||
"""Create a mock event that is not a ConversationStateUpdateEvent."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.id = uuid4()
|
||||
return mock_event
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -105,8 +112,10 @@ class TestSlackV1CallbackProcessor:
|
||||
@pytest.mark.parametrize(
|
||||
'event,expected_result',
|
||||
[
|
||||
# Wrong event types should be ignored
|
||||
(MessageAction(content='Hello world'), None),
|
||||
# Wrong event types should be ignored (use lazy evaluation for mock)
|
||||
pytest.param(
|
||||
None, None, id='wrong_event_type', marks=pytest.mark.wrong_event_type
|
||||
),
|
||||
# Wrong state values should be ignored
|
||||
(
|
||||
ConversationStateUpdateEvent(key='execution_status', value='running'),
|
||||
@@ -120,9 +129,12 @@ class TestSlackV1CallbackProcessor:
|
||||
],
|
||||
)
|
||||
async def test_event_filtering(
|
||||
self, slack_callback_processor, event_callback, event, expected_result
|
||||
self, slack_callback_processor, event_callback, event, expected_result, request
|
||||
):
|
||||
"""Test that processor correctly filters events."""
|
||||
# Handle the mock event case specially
|
||||
if event is None and 'wrong_event_type' in request.node.name:
|
||||
event = _create_mock_event()
|
||||
result = await slack_callback_processor(uuid4(), event_callback, event)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
@@ -69,7 +69,6 @@ def slack_new_conversation_view(mock_slack_user, mock_user_auth):
|
||||
send_summary_instruction=True,
|
||||
conversation_id='',
|
||||
team_id='T1234567890',
|
||||
v1_enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -99,7 +98,6 @@ def slack_update_conversation_view_v1(mock_slack_user, mock_user_auth):
|
||||
conversation_id=conversation_id,
|
||||
slack_conversation=mock_conversation,
|
||||
team_id='T1234567890',
|
||||
v1_enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -111,18 +109,15 @@ def slack_update_conversation_view_v1(mock_slack_user, mock_user_auth):
|
||||
class TestV1ConversationCreation:
|
||||
"""Test V1 conversation creation in Slack integration."""
|
||||
|
||||
@patch('integrations.slack.slack_view.is_v1_enabled_for_slack_resolver')
|
||||
@patch.object(SlackNewConversationView, '_create_v1_conversation')
|
||||
async def test_v1_conversation_creation(
|
||||
self,
|
||||
mock_create_v1,
|
||||
mock_is_v1_enabled,
|
||||
slack_new_conversation_view,
|
||||
mock_jinja_env,
|
||||
):
|
||||
"""Test that V1 conversations are created correctly."""
|
||||
# Setup mocks
|
||||
mock_is_v1_enabled.return_value = True
|
||||
mock_create_v1.return_value = None
|
||||
|
||||
# Execute
|
||||
@@ -132,7 +127,6 @@ class TestV1ConversationCreation:
|
||||
|
||||
# Verify
|
||||
assert result == slack_new_conversation_view.conversation_id
|
||||
assert slack_new_conversation_view.v1_enabled is True
|
||||
mock_create_v1.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from enterprise.integrations.resolver_context import ResolverUserContext
|
||||
from openhands.app_server.secrets.secrets_models import Secrets
|
||||
|
||||
# Import the real classes we want to test
|
||||
from openhands.integrations.provider import CustomSecret, ProviderToken
|
||||
@@ -17,7 +18,6 @@ from openhands.integrations.service_types import ProviderType
|
||||
|
||||
# Import the SDK types we need for testing
|
||||
from openhands.sdk.secret import SecretSource, StaticSecret
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,171 +1,11 @@
|
||||
"""Tests for enterprise integrations utils module."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
append_conversation_footer,
|
||||
get_session_expired_message,
|
||||
get_summary_for_agent_state,
|
||||
get_user_not_found_message,
|
||||
)
|
||||
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
|
||||
|
||||
class TestGetSummaryForAgentState:
|
||||
"""Test cases for get_summary_for_agent_state function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.conversation_link = 'https://example.com/conversation/123'
|
||||
|
||||
def test_empty_observations_list(self):
|
||||
"""Test handling of empty observations list."""
|
||||
result = get_summary_for_agent_state([], self.conversation_link)
|
||||
|
||||
assert 'unknown error' in result.lower()
|
||||
assert self.conversation_link in result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'state,expected_text,includes_link',
|
||||
[
|
||||
(AgentState.RATE_LIMITED, 'rate limited', False),
|
||||
(AgentState.AWAITING_USER_INPUT, 'waiting for your input', True),
|
||||
],
|
||||
)
|
||||
def test_handled_agent_states(self, state, expected_text, includes_link):
|
||||
"""Test handling of states with specific behavior."""
|
||||
observation = AgentStateChangedObservation(
|
||||
content=f'Agent state: {state.value}', agent_state=state
|
||||
)
|
||||
|
||||
result = get_summary_for_agent_state([observation], self.conversation_link)
|
||||
|
||||
assert expected_text in result.lower()
|
||||
if includes_link:
|
||||
assert self.conversation_link in result
|
||||
else:
|
||||
assert self.conversation_link not in result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'state',
|
||||
[
|
||||
AgentState.FINISHED,
|
||||
AgentState.PAUSED,
|
||||
AgentState.STOPPED,
|
||||
AgentState.AWAITING_USER_CONFIRMATION,
|
||||
],
|
||||
)
|
||||
def test_unhandled_agent_states(self, state):
|
||||
"""Test handling of unhandled states (should all return unknown error)."""
|
||||
observation = AgentStateChangedObservation(
|
||||
content=f'Agent state: {state.value}', agent_state=state
|
||||
)
|
||||
|
||||
result = get_summary_for_agent_state([observation], self.conversation_link)
|
||||
|
||||
assert 'unknown error' in result.lower()
|
||||
assert self.conversation_link in result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'error_code,expected_text',
|
||||
[
|
||||
(
|
||||
'STATUS$ERROR_LLM_AUTHENTICATION',
|
||||
'authentication with the llm provider failed',
|
||||
),
|
||||
(
|
||||
'STATUS$ERROR_LLM_SERVICE_UNAVAILABLE',
|
||||
'llm service is temporarily unavailable',
|
||||
),
|
||||
(
|
||||
'STATUS$ERROR_LLM_INTERNAL_SERVER_ERROR',
|
||||
'llm provider encountered an internal error',
|
||||
),
|
||||
('STATUS$ERROR_LLM_OUT_OF_CREDITS', "you've run out of credits"),
|
||||
('STATUS$ERROR_LLM_CONTENT_POLICY_VIOLATION', 'content policy violation'),
|
||||
],
|
||||
)
|
||||
def test_error_state_readable_reasons(self, error_code, expected_text):
|
||||
"""Test all readable error reason mappings."""
|
||||
observation = AgentStateChangedObservation(
|
||||
content=f'Agent encountered error: {error_code}',
|
||||
agent_state=AgentState.ERROR,
|
||||
reason=error_code,
|
||||
)
|
||||
|
||||
result = get_summary_for_agent_state([observation], self.conversation_link)
|
||||
|
||||
assert 'encountered an error' in result.lower()
|
||||
assert expected_text in result.lower()
|
||||
assert self.conversation_link in result
|
||||
|
||||
def test_error_state_with_custom_reason(self):
|
||||
"""Test handling of ERROR state with a custom reason."""
|
||||
observation = AgentStateChangedObservation(
|
||||
content='Agent encountered an error',
|
||||
agent_state=AgentState.ERROR,
|
||||
reason='Test error message',
|
||||
)
|
||||
|
||||
result = get_summary_for_agent_state([observation], self.conversation_link)
|
||||
|
||||
assert 'encountered an error' in result.lower()
|
||||
assert 'test error message' in result.lower()
|
||||
assert self.conversation_link in result
|
||||
|
||||
def test_multiple_observations_uses_first(self):
|
||||
"""Test that when multiple observations are provided, only the first is used."""
|
||||
observation1 = AgentStateChangedObservation(
|
||||
content='Agent is awaiting user input',
|
||||
agent_state=AgentState.AWAITING_USER_INPUT,
|
||||
)
|
||||
observation2 = AgentStateChangedObservation(
|
||||
content='Agent encountered an error',
|
||||
agent_state=AgentState.ERROR,
|
||||
reason='Should not be used',
|
||||
)
|
||||
|
||||
result = get_summary_for_agent_state(
|
||||
[observation1, observation2], self.conversation_link
|
||||
)
|
||||
|
||||
# Should handle the first observation (AWAITING_USER_INPUT), not the second (ERROR)
|
||||
assert 'waiting for your input' in result.lower()
|
||||
assert 'error' not in result.lower()
|
||||
|
||||
def test_awaiting_user_input_specific_message(self):
|
||||
"""Test that AWAITING_USER_INPUT returns the specific expected message."""
|
||||
observation = AgentStateChangedObservation(
|
||||
content='Agent is awaiting user input',
|
||||
agent_state=AgentState.AWAITING_USER_INPUT,
|
||||
)
|
||||
|
||||
result = get_summary_for_agent_state([observation], self.conversation_link)
|
||||
|
||||
# Test the exact message format
|
||||
assert 'waiting for your input' in result.lower()
|
||||
assert 'continue the conversation' in result.lower()
|
||||
assert self.conversation_link in result
|
||||
assert 'unknown error' not in result.lower()
|
||||
|
||||
def test_rate_limited_specific_message(self):
|
||||
"""Test that RATE_LIMITED returns the specific expected message."""
|
||||
observation = AgentStateChangedObservation(
|
||||
content='Agent was rate limited', agent_state=AgentState.RATE_LIMITED
|
||||
)
|
||||
|
||||
result = get_summary_for_agent_state([observation], self.conversation_link)
|
||||
|
||||
# Test the exact message format
|
||||
assert 'rate limited' in result.lower()
|
||||
assert 'try again later' in result.lower()
|
||||
# RATE_LIMITED doesn't include conversation link in response
|
||||
assert self.conversation_link not in result
|
||||
|
||||
|
||||
class TestGetSessionExpiredMessage:
|
||||
"""Test cases for get_session_expired_message function."""
|
||||
@@ -293,138 +133,3 @@ class TestGetUserNotFoundMessage:
|
||||
result = get_user_not_found_message(None)
|
||||
assert not result.startswith('@')
|
||||
assert 'It looks like' in result
|
||||
|
||||
|
||||
class TestAppendConversationFooter:
|
||||
"""Test cases for append_conversation_footer function."""
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_appends_footer_with_markdown_link(self):
|
||||
"""Test that footer is appended with correct markdown link format."""
|
||||
# Arrange
|
||||
message = 'This is a test message'
|
||||
conversation_id = 'test-conv-123'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert result.startswith(message)
|
||||
assert (
|
||||
'[View full conversation](https://example.com/conversations/test-conv-123)'
|
||||
in result
|
||||
)
|
||||
assert result.endswith(
|
||||
'[View full conversation](https://example.com/conversations/test-conv-123)'
|
||||
)
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_footer_does_not_contain_html_tags(self):
|
||||
"""Test that footer does not contain HTML tags like <sub>."""
|
||||
# Arrange
|
||||
message = 'Test message'
|
||||
conversation_id = 'test-conv-456'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert '<sub>' not in result
|
||||
assert '</sub>' not in result
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_footer_format_with_newlines(self):
|
||||
"""Test that footer is properly separated with newlines."""
|
||||
# Arrange
|
||||
message = 'Original message content'
|
||||
conversation_id = 'test-conv-789'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert (
|
||||
result
|
||||
== 'Original message content\n\n[View full conversation](https://example.com/conversations/test-conv-789)'
|
||||
)
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_empty_message_still_appends_footer(self):
|
||||
"""Test that footer is appended even when message is empty."""
|
||||
# Arrange
|
||||
message = ''
|
||||
conversation_id = 'empty-msg-conv'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert result.startswith('\n\n')
|
||||
assert (
|
||||
'[View full conversation](https://example.com/conversations/empty-msg-conv)'
|
||||
in result
|
||||
)
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_conversation_id_with_special_characters(self):
|
||||
"""Test that footer handles conversation IDs with special characters."""
|
||||
# Arrange
|
||||
message = 'Test message'
|
||||
conversation_id = 'conv-123_abc-456'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
expected_url = 'https://example.com/conversations/conv-123_abc-456'
|
||||
assert expected_url in result
|
||||
assert '[View full conversation]' in result
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_multiline_message_preserves_content(self):
|
||||
"""Test that multiline messages are preserved correctly."""
|
||||
# Arrange
|
||||
message = 'Line 1\nLine 2\nLine 3'
|
||||
conversation_id = 'multiline-conv'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert result.startswith('Line 1\nLine 2\nLine 3')
|
||||
assert '\n\n[View full conversation]' in result
|
||||
assert message in result
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_footer_contains_only_markdown_syntax(self):
|
||||
"""Test that footer uses only markdown syntax, not HTML."""
|
||||
# Arrange
|
||||
message = 'Test message'
|
||||
conversation_id = 'markdown-test'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
footer_part = result[len(message) :]
|
||||
# Should only contain markdown link syntax: [text](url)
|
||||
assert footer_part.startswith('\n\n[')
|
||||
assert '](' in footer_part
|
||||
assert footer_part.endswith(')')
|
||||
# Should not contain any HTML tags (specifically <sub> tags that were removed)
|
||||
assert '<sub>' not in footer_part
|
||||
assert '</sub>' not in footer_part
|
||||
|
||||
@@ -4,8 +4,10 @@ Tests for:
|
||||
- _should_redirect_to_onboarding() function
|
||||
- _get_post_auth_redirect() function
|
||||
- /complete_onboarding endpoint
|
||||
- /onboarding_status endpoint
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -17,6 +19,7 @@ from server.routes.auth import (
|
||||
_get_post_auth_redirect,
|
||||
_should_redirect_to_onboarding,
|
||||
complete_onboarding,
|
||||
onboarding_status,
|
||||
)
|
||||
from storage.user import User
|
||||
|
||||
@@ -328,3 +331,78 @@ class TestCompleteOnboardingEndpoint:
|
||||
await complete_onboarding(mock_request)
|
||||
|
||||
mock_mark_completed.assert_called_once_with(user_id)
|
||||
|
||||
|
||||
class TestOnboardingStatusEndpoint:
|
||||
"""Tests for the /onboarding_status API endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_401_when_not_authenticated(self, mock_request):
|
||||
"""Unauthenticated requests return 401."""
|
||||
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
||||
mock_user_auth.get_user_id = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
'server.routes.auth.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
):
|
||||
result = await onboarding_status(mock_request)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_true_for_new_cloud_user(self, mock_request, mock_user):
|
||||
"""A cloud user whose onboarding is incomplete should be told to complete it."""
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_user.onboarding_completed = False
|
||||
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
||||
mock_user_auth.get_user_id = AsyncMock(return_value=user_id)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.auth.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
),
|
||||
patch('server.routes.auth.DEPLOYMENT_MODE', 'cloud'),
|
||||
):
|
||||
result = await onboarding_status(mock_request)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
body = json.loads(result.body)
|
||||
assert body == {'should_complete_onboarding': True}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_for_completed_user(self, mock_request, mock_user):
|
||||
"""A user who already completed onboarding should not be told to complete it."""
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_user.onboarding_completed = True
|
||||
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
||||
mock_user_auth.get_user_id = AsyncMock(return_value=user_id)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.auth.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
),
|
||||
):
|
||||
result = await onboarding_status(mock_request)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
body = json.loads(result.body)
|
||||
assert body == {'should_complete_onboarding': False}
|
||||
|
||||
193
enterprise/tests/unit/server/routes/test_github_integration.py
Normal file
193
enterprise/tests/unit/server/routes/test_github_integration.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Unit tests for GitHub integration routes.
|
||||
|
||||
Tests for:
|
||||
- get_github_token endpoint
|
||||
"""
|
||||
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_github_dependencies():
|
||||
"""Mock module-level dependencies before importing the github module.
|
||||
|
||||
The github.py module instantiates GitHubDataCollector at module level,
|
||||
which requires GitHub App credentials. We mock these dependencies to
|
||||
allow importing the module in test environments without credentials.
|
||||
"""
|
||||
# Store original modules if they exist
|
||||
original_modules = {}
|
||||
modules_to_mock = [
|
||||
'integrations.github.data_collector',
|
||||
'integrations.github.github_manager',
|
||||
'server.routes.integration.github',
|
||||
]
|
||||
for mod in modules_to_mock:
|
||||
if mod in sys.modules:
|
||||
original_modules[mod] = sys.modules[mod]
|
||||
del sys.modules[mod]
|
||||
|
||||
# Create mock GitHubDataCollector that doesn't require credentials
|
||||
mock_data_collector_module = MagicMock()
|
||||
mock_data_collector_instance = MagicMock()
|
||||
mock_data_collector_module.GitHubDataCollector.return_value = (
|
||||
mock_data_collector_instance
|
||||
)
|
||||
sys.modules['integrations.github.data_collector'] = mock_data_collector_module
|
||||
|
||||
# Create mock GithubManager
|
||||
mock_github_manager_module = MagicMock()
|
||||
mock_github_manager_instance = MagicMock()
|
||||
mock_github_manager_module.GithubManager.return_value = mock_github_manager_instance
|
||||
sys.modules['integrations.github.github_manager'] = mock_github_manager_module
|
||||
|
||||
yield
|
||||
|
||||
# Clean up the mocked modules
|
||||
for mod in modules_to_mock:
|
||||
if mod in sys.modules:
|
||||
del sys.modules[mod]
|
||||
|
||||
# Restore original modules
|
||||
for mod, original in original_modules.items():
|
||||
sys.modules[mod] = original
|
||||
|
||||
|
||||
class TestGitHubTokenResponse:
|
||||
"""Test suite for GitHubTokenResponse model."""
|
||||
|
||||
def test_github_token_response_with_valid_token(self):
|
||||
"""GitHubTokenResponse should accept a valid access_token."""
|
||||
from server.routes.integration.github import GitHubTokenResponse
|
||||
|
||||
response = GitHubTokenResponse(access_token='ghp_test_token_12345')
|
||||
assert response.access_token == 'ghp_test_token_12345'
|
||||
|
||||
def test_github_token_response_model_dump(self):
|
||||
"""GitHubTokenResponse model_dump should include access_token."""
|
||||
from server.routes.integration.github import GitHubTokenResponse
|
||||
|
||||
response = GitHubTokenResponse(access_token='ghp_test_token_12345')
|
||||
data = response.model_dump()
|
||||
assert data['access_token'] == 'ghp_test_token_12345'
|
||||
|
||||
|
||||
class TestGetGitHubToken:
|
||||
"""Test suite for get_github_token endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request(self):
|
||||
"""Create a mock request object."""
|
||||
request = MagicMock()
|
||||
request.state = MagicMock()
|
||||
return request
|
||||
|
||||
@pytest.fixture
|
||||
def mock_saas_user_auth(self):
|
||||
"""Create a mock SaasUserAuth object."""
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
|
||||
mock_auth = AsyncMock()
|
||||
mock_auth.get_provider_tokens = AsyncMock(
|
||||
return_value={
|
||||
ProviderType.GITHUB: ProviderToken(
|
||||
token=SecretStr('ghp_test_token_12345')
|
||||
)
|
||||
}
|
||||
)
|
||||
return mock_auth
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_github_token_success(self, mock_request, mock_saas_user_auth):
|
||||
"""Should return GitHub token when user has a valid token."""
|
||||
from server.routes.integration.github import (
|
||||
GitHubTokenResponse,
|
||||
get_github_token,
|
||||
)
|
||||
|
||||
with patch(
|
||||
'server.routes.integration.github.get_user_auth',
|
||||
return_value=mock_saas_user_auth,
|
||||
):
|
||||
result = await get_github_token(mock_request)
|
||||
|
||||
assert isinstance(result, GitHubTokenResponse)
|
||||
assert result.access_token == 'ghp_test_token_12345'
|
||||
mock_saas_user_auth.get_provider_tokens.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_github_token_no_provider_tokens(self, mock_request):
|
||||
"""Should raise 404 when user has no provider tokens."""
|
||||
from fastapi import HTTPException
|
||||
from server.routes.integration.github import get_github_token
|
||||
|
||||
mock_auth = AsyncMock()
|
||||
mock_auth.get_provider_tokens = AsyncMock(return_value=None)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.integration.github.get_user_auth',
|
||||
return_value=mock_auth,
|
||||
),
|
||||
pytest.raises(HTTPException) as exc_info,
|
||||
):
|
||||
await get_github_token(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'No provider tokens' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_github_token_no_github_token(self, mock_request):
|
||||
"""Should raise 404 when user has provider tokens but no GitHub token."""
|
||||
from fastapi import HTTPException
|
||||
from server.routes.integration.github import get_github_token
|
||||
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
|
||||
mock_auth = AsyncMock()
|
||||
# Return GitLab token but no GitHub token
|
||||
mock_auth.get_provider_tokens = AsyncMock(
|
||||
return_value={
|
||||
ProviderType.GITLAB: ProviderToken(
|
||||
token=SecretStr('glpat_test_token_12345')
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.integration.github.get_user_auth',
|
||||
return_value=mock_auth,
|
||||
),
|
||||
pytest.raises(HTTPException) as exc_info,
|
||||
):
|
||||
await get_github_token(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'No GitHub token' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_github_token_empty_provider_tokens(self, mock_request):
|
||||
"""Should raise 404 when user has empty provider tokens dict."""
|
||||
from fastapi import HTTPException
|
||||
from server.routes.integration.github import get_github_token
|
||||
|
||||
mock_auth = AsyncMock()
|
||||
mock_auth.get_provider_tokens = AsyncMock(return_value={})
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.integration.github.get_user_auth',
|
||||
return_value=mock_auth,
|
||||
),
|
||||
pytest.raises(HTTPException) as exc_info,
|
||||
):
|
||||
await get_github_token(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
# Empty dict is falsy, so it triggers the "no provider tokens" error
|
||||
assert 'No provider tokens' in exc_info.value.detail
|
||||
@@ -950,7 +950,6 @@ async def test_list_user_orgs_all_fields_present(mock_app_list):
|
||||
sandbox_runtime_container_image='test-runtime',
|
||||
org_version=5,
|
||||
max_budget_per_task=1000.0,
|
||||
enable_solvability_analysis=True,
|
||||
v1_enabled=True,
|
||||
)
|
||||
mock_user = MagicMock()
|
||||
@@ -994,7 +993,6 @@ async def test_list_user_orgs_all_fields_present(mock_app_list):
|
||||
assert org_data['sandbox_runtime_container_image'] == 'test-runtime'
|
||||
assert org_data['org_version'] == 5
|
||||
assert org_data['max_budget_per_task'] == 1000.0
|
||||
assert org_data['enable_solvability_analysis'] is True
|
||||
assert org_data['v1_enabled'] is True
|
||||
assert org_data['credits'] is None
|
||||
|
||||
@@ -3245,7 +3243,6 @@ async def test_get_org_app_settings_success(
|
||||
# Arrange
|
||||
mock_response = OrgAppSettingsResponse(
|
||||
enable_proactive_conversation_starters=True,
|
||||
enable_solvability_analysis=False,
|
||||
max_budget_per_task=10.0,
|
||||
)
|
||||
|
||||
@@ -3268,7 +3265,6 @@ async def test_get_org_app_settings_success(
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
response_data = response.json()
|
||||
assert response_data['enable_proactive_conversation_starters'] is True
|
||||
assert response_data['enable_solvability_analysis'] is False
|
||||
assert response_data['max_budget_per_task'] == 10.0
|
||||
|
||||
|
||||
@@ -3285,7 +3281,6 @@ async def test_get_org_app_settings_with_null_values(
|
||||
# OrgAppSettingsResponse.from_org() handles defaults, so we test the response model
|
||||
mock_response = OrgAppSettingsResponse(
|
||||
enable_proactive_conversation_starters=True, # Default when None in Org
|
||||
enable_solvability_analysis=None,
|
||||
max_budget_per_task=None,
|
||||
)
|
||||
|
||||
@@ -3309,7 +3304,6 @@ async def test_get_org_app_settings_with_null_values(
|
||||
response_data = response.json()
|
||||
# enable_proactive_conversation_starters defaults to True when None
|
||||
assert response_data['enable_proactive_conversation_starters'] is True
|
||||
assert response_data['enable_solvability_analysis'] is None
|
||||
assert response_data['max_budget_per_task'] is None
|
||||
|
||||
|
||||
@@ -3377,7 +3371,6 @@ async def test_update_org_app_settings_success(
|
||||
# Arrange
|
||||
mock_response = OrgAppSettingsResponse(
|
||||
enable_proactive_conversation_starters=False,
|
||||
enable_solvability_analysis=True,
|
||||
max_budget_per_task=25.0,
|
||||
)
|
||||
|
||||
@@ -3398,7 +3391,6 @@ async def test_update_org_app_settings_success(
|
||||
'/api/organizations/app',
|
||||
json={
|
||||
'enable_proactive_conversation_starters': False,
|
||||
'enable_solvability_analysis': True,
|
||||
'max_budget_per_task': 25.0,
|
||||
},
|
||||
)
|
||||
@@ -3407,7 +3399,6 @@ async def test_update_org_app_settings_success(
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
response_data = response.json()
|
||||
assert response_data['enable_proactive_conversation_starters'] is False
|
||||
assert response_data['enable_solvability_analysis'] is True
|
||||
assert response_data['max_budget_per_task'] == 25.0
|
||||
mock_update.assert_called_once()
|
||||
|
||||
@@ -3424,7 +3415,6 @@ async def test_update_org_app_settings_partial_update(
|
||||
# Arrange
|
||||
mock_response = OrgAppSettingsResponse(
|
||||
enable_proactive_conversation_starters=False,
|
||||
enable_solvability_analysis=True,
|
||||
max_budget_per_task=10.0, # Unchanged
|
||||
)
|
||||
|
||||
@@ -3467,7 +3457,6 @@ async def test_update_org_app_settings_set_null(
|
||||
# Arrange
|
||||
mock_response = OrgAppSettingsResponse(
|
||||
enable_proactive_conversation_starters=True,
|
||||
enable_solvability_analysis=True,
|
||||
max_budget_per_task=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -826,249 +826,3 @@ class TestCacheHelpers:
|
||||
service = create_service(mock_token_manager)
|
||||
# Should not raise
|
||||
await service._set_cached_value('test-key', 'test-value', 3600)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# GitLab-specific tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gitlab_group_payload():
|
||||
"""Create a sample GitLab webhook payload for a group-owned project."""
|
||||
return {
|
||||
'object_kind': 'issue',
|
||||
'event_type': 'issue',
|
||||
'project': {
|
||||
'id': 123456,
|
||||
'path_with_namespace': 'test-org/test-repo',
|
||||
'visibility': 'public',
|
||||
'namespace': {
|
||||
'id': 789,
|
||||
'kind': 'group',
|
||||
'name': 'test-org',
|
||||
},
|
||||
},
|
||||
'user': {
|
||||
'id': 12345,
|
||||
'username': 'testuser',
|
||||
},
|
||||
'object_attributes': {
|
||||
'id': 1,
|
||||
'iid': 1,
|
||||
'title': 'Test Issue',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gitlab_user_payload():
|
||||
"""Create a sample GitLab webhook payload for a user-owned project."""
|
||||
return {
|
||||
'object_kind': 'issue',
|
||||
'event_type': 'issue',
|
||||
'project': {
|
||||
'id': 654321,
|
||||
'path_with_namespace': 'testuser/personal-repo',
|
||||
'visibility': 'private',
|
||||
'namespace': {
|
||||
'id': 12345,
|
||||
'kind': 'user',
|
||||
'name': 'testuser',
|
||||
},
|
||||
},
|
||||
'user': {
|
||||
'id': 12345,
|
||||
'username': 'testuser',
|
||||
},
|
||||
'object_attributes': {
|
||||
'id': 2,
|
||||
'iid': 1,
|
||||
'title': 'Personal Issue',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestExtractOwnerInfoGitLab:
|
||||
"""Tests for _extract_owner_info method with GitLab payloads."""
|
||||
|
||||
def test_extract_gitlab_group_owner(self, mock_token_manager, gitlab_group_payload):
|
||||
"""
|
||||
GIVEN: GitLab payload for a group-owned project
|
||||
WHEN: _extract_owner_info is called
|
||||
THEN: Returns correct git_org, owner_type, owner_id
|
||||
"""
|
||||
with patch('server.services.automation_event_service.sio'):
|
||||
service = create_service(mock_token_manager)
|
||||
git_org, owner_type, owner_id = service._extract_owner_info(
|
||||
ProviderType.GITLAB, gitlab_group_payload
|
||||
)
|
||||
|
||||
assert git_org == 'test-org'
|
||||
assert owner_type == 'group'
|
||||
assert owner_id == 789
|
||||
|
||||
def test_extract_gitlab_user_owner(self, mock_token_manager, gitlab_user_payload):
|
||||
"""
|
||||
GIVEN: GitLab payload for a user-owned project
|
||||
WHEN: _extract_owner_info is called
|
||||
THEN: Returns correct git_org, owner_type, owner_id
|
||||
"""
|
||||
with patch('server.services.automation_event_service.sio'):
|
||||
service = create_service(mock_token_manager)
|
||||
git_org, owner_type, owner_id = service._extract_owner_info(
|
||||
ProviderType.GITLAB, gitlab_user_payload
|
||||
)
|
||||
|
||||
assert git_org == 'testuser'
|
||||
assert owner_type == 'user'
|
||||
assert owner_id == 12345
|
||||
|
||||
def test_extract_gitlab_missing_project(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: GitLab payload without project data
|
||||
WHEN: _extract_owner_info is called
|
||||
THEN: Returns None values
|
||||
"""
|
||||
payload = {'object_kind': 'issue'}
|
||||
|
||||
with patch('server.services.automation_event_service.sio'):
|
||||
service = create_service(mock_token_manager)
|
||||
git_org, owner_type, owner_id = service._extract_owner_info(
|
||||
ProviderType.GITLAB, payload
|
||||
)
|
||||
|
||||
assert git_org is None
|
||||
assert owner_type is None
|
||||
assert owner_id is None
|
||||
|
||||
|
||||
class TestResolveOrgContextGitLab:
|
||||
"""Tests for _resolve_org_context method with GitLab payloads."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_gitlab_group_org(
|
||||
self, mock_token_manager, mock_org_git_claim, gitlab_group_payload
|
||||
):
|
||||
"""
|
||||
GIVEN: GitLab payload for a group project with claimed org
|
||||
WHEN: _resolve_org_context is called
|
||||
THEN: Returns correct OrgContext with org_id
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None) # Cache miss
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org_git_claim.org_id,
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_org_context(
|
||||
ProviderType.GITLAB, gitlab_group_payload
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.org_id == mock_org_git_claim.org_id
|
||||
assert result.git_org == 'test-org'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_gitlab_user_personal_org_fallback(
|
||||
self, mock_token_manager, mock_org_git_claim, gitlab_user_payload
|
||||
):
|
||||
"""
|
||||
GIVEN: GitLab payload for a user project (not claimed via OrgGitClaim)
|
||||
WHEN: _resolve_org_context is called
|
||||
THEN: Falls back to personal org resolution
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None) # Cache miss
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # No OrgGitClaim for user
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
# Mock _resolve_personal_org to return a personal org
|
||||
service._resolve_personal_org = AsyncMock(
|
||||
return_value=mock_org_git_claim.org_id
|
||||
)
|
||||
|
||||
result = await service._resolve_org_context(
|
||||
ProviderType.GITLAB, gitlab_user_payload
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.org_id == mock_org_git_claim.org_id
|
||||
assert result.git_org == 'testuser'
|
||||
# Verify personal org fallback was called with GitLab user ID
|
||||
service._resolve_personal_org.assert_called_once_with(
|
||||
ProviderType.GITLAB, 12345
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_gitlab_no_org_found(
|
||||
self, mock_token_manager, gitlab_group_payload
|
||||
):
|
||||
"""
|
||||
GIVEN: GitLab payload for a project with no org claim and no personal org
|
||||
WHEN: _resolve_org_context is called
|
||||
THEN: Returns None
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_org_context(
|
||||
ProviderType.GITLAB, gitlab_group_payload
|
||||
)
|
||||
|
||||
# Group owner doesn't fall back to personal org
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestResolveGitOrgGitLab:
|
||||
"""Tests for _resolve_git_org method with GitLab provider."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_gitlab_org_includes_provider_in_cache_key(
|
||||
self, mock_token_manager, mock_org_git_claim
|
||||
):
|
||||
"""
|
||||
GIVEN: GitLab provider with an org name
|
||||
WHEN: _resolve_git_org is called
|
||||
THEN: Cache key includes the gitlab provider name
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org_git_claim.org_id,
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
await service._resolve_git_org(ProviderType.GITLAB, 'Test-Org')
|
||||
|
||||
# Verify cache key includes provider and normalized org name
|
||||
get_call_args = mock_redis.get.call_args[0][0]
|
||||
assert 'gitlab' in get_call_args
|
||||
assert 'test-org' in get_call_args # Lowercase normalized
|
||||
|
||||
@@ -29,7 +29,6 @@ def mock_org():
|
||||
org = MagicMock(spec=Org)
|
||||
org.id = uuid.uuid4()
|
||||
org.enable_proactive_conversation_starters = True
|
||||
org.enable_solvability_analysis = False
|
||||
org.max_budget_per_task = 25.0
|
||||
return org
|
||||
|
||||
@@ -67,7 +66,6 @@ async def test_get_org_app_settings_success(
|
||||
# Assert
|
||||
assert isinstance(result, OrgAppSettingsResponse)
|
||||
assert result.enable_proactive_conversation_starters is True
|
||||
assert result.enable_solvability_analysis is False
|
||||
assert result.max_budget_per_task == 25.0
|
||||
mock_store.get_current_org_by_user_id.assert_called_once_with(user_id)
|
||||
|
||||
|
||||
@@ -1,533 +0,0 @@
|
||||
"""
|
||||
Tests for conversation_callback_utils.py
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from server.utils.conversation_callback_utils import update_active_working_seconds
|
||||
from storage.conversation_work import ConversationWork
|
||||
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
class TestUpdateActiveWorkingSeconds:
|
||||
"""Test the update_active_working_seconds function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_store(self):
|
||||
"""Create a mock FileStore."""
|
||||
return Mock(spec=FileStore)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_store(self):
|
||||
"""Create a mock EventStore."""
|
||||
return Mock()
|
||||
|
||||
def test_update_active_working_seconds_multiple_state_changes(
|
||||
self, session_maker, mock_event_store, mock_file_store
|
||||
):
|
||||
"""Test calculating active working seconds with multiple state changes between running and ready."""
|
||||
conversation_id = 'test_conversation_123'
|
||||
user_id = 'test_user_456'
|
||||
|
||||
# Create a sequence of events with state changes between RUNNING and other states
|
||||
# Timeline:
|
||||
# t=0: RUNNING (start)
|
||||
# t=10: AWAITING_USER_INPUT (10 seconds of running)
|
||||
# t=15: RUNNING (start again)
|
||||
# t=25: FINISHED (10 more seconds of running)
|
||||
# t=30: RUNNING (start again)
|
||||
# t=40: PAUSED (10 more seconds of running)
|
||||
# Total: 30 seconds of running time
|
||||
|
||||
# Create mock events with ISO-formatted timestamps for testing
|
||||
events = []
|
||||
|
||||
# First running period: 10 seconds
|
||||
event1 = Mock(spec=AgentStateChangedObservation)
|
||||
event1.agent_state = AgentState.RUNNING
|
||||
event1.timestamp = '1970-01-01T00:00:00.000000'
|
||||
events.append(event1)
|
||||
|
||||
event2 = Mock(spec=AgentStateChangedObservation)
|
||||
event2.agent_state = AgentState.AWAITING_USER_INPUT
|
||||
event2.timestamp = '1970-01-01T00:00:10.000000'
|
||||
events.append(event2)
|
||||
|
||||
# Second running period: 10 seconds
|
||||
event3 = Mock(spec=AgentStateChangedObservation)
|
||||
event3.agent_state = AgentState.RUNNING
|
||||
event3.timestamp = '1970-01-01T00:00:15.000000'
|
||||
events.append(event3)
|
||||
|
||||
event4 = Mock(spec=AgentStateChangedObservation)
|
||||
event4.agent_state = AgentState.FINISHED
|
||||
event4.timestamp = '1970-01-01T00:00:25.000000'
|
||||
events.append(event4)
|
||||
|
||||
# Third running period: 10 seconds
|
||||
event5 = Mock(spec=AgentStateChangedObservation)
|
||||
event5.agent_state = AgentState.RUNNING
|
||||
event5.timestamp = '1970-01-01T00:00:30.000000'
|
||||
events.append(event5)
|
||||
|
||||
event6 = Mock(spec=AgentStateChangedObservation)
|
||||
event6.agent_state = AgentState.PAUSED
|
||||
event6.timestamp = '1970-01-01T00:00:40.000000'
|
||||
events.append(event6)
|
||||
|
||||
# Configure the mock event store to return our test events
|
||||
mock_event_store.search_events.return_value = events
|
||||
|
||||
# Call the function under test with mocked session_maker
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.session_maker', session_maker
|
||||
):
|
||||
update_active_working_seconds(
|
||||
mock_event_store, conversation_id, user_id, mock_file_store
|
||||
)
|
||||
|
||||
# Verify the ConversationWork record was created with correct total seconds
|
||||
with session_maker() as session:
|
||||
conversation_work = (
|
||||
session.query(ConversationWork)
|
||||
.filter(ConversationWork.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert conversation_work is not None
|
||||
assert conversation_work.conversation_id == conversation_id
|
||||
assert conversation_work.user_id == user_id
|
||||
assert conversation_work.seconds == 30.0 # Total running time
|
||||
assert conversation_work.created_at is not None
|
||||
assert conversation_work.updated_at is not None
|
||||
|
||||
def test_update_active_working_seconds_updates_existing_record(
|
||||
self, session_maker, mock_event_store, mock_file_store
|
||||
):
|
||||
"""Test that the function updates an existing ConversationWork record."""
|
||||
conversation_id = 'test_conversation_456'
|
||||
user_id = 'test_user_789'
|
||||
|
||||
# Create an existing ConversationWork record
|
||||
with session_maker() as session:
|
||||
existing_work = ConversationWork(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
seconds=15.0, # Previous value
|
||||
)
|
||||
session.add(existing_work)
|
||||
session.commit()
|
||||
|
||||
# Create events with new running time
|
||||
event1 = Mock(spec=AgentStateChangedObservation)
|
||||
event1.agent_state = AgentState.RUNNING
|
||||
event1.timestamp = '1970-01-01T00:00:00.000000'
|
||||
|
||||
event2 = Mock(spec=AgentStateChangedObservation)
|
||||
event2.agent_state = AgentState.STOPPED
|
||||
event2.timestamp = '1970-01-01T00:00:20.000000'
|
||||
|
||||
events = [event1, event2]
|
||||
|
||||
mock_event_store.search_events.return_value = events
|
||||
|
||||
# Call the function under test with mocked session_maker
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.session_maker', session_maker
|
||||
):
|
||||
update_active_working_seconds(
|
||||
mock_event_store, conversation_id, user_id, mock_file_store
|
||||
)
|
||||
|
||||
# Verify the existing record was updated
|
||||
with session_maker() as session:
|
||||
conversation_work = (
|
||||
session.query(ConversationWork)
|
||||
.filter(ConversationWork.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert conversation_work is not None
|
||||
assert conversation_work.seconds == 20.0 # Updated value
|
||||
assert conversation_work.user_id == user_id
|
||||
|
||||
def test_update_active_working_seconds_agent_still_running(
|
||||
self, session_maker, mock_event_store, mock_file_store
|
||||
):
|
||||
"""Test that time is not counted if agent is still running at the end."""
|
||||
conversation_id = 'test_conversation_789'
|
||||
user_id = 'test_user_012'
|
||||
|
||||
# Create events where agent starts running but never stops
|
||||
event1 = Mock(spec=AgentStateChangedObservation)
|
||||
event1.agent_state = AgentState.RUNNING
|
||||
event1.timestamp = '1970-01-01T00:00:00.000000'
|
||||
|
||||
event2 = Mock(spec=AgentStateChangedObservation)
|
||||
event2.agent_state = AgentState.AWAITING_USER_INPUT
|
||||
event2.timestamp = '1970-01-01T00:00:10.000000'
|
||||
|
||||
event3 = Mock(spec=AgentStateChangedObservation)
|
||||
event3.agent_state = AgentState.RUNNING
|
||||
event3.timestamp = '1970-01-01T00:00:15.000000'
|
||||
|
||||
events = [event1, event2, event3]
|
||||
# No final state change - agent still running
|
||||
|
||||
mock_event_store.search_events.return_value = events
|
||||
|
||||
# Call the function under test with mocked session_maker
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.session_maker', session_maker
|
||||
):
|
||||
update_active_working_seconds(
|
||||
mock_event_store, conversation_id, user_id, mock_file_store
|
||||
)
|
||||
|
||||
# Verify only the completed running period is counted
|
||||
with session_maker() as session:
|
||||
conversation_work = (
|
||||
session.query(ConversationWork)
|
||||
.filter(ConversationWork.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert conversation_work is not None
|
||||
assert conversation_work.seconds == 10.0 # Only the first completed period
|
||||
|
||||
def test_update_active_working_seconds_no_running_states(
|
||||
self, session_maker, mock_event_store, mock_file_store
|
||||
):
|
||||
"""Test that zero seconds are recorded when there are no running states."""
|
||||
conversation_id = 'test_conversation_000'
|
||||
user_id = 'test_user_000'
|
||||
|
||||
# Create events with no RUNNING states
|
||||
event1 = Mock(spec=AgentStateChangedObservation)
|
||||
event1.agent_state = AgentState.LOADING
|
||||
event1.timestamp = '1970-01-01T00:00:00.000000'
|
||||
|
||||
event2 = Mock(spec=AgentStateChangedObservation)
|
||||
event2.agent_state = AgentState.AWAITING_USER_INPUT
|
||||
event2.timestamp = '1970-01-01T00:00:05.000000'
|
||||
|
||||
event3 = Mock(spec=AgentStateChangedObservation)
|
||||
event3.agent_state = AgentState.FINISHED
|
||||
event3.timestamp = '1970-01-01T00:00:10.000000'
|
||||
|
||||
events = [event1, event2, event3]
|
||||
|
||||
mock_event_store.search_events.return_value = events
|
||||
|
||||
# Call the function under test with mocked session_maker
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.session_maker', session_maker
|
||||
):
|
||||
update_active_working_seconds(
|
||||
mock_event_store, conversation_id, user_id, mock_file_store
|
||||
)
|
||||
|
||||
# Verify zero seconds are recorded
|
||||
with session_maker() as session:
|
||||
conversation_work = (
|
||||
session.query(ConversationWork)
|
||||
.filter(ConversationWork.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert conversation_work is not None
|
||||
assert conversation_work.seconds == 0.0
|
||||
|
||||
def test_update_active_working_seconds_mixed_event_types(
|
||||
self, session_maker, mock_event_store, mock_file_store
|
||||
):
|
||||
"""Test that only AgentStateChangedObservation events are processed."""
|
||||
conversation_id = 'test_conversation_mixed'
|
||||
user_id = 'test_user_mixed'
|
||||
|
||||
# Create a mix of event types, only AgentStateChangedObservation should be processed
|
||||
event1 = Mock(spec=AgentStateChangedObservation)
|
||||
event1.agent_state = AgentState.RUNNING
|
||||
event1.timestamp = '1970-01-01T00:00:00.000000'
|
||||
|
||||
# Mock other event types that should be ignored
|
||||
event2 = Mock() # Not an AgentStateChangedObservation
|
||||
event2.timestamp = '1970-01-01T00:00:05.000000'
|
||||
|
||||
event3 = Mock() # Not an AgentStateChangedObservation
|
||||
event3.timestamp = '1970-01-01T00:00:08.000000'
|
||||
|
||||
event4 = Mock(spec=AgentStateChangedObservation)
|
||||
event4.agent_state = AgentState.STOPPED
|
||||
event4.timestamp = '1970-01-01T00:00:10.000000'
|
||||
|
||||
events = [event1, event2, event3, event4]
|
||||
|
||||
mock_event_store.search_events.return_value = events
|
||||
|
||||
# Call the function under test with mocked session_maker
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.session_maker', session_maker
|
||||
):
|
||||
update_active_working_seconds(
|
||||
mock_event_store, conversation_id, user_id, mock_file_store
|
||||
)
|
||||
|
||||
# Verify only the AgentStateChangedObservation events were processed
|
||||
with session_maker() as session:
|
||||
conversation_work = (
|
||||
session.query(ConversationWork)
|
||||
.filter(ConversationWork.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert conversation_work is not None
|
||||
assert conversation_work.seconds == 10.0 # Only the valid state changes
|
||||
|
||||
@patch('server.utils.conversation_callback_utils.logger')
|
||||
def test_update_active_working_seconds_handles_exceptions(
|
||||
self, mock_logger, session_maker, mock_event_store, mock_file_store
|
||||
):
|
||||
"""Test that exceptions are properly handled and logged."""
|
||||
conversation_id = 'test_conversation_error'
|
||||
user_id = 'test_user_error'
|
||||
|
||||
# Configure the mock to raise an exception
|
||||
mock_event_store.search_events.side_effect = Exception('Test error')
|
||||
|
||||
# Call the function under test
|
||||
update_active_working_seconds(
|
||||
mock_event_store, conversation_id, user_id, mock_file_store
|
||||
)
|
||||
|
||||
# Verify the error was logged
|
||||
mock_logger.error.assert_called_once()
|
||||
error_call = mock_logger.error.call_args
|
||||
assert error_call[0][0] == 'failed_to_update_active_working_seconds'
|
||||
assert error_call[1]['extra']['conversation_id'] == conversation_id
|
||||
assert error_call[1]['extra']['user_id'] == user_id
|
||||
assert 'Test error' in error_call[1]['extra']['error']
|
||||
|
||||
def test_update_active_working_seconds_complex_state_transitions(
|
||||
self, session_maker, mock_event_store, mock_file_store
|
||||
):
|
||||
"""Test complex state transitions including error and rate limited states."""
|
||||
conversation_id = 'test_conversation_complex'
|
||||
user_id = 'test_user_complex'
|
||||
|
||||
# Create a complex sequence of state changes
|
||||
events = []
|
||||
|
||||
# First running period: 5 seconds
|
||||
event1 = Mock(spec=AgentStateChangedObservation)
|
||||
event1.agent_state = AgentState.LOADING
|
||||
event1.timestamp = '1970-01-01T00:00:00.000000'
|
||||
events.append(event1)
|
||||
|
||||
event2 = Mock(spec=AgentStateChangedObservation)
|
||||
event2.agent_state = AgentState.RUNNING
|
||||
event2.timestamp = '1970-01-01T00:00:02.000000'
|
||||
events.append(event2)
|
||||
|
||||
event3 = Mock(spec=AgentStateChangedObservation)
|
||||
event3.agent_state = AgentState.ERROR
|
||||
event3.timestamp = '1970-01-01T00:00:07.000000'
|
||||
events.append(event3)
|
||||
|
||||
# Second running period: 8 seconds
|
||||
event4 = Mock(spec=AgentStateChangedObservation)
|
||||
event4.agent_state = AgentState.RUNNING
|
||||
event4.timestamp = '1970-01-01T00:00:10.000000'
|
||||
events.append(event4)
|
||||
|
||||
event5 = Mock(spec=AgentStateChangedObservation)
|
||||
event5.agent_state = AgentState.RATE_LIMITED
|
||||
event5.timestamp = '1970-01-01T00:00:18.000000'
|
||||
events.append(event5)
|
||||
|
||||
# Third running period: 3 seconds
|
||||
event6 = Mock(spec=AgentStateChangedObservation)
|
||||
event6.agent_state = AgentState.RUNNING
|
||||
event6.timestamp = '1970-01-01T00:00:20.000000'
|
||||
events.append(event6)
|
||||
|
||||
event7 = Mock(spec=AgentStateChangedObservation)
|
||||
event7.agent_state = AgentState.AWAITING_USER_CONFIRMATION
|
||||
event7.timestamp = '1970-01-01T00:00:23.000000'
|
||||
events.append(event7)
|
||||
|
||||
event8 = Mock(spec=AgentStateChangedObservation)
|
||||
event8.agent_state = AgentState.USER_CONFIRMED
|
||||
event8.timestamp = '1970-01-01T00:00:25.000000'
|
||||
events.append(event8)
|
||||
|
||||
# Fourth running period: 7 seconds
|
||||
event9 = Mock(spec=AgentStateChangedObservation)
|
||||
event9.agent_state = AgentState.RUNNING
|
||||
event9.timestamp = '1970-01-01T00:00:30.000000'
|
||||
events.append(event9)
|
||||
|
||||
event10 = Mock(spec=AgentStateChangedObservation)
|
||||
event10.agent_state = AgentState.FINISHED
|
||||
event10.timestamp = '1970-01-01T00:00:37.000000'
|
||||
events.append(event10)
|
||||
|
||||
mock_event_store.search_events.return_value = events
|
||||
|
||||
# Call the function under test with mocked session_maker
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.session_maker', session_maker
|
||||
):
|
||||
update_active_working_seconds(
|
||||
mock_event_store, conversation_id, user_id, mock_file_store
|
||||
)
|
||||
|
||||
# Verify the total running time is calculated correctly
|
||||
# Running periods: 5 + 8 + 3 + 7 = 23 seconds
|
||||
with session_maker() as session:
|
||||
conversation_work = (
|
||||
session.query(ConversationWork)
|
||||
.filter(ConversationWork.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert conversation_work is not None
|
||||
assert conversation_work.seconds == 23.0
|
||||
assert conversation_work.conversation_id == conversation_id
|
||||
assert conversation_work.user_id == user_id
|
||||
|
||||
|
||||
class TestInvokeConversationCallbacks:
|
||||
"""Tests for invoke_conversation_callbacks function.
|
||||
|
||||
This function uses async database sessions (a_session_maker) to query
|
||||
and invoke callbacks for a conversation.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_observation(self):
|
||||
"""Create a mock AgentStateChangedObservation."""
|
||||
|
||||
observation = Mock(spec=AgentStateChangedObservation)
|
||||
observation.agent_state = AgentState.FINISHED
|
||||
return observation
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_async_session(self):
|
||||
"""Factory to create properly mocked async session context manager."""
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
def _create(callbacks_list):
|
||||
mock_session = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.scalars.return_value.all.return_value = callbacks_list
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock(return_value=None)
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_context_manager():
|
||||
yield mock_session
|
||||
|
||||
return mock_context_manager, mock_session
|
||||
|
||||
return _create
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_callbacks_with_active_callbacks(
|
||||
self, mock_observation, create_mock_async_session
|
||||
):
|
||||
"""Test that active callbacks are invoked successfully."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
# Arrange
|
||||
conversation_id = 'test_conversation_callbacks'
|
||||
mock_processor = AsyncMock(return_value=None)
|
||||
|
||||
# Create a mock callback
|
||||
mock_callback = Mock()
|
||||
mock_callback.id = 1
|
||||
mock_callback.processor_type = 'test_processor'
|
||||
mock_callback.get_processor.return_value = mock_processor
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session([mock_callback])
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.a_session_maker',
|
||||
mock_context_manager,
|
||||
):
|
||||
from server.utils.conversation_callback_utils import (
|
||||
invoke_conversation_callbacks,
|
||||
)
|
||||
|
||||
await invoke_conversation_callbacks(conversation_id, mock_observation)
|
||||
|
||||
# Assert
|
||||
mock_callback.get_processor.assert_called_once()
|
||||
mock_processor.assert_called_once_with(mock_callback, mock_observation)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_callbacks_with_no_active_callbacks(
|
||||
self, mock_observation, create_mock_async_session
|
||||
):
|
||||
"""Test behavior when no active callbacks exist."""
|
||||
# Arrange
|
||||
conversation_id = 'test_no_callbacks'
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session([])
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.a_session_maker',
|
||||
mock_context_manager,
|
||||
):
|
||||
from server.utils.conversation_callback_utils import (
|
||||
invoke_conversation_callbacks,
|
||||
)
|
||||
|
||||
await invoke_conversation_callbacks(conversation_id, mock_observation)
|
||||
|
||||
# Assert - should complete without errors
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_callbacks_handles_processor_exception(
|
||||
self, mock_observation, create_mock_async_session
|
||||
):
|
||||
"""Test that processor exceptions are caught and callback status is updated."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
# Arrange
|
||||
conversation_id = 'test_callback_error'
|
||||
mock_processor = AsyncMock(side_effect=Exception('Processor error'))
|
||||
|
||||
mock_callback = Mock()
|
||||
mock_callback.id = 1
|
||||
mock_callback.processor_type = 'failing_processor'
|
||||
mock_callback.get_processor.return_value = mock_processor
|
||||
mock_callback.status = 'active'
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session([mock_callback])
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.a_session_maker',
|
||||
mock_context_manager,
|
||||
), patch('server.utils.conversation_callback_utils.logger') as mock_logger:
|
||||
from server.utils.conversation_callback_utils import (
|
||||
invoke_conversation_callbacks,
|
||||
)
|
||||
from storage.conversation_callback import CallbackStatus
|
||||
|
||||
await invoke_conversation_callbacks(conversation_id, mock_observation)
|
||||
|
||||
# Assert - callback status should be set to ERROR
|
||||
assert mock_callback.status == CallbackStatus.ERROR
|
||||
mock_logger.error.assert_called_once()
|
||||
error_call = mock_logger.error.call_args
|
||||
assert error_call[0][0] == 'callback_invocation_failed'
|
||||
@@ -1,113 +0,0 @@
|
||||
"""
|
||||
Shared fixtures for all tests.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from integrations.solvability.models.classifier import SolvabilityClassifier
|
||||
from integrations.solvability.models.featurizer import (
|
||||
Feature,
|
||||
FeatureEmbedding,
|
||||
Featurizer,
|
||||
)
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def features() -> list[Feature]:
|
||||
"""Create a list of features for testing."""
|
||||
return [
|
||||
Feature(identifier='feature1', description='Test feature 1'),
|
||||
Feature(identifier='feature2', description='Test feature 2'),
|
||||
Feature(identifier='feature3', description='Test feature 3'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def feature_embedding() -> FeatureEmbedding:
|
||||
"""Create a feature embedding for testing."""
|
||||
return FeatureEmbedding(
|
||||
samples=[
|
||||
{'feature1': True, 'feature2': False, 'feature3': True},
|
||||
{'feature1': False, 'feature2': True, 'feature3': True},
|
||||
],
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
response_latency=0.1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def featurizer(mock_llm, features) -> Featurizer:
|
||||
"""
|
||||
Create a featurizer for testing.
|
||||
|
||||
Mocks out any calls to LLM.completion
|
||||
"""
|
||||
pytest.MonkeyPatch().setattr(
|
||||
'integrations.solvability.models.featurizer.LLM',
|
||||
lambda *args, **kwargs: mock_llm,
|
||||
)
|
||||
|
||||
featurizer = Featurizer(
|
||||
system_prompt='Test system prompt',
|
||||
message_prefix='Test message prefix: ',
|
||||
features=features,
|
||||
)
|
||||
|
||||
return featurizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion_response() -> dict[str, Any]:
|
||||
"""Create a mock response for the feature sample model."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.tool_calls = [MagicMock()]
|
||||
mock_response.choices[0].message.tool_calls[
|
||||
0
|
||||
].function.arguments = '{"feature1": true, "feature2": false, "feature3": true}'
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
return mock_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm(mock_completion_response):
|
||||
"""Create a mock LLM instance."""
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_llm_instance.completion.return_value = mock_completion_response
|
||||
return mock_llm_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_config():
|
||||
"""Create a mock LLM config for testing."""
|
||||
return LLMConfig(model='test-model')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_classifier():
|
||||
"""Create a mock classifier for testing."""
|
||||
clf = RandomForestClassifier(random_state=42)
|
||||
# Initialize with some dummy data to avoid errors
|
||||
X = np.array([[0, 0, 0], [1, 1, 1]]) # noqa: N806
|
||||
y = np.array([0, 1])
|
||||
clf.fit(X, y)
|
||||
return clf
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def solvability_classifier(featurizer, mock_classifier):
|
||||
"""Create a SolvabilityClassifier instance for testing."""
|
||||
return SolvabilityClassifier(
|
||||
identifier='test-classifier',
|
||||
featurizer=featurizer,
|
||||
classifier=mock_classifier,
|
||||
random_state=42,
|
||||
)
|
||||
@@ -1,218 +0,0 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from integrations.solvability.models.classifier import SolvabilityClassifier
|
||||
from integrations.solvability.models.featurizer import Feature
|
||||
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
|
||||
@pytest.mark.parametrize('random_state', [None, 42])
|
||||
def test_random_state_initialization(random_state, featurizer):
|
||||
"""Test initialization of the solvability classifier random state propagates to the RFC."""
|
||||
# If the RFC has no random state, the solvability classifier should propagate
|
||||
# its random state down.
|
||||
solvability_classifier = SolvabilityClassifier(
|
||||
identifier='test',
|
||||
featurizer=featurizer,
|
||||
classifier=RandomForestClassifier(random_state=None),
|
||||
random_state=random_state,
|
||||
)
|
||||
|
||||
# The classifier's random_state should be updated to match
|
||||
assert solvability_classifier.random_state == random_state
|
||||
assert solvability_classifier.classifier.random_state == random_state
|
||||
|
||||
# If the RFC somehow has a random state, as long as it matches the solvability
|
||||
# classifier's random state initialization should succeed.
|
||||
solvability_classifier = SolvabilityClassifier(
|
||||
identifier='test',
|
||||
featurizer=featurizer,
|
||||
classifier=RandomForestClassifier(random_state=random_state),
|
||||
random_state=random_state,
|
||||
)
|
||||
|
||||
assert solvability_classifier.random_state == random_state
|
||||
assert solvability_classifier.classifier.random_state == random_state
|
||||
|
||||
|
||||
def test_inconsistent_random_state(featurizer):
|
||||
"""Test validation fails when the classifier and RFC have inconsistent random states."""
|
||||
classifier = RandomForestClassifier(random_state=42)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
SolvabilityClassifier(
|
||||
identifier='test',
|
||||
featurizer=featurizer,
|
||||
classifier=classifier,
|
||||
random_state=24,
|
||||
)
|
||||
|
||||
|
||||
def test_transform_produces_feature_columns(solvability_classifier, mock_llm_config):
|
||||
"""Test transform method produces expected feature columns."""
|
||||
issues = pd.Series(['Test issue'])
|
||||
features = solvability_classifier.transform(issues, llm_config=mock_llm_config)
|
||||
|
||||
assert isinstance(features, pd.DataFrame)
|
||||
|
||||
for feature in solvability_classifier.featurizer.features:
|
||||
assert feature.identifier in features.columns
|
||||
|
||||
|
||||
def test_transform_sets_classifier_attrs(solvability_classifier, mock_llm_config):
|
||||
"""Test transform method sets classifier attributes `features_` and `cost_`."""
|
||||
issues = pd.Series(['Test issue'])
|
||||
features = solvability_classifier.transform(issues, llm_config=mock_llm_config)
|
||||
|
||||
# Make sure the features_ attr is set and equivalent to the transformed features.
|
||||
np.testing.assert_array_equal(features, solvability_classifier.features_)
|
||||
|
||||
# Make sure the cost attr exists and has all the columns we'd expect.
|
||||
assert solvability_classifier.cost_ is not None
|
||||
assert isinstance(solvability_classifier.cost_, pd.DataFrame)
|
||||
assert 'prompt_tokens' in solvability_classifier.cost_.columns
|
||||
assert 'completion_tokens' in solvability_classifier.cost_.columns
|
||||
assert 'response_latency' in solvability_classifier.cost_.columns
|
||||
|
||||
|
||||
def test_fit_sets_classifier_attrs(solvability_classifier, mock_llm_config):
|
||||
"""Test fit method sets classifier attribute `feature_importances_`."""
|
||||
issues = pd.Series(['Test issue'])
|
||||
labels = pd.Series([1])
|
||||
|
||||
# Fit the classifier
|
||||
solvability_classifier.fit(issues, labels, llm_config=mock_llm_config)
|
||||
|
||||
# Check that the feature importances are set
|
||||
assert 'feature_importances_' in solvability_classifier._classifier_attrs
|
||||
assert isinstance(solvability_classifier.feature_importances_, np.ndarray)
|
||||
|
||||
|
||||
def test_predict_proba_sets_classifier_attrs(solvability_classifier, mock_llm_config):
|
||||
"""Test predict_proba method sets classifier attribute `feature_importances_`."""
|
||||
issues = pd.Series(['Test issue'])
|
||||
|
||||
# Call predict_proba -- we don't care about the output here, just the side
|
||||
# effects.
|
||||
_ = solvability_classifier.predict_proba(issues, llm_config=mock_llm_config)
|
||||
|
||||
# Check that the feature importances are set
|
||||
assert 'feature_importances_' in solvability_classifier._classifier_attrs
|
||||
assert isinstance(solvability_classifier.feature_importances_, np.ndarray)
|
||||
|
||||
|
||||
def test_predict_sets_classifier_attrs(solvability_classifier, mock_llm_config):
|
||||
"""Test predict method sets classifier attribute `feature_importances_`."""
|
||||
issues = pd.Series(['Test issue'])
|
||||
|
||||
# Call predict -- we don't care about the output here, just the side effects.
|
||||
_ = solvability_classifier.predict(issues, llm_config=mock_llm_config)
|
||||
|
||||
# Check that the feature importances are set
|
||||
assert 'feature_importances_' in solvability_classifier._classifier_attrs
|
||||
assert isinstance(solvability_classifier.feature_importances_, np.ndarray)
|
||||
|
||||
|
||||
def test_add_single_feature(solvability_classifier):
|
||||
"""Test that a single feature can be added."""
|
||||
feature = Feature(identifier='new_feature', description='New test feature')
|
||||
|
||||
assert feature not in solvability_classifier.featurizer.features
|
||||
|
||||
solvability_classifier.add_features([feature])
|
||||
assert feature in solvability_classifier.featurizer.features
|
||||
|
||||
|
||||
def test_add_multiple_features(solvability_classifier):
|
||||
"""Test that multiple features can be added."""
|
||||
feature_1 = Feature(identifier='new_feature_1', description='New test feature 1')
|
||||
feature_2 = Feature(identifier='new_feature_2', description='New test feature 2')
|
||||
|
||||
assert feature_1 not in solvability_classifier.featurizer.features
|
||||
assert feature_2 not in solvability_classifier.featurizer.features
|
||||
|
||||
solvability_classifier.add_features([feature_1, feature_2])
|
||||
|
||||
assert feature_1 in solvability_classifier.featurizer.features
|
||||
assert feature_2 in solvability_classifier.featurizer.features
|
||||
|
||||
|
||||
def test_add_features_idempotency(solvability_classifier):
|
||||
"""Test that adding the same feature multiple times does not duplicate it."""
|
||||
feature = Feature(identifier='new_feature', description='New test feature')
|
||||
|
||||
# Add the feature once
|
||||
solvability_classifier.add_features([feature])
|
||||
num_features = len(solvability_classifier.featurizer.features)
|
||||
|
||||
# Add the same feature again -- number of features should not increase
|
||||
solvability_classifier.add_features([feature])
|
||||
assert len(solvability_classifier.featurizer.features) == num_features
|
||||
|
||||
|
||||
@pytest.mark.parametrize('strategy', list(ImportanceStrategy))
|
||||
def test_importance_strategies(strategy, solvability_classifier, mock_llm_config):
|
||||
"""Test different importance strategies."""
|
||||
# Setup
|
||||
issues = pd.Series(['Test issue', 'Another test issue'])
|
||||
labels = pd.Series([1, 0])
|
||||
|
||||
# Set the importance strategy
|
||||
solvability_classifier.importance_strategy = strategy
|
||||
|
||||
# Fit the model -- this will force the classifier to compute feature importances
|
||||
# and set them in the feature_importances_ attribute.
|
||||
solvability_classifier.fit(issues, labels, llm_config=mock_llm_config)
|
||||
|
||||
assert 'feature_importances_' in solvability_classifier._classifier_attrs
|
||||
assert isinstance(solvability_classifier.feature_importances_, np.ndarray)
|
||||
|
||||
# Make sure the feature importances actually have some values to them.
|
||||
assert not np.isnan(solvability_classifier.feature_importances_).any()
|
||||
|
||||
|
||||
def test_is_fitted_property(solvability_classifier, mock_llm_config):
|
||||
"""Test the is_fitted property accurately reflects the classifier's state."""
|
||||
issues = pd.Series(['Test issue', 'Another test issue'])
|
||||
labels = pd.Series([1, 0])
|
||||
|
||||
# Set the solvability classifier's RFC to a fresh instance to ensure it's not fitted.
|
||||
solvability_classifier.classifier = RandomForestClassifier(random_state=42)
|
||||
assert not solvability_classifier.is_fitted
|
||||
|
||||
solvability_classifier.fit(issues, labels, llm_config=mock_llm_config)
|
||||
assert solvability_classifier.is_fitted
|
||||
|
||||
|
||||
def test_solvability_report_well_formed(solvability_classifier, mock_llm_config):
|
||||
"""Test that the SolvabilityReport is well-formed and all required fields are present."""
|
||||
issues = pd.Series(['Test issue', 'Another test issue'])
|
||||
labels = pd.Series([1, 0])
|
||||
# Fit the classifier
|
||||
solvability_classifier.fit(issues, labels, llm_config=mock_llm_config)
|
||||
|
||||
report = solvability_classifier.solvability_report(
|
||||
issues.iloc[0], llm_config=mock_llm_config
|
||||
)
|
||||
|
||||
# Generation of the report is a strong enough test (as it has to get past all
|
||||
# the pydantic validators). But just in case we can also double-check the field
|
||||
# values.
|
||||
assert report.identifier == solvability_classifier.identifier
|
||||
assert report.issue == issues.iloc[0]
|
||||
assert 0 <= report.score <= 1
|
||||
assert report.samples == solvability_classifier.samples
|
||||
assert set(report.features.keys()) == set(
|
||||
solvability_classifier.featurizer.feature_identifiers()
|
||||
)
|
||||
assert report.importance_strategy == solvability_classifier.importance_strategy
|
||||
assert set(report.feature_importances.keys()) == set(
|
||||
solvability_classifier.featurizer.feature_identifiers()
|
||||
)
|
||||
assert report.random_state == solvability_classifier.random_state
|
||||
assert report.created_at is not None
|
||||
assert report.prompt_tokens >= 0
|
||||
assert report.completion_tokens >= 0
|
||||
assert report.response_latency >= 0
|
||||
assert report.metadata is None
|
||||
@@ -1,130 +0,0 @@
|
||||
"""
|
||||
Unit tests for data loading functionality in solvability/data.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from integrations.solvability.data import available_classifiers, load_classifier
|
||||
from integrations.solvability.models.classifier import SolvabilityClassifier
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
def test_load_classifier_default():
|
||||
"""Test loading the default classifier."""
|
||||
classifier = load_classifier('default-classifier')
|
||||
|
||||
assert isinstance(classifier, SolvabilityClassifier)
|
||||
assert classifier.identifier == 'default-classifier'
|
||||
assert classifier.featurizer is not None
|
||||
assert classifier.classifier is not None
|
||||
|
||||
|
||||
def test_load_classifier_not_found():
|
||||
"""Test loading a non-existent classifier raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError) as exc_info:
|
||||
load_classifier('non-existent-classifier')
|
||||
|
||||
assert "Classifier 'non-existent-classifier' not found" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_available_classifiers():
|
||||
"""Test listing available classifiers."""
|
||||
classifiers = available_classifiers()
|
||||
|
||||
assert isinstance(classifiers, list)
|
||||
assert 'default-classifier' in classifiers
|
||||
assert len(classifiers) >= 1
|
||||
|
||||
|
||||
def test_load_classifier_with_mock_data(solvability_classifier):
|
||||
"""Test loading a classifier with mocked data."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / 'test-classifier.json'
|
||||
|
||||
with test_file.open('w') as f:
|
||||
f.write(solvability_classifier.model_dump_json())
|
||||
|
||||
with patch('integrations.solvability.data.Path') as mock_path:
|
||||
mock_path.return_value.parent = Path(tmpdir)
|
||||
|
||||
classifier = load_classifier('test-classifier')
|
||||
|
||||
assert isinstance(classifier, SolvabilityClassifier)
|
||||
assert classifier.identifier == 'test-classifier'
|
||||
|
||||
|
||||
def test_available_classifiers_with_mock_directory():
|
||||
"""Test listing classifiers in a mocked directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
(tmpdir_path / 'classifier1.json').touch()
|
||||
(tmpdir_path / 'classifier2.json').touch()
|
||||
(tmpdir_path / 'not-a-json.txt').touch()
|
||||
|
||||
with patch('integrations.solvability.data.Path') as mock_path:
|
||||
mock_path.return_value.parent = tmpdir_path
|
||||
|
||||
classifiers = available_classifiers()
|
||||
|
||||
assert len(classifiers) == 2
|
||||
assert 'classifier1' in classifiers
|
||||
assert 'classifier2' in classifiers
|
||||
assert 'not-a-json' not in classifiers
|
||||
|
||||
|
||||
def test_load_classifier_invalid_json():
|
||||
"""Test loading a classifier with invalid JSON content."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / 'invalid-classifier.json'
|
||||
|
||||
with test_file.open('w') as f:
|
||||
f.write('{ invalid json content')
|
||||
|
||||
with patch('integrations.solvability.data.Path') as mock_path:
|
||||
mock_path.return_value.parent = Path(tmpdir)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
load_classifier('invalid-classifier')
|
||||
|
||||
|
||||
def test_load_classifier_valid_json_invalid_schema():
|
||||
"""Test loading a classifier with valid JSON but invalid schema."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / 'invalid-schema.json'
|
||||
|
||||
with test_file.open('w') as f:
|
||||
json.dump({'not': 'a valid classifier'}, f)
|
||||
|
||||
with patch('integrations.solvability.data.Path') as mock_path:
|
||||
mock_path.return_value.parent = Path(tmpdir)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
load_classifier('invalid-schema')
|
||||
|
||||
|
||||
def test_available_classifiers_empty_directory():
|
||||
"""Test listing classifiers in an empty directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with patch('integrations.solvability.data.Path') as mock_path:
|
||||
mock_path.return_value.parent = Path(tmpdir)
|
||||
|
||||
classifiers = available_classifiers()
|
||||
|
||||
assert classifiers == []
|
||||
|
||||
|
||||
def test_load_classifier_path_construction():
|
||||
"""Test that the classifier path is constructed correctly."""
|
||||
with patch('integrations.solvability.data.Path') as mock_path:
|
||||
mock_parent = mock_path.return_value.parent
|
||||
mock_parent.__truediv__.return_value.exists.return_value = False
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_classifier('test-name')
|
||||
|
||||
mock_parent.__truediv__.assert_called_once_with('test-name.json')
|
||||
@@ -1,266 +0,0 @@
|
||||
import pytest
|
||||
from integrations.solvability.models.featurizer import Feature, FeatureEmbedding
|
||||
|
||||
|
||||
def test_feature_to_tool_description_field():
|
||||
"""Test to_tool_description_field property."""
|
||||
feature = Feature(identifier='test', description='Test description')
|
||||
field = feature.to_tool_description_field
|
||||
|
||||
# There's not much structure here, but we can check the expected type and make
|
||||
# sure the other fields are propagated.
|
||||
assert field['type'] == 'boolean'
|
||||
assert field['description'] == 'Test description'
|
||||
|
||||
|
||||
def test_feature_embedding_dimensions(feature_embedding):
|
||||
"""Test dimensions property."""
|
||||
dimensions = feature_embedding.dimensions
|
||||
assert isinstance(dimensions, list)
|
||||
assert set(dimensions) == {'feature1', 'feature2', 'feature3'}
|
||||
|
||||
|
||||
def test_feature_embedding_coefficients(feature_embedding):
|
||||
"""Test coefficient method."""
|
||||
# These values are manually computed from the results in the fixture's samples.
|
||||
assert feature_embedding.coefficient('feature1') == 0.5
|
||||
assert feature_embedding.coefficient('feature2') == 0.5
|
||||
assert feature_embedding.coefficient('feature3') == 1.0
|
||||
|
||||
# Non-existent features should not have a coefficient.
|
||||
assert feature_embedding.coefficient('non_existent') is None
|
||||
|
||||
|
||||
def test_featurizer_system_message(featurizer):
|
||||
"""Test system_message method."""
|
||||
message = featurizer.system_message()
|
||||
assert message['role'] == 'system'
|
||||
assert message['content'] == 'Test system prompt'
|
||||
|
||||
|
||||
def test_featurizer_user_message(featurizer):
|
||||
"""Test user_message method."""
|
||||
# With cache
|
||||
message = featurizer.user_message('Test issue', set_cache=True)
|
||||
assert message['role'] == 'user'
|
||||
assert message['content'] == 'Test message prefix: Test issue'
|
||||
assert 'cache_control' in message
|
||||
assert message['cache_control']['type'] == 'ephemeral'
|
||||
|
||||
# Without cache
|
||||
message = featurizer.user_message('Test issue', set_cache=False)
|
||||
assert message['role'] == 'user'
|
||||
assert message['content'] == 'Test message prefix: Test issue'
|
||||
assert 'cache_control' not in message
|
||||
|
||||
|
||||
def test_featurizer_tool_choice(featurizer):
|
||||
"""Test tool_choice property."""
|
||||
tool_choice = featurizer.tool_choice
|
||||
assert tool_choice['type'] == 'function'
|
||||
assert tool_choice['function']['name'] == 'call_featurizer'
|
||||
|
||||
|
||||
def test_featurizer_tool_description(featurizer):
|
||||
"""Test tool_description property."""
|
||||
tool_desc = featurizer.tool_description
|
||||
assert tool_desc['type'] == 'function'
|
||||
assert tool_desc['function']['name'] == 'call_featurizer'
|
||||
assert 'description' in tool_desc['function']
|
||||
|
||||
# Check that all features are included in the properties
|
||||
properties = tool_desc['function']['parameters']['properties']
|
||||
for feature in featurizer.features:
|
||||
assert feature.identifier in properties
|
||||
assert properties[feature.identifier]['type'] == 'boolean'
|
||||
assert properties[feature.identifier]['description'] == feature.description
|
||||
|
||||
|
||||
@pytest.mark.parametrize('samples', [1, 10, 100])
|
||||
def test_featurizer_embed(samples, featurizer, mock_llm_config):
|
||||
"""Test the embed method to ensure it generates the right number of samples and computes the metadata correctly."""
|
||||
embedding = featurizer.embed(
|
||||
'Test issue', llm_config=mock_llm_config, samples=samples
|
||||
)
|
||||
|
||||
# We should get the right number of samples.
|
||||
assert len(embedding.samples) == samples
|
||||
|
||||
# Because of the mocks, all the samples should be the same (and be correct).
|
||||
assert all(sample == embedding.samples[0] for sample in embedding.samples)
|
||||
assert embedding.samples[0]['feature1'] is True
|
||||
assert embedding.samples[0]['feature2'] is False
|
||||
assert embedding.samples[0]['feature3'] is True
|
||||
|
||||
# And all the metadata should be correct (we know the token counts because
|
||||
# they're mocked, so just count once per sample).
|
||||
assert embedding.prompt_tokens == 10 * samples
|
||||
assert embedding.completion_tokens == 5 * samples
|
||||
|
||||
# These timings are real, so best we can do is check that they're positive.
|
||||
assert embedding.response_latency > 0.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('samples', [1, 10, 100])
|
||||
@pytest.mark.parametrize('batch_size', [1, 10, 100])
|
||||
def test_featurizer_embed_batch(samples, batch_size, featurizer, mock_llm_config):
|
||||
"""Test the embed_batch method to ensure it correctly handles all issues in the batch."""
|
||||
embeddings = featurizer.embed_batch(
|
||||
[f'Issue {i}' for i in range(batch_size)],
|
||||
llm_config=mock_llm_config,
|
||||
samples=samples,
|
||||
)
|
||||
|
||||
# Make sure that we get an embedding for each issue.
|
||||
assert len(embeddings) == batch_size
|
||||
|
||||
# Since the embeddings are computed from a mocked completionc all, they should
|
||||
# all be the same. We can check that they're well-formatted by applying the same
|
||||
# checks as in `test_featurizer_embed`.
|
||||
for embedding in embeddings:
|
||||
assert all(sample == embedding.samples[0] for sample in embedding.samples)
|
||||
assert embedding.samples[0]['feature1'] is True
|
||||
assert embedding.samples[0]['feature2'] is False
|
||||
assert embedding.samples[0]['feature3'] is True
|
||||
|
||||
assert len(embedding.samples) == samples
|
||||
assert embedding.prompt_tokens == 10 * samples
|
||||
assert embedding.completion_tokens == 5 * samples
|
||||
assert embedding.response_latency >= 0.0
|
||||
|
||||
|
||||
def test_featurizer_embed_batch_thread_safety(featurizer, mock_llm_config, monkeypatch):
|
||||
"""Test embed_batch maintains correct ordering and handles concurrent execution safely."""
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create unique responses for each issue to verify ordering
|
||||
def create_mock_response(issue_index):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.tool_calls = [MagicMock()]
|
||||
# Each issue gets a unique feature pattern based on its index
|
||||
mock_response.choices[0].message.tool_calls[0].function.arguments = (
|
||||
f'{{"feature1": {str(issue_index % 2 == 0).lower()}, '
|
||||
f'"feature2": {str(issue_index % 3 == 0).lower()}, '
|
||||
f'"feature3": {str(issue_index % 5 == 0).lower()}}}'
|
||||
)
|
||||
mock_response.usage.prompt_tokens = 10 + issue_index
|
||||
mock_response.usage.completion_tokens = 5 + issue_index
|
||||
return mock_response
|
||||
|
||||
# Track call order and add delays to simulate varying processing times
|
||||
call_count = 0
|
||||
call_order = []
|
||||
|
||||
def mock_completion(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
# Extract issue index from the message content
|
||||
messages = kwargs.get('messages', args[0] if args else [])
|
||||
message_content = messages[1]['content']
|
||||
issue_index = int(message_content.split('Issue ')[-1])
|
||||
call_order.append(issue_index)
|
||||
|
||||
# Add varying delays to simulate real-world conditions
|
||||
# Later issues process faster to test race conditions
|
||||
delay = 0.01 * (20 - issue_index)
|
||||
time.sleep(delay)
|
||||
|
||||
call_count += 1
|
||||
return create_mock_response(issue_index)
|
||||
|
||||
def mock_llm_class(*args, **kwargs):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_llm_instance.completion = mock_completion
|
||||
return mock_llm_instance
|
||||
|
||||
monkeypatch.setattr(
|
||||
'integrations.solvability.models.featurizer.LLM', mock_llm_class
|
||||
)
|
||||
|
||||
# Test with a large enough batch to stress concurrency
|
||||
batch_size = 20
|
||||
issues = [f'Issue {i}' for i in range(batch_size)]
|
||||
|
||||
embeddings = featurizer.embed_batch(issues, llm_config=mock_llm_config, samples=1)
|
||||
|
||||
# Verify we got all embeddings
|
||||
assert len(embeddings) == batch_size
|
||||
|
||||
# Verify each embedding corresponds to its correct issue index
|
||||
for i, embedding in enumerate(embeddings):
|
||||
assert len(embedding.samples) == 1
|
||||
sample = embedding.samples[0]
|
||||
|
||||
# Check the unique pattern matches the issue index
|
||||
assert sample['feature1'] == (i % 2 == 0)
|
||||
assert sample['feature2'] == (i % 3 == 0)
|
||||
assert sample['feature3'] == (i % 5 == 0)
|
||||
|
||||
# Check token counts match
|
||||
assert embedding.prompt_tokens == 10 + i
|
||||
assert embedding.completion_tokens == 5 + i
|
||||
|
||||
# Verify all issues were processed
|
||||
assert call_count == batch_size
|
||||
assert len(set(call_order)) == batch_size # All unique indices
|
||||
|
||||
|
||||
def test_featurizer_embed_batch_exception_handling(
|
||||
featurizer, mock_llm_config, monkeypatch
|
||||
):
|
||||
"""Test embed_batch handles exceptions in individual tasks correctly."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
def mock_completion(*args, **kwargs):
|
||||
# Extract issue index from the message content
|
||||
messages = kwargs.get('messages', args[0] if args else [])
|
||||
message_content = messages[1]['content']
|
||||
issue_index = int(message_content.split('Issue ')[-1])
|
||||
|
||||
# Make some issues fail
|
||||
if issue_index in [2, 5, 7]:
|
||||
raise ValueError(f'Simulated error for issue {issue_index}')
|
||||
|
||||
# Return normal response for others
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.tool_calls = [MagicMock()]
|
||||
mock_response.choices[0].message.tool_calls[
|
||||
0
|
||||
].function.arguments = '{"feature1": true, "feature2": false, "feature3": true}'
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
return mock_response
|
||||
|
||||
def mock_llm_class(*args, **kwargs):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_llm_instance.completion = mock_completion
|
||||
return mock_llm_instance
|
||||
|
||||
monkeypatch.setattr(
|
||||
'integrations.solvability.models.featurizer.LLM', mock_llm_class
|
||||
)
|
||||
|
||||
issues = [f'Issue {i}' for i in range(10)]
|
||||
|
||||
# The method should raise an exception when any task fails
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
featurizer.embed_batch(issues, llm_config=mock_llm_config, samples=1)
|
||||
|
||||
# Verify it's one of our expected errors
|
||||
assert 'Simulated error for issue' in str(exc_info.value)
|
||||
|
||||
|
||||
def test_featurizer_embed_batch_no_none_values(featurizer, mock_llm_config):
|
||||
"""Test that embed_batch never returns None values in the results list."""
|
||||
# Test with various batch sizes to ensure no None values slip through
|
||||
for batch_size in [1, 5, 10, 20]:
|
||||
issues = [f'Issue {i}' for i in range(batch_size)]
|
||||
embeddings = featurizer.embed_batch(
|
||||
issues, llm_config=mock_llm_config, samples=1
|
||||
)
|
||||
|
||||
# Verify no None values in results
|
||||
assert all(embedding is not None for embedding in embeddings)
|
||||
assert all(isinstance(embedding, FeatureEmbedding) for embedding in embeddings)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user