mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fbb48c406 |
12
.github/CODEOWNERS
vendored
12
.github/CODEOWNERS
vendored
@@ -1,8 +1,12 @@
|
||||
# CODEOWNERS file for OpenHands repository
|
||||
# See https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
|
||||
|
||||
/frontend/ @amanape @hieptl
|
||||
/openhands-ui/ @amanape @hieptl
|
||||
/openhands/ @tofarr @malhotra5 @hieptl
|
||||
/enterprise/ @chuckbutkus @tofarr @malhotra5
|
||||
# Frontend code owners
|
||||
/frontend/ @amanape
|
||||
/openhands-ui/ @amanape
|
||||
|
||||
# Evaluation code owners
|
||||
/evaluation/ @xingyaoww @neubig
|
||||
|
||||
# Documentation code owners
|
||||
/docs/ @mamoodi
|
||||
|
||||
13
.github/scripts/update_pr_description.sh
vendored
13
.github/scripts/update_pr_description.sh
vendored
@@ -17,6 +17,9 @@ DOCKER_RUN_COMMAND="docker run -it --rm \
|
||||
--name openhands-app-${SHORT_SHA} \
|
||||
docker.openhands.dev/openhands/openhands:${SHORT_SHA}"
|
||||
|
||||
# Define the uvx command
|
||||
UVX_RUN_COMMAND="uvx --python 3.12 --from git+https://github.com/OpenHands/OpenHands@${BRANCH_NAME}#subdirectory=openhands-cli openhands"
|
||||
|
||||
# Get the current PR body
|
||||
PR_BODY=$(gh pr view "$PR_NUMBER" --json body --jq .body)
|
||||
|
||||
@@ -34,6 +37,11 @@ GUI with Docker:
|
||||
\`\`\`
|
||||
${DOCKER_RUN_COMMAND}
|
||||
\`\`\`
|
||||
|
||||
CLI with uvx:
|
||||
\`\`\`
|
||||
${UVX_RUN_COMMAND}
|
||||
\`\`\`
|
||||
EOF
|
||||
)
|
||||
else
|
||||
@@ -49,6 +57,11 @@ GUI with Docker:
|
||||
\`\`\`
|
||||
${DOCKER_RUN_COMMAND}
|
||||
\`\`\`
|
||||
|
||||
CLI with uvx:
|
||||
\`\`\`
|
||||
${UVX_RUN_COMMAND}
|
||||
\`\`\`
|
||||
EOF
|
||||
)
|
||||
fi
|
||||
|
||||
2
.github/workflows/check-package-versions.yml
vendored
2
.github/workflows/check-package-versions.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
|
||||
69
.github/workflows/clean-up.yml
vendored
Normal file
69
.github/workflows/clean-up.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
# Workflow that cleans up outdated and old workflows to prevent out of disk issues
|
||||
name: Delete old workflow runs
|
||||
|
||||
# This workflow is currently only triggered manually
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
days:
|
||||
description: 'Days-worth of runs to keep for each workflow'
|
||||
required: true
|
||||
default: '30'
|
||||
minimum_runs:
|
||||
description: 'Minimum runs to keep for each workflow'
|
||||
required: true
|
||||
default: '10'
|
||||
delete_workflow_pattern:
|
||||
description: 'Name or filename of the workflow (if not set, all workflows are targeted)'
|
||||
required: false
|
||||
delete_workflow_by_state_pattern:
|
||||
description: 'Filter workflows by state: active, deleted, disabled_fork, disabled_inactivity, disabled_manually'
|
||||
required: true
|
||||
default: "ALL"
|
||||
type: choice
|
||||
options:
|
||||
- "ALL"
|
||||
- active
|
||||
- deleted
|
||||
- disabled_inactivity
|
||||
- disabled_manually
|
||||
delete_run_by_conclusion_pattern:
|
||||
description: 'Remove runs based on conclusion: action_required, cancelled, failure, skipped, success'
|
||||
required: true
|
||||
default: 'ALL'
|
||||
type: choice
|
||||
options:
|
||||
- 'ALL'
|
||||
- 'Unsuccessful: action_required,cancelled,failure,skipped'
|
||||
- action_required
|
||||
- cancelled
|
||||
- failure
|
||||
- skipped
|
||||
- success
|
||||
dry_run:
|
||||
description: 'Logs simulated changes, no deletions are performed'
|
||||
required: false
|
||||
|
||||
jobs:
|
||||
del_runs:
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
permissions:
|
||||
actions: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Delete workflow runs
|
||||
uses: Mattraks/delete-workflow-runs@v2
|
||||
with:
|
||||
token: ${{ github.token }}
|
||||
repository: ${{ github.repository }}
|
||||
retain_days: ${{ github.event.inputs.days }}
|
||||
keep_minimum_runs: ${{ github.event.inputs.minimum_runs }}
|
||||
delete_workflow_pattern: ${{ github.event.inputs.delete_workflow_pattern }}
|
||||
delete_workflow_by_state_pattern: ${{ github.event.inputs.delete_workflow_by_state_pattern }}
|
||||
delete_run_by_conclusion_pattern: >-
|
||||
${{
|
||||
startsWith(github.event.inputs.delete_run_by_conclusion_pattern, 'Unsuccessful:')
|
||||
&& 'action_required,cancelled,failure,skipped'
|
||||
|| github.event.inputs.delete_run_by_conclusion_pattern
|
||||
}}
|
||||
dry_run: ${{ github.event.inputs.dry_run }}
|
||||
122
.github/workflows/cli-build-binary-and-optionally-release.yml
vendored
Normal file
122
.github/workflows/cli-build-binary-and-optionally-release.yml
vendored
Normal file
@@ -0,0 +1,122 @@
|
||||
# Workflow that builds and tests the CLI binary executable
|
||||
name: CLI - Build binary and optionally release
|
||||
|
||||
# Run on pushes to main branch and CLI tags, and on pull requests when CLI files change
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- "*-cli"
|
||||
pull_request:
|
||||
paths:
|
||||
- "openhands-cli/**"
|
||||
|
||||
permissions:
|
||||
contents: write # needed to create releases or upload assets
|
||||
|
||||
# Cancel previous runs if a new commit is pushed
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-binary:
|
||||
name: Build binary executable
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
# Build on Ubuntu 22.04 for maximum GLIBC compatibility (GLIBC 2.31)
|
||||
- os: ubuntu-22.04
|
||||
platform: linux
|
||||
artifact_name: openhands-cli-linux
|
||||
# Build on macOS for macOS users
|
||||
- os: macos-15
|
||||
platform: macos
|
||||
artifact_name: openhands-cli-macos
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.12
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
version: "latest"
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: openhands-cli
|
||||
run: |
|
||||
uv sync
|
||||
|
||||
- name: Build binary executable
|
||||
working-directory: openhands-cli
|
||||
run: |
|
||||
./build.sh --install-pyinstaller | tee output.log
|
||||
echo "Full output:"
|
||||
cat output.log
|
||||
|
||||
if grep -q "❌" output.log; then
|
||||
echo "❌ Found failure marker in output"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ Build & test finished without ❌ markers"
|
||||
|
||||
- name: Verify binary files exist
|
||||
run: |
|
||||
if ! ls openhands-cli/dist/openhands* 1> /dev/null 2>&1; then
|
||||
echo "❌ No binaries found to upload!"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Found binaries to upload."
|
||||
|
||||
- name: Upload binary artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.artifact_name }}
|
||||
path: openhands-cli/dist/openhands*
|
||||
retention-days: 30
|
||||
|
||||
create-github-release:
|
||||
name: Create GitHub Release
|
||||
runs-on: ubuntu-latest
|
||||
needs: build-binary
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
- name: Prepare release assets
|
||||
run: |
|
||||
mkdir -p release-assets
|
||||
# Copy binaries with appropriate names for release
|
||||
if [ -f artifacts/openhands-cli-linux/openhands ]; then
|
||||
cp artifacts/openhands-cli-linux/openhands release-assets/openhands-linux
|
||||
fi
|
||||
if [ -f artifacts/openhands-cli-macos/openhands ]; then
|
||||
cp artifacts/openhands-cli-macos/openhands release-assets/openhands-macos
|
||||
fi
|
||||
ls -la release-assets/
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: release-assets/*
|
||||
draft: true
|
||||
prerelease: false
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
23
.github/workflows/dispatch-to-docs.yml
vendored
Normal file
23
.github/workflows/dispatch-to-docs.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Dispatch to docs repo
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'docs/**'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
dispatch:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
repo: ["OpenHands/docs"]
|
||||
steps:
|
||||
- name: Push to docs repo
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
|
||||
repository: ${{ matrix.repo }}
|
||||
event-type: update
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "module": "openhands", "branch": "main"}'
|
||||
8
.github/workflows/e2e-tests.yml
vendored
8
.github/workflows/e2e-tests.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
poetry-version: 2.1.3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'poetry'
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
sudo apt-get install -y libgtk-3-0 libnotify4 libnss3 libxss1 libxtst6 xauth xvfb libgbm1 libasound2t64 netcat-openbsd
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22'
|
||||
cache: 'npm'
|
||||
@@ -192,7 +192,7 @@ jobs:
|
||||
|
||||
- name: Upload test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: tests/e2e/test-results/
|
||||
@@ -200,7 +200,7 @@ jobs:
|
||||
|
||||
- name: Upload OpenHands logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openhands-logs
|
||||
path: |
|
||||
|
||||
@@ -43,7 +43,7 @@ jobs:
|
||||
⚠️ This PR contains **migrations**
|
||||
|
||||
- name: Comment warning on PR
|
||||
uses: peter-evans/create-or-update-comment@v5
|
||||
uses: peter-evans/create-or-update-comment@v4
|
||||
with:
|
||||
issue-number: ${{ github.event.pull_request.number }}
|
||||
comment-id: ${{ steps.find-comment.outputs.comment-id }}
|
||||
|
||||
47
.github/workflows/fe-e2e-tests.yml
vendored
47
.github/workflows/fe-e2e-tests.yml
vendored
@@ -1,47 +0,0 @@
|
||||
# Workflow that runs frontend e2e tests with Playwright
|
||||
name: Run Frontend E2E Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- "frontend/**"
|
||||
- ".github/workflows/fe-e2e-tests.yml"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
fe-e2e-test:
|
||||
name: FE E2E Tests
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [22]
|
||||
fail-fast: true
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Node.js
|
||||
uses: useblacksmith/setup-node@v5
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
- name: Install dependencies
|
||||
working-directory: ./frontend
|
||||
run: npm ci
|
||||
- name: Install Playwright browsers
|
||||
working-directory: ./frontend
|
||||
run: npx playwright install --with-deps chromium
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./frontend
|
||||
run: npx playwright test --project=chromium
|
||||
- name: Upload Playwright report
|
||||
uses: actions/upload-artifact@v6
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
path: frontend/playwright-report/
|
||||
retention-days: 30
|
||||
10
.github/workflows/ghcr-build.yml
vendored
10
.github/workflows/ghcr-build.yml
vendored
@@ -64,7 +64,7 @@ jobs:
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3.7.0
|
||||
uses: docker/setup-qemu-action@v3.6.0
|
||||
with:
|
||||
image: tonistiigi/binfmt:latest
|
||||
- name: Login to GHCR
|
||||
@@ -102,7 +102,7 @@ jobs:
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3.7.0
|
||||
uses: docker/setup-qemu-action@v3.6.0
|
||||
with:
|
||||
image: tonistiigi/binfmt:latest
|
||||
- name: Login to GHCR
|
||||
@@ -161,7 +161,7 @@ jobs:
|
||||
context: containers/runtime
|
||||
- name: Upload runtime source for fork
|
||||
if: github.event.pull_request.head.repo.fork
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: runtime-src-${{ matrix.base_image.tag }}
|
||||
path: containers/runtime
|
||||
@@ -268,7 +268,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Download runtime source for fork
|
||||
if: github.event.pull_request.head.repo.fork
|
||||
uses: actions/download-artifact@v6
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: runtime-src-${{ matrix.base_image.tag }}
|
||||
path: containers/runtime
|
||||
@@ -330,7 +330,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Download runtime source for fork
|
||||
if: github.event.pull_request.head.repo.fork
|
||||
uses: actions/download-artifact@v6
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: runtime-src-${{ matrix.base_image.tag }}
|
||||
path: containers/runtime
|
||||
|
||||
18
.github/workflows/lint.yml
vendored
18
.github/workflows/lint.yml
vendored
@@ -72,3 +72,21 @@ jobs:
|
||||
- name: Run pre-commit hooks
|
||||
working-directory: ./enterprise
|
||||
run: pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||
|
||||
lint-cli-python:
|
||||
name: Lint CLI python
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: 3.12
|
||||
cache: "pip"
|
||||
- name: Install pre-commit
|
||||
run: pip install pre-commit==4.2.0
|
||||
- name: Run pre-commit hooks
|
||||
working-directory: ./openhands-cli
|
||||
run: pre-commit run --all-files --config ./dev_config/python/.pre-commit-config.yaml
|
||||
|
||||
70
.github/workflows/mdx-lint.yml
vendored
Normal file
70
.github/workflows/mdx-lint.yml
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
# Workflow that checks MDX format in docs/ folder
|
||||
name: MDX Lint
|
||||
|
||||
# Run on pushes to main and on pull requests that modify docs/ files
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'docs/**/*.mdx'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docs/**/*.mdx'
|
||||
|
||||
# If triggered by a PR, it will be in the same group. However, each commit on main will be in its own unique group
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
mdx-lint:
|
||||
name: Lint MDX files
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install Node.js 22
|
||||
uses: useblacksmith/setup-node@v5
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Install MDX dependencies
|
||||
run: |
|
||||
npm install @mdx-js/mdx@3 glob@10
|
||||
|
||||
- name: Validate MDX files
|
||||
run: |
|
||||
node -e "
|
||||
const {compile} = require('@mdx-js/mdx');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const glob = require('glob');
|
||||
|
||||
async function validateMDXFiles() {
|
||||
const files = glob.sync('docs/**/*.mdx');
|
||||
console.log('Found', files.length, 'MDX files to validate');
|
||||
|
||||
let hasErrors = false;
|
||||
|
||||
for (const file of files) {
|
||||
try {
|
||||
const content = fs.readFileSync(file, 'utf8');
|
||||
await compile(content);
|
||||
console.log('✅ MDX parsing successful for', file);
|
||||
} catch (err) {
|
||||
console.error('❌ MDX parsing failed for', file, ':', err.message);
|
||||
hasErrors = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasErrors) {
|
||||
console.error('\\n❌ Some MDX files have parsing errors. Please fix them before merging.');
|
||||
process.exit(1);
|
||||
} else {
|
||||
console.log('\\n✅ All MDX files are valid!');
|
||||
}
|
||||
}
|
||||
|
||||
validateMDXFiles();
|
||||
"
|
||||
6
.github/workflows/openhands-resolver.yml
vendored
6
.github/workflows/openhands-resolver.yml
vendored
@@ -89,7 +89,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- name: Upgrade pip
|
||||
@@ -118,7 +118,7 @@ jobs:
|
||||
contains(github.event.review.body, '@openhands-agent-exp')
|
||||
)
|
||||
)
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ env.pythonLocation }}/lib/python3.12/site-packages/*
|
||||
key: ${{ runner.os }}-pip-openhands-resolver-${{ hashFiles('/tmp/requirements.txt') }}
|
||||
@@ -269,7 +269,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Upload output.jsonl as artifact
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
if: always() # Upload even if the previous steps fail
|
||||
with:
|
||||
name: resolver-output
|
||||
|
||||
56
.github/workflows/py-tests.yml
vendored
56
.github/workflows/py-tests.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
env:
|
||||
COVERAGE_FILE: ".coverage.runtime.${{ matrix.python_version }}"
|
||||
- name: Store coverage file
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-openhands
|
||||
path: |
|
||||
@@ -95,17 +95,62 @@ jobs:
|
||||
env:
|
||||
COVERAGE_FILE: ".coverage.enterprise.${{ matrix.python_version }}"
|
||||
- name: Store coverage file
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-enterprise
|
||||
path: ".coverage.enterprise.${{ matrix.python_version }}"
|
||||
include-hidden-files: true
|
||||
|
||||
# Run CLI unit tests
|
||||
test-cli-python:
|
||||
name: CLI Unit Tests
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2404
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
version: "latest"
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: ./openhands-cli
|
||||
run: |
|
||||
uv sync --group dev
|
||||
|
||||
- name: Run CLI unit tests
|
||||
working-directory: ./openhands-cli
|
||||
env:
|
||||
# write coverage to repo root so the merge step finds it
|
||||
COVERAGE_FILE: "${{ github.workspace }}/.coverage.openhands-cli.${{ matrix.python-version }}"
|
||||
run: |
|
||||
uv run pytest --forked -n auto -s \
|
||||
-p no:ddtrace -p no:ddtrace.pytest_bdd -p no:ddtrace.pytest_benchmark \
|
||||
tests --cov=openhands_cli --cov-branch
|
||||
|
||||
- name: Store coverage file
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-openhands-cli
|
||||
path: ".coverage.openhands-cli.${{ matrix.python-version }}"
|
||||
include-hidden-files: true
|
||||
|
||||
coverage-comment:
|
||||
name: Coverage Comment
|
||||
if: github.event_name == 'pull_request'
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test-on-linux, test-enterprise]
|
||||
needs: [test-on-linux, test-enterprise, test-cli-python]
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
@@ -113,12 +158,15 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/download-artifact@v6
|
||||
- uses: actions/download-artifact@v5
|
||||
id: download
|
||||
with:
|
||||
pattern: coverage-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Create symlink for CLI source files
|
||||
run: ln -sf openhands-cli/openhands_cli openhands_cli
|
||||
|
||||
- name: Coverage comment
|
||||
id: coverage_comment
|
||||
uses: py-cov-action/python-coverage-comment-action@v3
|
||||
|
||||
34
.github/workflows/pypi-release.yml
vendored
34
.github/workflows/pypi-release.yml
vendored
@@ -10,6 +10,7 @@ on:
|
||||
type: choice
|
||||
options:
|
||||
- app server
|
||||
- cli
|
||||
default: app server
|
||||
push:
|
||||
tags:
|
||||
@@ -38,3 +39,36 @@ jobs:
|
||||
run: ./build.sh
|
||||
- name: publish
|
||||
run: poetry publish -u __token__ -p ${{ secrets.PYPI_TOKEN }}
|
||||
|
||||
release-cli:
|
||||
name: Publish CLI to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
# Run when manually dispatched for "cli" OR for tag pushes that contain '-cli'
|
||||
if: |
|
||||
(github.event_name == 'workflow_dispatch' && github.event.inputs.reason == 'cli')
|
||||
|| (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') && contains(github.ref, '-cli'))
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.12
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
version: "latest"
|
||||
|
||||
- name: Build CLI package
|
||||
working-directory: openhands-cli
|
||||
run: |
|
||||
# Clean dist directory to avoid conflicts with binary builds
|
||||
rm -rf dist/
|
||||
uv build
|
||||
|
||||
- name: Publish CLI to PyPI
|
||||
working-directory: openhands-cli
|
||||
run: |
|
||||
uv publish --token ${{ secrets.PYPI_TOKEN_OPENHANDS }}
|
||||
|
||||
135
.github/workflows/run-eval.yml
vendored
Normal file
135
.github/workflows/run-eval.yml
vendored
Normal file
@@ -0,0 +1,135 @@
|
||||
# Run evaluation on a PR, after releases, or manually
|
||||
name: Run Eval
|
||||
|
||||
# Runs when a PR is labeled with one of the "run-eval-" labels, after releases, or manually triggered
|
||||
on:
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to evaluate'
|
||||
required: true
|
||||
default: 'main'
|
||||
eval_instances:
|
||||
description: 'Number of evaluation instances'
|
||||
required: true
|
||||
default: '50'
|
||||
type: choice
|
||||
options:
|
||||
- '1'
|
||||
- '2'
|
||||
- '50'
|
||||
- '100'
|
||||
reason:
|
||||
description: 'Reason for manual trigger'
|
||||
required: false
|
||||
default: ''
|
||||
|
||||
env:
|
||||
# Environment variable for the master GitHub issue number where all evaluation results will be commented
|
||||
# This should be set to the issue number where you want all evaluation results to be posted
|
||||
MASTER_EVAL_ISSUE_NUMBER: ${{ vars.MASTER_EVAL_ISSUE_NUMBER || '0' }}
|
||||
|
||||
jobs:
|
||||
trigger-job:
|
||||
name: Trigger remote eval job
|
||||
if: ${{ (github.event_name == 'pull_request' && (github.event.label.name == 'run-eval-1' || github.event.label.name == 'run-eval-2' || github.event.label.name == 'run-eval-50' || github.event.label.name == 'run-eval-100')) || github.event_name == 'release' || github.event_name == 'workflow_dispatch' }}
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
|
||||
steps:
|
||||
- name: Checkout branch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.head_ref || (github.event_name == 'workflow_dispatch' && github.event.inputs.branch) || github.ref }}
|
||||
|
||||
- name: Set evaluation parameters
|
||||
id: eval_params
|
||||
run: |
|
||||
REPO_URL="https://github.com/${{ github.repository }}"
|
||||
echo "Repository URL: $REPO_URL"
|
||||
|
||||
# Determine branch based on trigger type
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||
EVAL_BRANCH="${{ github.head_ref }}"
|
||||
echo "PR Branch: $EVAL_BRANCH"
|
||||
elif [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||
EVAL_BRANCH="${{ github.event.inputs.branch }}"
|
||||
echo "Manual Branch: $EVAL_BRANCH"
|
||||
else
|
||||
# For release events, use the tag name or main branch
|
||||
EVAL_BRANCH="${{ github.ref_name }}"
|
||||
echo "Release Branch/Tag: $EVAL_BRANCH"
|
||||
fi
|
||||
|
||||
# Determine evaluation instances based on trigger type
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||
if [[ "${{ github.event.label.name }}" == "run-eval-1" ]]; then
|
||||
EVAL_INSTANCES="1"
|
||||
elif [[ "${{ github.event.label.name }}" == "run-eval-2" ]]; then
|
||||
EVAL_INSTANCES="2"
|
||||
elif [[ "${{ github.event.label.name }}" == "run-eval-50" ]]; then
|
||||
EVAL_INSTANCES="50"
|
||||
elif [[ "${{ github.event.label.name }}" == "run-eval-100" ]]; then
|
||||
EVAL_INSTANCES="100"
|
||||
fi
|
||||
elif [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||
EVAL_INSTANCES="${{ github.event.inputs.eval_instances }}"
|
||||
else
|
||||
# For release events, default to 50 instances
|
||||
EVAL_INSTANCES="50"
|
||||
fi
|
||||
|
||||
echo "Evaluation instances: $EVAL_INSTANCES"
|
||||
echo "repo_url=$REPO_URL" >> $GITHUB_OUTPUT
|
||||
echo "eval_branch=$EVAL_BRANCH" >> $GITHUB_OUTPUT
|
||||
echo "eval_instances=$EVAL_INSTANCES" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Trigger remote job
|
||||
run: |
|
||||
# Determine PR number for the remote evaluation system
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||
PR_NUMBER="${{ github.event.pull_request.number }}"
|
||||
else
|
||||
# For non-PR triggers, use the master issue number as PR number
|
||||
PR_NUMBER="${{ env.MASTER_EVAL_ISSUE_NUMBER }}"
|
||||
fi
|
||||
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"github-repo\": \"${{ steps.eval_params.outputs.repo_url }}\", \"github-branch\": \"${{ steps.eval_params.outputs.eval_branch }}\", \"pr-number\": \"${PR_NUMBER}\", \"eval-instances\": \"${{ steps.eval_params.outputs.eval_instances }}\"}}" \
|
||||
https://api.github.com/repos/OpenHands/evaluation/actions/workflows/create-branch.yml/dispatches
|
||||
|
||||
# Send Slack message
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||
TRIGGER_URL="https://github.com/${{ github.repository }}/pull/${{ github.event.pull_request.number }}"
|
||||
slack_text="PR $TRIGGER_URL has triggered evaluation on ${{ steps.eval_params.outputs.eval_instances }} instances..."
|
||||
elif [[ "${{ github.event_name }}" == "release" ]]; then
|
||||
TRIGGER_URL="https://github.com/${{ github.repository }}/releases/tag/${{ github.ref_name }}"
|
||||
slack_text="Release $TRIGGER_URL has triggered evaluation on ${{ steps.eval_params.outputs.eval_instances }} instances..."
|
||||
else
|
||||
TRIGGER_URL="https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
|
||||
slack_text="Manual trigger (${{ github.event.inputs.reason || 'No reason provided' }}) has triggered evaluation on ${{ steps.eval_params.outputs.eval_instances }} instances for branch ${{ steps.eval_params.outputs.eval_branch }}..."
|
||||
fi
|
||||
|
||||
curl -X POST -H 'Content-type: application/json' --data '{"text":"'"$slack_text"'"}' \
|
||||
https://hooks.slack.com/services/${{ secrets.SLACK_TOKEN }}
|
||||
|
||||
- name: Comment on issue/PR
|
||||
uses: KeisukeYamashita/create-comment@v1
|
||||
with:
|
||||
# For PR triggers, comment on the PR. For other triggers, comment on the master issue
|
||||
number: ${{ github.event_name == 'pull_request' && github.event.pull_request.number || env.MASTER_EVAL_ISSUE_NUMBER }}
|
||||
unique: false
|
||||
comment: |
|
||||
**Evaluation Triggered**
|
||||
|
||||
**Trigger:** ${{ github.event_name == 'pull_request' && format('Pull Request #{0}', github.event.pull_request.number) || (github.event_name == 'release' && 'Release') || format('Manual Trigger: {0}', github.event.inputs.reason || 'No reason provided') }}
|
||||
**Branch:** ${{ steps.eval_params.outputs.eval_branch }}
|
||||
**Instances:** ${{ steps.eval_params.outputs.eval_instances }}
|
||||
**Commit:** ${{ github.sha }}
|
||||
|
||||
Running evaluation on the specified branch. Once eval is done, the results will be posted here.
|
||||
6
.github/workflows/vscode-extension-build.yml
vendored
6
.github/workflows/vscode-extension-build.yml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
node-version: '22'
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Upload VSCode extension artifact
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: vscode-extension
|
||||
path: openhands/integrations/vscode/openhands-vscode-0.0.1.vsix
|
||||
@@ -142,7 +142,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Download .vsix artifact
|
||||
uses: actions/download-artifact@v6
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: vscode-extension
|
||||
path: ./
|
||||
|
||||
@@ -63,7 +63,7 @@ Frontend:
|
||||
- We use TanStack Query (fka React Query) for data fetching and cache management
|
||||
- Data Access Layer: API client methods are located in `frontend/src/api` and should never be called directly from UI components - they must always be wrapped with TanStack Query
|
||||
- Custom hooks are located in `frontend/src/hooks/query/` and `frontend/src/hooks/mutation/`
|
||||
- Query hooks should follow the pattern use[Resource] (e.g., `useConversationSkills`)
|
||||
- Query hooks should follow the pattern use[Resource] (e.g., `useConversationMicroagents`)
|
||||
- Mutation hooks should follow the pattern use[Action] (e.g., `useDeleteConversation`)
|
||||
- Architecture rule: UI components → TanStack Query hooks → Data Access Layer (`frontend/src/api`) → API endpoints
|
||||
|
||||
|
||||
60
COMMUNITY.md
60
COMMUNITY.md
@@ -1,45 +1,43 @@
|
||||
# The OpenHands Community
|
||||
# 🙌 The OpenHands Community
|
||||
|
||||
OpenHands is a community of engineers, academics, and enthusiasts reimagining software development for an AI-powered world.
|
||||
The OpenHands community is built around the belief that (1) AI and AI agents are going to fundamentally change the way
|
||||
we build software, and (2) if this is true, we should do everything we can to make sure that the benefits provided by
|
||||
such powerful technology are accessible to everyone.
|
||||
|
||||
## Mission
|
||||
If this resonates with you, we'd love to have you join us in our quest!
|
||||
|
||||
It’s very clear that AI is changing software development. We want the developer community to drive that change organically, through open source.
|
||||
## 🤝 How to Join
|
||||
|
||||
So we’re not just building friendly interfaces for AI-driven development. We’re publishing _building blocks_ that empower developers to create new experiences, tailored to your own habits, needs, and imagination.
|
||||
Check out our [How to Join the Community section.](https://github.com/OpenHands/OpenHands?tab=readme-ov-file#-how-to-join-the-community)
|
||||
|
||||
## Ethos
|
||||
## 💪 Becoming a Contributor
|
||||
|
||||
We have two core values: **high openness** and **high agency**. While we don’t expect everyone in the community to embody these values, we want to establish them as norms.
|
||||
We welcome contributions from everyone! Whether you're a developer, a researcher, or simply enthusiastic about advancing
|
||||
the field of software engineering with AI, there are many ways to get involved:
|
||||
|
||||
### High Openness
|
||||
- **Code Contributions:** Help us develop new core functionality, improve our agents, improve the frontend and other
|
||||
interfaces, or anything else that would help make OpenHands better.
|
||||
- **Research and Evaluation:** Contribute to our understanding of LLMs in software engineering, participate in
|
||||
evaluating the models, or suggest improvements.
|
||||
- **Feedback and Testing:** Use the OpenHands toolset, report bugs, suggest features, or provide feedback on usability.
|
||||
|
||||
We welcome anyone and everyone into our community by default. You don’t have to be a software developer to help us build. You don’t have to be pro-AI to help us learn.
|
||||
For details, please check [CONTRIBUTING.md](./CONTRIBUTING.md).
|
||||
|
||||
Our plans, our work, our successes, and our failures are all public record. We want the world to see not just the fruits of our work, but the whole process of growing it.
|
||||
## Code of Conduct
|
||||
|
||||
We welcome thoughtful criticism, whether it’s a comment on a PR or feedback on the community as a whole.
|
||||
We have a [Code of Conduct](./CODE_OF_CONDUCT.md) that we expect all contributors to adhere to.
|
||||
Long story short, we are aiming for an open, welcoming, diverse, inclusive, and healthy community.
|
||||
All contributors are expected to contribute to building this sort of community.
|
||||
|
||||
### High Agency
|
||||
## 🛠️ Becoming a Maintainer
|
||||
|
||||
Everyone should feel empowered to contribute to OpenHands. Whether it’s by making a PR, hosting an event, sharing feedback, or just asking a question, don’t hold back!
|
||||
For contributors who have made significant and sustained contributions to the project, there is a possibility of joining
|
||||
the maintainer team. The process for this is as follows:
|
||||
|
||||
OpenHands gives everyone the building blocks to create state-of-the-art developer experiences. We experiment constantly and love building new things.
|
||||
1. Any contributor who has made sustained and high-quality contributions to the codebase can be nominated by any
|
||||
maintainer. If you feel that you may qualify you can reach out to any of the maintainers that have reviewed your PRs and ask if you can be nominated.
|
||||
2. Once a maintainer nominates a new maintainer, there will be a discussion period among the maintainers for at least 3 days.
|
||||
3. If no concerns are raised the nomination will be accepted by acclamation, and if concerns are raised there will be a discussion and possible vote.
|
||||
|
||||
Coding, development practices, and communities are changing rapidly. We won’t hesitate to change direction and make big bets.
|
||||
|
||||
## Relationship to All Hands
|
||||
|
||||
OpenHands is supported by the for-profit organization [All Hands AI, Inc](https://www.all-hands.dev/).
|
||||
|
||||
All Hands was founded by three of the first major contributors to OpenHands:
|
||||
|
||||
- Xingyao Wang, a UIUC PhD candidate who got OpenHands to the top of the SWE-bench leaderboards
|
||||
- Graham Neubig, a CMU Professor who rallied the academic community around OpenHands
|
||||
- Robert Brennan, a software engineer who architected the user-facing features of OpenHands
|
||||
|
||||
All Hands is an important part of the OpenHands ecosystem. We’ve raised over $20M--mainly to hire developers and researchers who can work on OpenHands full-time, and to provide them with expensive infrastructure. ([Join us!](https://allhandsai.applytojob.com/apply/))
|
||||
|
||||
But we see OpenHands as much larger, and ultimately more important, than All Hands. When our financial responsibility to investors is at odds with our social responsibility to the community—as it inevitably will be, from time to time—we promise to navigate that conflict thoughtfully and transparently.
|
||||
|
||||
At some point, we may transfer custody of OpenHands to an open source foundation. But for now, the [Benevolent Dictator approach](http://www.catb.org/~esr/writings/cathedral-bazaar/homesteading/ar01s16.html) helps us move forward with speed and intention. If we ever forget the “benevolent” part, please: fork us.
|
||||
Note that just making many PRs does not immediately imply that you will become a maintainer. We will be looking
|
||||
at sustained high-quality contributions over a period of time, as well as good teamwork and adherence to our [Code of Conduct](./CODE_OF_CONDUCT.md).
|
||||
|
||||
@@ -91,14 +91,14 @@ make run
|
||||
#### Option B: Individual Server Startup
|
||||
|
||||
- **Start the Backend Server:** If you prefer, you can start the backend server independently to focus on
|
||||
backend-related tasks or configurations.
|
||||
backend-related tasks or configurations.
|
||||
|
||||
```bash
|
||||
make start-backend
|
||||
```
|
||||
|
||||
- **Start the Frontend Server:** Similarly, you can start the frontend server on its own to work on frontend-related
|
||||
components or interface enhancements.
|
||||
components or interface enhancements.
|
||||
```bash
|
||||
make start-frontend
|
||||
```
|
||||
@@ -110,7 +110,6 @@ You can use OpenHands to develop and improve OpenHands itself! This is a powerfu
|
||||
#### Quick Start
|
||||
|
||||
1. **Build and run OpenHands:**
|
||||
|
||||
```bash
|
||||
export INSTALL_DOCKER=0
|
||||
export RUNTIME=local
|
||||
@@ -118,7 +117,6 @@ You can use OpenHands to develop and improve OpenHands itself! This is a powerfu
|
||||
```
|
||||
|
||||
2. **Access the interface:**
|
||||
|
||||
- Local development: http://localhost:3001
|
||||
- Remote/cloud environments: Use the appropriate external URL
|
||||
|
||||
@@ -161,7 +159,7 @@ poetry run pytest ./tests/unit/test_*.py
|
||||
To reduce build time (e.g., if no changes were made to the client-runtime component), you can use an existing Docker
|
||||
container image by setting the SANDBOX_RUNTIME_CONTAINER_IMAGE environment variable to the desired Docker image.
|
||||
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.1-nikolaik`
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:0.62-nikolaik`
|
||||
|
||||
## Develop inside Docker container
|
||||
|
||||
@@ -201,6 +199,6 @@ Here's a guide to the important documentation files in the repository:
|
||||
- [/containers/README.md](./containers/README.md): Information about Docker containers and deployment
|
||||
- [/tests/unit/README.md](./tests/unit/README.md): Guide to writing and running unit tests
|
||||
- [/evaluation/README.md](./evaluation/README.md): Documentation for the evaluation framework and benchmarks
|
||||
- [/skills/README.md](./skills/README.md): Information about the skills architecture and implementation
|
||||
- [/microagents/README.md](./microagents/README.md): Information about the microagents architecture and implementation
|
||||
- [/openhands/server/README.md](./openhands/server/README.md): Server implementation details and API documentation
|
||||
- [/openhands/runtime/README.md](./openhands/runtime/README.md): Documentation for the runtime environment and execution model
|
||||
|
||||
186
README.md
186
README.md
@@ -1,18 +1,22 @@
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<div align="center">
|
||||
<img src="https://raw.githubusercontent.com/OpenHands/docs/main/openhands/static/img/logo.png" alt="Logo" width="200">
|
||||
<h1 align="center" style="border-bottom: none">OpenHands: AI-Driven Development</h1>
|
||||
<img src="https://raw.githubusercontent.com/All-Hands-AI/docs/main/openhands/static/img/logo.png" alt="Logo" width="200">
|
||||
<h1 align="center">OpenHands: Code Less, Make More</h1>
|
||||
</div>
|
||||
|
||||
|
||||
<div align="center">
|
||||
<a href="https://github.com/OpenHands/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/badge/LICENSE-MIT-20B2AA?style=for-the-badge" alt="MIT License"></a>
|
||||
<a href="https://docs.google.com/spreadsheets/d/1wOUdFCMyY6Nt0AIqF705KN4JKOWgeI4wUGUP60krXXs/edit?gid=811504672#gid=811504672"><img src="https://img.shields.io/badge/SWEBench-77.6-00cc00?logoColor=FFE165&style=for-the-badge" alt="Benchmark Score"></a>
|
||||
<a href="https://github.com/OpenHands/OpenHands/graphs/contributors"><img src="https://img.shields.io/github/contributors/OpenHands/OpenHands?style=for-the-badge&color=blue" alt="Contributors"></a>
|
||||
<a href="https://github.com/OpenHands/OpenHands/stargazers"><img src="https://img.shields.io/github/stars/OpenHands/OpenHands?style=for-the-badge&color=blue" alt="Stargazers"></a>
|
||||
<a href="https://github.com/OpenHands/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/OpenHands/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
|
||||
<br/>
|
||||
<a href="https://docs.openhands.dev/sdk"><img src="https://img.shields.io/badge/Documentation-000?logo=googledocs&logoColor=FFE165&style=for-the-badge" alt="Check out the documentation"></a>
|
||||
<a href="https://arxiv.org/abs/2511.03690"><img src="https://img.shields.io/badge/Paper-000?logoColor=FFE165&logo=arxiv&style=for-the-badge" alt="Tech Report"></a>
|
||||
|
||||
<a href="https://all-hands.dev/joinslack"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
|
||||
<a href="https://github.com/OpenHands/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits"></a>
|
||||
<br/>
|
||||
<a href="https://docs.all-hands.dev/usage/getting-started"><img src="https://img.shields.io/badge/Documentation-000?logo=googledocs&logoColor=FFE165&style=for-the-badge" alt="Check out the documentation"></a>
|
||||
<a href="https://arxiv.org/abs/2407.16741"><img src="https://img.shields.io/badge/Paper%20on%20Arxiv-000?logoColor=FFE165&logo=arxiv&style=for-the-badge" alt="Paper on Arxiv"></a>
|
||||
<a href="https://docs.google.com/spreadsheets/d/1wOUdFCMyY6Nt0AIqF705KN4JKOWgeI4wUGUP60krXXs/edit?gid=0#gid=0"><img src="https://img.shields.io/badge/Benchmark%20score-000?logoColor=FFE165&logo=huggingface&style=for-the-badge" alt="Evaluation Benchmark Score"></a>
|
||||
|
||||
<!-- Keep these links. Translations will automatically update with the README. -->
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=de">Deutsch</a> |
|
||||
@@ -24,63 +28,157 @@
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=ru">Русский</a> |
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=zh">中文</a>
|
||||
|
||||
<hr>
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
Welcome to OpenHands (formerly OpenDevin), a platform for software development agents powered by AI.
|
||||
|
||||
🙌 Welcome to OpenHands, a [community](COMMUNITY.md) focused on AI-driven development. We’d love for you to [join us on Slack](https://dub.sh/openhands).
|
||||
OpenHands agents can do anything a human developer can: modify code, run commands, browse the web,
|
||||
call APIs, and yes—even copy code snippets from StackOverflow.
|
||||
|
||||
There are a few ways to work with OpenHands:
|
||||
Learn more at [docs.all-hands.dev](https://docs.all-hands.dev), or [sign up for OpenHands Cloud](https://app.all-hands.dev) to get started.
|
||||
|
||||
### OpenHands Software Agent SDK
|
||||
The SDK is a composable Python library that contains all of our agentic tech. It's the engine that powers everything else below.
|
||||
|
||||
Define agents in code, then run them locally, or scale to 1000s of agents in the cloud.
|
||||
> [!IMPORTANT]
|
||||
> **Upcoming change**: We are renaming our GitHub Org from `All-Hands-AI` to `OpenHands` on October 20th, 2025.
|
||||
> Check the [tracking issue](https://github.com/All-Hands-AI/OpenHands/issues/11376) for more information.
|
||||
|
||||
[Check out the docs](https://docs.openhands.dev/sdk) or [view the source](https://github.com/OpenHands/software-agent-sdk/)
|
||||
|
||||
### OpenHands CLI
|
||||
The CLI is the easiest way to start using OpenHands. The experience will be familiar to anyone who has worked
|
||||
with e.g. Claude Code or Codex. You can power it with Claude, GPT, or any other LLM.
|
||||
> [!IMPORTANT]
|
||||
> Using OpenHands for work? We'd love to chat! Fill out
|
||||
> [this short form](https://docs.google.com/forms/d/e/1FAIpQLSet3VbGaz8z32gW9Wm-Grl4jpt5WgMXPgJ4EDPVmCETCBpJtQ/viewform)
|
||||
> to join our Design Partner program, where you'll get early access to commercial features and the opportunity to provide input on our product roadmap.
|
||||
|
||||
[Check out the docs](https://docs.openhands.dev/openhands/usage/run-openhands/cli-mode) or [view the source](https://github.com/OpenHands/OpenHands-CLI)
|
||||
## ☁️ OpenHands Cloud
|
||||
The easiest way to get started with OpenHands is on [OpenHands Cloud](https://app.all-hands.dev),
|
||||
which comes with $10 in free credits for new users.
|
||||
|
||||
### OpenHands Local GUI
|
||||
Use the Local GUI for running agents on your laptop. It comes with a REST API and a single-page React application.
|
||||
The experience will be familiar to anyone who has used Devin or Jules.
|
||||
## 💻 Running OpenHands Locally
|
||||
|
||||
[Check out the docs](https://docs.openhands.dev/openhands/usage/run-openhands/local-setup) or view the source in this repo.
|
||||
### Option 1: CLI Launcher (Recommended)
|
||||
|
||||
### OpenHands Cloud
|
||||
This is a deployment of OpenHands GUI, running on hosted infrastructure.
|
||||
The easiest way to run OpenHands locally is using the CLI launcher with [uv](https://docs.astral.sh/uv/). This provides better isolation from your current project's virtual environment and is required for OpenHands' default MCP servers.
|
||||
|
||||
You can try it with a free $10 credit by [signing in with your GitHub account](https://app.all-hands.dev).
|
||||
**Install uv** (if you haven't already):
|
||||
|
||||
OpenHands Cloud comes with source-available features and integrations:
|
||||
- Integrations with Slack, Jira, and Linear
|
||||
- Multi-user support
|
||||
- RBAC and permissions
|
||||
- Collaboration features (e.g., conversation sharing)
|
||||
See the [uv installation guide](https://docs.astral.sh/uv/getting-started/installation/) for the latest installation instructions for your platform.
|
||||
|
||||
### OpenHands Enterprise
|
||||
Large enterprises can work with us to self-host OpenHands Cloud in their own VPC, via Kubernetes.
|
||||
OpenHands Enterprise can also work with the CLI and SDK above.
|
||||
**Launch OpenHands**:
|
||||
```bash
|
||||
# Launch the GUI server
|
||||
uvx --python 3.12 openhands serve
|
||||
|
||||
OpenHands Enterprise is source-available--you can see all the source code here in the enterprise/ directory,
|
||||
but you'll need to purchase a license if you want to run it for more than one month.
|
||||
# Or launch the CLI
|
||||
uvx --python 3.12 openhands
|
||||
```
|
||||
|
||||
Enterprise contracts also come with extended support and access to our research team.
|
||||
You'll find OpenHands running at [http://localhost:3000](http://localhost:3000) (for GUI mode)!
|
||||
|
||||
Learn more at [openhands.dev/enterprise](https://openhands.dev/enterprise)
|
||||
### Option 2: Docker
|
||||
|
||||
### Everything Else
|
||||
<details>
|
||||
<summary>Click to expand Docker command</summary>
|
||||
|
||||
Check out our [Product Roadmap](https://github.com/orgs/openhands/projects/1), and feel free to
|
||||
[open up an issue](https://github.com/OpenHands/OpenHands/issues) if there's something you'd like to see!
|
||||
You can also run OpenHands directly with Docker:
|
||||
|
||||
You might also be interested in our [evaluation infrastructure](https://github.com/OpenHands/benchmarks), our [chrome extension](https://github.com/OpenHands/openhands-chrome-extension/), or our [Theory-of-Mind module](https://github.com/OpenHands/ToM-SWE).
|
||||
```bash
|
||||
docker pull docker.openhands.dev/openhands/runtime:0.62-nikolaik
|
||||
|
||||
All our work is available under the MIT license, except for the `enterprise/` directory in this repository (see the [enterprise license](enterprise/LICENSE) for details).
|
||||
The core `openhands` and `agent-server` Docker images are fully MIT-licensed as well.
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.openhands.dev/openhands/runtime:0.62-nikolaik \
|
||||
-e LOG_ALL_EVENTS=true \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ~/.openhands:/.openhands \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.openhands.dev/openhands/openhands:0.62
|
||||
```
|
||||
|
||||
If you need help with anything, or just want to chat, [come find us on Slack](https://dub.sh/openhands).
|
||||
</details>
|
||||
|
||||
> **Note**: If you used OpenHands before version 0.44, you may want to run `mv ~/.openhands-state ~/.openhands` to migrate your conversation history to the new location.
|
||||
|
||||
> [!WARNING]
|
||||
> On a public network? See our [Hardened Docker Installation Guide](https://docs.all-hands.dev/usage/runtimes/docker#hardened-docker-installation)
|
||||
> to secure your deployment by restricting network binding and implementing additional security measures.
|
||||
|
||||
### Getting Started
|
||||
|
||||
When you open the application, you'll be asked to choose an LLM provider and add an API key.
|
||||
[Anthropic's Claude Sonnet 4.5](https://www.anthropic.com/api) (`anthropic/claude-sonnet-4-5-20250929`)
|
||||
works best, but you have [many options](https://docs.all-hands.dev/usage/llms).
|
||||
|
||||
See the [Running OpenHands](https://docs.all-hands.dev/usage/installation) guide for
|
||||
system requirements and more information.
|
||||
|
||||
## 💡 Other ways to run OpenHands
|
||||
|
||||
> [!WARNING]
|
||||
> OpenHands is meant to be run by a single user on their local workstation.
|
||||
> It is not appropriate for multi-tenant deployments where multiple users share the same instance. There is no built-in authentication, isolation, or scalability.
|
||||
>
|
||||
> If you're interested in running OpenHands in a multi-tenant environment, check out the source-available, commercially-licensed
|
||||
> [OpenHands Cloud Helm Chart](https://github.com/openHands/OpenHands-cloud)
|
||||
|
||||
You can [connect OpenHands to your local filesystem](https://docs.all-hands.dev/usage/runtimes/docker#connecting-to-your-filesystem),
|
||||
interact with it via a [friendly CLI](https://docs.all-hands.dev/usage/how-to/cli-mode),
|
||||
run OpenHands in a scriptable [headless mode](https://docs.all-hands.dev/usage/how-to/headless-mode),
|
||||
or run it on tagged issues with [a github action](https://docs.all-hands.dev/usage/how-to/github-action).
|
||||
|
||||
Visit [Running OpenHands](https://docs.all-hands.dev/usage/installation) for more information and setup instructions.
|
||||
|
||||
If you want to modify the OpenHands source code, check out [Development.md](https://github.com/OpenHands/OpenHands/blob/main/Development.md).
|
||||
|
||||
Having issues? The [Troubleshooting Guide](https://docs.all-hands.dev/usage/troubleshooting) can help.
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
To learn more about the project, and for tips on using OpenHands,
|
||||
check out our [documentation](https://docs.all-hands.dev/usage/getting-started).
|
||||
|
||||
There you'll find resources on how to use different LLM providers,
|
||||
troubleshooting resources, and advanced configuration options.
|
||||
|
||||
## 🤝 How to Join the Community
|
||||
|
||||
OpenHands is a community-driven project, and we welcome contributions from everyone. We do most of our communication
|
||||
through Slack, so this is the best place to start, but we also are happy to have you contact us on Github:
|
||||
|
||||
- [Join our Slack workspace](https://all-hands.dev/joinslack) - Here we talk about research, architecture, and future development.
|
||||
- [Read or post Github Issues](https://github.com/OpenHands/OpenHands/issues) - Check out the issues we're working on, or add your own ideas.
|
||||
|
||||
See more about the community in [COMMUNITY.md](./COMMUNITY.md) or find details on contributing in [CONTRIBUTING.md](./CONTRIBUTING.md).
|
||||
|
||||
## 📈 Progress
|
||||
|
||||
See the monthly OpenHands roadmap [here](https://github.com/orgs/OpenHands/projects/1) (updated at the maintainer's meeting at the end of each month).
|
||||
|
||||
<p align="center">
|
||||
<a href="https://star-history.com/#OpenHands/OpenHands&Date">
|
||||
<img src="https://api.star-history.com/svg?repos=OpenHands/OpenHands&type=Date" width="500" alt="Star History Chart">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
## 📜 License
|
||||
|
||||
Distributed under the MIT License, with the exception of the `enterprise/` folder. See [`LICENSE`](./LICENSE) for more information.
|
||||
|
||||
## 🙏 Acknowledgements
|
||||
|
||||
OpenHands is built by a large number of contributors, and every contribution is greatly appreciated! We also build upon other open source projects, and we are deeply thankful for their work.
|
||||
|
||||
For a list of open source projects and licenses used in OpenHands, please see our [CREDITS.md](./CREDITS.md) file.
|
||||
|
||||
## 📚 Cite
|
||||
|
||||
```
|
||||
@inproceedings{
|
||||
wang2025openhands,
|
||||
title={OpenHands: An Open Platform for {AI} Software Developers as Generalist Agents},
|
||||
author={Xingyao Wang and Boxuan Li and Yufan Song and Frank F. Xu and Xiangru Tang and Mingchen Zhuge and Jiayi Pan and Yueqi Song and Bowen Li and Jaskirat Singh and Hoang H. Tran and Fuqiang Li and Ren Ma and Mingzhang Zheng and Bill Qian and Yanjun Shao and Niklas Muennighoff and Yizhe Zhang and Binyuan Hui and Junyang Lin and Robert Brennan and Hao Peng and Heng Ji and Graham Neubig},
|
||||
booktitle={The Thirteenth International Conference on Learning Representations},
|
||||
year={2025},
|
||||
url={https://openreview.net/forum?id=OJd3ayDDoF}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -73,7 +73,7 @@ ENV VIRTUAL_ENV=/app/.venv \
|
||||
|
||||
COPY --chown=openhands:openhands --chmod=770 --from=backend-builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
|
||||
COPY --chown=openhands:openhands --chmod=770 ./skills ./skills
|
||||
COPY --chown=openhands:openhands --chmod=770 ./microagents ./microagents
|
||||
COPY --chown=openhands:openhands --chmod=770 ./openhands ./openhands
|
||||
COPY --chown=openhands:openhands --chmod=777 ./openhands/runtime/plugins ./openhands/runtime/plugins
|
||||
COPY --chown=openhands:openhands pyproject.toml poetry.lock README.md MANIFEST.in LICENSE ./
|
||||
|
||||
@@ -12,7 +12,7 @@ services:
|
||||
- SANDBOX_API_HOSTNAME=host.docker.internal
|
||||
- DOCKER_HOST_ADDR=host.docker.internal
|
||||
#
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.1-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:0.62-nikolaik}
|
||||
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -3,9 +3,9 @@ repos:
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/|enterprise/)
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/|enterprise/|openhands-cli/)
|
||||
- id: end-of-file-fixer
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/|enterprise/)
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/|enterprise/|openhands-cli/)
|
||||
- id: check-yaml
|
||||
args: ["--allow-multiple-documents"]
|
||||
- id: debug-statements
|
||||
@@ -28,12 +28,12 @@ repos:
|
||||
entry: ruff check --config dev_config/python/ruff.toml
|
||||
types_or: [python, pyi, jupyter]
|
||||
args: [--fix, --unsafe-fixes]
|
||||
exclude: ^(third_party/|enterprise/)
|
||||
exclude: ^(third_party/|enterprise/|openhands-cli/)
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
entry: ruff format --config dev_config/python/ruff.toml
|
||||
types_or: [python, pyi, jupyter]
|
||||
exclude: ^(third_party/|enterprise/)
|
||||
exclude: ^(third_party/|enterprise/|openhands-cli/)
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.15.0
|
||||
|
||||
@@ -7,7 +7,7 @@ services:
|
||||
image: openhands:latest
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.1-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:0.62-nikolaik}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -31,8 +31,9 @@ RUN pip install alembic psycopg2-binary cloud-sql-python-connector pg8000 gsprea
|
||||
"pillow>=11.3.0"
|
||||
|
||||
WORKDIR /app
|
||||
COPY --chown=openhands:openhands --chmod=770 enterprise .
|
||||
COPY enterprise .
|
||||
|
||||
RUN chown -R openhands:openhands /app && chmod -R 770 /app
|
||||
USER openhands
|
||||
|
||||
# Command will be overridden by Kubernetes deployment template
|
||||
|
||||
@@ -721,7 +721,6 @@
|
||||
"https://$WEB_HOST/oauth/keycloak/callback",
|
||||
"https://$WEB_HOST/oauth/keycloak/offline/callback",
|
||||
"https://$WEB_HOST/slack/keycloak-callback",
|
||||
"https://$WEB_HOST/oauth/device/keycloak-callback",
|
||||
"https://$WEB_HOST/api/email/verified",
|
||||
"/realms/$KEYCLOAK_REALM_NAME/$KEYCLOAK_CLIENT_ID/*"
|
||||
],
|
||||
|
||||
@@ -50,7 +50,7 @@ First run this to retrieve Github App secrets
|
||||
```
|
||||
gcloud auth application-default login
|
||||
gcloud config set project global-432717
|
||||
enterprise_local/decrypt_env.sh /path/to/root/of/deploy/repo
|
||||
local/decrypt_env.sh
|
||||
```
|
||||
|
||||
Now run this to generate a `.env` file, which will used to run SAAS locally
|
||||
|
||||
@@ -116,7 +116,7 @@ lines.append('POSTHOG_CLIENT_KEY=test')
|
||||
lines.append('ENABLE_PROACTIVE_CONVERSATION_STARTERS=true')
|
||||
lines.append('MAX_CONCURRENT_CONVERSATIONS=10')
|
||||
lines.append('LITE_LLM_API_URL=https://llm-proxy.eval.all-hands.dev')
|
||||
lines.append('LITELLM_DEFAULT_MODEL=litellm_proxy/claude-opus-4-5-20251101')
|
||||
lines.append('LITELLM_DEFAULT_MODEL=litellm_proxy/claude-sonnet-4-20250514')
|
||||
lines.append(f'LITE_LLM_API_KEY={lite_llm_api_key}')
|
||||
lines.append('LOCAL_DEPLOYMENT=true')
|
||||
lines.append('DB_HOST=localhost')
|
||||
|
||||
4
enterprise/enterprise_local/decrypt_env.sh
Executable file → Normal file
4
enterprise/enterprise_local/decrypt_env.sh
Executable file → Normal file
@@ -4,12 +4,12 @@ set -euo pipefail
|
||||
# Check if DEPLOY_DIR argument was provided
|
||||
if [ $# -lt 1 ]; then
|
||||
echo "Usage: $0 <DEPLOY_DIR>"
|
||||
echo "Example: $0 /path/to/root/of/deploy/repo"
|
||||
echo "Example: $0 /path/to/deploy"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Normalize path (remove trailing slash)
|
||||
DEPLOY_DIR="${1%/}"
|
||||
DEPLOY_DIR="${DEPLOY_DIR%/}"
|
||||
|
||||
# Function to decrypt and rename
|
||||
decrypt_and_move() {
|
||||
|
||||
@@ -5,8 +5,12 @@ from experiments.constants import (
|
||||
EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
)
|
||||
from experiments.experiment_versions import (
|
||||
handle_condenser_max_step_experiment,
|
||||
handle_system_prompt_experiment,
|
||||
)
|
||||
from experiments.experiment_versions._004_condenser_max_step_experiment import (
|
||||
handle_condenser_max_step_experiment__v1,
|
||||
)
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -27,6 +31,10 @@ class SaaSExperimentManager(ExperimentManager):
|
||||
)
|
||||
return agent
|
||||
|
||||
agent = handle_condenser_max_step_experiment__v1(
|
||||
user_id, conversation_id, agent
|
||||
)
|
||||
|
||||
if EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
@@ -52,7 +60,20 @@ class SaaSExperimentManager(ExperimentManager):
|
||||
"""
|
||||
logger.debug(
|
||||
'experiment_manager:run_conversation_variant_test:started',
|
||||
extra={'user_id': user_id, 'conversation_id': conversation_id},
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Skip all experiment processing if the experiment manager is disabled
|
||||
if not ENABLE_EXPERIMENT_MANAGER:
|
||||
logger.info(
|
||||
'experiment_manager:run_conversation_variant_test:skipped',
|
||||
extra={'reason': 'experiment_manager_disabled'},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
# Apply conversation-scoped experiments
|
||||
conversation_settings = handle_condenser_max_step_experiment(
|
||||
user_id, conversation_id, conversation_settings
|
||||
)
|
||||
|
||||
return conversation_settings
|
||||
|
||||
@@ -22,7 +22,6 @@ from integrations.utils import (
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
)
|
||||
from integrations.v1_utils import get_saas_user_auth
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
@@ -165,13 +164,8 @@ class GithubManager(Manager):
|
||||
)
|
||||
|
||||
if await self.is_job_requested(message):
|
||||
payload = message.message.get('payload', {})
|
||||
user_id = payload['sender']['id']
|
||||
keycloak_user_id = await self.token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITHUB
|
||||
)
|
||||
github_view = await GithubFactory.create_github_view_from_payload(
|
||||
message, keycloak_user_id
|
||||
message, self.token_manager
|
||||
)
|
||||
logger.info(
|
||||
f'[GitHub] Creating job for {github_view.user_info.username} in {github_view.full_repo_name}#{github_view.issue_number}'
|
||||
@@ -288,15 +282,8 @@ class GithubManager(Manager):
|
||||
f'[Github]: Error summarizing issue solvability: {str(e)}'
|
||||
)
|
||||
|
||||
saas_user_auth = await get_saas_user_auth(
|
||||
github_view.user_info.keycloak_user_id, self.token_manager
|
||||
)
|
||||
|
||||
await github_view.create_new_conversation(
|
||||
self.jinja_env,
|
||||
secret_store.provider_tokens,
|
||||
convo_metadata,
|
||||
saas_user_auth,
|
||||
self.jinja_env, secret_store.provider_tokens, convo_metadata
|
||||
)
|
||||
|
||||
conversation_id = github_view.conversation_id
|
||||
@@ -305,19 +292,18 @@ class GithubManager(Manager):
|
||||
f'[GitHub] Created conversation {conversation_id} for user {user_info.username}'
|
||||
)
|
||||
|
||||
if not github_view.v1:
|
||||
# Create a GithubCallbackProcessor
|
||||
processor = GithubCallbackProcessor(
|
||||
github_view=github_view,
|
||||
send_summary_instruction=True,
|
||||
)
|
||||
# Create a GithubCallbackProcessor
|
||||
processor = GithubCallbackProcessor(
|
||||
github_view=github_view,
|
||||
send_summary_instruction=True,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Github] Registered callback processor for conversation {conversation_id}'
|
||||
)
|
||||
logger.info(
|
||||
f'[Github] Registered callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
# Send message with conversation link
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import uuid4
|
||||
|
||||
from github import Github, GithubIntegration
|
||||
from github.Issue import Issue
|
||||
@@ -9,17 +8,16 @@ from integrations.github.github_types import (
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from integrations.models import Message
|
||||
from integrations.resolver_context import ResolverUserContext
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import (
|
||||
ENABLE_PROACTIVE_CONVERSATION_STARTERS,
|
||||
ENABLE_V1_GITHUB_RESOLVER,
|
||||
HOST,
|
||||
HOST_URL,
|
||||
get_oh_labels,
|
||||
has_exact_mention,
|
||||
)
|
||||
from jinja2 import Environment
|
||||
from pydantic.dataclasses import dataclass
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
@@ -28,24 +26,14 @@ from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
AppConversationStartTaskStatus,
|
||||
)
|
||||
from openhands.app_server.config import get_app_conversation_service
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
from openhands.integrations.service_types import Comment
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.server.services.conversation_service import (
|
||||
initialize_conversation,
|
||||
start_conversation,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
@@ -88,38 +76,6 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
return settings.enable_proactive_conversation_starters
|
||||
|
||||
|
||||
async def get_user_v1_enabled_setting(user_id: str) -> bool:
|
||||
"""Get the user's V1 conversation API setting.
|
||||
|
||||
Args:
|
||||
user_id: The keycloak user ID
|
||||
|
||||
Returns:
|
||||
True if V1 conversations are enabled for this user, False otherwise
|
||||
|
||||
Note:
|
||||
This function checks both the global environment variable kill switch AND
|
||||
the user's individual setting. Both must be true for the function to return true.
|
||||
"""
|
||||
# Check the global environment variable first
|
||||
if not ENABLE_V1_GITHUB_RESOLVER:
|
||||
return False
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
|
||||
if not settings or settings.v1_enabled is None:
|
||||
return False
|
||||
|
||||
return settings.v1_enabled
|
||||
|
||||
|
||||
# =================================================
|
||||
# SECTION: Github view types
|
||||
# =================================================
|
||||
@@ -140,7 +96,6 @@ class GithubIssue(ResolverViewInterface):
|
||||
title: str
|
||||
description: str
|
||||
previous_comments: list[Comment]
|
||||
v1: bool
|
||||
|
||||
async def _load_resolver_context(self):
|
||||
github_service = GithubServiceImpl(
|
||||
@@ -187,19 +142,6 @@ class GithubIssue(ResolverViewInterface):
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
|
||||
v1_enabled = await get_user_v1_enabled_setting(self.user_info.keycloak_user_id)
|
||||
logger.info(
|
||||
f'[GitHub V1]: User flag found for {self.user_info.keycloak_user_id} is {v1_enabled}'
|
||||
)
|
||||
if v1_enabled:
|
||||
# Create dummy conversationm metadata
|
||||
# Don't save to conversation store
|
||||
# V1 conversations are stored in a separate table
|
||||
return ConversationMetadata(
|
||||
conversation_id=uuid4().hex, selected_repository=self.full_repo_name
|
||||
)
|
||||
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
@@ -216,36 +158,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
v1_enabled = await get_user_v1_enabled_setting(self.user_info.keycloak_user_id)
|
||||
logger.info(
|
||||
f'[GitHub V1]: User flag found for {self.user_info.keycloak_user_id} is {v1_enabled}'
|
||||
)
|
||||
if v1_enabled:
|
||||
try:
|
||||
# Use V1 app conversation service
|
||||
await self._create_v1_conversation(
|
||||
jinja_env, saas_user_auth, conversation_metadata
|
||||
)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Error checking V1 settings, falling back to V0: {e}')
|
||||
|
||||
# Use existing V0 conversation service
|
||||
await self._create_v0_conversation(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
|
||||
async def _create_v0_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
):
|
||||
"""Create conversation using the legacy V0 system."""
|
||||
logger.info('[GitHub]: Creating V0 conversation')
|
||||
custom_secrets = await self._get_user_secrets()
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
@@ -264,78 +177,6 @@ class GithubIssue(ResolverViewInterface):
|
||||
conversation_instructions=conversation_instructions,
|
||||
)
|
||||
|
||||
async def _create_v1_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
saas_user_auth: UserAuth,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
):
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
logger.info('[GitHub V1]: Creating V1 conversation')
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
|
||||
# Create the initial message request
|
||||
initial_message = SendMessageRequest(
|
||||
role='user', content=[TextContent(text=user_instructions)]
|
||||
)
|
||||
|
||||
# Create the GitHub V1 callback processor
|
||||
github_callback_processor = self._create_github_v1_callback_processor()
|
||||
|
||||
# Get the app conversation service and start the conversation
|
||||
injector_state = InjectorState()
|
||||
|
||||
# Create the V1 conversation start request with the callback processor
|
||||
start_request = AppConversationStartRequest(
|
||||
conversation_id=UUID(conversation_metadata.conversation_id),
|
||||
system_message_suffix=conversation_instructions,
|
||||
initial_message=initial_message,
|
||||
selected_repository=self.full_repo_name,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title=f'GitHub Issue #{self.issue_number}: {self.title}',
|
||||
trigger=ConversationTrigger.RESOLVER,
|
||||
processors=[
|
||||
github_callback_processor
|
||||
], # Pass the callback processor directly
|
||||
)
|
||||
|
||||
# Set up the GitHub user context for the V1 system
|
||||
github_user_context = ResolverUserContext(saas_user_auth=saas_user_auth)
|
||||
setattr(injector_state, USER_CONTEXT_ATTR, github_user_context)
|
||||
|
||||
async with get_app_conversation_service(
|
||||
injector_state
|
||||
) as app_conversation_service:
|
||||
async for task in app_conversation_service.start_app_conversation(
|
||||
start_request
|
||||
):
|
||||
if task.status == AppConversationStartTaskStatus.ERROR:
|
||||
logger.error(f'Failed to start V1 conversation: {task.detail}')
|
||||
raise RuntimeError(
|
||||
f'Failed to start V1 conversation: {task.detail}'
|
||||
)
|
||||
|
||||
self.v1 = True
|
||||
|
||||
def _create_github_v1_callback_processor(self):
|
||||
"""Create a V1 callback processor for GitHub integration."""
|
||||
from openhands.app_server.event_callback.github_v1_callback_processor import (
|
||||
GithubV1CallbackProcessor,
|
||||
)
|
||||
|
||||
# Create and return the GitHub V1 callback processor
|
||||
return GithubV1CallbackProcessor(
|
||||
github_view_data={
|
||||
'issue_number': self.issue_number,
|
||||
'full_repo_name': self.full_repo_name,
|
||||
'installation_id': self.installation_id,
|
||||
},
|
||||
send_summary_instruction=self.send_summary_instruction,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubIssueComment(GithubIssue):
|
||||
@@ -391,18 +232,7 @@ class GithubPRComment(GithubIssueComment):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
v1_enabled = await get_user_v1_enabled_setting(self.user_info.keycloak_user_id)
|
||||
logger.info(
|
||||
f'[GitHub V1]: User flag found for {self.user_info.keycloak_user_id} is {v1_enabled}'
|
||||
)
|
||||
if v1_enabled:
|
||||
# Create dummy conversationm metadata
|
||||
# Don't save to conversation store
|
||||
# V1 conversations are stored in a separate table
|
||||
return ConversationMetadata(
|
||||
conversation_id=uuid4().hex, selected_repository=self.full_repo_name
|
||||
)
|
||||
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
@@ -462,24 +292,6 @@ class GithubInlinePRComment(GithubPRComment):
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
def _create_github_v1_callback_processor(self):
|
||||
"""Create a V1 callback processor for GitHub integration."""
|
||||
from openhands.app_server.event_callback.github_v1_callback_processor import (
|
||||
GithubV1CallbackProcessor,
|
||||
)
|
||||
|
||||
# Create and return the GitHub V1 callback processor
|
||||
return GithubV1CallbackProcessor(
|
||||
github_view_data={
|
||||
'issue_number': self.issue_number,
|
||||
'full_repo_name': self.full_repo_name,
|
||||
'installation_id': self.installation_id,
|
||||
'comment_id': self.comment_id,
|
||||
},
|
||||
inline_pr_comment=True,
|
||||
send_summary_instruction=self.send_summary_instruction,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubFailingAction:
|
||||
@@ -793,7 +605,7 @@ class GithubFactory:
|
||||
|
||||
@staticmethod
|
||||
async def create_github_view_from_payload(
|
||||
message: Message, keycloak_user_id: str
|
||||
message: Message, token_manager: TokenManager
|
||||
) -> ResolverViewInterface:
|
||||
"""Create the appropriate class (GithubIssue or GithubPRComment) based on the payload.
|
||||
Also return metadata about the event (e.g., action type).
|
||||
@@ -803,10 +615,17 @@ class GithubFactory:
|
||||
user_id = payload['sender']['id']
|
||||
username = payload['sender']['login']
|
||||
|
||||
keyloak_user_id = await token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITHUB
|
||||
)
|
||||
|
||||
if keyloak_user_id is None:
|
||||
logger.warning(f'Got invalid keyloak user id for GitHub User {user_id} ')
|
||||
|
||||
selected_repo = GithubFactory.get_full_repo_name(repo_obj)
|
||||
is_public_repo = not repo_obj.get('private', True)
|
||||
user_info = UserData(
|
||||
user_id=user_id, username=username, keycloak_user_id=keycloak_user_id
|
||||
user_id=user_id, username=username, keycloak_user_id=keyloak_user_id
|
||||
)
|
||||
|
||||
installation_id = message.message['installation']
|
||||
@@ -830,7 +649,6 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
elif GithubFactory.is_issue_comment(message):
|
||||
@@ -856,7 +674,6 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
elif GithubFactory.is_pr_comment(message):
|
||||
@@ -898,7 +715,6 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
elif GithubFactory.is_inline_pr_comment(message):
|
||||
@@ -932,7 +748,6 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.app_server.user.user_models import UserInfo
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.sdk.secret import SecretSource, StaticSecret
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class ResolverUserContext(UserContext):
|
||||
"""User context for resolver operations that inherits from UserContext."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
self.saas_user_auth = saas_user_auth
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return await self.saas_user_auth.get_user_id()
|
||||
|
||||
async def get_user_info(self) -> UserInfo:
|
||||
user_settings = await self.saas_user_auth.get_user_settings()
|
||||
user_id = await self.saas_user_auth.get_user_id()
|
||||
if user_settings:
|
||||
return UserInfo(
|
||||
id=user_id,
|
||||
**user_settings.model_dump(context={'expose_secrets': True}),
|
||||
)
|
||||
|
||||
return UserInfo(id=user_id)
|
||||
|
||||
async def get_authenticated_git_url(self, repository: str) -> str:
|
||||
# This would need to be implemented based on the git provider tokens
|
||||
# For now, return a basic HTTPS URL
|
||||
return f'https://github.com/{repository}.git'
|
||||
|
||||
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
|
||||
# Return the appropriate token from git_provider_tokens
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens:
|
||||
return provider_tokens.get(provider_type)
|
||||
return None
|
||||
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
return await self.saas_user_auth.get_provider_tokens()
|
||||
|
||||
async def get_secrets(self) -> dict[str, SecretSource]:
|
||||
"""Get secrets for the user, including custom secrets."""
|
||||
secrets = await self.saas_user_auth.get_secrets()
|
||||
if secrets:
|
||||
# Convert custom secrets to StaticSecret objects for SDK compatibility
|
||||
# secrets.custom_secrets is of type Mapping[str, CustomSecret]
|
||||
converted_secrets = {}
|
||||
for key, custom_secret in secrets.custom_secrets.items():
|
||||
# Extract the secret value from CustomSecret and convert to StaticSecret
|
||||
secret_value = custom_secret.secret.get_secret_value()
|
||||
converted_secrets[key] = StaticSecret(value=secret_value)
|
||||
return converted_secrets
|
||||
return {}
|
||||
|
||||
async def get_mcp_api_key(self) -> str | None:
|
||||
return await self.saas_user_auth.get_mcp_api_key()
|
||||
@@ -19,7 +19,7 @@ class PRStatus(Enum):
|
||||
class UserData(BaseModel):
|
||||
user_id: int
|
||||
username: str
|
||||
keycloak_user_id: str
|
||||
keycloak_user_id: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -51,11 +51,6 @@ ENABLE_SOLVABILITY_ANALYSIS = (
|
||||
os.getenv('ENABLE_SOLVABILITY_ANALYSIS', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
# Toggle for V1 GitHub resolver feature
|
||||
ENABLE_V1_GITHUB_RESOLVER = (
|
||||
os.getenv('ENABLE_V1_GITHUB_RESOLVER', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR = 'openhands/integrations/templates/resolver/'
|
||||
jinja_env = Environment(loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR))
|
||||
@@ -321,7 +316,7 @@ def append_conversation_footer(message: str, conversation_id: str) -> str:
|
||||
The message with the conversation footer appended
|
||||
"""
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
footer = f'\n\n[View full conversation]({conversation_link})'
|
||||
footer = f'\n\n<sub>[View full conversation]({conversation_link})</sub>'
|
||||
return message + footer
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
async def get_saas_user_auth(
|
||||
keycloak_user_id: str, token_manager: TokenManager
|
||||
) -> UserAuth:
|
||||
offline_token = await token_manager.load_offline_token(keycloak_user_id)
|
||||
if offline_token is None:
|
||||
logger.info('no_offline_token_found')
|
||||
|
||||
user_auth = SaasUserAuth(
|
||||
user_id=keycloak_user_id,
|
||||
refresh_token=SecretStr(offline_token),
|
||||
)
|
||||
return user_auth
|
||||
@@ -1,41 +0,0 @@
|
||||
"""add parent_conversation_id to conversation_metadata
|
||||
|
||||
Revision ID: 081
|
||||
Revises: 080
|
||||
Create Date: 2025-11-06 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '081'
|
||||
down_revision: Union[str, None] = '080'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('parent_conversation_id', sa.String(), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_conversation_metadata_parent_conversation_id'),
|
||||
'conversation_metadata',
|
||||
['parent_conversation_id'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.drop_index(
|
||||
op.f('ix_conversation_metadata_parent_conversation_id'),
|
||||
table_name='conversation_metadata',
|
||||
)
|
||||
op.drop_column('conversation_metadata', 'parent_conversation_id')
|
||||
@@ -1,51 +0,0 @@
|
||||
"""Add SETTING_UP_SKILLS to appconversationstarttaskstatus enum
|
||||
|
||||
Revision ID: 082
|
||||
Revises: 081
|
||||
Create Date: 2025-11-19 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '082'
|
||||
down_revision: Union[str, Sequence[str], None] = '081'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add SETTING_UP_SKILLS enum value to appconversationstarttaskstatus."""
|
||||
# Check if the enum value already exists before adding it
|
||||
# This handles the case where the enum was created with the value already included
|
||||
connection = op.get_bind()
|
||||
result = connection.execute(
|
||||
text(
|
||||
"SELECT 1 FROM pg_enum WHERE enumlabel = 'SETTING_UP_SKILLS' "
|
||||
"AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'appconversationstarttaskstatus')"
|
||||
)
|
||||
)
|
||||
|
||||
if not result.fetchone():
|
||||
# Add the new enum value only if it doesn't already exist
|
||||
op.execute(
|
||||
"ALTER TYPE appconversationstarttaskstatus ADD VALUE 'SETTING_UP_SKILLS'"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove SETTING_UP_SKILLS enum value from appconversationstarttaskstatus.
|
||||
|
||||
Note: PostgreSQL doesn't support removing enum values directly.
|
||||
This would require recreating the enum type and updating all references.
|
||||
For safety, this downgrade is not implemented.
|
||||
"""
|
||||
# PostgreSQL doesn't support removing enum values directly
|
||||
# This would require a complex migration to recreate the enum
|
||||
# For now, we'll leave this as a no-op since removing enum values
|
||||
# is rarely needed and can be dangerous
|
||||
pass
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Add v1_enabled column to user_settings
|
||||
|
||||
Revision ID: 083
|
||||
Revises: 082
|
||||
Create Date: 2025-11-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '083'
|
||||
down_revision: Union[str, None] = '082'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add v1_enabled column to user_settings table."""
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column(
|
||||
'v1_enabled',
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove v1_enabled column from user_settings table."""
|
||||
op.drop_column('user_settings', 'v1_enabled')
|
||||
@@ -1,49 +0,0 @@
|
||||
"""Create device_codes table for OAuth 2.0 Device Flow
|
||||
|
||||
Revision ID: 084
|
||||
Revises: 083
|
||||
Create Date: 2024-12-10 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '084'
|
||||
down_revision = '083'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Create device_codes table for OAuth 2.0 Device Flow."""
|
||||
op.create_table(
|
||||
'device_codes',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('device_code', sa.String(length=128), nullable=False),
|
||||
sa.Column('user_code', sa.String(length=16), nullable=False),
|
||||
sa.Column('status', sa.String(length=32), nullable=False),
|
||||
sa.Column('keycloak_user_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('authorized_at', sa.DateTime(timezone=True), nullable=True),
|
||||
# Rate limiting fields for RFC 8628 section 3.5 compliance
|
||||
sa.Column('last_poll_time', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('current_interval', sa.Integer(), nullable=False, default=5),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
# Create indexes for efficient lookups
|
||||
op.create_index(
|
||||
'ix_device_codes_device_code', 'device_codes', ['device_code'], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
'ix_device_codes_user_code', 'device_codes', ['user_code'], unique=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Drop device_codes table."""
|
||||
op.drop_index('ix_device_codes_user_code', table_name='device_codes')
|
||||
op.drop_index('ix_device_codes_device_code', table_name='device_codes')
|
||||
op.drop_table('device_codes')
|
||||
@@ -1,41 +0,0 @@
|
||||
"""add public column to conversation_metadata
|
||||
|
||||
Revision ID: 085
|
||||
Revises: 084
|
||||
Create Date: 2025-01-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '085'
|
||||
down_revision: Union[str, None] = '084'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('public', sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_conversation_metadata_public'),
|
||||
'conversation_metadata',
|
||||
['public'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.drop_index(
|
||||
op.f('ix_conversation_metadata_public'),
|
||||
table_name='conversation_metadata',
|
||||
)
|
||||
op.drop_column('conversation_metadata', 'public')
|
||||
322
enterprise/poetry.lock
generated
322
enterprise/poetry.lock
generated
@@ -201,14 +201,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "anthropic"
|
||||
version = "0.75.0"
|
||||
version = "0.72.0"
|
||||
description = "The official Python library for the anthropic API"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "anthropic-0.75.0-py3-none-any.whl", hash = "sha256:ea8317271b6c15d80225a9f3c670152746e88805a7a61e14d4a374577164965b"},
|
||||
{file = "anthropic-0.75.0.tar.gz", hash = "sha256:e8607422f4ab616db2ea5baacc215dd5f028da99ce2f022e33c7c535b29f3dfb"},
|
||||
{file = "anthropic-0.72.0-py3-none-any.whl", hash = "sha256:0e9f5a7582f038cab8efbb4c959e49ef654a56bfc7ba2da51b5a7b8a84de2e4d"},
|
||||
{file = "anthropic-0.72.0.tar.gz", hash = "sha256:8971fe76dcffc644f74ac3883069beb1527641115ae0d6eb8fa21c1ce4082f7a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -682,37 +682,37 @@ crt = ["awscrt (==0.27.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "browser-use"
|
||||
version = "0.10.1"
|
||||
version = "0.9.5"
|
||||
description = "Make websites accessible for AI agents"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.11"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "browser_use-0.10.1-py3-none-any.whl", hash = "sha256:96e603bfc71098175342cdcb0592519e6f244412e740f0254e4389fdd82a977f"},
|
||||
{file = "browser_use-0.10.1.tar.gz", hash = "sha256:5f211ecfdf1f9fd186160f10df70dedd661821231e30f1bce40939787abab223"},
|
||||
{file = "browser_use-0.9.5-py3-none-any.whl", hash = "sha256:4a2e92847204d1ded269026a99cb0cc0e60e38bd2751fa3f58aedd78f00b4e67"},
|
||||
{file = "browser_use-0.9.5.tar.gz", hash = "sha256:f8285fe253b149d01769a7084883b4cf4db351e2f38e26302c157bcbf14a703f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = "3.12.15"
|
||||
anthropic = ">=0.72.1,<1.0.0"
|
||||
anthropic = ">=0.68.1,<1.0.0"
|
||||
anyio = ">=4.9.0"
|
||||
authlib = ">=1.6.0"
|
||||
bubus = ">=1.5.6"
|
||||
cdp-use = ">=1.4.4"
|
||||
cdp-use = ">=1.4.0"
|
||||
click = ">=8.1.8"
|
||||
cloudpickle = ">=3.1.1"
|
||||
google-api-core = ">=2.25.0"
|
||||
google-api-python-client = ">=2.174.0"
|
||||
google-auth = ">=2.40.3"
|
||||
google-auth-oauthlib = ">=1.2.2"
|
||||
google-genai = ">=1.50.0,<2.0.0"
|
||||
google-genai = ">=1.29.0,<2.0.0"
|
||||
groq = ">=0.30.0"
|
||||
httpx = ">=0.28.1"
|
||||
inquirerpy = ">=0.3.4"
|
||||
markdownify = ">=1.2.0"
|
||||
mcp = ">=1.10.1"
|
||||
ollama = ">=0.5.1"
|
||||
openai = ">=2.7.2,<3.0.0"
|
||||
openai = ">=1.99.2,<2.0.0"
|
||||
pillow = ">=11.2.1"
|
||||
portalocker = ">=2.7.0,<3.0.0"
|
||||
posthog = ">=3.7.0"
|
||||
@@ -721,7 +721,6 @@ pydantic = ">=2.11.5"
|
||||
pyobjc = {version = ">=11.0", markers = "platform_system == \"darwin\""}
|
||||
pyotp = ">=2.9.0"
|
||||
pypdf = ">=5.7.0"
|
||||
python-docx = ">=1.2.0"
|
||||
python-dotenv = ">=1.0.1"
|
||||
reportlab = ">=4.0.0"
|
||||
requests = ">=2.32.3"
|
||||
@@ -851,14 +850,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "cdp-use"
|
||||
version = "1.4.4"
|
||||
version = "1.4.3"
|
||||
description = "Type safe generator/client library for CDP"
|
||||
optional = false
|
||||
python-versions = ">=3.11"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "cdp_use-1.4.4-py3-none-any.whl", hash = "sha256:e37e80e067db2653d6fdf953d4ff9e5d80d75daa27b7c6d48c0261cccbef73e1"},
|
||||
{file = "cdp_use-1.4.4.tar.gz", hash = "sha256:330a848b517006eb9ad1dc468aa6434d913cf0c6918610760c36c3fdfdba0fab"},
|
||||
{file = "cdp_use-1.4.3-py3-none-any.whl", hash = "sha256:c48664604470c2579aa1e677c3e3e7e24c4f300c54804c093d935abb50479ecd"},
|
||||
{file = "cdp_use-1.4.3.tar.gz", hash = "sha256:9029c04bdc49fbd3939d2bf1988ad8d88e260729c7d5e35c2f6c87591f5a10e9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2979,29 +2978,28 @@ testing = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "google-genai"
|
||||
version = "1.53.0"
|
||||
version = "1.32.0"
|
||||
description = "GenAI Python SDK"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "google_genai-1.53.0-py3-none-any.whl", hash = "sha256:65a3f99e5c03c372d872cda7419f5940e723374bb12a2f3ffd5e3e56e8eb2094"},
|
||||
{file = "google_genai-1.53.0.tar.gz", hash = "sha256:938a26d22f3fd32c6eeeb4276ef204ef82884e63af9842ce3eac05ceb39cbd8d"},
|
||||
{file = "google_genai-1.32.0-py3-none-any.whl", hash = "sha256:c0c4b1d45adf3aa99501050dd73da2f0dea09374002231052d81a6765d15e7f6"},
|
||||
{file = "google_genai-1.32.0.tar.gz", hash = "sha256:349da3f5ff0e981066bd508585fcdd308d28fc4646f318c8f6d1aa6041f4c7e3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=4.8.0,<5.0.0"
|
||||
google-auth = {version = ">=2.14.1,<3.0.0", extras = ["requests"]}
|
||||
google-auth = ">=2.14.1,<3.0.0"
|
||||
httpx = ">=0.28.1,<1.0.0"
|
||||
pydantic = ">=2.9.0,<3.0.0"
|
||||
pydantic = ">=2.0.0,<3.0.0"
|
||||
requests = ">=2.28.1,<3.0.0"
|
||||
tenacity = ">=8.2.3,<9.2.0"
|
||||
typing-extensions = ">=4.11.0,<5.0.0"
|
||||
websockets = ">=13.0.0,<15.1.0"
|
||||
|
||||
[package.extras]
|
||||
aiohttp = ["aiohttp (<3.13.3)"]
|
||||
local-tokenizer = ["protobuf", "sentencepiece (>=0.2.0)"]
|
||||
aiohttp = ["aiohttp (<4.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "google-resumable-media"
|
||||
@@ -3057,8 +3055,6 @@ files = [
|
||||
{file = "greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d"},
|
||||
{file = "greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5"},
|
||||
{file = "greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f"},
|
||||
{file = "greenlet-3.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f47617f698838ba98f4ff4189aef02e7343952df3a615f847bb575c3feb177a7"},
|
||||
{file = "greenlet-3.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af41be48a4f60429d5cad9d22175217805098a9ef7c40bfef44f7669fb9d74d8"},
|
||||
{file = "greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246"},
|
||||
@@ -3068,8 +3064,6 @@ files = [
|
||||
{file = "greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb"},
|
||||
@@ -3079,8 +3073,6 @@ files = [
|
||||
{file = "greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945"},
|
||||
@@ -3090,8 +3082,6 @@ files = [
|
||||
{file = "greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:49a30d5fda2507ae77be16479bdb62a660fa51b1eb4928b524975b3bde77b3c0"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:299fd615cd8fc86267b47597123e3f43ad79c9d8a22bebdce535e53550763e2f"},
|
||||
@@ -3099,8 +3089,6 @@ files = [
|
||||
{file = "greenlet-3.2.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b4a1870c51720687af7fa3e7cda6d08d801dae660f75a76f3845b642b4da6ee1"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:061dc4cf2c34852b052a8620d40f36324554bc192be474b9e9770e8c042fd735"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44358b9bf66c8576a9f57a590d5f5d6e72fa4228b763d0e43fee6d3b06d3a337"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2917bdf657f5859fbf3386b12d68ede4cf1f04c90c3a6bc1f013dd68a22e2269"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:015d48959d4add5d6c9f6c5210ee3803a830dce46356e3bc326d6776bde54681"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-win_amd64.whl", hash = "sha256:e37ab26028f12dbb0ff65f29a8d3d44a765c61e729647bf2ddfbbed621726f01"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:b6a7c19cf0d2742d0809a4c05975db036fdff50cd294a93632d6a310bf9ac02c"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:27890167f55d2387576d1f41d9487ef171849ea0359ce1510ca6e06c8bece11d"},
|
||||
@@ -3110,8 +3098,6 @@ files = [
|
||||
{file = "greenlet-3.2.4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9913f1a30e4526f432991f89ae263459b1c64d1608c0d22a5c79c287b3c70df"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b90654e092f928f110e0007f572007c9727b5265f7632c2fa7415b4689351594"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:81701fd84f26330f0d5f4944d4e92e61afe6319dcd9775e39396e39d7c3e5f98"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:28a3c6b7cd72a96f61b0e4b2a36f681025b60ae4779cc73c1535eb5f29560b10"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:52206cd642670b0b320a1fd1cbfd95bca0e043179c1d8a045f2c6109dfe973be"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-win32.whl", hash = "sha256:65458b409c1ed459ea899e939f0e1cdb14f58dbc803f2f93c5eab5694d32671b"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:d2e685ade4dafd447ede19c31277a224a239a0a1a4eca4e6390efedf20260cfb"},
|
||||
{file = "greenlet-3.2.4.tar.gz", hash = "sha256:0dca0d95ff849f9a364385f36ab49f50065d76964944638be9691e1832e9f86d"},
|
||||
@@ -3180,87 +3166,83 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4
|
||||
|
||||
[[package]]
|
||||
name = "grpcio"
|
||||
version = "1.67.1"
|
||||
version = "1.74.0"
|
||||
description = "HTTP/2-based RPC framework"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "grpcio-1.67.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:8b0341d66a57f8a3119b77ab32207072be60c9bf79760fa609c5609f2deb1f3f"},
|
||||
{file = "grpcio-1.67.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:f5a27dddefe0e2357d3e617b9079b4bfdc91341a91565111a21ed6ebbc51b22d"},
|
||||
{file = "grpcio-1.67.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:43112046864317498a33bdc4797ae6a268c36345a910de9b9c17159d8346602f"},
|
||||
{file = "grpcio-1.67.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9b929f13677b10f63124c1a410994a401cdd85214ad83ab67cc077fc7e480f0"},
|
||||
{file = "grpcio-1.67.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7d1797a8a3845437d327145959a2c0c47c05947c9eef5ff1a4c80e499dcc6fa"},
|
||||
{file = "grpcio-1.67.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0489063974d1452436139501bf6b180f63d4977223ee87488fe36858c5725292"},
|
||||
{file = "grpcio-1.67.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9fd042de4a82e3e7aca44008ee2fb5da01b3e5adb316348c21980f7f58adc311"},
|
||||
{file = "grpcio-1.67.1-cp310-cp310-win32.whl", hash = "sha256:638354e698fd0c6c76b04540a850bf1db27b4d2515a19fcd5cf645c48d3eb1ed"},
|
||||
{file = "grpcio-1.67.1-cp310-cp310-win_amd64.whl", hash = "sha256:608d87d1bdabf9e2868b12338cd38a79969eaf920c89d698ead08f48de9c0f9e"},
|
||||
{file = "grpcio-1.67.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:7818c0454027ae3384235a65210bbf5464bd715450e30a3d40385453a85a70cb"},
|
||||
{file = "grpcio-1.67.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ea33986b70f83844cd00814cee4451055cd8cab36f00ac64a31f5bb09b31919e"},
|
||||
{file = "grpcio-1.67.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:c7a01337407dd89005527623a4a72c5c8e2894d22bead0895306b23c6695698f"},
|
||||
{file = "grpcio-1.67.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80b866f73224b0634f4312a4674c1be21b2b4afa73cb20953cbbb73a6b36c3cc"},
|
||||
{file = "grpcio-1.67.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9fff78ba10d4250bfc07a01bd6254a6d87dc67f9627adece85c0b2ed754fa96"},
|
||||
{file = "grpcio-1.67.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8a23cbcc5bb11ea7dc6163078be36c065db68d915c24f5faa4f872c573bb400f"},
|
||||
{file = "grpcio-1.67.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1a65b503d008f066e994f34f456e0647e5ceb34cfcec5ad180b1b44020ad4970"},
|
||||
{file = "grpcio-1.67.1-cp311-cp311-win32.whl", hash = "sha256:e29ca27bec8e163dca0c98084040edec3bc49afd10f18b412f483cc68c712744"},
|
||||
{file = "grpcio-1.67.1-cp311-cp311-win_amd64.whl", hash = "sha256:786a5b18544622bfb1e25cc08402bd44ea83edfb04b93798d85dca4d1a0b5be5"},
|
||||
{file = "grpcio-1.67.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:267d1745894200e4c604958da5f856da6293f063327cb049a51fe67348e4f953"},
|
||||
{file = "grpcio-1.67.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:85f69fdc1d28ce7cff8de3f9c67db2b0ca9ba4449644488c1e0303c146135ddb"},
|
||||
{file = "grpcio-1.67.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:f26b0b547eb8d00e195274cdfc63ce64c8fc2d3e2d00b12bf468ece41a0423a0"},
|
||||
{file = "grpcio-1.67.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4422581cdc628f77302270ff839a44f4c24fdc57887dc2a45b7e53d8fc2376af"},
|
||||
{file = "grpcio-1.67.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d7616d2ded471231c701489190379e0c311ee0a6c756f3c03e6a62b95a7146e"},
|
||||
{file = "grpcio-1.67.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8a00efecde9d6fcc3ab00c13f816313c040a28450e5e25739c24f432fc6d3c75"},
|
||||
{file = "grpcio-1.67.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:699e964923b70f3101393710793289e42845791ea07565654ada0969522d0a38"},
|
||||
{file = "grpcio-1.67.1-cp312-cp312-win32.whl", hash = "sha256:4e7b904484a634a0fff132958dabdb10d63e0927398273917da3ee103e8d1f78"},
|
||||
{file = "grpcio-1.67.1-cp312-cp312-win_amd64.whl", hash = "sha256:5721e66a594a6c4204458004852719b38f3d5522082be9061d6510b455c90afc"},
|
||||
{file = "grpcio-1.67.1-cp313-cp313-linux_armv7l.whl", hash = "sha256:aa0162e56fd10a5547fac8774c4899fc3e18c1aa4a4759d0ce2cd00d3696ea6b"},
|
||||
{file = "grpcio-1.67.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:beee96c8c0b1a75d556fe57b92b58b4347c77a65781ee2ac749d550f2a365dc1"},
|
||||
{file = "grpcio-1.67.1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:a93deda571a1bf94ec1f6fcda2872dad3ae538700d94dc283c672a3b508ba3af"},
|
||||
{file = "grpcio-1.67.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e6f255980afef598a9e64a24efce87b625e3e3c80a45162d111a461a9f92955"},
|
||||
{file = "grpcio-1.67.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e838cad2176ebd5d4a8bb03955138d6589ce9e2ce5d51c3ada34396dbd2dba8"},
|
||||
{file = "grpcio-1.67.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:a6703916c43b1d468d0756c8077b12017a9fcb6a1ef13faf49e67d20d7ebda62"},
|
||||
{file = "grpcio-1.67.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:917e8d8994eed1d86b907ba2a61b9f0aef27a2155bca6cbb322430fc7135b7bb"},
|
||||
{file = "grpcio-1.67.1-cp313-cp313-win32.whl", hash = "sha256:e279330bef1744040db8fc432becc8a727b84f456ab62b744d3fdb83f327e121"},
|
||||
{file = "grpcio-1.67.1-cp313-cp313-win_amd64.whl", hash = "sha256:fa0c739ad8b1996bd24823950e3cb5152ae91fca1c09cc791190bf1627ffefba"},
|
||||
{file = "grpcio-1.67.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:178f5db771c4f9a9facb2ab37a434c46cb9be1a75e820f187ee3d1e7805c4f65"},
|
||||
{file = "grpcio-1.67.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0f3e49c738396e93b7ba9016e153eb09e0778e776df6090c1b8c91877cc1c426"},
|
||||
{file = "grpcio-1.67.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:24e8a26dbfc5274d7474c27759b54486b8de23c709d76695237515bc8b5baeab"},
|
||||
{file = "grpcio-1.67.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b6c16489326d79ead41689c4b84bc40d522c9a7617219f4ad94bc7f448c5085"},
|
||||
{file = "grpcio-1.67.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60e6a4dcf5af7bbc36fd9f81c9f372e8ae580870a9e4b6eafe948cd334b81cf3"},
|
||||
{file = "grpcio-1.67.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:95b5f2b857856ed78d72da93cd7d09b6db8ef30102e5e7fe0961fe4d9f7d48e8"},
|
||||
{file = "grpcio-1.67.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b49359977c6ec9f5d0573ea4e0071ad278ef905aa74e420acc73fd28ce39e9ce"},
|
||||
{file = "grpcio-1.67.1-cp38-cp38-win32.whl", hash = "sha256:f5b76ff64aaac53fede0cc93abf57894ab2a7362986ba22243d06218b93efe46"},
|
||||
{file = "grpcio-1.67.1-cp38-cp38-win_amd64.whl", hash = "sha256:804c6457c3cd3ec04fe6006c739579b8d35c86ae3298ffca8de57b493524b771"},
|
||||
{file = "grpcio-1.67.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:a25bdea92b13ff4d7790962190bf6bf5c4639876e01c0f3dda70fc2769616335"},
|
||||
{file = "grpcio-1.67.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cdc491ae35a13535fd9196acb5afe1af37c8237df2e54427be3eecda3653127e"},
|
||||
{file = "grpcio-1.67.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:85f862069b86a305497e74d0dc43c02de3d1d184fc2c180993aa8aa86fbd19b8"},
|
||||
{file = "grpcio-1.67.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ec74ef02010186185de82cc594058a3ccd8d86821842bbac9873fd4a2cf8be8d"},
|
||||
{file = "grpcio-1.67.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01f616a964e540638af5130469451cf580ba8c7329f45ca998ab66e0c7dcdb04"},
|
||||
{file = "grpcio-1.67.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:299b3d8c4f790c6bcca485f9963b4846dd92cf6f1b65d3697145d005c80f9fe8"},
|
||||
{file = "grpcio-1.67.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:60336bff760fbb47d7e86165408126f1dded184448e9a4c892189eb7c9d3f90f"},
|
||||
{file = "grpcio-1.67.1-cp39-cp39-win32.whl", hash = "sha256:5ed601c4c6008429e3d247ddb367fe8c7259c355757448d7c1ef7bd4a6739e8e"},
|
||||
{file = "grpcio-1.67.1-cp39-cp39-win_amd64.whl", hash = "sha256:5db70d32d6703b89912af16d6d45d78406374a8b8ef0d28140351dd0ec610e98"},
|
||||
{file = "grpcio-1.67.1.tar.gz", hash = "sha256:3dc2ed4cabea4dc14d5e708c2b426205956077cc5de419b4d4079315017e9732"},
|
||||
{file = "grpcio-1.74.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:85bd5cdf4ed7b2d6438871adf6afff9af7096486fcf51818a81b77ef4dd30907"},
|
||||
{file = "grpcio-1.74.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:68c8ebcca945efff9d86d8d6d7bfb0841cf0071024417e2d7f45c5e46b5b08eb"},
|
||||
{file = "grpcio-1.74.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:e154d230dc1bbbd78ad2fdc3039fa50ad7ffcf438e4eb2fa30bce223a70c7486"},
|
||||
{file = "grpcio-1.74.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e8978003816c7b9eabe217f88c78bc26adc8f9304bf6a594b02e5a49b2ef9c11"},
|
||||
{file = "grpcio-1.74.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3d7bd6e3929fd2ea7fbc3f562e4987229ead70c9ae5f01501a46701e08f1ad9"},
|
||||
{file = "grpcio-1.74.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:136b53c91ac1d02c8c24201bfdeb56f8b3ac3278668cbb8e0ba49c88069e1bdc"},
|
||||
{file = "grpcio-1.74.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:fe0f540750a13fd8e5da4b3eaba91a785eea8dca5ccd2bc2ffe978caa403090e"},
|
||||
{file = "grpcio-1.74.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4e4181bfc24413d1e3a37a0b7889bea68d973d4b45dd2bc68bb766c140718f82"},
|
||||
{file = "grpcio-1.74.0-cp310-cp310-win32.whl", hash = "sha256:1733969040989f7acc3d94c22f55b4a9501a30f6aaacdbccfaba0a3ffb255ab7"},
|
||||
{file = "grpcio-1.74.0-cp310-cp310-win_amd64.whl", hash = "sha256:9e912d3c993a29df6c627459af58975b2e5c897d93287939b9d5065f000249b5"},
|
||||
{file = "grpcio-1.74.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:69e1a8180868a2576f02356565f16635b99088da7df3d45aaa7e24e73a054e31"},
|
||||
{file = "grpcio-1.74.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:8efe72fde5500f47aca1ef59495cb59c885afe04ac89dd11d810f2de87d935d4"},
|
||||
{file = "grpcio-1.74.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:a8f0302f9ac4e9923f98d8e243939a6fb627cd048f5cd38595c97e38020dffce"},
|
||||
{file = "grpcio-1.74.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f609a39f62a6f6f05c7512746798282546358a37ea93c1fcbadf8b2fed162e3"},
|
||||
{file = "grpcio-1.74.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c98e0b7434a7fa4e3e63f250456eaef52499fba5ae661c58cc5b5477d11e7182"},
|
||||
{file = "grpcio-1.74.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:662456c4513e298db6d7bd9c3b8df6f75f8752f0ba01fb653e252ed4a59b5a5d"},
|
||||
{file = "grpcio-1.74.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3d14e3c4d65e19d8430a4e28ceb71ace4728776fd6c3ce34016947474479683f"},
|
||||
{file = "grpcio-1.74.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1bf949792cee20d2078323a9b02bacbbae002b9e3b9e2433f2741c15bdeba1c4"},
|
||||
{file = "grpcio-1.74.0-cp311-cp311-win32.whl", hash = "sha256:55b453812fa7c7ce2f5c88be3018fb4a490519b6ce80788d5913f3f9d7da8c7b"},
|
||||
{file = "grpcio-1.74.0-cp311-cp311-win_amd64.whl", hash = "sha256:86ad489db097141a907c559988c29718719aa3e13370d40e20506f11b4de0d11"},
|
||||
{file = "grpcio-1.74.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:8533e6e9c5bd630ca98062e3a1326249e6ada07d05acf191a77bc33f8948f3d8"},
|
||||
{file = "grpcio-1.74.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:2918948864fec2a11721d91568effffbe0a02b23ecd57f281391d986847982f6"},
|
||||
{file = "grpcio-1.74.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:60d2d48b0580e70d2e1954d0d19fa3c2e60dd7cbed826aca104fff518310d1c5"},
|
||||
{file = "grpcio-1.74.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3601274bc0523f6dc07666c0e01682c94472402ac2fd1226fd96e079863bfa49"},
|
||||
{file = "grpcio-1.74.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:176d60a5168d7948539def20b2a3adcce67d72454d9ae05969a2e73f3a0feee7"},
|
||||
{file = "grpcio-1.74.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e759f9e8bc908aaae0412642afe5416c9f983a80499448fcc7fab8692ae044c3"},
|
||||
{file = "grpcio-1.74.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e7c4389771855a92934b2846bd807fc25a3dfa820fd912fe6bd8136026b2707"},
|
||||
{file = "grpcio-1.74.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cce634b10aeab37010449124814b05a62fb5f18928ca878f1bf4750d1f0c815b"},
|
||||
{file = "grpcio-1.74.0-cp312-cp312-win32.whl", hash = "sha256:885912559974df35d92219e2dc98f51a16a48395f37b92865ad45186f294096c"},
|
||||
{file = "grpcio-1.74.0-cp312-cp312-win_amd64.whl", hash = "sha256:42f8fee287427b94be63d916c90399ed310ed10aadbf9e2e5538b3e497d269bc"},
|
||||
{file = "grpcio-1.74.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:2bc2d7d8d184e2362b53905cb1708c84cb16354771c04b490485fa07ce3a1d89"},
|
||||
{file = "grpcio-1.74.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:c14e803037e572c177ba54a3e090d6eb12efd795d49327c5ee2b3bddb836bf01"},
|
||||
{file = "grpcio-1.74.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:f6ec94f0e50eb8fa1744a731088b966427575e40c2944a980049798b127a687e"},
|
||||
{file = "grpcio-1.74.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:566b9395b90cc3d0d0c6404bc8572c7c18786ede549cdb540ae27b58afe0fb91"},
|
||||
{file = "grpcio-1.74.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1ea6176d7dfd5b941ea01c2ec34de9531ba494d541fe2057c904e601879f249"},
|
||||
{file = "grpcio-1.74.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:64229c1e9cea079420527fa8ac45d80fc1e8d3f94deaa35643c381fa8d98f362"},
|
||||
{file = "grpcio-1.74.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:0f87bddd6e27fc776aacf7ebfec367b6d49cad0455123951e4488ea99d9b9b8f"},
|
||||
{file = "grpcio-1.74.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:3b03d8f2a07f0fea8c8f74deb59f8352b770e3900d143b3d1475effcb08eec20"},
|
||||
{file = "grpcio-1.74.0-cp313-cp313-win32.whl", hash = "sha256:b6a73b2ba83e663b2480a90b82fdae6a7aa6427f62bf43b29912c0cfd1aa2bfa"},
|
||||
{file = "grpcio-1.74.0-cp313-cp313-win_amd64.whl", hash = "sha256:fd3c71aeee838299c5887230b8a1822795325ddfea635edd82954c1eaa831e24"},
|
||||
{file = "grpcio-1.74.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:4bc5fca10aaf74779081e16c2bcc3d5ec643ffd528d9e7b1c9039000ead73bae"},
|
||||
{file = "grpcio-1.74.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:6bab67d15ad617aff094c382c882e0177637da73cbc5532d52c07b4ee887a87b"},
|
||||
{file = "grpcio-1.74.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:655726919b75ab3c34cdad39da5c530ac6fa32696fb23119e36b64adcfca174a"},
|
||||
{file = "grpcio-1.74.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a2b06afe2e50ebfd46247ac3ba60cac523f54ec7792ae9ba6073c12daf26f0a"},
|
||||
{file = "grpcio-1.74.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f251c355167b2360537cf17bea2cf0197995e551ab9da6a0a59b3da5e8704f9"},
|
||||
{file = "grpcio-1.74.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8f7b5882fb50632ab1e48cb3122d6df55b9afabc265582808036b6e51b9fd6b7"},
|
||||
{file = "grpcio-1.74.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:834988b6c34515545b3edd13e902c1acdd9f2465d386ea5143fb558f153a7176"},
|
||||
{file = "grpcio-1.74.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:22b834cef33429ca6cc28303c9c327ba9a3fafecbf62fae17e9a7b7163cc43ac"},
|
||||
{file = "grpcio-1.74.0-cp39-cp39-win32.whl", hash = "sha256:7d95d71ff35291bab3f1c52f52f474c632db26ea12700c2ff0ea0532cb0b5854"},
|
||||
{file = "grpcio-1.74.0-cp39-cp39-win_amd64.whl", hash = "sha256:ecde9ab49f58433abe02f9ed076c7b5be839cf0153883a6d23995937a82392fa"},
|
||||
{file = "grpcio-1.74.0.tar.gz", hash = "sha256:80d1f4fbb35b0742d3e3d3bb654b7381cd5f015f8497279a1e9c21ba623e01b1"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
protobuf = ["grpcio-tools (>=1.67.1)"]
|
||||
protobuf = ["grpcio-tools (>=1.74.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "grpcio-status"
|
||||
version = "1.67.1"
|
||||
version = "1.71.2"
|
||||
description = "Status proto mapping for gRPC"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "grpcio_status-1.67.1-py3-none-any.whl", hash = "sha256:16e6c085950bdacac97c779e6a502ea671232385e6e37f258884d6883392c2bd"},
|
||||
{file = "grpcio_status-1.67.1.tar.gz", hash = "sha256:2bf38395e028ceeecfd8866b081f61628114b384da7d51ae064ddc8d766a5d11"},
|
||||
{file = "grpcio_status-1.71.2-py3-none-any.whl", hash = "sha256:803c98cb6a8b7dc6dbb785b1111aed739f241ab5e9da0bba96888aa74704cfd3"},
|
||||
{file = "grpcio_status-1.71.2.tar.gz", hash = "sha256:c7a97e176df71cdc2c179cd1847d7fc86cca5832ad12e9798d7fed6b7a1aab50"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
googleapis-common-protos = ">=1.5.5"
|
||||
grpcio = ">=1.67.1"
|
||||
grpcio = ">=1.71.2"
|
||||
protobuf = ">=5.26.1,<6.0dev"
|
||||
|
||||
[[package]]
|
||||
@@ -4558,39 +4540,42 @@ valkey = ["valkey (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.80.11"
|
||||
version = "1.77.7"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
python-versions = ">=3.8.1,<4.0, !=3.9.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "litellm-1.80.11-py3-none-any.whl", hash = "sha256:406283d66ead77dc7ff0e0b2559c80e9e497d8e7c2257efb1cb9210a20d09d54"},
|
||||
{file = "litellm-1.80.11.tar.gz", hash = "sha256:c9fc63e7acb6360363238fe291bcff1488c59ff66020416d8376c0ee56414a19"},
|
||||
]
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = ">=3.10"
|
||||
click = "*"
|
||||
fastuuid = ">=0.13.0"
|
||||
grpcio = {version = ">=1.62.3,<1.68.0", markers = "python_version < \"3.14\""}
|
||||
httpx = ">=0.23.0"
|
||||
importlib-metadata = ">=6.8.0"
|
||||
jinja2 = ">=3.1.2,<4.0.0"
|
||||
jsonschema = ">=4.23.0,<5.0.0"
|
||||
openai = ">=2.8.0"
|
||||
pydantic = ">=2.5.0,<3.0.0"
|
||||
jinja2 = "^3.1.2"
|
||||
jsonschema = "^4.22.0"
|
||||
openai = ">=1.99.5"
|
||||
pydantic = "^2.5.0"
|
||||
python-dotenv = ">=0.2.0"
|
||||
tiktoken = ">=0.7.0"
|
||||
tokenizers = "*"
|
||||
|
||||
[package.extras]
|
||||
caching = ["diskcache (>=5.6.1,<6.0.0)"]
|
||||
extra-proxy = ["azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-iam (>=2.19.1,<3.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "redisvl (>=0.4.1,<0.5.0) ; python_version >= \"3.9\" and python_version < \"3.14\"", "resend (>=0.8.0)"]
|
||||
extra-proxy = ["azure-identity (>=1.15.0,<2.0.0)", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-iam (>=2.19.1,<3.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "redisvl (>=0.4.1,<0.5.0) ; python_version >= \"3.9\" and python_version < \"3.14\"", "resend (>=0.8.0,<0.9.0)"]
|
||||
mlflow = ["mlflow (>3.1.4) ; python_version >= \"3.10\""]
|
||||
proxy = ["PyJWT (>=2.10.1,<3.0.0) ; python_version >= \"3.9\"", "apscheduler (>=3.10.4,<4.0.0)", "azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-storage-blob (>=12.25.1,<13.0.0)", "backoff", "boto3 (==1.36.0)", "cryptography", "fastapi (>=0.120.1)", "fastapi-sso (>=0.16.0,<0.17.0)", "gunicorn (>=23.0.0,<24.0.0)", "litellm-enterprise (==0.1.27)", "litellm-proxy-extras (==0.4.16)", "mcp (>=1.21.2,<2.0.0) ; python_version >= \"3.10\"", "orjson (>=3.9.7,<4.0.0)", "polars (>=1.31.0,<2.0.0) ; python_version >= \"3.10\"", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.18,<0.0.19)", "pyyaml (>=6.0.1,<7.0.0)", "rich (==13.7.1)", "rq", "soundfile (>=0.12.1,<0.13.0)", "uvicorn (>=0.31.1,<0.32.0)", "uvloop (>=0.21.0,<0.22.0) ; sys_platform != \"win32\"", "websockets (>=15.0.1,<16.0.0)"]
|
||||
semantic-router = ["semantic-router (>=0.1.12) ; python_version >= \"3.9\" and python_version < \"3.14\""]
|
||||
proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "azure-identity (>=1.15.0,<2.0.0)", "azure-storage-blob (>=12.25.1,<13.0.0)", "backoff", "boto3 (==1.36.0)", "cryptography", "fastapi (>=0.115.5,<0.116.0)", "fastapi-sso (>=0.16.0,<0.17.0)", "gunicorn (>=23.0.0,<24.0.0)", "litellm-enterprise (==0.1.20)", "litellm-proxy-extras (==0.2.25)", "mcp (>=1.10.0,<2.0.0) ; python_version >= \"3.10\"", "orjson (>=3.9.7,<4.0.0)", "polars (>=1.31.0,<2.0.0) ; python_version >= \"3.10\"", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.18,<0.0.19)", "pyyaml (>=6.0.1,<7.0.0)", "rich (==13.7.1)", "rq", "uvicorn (>=0.29.0,<0.30.0)", "uvloop (>=0.21.0,<0.22.0) ; sys_platform != \"win32\"", "websockets (>=13.1.0,<14.0.0)"]
|
||||
semantic-router = ["semantic-router ; python_version >= \"3.9\""]
|
||||
utils = ["numpydoc"]
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/BerriAI/litellm.git"
|
||||
reference = "v1.77.7.dev9"
|
||||
resolved_reference = "763d2f8ccdd8412dbe6d4ac0e136d9ac34dcd4c0"
|
||||
|
||||
[[package]]
|
||||
name = "llvmlite"
|
||||
version = "0.44.0"
|
||||
@@ -4624,14 +4609,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "lmnr"
|
||||
version = "0.7.24"
|
||||
version = "0.7.20"
|
||||
description = "Python SDK for Laminar"
|
||||
optional = false
|
||||
python-versions = "<4,>=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "lmnr-0.7.24-py3-none-any.whl", hash = "sha256:ad780d4a62ece897048811f3368639c240a9329ab31027da8c96545137a3a08a"},
|
||||
{file = "lmnr-0.7.24.tar.gz", hash = "sha256:aa6973f46fc4ba95c9061c1feceb58afc02eb43c9376c21e32545371ff6123d7"},
|
||||
{file = "lmnr-0.7.20-py3-none-any.whl", hash = "sha256:5f9fa7444e6f96c25e097f66484ff29e632bdd1de0e9346948bf5595f4a8af38"},
|
||||
{file = "lmnr-0.7.20.tar.gz", hash = "sha256:1f484cd618db2d71af65f90a0b8b36d20d80dc91a5138b811575c8677bf7c4fd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4654,15 +4639,14 @@ tqdm = ">=4.0"
|
||||
|
||||
[package.extras]
|
||||
alephalpha = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)"]
|
||||
all = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)", "opentelemetry-instrumentation-bedrock (>=0.47.1)", "opentelemetry-instrumentation-chromadb (>=0.47.1)", "opentelemetry-instrumentation-cohere (>=0.47.1)", "opentelemetry-instrumentation-crewai (>=0.47.1)", "opentelemetry-instrumentation-haystack (>=0.47.1)", "opentelemetry-instrumentation-lancedb (>=0.47.1)", "opentelemetry-instrumentation-langchain (>=0.47.1,<0.48.0)", "opentelemetry-instrumentation-llamaindex (>=0.47.1)", "opentelemetry-instrumentation-marqo (>=0.47.1)", "opentelemetry-instrumentation-mcp (>=0.47.1)", "opentelemetry-instrumentation-milvus (>=0.47.1)", "opentelemetry-instrumentation-mistralai (>=0.47.1)", "opentelemetry-instrumentation-ollama (>=0.47.1)", "opentelemetry-instrumentation-pinecone (>=0.47.1)", "opentelemetry-instrumentation-qdrant (>=0.47.1)", "opentelemetry-instrumentation-replicate (>=0.47.1)", "opentelemetry-instrumentation-sagemaker (>=0.47.1)", "opentelemetry-instrumentation-together (>=0.47.1)", "opentelemetry-instrumentation-transformers (>=0.47.1)", "opentelemetry-instrumentation-vertexai (>=0.47.1)", "opentelemetry-instrumentation-watsonx (>=0.47.1)", "opentelemetry-instrumentation-weaviate (>=0.47.1)"]
|
||||
all = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)", "opentelemetry-instrumentation-bedrock (>=0.47.1)", "opentelemetry-instrumentation-chromadb (>=0.47.1)", "opentelemetry-instrumentation-cohere (>=0.47.1)", "opentelemetry-instrumentation-crewai (>=0.47.1)", "opentelemetry-instrumentation-haystack (>=0.47.1)", "opentelemetry-instrumentation-lancedb (>=0.47.1)", "opentelemetry-instrumentation-langchain (>=0.47.1)", "opentelemetry-instrumentation-llamaindex (>=0.47.1)", "opentelemetry-instrumentation-marqo (>=0.47.1)", "opentelemetry-instrumentation-mcp (>=0.47.1)", "opentelemetry-instrumentation-milvus (>=0.47.1)", "opentelemetry-instrumentation-mistralai (>=0.47.1)", "opentelemetry-instrumentation-ollama (>=0.47.1)", "opentelemetry-instrumentation-pinecone (>=0.47.1)", "opentelemetry-instrumentation-qdrant (>=0.47.1)", "opentelemetry-instrumentation-replicate (>=0.47.1)", "opentelemetry-instrumentation-sagemaker (>=0.47.1)", "opentelemetry-instrumentation-together (>=0.47.1)", "opentelemetry-instrumentation-transformers (>=0.47.1)", "opentelemetry-instrumentation-vertexai (>=0.47.1)", "opentelemetry-instrumentation-watsonx (>=0.47.1)", "opentelemetry-instrumentation-weaviate (>=0.47.1)"]
|
||||
bedrock = ["opentelemetry-instrumentation-bedrock (>=0.47.1)"]
|
||||
chromadb = ["opentelemetry-instrumentation-chromadb (>=0.47.1)"]
|
||||
claude-agent-sdk = ["lmnr-claude-code-proxy (>=0.1.0a5)"]
|
||||
cohere = ["opentelemetry-instrumentation-cohere (>=0.47.1)"]
|
||||
crewai = ["opentelemetry-instrumentation-crewai (>=0.47.1)"]
|
||||
haystack = ["opentelemetry-instrumentation-haystack (>=0.47.1)"]
|
||||
lancedb = ["opentelemetry-instrumentation-lancedb (>=0.47.1)"]
|
||||
langchain = ["opentelemetry-instrumentation-langchain (>=0.47.1,<0.48.0)"]
|
||||
langchain = ["opentelemetry-instrumentation-langchain (>=0.47.1)"]
|
||||
llamaindex = ["opentelemetry-instrumentation-llamaindex (>=0.47.1)"]
|
||||
marqo = ["opentelemetry-instrumentation-marqo (>=0.47.1)"]
|
||||
mcp = ["opentelemetry-instrumentation-mcp (>=0.47.1)"]
|
||||
@@ -5660,28 +5644,28 @@ pydantic = ">=2.9"
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "2.8.0"
|
||||
version = "1.99.9"
|
||||
description = "The official Python library for the openai API"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main", "test"]
|
||||
files = [
|
||||
{file = "openai-2.8.0-py3-none-any.whl", hash = "sha256:ba975e347f6add2fe13529ccb94d54a578280e960765e5224c34b08d7e029ddf"},
|
||||
{file = "openai-2.8.0.tar.gz", hash = "sha256:4851908f6d6fcacbd47ba659c5ac084f7725b752b6bfa1e948b6fbfc111a6bad"},
|
||||
{file = "openai-1.99.9-py3-none-any.whl", hash = "sha256:9dbcdb425553bae1ac5d947147bebbd630d91bbfc7788394d4c4f3a35682ab3a"},
|
||||
{file = "openai-1.99.9.tar.gz", hash = "sha256:f2082d155b1ad22e83247c3de3958eb4255b20ccf4a1de2e6681b6957b554e92"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=3.5.0,<5"
|
||||
distro = ">=1.7.0,<2"
|
||||
httpx = ">=0.23.0,<1"
|
||||
jiter = ">=0.10.0,<1"
|
||||
jiter = ">=0.4.0,<1"
|
||||
pydantic = ">=1.9.0,<3"
|
||||
sniffio = "*"
|
||||
tqdm = ">4"
|
||||
typing-extensions = ">=4.11,<5"
|
||||
|
||||
[package.extras]
|
||||
aiohttp = ["aiohttp", "httpx-aiohttp (>=0.1.9)"]
|
||||
aiohttp = ["aiohttp", "httpx-aiohttp (>=0.1.8)"]
|
||||
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||
realtime = ["websockets (>=13,<16)"]
|
||||
voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"]
|
||||
@@ -5836,31 +5820,35 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.7.1"
|
||||
version = "1.0.0a5"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.7.1-py3-none-any.whl", hash = "sha256:e5c57f1b73293d00a68b77f9d290f59d9e2217d9df844fb01c7d2f929c3417f4"},
|
||||
{file = "openhands_agent_server-1.7.1.tar.gz", hash = "sha256:c82e1e6748ea3b4278ef2ee72f091dc37da6667c854b3aa3c0bc616086a82310"},
|
||||
]
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
aiosqlite = ">=0.19"
|
||||
alembic = ">=1.13"
|
||||
docker = ">=7.1,<8"
|
||||
fastapi = ">=0.104"
|
||||
openhands-sdk = "*"
|
||||
pydantic = ">=2"
|
||||
sqlalchemy = ">=2"
|
||||
uvicorn = ">=0.31.1"
|
||||
websockets = ">=12"
|
||||
wsproto = ">=1.2.0"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/OpenHands/software-agent-sdk.git"
|
||||
reference = "d5995c31c55e488d4ab0372d292973bc6fad71f1"
|
||||
resolved_reference = "d5995c31c55e488d4ab0372d292973bc6fad71f1"
|
||||
subdirectory = "openhands-agent-server"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-ai"
|
||||
version = "0.0.0-post.5742+ee50f333b"
|
||||
version = "0.0.0-post.5514+7c9e66194"
|
||||
description = "OpenHands: Code Less, Make More"
|
||||
optional = false
|
||||
python-versions = "^3.12,<3.14"
|
||||
@@ -5877,7 +5865,6 @@ bashlex = "^0.18"
|
||||
boto3 = "*"
|
||||
browsergym-core = "0.13.3"
|
||||
deprecated = "*"
|
||||
deprecation = "^2.1.0"
|
||||
dirhash = "*"
|
||||
docker = "*"
|
||||
fastapi = "*"
|
||||
@@ -5896,15 +5883,15 @@ json-repair = "*"
|
||||
jupyter_kernel_gateway = "*"
|
||||
kubernetes = "^33.1.0"
|
||||
libtmux = ">=0.46.2"
|
||||
litellm = ">=1.74.3, !=1.64.4, !=1.67.*"
|
||||
litellm = ">=1.74.3, <1.78.0, !=1.64.4, !=1.67.*"
|
||||
lmnr = "^0.7.20"
|
||||
memory-profiler = "^0.61.0"
|
||||
numpy = "*"
|
||||
openai = "2.8.0"
|
||||
openai = "1.99.9"
|
||||
openhands-aci = "0.3.2"
|
||||
openhands-agent-server = "1.7.1"
|
||||
openhands-sdk = "1.7.1"
|
||||
openhands-tools = "1.7.1"
|
||||
openhands-agent-server = {git = "https://github.com/OpenHands/software-agent-sdk.git", rev = "d5995c31c55e488d4ab0372d292973bc6fad71f1", subdirectory = "openhands-agent-server"}
|
||||
openhands-sdk = {git = "https://github.com/OpenHands/software-agent-sdk.git", rev = "d5995c31c55e488d4ab0372d292973bc6fad71f1", subdirectory = "openhands-sdk"}
|
||||
openhands-tools = {git = "https://github.com/OpenHands/software-agent-sdk.git", rev = "d5995c31c55e488d4ab0372d292973bc6fad71f1", subdirectory = "openhands-tools"}
|
||||
opentelemetry-api = "^1.33.1"
|
||||
opentelemetry-exporter-otlp-proto-grpc = "^1.33.1"
|
||||
pathspec = "^0.12.1"
|
||||
@@ -5960,22 +5947,19 @@ url = ".."
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.7.1"
|
||||
version = "1.0.0a5"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.7.1-py3-none-any.whl", hash = "sha256:e097e34dfbd45f38225ae2ff4830702424bcf742bc197b5a811540a75265b135"},
|
||||
{file = "openhands_sdk-1.7.1.tar.gz", hash = "sha256:e13d1fe8bf14dffd91e9080608072a989132c981cf9bfcd124fa4f7a68a13691"},
|
||||
]
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2.1.0"
|
||||
fastmcp = ">=2.11.3"
|
||||
httpx = ">=0.27.0"
|
||||
litellm = ">=1.80.10"
|
||||
lmnr = ">=0.7.24"
|
||||
litellm = ">=1.77.7.dev9"
|
||||
lmnr = ">=0.7.20"
|
||||
pydantic = ">=2.11.7"
|
||||
python-frontmatter = ">=1.1.0"
|
||||
python-json-logger = ">=3.3.0"
|
||||
@@ -5985,17 +5969,22 @@ websockets = ">=12"
|
||||
[package.extras]
|
||||
boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/OpenHands/software-agent-sdk.git"
|
||||
reference = "d5995c31c55e488d4ab0372d292973bc6fad71f1"
|
||||
resolved_reference = "d5995c31c55e488d4ab0372d292973bc6fad71f1"
|
||||
subdirectory = "openhands-sdk"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.7.1"
|
||||
version = "1.0.0a5"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.7.1-py3-none-any.whl", hash = "sha256:e25815f24925e94fbd4d8c3fd9b2147a0556fde595bf4f80a7dbba1014ea3c86"},
|
||||
{file = "openhands_tools-1.7.1.tar.gz", hash = "sha256:f3823f7bd302c78969c454730cf793eb63109ce2d986e78585989c53986cc966"},
|
||||
]
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
bashlex = ">=0.18"
|
||||
@@ -6006,7 +5995,13 @@ func-timeout = ">=4.3.5"
|
||||
libtmux = ">=0.46.2"
|
||||
openhands-sdk = "*"
|
||||
pydantic = ">=2.11.7"
|
||||
tom-swe = ">=1.0.3"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/OpenHands/software-agent-sdk.git"
|
||||
reference = "d5995c31c55e488d4ab0372d292973bc6fad71f1"
|
||||
resolved_reference = "d5995c31c55e488d4ab0372d292973bc6fad71f1"
|
||||
subdirectory = "openhands-tools"
|
||||
|
||||
[[package]]
|
||||
name = "openpyxl"
|
||||
@@ -13323,31 +13318,6 @@ dev = ["tokenizers[testing]"]
|
||||
docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
|
||||
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "pytest-asyncio", "requests", "ruff"]
|
||||
|
||||
[[package]]
|
||||
name = "tom-swe"
|
||||
version = "1.0.3"
|
||||
description = "Theory of Mind modeling for Software Engineering assistants"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "tom_swe-1.0.3-py3-none-any.whl", hash = "sha256:7b1172b29eb5c8fb7f1975016e7b6a238511b9ac2a7a980bd400dcb4e29773f2"},
|
||||
{file = "tom_swe-1.0.3.tar.gz", hash = "sha256:57c97d0104e563f15bd39edaf2aa6ac4c3e9444afd437fb92458700d22c6c0f5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
jinja2 = ">=3.0.0"
|
||||
json-repair = ">=0.1.0"
|
||||
litellm = ">=1.0.0"
|
||||
pydantic = ">=2.0.0"
|
||||
python-dotenv = ">=1.0.0"
|
||||
tiktoken = ">=0.8.0"
|
||||
tqdm = ">=4.65.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["aiofiles (>=23.0.0)", "black (>=22.0.0)", "datasets (>=2.0.0)", "fastapi (>=0.104.0)", "httpx (>=0.25.0)", "huggingface-hub (>=0.0.0)", "isort (>=5.0.0)", "mypy (>=1.0.0)", "numpy (>=1.24.0)", "pandas (>=2.0.0)", "pre-commit (>=3.6.0)", "pytest (>=7.0.0)", "pytest-cov (>=6.2.1)", "rich (>=13.0.0)", "ruff (>=0.3.0)", "typing-extensions (>=4.0.0)", "uvicorn (>=0.24.0)"]
|
||||
search = ["bm25s (>=0.2.0)", "pystemmer (>=2.2.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.10.2"
|
||||
|
||||
@@ -34,15 +34,8 @@ from server.routes.integration.jira_dc import jira_dc_integration_router # noqa
|
||||
from server.routes.integration.linear import linear_integration_router # noqa: E402
|
||||
from server.routes.integration.slack import slack_router # noqa: E402
|
||||
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
router as shared_conversation_router,
|
||||
)
|
||||
from server.sharing.shared_event_router import ( # noqa: E402
|
||||
router as shared_event_router,
|
||||
)
|
||||
|
||||
from openhands.server.app import app as base_app # noqa: E402
|
||||
from openhands.server.listen_socket import sio # noqa: E402
|
||||
@@ -67,13 +60,10 @@ base_app.mount('/internal/metrics', metrics_app())
|
||||
base_app.include_router(readiness_router) # Add routes for readiness checks
|
||||
base_app.include_router(api_router) # Add additional route for github auth
|
||||
base_app.include_router(oauth_router) # Add additional route for oauth callback
|
||||
base_app.include_router(oauth_device_router) # Add OAuth 2.0 Device Flow routes
|
||||
base_app.include_router(saas_user_router) # Add additional route SAAS user calls
|
||||
base_app.include_router(
|
||||
billing_router
|
||||
) # Add routes for credit management and Stripe payment integration
|
||||
base_app.include_router(shared_conversation_router)
|
||||
base_app.include_router(shared_event_router)
|
||||
|
||||
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
|
||||
if GITHUB_APP_CLIENT_ID:
|
||||
@@ -107,7 +97,6 @@ base_app.include_router(
|
||||
event_webhook_router
|
||||
) # Add routes for Events in nested runtimes
|
||||
|
||||
|
||||
base_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=PERMITTED_CORS_ORIGINS,
|
||||
|
||||
@@ -30,16 +30,3 @@ JIRA_DC_CLIENT_SECRET = os.getenv('JIRA_DC_CLIENT_SECRET', '').strip()
|
||||
JIRA_DC_BASE_URL = os.getenv('JIRA_DC_BASE_URL', '').strip()
|
||||
JIRA_DC_ENABLE_OAUTH = os.getenv('JIRA_DC_ENABLE_OAUTH', '1') in ('1', 'true')
|
||||
AUTH_URL = os.getenv('AUTH_URL', '').rstrip('/')
|
||||
ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
|
||||
'1',
|
||||
'true',
|
||||
't',
|
||||
'yes',
|
||||
'y',
|
||||
'on',
|
||||
)
|
||||
BLOCKED_EMAIL_DOMAINS = [
|
||||
domain.strip().lower()
|
||||
for domain in os.getenv('BLOCKED_EMAIL_DOMAINS', '').split(',')
|
||||
if domain.strip()
|
||||
]
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
from server.auth.constants import BLOCKED_EMAIL_DOMAINS
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class DomainBlocker:
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Initializing DomainBlocker')
|
||||
self.blocked_domains: list[str] = BLOCKED_EMAIL_DOMAINS
|
||||
if self.blocked_domains:
|
||||
logger.info(
|
||||
f'Successfully loaded {len(self.blocked_domains)} blocked email domains: {self.blocked_domains}'
|
||||
)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Check if domain blocking is enabled"""
|
||||
return bool(self.blocked_domains)
|
||||
|
||||
def _extract_domain(self, email: str) -> str | None:
|
||||
"""Extract and normalize email domain from email address"""
|
||||
if not email:
|
||||
return None
|
||||
try:
|
||||
# Extract domain part after @
|
||||
if '@' not in email:
|
||||
return None
|
||||
domain = email.split('@')[1].strip().lower()
|
||||
return domain if domain else None
|
||||
except Exception:
|
||||
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
|
||||
return None
|
||||
|
||||
def is_domain_blocked(self, email: str) -> bool:
|
||||
"""Check if email domain is blocked"""
|
||||
if not self.is_active():
|
||||
return False
|
||||
|
||||
if not email:
|
||||
logger.debug('No email provided for domain check')
|
||||
return False
|
||||
|
||||
domain = self._extract_domain(email)
|
||||
if not domain:
|
||||
logger.debug(f'Could not extract domain from email: {email}')
|
||||
return False
|
||||
|
||||
is_blocked = domain in self.blocked_domains
|
||||
if is_blocked:
|
||||
logger.warning(f'Email domain {domain} is blocked for email: {email}')
|
||||
else:
|
||||
logger.debug(f'Email domain {domain} is not blocked')
|
||||
|
||||
return is_blocked
|
||||
|
||||
|
||||
domain_blocker = DomainBlocker()
|
||||
@@ -1,109 +0,0 @@
|
||||
"""Email validation utilities for preventing duplicate signups with + modifier."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def extract_base_email(email: str) -> str | None:
|
||||
"""Extract base email from an email address.
|
||||
|
||||
For emails with + modifier, extracts the base email (local part before + and @, plus domain).
|
||||
For emails without + modifier, returns the email as-is.
|
||||
|
||||
Examples:
|
||||
extract_base_email("joe+test@example.com") -> "joe@example.com"
|
||||
extract_base_email("joe@example.com") -> "joe@example.com"
|
||||
extract_base_email("joe+openhands+test@example.com") -> "joe@example.com"
|
||||
|
||||
Args:
|
||||
email: The email address to process
|
||||
|
||||
Returns:
|
||||
The base email address, or None if email format is invalid
|
||||
"""
|
||||
if not email or '@' not in email:
|
||||
return None
|
||||
|
||||
try:
|
||||
local_part, domain = email.rsplit('@', 1)
|
||||
# Extract the part before + if it exists
|
||||
base_local = local_part.split('+', 1)[0]
|
||||
return f'{base_local}@{domain}'
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def has_plus_modifier(email: str) -> bool:
|
||||
"""Check if an email address contains a + modifier.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if email contains + before @, False otherwise
|
||||
"""
|
||||
if not email or '@' not in email:
|
||||
return False
|
||||
|
||||
try:
|
||||
local_part, _ = email.rsplit('@', 1)
|
||||
return '+' in local_part
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
def matches_base_email(email: str, base_email: str) -> bool:
|
||||
"""Check if an email matches a base email pattern.
|
||||
|
||||
An email matches if:
|
||||
- It is exactly the base email (e.g., joe@example.com)
|
||||
- It has the same base local part and domain, with or without + modifier
|
||||
(e.g., joe+test@example.com matches base joe@example.com)
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
base_email: The base email to match against
|
||||
|
||||
Returns:
|
||||
True if email matches the base pattern, False otherwise
|
||||
"""
|
||||
if not email or not base_email:
|
||||
return False
|
||||
|
||||
# Extract base from both emails for comparison
|
||||
email_base = extract_base_email(email)
|
||||
base_email_normalized = extract_base_email(base_email)
|
||||
|
||||
if not email_base or not base_email_normalized:
|
||||
return False
|
||||
|
||||
# Emails match if they have the same base
|
||||
return email_base.lower() == base_email_normalized.lower()
|
||||
|
||||
|
||||
def get_base_email_regex_pattern(base_email: str) -> re.Pattern | None:
|
||||
"""Generate a regex pattern to match emails with the same base.
|
||||
|
||||
For base_email "joe@example.com", the pattern will match:
|
||||
- joe@example.com
|
||||
- joe+anything@example.com
|
||||
|
||||
Args:
|
||||
base_email: The base email address
|
||||
|
||||
Returns:
|
||||
A compiled regex pattern, or None if base_email is invalid
|
||||
"""
|
||||
base = extract_base_email(base_email)
|
||||
if not base:
|
||||
return None
|
||||
|
||||
try:
|
||||
local_part, domain = base.rsplit('@', 1)
|
||||
# Escape special regex characters in local part and domain
|
||||
escaped_local = re.escape(local_part)
|
||||
escaped_domain = re.escape(domain)
|
||||
# Pattern: joe@example.com OR joe+anything@example.com
|
||||
pattern = rf'^{escaped_local}(\+[^@\s]+)?@{escaped_domain}$'
|
||||
return re.compile(pattern, re.IGNORECASE)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
@@ -13,7 +13,6 @@ from server.auth.auth_error import (
|
||||
ExpiredError,
|
||||
NoCredentialsError,
|
||||
)
|
||||
from server.auth.domain_blocker import domain_blocker
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
@@ -154,10 +153,8 @@ class SaasUserAuth(UserAuth):
|
||||
try:
|
||||
# TODO: I think we can do this in a single request if we refactor
|
||||
with session_maker() as session:
|
||||
tokens = (
|
||||
session.query(AuthTokens)
|
||||
.where(AuthTokens.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
tokens = session.query(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == self.user_id
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
@@ -206,15 +203,6 @@ class SaasUserAuth(UserAuth):
|
||||
self.settings_store = settings_store
|
||||
return settings_store
|
||||
|
||||
async def get_mcp_api_key(self) -> str:
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
mcp_api_key = api_key_store.retrieve_mcp_api_key(self.user_id)
|
||||
if not mcp_api_key:
|
||||
mcp_api_key = api_key_store.create_api_key(
|
||||
self.user_id, 'MCP_API_KEY', None
|
||||
)
|
||||
return mcp_api_key
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, request: Request) -> UserAuth:
|
||||
logger.debug('saas_user_auth_get_instance')
|
||||
@@ -255,12 +243,7 @@ def get_api_key_from_header(request: Request):
|
||||
# This is a temp hack
|
||||
# Streamable HTTP MCP Client works via redirect requests, but drops the Authorization header for reason
|
||||
# We include `X-Session-API-Key` header by default due to nested runtimes, so it used as a drop in replacement here
|
||||
session_api_key = request.headers.get('X-Session-API-Key')
|
||||
if session_api_key:
|
||||
return session_api_key
|
||||
|
||||
# Fallback to X-Access-Token header as an additional option
|
||||
return request.headers.get('X-Access-Token')
|
||||
return request.headers.get('X-Session-API-Key')
|
||||
|
||||
|
||||
async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
|
||||
@@ -315,16 +298,6 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
|
||||
user_id = access_token_payload['sub']
|
||||
email = access_token_payload['email']
|
||||
email_verified = access_token_payload['email_verified']
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for existing user with email: {email}'
|
||||
)
|
||||
raise AuthError(
|
||||
'Access denied: Your email domain is not allowed to access this service'
|
||||
)
|
||||
|
||||
logger.debug('saas_user_auth_from_signed_token:return')
|
||||
|
||||
return SaasUserAuth(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
@@ -26,11 +25,6 @@ from server.auth.constants import (
|
||||
KEYCLOAK_SERVER_URL,
|
||||
KEYCLOAK_SERVER_URL_EXT,
|
||||
)
|
||||
from server.auth.email_validation import (
|
||||
extract_base_email,
|
||||
get_base_email_regex_pattern,
|
||||
matches_base_email,
|
||||
)
|
||||
from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
@@ -515,183 +509,6 @@ class TokenManager:
|
||||
logger.info(f'Got user ID {keycloak_user_id} from email: {email}')
|
||||
return keycloak_user_id
|
||||
|
||||
async def _query_users_by_wildcard_pattern(
|
||||
self, local_part: str, domain: str
|
||||
) -> dict[str, dict]:
|
||||
"""Query Keycloak for users matching a wildcard email pattern.
|
||||
|
||||
Tries multiple query methods to find users with emails matching
|
||||
the pattern {local_part}*@{domain}. This catches the base email
|
||||
and all + modifier variants.
|
||||
|
||||
Args:
|
||||
local_part: The local part of the email (before @)
|
||||
domain: The domain part of the email (after @)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping user IDs to user objects
|
||||
"""
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
all_users = {}
|
||||
|
||||
# Query for users with emails matching the base pattern using wildcard
|
||||
# Pattern: {local_part}*@{domain} - catches base email and all + variants
|
||||
# This may also catch unintended matches (e.g., joesmith@example.com), but
|
||||
# they will be filtered out by the regex pattern check later
|
||||
# Use 'search' parameter for Keycloak 26+ (better wildcard support)
|
||||
wildcard_queries = [
|
||||
{'search': f'{local_part}*@{domain}'}, # Try 'search' parameter first
|
||||
{'q': f'email:{local_part}*@{domain}'}, # Fallback to 'q' parameter
|
||||
]
|
||||
|
||||
for query_params in wildcard_queries:
|
||||
try:
|
||||
users = await keycloak_admin.a_get_users(query_params)
|
||||
for user in users:
|
||||
all_users[user.get('id')] = user
|
||||
break # Success, no need to try fallback
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f'Wildcard query failed with {list(query_params.keys())[0]}: {e}'
|
||||
)
|
||||
continue # Try next query method
|
||||
|
||||
return all_users
|
||||
|
||||
def _find_duplicate_in_users(
|
||||
self, users: dict[str, dict], base_email: str, current_user_id: str
|
||||
) -> bool:
|
||||
"""Check if any user in the provided list matches the base email pattern.
|
||||
|
||||
Filters users to find duplicates that match the base email pattern,
|
||||
excluding the current user.
|
||||
|
||||
Args:
|
||||
users: Dictionary mapping user IDs to user objects
|
||||
base_email: The base email to match against
|
||||
current_user_id: The user ID to exclude from the check
|
||||
|
||||
Returns:
|
||||
True if a duplicate is found, False otherwise
|
||||
"""
|
||||
regex_pattern = get_base_email_regex_pattern(base_email)
|
||||
if not regex_pattern:
|
||||
logger.warning(
|
||||
f'Could not generate regex pattern for base email: {base_email}'
|
||||
)
|
||||
# Fallback to simple matching
|
||||
for user in users.values():
|
||||
user_email = user.get('email', '').lower()
|
||||
if (
|
||||
user_email
|
||||
and user.get('id') != current_user_id
|
||||
and matches_base_email(user_email, base_email)
|
||||
):
|
||||
logger.info(
|
||||
f'Found duplicate email: {user_email} matches base {base_email}'
|
||||
)
|
||||
return True
|
||||
else:
|
||||
for user in users.values():
|
||||
user_email = user.get('email', '')
|
||||
if (
|
||||
user_email
|
||||
and user.get('id') != current_user_id
|
||||
and regex_pattern.match(user_email)
|
||||
):
|
||||
logger.info(
|
||||
f'Found duplicate email: {user_email} matches base {base_email}'
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(2),
|
||||
retry=retry_if_exception_type(KeycloakConnectionError),
|
||||
before_sleep=_before_sleep_callback,
|
||||
)
|
||||
async def check_duplicate_base_email(
|
||||
self, email: str, current_user_id: str
|
||||
) -> bool:
|
||||
"""Check if a user with the same base email already exists.
|
||||
|
||||
This method checks for duplicate signups using email + modifier.
|
||||
It checks if any user exists with the same base email, regardless of whether
|
||||
the provided email has a + modifier or not.
|
||||
|
||||
Examples:
|
||||
- If email is "joe+test@example.com", it checks for existing users with
|
||||
base email "joe@example.com" (e.g., "joe@example.com", "joe+1@example.com")
|
||||
- If email is "joe@example.com", it checks for existing users with
|
||||
base email "joe@example.com" (e.g., "joe+1@example.com", "joe+test@example.com")
|
||||
|
||||
Args:
|
||||
email: The email address to check (may or may not contain + modifier)
|
||||
current_user_id: The user ID of the current user (to exclude from check)
|
||||
|
||||
Returns:
|
||||
True if a duplicate is found (excluding current user), False otherwise
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
|
||||
base_email = extract_base_email(email)
|
||||
if not base_email:
|
||||
logger.warning(f'Could not extract base email from: {email}')
|
||||
return False
|
||||
|
||||
try:
|
||||
local_part, domain = base_email.rsplit('@', 1)
|
||||
users = await self._query_users_by_wildcard_pattern(local_part, domain)
|
||||
return self._find_duplicate_in_users(users, base_email, current_user_id)
|
||||
|
||||
except KeycloakConnectionError:
|
||||
logger.exception('KeycloakConnectionError when checking duplicate email')
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f'Unexpected error checking duplicate email: {e}')
|
||||
# On any error, allow signup to proceed (fail open)
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(2),
|
||||
retry=retry_if_exception_type(KeycloakConnectionError),
|
||||
before_sleep=_before_sleep_callback,
|
||||
)
|
||||
async def delete_keycloak_user(self, user_id: str) -> bool:
|
||||
"""Delete a user from Keycloak.
|
||||
|
||||
This method is used to clean up user accounts that were created
|
||||
but should not exist (e.g., duplicate email signups).
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to delete
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
# Use the sync method (python-keycloak doesn't have async delete_user)
|
||||
# Run it in a thread executor to avoid blocking the event loop
|
||||
await asyncio.to_thread(keycloak_admin.delete_user, user_id)
|
||||
logger.info(f'Successfully deleted Keycloak user {user_id}')
|
||||
return True
|
||||
except KeycloakConnectionError:
|
||||
logger.exception(f'KeycloakConnectionError when deleting user {user_id}')
|
||||
raise
|
||||
except KeycloakError as e:
|
||||
# User might not exist or already deleted
|
||||
logger.warning(
|
||||
f'KeycloakError when deleting user {user_id}: {e}',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception(f'Unexpected error deleting Keycloak user {user_id}: {e}')
|
||||
return False
|
||||
|
||||
async def get_user_info_from_user_id(self, user_id: str) -> dict | None:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
user = await keycloak_admin.a_get_user(user_id)
|
||||
@@ -710,49 +527,6 @@ class TokenManager:
|
||||
github_id = github_ids[0]
|
||||
return github_id
|
||||
|
||||
async def disable_keycloak_user(
|
||||
self, user_id: str, email: str | None = None
|
||||
) -> None:
|
||||
"""Disable a Keycloak user account.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to disable
|
||||
email: Optional email address for logging purposes
|
||||
|
||||
This method attempts to disable the user account but will not raise exceptions.
|
||||
Errors are logged but do not prevent the operation from completing.
|
||||
"""
|
||||
try:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
# Get current user to preserve other fields
|
||||
user = await keycloak_admin.a_get_user(user_id)
|
||||
if user:
|
||||
# Update user with enabled=False to disable the account
|
||||
await keycloak_admin.a_update_user(
|
||||
user_id=user_id,
|
||||
payload={
|
||||
'enabled': False,
|
||||
'username': user.get('username', ''),
|
||||
'email': user.get('email', ''),
|
||||
'emailVerified': user.get('emailVerified', False),
|
||||
},
|
||||
)
|
||||
email_str = f', email: {email}' if email else ''
|
||||
logger.info(
|
||||
f'Disabled Keycloak account for user_id: {user_id}{email_str}'
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'User not found in Keycloak when attempting to disable: {user_id}'
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't raise - the caller should handle the blocking regardless
|
||||
email_str = f', email: {email}' if email else ''
|
||||
logger.error(
|
||||
f'Failed to disable Keycloak account for user_id: {user_id}{email_str}: {str(e)}',
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def store_org_token(self, installation_id: int, installation_token: str):
|
||||
"""Store a GitHub App installation token.
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
4: 'claude-sonnet-4-20250514',
|
||||
5: 'claude-opus-4-5-20251101',
|
||||
}
|
||||
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
|
||||
331
enterprise/server/legacy_conversation_manager.py
Normal file
331
enterprise/server/legacy_conversation_manager.py
Normal file
@@ -0,0 +1,331 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import socketio
|
||||
from server.clustered_conversation_manager import ClusteredConversationManager
|
||||
from server.saas_nested_conversation_manager import SaasNestedConversationManager
|
||||
|
||||
from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.conversation_manager.conversation_manager import (
|
||||
ConversationManager,
|
||||
)
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import wait_all
|
||||
|
||||
_LEGACY_ENTRY_TIMEOUT_SECONDS = 3600
|
||||
|
||||
|
||||
@dataclass
|
||||
class LegacyCacheEntry:
|
||||
"""Cache entry for legacy mode status."""
|
||||
|
||||
is_legacy: bool
|
||||
timestamp: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class LegacyConversationManager(ConversationManager):
|
||||
"""
|
||||
Conversation manager for use while migrating - since existing conversations are not nested!
|
||||
Separate class from SaasNestedConversationManager so it can be easliy removed in a few weeks.
|
||||
(As of 2025-07-23)
|
||||
"""
|
||||
|
||||
sio: socketio.AsyncServer
|
||||
config: OpenHandsConfig
|
||||
server_config: ServerConfig
|
||||
file_store: FileStore
|
||||
conversation_manager: SaasNestedConversationManager
|
||||
legacy_conversation_manager: ClusteredConversationManager
|
||||
_legacy_cache: dict[str, LegacyCacheEntry] = field(default_factory=dict)
|
||||
|
||||
async def __aenter__(self):
|
||||
await wait_all(
|
||||
[
|
||||
self.conversation_manager.__aenter__(),
|
||||
self.legacy_conversation_manager.__aenter__(),
|
||||
]
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await wait_all(
|
||||
[
|
||||
self.conversation_manager.__aexit__(exc_type, exc_value, traceback),
|
||||
self.legacy_conversation_manager.__aexit__(
|
||||
exc_type, exc_value, traceback
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
async def request_llm_completion(
|
||||
self,
|
||||
sid: str,
|
||||
service_id: str,
|
||||
llm_config: LLMConfig,
|
||||
messages: list[dict[str, str]],
|
||||
) -> str:
|
||||
session = self.get_agent_session(sid)
|
||||
llm_registry = session.llm_registry
|
||||
return llm_registry.request_extraneous_completion(
|
||||
service_id, llm_config, messages
|
||||
)
|
||||
|
||||
async def attach_to_conversation(
|
||||
self, sid: str, user_id: str | None = None
|
||||
) -> ServerConversation | None:
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.attach_to_conversation(
|
||||
sid, user_id
|
||||
)
|
||||
return await self.conversation_manager.attach_to_conversation(sid, user_id)
|
||||
|
||||
async def detach_from_conversation(self, conversation: ServerConversation):
|
||||
if await self.should_start_in_legacy_mode(conversation.sid):
|
||||
return await self.legacy_conversation_manager.detach_from_conversation(
|
||||
conversation
|
||||
)
|
||||
return await self.conversation_manager.detach_from_conversation(conversation)
|
||||
|
||||
async def join_conversation(
|
||||
self,
|
||||
sid: str,
|
||||
connection_id: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
) -> AgentLoopInfo:
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.join_conversation(
|
||||
sid, connection_id, settings, user_id
|
||||
)
|
||||
return await self.conversation_manager.join_conversation(
|
||||
sid, connection_id, settings, user_id
|
||||
)
|
||||
|
||||
def get_agent_session(self, sid: str):
|
||||
session = self.legacy_conversation_manager.get_agent_session(sid)
|
||||
if session is None:
|
||||
session = self.conversation_manager.get_agent_session(sid)
|
||||
return session
|
||||
|
||||
async def get_running_agent_loops(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> set[str]:
|
||||
if filter_to_sids and len(filter_to_sids) == 1:
|
||||
sid = next(iter(filter_to_sids))
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.get_running_agent_loops(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
return await self.conversation_manager.get_running_agent_loops(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
|
||||
# Get all running agent loops from both managers
|
||||
agent_loops, legacy_agent_loops = await wait_all(
|
||||
[
|
||||
self.conversation_manager.get_running_agent_loops(
|
||||
user_id, filter_to_sids
|
||||
),
|
||||
self.legacy_conversation_manager.get_running_agent_loops(
|
||||
user_id, filter_to_sids
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Combine the results
|
||||
result = set()
|
||||
for sid in legacy_agent_loops:
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
result.add(sid)
|
||||
|
||||
for sid in agent_loops:
|
||||
if not await self.should_start_in_legacy_mode(sid):
|
||||
result.add(sid)
|
||||
|
||||
return result
|
||||
|
||||
async def is_agent_loop_running(self, sid: str) -> bool:
|
||||
return bool(await self.get_running_agent_loops(filter_to_sids={sid}))
|
||||
|
||||
async def get_connections(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> dict[str, str]:
|
||||
if filter_to_sids and len(filter_to_sids) == 1:
|
||||
sid = next(iter(filter_to_sids))
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.get_connections(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
return await self.conversation_manager.get_connections(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
agent_loops, legacy_agent_loops = await wait_all(
|
||||
[
|
||||
self.conversation_manager.get_connections(user_id, filter_to_sids),
|
||||
self.legacy_conversation_manager.get_connections(
|
||||
user_id, filter_to_sids
|
||||
),
|
||||
]
|
||||
)
|
||||
legacy_agent_loops.update(agent_loops)
|
||||
return legacy_agent_loops
|
||||
|
||||
async def maybe_start_agent_loop(
|
||||
self,
|
||||
sid: str,
|
||||
settings: Settings,
|
||||
user_id: str, # type: ignore[override]
|
||||
initial_user_msg: MessageAction | None = None,
|
||||
replay_json: str | None = None,
|
||||
) -> AgentLoopInfo:
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.maybe_start_agent_loop(
|
||||
sid, settings, user_id, initial_user_msg, replay_json
|
||||
)
|
||||
return await self.conversation_manager.maybe_start_agent_loop(
|
||||
sid, settings, user_id, initial_user_msg, replay_json
|
||||
)
|
||||
|
||||
async def send_to_event_stream(self, connection_id: str, data: dict):
|
||||
return await self.legacy_conversation_manager.send_to_event_stream(
|
||||
connection_id, data
|
||||
)
|
||||
|
||||
async def send_event_to_conversation(self, sid: str, data: dict):
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
await self.legacy_conversation_manager.send_event_to_conversation(sid, data)
|
||||
await self.conversation_manager.send_event_to_conversation(sid, data)
|
||||
|
||||
async def disconnect_from_session(self, connection_id: str):
|
||||
return await self.legacy_conversation_manager.disconnect_from_session(
|
||||
connection_id
|
||||
)
|
||||
|
||||
async def close_session(self, sid: str):
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
await self.legacy_conversation_manager.close_session(sid)
|
||||
await self.conversation_manager.close_session(sid)
|
||||
|
||||
async def get_agent_loop_info(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> list[AgentLoopInfo]:
|
||||
if filter_to_sids and len(filter_to_sids) == 1:
|
||||
sid = next(iter(filter_to_sids))
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.get_agent_loop_info(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
return await self.conversation_manager.get_agent_loop_info(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
agent_loops, legacy_agent_loops = await wait_all(
|
||||
[
|
||||
self.conversation_manager.get_agent_loop_info(user_id, filter_to_sids),
|
||||
self.legacy_conversation_manager.get_agent_loop_info(
|
||||
user_id, filter_to_sids
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Combine results
|
||||
result = []
|
||||
legacy_sids = set()
|
||||
|
||||
# Add legacy agent loops
|
||||
for agent_loop in legacy_agent_loops:
|
||||
if await self.should_start_in_legacy_mode(agent_loop.conversation_id):
|
||||
result.append(agent_loop)
|
||||
legacy_sids.add(agent_loop.conversation_id)
|
||||
|
||||
# Add non-legacy agent loops
|
||||
for agent_loop in agent_loops:
|
||||
if (
|
||||
agent_loop.conversation_id not in legacy_sids
|
||||
and not await self.should_start_in_legacy_mode(
|
||||
agent_loop.conversation_id
|
||||
)
|
||||
):
|
||||
result.append(agent_loop)
|
||||
|
||||
return result
|
||||
|
||||
def _cleanup_expired_cache_entries(self):
|
||||
"""Remove expired entries from the local cache."""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key
|
||||
for key, entry in self._legacy_cache.items()
|
||||
if current_time - entry.timestamp > _LEGACY_ENTRY_TIMEOUT_SECONDS
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._legacy_cache[key]
|
||||
|
||||
async def should_start_in_legacy_mode(self, conversation_id: str) -> bool:
|
||||
"""
|
||||
Check if a conversation should run in legacy mode by directly checking the runtime.
|
||||
The /list method does not include stopped conversations even though the PVC for these
|
||||
may not yet have been deleted, so we need to check /sessions/{session_id} directly.
|
||||
"""
|
||||
# Clean up expired entries periodically
|
||||
self._cleanup_expired_cache_entries()
|
||||
|
||||
# First check the local cache
|
||||
if conversation_id in self._legacy_cache:
|
||||
cached_entry = self._legacy_cache[conversation_id]
|
||||
# Check if the cached value is still valid
|
||||
if time.time() - cached_entry.timestamp <= _LEGACY_ENTRY_TIMEOUT_SECONDS:
|
||||
return cached_entry.is_legacy
|
||||
|
||||
# If not in cache or expired, check the runtime directly
|
||||
runtime = await self.conversation_manager._get_runtime(conversation_id)
|
||||
is_legacy = self.is_legacy_runtime(runtime)
|
||||
|
||||
# Cache the result with current timestamp
|
||||
self._legacy_cache[conversation_id] = LegacyCacheEntry(is_legacy, time.time())
|
||||
|
||||
return is_legacy
|
||||
|
||||
def is_legacy_runtime(self, runtime: dict | None) -> bool:
|
||||
"""
|
||||
Determine if a runtime is a legacy runtime based on its command.
|
||||
|
||||
Args:
|
||||
runtime: The runtime dictionary or None if not found
|
||||
|
||||
Returns:
|
||||
bool: True if this is a legacy runtime, False otherwise
|
||||
"""
|
||||
if runtime is None:
|
||||
return False
|
||||
return 'openhands.server' not in runtime['command']
|
||||
|
||||
@classmethod
|
||||
def get_instance(
|
||||
cls,
|
||||
sio: socketio.AsyncServer,
|
||||
config: OpenHandsConfig,
|
||||
file_store: FileStore,
|
||||
server_config: ServerConfig,
|
||||
monitoring_listener: MonitoringListener,
|
||||
) -> ConversationManager:
|
||||
return LegacyConversationManager(
|
||||
sio=sio,
|
||||
config=config,
|
||||
server_config=server_config,
|
||||
file_store=file_store,
|
||||
conversation_manager=SaasNestedConversationManager.get_instance(
|
||||
sio, config, file_store, server_config, monitoring_listener
|
||||
),
|
||||
legacy_conversation_manager=ClusteredConversationManager.get_instance(
|
||||
sio, config, file_store, server_config, monitoring_listener
|
||||
),
|
||||
)
|
||||
@@ -152,23 +152,17 @@ class SetAuthCookieMiddleware:
|
||||
return False
|
||||
path = request.url.path
|
||||
|
||||
ignore_paths = (
|
||||
is_api_that_should_attach = path.startswith('/api') and path not in (
|
||||
'/api/options/config',
|
||||
'/api/keycloak/callback',
|
||||
'/api/billing/success',
|
||||
'/api/billing/cancel',
|
||||
'/api/billing/customer-setup-success',
|
||||
'/api/billing/stripe-webhook',
|
||||
'/api/email/resend',
|
||||
'/oauth/device/authorize',
|
||||
'/oauth/device/token',
|
||||
)
|
||||
if path in ignore_paths:
|
||||
return False
|
||||
|
||||
is_mcp = path.startswith('/mcp')
|
||||
is_api_route = path.startswith('/api')
|
||||
return is_api_route or is_mcp
|
||||
return is_api_that_should_attach or is_mcp
|
||||
|
||||
async def _logout(self, request: Request):
|
||||
# Log out of keycloak - this prevents issues where you did not log in with the idp you believe you used
|
||||
|
||||
@@ -12,9 +12,7 @@ from server.auth.constants import (
|
||||
KEYCLOAK_CLIENT_ID,
|
||||
KEYCLOAK_REALM_NAME,
|
||||
KEYCLOAK_SERVER_URL_EXT,
|
||||
ROLE_CHECK_ENABLED,
|
||||
)
|
||||
from server.auth.domain_blocker import domain_blocker
|
||||
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
@@ -134,88 +132,13 @@ async def keycloak_callback(
|
||||
|
||||
user_info = await token_manager.get_user_info(keycloak_access_token)
|
||||
logger.debug(f'user_info: {user_info}')
|
||||
if ROLE_CHECK_ENABLED and 'roles' not in user_info:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'Missing required role'},
|
||||
)
|
||||
|
||||
if 'sub' not in user_info or 'preferred_username' not in user_info:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={'error': 'Missing user ID or username in response'},
|
||||
)
|
||||
|
||||
email = user_info.get('email')
|
||||
user_id = user_info['sub']
|
||||
|
||||
# Check if email domain is blocked
|
||||
email = user_info.get('email')
|
||||
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
)
|
||||
|
||||
# Disable the Keycloak account
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': 'Access denied: Your email domain is not allowed to access this service'
|
||||
},
|
||||
)
|
||||
|
||||
# Check for duplicate email with + modifier
|
||||
if email:
|
||||
try:
|
||||
has_duplicate = await token_manager.check_duplicate_base_email(
|
||||
email, user_id
|
||||
)
|
||||
if has_duplicate:
|
||||
logger.warning(
|
||||
f'Blocked signup attempt for email {email} - duplicate base email found',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
|
||||
# Delete the Keycloak user that was automatically created during OAuth
|
||||
# This prevents orphaned accounts in Keycloak
|
||||
# The delete_keycloak_user method already handles all errors internally
|
||||
deletion_success = await token_manager.delete_keycloak_user(user_id)
|
||||
if deletion_success:
|
||||
logger.info(
|
||||
f'Deleted Keycloak user {user_id} after detecting duplicate email {email}'
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'Failed to delete Keycloak user {user_id} after detecting duplicate email {email}. '
|
||||
f'User may need to be manually cleaned up.'
|
||||
)
|
||||
|
||||
# Redirect to home page with query parameter indicating the issue
|
||||
home_url = f'{request.base_url}?duplicated_email=true'
|
||||
return RedirectResponse(home_url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log error but allow signup to proceed (fail open)
|
||||
logger.error(
|
||||
f'Error checking duplicate email for {email}: {e}',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
|
||||
# Check email verification status
|
||||
email_verified = user_info.get('email_verified', False)
|
||||
if not email_verified:
|
||||
# Send verification email
|
||||
# Import locally to avoid circular import with email.py
|
||||
from server.routes.email import verify_email
|
||||
|
||||
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
|
||||
redirect_url = (
|
||||
f'{request.base_url}?email_verification_required=true&user_id={user_id}'
|
||||
)
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
return response
|
||||
|
||||
# default to github IDP for now.
|
||||
# TODO: remove default once Keycloak is updated universally with the new attribute.
|
||||
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
|
||||
|
||||
@@ -111,24 +111,10 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
|
||||
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
|
||||
if not stripe_service.STRIPE_API_KEY:
|
||||
return GetCreditsResponse()
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
user_json = await _get_litellm_user(client, user_id)
|
||||
credits = calculate_credits(user_json['user_info'])
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f'litellm_get_user_failed: {type(e).__name__}: {e}',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': e.response.status_code,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve credit balance from billing service',
|
||||
)
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
user_json = await _get_litellm_user(client, user_id)
|
||||
credits = calculate_credits(user_json['user_info'])
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
|
||||
|
||||
# Endpoint to retrieve user's current subscription access
|
||||
|
||||
@@ -7,7 +7,6 @@ from server.auth.constants import KEYCLOAK_CLIENT_ID
|
||||
from server.auth.keycloak_manager import get_keycloak_admin
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.auth import set_response_cookie
|
||||
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
@@ -29,11 +28,6 @@ class EmailUpdate(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
class ResendEmailVerificationRequest(BaseModel):
|
||||
user_id: str | None = None
|
||||
is_auth_flow: bool = False
|
||||
|
||||
|
||||
@api_router.post('')
|
||||
async def update_email(
|
||||
email_data: EmailUpdate, request: Request, user_id: str = Depends(get_user_id)
|
||||
@@ -80,7 +74,7 @@ async def update_email(
|
||||
accepted_tos=user_auth.accepted_tos,
|
||||
)
|
||||
|
||||
await verify_email(request=request, user_id=user_id)
|
||||
await _verify_email(request=request, user_id=user_id)
|
||||
|
||||
logger.info(f'Updating email address for {user_id} to {email}')
|
||||
return response
|
||||
@@ -96,41 +90,9 @@ async def update_email(
|
||||
)
|
||||
|
||||
|
||||
@api_router.put('/resend')
|
||||
async def resend_email_verification(
|
||||
request: Request,
|
||||
body: ResendEmailVerificationRequest | None = None,
|
||||
):
|
||||
# Get user_id from body if provided, otherwise from auth
|
||||
user_id: str | None = None
|
||||
if body and body.user_id:
|
||||
user_id = body.user_id
|
||||
else:
|
||||
try:
|
||||
user_id = await get_user_id(request)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='user_id is required in request body or user must be authenticated',
|
||||
)
|
||||
|
||||
# Check rate limit (uses user_id if available, otherwise falls back to IP)
|
||||
# Use 30 seconds for user-based rate limiting to match frontend cooldown
|
||||
await check_rate_limit_by_user_id(
|
||||
request=request,
|
||||
key_prefix='email_resend',
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=30,
|
||||
ip_rate_limit_seconds=60, # 1 minute for IP-based limiting (more lenient)
|
||||
)
|
||||
|
||||
# Get is_auth_flow from body if provided, default to False
|
||||
is_auth_flow = body.is_auth_flow if body else False
|
||||
|
||||
await verify_email(request=request, user_id=user_id, is_auth_flow=is_auth_flow)
|
||||
@api_router.put('/verify')
|
||||
async def verify_email(request: Request, user_id: str = Depends(get_user_id)):
|
||||
await _verify_email(request=request, user_id=user_id)
|
||||
|
||||
logger.info(f'Resending verification email for {user_id}')
|
||||
return JSONResponse(
|
||||
@@ -162,13 +124,10 @@ async def verified_email(request: Request):
|
||||
return response
|
||||
|
||||
|
||||
async def verify_email(request: Request, user_id: str, is_auth_flow: bool = False):
|
||||
async def _verify_email(request: Request, user_id: str):
|
||||
keycloak_admin = get_keycloak_admin()
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
if is_auth_flow:
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}?email_verified=true'
|
||||
else:
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified'
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified'
|
||||
logger.info(f'Redirect URI: {redirect_uri}')
|
||||
await keycloak_admin.a_send_verify_email(
|
||||
user_id=user_id,
|
||||
|
||||
@@ -134,12 +134,12 @@ async def _process_batch_operations_background(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'error_processing_batch_operation: {type(e).__name__}: {e}',
|
||||
'error_processing_batch_operation',
|
||||
extra={
|
||||
'path': batch_op.path,
|
||||
'method': str(batch_op.method),
|
||||
'error': str(e),
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
@@ -59,8 +58,7 @@ async def github_events(
|
||||
)
|
||||
|
||||
try:
|
||||
# Add timeout to prevent hanging on slow/stalled clients
|
||||
payload = await asyncio.wait_for(request.body(), timeout=15.0)
|
||||
payload = await request.body()
|
||||
verify_github_signature(payload, x_hub_signature_256)
|
||||
|
||||
payload_data = await request.json()
|
||||
@@ -80,12 +78,6 @@ async def github_events(
|
||||
status_code=200,
|
||||
content={'message': 'GitHub events endpoint reached successfully.'},
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning('GitHub webhook request timed out waiting for request body')
|
||||
return JSONResponse(
|
||||
status_code=408,
|
||||
content={'error': 'Request timeout - client took too long to send data.'},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error processing GitHub event: {e}')
|
||||
return JSONResponse(status_code=400, content={'error': 'Invalid payload.'})
|
||||
|
||||
@@ -1,324 +0,0 @@
|
||||
"""OAuth 2.0 Device Flow endpoints for CLI authentication."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.device_code_store import DeviceCodeStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEVICE_CODE_EXPIRES_IN = 600 # 10 minutes
|
||||
DEVICE_TOKEN_POLL_INTERVAL = 5 # seconds
|
||||
|
||||
API_KEY_NAME = 'Device Link Access Key'
|
||||
KEY_EXPIRATION_TIME = timedelta(days=1) # Key expires in 24 hours
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DeviceAuthorizationResponse(BaseModel):
|
||||
device_code: str
|
||||
user_code: str
|
||||
verification_uri: str
|
||||
verification_uri_complete: str
|
||||
expires_in: int
|
||||
interval: int
|
||||
|
||||
|
||||
class DeviceTokenResponse(BaseModel):
|
||||
access_token: str # This will be the user's API key
|
||||
token_type: str = 'Bearer'
|
||||
expires_in: Optional[int] = None # API keys may not have expiration
|
||||
|
||||
|
||||
class DeviceTokenErrorResponse(BaseModel):
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
interval: Optional[int] = None # Required for slow_down error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router + stores
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
oauth_device_router = APIRouter(prefix='/oauth/device')
|
||||
device_code_store = DeviceCodeStore(session_maker)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _oauth_error(
|
||||
status_code: int,
|
||||
error: str,
|
||||
description: str,
|
||||
interval: Optional[int] = None,
|
||||
) -> JSONResponse:
|
||||
"""Return a JSON OAuth-style error response."""
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=DeviceTokenErrorResponse(
|
||||
error=error,
|
||||
error_description=description,
|
||||
interval=interval,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@oauth_device_router.post('/authorize', response_model=DeviceAuthorizationResponse)
|
||||
async def device_authorization(
|
||||
http_request: Request,
|
||||
) -> DeviceAuthorizationResponse:
|
||||
"""Start device flow by generating device and user codes."""
|
||||
try:
|
||||
device_code_entry = device_code_store.create_device_code(
|
||||
expires_in=DEVICE_CODE_EXPIRES_IN,
|
||||
)
|
||||
|
||||
base_url = str(http_request.base_url).rstrip('/')
|
||||
verification_uri = f'{base_url}/oauth/device/verify'
|
||||
verification_uri_complete = (
|
||||
f'{verification_uri}?user_code={device_code_entry.user_code}'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Device authorization initiated',
|
||||
extra={'user_code': device_code_entry.user_code},
|
||||
)
|
||||
|
||||
return DeviceAuthorizationResponse(
|
||||
device_code=device_code_entry.device_code,
|
||||
user_code=device_code_entry.user_code,
|
||||
verification_uri=verification_uri,
|
||||
verification_uri_complete=verification_uri_complete,
|
||||
expires_in=DEVICE_CODE_EXPIRES_IN,
|
||||
interval=device_code_entry.current_interval,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception('Error in device authorization: %s', str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Internal server error',
|
||||
) from e
|
||||
|
||||
|
||||
@oauth_device_router.post('/token')
|
||||
async def device_token(device_code: str = Form(...)):
|
||||
"""Poll for a token until the user authorizes or the code expires."""
|
||||
try:
|
||||
device_code_entry = device_code_store.get_by_device_code(device_code)
|
||||
|
||||
if not device_code_entry:
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'invalid_grant',
|
||||
'Invalid device code',
|
||||
)
|
||||
|
||||
# Check rate limiting (RFC 8628 section 3.5)
|
||||
is_too_fast, current_interval = device_code_entry.check_rate_limit()
|
||||
if is_too_fast:
|
||||
# Update poll time and increase interval
|
||||
device_code_store.update_poll_time(device_code, increase_interval=True)
|
||||
logger.warning(
|
||||
'Client polling too fast, returning slow_down error',
|
||||
extra={
|
||||
'device_code': device_code[:8] + '...', # Log partial for privacy
|
||||
'new_interval': current_interval,
|
||||
},
|
||||
)
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'slow_down',
|
||||
f'Polling too frequently. Wait at least {current_interval} seconds between requests.',
|
||||
interval=current_interval,
|
||||
)
|
||||
|
||||
# Update poll time for successful rate limit check
|
||||
device_code_store.update_poll_time(device_code, increase_interval=False)
|
||||
|
||||
if device_code_entry.is_expired():
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'expired_token',
|
||||
'Device code has expired',
|
||||
)
|
||||
|
||||
if device_code_entry.status == 'denied':
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'access_denied',
|
||||
'User denied the authorization request',
|
||||
)
|
||||
|
||||
if device_code_entry.status == 'pending':
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'authorization_pending',
|
||||
'User has not yet completed authorization',
|
||||
)
|
||||
|
||||
if device_code_entry.status == 'authorized':
|
||||
# Retrieve the specific API key for this device using the user_code
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
device_key_name = f'{API_KEY_NAME} ({device_code_entry.user_code})'
|
||||
device_api_key = api_key_store.retrieve_api_key_by_name(
|
||||
device_code_entry.keycloak_user_id, device_key_name
|
||||
)
|
||||
|
||||
if not device_api_key:
|
||||
logger.error(
|
||||
'No device API key found for authorized device',
|
||||
extra={
|
||||
'user_id': device_code_entry.keycloak_user_id,
|
||||
'user_code': device_code_entry.user_code,
|
||||
},
|
||||
)
|
||||
return _oauth_error(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'server_error',
|
||||
'API key not found',
|
||||
)
|
||||
|
||||
# Return the API key as access_token
|
||||
return DeviceTokenResponse(
|
||||
access_token=device_api_key,
|
||||
)
|
||||
|
||||
# Fallback for unexpected status values
|
||||
logger.error(
|
||||
'Unknown device code status',
|
||||
extra={'status': device_code_entry.status},
|
||||
)
|
||||
return _oauth_error(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'server_error',
|
||||
'Unknown device code status',
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception('Error in device token: %s', str(e))
|
||||
return _oauth_error(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'server_error',
|
||||
'Internal server error',
|
||||
)
|
||||
|
||||
|
||||
@oauth_device_router.post('/verify-authenticated')
|
||||
async def device_verification_authenticated(
|
||||
user_code: str = Form(...),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""Process device verification for authenticated users (called by frontend)."""
|
||||
try:
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Authentication required',
|
||||
)
|
||||
|
||||
# Validate device code
|
||||
device_code_entry = device_code_store.get_by_user_code(user_code)
|
||||
if not device_code_entry:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='The device code is invalid or has expired.',
|
||||
)
|
||||
|
||||
if not device_code_entry.is_pending():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='This device code has already been processed.',
|
||||
)
|
||||
|
||||
# First, authorize the device code
|
||||
success = device_code_store.authorize_device_code(
|
||||
user_code=user_code,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(
|
||||
'Failed to authorize device code',
|
||||
extra={'user_code': user_code, 'user_id': user_id},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to authorize the device. Please try again.',
|
||||
)
|
||||
|
||||
# Only create API key AFTER successful authorization
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
try:
|
||||
# Create a unique API key for this device using user_code in the name
|
||||
device_key_name = f'{API_KEY_NAME} ({user_code})'
|
||||
api_key_store.create_api_key(
|
||||
user_id,
|
||||
name=device_key_name,
|
||||
expires_at=datetime.now(UTC) + KEY_EXPIRATION_TIME,
|
||||
)
|
||||
logger.info(
|
||||
'Created new device API key for user after successful authorization',
|
||||
extra={'user_id': user_id, 'user_code': user_code},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Failed to create device API key after authorization: %s', str(e)
|
||||
)
|
||||
|
||||
# Clean up: revert the device authorization since API key creation failed
|
||||
# This prevents the device from being in an authorized state without an API key
|
||||
try:
|
||||
device_code_store.deny_device_code(user_code)
|
||||
logger.info(
|
||||
'Reverted device authorization due to API key creation failure',
|
||||
extra={'user_code': user_code, 'user_id': user_id},
|
||||
)
|
||||
except Exception as cleanup_error:
|
||||
logger.exception(
|
||||
'Failed to revert device authorization during cleanup: %s',
|
||||
str(cleanup_error),
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create API key for device access.',
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Device code authorized with API key successfully',
|
||||
extra={'user_code': user_code, 'user_id': user_id},
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'message': 'Device authorized successfully!'},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception('Error in device verification: %s', str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred. Please try again.',
|
||||
)
|
||||
@@ -31,7 +31,6 @@ from openhands.events.event_store import EventStore
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
|
||||
from openhands.runtime.plugins.vscode import VSCodeRequirement
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.constants import ROOM_KEY
|
||||
@@ -71,14 +70,6 @@ RUNTIME_CONVERSATION_URL = RUNTIME_URL_PATTERN + (
|
||||
else '/api/conversations/{conversation_id}'
|
||||
)
|
||||
|
||||
RUNTIME_USERNAME = os.getenv('RUNTIME_USERNAME')
|
||||
|
||||
SU_TO_USER = os.getenv('SU_TO_USER', 'false')
|
||||
truthy = {'1', 'true', 't', 'yes', 'y', 'on'}
|
||||
SU_TO_USER = str(SU_TO_USER.lower() in truthy).lower()
|
||||
|
||||
DISABLE_VSCODE_PLUGIN = os.getenv('DISABLE_VSCODE_PLUGIN', 'false').lower() == 'true'
|
||||
|
||||
# Time in seconds before a Redis entry is considered expired if not refreshed
|
||||
_REDIS_ENTRY_TIMEOUT_SECONDS = 300
|
||||
|
||||
@@ -781,11 +772,7 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
env_vars['SERVE_FRONTEND'] = '0'
|
||||
env_vars['RUNTIME'] = 'local'
|
||||
# TODO: In the long term we may come up with a more secure strategy for user management within the nested runtime.
|
||||
env_vars['USER'] = (
|
||||
RUNTIME_USERNAME
|
||||
if RUNTIME_USERNAME
|
||||
else ('openhands' if config.run_as_openhands else 'root')
|
||||
)
|
||||
env_vars['USER'] = 'openhands' if config.run_as_openhands else 'root'
|
||||
env_vars['PERMITTED_CORS_ORIGINS'] = ','.join(PERMITTED_CORS_ORIGINS)
|
||||
env_vars['port'] = '60000'
|
||||
# TODO: These values are static in the runtime-api project, but do not get copied into the runtime ENV
|
||||
@@ -802,10 +789,6 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
env_vars['INITIAL_NUM_WARM_SERVERS'] = '1'
|
||||
env_vars['INIT_GIT_IN_EMPTY_WORKSPACE'] = '1'
|
||||
env_vars['ENABLE_V1'] = '0'
|
||||
env_vars['SU_TO_USER'] = SU_TO_USER
|
||||
env_vars['DISABLE_VSCODE_PLUGIN'] = str(DISABLE_VSCODE_PLUGIN).lower()
|
||||
env_vars['BROWSERGYM_DOWNLOAD_DIR'] = '/workspace/.downloads/'
|
||||
env_vars['PLAYWRIGHT_BROWSERS_PATH'] = '/opt/playwright-browsers'
|
||||
|
||||
# We need this for LLM traces tracking to identify the source of the LLM calls
|
||||
env_vars['WEB_HOST'] = WEB_HOST
|
||||
@@ -821,18 +804,11 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
if self._runtime_container_image:
|
||||
config.sandbox.runtime_container_image = self._runtime_container_image
|
||||
|
||||
plugins = [
|
||||
plugin
|
||||
for plugin in agent.sandbox_plugins
|
||||
if not (DISABLE_VSCODE_PLUGIN and isinstance(plugin, VSCodeRequirement))
|
||||
]
|
||||
logger.info(f'Loaded plugins for runtime {sid}: {plugins}')
|
||||
|
||||
runtime = RemoteRuntime(
|
||||
config=config,
|
||||
event_stream=None, # type: ignore[arg-type]
|
||||
sid=sid,
|
||||
plugins=plugins,
|
||||
plugins=agent.sandbox_plugins,
|
||||
# env_vars=env_vars,
|
||||
# status_callback: Callable[..., None] | None = None,
|
||||
attach_to_existing=False,
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
# Sharing Package
|
||||
|
||||
This package contains functionality for sharing conversations.
|
||||
|
||||
## Components
|
||||
|
||||
- **shared.py**: Data models for shared conversations
|
||||
- **shared_conversation_info_service.py**: Service interface for accessing shared conversation info
|
||||
- **sql_shared_conversation_info_service.py**: SQL implementation of the shared conversation info service
|
||||
- **shared_event_service.py**: Service interface for accessing shared events
|
||||
- **shared_event_service_impl.py**: Implementation of the shared event service
|
||||
- **shared_conversation_router.py**: REST API endpoints for shared conversations
|
||||
- **shared_event_router.py**: REST API endpoints for shared events
|
||||
|
||||
## Features
|
||||
|
||||
- Read-only access to shared conversations
|
||||
- Event access for shared conversations
|
||||
- Search and filtering capabilities
|
||||
- Pagination support
|
||||
@@ -1,142 +0,0 @@
|
||||
"""Implementation of SharedEventService.
|
||||
|
||||
This implementation provides read-only access to events from shared conversations:
|
||||
- Validates that the conversation is shared before returning events
|
||||
- Uses existing EventService for actual event retrieval
|
||||
- Uses SharedConversationInfoService for shared conversation validation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_event_service import (
|
||||
SharedEventService,
|
||||
SharedEventServiceInjector,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoService,
|
||||
)
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.sdk import Event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedEventServiceImpl(SharedEventService):
|
||||
"""Implementation of SharedEventService that validates shared access."""
|
||||
|
||||
shared_conversation_info_service: SharedConversationInfoService
|
||||
event_service: EventService
|
||||
|
||||
async def get_shared_event(
|
||||
self, conversation_id: UUID, event_id: str
|
||||
) -> Event | None:
|
||||
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
|
||||
# First check if the conversation is shared
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
return None
|
||||
|
||||
# If conversation is shared, get the event
|
||||
return await self.event_service.get_event(event_id)
|
||||
|
||||
async def search_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> EventPage:
|
||||
"""Search events for a specific shared conversation."""
|
||||
# First check if the conversation is shared
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
# Return empty page if conversation is not shared
|
||||
return EventPage(items=[], next_page_id=None)
|
||||
|
||||
# If conversation is shared, search events for this conversation
|
||||
return await self.event_service.search_events(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def count_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
) -> int:
|
||||
"""Count events for a specific shared conversation."""
|
||||
# First check if the conversation is shared
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
return 0
|
||||
|
||||
# If conversation is shared, count events for this conversation
|
||||
return await self.event_service.count_events(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
|
||||
class SharedEventServiceImplInjector(SharedEventServiceInjector):
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[SharedEventService, None]:
|
||||
# Define inline to prevent circular lookup
|
||||
from openhands.app_server.config import (
|
||||
get_db_session,
|
||||
get_event_service,
|
||||
)
|
||||
|
||||
async with (
|
||||
get_db_session(state, request) as db_session,
|
||||
get_event_service(state, request) as event_service,
|
||||
):
|
||||
shared_conversation_info_service = SQLSharedConversationInfoService(
|
||||
db_session=db_session
|
||||
)
|
||||
service = SharedEventServiceImpl(
|
||||
shared_conversation_info_service=shared_conversation_info_service,
|
||||
event_service=event_service,
|
||||
)
|
||||
yield service
|
||||
@@ -1,66 +0,0 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
|
||||
from openhands.app_server.services.injector import Injector
|
||||
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
|
||||
class SharedConversationInfoService(ABC):
|
||||
"""Service for accessing shared conversation info without user restrictions."""
|
||||
|
||||
@abstractmethod
|
||||
async def search_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> SharedConversationPage:
|
||||
"""Search for shared conversations."""
|
||||
|
||||
@abstractmethod
|
||||
async def count_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count shared conversations."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_shared_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> SharedConversation | None:
|
||||
"""Get a single shared conversation info, returning None if missing or not shared."""
|
||||
|
||||
async def batch_get_shared_conversation_info(
|
||||
self, conversation_ids: list[UUID]
|
||||
) -> list[SharedConversation | None]:
|
||||
"""Get a batch of shared conversation info, return None for any missing or non-shared."""
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
self.get_shared_conversation_info(conversation_id)
|
||||
for conversation_id in conversation_ids
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SharedConversationInfoServiceInjector(
|
||||
DiscriminatedUnionMixin, Injector[SharedConversationInfoService], ABC
|
||||
):
|
||||
pass
|
||||
@@ -1,56 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
# Simplified imports to avoid dependency chain issues
|
||||
# from openhands.integrations.service_types import ProviderType
|
||||
# from openhands.sdk.llm import MetricsSnapshot
|
||||
# from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
# For now, use Any to avoid import issues
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.agent_server.utils import OpenHandsUUID, utc_now
|
||||
|
||||
ProviderType = Any
|
||||
MetricsSnapshot = Any
|
||||
ConversationTrigger = Any
|
||||
|
||||
|
||||
class SharedConversation(BaseModel):
|
||||
"""Shared conversation info model with all fields from AppConversationInfo."""
|
||||
|
||||
id: OpenHandsUUID = Field(default_factory=uuid4)
|
||||
|
||||
created_by_user_id: str | None
|
||||
sandbox_id: str
|
||||
|
||||
selected_repository: str | None = None
|
||||
selected_branch: str | None = None
|
||||
git_provider: ProviderType | None = None
|
||||
title: str | None = None
|
||||
pr_number: list[int] = Field(default_factory=list)
|
||||
llm_model: str | None = None
|
||||
|
||||
metrics: MetricsSnapshot | None = None
|
||||
|
||||
parent_conversation_id: OpenHandsUUID | None = None
|
||||
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
|
||||
|
||||
created_at: datetime = Field(default_factory=utc_now)
|
||||
updated_at: datetime = Field(default_factory=utc_now)
|
||||
|
||||
|
||||
class SharedConversationSortOrder(Enum):
|
||||
CREATED_AT = 'CREATED_AT'
|
||||
CREATED_AT_DESC = 'CREATED_AT_DESC'
|
||||
UPDATED_AT = 'UPDATED_AT'
|
||||
UPDATED_AT_DESC = 'UPDATED_AT_DESC'
|
||||
TITLE = 'TITLE'
|
||||
TITLE_DESC = 'TITLE_DESC'
|
||||
|
||||
|
||||
class SharedConversationPage(BaseModel):
|
||||
items: list[SharedConversation]
|
||||
next_page_id: str | None = None
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Shared Conversation router for OpenHands Server."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoServiceInjector,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix='/api/shared-conversations', tags=['Sharing'])
|
||||
shared_conversation_info_service_dependency = Depends(
|
||||
SQLSharedConversationInfoServiceInjector().depends
|
||||
)
|
||||
|
||||
# Read methods
|
||||
|
||||
|
||||
@router.get('/search')
|
||||
async def search_shared_conversations(
|
||||
title__contains: Annotated[
|
||||
str | None,
|
||||
Query(title='Filter by title containing this string'),
|
||||
] = None,
|
||||
created_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
created_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at less than this datetime'),
|
||||
] = None,
|
||||
updated_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
updated_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at less than this datetime'),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
SharedConversationSortOrder,
|
||||
Query(title='Sort order for results'),
|
||||
] = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
),
|
||||
] = 100,
|
||||
include_sub_conversations: Annotated[
|
||||
bool,
|
||||
Query(
|
||||
title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.'
|
||||
),
|
||||
] = False,
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> SharedConversationPage:
|
||||
"""Search / List shared conversations."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_conversation_service.search_shared_conversation_info(
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
include_sub_conversations=include_sub_conversations,
|
||||
)
|
||||
|
||||
|
||||
@router.get('/count')
|
||||
async def count_shared_conversations(
|
||||
title__contains: Annotated[
|
||||
str | None,
|
||||
Query(title='Filter by title containing this string'),
|
||||
] = None,
|
||||
created_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
created_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at less than this datetime'),
|
||||
] = None,
|
||||
updated_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
updated_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at less than this datetime'),
|
||||
] = None,
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> int:
|
||||
"""Count shared conversations matching the given filters."""
|
||||
return await shared_conversation_service.count_shared_conversation_info(
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
|
||||
@router.get('')
|
||||
async def batch_get_shared_conversations(
|
||||
ids: Annotated[list[str], Query()],
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> list[SharedConversation | None]:
|
||||
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
|
||||
assert len(ids) <= 100
|
||||
uuids = [UUID(id_) for id_ in ids]
|
||||
shared_conversation_info = (
|
||||
await shared_conversation_service.batch_get_shared_conversation_info(uuids)
|
||||
)
|
||||
return shared_conversation_info
|
||||
@@ -1,126 +0,0 @@
|
||||
"""Shared Event router for OpenHands Server."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from server.sharing.filesystem_shared_event_service import (
|
||||
SharedEventServiceImplInjector,
|
||||
)
|
||||
from server.sharing.shared_event_service import SharedEventService
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.sdk import Event
|
||||
|
||||
router = APIRouter(prefix='/api/shared-events', tags=['Sharing'])
|
||||
shared_event_service_dependency = Depends(SharedEventServiceImplInjector().depends)
|
||||
|
||||
|
||||
# Read methods
|
||||
|
||||
|
||||
@router.get('/search')
|
||||
async def search_shared_events(
|
||||
conversation_id: Annotated[
|
||||
str,
|
||||
Query(title='Conversation ID to search events for'),
|
||||
],
|
||||
kind__eq: Annotated[
|
||||
EventKind | None,
|
||||
Query(title='Optional filter by event kind'),
|
||||
] = None,
|
||||
timestamp__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp greater than or equal to'),
|
||||
] = None,
|
||||
timestamp__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp less than'),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
EventSortOrder,
|
||||
Query(title='Sort order for results'),
|
||||
] = EventSortOrder.TIMESTAMP,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
] = 100,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> EventPage:
|
||||
"""Search / List events for a shared conversation."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_event_service.search_shared_events(
|
||||
conversation_id=UUID(conversation_id),
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get('/count')
|
||||
async def count_shared_events(
|
||||
conversation_id: Annotated[
|
||||
str,
|
||||
Query(title='Conversation ID to count events for'),
|
||||
],
|
||||
kind__eq: Annotated[
|
||||
EventKind | None,
|
||||
Query(title='Optional filter by event kind'),
|
||||
] = None,
|
||||
timestamp__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp greater than or equal to'),
|
||||
] = None,
|
||||
timestamp__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp less than'),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
EventSortOrder,
|
||||
Query(title='Sort order for results'),
|
||||
] = EventSortOrder.TIMESTAMP,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> int:
|
||||
"""Count events for a shared conversation matching the given filters."""
|
||||
return await shared_event_service.count_shared_events(
|
||||
conversation_id=UUID(conversation_id),
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
|
||||
@router.get('')
|
||||
async def batch_get_shared_events(
|
||||
conversation_id: Annotated[
|
||||
UUID,
|
||||
Query(title='Conversation ID to get events for'),
|
||||
],
|
||||
id: Annotated[list[str], Query()],
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> list[Event | None]:
|
||||
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
|
||||
assert len(id) <= 100
|
||||
events = await shared_event_service.batch_get_shared_events(conversation_id, id)
|
||||
return events
|
||||
|
||||
|
||||
@router.get('/{conversation_id}/{event_id}')
|
||||
async def get_shared_event(
|
||||
conversation_id: UUID,
|
||||
event_id: str,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> Event | None:
|
||||
"""Get a single event from a shared conversation by conversation_id and event_id."""
|
||||
return await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||
@@ -1,64 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.app_server.services.injector import Injector
|
||||
from openhands.sdk import Event
|
||||
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharedEventService(ABC):
|
||||
"""Event Service for getting events from shared conversations only."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_shared_event(
|
||||
self, conversation_id: UUID, event_id: str
|
||||
) -> Event | None:
|
||||
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
|
||||
|
||||
@abstractmethod
|
||||
async def search_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> EventPage:
|
||||
"""Search events for a specific shared conversation."""
|
||||
|
||||
@abstractmethod
|
||||
async def count_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
) -> int:
|
||||
"""Count events for a specific shared conversation."""
|
||||
|
||||
async def batch_get_shared_events(
|
||||
self, conversation_id: UUID, event_ids: list[str]
|
||||
) -> list[Event | None]:
|
||||
"""Given a conversation_id and list of event_ids, get events if the conversation is shared."""
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
self.get_shared_event(conversation_id, event_id)
|
||||
for event_id in event_ids
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SharedEventServiceInjector(
|
||||
DiscriminatedUnionMixin, Injector[SharedEventService], ABC
|
||||
):
|
||||
pass
|
||||
@@ -1,282 +0,0 @@
|
||||
"""SQL implementation of SharedConversationInfoService.
|
||||
|
||||
This implementation provides read-only access to shared conversations:
|
||||
- Direct database access without user permission checks
|
||||
- Filters only conversations marked as shared (currently public)
|
||||
- Full async/await support using SQL async db_sessions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
SharedConversationInfoServiceInjector,
|
||||
)
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
"""SQL implementation of SharedConversationInfoService for shared conversations only."""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def search_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> SharedConversationPage:
|
||||
"""Search for shared conversations."""
|
||||
query = self._public_select()
|
||||
|
||||
# Conditionally exclude sub-conversations based on the parameter
|
||||
if not include_sub_conversations:
|
||||
# Exclude sub-conversations (only include top-level conversations)
|
||||
query = query.where(
|
||||
StoredConversationMetadata.parent_conversation_id.is_(None)
|
||||
)
|
||||
|
||||
query = self._apply_filters(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
# Add sort order
|
||||
if sort_order == SharedConversationSortOrder.CREATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.created_at)
|
||||
elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.created_at.desc())
|
||||
elif sort_order == SharedConversationSortOrder.UPDATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at)
|
||||
elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
|
||||
elif sort_order == SharedConversationSortOrder.TITLE:
|
||||
query = query.order_by(StoredConversationMetadata.title)
|
||||
elif sort_order == SharedConversationSortOrder.TITLE_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.title.desc())
|
||||
|
||||
# Apply pagination
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Apply limit and get one extra to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [self._to_shared_conversation(row) for row in rows]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
return SharedConversationPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def count_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count shared conversations matching the given filters."""
|
||||
from sqlalchemy import func
|
||||
|
||||
query = select(func.count(StoredConversationMetadata.conversation_id))
|
||||
# Only include shared conversations
|
||||
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
query = query.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
|
||||
query = self._apply_filters(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_shared_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> SharedConversation | None:
|
||||
"""Get a single public conversation info, returning None if missing or not shared."""
|
||||
query = self._public_select().where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
stored = result.scalar_one_or_none()
|
||||
|
||||
if stored is None:
|
||||
return None
|
||||
|
||||
return self._to_shared_conversation(stored)
|
||||
|
||||
def _public_select(self):
|
||||
"""Create a select query that only returns public conversations."""
|
||||
query = select(StoredConversationMetadata).where(
|
||||
StoredConversationMetadata.conversation_version == 'V1'
|
||||
)
|
||||
# Only include conversations marked as public
|
||||
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
return query
|
||||
|
||||
def _apply_filters(
|
||||
self,
|
||||
query,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
):
|
||||
"""Apply common filters to a query."""
|
||||
if title__contains is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.title.contains(title__contains)
|
||||
)
|
||||
|
||||
if created_at__gte is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.created_at >= created_at__gte
|
||||
)
|
||||
|
||||
if created_at__lt is not None:
|
||||
query = query.where(StoredConversationMetadata.created_at < created_at__lt)
|
||||
|
||||
if updated_at__gte is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.last_updated_at >= updated_at__gte
|
||||
)
|
||||
|
||||
if updated_at__lt is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.last_updated_at < updated_at__lt
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
def _to_shared_conversation(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
sub_conversation_ids: list[UUID] | None = None,
|
||||
) -> SharedConversation:
|
||||
"""Convert StoredConversationMetadata to SharedConversation."""
|
||||
# V1 conversations should always have a sandbox_id
|
||||
sandbox_id = stored.sandbox_id
|
||||
assert sandbox_id is not None
|
||||
|
||||
# Rebuild token usage
|
||||
token_usage = TokenUsage(
|
||||
prompt_tokens=stored.prompt_tokens,
|
||||
completion_tokens=stored.completion_tokens,
|
||||
cache_read_tokens=stored.cache_read_tokens,
|
||||
cache_write_tokens=stored.cache_write_tokens,
|
||||
context_window=stored.context_window,
|
||||
per_turn_token=stored.per_turn_token,
|
||||
)
|
||||
|
||||
# Rebuild metrics object
|
||||
metrics = MetricsSnapshot(
|
||||
accumulated_cost=stored.accumulated_cost,
|
||||
max_budget_per_task=stored.max_budget_per_task,
|
||||
accumulated_token_usage=token_usage,
|
||||
)
|
||||
|
||||
# Get timestamps
|
||||
created_at = self._fix_timezone(stored.created_at)
|
||||
updated_at = self._fix_timezone(stored.last_updated_at)
|
||||
|
||||
return SharedConversation(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=stored.user_id if stored.user_id else None,
|
||||
sandbox_id=stored.sandbox_id,
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
git_provider=(
|
||||
ProviderType(stored.git_provider) if stored.git_provider else None
|
||||
),
|
||||
title=stored.title,
|
||||
pr_number=stored.pr_number,
|
||||
llm_model=stored.llm_model,
|
||||
metrics=metrics,
|
||||
parent_conversation_id=(
|
||||
UUID(stored.parent_conversation_id)
|
||||
if stored.parent_conversation_id
|
||||
else None
|
||||
),
|
||||
sub_conversation_ids=sub_conversation_ids or [],
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
def _fix_timezone(self, value: datetime) -> datetime:
|
||||
"""Sqlite does not store timezones - and since we can't update the existing models
|
||||
we assume UTC if the timezone is missing."""
|
||||
if not value.tzinfo:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
return value
|
||||
|
||||
|
||||
class SQLSharedConversationInfoServiceInjector(SharedConversationInfoServiceInjector):
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[SharedConversationInfoService, None]:
|
||||
# Define inline to prevent circular lookup
|
||||
from openhands.app_server.config import get_db_session
|
||||
|
||||
async with get_db_session(state, request) as db_session:
|
||||
service = SQLSharedConversationInfoService(db_session=db_session)
|
||||
yield service
|
||||
@@ -1,83 +0,0 @@
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.shared import sio
|
||||
|
||||
# Rate limiting constants
|
||||
RATE_LIMIT_USER_SECONDS = 120 # 2 minutes per user_id
|
||||
RATE_LIMIT_IP_SECONDS = 300 # 5 minutes per IP address
|
||||
|
||||
|
||||
async def check_rate_limit_by_user_id(
|
||||
request: Request,
|
||||
key_prefix: str,
|
||||
user_id: str | None,
|
||||
user_rate_limit_seconds: int = RATE_LIMIT_USER_SECONDS,
|
||||
ip_rate_limit_seconds: int = RATE_LIMIT_IP_SECONDS,
|
||||
) -> None:
|
||||
"""
|
||||
Check rate limit for requests, using user_id when available, falling back to IP address.
|
||||
|
||||
Uses Redis to store rate limit keys with expiration. If a key already exists,
|
||||
it means the rate limit is active and the request will be rejected.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
key_prefix: Prefix for the Redis key (e.g., "email_resend")
|
||||
user_id: User ID if available, None otherwise
|
||||
user_rate_limit_seconds: Rate limit window in seconds for user_id-based limiting (default: 120)
|
||||
ip_rate_limit_seconds: Rate limit window in seconds for IP-based limiting (default: 300)
|
||||
|
||||
Raises:
|
||||
HTTPException: If rate limit is exceeded (429 status code)
|
||||
"""
|
||||
try:
|
||||
redis = sio.manager.redis
|
||||
if not redis:
|
||||
# If Redis is unavailable, log warning and allow request (fail open)
|
||||
logger.warning('Redis unavailable for rate limiting, allowing request')
|
||||
return
|
||||
|
||||
if user_id:
|
||||
# Rate limit by user_id (primary method)
|
||||
rate_limit_key = f'{key_prefix}:{user_id}'
|
||||
rate_limit_seconds = user_rate_limit_seconds
|
||||
else:
|
||||
# Fallback to IP address rate limiting
|
||||
client_ip = request.client.host if request.client else 'unknown'
|
||||
rate_limit_key = f'{key_prefix}:ip:{client_ip}'
|
||||
rate_limit_seconds = ip_rate_limit_seconds
|
||||
|
||||
# Try to set the key with expiration. If it already exists (nx=True fails),
|
||||
# it means the rate limit is active
|
||||
created = await redis.set(rate_limit_key, 1, nx=True, ex=rate_limit_seconds)
|
||||
|
||||
if not created:
|
||||
logger.info(
|
||||
f'Rate limit exceeded for {rate_limit_key}',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'ip': request.client.host if request.client else 'unknown',
|
||||
},
|
||||
)
|
||||
# Format error message based on duration
|
||||
if rate_limit_seconds < 60:
|
||||
wait_message = f'{rate_limit_seconds} seconds'
|
||||
elif rate_limit_seconds % 60 == 0:
|
||||
wait_message = f'{rate_limit_seconds // 60} minute{"s" if rate_limit_seconds // 60 != 1 else ""}'
|
||||
else:
|
||||
minutes = rate_limit_seconds // 60
|
||||
seconds = rate_limit_seconds % 60
|
||||
wait_message = f'{minutes} minute{"s" if minutes != 1 else ""} and {seconds} second{"s" if seconds != 1 else ""}'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f'Too many requests. Please wait {wait_message} before trying again.',
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise HTTPException (rate limit exceeded)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log error but allow request (fail open) to avoid blocking legitimate users
|
||||
logger.warning(f'Error checking rate limit: {e}', exc_info=True)
|
||||
return
|
||||
@@ -17,13 +17,10 @@ from openhands.core.logger import openhands_logger as logger
|
||||
class ApiKeyStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
API_KEY_PREFIX = 'sk-oh-'
|
||||
|
||||
def generate_api_key(self, length: int = 32) -> str:
|
||||
"""Generate a random API key with the sk-oh- prefix."""
|
||||
"""Generate a random API key."""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
return f'{self.API_KEY_PREFIX}{random_part}'
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
|
||||
def create_api_key(
|
||||
self, user_id: str, name: str | None = None, expires_at: datetime | None = None
|
||||
@@ -60,15 +57,9 @@ class ApiKeyStore:
|
||||
return None
|
||||
|
||||
# Check if the key has expired
|
||||
if key_record.expires_at:
|
||||
# Handle timezone-naive datetime from database by assuming it's UTC
|
||||
expires_at = key_record.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
logger.info(f'API key has expired: {key_record.id}')
|
||||
return None
|
||||
if key_record.expires_at and key_record.expires_at < now:
|
||||
logger.info(f'API key has expired: {key_record.id}')
|
||||
return None
|
||||
|
||||
# Update last_used_at timestamp
|
||||
session.execute(
|
||||
@@ -134,33 +125,6 @@ class ApiKeyStore:
|
||||
|
||||
return None
|
||||
|
||||
def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
|
||||
"""Retrieve an API key by name for a specific user."""
|
||||
with self.session_maker() as session:
|
||||
key_record = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
.first()
|
||||
)
|
||||
return key_record.key if key_record else None
|
||||
|
||||
def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
|
||||
"""Delete an API key by name for a specific user."""
|
||||
with self.session_maker() as session:
|
||||
key_record = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
session.delete(key_record)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> ApiKeyStore:
|
||||
"""Get an instance of the ApiKeyStore."""
|
||||
|
||||
@@ -19,23 +19,17 @@ GCP_REGION = os.environ.get('GCP_REGION')
|
||||
|
||||
POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25'))
|
||||
MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10'))
|
||||
POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800'))
|
||||
|
||||
# Initialize Cloud SQL Connector once at module level for GCP environments.
|
||||
_connector = None
|
||||
|
||||
|
||||
def _get_db_engine():
|
||||
if GCP_DB_INSTANCE: # GCP environments
|
||||
|
||||
def get_db_connection():
|
||||
global _connector
|
||||
from google.cloud.sql.connector import Connector
|
||||
|
||||
if not _connector:
|
||||
_connector = Connector()
|
||||
connector = Connector()
|
||||
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
|
||||
return _connector.connect(
|
||||
return connector.connect(
|
||||
instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME
|
||||
)
|
||||
|
||||
@@ -44,7 +38,6 @@ def _get_db_engine():
|
||||
creator=get_db_connection,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
else:
|
||||
@@ -55,7 +48,6 @@ def _get_db_engine():
|
||||
host_string,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
"""Device code storage model for OAuth 2.0 Device Flow."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class DeviceCodeStatus(Enum):
|
||||
"""Status of a device code authorization request."""
|
||||
|
||||
PENDING = 'pending'
|
||||
AUTHORIZED = 'authorized'
|
||||
EXPIRED = 'expired'
|
||||
DENIED = 'denied'
|
||||
|
||||
|
||||
class DeviceCode(Base):
|
||||
"""Device code for OAuth 2.0 Device Flow.
|
||||
|
||||
This stores the device codes issued during the device authorization flow,
|
||||
along with their status and associated user information once authorized.
|
||||
"""
|
||||
|
||||
__tablename__ = 'device_codes'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
device_code = Column(String(128), unique=True, nullable=False, index=True)
|
||||
user_code = Column(String(16), unique=True, nullable=False, index=True)
|
||||
status = Column(String(32), nullable=False, default=DeviceCodeStatus.PENDING.value)
|
||||
|
||||
# Keycloak user ID who authorized the device (set during verification)
|
||||
keycloak_user_id = Column(String(255), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
authorized_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Rate limiting fields for RFC 8628 section 3.5 compliance
|
||||
last_poll_time = Column(DateTime(timezone=True), nullable=True)
|
||||
current_interval = Column(Integer, nullable=False, default=5)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<DeviceCode(user_code='{self.user_code}', status='{self.status}')>"
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the device code has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
return now > self.expires_at
|
||||
|
||||
def is_pending(self) -> bool:
|
||||
"""Check if the device code is still pending authorization."""
|
||||
return self.status == DeviceCodeStatus.PENDING.value and not self.is_expired()
|
||||
|
||||
def is_authorized(self) -> bool:
|
||||
"""Check if the device code has been authorized."""
|
||||
return self.status == DeviceCodeStatus.AUTHORIZED.value
|
||||
|
||||
def authorize(self, user_id: str) -> None:
|
||||
"""Mark the device code as authorized."""
|
||||
self.status = DeviceCodeStatus.AUTHORIZED.value
|
||||
self.keycloak_user_id = user_id # Set the Keycloak user ID during authorization
|
||||
self.authorized_at = datetime.now(timezone.utc)
|
||||
|
||||
def deny(self) -> None:
|
||||
"""Mark the device code as denied."""
|
||||
self.status = DeviceCodeStatus.DENIED.value
|
||||
|
||||
def expire(self) -> None:
|
||||
"""Mark the device code as expired."""
|
||||
self.status = DeviceCodeStatus.EXPIRED.value
|
||||
|
||||
def check_rate_limit(self) -> tuple[bool, int]:
|
||||
"""Check if the client is polling too fast.
|
||||
|
||||
Returns:
|
||||
tuple: (is_too_fast, current_interval)
|
||||
- is_too_fast: True if client should receive slow_down error
|
||||
- current_interval: Current polling interval to use
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# If this is the first poll, allow it
|
||||
if self.last_poll_time is None:
|
||||
return False, self.current_interval
|
||||
|
||||
# Calculate time since last poll
|
||||
time_since_last_poll = (now - self.last_poll_time).total_seconds()
|
||||
|
||||
# Check if polling too fast
|
||||
if time_since_last_poll < self.current_interval:
|
||||
# Increase interval for slow_down (RFC 8628 section 3.5)
|
||||
new_interval = min(self.current_interval + 5, 60) # Cap at 60 seconds
|
||||
return True, new_interval
|
||||
|
||||
return False, self.current_interval
|
||||
|
||||
def update_poll_time(self, increase_interval: bool = False) -> None:
|
||||
"""Update the last poll time and optionally increase the interval.
|
||||
|
||||
Args:
|
||||
increase_interval: If True, increase the current interval for slow_down
|
||||
"""
|
||||
self.last_poll_time = datetime.now(timezone.utc)
|
||||
|
||||
if increase_interval:
|
||||
# Increase interval by 5 seconds, cap at 60 seconds (RFC 8628)
|
||||
self.current_interval = min(self.current_interval + 5, 60)
|
||||
@@ -1,167 +0,0 @@
|
||||
"""Device code store for OAuth 2.0 Device Flow."""
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.device_code import DeviceCode
|
||||
|
||||
|
||||
class DeviceCodeStore:
|
||||
"""Store for managing OAuth 2.0 device codes."""
|
||||
|
||||
def __init__(self, session_maker):
|
||||
self.session_maker = session_maker
|
||||
|
||||
def generate_user_code(self) -> str:
|
||||
"""Generate a human-readable user code (8 characters, uppercase letters and digits)."""
|
||||
# Use a mix of uppercase letters and digits, avoiding confusing characters
|
||||
alphabet = 'ABCDEFGHJKLMNPQRSTUVWXYZ23456789' # No I, O, 0, 1
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(8))
|
||||
|
||||
def generate_device_code(self) -> str:
|
||||
"""Generate a secure device code (128 characters)."""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(128))
|
||||
|
||||
def create_device_code(
|
||||
self,
|
||||
expires_in: int = 600, # 10 minutes default
|
||||
max_attempts: int = 10,
|
||||
) -> DeviceCode:
|
||||
"""Create a new device code entry.
|
||||
|
||||
Uses database constraints to ensure uniqueness, avoiding TOCTOU race conditions.
|
||||
Retries on constraint violations until unique codes are generated.
|
||||
|
||||
Args:
|
||||
expires_in: Expiration time in seconds
|
||||
max_attempts: Maximum number of attempts to generate unique codes
|
||||
|
||||
Returns:
|
||||
The created DeviceCode instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If unable to generate unique codes after max_attempts
|
||||
"""
|
||||
for attempt in range(max_attempts):
|
||||
user_code = self.generate_user_code()
|
||||
device_code = self.generate_device_code()
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
||||
|
||||
device_code_entry = DeviceCode(
|
||||
device_code=device_code,
|
||||
user_code=user_code,
|
||||
keycloak_user_id=None, # Will be set during authorization
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
try:
|
||||
with self.session_maker() as session:
|
||||
session.add(device_code_entry)
|
||||
session.commit()
|
||||
session.refresh(device_code_entry)
|
||||
session.expunge(device_code_entry) # Detach from session cleanly
|
||||
return device_code_entry
|
||||
except IntegrityError:
|
||||
# Constraint violation - codes already exist, retry with new codes
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
f'Failed to generate unique device codes after {max_attempts} attempts'
|
||||
)
|
||||
|
||||
def get_by_device_code(self, device_code: str) -> DeviceCode | None:
|
||||
"""Get device code entry by device code."""
|
||||
with self.session_maker() as session:
|
||||
result = (
|
||||
session.query(DeviceCode).filter_by(device_code=device_code).first()
|
||||
)
|
||||
if result:
|
||||
session.expunge(result) # Detach from session cleanly
|
||||
return result
|
||||
|
||||
def get_by_user_code(self, user_code: str) -> DeviceCode | None:
|
||||
"""Get device code entry by user code."""
|
||||
with self.session_maker() as session:
|
||||
result = session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
if result:
|
||||
session.expunge(result) # Detach from session cleanly
|
||||
return result
|
||||
|
||||
def authorize_device_code(self, user_code: str, user_id: str) -> bool:
|
||||
"""Authorize a device code.
|
||||
|
||||
Args:
|
||||
user_code: The user code to authorize
|
||||
user_id: The user ID from Keycloak
|
||||
|
||||
Returns:
|
||||
True if authorization was successful, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
)
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
|
||||
if not device_code_entry.is_pending():
|
||||
return False
|
||||
|
||||
device_code_entry.authorize(user_id)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def deny_device_code(self, user_code: str) -> bool:
|
||||
"""Deny a device code authorization.
|
||||
|
||||
Args:
|
||||
user_code: The user code to deny
|
||||
|
||||
Returns:
|
||||
True if denial was successful, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
)
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
|
||||
if not device_code_entry.is_pending():
|
||||
return False
|
||||
|
||||
device_code_entry.deny()
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def update_poll_time(
|
||||
self, device_code: str, increase_interval: bool = False
|
||||
) -> bool:
|
||||
"""Update the poll time for a device code and optionally increase interval.
|
||||
|
||||
Args:
|
||||
device_code: The device code to update
|
||||
increase_interval: If True, increase the polling interval for slow_down
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(device_code=device_code).first()
|
||||
)
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
|
||||
device_code_entry.update_poll_time(increase_interval)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
@@ -60,8 +60,6 @@ class SaasConversationStore(ConversationStore):
|
||||
kwargs.pop('reasoning_tokens', None)
|
||||
kwargs.pop('context_window', None)
|
||||
kwargs.pop('per_turn_token', None)
|
||||
kwargs.pop('parent_conversation_id', None)
|
||||
kwargs.pop('public')
|
||||
|
||||
return ConversationMetadata(**kwargs)
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
REQUIRE_PAYMENT,
|
||||
USER_SETTINGS_VERSION_TO_MODEL,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
@@ -95,14 +94,9 @@ class SaasSettingsStore(SettingsStore):
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if item and self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item)
|
||||
|
||||
with self.session_maker() as session:
|
||||
existing = None
|
||||
kwargs = {}
|
||||
@@ -203,53 +197,6 @@ class SaasSettingsStore(SettingsStore):
|
||||
)
|
||||
return None
|
||||
|
||||
def _has_custom_settings(
|
||||
self, settings: Settings, old_user_version: int | None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has custom LLM settings that should be preserved.
|
||||
Returns True if user customized either model or base_url.
|
||||
|
||||
Args:
|
||||
settings: The user's current settings
|
||||
old_user_version: The user's old settings version, if any
|
||||
|
||||
Returns:
|
||||
True if user has custom settings, False if using old defaults
|
||||
"""
|
||||
# Normalize values
|
||||
user_model = (
|
||||
settings.llm_model.strip()
|
||||
if settings.llm_model and settings.llm_model.strip()
|
||||
else None
|
||||
)
|
||||
user_base_url = (
|
||||
settings.llm_base_url.strip()
|
||||
if settings.llm_base_url and settings.llm_base_url.strip()
|
||||
else None
|
||||
)
|
||||
|
||||
# Custom base_url = definitely custom settings (BYOK)
|
||||
if user_base_url and user_base_url != LITE_LLM_API_URL:
|
||||
return True
|
||||
|
||||
# No model set = using defaults
|
||||
if not user_model:
|
||||
return False
|
||||
|
||||
# Check if model matches old version's default
|
||||
if (
|
||||
old_user_version
|
||||
and old_user_version < CURRENT_USER_SETTINGS_VERSION
|
||||
and old_user_version in USER_SETTINGS_VERSION_TO_MODEL
|
||||
):
|
||||
old_default_base = USER_SETTINGS_VERSION_TO_MODEL[old_user_version]
|
||||
user_model_base = user_model.split('/')[-1]
|
||||
if user_model_base == old_default_base:
|
||||
return False # Matches old default
|
||||
|
||||
return True # Custom model
|
||||
|
||||
async def update_settings_with_litellm_default(
|
||||
self, settings: Settings
|
||||
) -> Settings | None:
|
||||
@@ -261,17 +208,6 @@ class SaasSettingsStore(SettingsStore):
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
|
||||
# Check if user has custom settings
|
||||
has_custom = self._has_custom_settings(settings, settings.user_version)
|
||||
|
||||
# Determine model to use (needed before LiteLLM user creation)
|
||||
llm_model_to_use = (
|
||||
settings.llm_model
|
||||
if has_custom and settings.llm_model
|
||||
else get_default_litellm_model()
|
||||
)
|
||||
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
@@ -335,7 +271,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
|
||||
# Create the new litellm user
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, email, max_budget, spend, llm_model_to_use
|
||||
client, email, max_budget, spend
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
@@ -344,7 +280,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, None, max_budget, spend, llm_model_to_use
|
||||
client, None, max_budget, spend
|
||||
)
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
@@ -370,17 +306,11 @@ class SaasSettingsStore(SettingsStore):
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
|
||||
if has_custom:
|
||||
settings.llm_model = settings.llm_model or get_default_litellm_model()
|
||||
settings.llm_base_url = settings.llm_base_url or LITE_LLM_API_URL
|
||||
settings.llm_api_key = settings.llm_api_key or SecretStr(key)
|
||||
else:
|
||||
settings.llm_model = get_default_litellm_model()
|
||||
settings.llm_base_url = LITE_LLM_API_URL
|
||||
settings.llm_api_key = SecretStr(key)
|
||||
|
||||
settings.agent = 'CodeActAgent'
|
||||
|
||||
# Use the model corresponding to the current user settings version
|
||||
settings.llm_model = get_default_litellm_model()
|
||||
settings.llm_api_key = SecretStr(key)
|
||||
settings.llm_base_url = LITE_LLM_API_URL
|
||||
return settings
|
||||
|
||||
@classmethod
|
||||
@@ -438,37 +368,8 @@ class SaasSettingsStore(SettingsStore):
|
||||
def _should_encrypt(self, key: str) -> bool:
|
||||
return key in ('llm_api_key', 'llm_api_key_for_byor', 'search_api_key')
|
||||
|
||||
def _is_openhands_provider(self, item: Settings) -> bool:
|
||||
"""Check if the settings use the OpenHands provider."""
|
||||
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
|
||||
|
||||
async def _ensure_openhands_api_key(self, item: Settings) -> None:
|
||||
"""Generate and set the OpenHands API key for the given settings.
|
||||
|
||||
First checks if an existing key with the OpenHands alias exists,
|
||||
and reuses it if found. Otherwise, generates a new key.
|
||||
"""
|
||||
# Generate new key if none exists
|
||||
generated_key = await self._generate_openhands_key()
|
||||
if generated_key:
|
||||
item.llm_api_key = SecretStr(generated_key)
|
||||
logger.info(
|
||||
'saas_settings_store:store:generated_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'saas_settings_store:store:failed_to_generate_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
|
||||
async def _create_user_in_lite_llm(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
max_budget: int,
|
||||
spend: int,
|
||||
llm_model: str,
|
||||
self, client: httpx.AsyncClient, email: str | None, max_budget: int, spend: int
|
||||
):
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
@@ -483,61 +384,9 @@ class SaasSettingsStore(SettingsStore):
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': CURRENT_USER_SETTINGS_VERSION,
|
||||
'model': llm_model,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
'key_alias': f'OpenHands Cloud - user {self.user_id}',
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
async def _generate_openhands_key(self) -> str | None:
|
||||
"""Generate a new OpenHands provider key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'saas_settings_store:_generate_openhands_key:litellm_config_not_found',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
'user_id': self.user_id,
|
||||
'metadata': {'type': 'openhands'},
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'saas_settings_store:_generate_openhands_key:success',
|
||||
extra={
|
||||
'user_id': self.user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': (
|
||||
key[:10] + '...' if key and len(key) > 10 else key
|
||||
),
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'saas_settings_store:_generate_openhands_key:no_key_in_response',
|
||||
extra={'user_id': self.user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'saas_settings_store:_generate_openhands_key:error',
|
||||
extra={'user_id': self.user_id, 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -38,4 +38,3 @@ class UserSettings(Base): # type: ignore
|
||||
email_verified = Column(Boolean, nullable=True)
|
||||
git_user_name = Column(String, nullable=True)
|
||||
git_user_email = Column(String, nullable=True)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
|
||||
@@ -4,8 +4,6 @@ from uuid import uuid4
|
||||
|
||||
from integrations.types import GitLabResourceType
|
||||
from integrations.utils import GITLAB_WEBHOOK_URL
|
||||
from sqlalchemy import text
|
||||
from storage.database import a_session_maker
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||
|
||||
@@ -260,25 +258,6 @@ class VerifyWebhookStatus:
|
||||
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
# Check if the table exists before proceeding
|
||||
# This handles cases where the CronJob runs before database migrations complete
|
||||
async with a_session_maker() as session:
|
||||
query = text("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'gitlab_webhook'
|
||||
)
|
||||
""")
|
||||
result = await session.execute(query)
|
||||
table_exists = result.scalar() or False
|
||||
|
||||
if not table_exists:
|
||||
logger.info(
|
||||
'gitlab_webhook table does not exist yet, '
|
||||
'waiting for database migrations to complete'
|
||||
)
|
||||
return
|
||||
|
||||
# Get an instance of the webhook store
|
||||
webhook_store = await GitlabWebhookStore.get_instance()
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from storage.base import Base
|
||||
# Anything not loaded here may not have a table created for it.
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.device_code import DeviceCode # noqa: F401
|
||||
from storage.feedback import Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
|
||||
@@ -92,8 +92,11 @@ def test_unknown_variant_returns_original_agent_without_changes(monkeypatch):
|
||||
assert getattr(result, 'condenser', None) is None
|
||||
|
||||
|
||||
@patch('experiments.experiment_manager.handle_condenser_max_step_experiment__v1')
|
||||
@patch('experiments.experiment_manager.ENABLE_EXPERIMENT_MANAGER', False)
|
||||
def test_run_agent_variant_tests_v1_noop_when_manager_disabled():
|
||||
def test_run_agent_variant_tests_v1_noop_when_manager_disabled(
|
||||
mock_handle_condenser,
|
||||
):
|
||||
"""If ENABLE_EXPERIMENT_MANAGER is False, the method returns the exact same agent and does not call the handler."""
|
||||
agent = make_agent()
|
||||
conv_id = uuid4()
|
||||
@@ -106,6 +109,8 @@ def test_run_agent_variant_tests_v1_noop_when_manager_disabled():
|
||||
|
||||
# Same object returned (no copy)
|
||||
assert result is agent
|
||||
# Handler should not have been called
|
||||
mock_handle_condenser.assert_not_called()
|
||||
|
||||
|
||||
@patch('experiments.experiment_manager.ENABLE_EXPERIMENT_MANAGER', True)
|
||||
@@ -126,3 +131,7 @@ def test_run_agent_variant_tests_v1_calls_handler_and_sets_system_prompt(monkeyp
|
||||
# Should be a different instance than the original (copied after handler runs)
|
||||
assert result is not agent
|
||||
assert result.system_prompt_filename == 'system_prompt_long_horizon.j2'
|
||||
|
||||
# The condenser returned by the handler must be preserved after the system-prompt override copy
|
||||
assert isinstance(result.condenser, LLMSummarizingCondenser)
|
||||
assert result.condenser.max_size == 80
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
"""Test for ResolverUserContext get_secrets conversion logic.
|
||||
|
||||
This test focuses on testing the actual ResolverUserContext implementation.
|
||||
"""
|
||||
|
||||
from types import MappingProxyType
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from enterprise.integrations.resolver_context import ResolverUserContext
|
||||
|
||||
# Import the real classes we want to test
|
||||
from openhands.integrations.provider import CustomSecret
|
||||
|
||||
# Import the SDK types we need for testing
|
||||
from openhands.sdk.secret import SecretSource, StaticSecret
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_saas_user_auth():
|
||||
"""Mock SaasUserAuth for testing."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resolver_context(mock_saas_user_auth):
|
||||
"""Create a ResolverUserContext instance for testing."""
|
||||
return ResolverUserContext(saas_user_auth=mock_saas_user_auth)
|
||||
|
||||
|
||||
def create_custom_secret(value: str, description: str = 'Test secret') -> CustomSecret:
|
||||
"""Helper to create CustomSecret instances."""
|
||||
return CustomSecret(secret=SecretStr(value), description=description)
|
||||
|
||||
|
||||
def create_secrets(custom_secrets_dict: dict[str, CustomSecret]) -> Secrets:
|
||||
"""Helper to create Secrets instances."""
|
||||
return Secrets(custom_secrets=MappingProxyType(custom_secrets_dict))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secrets_converts_custom_to_static(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_secrets correctly converts CustomSecret objects to StaticSecret objects."""
|
||||
# Arrange
|
||||
secrets = create_secrets(
|
||||
{
|
||||
'TEST_SECRET_1': create_custom_secret('secret_value_1'),
|
||||
'TEST_SECRET_2': create_custom_secret('secret_value_2'),
|
||||
}
|
||||
)
|
||||
mock_saas_user_auth.get_secrets.return_value = secrets
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_secrets()
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(secret, StaticSecret) for secret in result.values())
|
||||
assert result['TEST_SECRET_1'].value.get_secret_value() == 'secret_value_1'
|
||||
assert result['TEST_SECRET_2'].value.get_secret_value() == 'secret_value_2'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secrets_with_special_characters(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that secret values with special characters are preserved during conversion."""
|
||||
# Arrange
|
||||
special_value = 'very_secret_password_123!@#$%^&*()'
|
||||
secrets = create_secrets({'SPECIAL_SECRET': create_custom_secret(special_value)})
|
||||
mock_saas_user_auth.get_secrets.return_value = secrets
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_secrets()
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert isinstance(result['SPECIAL_SECRET'], StaticSecret)
|
||||
assert result['SPECIAL_SECRET'].value.get_secret_value() == special_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'secrets_input,expected_result',
|
||||
[
|
||||
(None, {}), # No secrets available
|
||||
(create_secrets({}), {}), # Empty custom secrets
|
||||
],
|
||||
)
|
||||
async def test_get_secrets_empty_cases(
|
||||
resolver_context, mock_saas_user_auth, secrets_input, expected_result
|
||||
):
|
||||
"""Test that get_secrets handles empty cases correctly."""
|
||||
# Arrange
|
||||
mock_saas_user_auth.get_secrets.return_value = secrets_input
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_secrets()
|
||||
|
||||
# Assert
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_static_secret_is_valid_secret_source():
|
||||
"""Test that StaticSecret is a valid SecretSource for SDK validation."""
|
||||
# Arrange & Act
|
||||
static_secret = StaticSecret(value='test_secret_123')
|
||||
|
||||
# Assert
|
||||
assert isinstance(static_secret, StaticSecret)
|
||||
assert isinstance(static_secret, SecretSource)
|
||||
assert static_secret.value.get_secret_value() == 'test_secret_123'
|
||||
|
||||
|
||||
def test_custom_to_static_conversion():
|
||||
"""Test the complete conversion flow from CustomSecret to StaticSecret."""
|
||||
# Arrange
|
||||
secret_value = 'conversion_test_secret'
|
||||
custom_secret = create_custom_secret(secret_value, 'Conversion test')
|
||||
|
||||
# Act - simulate the conversion logic from the actual method
|
||||
extracted_value = custom_secret.secret.get_secret_value()
|
||||
static_secret = StaticSecret(value=extracted_value)
|
||||
|
||||
# Assert
|
||||
assert isinstance(static_secret, StaticSecret)
|
||||
assert isinstance(static_secret, SecretSource)
|
||||
assert static_secret.value.get_secret_value() == secret_value
|
||||
@@ -1,12 +1,7 @@
|
||||
"""Tests for enterprise integrations utils module."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from integrations.utils import (
|
||||
append_conversation_footer,
|
||||
get_summary_for_agent_state,
|
||||
)
|
||||
from integrations.utils import get_summary_for_agent_state
|
||||
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
@@ -162,138 +157,3 @@ class TestGetSummaryForAgentState:
|
||||
assert 'try again later' in result.lower()
|
||||
# RATE_LIMITED doesn't include conversation link in response
|
||||
assert self.conversation_link not in result
|
||||
|
||||
|
||||
class TestAppendConversationFooter:
|
||||
"""Test cases for append_conversation_footer function."""
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_appends_footer_with_markdown_link(self):
|
||||
"""Test that footer is appended with correct markdown link format."""
|
||||
# Arrange
|
||||
message = 'This is a test message'
|
||||
conversation_id = 'test-conv-123'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert result.startswith(message)
|
||||
assert (
|
||||
'[View full conversation](https://example.com/conversations/test-conv-123)'
|
||||
in result
|
||||
)
|
||||
assert result.endswith(
|
||||
'[View full conversation](https://example.com/conversations/test-conv-123)'
|
||||
)
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_footer_does_not_contain_html_tags(self):
|
||||
"""Test that footer does not contain HTML tags like <sub>."""
|
||||
# Arrange
|
||||
message = 'Test message'
|
||||
conversation_id = 'test-conv-456'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert '<sub>' not in result
|
||||
assert '</sub>' not in result
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_footer_format_with_newlines(self):
|
||||
"""Test that footer is properly separated with newlines."""
|
||||
# Arrange
|
||||
message = 'Original message content'
|
||||
conversation_id = 'test-conv-789'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert (
|
||||
result
|
||||
== 'Original message content\n\n[View full conversation](https://example.com/conversations/test-conv-789)'
|
||||
)
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_empty_message_still_appends_footer(self):
|
||||
"""Test that footer is appended even when message is empty."""
|
||||
# Arrange
|
||||
message = ''
|
||||
conversation_id = 'empty-msg-conv'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert result.startswith('\n\n')
|
||||
assert (
|
||||
'[View full conversation](https://example.com/conversations/empty-msg-conv)'
|
||||
in result
|
||||
)
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_conversation_id_with_special_characters(self):
|
||||
"""Test that footer handles conversation IDs with special characters."""
|
||||
# Arrange
|
||||
message = 'Test message'
|
||||
conversation_id = 'conv-123_abc-456'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
expected_url = 'https://example.com/conversations/conv-123_abc-456'
|
||||
assert expected_url in result
|
||||
assert '[View full conversation]' in result
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_multiline_message_preserves_content(self):
|
||||
"""Test that multiline messages are preserved correctly."""
|
||||
# Arrange
|
||||
message = 'Line 1\nLine 2\nLine 3'
|
||||
conversation_id = 'multiline-conv'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert result.startswith('Line 1\nLine 2\nLine 3')
|
||||
assert '\n\n[View full conversation]' in result
|
||||
assert message in result
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_footer_contains_only_markdown_syntax(self):
|
||||
"""Test that footer uses only markdown syntax, not HTML."""
|
||||
# Arrange
|
||||
message = 'Test message'
|
||||
conversation_id = 'markdown-test'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
footer_part = result[len(message) :]
|
||||
# Should only contain markdown link syntax: [text](url)
|
||||
assert footer_part.startswith('\n\n[')
|
||||
assert '](' in footer_part
|
||||
assert footer_part.endswith(')')
|
||||
# Should not contain any HTML tags (specifically <sub> tags that were removed)
|
||||
assert '<sub>' not in footer_part
|
||||
assert '</sub>' not in footer_part
|
||||
|
||||
@@ -1,361 +0,0 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.email import (
|
||||
ResendEmailVerificationRequest,
|
||||
resend_email_verification,
|
||||
verified_email,
|
||||
verify_email,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request object."""
|
||||
request = MagicMock(spec=Request)
|
||||
request.url = MagicMock()
|
||||
request.url.hostname = 'localhost'
|
||||
request.url.netloc = 'localhost:8000'
|
||||
request.url.path = '/api/email/verified'
|
||||
request.base_url = 'http://localhost:8000/'
|
||||
request.headers = {}
|
||||
request.cookies = {}
|
||||
request.query_params = MagicMock()
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_auth():
|
||||
"""Create a mock SaasUserAuth object."""
|
||||
auth = MagicMock(spec=SaasUserAuth)
|
||||
auth.access_token = SecretStr('test_access_token')
|
||||
auth.refresh_token = SecretStr('test_refresh_token')
|
||||
auth.email = 'test@example.com'
|
||||
auth.email_verified = False
|
||||
auth.accepted_tos = True
|
||||
auth.refresh = AsyncMock()
|
||||
return auth
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_email_default_behavior(mock_request):
|
||||
"""Test verify_email with default is_auth_flow=False."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
):
|
||||
await verify_email(request=mock_request, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
assert call_args.kwargs['user_id'] == user_id
|
||||
assert (
|
||||
call_args.kwargs['redirect_uri'] == 'http://localhost:8000/api/email/verified'
|
||||
)
|
||||
assert 'client_id' in call_args.kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_email_with_auth_flow(mock_request):
|
||||
"""Test verify_email with is_auth_flow=True."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
):
|
||||
await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True)
|
||||
|
||||
# Assert
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
assert call_args.kwargs['user_id'] == user_id
|
||||
assert (
|
||||
call_args.kwargs['redirect_uri'] == 'http://localhost:8000?email_verified=true'
|
||||
)
|
||||
assert 'client_id' in call_args.kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_email_https_scheme(mock_request):
|
||||
"""Test verify_email uses https scheme for non-localhost hosts."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_request.url.hostname = 'example.com'
|
||||
mock_request.url.netloc = 'example.com'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
):
|
||||
await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True)
|
||||
|
||||
# Assert
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
assert call_args.kwargs['redirect_uri'].startswith('https://')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verified_email_default_redirect(mock_request, mock_user_auth):
|
||||
"""Test verified_email redirects to /settings/user by default."""
|
||||
# Arrange
|
||||
mock_request.query_params.get.return_value = None
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
|
||||
):
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert result.headers['location'] == 'http://localhost:8000/settings/user'
|
||||
mock_user_auth.refresh.assert_called_once()
|
||||
mock_set_cookie.assert_called_once()
|
||||
assert mock_user_auth.email_verified is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verified_email_https_scheme(mock_request, mock_user_auth):
|
||||
"""Test verified_email uses https scheme for non-localhost hosts."""
|
||||
# Arrange
|
||||
mock_request.url.hostname = 'example.com'
|
||||
mock_request.url.netloc = 'example.com'
|
||||
mock_request.query_params.get.return_value = None
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
|
||||
):
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.headers['location'].startswith('https://')
|
||||
mock_set_cookie.assert_called_once()
|
||||
# Verify secure flag is True for https
|
||||
call_kwargs = mock_set_cookie.call_args.kwargs
|
||||
assert call_kwargs['secure'] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_with_user_id_from_body_succeeds(mock_request):
|
||||
"""Test resend_email_verification succeeds when user_id is provided in body."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
body = ResendEmailVerificationRequest(user_id=user_id, is_auth_flow=False)
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
|
||||
patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
),
|
||||
patch('server.routes.email.logger') as mock_logger,
|
||||
):
|
||||
mock_rate_limit.return_value = None # Rate limit check passes
|
||||
|
||||
# Act
|
||||
result = await resend_email_verification(request=mock_request, body=body)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
assert 'message' in result.body.decode()
|
||||
mock_rate_limit.assert_called_once_with(
|
||||
request=mock_request,
|
||||
key_prefix='email_resend',
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=30,
|
||||
ip_rate_limit_seconds=60,
|
||||
)
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
# Logger is called multiple times (verify_email and resend_email_verification)
|
||||
# Check that the resend message was logged
|
||||
assert any(
|
||||
'Resending verification email for' in str(call)
|
||||
for call in mock_logger.info.call_args_list
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_with_user_id_from_auth_succeeds(mock_request):
|
||||
"""Test resend_email_verification succeeds when user_id comes from authentication."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.email.get_user_id', return_value=user_id
|
||||
) as mock_get_user_id,
|
||||
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
|
||||
patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
),
|
||||
):
|
||||
mock_rate_limit.return_value = None # Rate limit check passes
|
||||
|
||||
# Act
|
||||
result = await resend_email_verification(request=mock_request, body=None)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
mock_get_user_id.assert_called_once_with(mock_request)
|
||||
mock_rate_limit.assert_called_once_with(
|
||||
request=mock_request,
|
||||
key_prefix='email_resend',
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=30,
|
||||
ip_rate_limit_seconds=60,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_without_user_id_returns_400(mock_request):
|
||||
"""Test resend_email_verification returns 400 when user_id is not available."""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.routes.email.get_user_id', side_effect=Exception('Not authenticated')
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await resend_email_verification(request=mock_request, body=None)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert 'user_id is required' in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_rate_limit_exceeded_returns_429(mock_request):
|
||||
"""Test resend_email_verification returns 429 when rate limit is exceeded."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
body = ResendEmailVerificationRequest(user_id=user_id)
|
||||
|
||||
with (
|
||||
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
|
||||
):
|
||||
mock_rate_limit.side_effect = HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail='Too many requests. Please wait 2 minutes before trying again.',
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await resend_email_verification(request=mock_request, body=body)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert 'Too many requests' in exc_info.value.detail
|
||||
mock_rate_limit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_with_is_auth_flow_true(mock_request):
|
||||
"""Test resend_email_verification passes is_auth_flow to verify_email."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
body = ResendEmailVerificationRequest(user_id=user_id, is_auth_flow=True)
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
|
||||
patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
),
|
||||
):
|
||||
mock_rate_limit.return_value = None
|
||||
|
||||
# Act
|
||||
await resend_email_verification(request=mock_request, body=body)
|
||||
|
||||
# Assert
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
# Verify that verify_email was called with is_auth_flow=True
|
||||
# We check this indirectly by verifying the redirect_uri
|
||||
assert 'email_verified=true' in call_args.kwargs['redirect_uri']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_with_is_auth_flow_false(mock_request):
|
||||
"""Test resend_email_verification uses default is_auth_flow=False when not specified."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
body = ResendEmailVerificationRequest(user_id=user_id, is_auth_flow=False)
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
|
||||
patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
),
|
||||
):
|
||||
mock_rate_limit.return_value = None
|
||||
|
||||
# Act
|
||||
await resend_email_verification(request=mock_request, body=body)
|
||||
|
||||
# Assert
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
# Verify that verify_email was called with is_auth_flow=False
|
||||
assert '/api/email/verified' in call_args.kwargs['redirect_uri']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_body_none_uses_auth(mock_request):
|
||||
"""Test resend_email_verification uses auth when body is None."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.email.get_user_id', return_value=user_id
|
||||
) as mock_get_user_id,
|
||||
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
|
||||
patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
),
|
||||
):
|
||||
mock_rate_limit.return_value = None
|
||||
|
||||
# Act
|
||||
result = await resend_email_verification(request=mock_request, body=None)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
mock_get_user_id.assert_called_once()
|
||||
mock_rate_limit.assert_called_once_with(
|
||||
request=mock_request,
|
||||
key_prefix='email_resend',
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=30,
|
||||
ip_rate_limit_seconds=60,
|
||||
)
|
||||
@@ -1,610 +0,0 @@
|
||||
"""Unit tests for OAuth2 Device Flow endpoints."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from server.routes.oauth_device import (
|
||||
device_authorization,
|
||||
device_token,
|
||||
device_verification_authenticated,
|
||||
)
|
||||
from storage.device_code import DeviceCode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_device_code_store():
|
||||
"""Mock device code store."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_key_store():
|
||||
"""Mock API key store."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_manager():
|
||||
"""Mock token manager."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Mock FastAPI request."""
|
||||
request = MagicMock(spec=Request)
|
||||
request.base_url = 'https://test.example.com/'
|
||||
return request
|
||||
|
||||
|
||||
class TestDeviceAuthorization:
|
||||
"""Test device authorization endpoint."""
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_device_authorization_success(self, mock_store, mock_request):
|
||||
"""Test successful device authorization."""
|
||||
mock_device = DeviceCode(
|
||||
device_code='test-device-code-123',
|
||||
user_code='ABC12345',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
current_interval=5, # Default interval
|
||||
)
|
||||
mock_store.create_device_code.return_value = mock_device
|
||||
|
||||
result = await device_authorization(mock_request)
|
||||
|
||||
assert result.device_code == 'test-device-code-123'
|
||||
assert result.user_code == 'ABC12345'
|
||||
assert result.expires_in == 600
|
||||
assert result.interval == 5 # Should match device's current_interval
|
||||
assert 'verify' in result.verification_uri
|
||||
assert 'ABC12345' in result.verification_uri_complete
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_device_authorization_with_increased_interval(
|
||||
self, mock_store, mock_request
|
||||
):
|
||||
"""Test device authorization returns increased interval from rate limiting."""
|
||||
mock_device = DeviceCode(
|
||||
device_code='test-device-code-456',
|
||||
user_code='XYZ98765',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
current_interval=15, # Increased interval from previous rate limiting
|
||||
)
|
||||
mock_store.create_device_code.return_value = mock_device
|
||||
|
||||
result = await device_authorization(mock_request)
|
||||
|
||||
assert result.device_code == 'test-device-code-456'
|
||||
assert result.user_code == 'XYZ98765'
|
||||
assert result.expires_in == 600
|
||||
assert result.interval == 15 # Should match device's increased current_interval
|
||||
assert 'verify' in result.verification_uri
|
||||
assert 'XYZ98765' in result.verification_uri_complete
|
||||
|
||||
|
||||
class TestDeviceToken:
|
||||
"""Test device token endpoint."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'device_exists,status,expected_error',
|
||||
[
|
||||
(False, None, 'invalid_grant'),
|
||||
(True, 'expired', 'expired_token'),
|
||||
(True, 'denied', 'access_denied'),
|
||||
(True, 'pending', 'authorization_pending'),
|
||||
],
|
||||
)
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_device_token_error_cases(
|
||||
self, mock_store, device_exists, status, expected_error
|
||||
):
|
||||
"""Test various error cases for device token endpoint."""
|
||||
device_code = 'test-device-code'
|
||||
|
||||
if device_exists:
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_expired.return_value = status == 'expired'
|
||||
mock_device.status = status
|
||||
# Mock rate limiting - return False (not too fast) and default interval
|
||||
mock_device.check_rate_limit.return_value = (False, 5)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
else:
|
||||
mock_store.get_by_device_code.return_value = None
|
||||
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
# Check error in response content
|
||||
content = result.body.decode()
|
||||
assert expected_error in content
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_device_token_success(self, mock_store, mock_api_key_class):
|
||||
"""Test successful device token retrieval."""
|
||||
device_code = 'test-device-code'
|
||||
|
||||
# Mock authorized device
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_expired.return_value = False
|
||||
mock_device.status = 'authorized'
|
||||
mock_device.keycloak_user_id = 'user-123'
|
||||
mock_device.user_code = (
|
||||
'ABC12345' # Add user_code for device-specific API key lookup
|
||||
)
|
||||
# Mock rate limiting - return False (not too fast) and default interval
|
||||
mock_device.check_rate_limit.return_value = (False, 5)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
# Mock API key retrieval
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.retrieve_api_key_by_name.return_value = 'test-api-key'
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Check that result is a DeviceTokenResponse
|
||||
assert result.access_token == 'test-api-key'
|
||||
assert result.token_type == 'Bearer'
|
||||
|
||||
# Verify that the correct device-specific API key name was used
|
||||
mock_api_key_store.retrieve_api_key_by_name.assert_called_once_with(
|
||||
'user-123', 'Device Link Access Key (ABC12345)'
|
||||
)
|
||||
|
||||
|
||||
class TestDeviceVerificationAuthenticated:
|
||||
"""Test device verification authenticated endpoint."""
|
||||
|
||||
async def test_verification_unauthenticated_user(self):
|
||||
"""Test verification with unauthenticated user."""
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(user_code='ABC12345', user_id=None)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_verification_invalid_device_code(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test verification with invalid device code."""
|
||||
mock_store.get_by_user_code.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(
|
||||
user_code='INVALID', user_id='user-123'
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_verification_already_processed(self, mock_store, mock_api_key_class):
|
||||
"""Test verification with already processed device code."""
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = False
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_verification_success(self, mock_store, mock_api_key_class):
|
||||
"""Test successful device verification."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True
|
||||
|
||||
# Mock API key store
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 200
|
||||
# Should NOT delete existing API keys (multiple devices allowed)
|
||||
mock_api_key_store.delete_api_key_by_name.assert_not_called()
|
||||
# Should create a new API key with device-specific name
|
||||
mock_api_key_store.create_api_key.assert_called_once()
|
||||
call_args = mock_api_key_store.create_api_key.call_args
|
||||
assert call_args[1]['name'] == 'Device Link Access Key (ABC12345)'
|
||||
mock_store.authorize_device_code.assert_called_once_with(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_multiple_device_authentication(self, mock_store, mock_api_key_class):
|
||||
"""Test that multiple devices can authenticate simultaneously."""
|
||||
# Mock API key store
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Simulate two different devices
|
||||
device1_code = 'ABC12345'
|
||||
device2_code = 'XYZ67890'
|
||||
user_id = 'user-123'
|
||||
|
||||
# Mock device codes
|
||||
mock_device1 = MagicMock()
|
||||
mock_device1.is_pending.return_value = True
|
||||
mock_device2 = MagicMock()
|
||||
mock_device2.is_pending.return_value = True
|
||||
|
||||
# Configure mock store to return appropriate device for each user_code
|
||||
def get_by_user_code_side_effect(user_code):
|
||||
if user_code == device1_code:
|
||||
return mock_device1
|
||||
elif user_code == device2_code:
|
||||
return mock_device2
|
||||
return None
|
||||
|
||||
mock_store.get_by_user_code.side_effect = get_by_user_code_side_effect
|
||||
mock_store.authorize_device_code.return_value = True
|
||||
|
||||
# Authenticate first device
|
||||
result1 = await device_verification_authenticated(
|
||||
user_code=device1_code, user_id=user_id
|
||||
)
|
||||
|
||||
# Authenticate second device
|
||||
result2 = await device_verification_authenticated(
|
||||
user_code=device2_code, user_id=user_id
|
||||
)
|
||||
|
||||
# Both should succeed
|
||||
assert isinstance(result1, JSONResponse)
|
||||
assert result1.status_code == 200
|
||||
assert isinstance(result2, JSONResponse)
|
||||
assert result2.status_code == 200
|
||||
|
||||
# Should create two separate API keys with different names
|
||||
assert mock_api_key_store.create_api_key.call_count == 2
|
||||
|
||||
# Check that each device got a unique API key name
|
||||
call_args_list = mock_api_key_store.create_api_key.call_args_list
|
||||
device1_name = call_args_list[0][1]['name']
|
||||
device2_name = call_args_list[1][1]['name']
|
||||
|
||||
assert device1_name == f'Device Link Access Key ({device1_code})'
|
||||
assert device2_name == f'Device Link Access Key ({device2_code})'
|
||||
assert device1_name != device2_name # Ensure they're different
|
||||
|
||||
# Should NOT delete any existing API keys
|
||||
mock_api_key_store.delete_api_key_by_name.assert_not_called()
|
||||
|
||||
|
||||
class TestDeviceTokenRateLimiting:
|
||||
"""Test rate limiting for device token polling (RFC 8628 section 3.5)."""
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_first_poll_allowed(self, mock_store):
|
||||
"""Test that the first poll is always allowed."""
|
||||
# Create a device code with no previous poll time
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=None, # First poll
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return authorization_pending, not slow_down
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'authorization_pending' in content
|
||||
assert 'slow_down' not in content
|
||||
|
||||
# Should update poll time without increasing interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=False
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_normal_polling_allowed(self, mock_store):
|
||||
"""Test that normal polling (respecting interval) is allowed."""
|
||||
# Create a device code with last poll time 6 seconds ago (interval is 5)
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=6)
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return authorization_pending, not slow_down
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'authorization_pending' in content
|
||||
assert 'slow_down' not in content
|
||||
|
||||
# Should update poll time without increasing interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=False
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_fast_polling_returns_slow_down(self, mock_store):
|
||||
"""Test that polling too fast returns slow_down error."""
|
||||
# Create a device code with last poll time 2 seconds ago (interval is 5)
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=2)
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return slow_down error
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'slow_down' in content
|
||||
assert 'interval' in content
|
||||
assert '10' in content # New interval should be 5 + 5 = 10
|
||||
|
||||
# Should update poll time and increase interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=True
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_interval_increases_with_repeated_fast_polling(self, mock_store):
|
||||
"""Test that interval increases with repeated fast polling."""
|
||||
# Create a device code with higher current interval from previous slow_down
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=5) # 5 seconds ago
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=15, # Already increased from previous slow_down
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return slow_down error with increased interval
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'slow_down' in content
|
||||
assert '20' in content # New interval should be 15 + 5 = 20
|
||||
|
||||
# Should update poll time and increase interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=True
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_interval_caps_at_maximum(self, mock_store):
|
||||
"""Test that interval is capped at maximum value."""
|
||||
# Create a device code with interval near maximum
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=30)
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=58, # Near maximum of 60
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return slow_down error with capped interval
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'slow_down' in content
|
||||
assert '60' in content # Should be capped at 60, not 63
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_rate_limiting_with_authorized_device(self, mock_store):
|
||||
"""Test that rate limiting still applies to authorized devices."""
|
||||
# Create an authorized device code with recent poll
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=2)
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='authorized', # Device is authorized
|
||||
keycloak_user_id='user123',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should still return slow_down error even for authorized device
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'slow_down' in content
|
||||
|
||||
# Should update poll time and increase interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=True
|
||||
)
|
||||
|
||||
|
||||
class TestDeviceVerificationTransactionIntegrity:
|
||||
"""Test transaction integrity for device verification to prevent orphaned API keys."""
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_authorization_failure_prevents_api_key_creation(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test that if device authorization fails, no API key is created."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = False # Authorization fails
|
||||
|
||||
# Mock API key store
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Should raise HTTPException due to authorization failure
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to authorize the device' in exc_info.value.detail
|
||||
|
||||
# API key should NOT be created since authorization failed
|
||||
mock_api_key_store.create_api_key.assert_not_called()
|
||||
mock_store.authorize_device_code.assert_called_once_with(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_api_key_creation_failure_reverts_authorization(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test that if API key creation fails after authorization, the authorization is reverted."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
mock_store.deny_device_code.return_value = True # Cleanup succeeds
|
||||
|
||||
# Mock API key store to fail on creation
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.create_api_key.side_effect = Exception('Database error')
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Should raise HTTPException due to API key creation failure
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to create API key for device access' in exc_info.value.detail
|
||||
|
||||
# Authorization should have been attempted first
|
||||
mock_store.authorize_device_code.assert_called_once_with(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
# API key creation should have been attempted after authorization
|
||||
mock_api_key_store.create_api_key.assert_called_once()
|
||||
|
||||
# Authorization should be reverted due to API key creation failure
|
||||
mock_store.deny_device_code.assert_called_once_with('ABC12345')
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_api_key_creation_failure_cleanup_failure_logged(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test that cleanup failure is logged but doesn't prevent the main error from being raised."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
mock_store.deny_device_code.side_effect = Exception(
|
||||
'Cleanup failed'
|
||||
) # Cleanup fails
|
||||
|
||||
# Mock API key store to fail on creation
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.create_api_key.side_effect = Exception('Database error')
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Should still raise HTTPException for the original API key creation failure
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to create API key for device access' in exc_info.value.detail
|
||||
|
||||
# Both operations should have been attempted
|
||||
mock_store.authorize_device_code.assert_called_once()
|
||||
mock_api_key_store.create_api_key.assert_called_once()
|
||||
mock_store.deny_device_code.assert_called_once_with('ABC12345')
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_successful_flow_creates_api_key_after_authorization(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test that in the successful flow, API key is created only after authorization."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
|
||||
# Mock API key store
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 200
|
||||
|
||||
# Verify the order: authorization first, then API key creation
|
||||
mock_store.authorize_device_code.assert_called_once_with(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
mock_api_key_store.create_api_key.assert_called_once()
|
||||
|
||||
# No cleanup should be needed in successful case
|
||||
mock_store.deny_device_code.assert_not_called()
|
||||
@@ -699,11 +699,12 @@ class TestProcessBatchOperationsBackground:
|
||||
# Should not raise exceptions
|
||||
await _process_batch_operations_background(batch_ops, 'test-api-key')
|
||||
|
||||
# Should log the error with exception type and message in the log message
|
||||
mock_logger.error.assert_called_once()
|
||||
call_args = mock_logger.error.call_args
|
||||
log_message = call_args[0][0]
|
||||
assert log_message.startswith('error_processing_batch_operation:')
|
||||
assert call_args[1]['extra']['path'] == 'invalid-path'
|
||||
assert call_args[1]['extra']['method'] == 'BatchMethod.POST'
|
||||
assert call_args[1]['exc_info'] is True
|
||||
# Should log the error
|
||||
mock_logger.error.assert_called_once_with(
|
||||
'error_processing_batch_operation',
|
||||
extra={
|
||||
'path': 'invalid-path',
|
||||
'method': 'BatchMethod.POST',
|
||||
'error': mock_logger.error.call_args[1]['extra']['error'],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,290 +0,0 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request, status
|
||||
from server.utils.rate_limit_utils import (
|
||||
RATE_LIMIT_IP_SECONDS,
|
||||
RATE_LIMIT_USER_SECONDS,
|
||||
check_rate_limit_by_user_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request object."""
|
||||
request = MagicMock(spec=Request)
|
||||
request.client = MagicMock()
|
||||
request.client.host = '192.168.1.1'
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
"""Create a mock Redis client."""
|
||||
redis = AsyncMock()
|
||||
redis.set = AsyncMock(return_value=True) # First call succeeds (key doesn't exist)
|
||||
return redis
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_by_user_id_first_request_succeeds(mock_request, mock_redis):
|
||||
"""Test that first request with user_id succeeds and sets rate limit key."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
key_prefix = 'email_resend'
|
||||
|
||||
with (
|
||||
patch('server.utils.rate_limit_utils.sio') as mock_sio,
|
||||
patch('server.utils.rate_limit_utils.logger') as mock_logger,
|
||||
):
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_redis.set.assert_called_once_with(
|
||||
f'{key_prefix}:{user_id}', 1, nx=True, ex=RATE_LIMIT_USER_SECONDS
|
||||
)
|
||||
mock_logger.warning.assert_not_called()
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_by_user_id_second_request_within_window_fails(
|
||||
mock_request, mock_redis
|
||||
):
|
||||
"""Test that second request with same user_id within rate limit window fails."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
key_prefix = 'email_resend'
|
||||
mock_redis.set = AsyncMock(return_value=False) # Key already exists
|
||||
|
||||
with (
|
||||
patch('server.utils.rate_limit_utils.sio') as mock_sio,
|
||||
patch('server.utils.rate_limit_utils.logger') as mock_logger,
|
||||
):
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert 'Too many requests' in exc_info.value.detail
|
||||
assert f'{RATE_LIMIT_USER_SECONDS // 60} minutes' in exc_info.value.detail
|
||||
mock_logger.info.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_by_ip_when_user_id_is_none(mock_request, mock_redis):
|
||||
"""Test that rate limiting falls back to IP address when user_id is None."""
|
||||
# Arrange
|
||||
key_prefix = 'email_resend'
|
||||
|
||||
with (
|
||||
patch('server.utils.rate_limit_utils.sio') as mock_sio,
|
||||
patch('server.utils.rate_limit_utils.logger') as mock_logger,
|
||||
):
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=None
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_redis.set.assert_called_once_with(
|
||||
f'{key_prefix}:ip:{mock_request.client.host}',
|
||||
1,
|
||||
nx=True,
|
||||
ex=RATE_LIMIT_IP_SECONDS,
|
||||
)
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_by_ip_second_request_within_window_fails(
|
||||
mock_request, mock_redis
|
||||
):
|
||||
"""Test that second request from same IP within rate limit window fails."""
|
||||
# Arrange
|
||||
key_prefix = 'email_resend'
|
||||
mock_redis.set = AsyncMock(return_value=False) # Key already exists
|
||||
|
||||
with (
|
||||
patch('server.utils.rate_limit_utils.sio') as mock_sio,
|
||||
):
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=None
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert f'{RATE_LIMIT_IP_SECONDS // 60} minutes' in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_redis_unavailable_fails_open(mock_request):
|
||||
"""Test that rate limiting fails open when Redis is unavailable."""
|
||||
# Arrange
|
||||
key_prefix = 'email_resend'
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.utils.rate_limit_utils.sio') as mock_sio,
|
||||
patch('server.utils.rate_limit_utils.logger') as mock_logger,
|
||||
):
|
||||
mock_sio.manager.redis = None # Redis unavailable
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
'Redis unavailable for rate limiting, allowing request'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_redis_exception_fails_open(mock_request, mock_redis):
|
||||
"""Test that rate limiting fails open when Redis raises an exception."""
|
||||
# Arrange
|
||||
key_prefix = 'email_resend'
|
||||
user_id = 'test_user_id'
|
||||
mock_redis.set = AsyncMock(side_effect=Exception('Redis connection error'))
|
||||
|
||||
with (
|
||||
patch('server.utils.rate_limit_utils.sio') as mock_sio,
|
||||
patch('server.utils.rate_limit_utils.logger') as mock_logger,
|
||||
):
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert 'Error checking rate limit' in str(mock_logger.warning.call_args[0][0])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_custom_key_prefix(mock_request, mock_redis):
|
||||
"""Test that different key prefixes create different rate limit keys."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
key_prefix = 'password_reset'
|
||||
|
||||
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_redis.set.assert_called_once_with(
|
||||
f'{key_prefix}:{user_id}', 1, nx=True, ex=RATE_LIMIT_USER_SECONDS
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_custom_rate_limit_seconds(mock_request, mock_redis):
|
||||
"""Test that custom rate limit seconds are used correctly."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
key_prefix = 'email_resend'
|
||||
custom_user_seconds = 60
|
||||
custom_ip_seconds = 180
|
||||
|
||||
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request,
|
||||
key_prefix=key_prefix,
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=custom_user_seconds,
|
||||
ip_rate_limit_seconds=custom_ip_seconds,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_redis.set.assert_called_once_with(
|
||||
f'{key_prefix}:{user_id}', 1, nx=True, ex=custom_user_seconds
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_ip_with_unknown_client(mock_request, mock_redis):
|
||||
"""Test that rate limiting handles missing client host gracefully."""
|
||||
# Arrange
|
||||
key_prefix = 'email_resend'
|
||||
mock_request.client = None # No client information
|
||||
|
||||
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=None
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_redis.set.assert_called_once_with(
|
||||
f'{key_prefix}:ip:unknown', 1, nx=True, ex=RATE_LIMIT_IP_SECONDS
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_different_users_have_separate_limits(
|
||||
mock_request, mock_redis
|
||||
):
|
||||
"""Test that different user_ids have separate rate limit keys."""
|
||||
# Arrange
|
||||
key_prefix = 'email_resend'
|
||||
user_id_1 = 'user_1'
|
||||
user_id_2 = 'user_2'
|
||||
|
||||
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id_1
|
||||
)
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id_2
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mock_redis.set.call_count == 2
|
||||
# Extract call arguments properly
|
||||
call_args_list = [
|
||||
(call[0][0], call[0][1], call[1]['nx'], call[1]['ex'])
|
||||
for call in mock_redis.set.call_args_list
|
||||
]
|
||||
assert (
|
||||
f'{key_prefix}:{user_id_1}',
|
||||
1,
|
||||
True,
|
||||
RATE_LIMIT_USER_SECONDS,
|
||||
) in call_args_list
|
||||
assert (
|
||||
f'{key_prefix}:{user_id_2}',
|
||||
1,
|
||||
True,
|
||||
RATE_LIMIT_USER_SECONDS,
|
||||
) in call_args_list
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Unit tests for DeviceCode model."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from storage.device_code import DeviceCode, DeviceCodeStatus
|
||||
|
||||
|
||||
class TestDeviceCode:
|
||||
"""Test cases for DeviceCode model."""
|
||||
|
||||
@pytest.fixture
|
||||
def device_code(self):
|
||||
"""Create a test device code."""
|
||||
return DeviceCode(
|
||||
device_code='test-device-code-123',
|
||||
user_code='ABC12345',
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'expires_delta,expected',
|
||||
[
|
||||
(timedelta(minutes=5), False), # Future expiry
|
||||
(timedelta(minutes=-5), True), # Past expiry
|
||||
(timedelta(seconds=1), False), # Just future (not expired)
|
||||
],
|
||||
)
|
||||
def test_is_expired(self, expires_delta, expected):
|
||||
"""Test expiration check with various time deltas."""
|
||||
device_code = DeviceCode(
|
||||
device_code='test-device-code',
|
||||
user_code='ABC12345',
|
||||
expires_at=datetime.now(timezone.utc) + expires_delta,
|
||||
)
|
||||
assert device_code.is_expired() == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'status,expired,expected',
|
||||
[
|
||||
(DeviceCodeStatus.PENDING.value, False, True),
|
||||
(DeviceCodeStatus.PENDING.value, True, False),
|
||||
(DeviceCodeStatus.AUTHORIZED.value, False, False),
|
||||
(DeviceCodeStatus.DENIED.value, False, False),
|
||||
],
|
||||
)
|
||||
def test_is_pending(self, status, expired, expected):
|
||||
"""Test pending status check."""
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) - timedelta(minutes=1)
|
||||
if expired
|
||||
else datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
)
|
||||
device_code = DeviceCode(
|
||||
device_code='test-device-code',
|
||||
user_code='ABC12345',
|
||||
status=status,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
assert device_code.is_pending() == expected
|
||||
|
||||
def test_authorize(self, device_code):
|
||||
"""Test device authorization."""
|
||||
user_id = 'test-user-123'
|
||||
|
||||
device_code.authorize(user_id)
|
||||
|
||||
assert device_code.status == DeviceCodeStatus.AUTHORIZED.value
|
||||
assert device_code.keycloak_user_id == user_id
|
||||
assert device_code.authorized_at is not None
|
||||
assert isinstance(device_code.authorized_at, datetime)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'method,expected_status',
|
||||
[
|
||||
('deny', DeviceCodeStatus.DENIED.value),
|
||||
('expire', DeviceCodeStatus.EXPIRED.value),
|
||||
],
|
||||
)
|
||||
def test_status_changes(self, device_code, method, expected_status):
|
||||
"""Test status change methods."""
|
||||
getattr(device_code, method)()
|
||||
assert device_code.status == expected_status
|
||||
@@ -1,193 +0,0 @@
|
||||
"""Unit tests for DeviceCodeStore."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.device_code import DeviceCode
|
||||
from storage.device_code_store import DeviceCodeStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Mock database session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
"""Mock session maker."""
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = mock_session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device_code_store(mock_session_maker):
|
||||
"""Create DeviceCodeStore instance."""
|
||||
return DeviceCodeStore(mock_session_maker)
|
||||
|
||||
|
||||
class TestDeviceCodeStore:
|
||||
"""Test cases for DeviceCodeStore."""
|
||||
|
||||
def test_generate_user_code(self, device_code_store):
|
||||
"""Test user code generation."""
|
||||
code = device_code_store.generate_user_code()
|
||||
|
||||
assert len(code) == 8
|
||||
assert code.isupper()
|
||||
# Should not contain confusing characters
|
||||
assert not any(char in code for char in 'IO01')
|
||||
|
||||
def test_generate_device_code(self, device_code_store):
|
||||
"""Test device code generation."""
|
||||
code = device_code_store.generate_device_code()
|
||||
|
||||
assert len(code) == 128
|
||||
assert code.isalnum()
|
||||
|
||||
def test_create_device_code_success(self, device_code_store, mock_session):
|
||||
"""Test successful device code creation."""
|
||||
# Mock successful creation (no IntegrityError)
|
||||
mock_device_code = MagicMock(spec=DeviceCode)
|
||||
mock_device_code.device_code = 'test-device-code-123'
|
||||
mock_device_code.user_code = 'TESTCODE'
|
||||
|
||||
# Mock the session to return our mock device code after refresh
|
||||
def mock_refresh(obj):
|
||||
obj.device_code = mock_device_code.device_code
|
||||
obj.user_code = mock_device_code.user_code
|
||||
|
||||
mock_session.refresh.side_effect = mock_refresh
|
||||
|
||||
result = device_code_store.create_device_code(expires_in=600)
|
||||
|
||||
assert isinstance(result, DeviceCode)
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_create_device_code_with_retries(
|
||||
self, device_code_store, mock_session_maker
|
||||
):
|
||||
"""Test device code creation with constraint violation retries."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
|
||||
# First attempt fails with IntegrityError, second succeeds
|
||||
mock_session.commit.side_effect = [IntegrityError('', '', ''), None]
|
||||
|
||||
mock_device_code = MagicMock(spec=DeviceCode)
|
||||
mock_device_code.device_code = 'test-device-code-456'
|
||||
mock_device_code.user_code = 'TESTCD2'
|
||||
|
||||
def mock_refresh(obj):
|
||||
obj.device_code = mock_device_code.device_code
|
||||
obj.user_code = mock_device_code.user_code
|
||||
|
||||
mock_session.refresh.side_effect = mock_refresh
|
||||
|
||||
store = DeviceCodeStore(mock_session_maker)
|
||||
result = store.create_device_code(expires_in=600)
|
||||
|
||||
assert isinstance(result, DeviceCode)
|
||||
assert mock_session.add.call_count == 2 # Two attempts
|
||||
assert mock_session.commit.call_count == 2 # Two attempts
|
||||
|
||||
def test_create_device_code_max_attempts_exceeded(
|
||||
self, device_code_store, mock_session_maker
|
||||
):
|
||||
"""Test device code creation failure after max attempts."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
|
||||
# All attempts fail with IntegrityError
|
||||
mock_session.commit.side_effect = IntegrityError('', '', '')
|
||||
|
||||
store = DeviceCodeStore(mock_session_maker)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match='Failed to generate unique device codes after 3 attempts',
|
||||
):
|
||||
store.create_device_code(expires_in=600, max_attempts=3)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'lookup_method,lookup_field',
|
||||
[
|
||||
('get_by_device_code', 'device_code'),
|
||||
('get_by_user_code', 'user_code'),
|
||||
],
|
||||
)
|
||||
def test_lookup_methods(
|
||||
self, device_code_store, mock_session, lookup_method, lookup_field
|
||||
):
|
||||
"""Test device code lookup methods."""
|
||||
test_code = 'test-code-123'
|
||||
mock_device_code = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_device_code
|
||||
)
|
||||
|
||||
result = getattr(device_code_store, lookup_method)(test_code)
|
||||
|
||||
assert result == mock_device_code
|
||||
mock_session.query.assert_called_once_with(DeviceCode)
|
||||
mock_session.query.return_value.filter_by.assert_called_once_with(
|
||||
**{lookup_field: test_code}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'device_exists,is_pending,expected_result',
|
||||
[
|
||||
(True, True, True), # Success case
|
||||
(False, True, False), # Device not found
|
||||
(True, False, False), # Device not pending
|
||||
],
|
||||
)
|
||||
def test_authorize_device_code(
|
||||
self,
|
||||
device_code_store,
|
||||
mock_session,
|
||||
device_exists,
|
||||
is_pending,
|
||||
expected_result,
|
||||
):
|
||||
"""Test device code authorization."""
|
||||
user_code = 'ABC12345'
|
||||
user_id = 'test-user-123'
|
||||
|
||||
if device_exists:
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = is_pending
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_device
|
||||
else:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
result = device_code_store.authorize_device_code(user_code, user_id)
|
||||
|
||||
assert result == expected_result
|
||||
if expected_result:
|
||||
mock_device.authorize.assert_called_once_with(user_id)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_deny_device_code(self, device_code_store, mock_session):
|
||||
"""Test device code denial."""
|
||||
user_code = 'ABC12345'
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_device
|
||||
)
|
||||
|
||||
result = device_code_store.deny_device_code(user_code)
|
||||
|
||||
assert result is True
|
||||
mock_device.deny.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
@@ -25,12 +25,10 @@ def api_key_store(mock_session_maker):
|
||||
|
||||
|
||||
def test_generate_api_key(api_key_store):
|
||||
"""Test that generate_api_key returns a string with sk-oh- prefix and expected length."""
|
||||
"""Test that generate_api_key returns a string of the expected length."""
|
||||
key = api_key_store.generate_api_key(length=32)
|
||||
assert isinstance(key, str)
|
||||
assert key.startswith('sk-oh-')
|
||||
# Total length should be prefix (6 chars) + random part (32 chars) = 38 chars
|
||||
assert len(key) == len('sk-oh-') + 32
|
||||
assert len(key) == 32
|
||||
|
||||
|
||||
def test_create_api_key(api_key_store, mock_session):
|
||||
@@ -92,50 +90,6 @@ def test_validate_api_key_expired(api_key_store, mock_session):
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_api_key_expired_timezone_naive(api_key_store, mock_session):
|
||||
"""Test validating an expired API key with timezone-naive datetime from database."""
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
mock_key_record = MagicMock()
|
||||
# Simulate timezone-naive datetime as returned from database
|
||||
mock_key_record.expires_at = datetime.now() - timedelta(days=1) # No UTC timezone
|
||||
mock_key_record.id = 1
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_key_record
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = api_key_store.validate_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
mock_session.execute.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_api_key_valid_timezone_naive(api_key_store, mock_session):
|
||||
"""Test validating a valid API key with timezone-naive datetime from database."""
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
user_id = 'test-user-123'
|
||||
mock_key_record = MagicMock()
|
||||
mock_key_record.user_id = user_id
|
||||
# Simulate timezone-naive datetime as returned from database (future date)
|
||||
mock_key_record.expires_at = datetime.now() + timedelta(days=1) # No UTC timezone
|
||||
mock_key_record.id = 1
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_key_record
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = api_key_store.validate_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result == user_id
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_api_key_not_found(api_key_store, mock_session):
|
||||
"""Test validating a non-existent API key."""
|
||||
# Setup
|
||||
|
||||
@@ -234,53 +234,3 @@ async def test_middleware_with_other_auth_error(middleware, mock_request):
|
||||
assert 'set-cookie' in result.headers
|
||||
# Logger should be called for non-NoCredentialsError
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_ignores_email_resend_path(
|
||||
middleware, mock_request, mock_response
|
||||
):
|
||||
"""Test middleware ignores /api/email/resend path and doesn't require authentication."""
|
||||
# Arrange
|
||||
mock_request.cookies = {}
|
||||
mock_request.url = MagicMock()
|
||||
mock_request.url.hostname = 'localhost'
|
||||
mock_request.url.path = '/api/email/resend'
|
||||
mock_call_next = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Act
|
||||
result = await middleware(mock_request, mock_call_next)
|
||||
|
||||
# Assert
|
||||
assert result == mock_response
|
||||
mock_call_next.assert_called_once_with(mock_request)
|
||||
# Should not raise NoCredentialsError even without auth cookie
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_ignores_email_resend_path_no_tos_check(
|
||||
middleware, mock_request, mock_response
|
||||
):
|
||||
"""Test middleware doesn't check TOS for /api/email/resend path."""
|
||||
# Arrange
|
||||
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
|
||||
mock_request.url = MagicMock()
|
||||
mock_request.url.hostname = 'localhost'
|
||||
mock_request.url.path = '/api/email/resend'
|
||||
mock_call_next = AsyncMock(return_value=mock_response)
|
||||
|
||||
with (
|
||||
patch('server.middleware.jwt.decode') as mock_decode,
|
||||
patch('server.middleware.config') as mock_config,
|
||||
):
|
||||
# Even with accepted_tos=False, should not raise TosNotAcceptedError
|
||||
mock_decode.return_value = {'accepted_tos': False}
|
||||
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
||||
|
||||
# Act
|
||||
result = await middleware(mock_request, mock_call_next)
|
||||
|
||||
# Assert
|
||||
assert result == mock_response
|
||||
mock_call_next.assert_called_once_with(mock_request)
|
||||
# Should not raise TosNotAcceptedError for this path
|
||||
|
||||
@@ -136,7 +136,6 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@@ -185,7 +184,6 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@@ -216,84 +214,6 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
mock_posthog.set.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
"""Test keycloak_callback when email is not verified."""
|
||||
# Arrange
|
||||
mock_verify_email = AsyncMock()
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': False,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'email_verification_required=true' in result.headers['location']
|
||||
assert 'user_id=test_user_id' in result.headers['location']
|
||||
mock_verify_email.assert_called_once_with(
|
||||
request=mock_request, user_id='test_user_id', is_auth_flow=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
"""Test keycloak_callback when email_verified field is missing (defaults to False)."""
|
||||
# Arrange
|
||||
mock_verify_email = AsyncMock()
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
# email_verified field is missing
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'email_verification_required=true' in result.headers['location']
|
||||
assert 'user_id=test_user_id' in result.headers['location']
|
||||
mock_verify_email.assert_called_once_with(
|
||||
request=mock_request, user_id='test_user_id', is_auth_flow=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
"""Test successful keycloak_callback without valid offline token."""
|
||||
@@ -328,7 +248,6 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@@ -523,418 +442,3 @@ async def test_logout_without_refresh_token():
|
||||
|
||||
mock_token_manager.logout.assert_not_called()
|
||||
assert 'set-cookie' in result.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
"""Test keycloak_callback when email domain is blocked."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@colsch.us',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.disable_keycloak_user = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert 'error' in result.body.decode()
|
||||
assert 'email domain is not allowed' in result.body.decode()
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||
mock_token_manager.disable_keycloak_user.assert_called_once_with(
|
||||
'test_user_id', 'user@colsch.us'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
"""Test keycloak_callback when email domain is not blocked."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@example.com',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
|
||||
'user@example.com'
|
||||
)
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
"""Test keycloak_callback when domain blocking is not active."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@colsch.us',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_missing_email(mock_request):
|
||||
"""Test keycloak_callback when user info does not contain email."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
# No email field
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
"""Test keycloak_callback when duplicate email is detected."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'duplicated_email=true' in result.headers['location']
|
||||
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
|
||||
'joe+test@example.com', 'test_user_id'
|
||||
)
|
||||
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
"""Test keycloak_callback when duplicate is detected but deletion fails."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'duplicated_email=true' in result.headers['location']
|
||||
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
"""Test keycloak_callback when duplicate check raises exception."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(
|
||||
side_effect=Exception('Check failed')
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should proceed with normal flow despite exception (fail open)
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
"""Test keycloak_callback when no duplicate email is found."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=False)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
|
||||
'joe+test@example.com', 'test_user_id'
|
||||
)
|
||||
# Should not delete user when no duplicate found
|
||||
mock_token_manager.delete_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
"""Test keycloak_callback when email is not in user_info."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
# No email field
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
# Should not check for duplicate when email is missing
|
||||
mock_token_manager.check_duplicate_base_email.assert_not_called()
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
import stripe
|
||||
from fastapi import HTTPException, Request, status
|
||||
from httpx import Response
|
||||
from httpx import HTTPStatusError, Response
|
||||
from integrations.stripe_service import has_payment_method
|
||||
from server.routes.billing import (
|
||||
CreateBillingSessionResponse,
|
||||
@@ -78,6 +78,8 @@ def mock_subscription_request():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_lite_llm_error():
|
||||
mock_request = Request(scope={'type': 'http', 'state': {'user_id': 'mock_user'}})
|
||||
|
||||
mock_response = Response(
|
||||
status_code=500, json={'error': 'Internal Server Error'}, request=MagicMock()
|
||||
)
|
||||
@@ -86,12 +88,11 @@ async def test_get_credits_lite_llm_error():
|
||||
|
||||
with patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'):
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_credits('mock_user')
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
with pytest.raises(HTTPStatusError) as exc_info:
|
||||
await get_credits(mock_request)
|
||||
assert (
|
||||
exc_info.value.detail
|
||||
== 'Failed to retrieve credit balance from billing service'
|
||||
exc_info.value.response.status_code
|
||||
== status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
"""Unit tests for DomainBlocker class."""
|
||||
|
||||
import pytest
|
||||
from server.auth.domain_blocker import DomainBlocker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def domain_blocker():
|
||||
"""Create a DomainBlocker instance for testing."""
|
||||
return DomainBlocker()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'blocked_domains,expected',
|
||||
[
|
||||
(['colsch.us', 'other-domain.com'], True),
|
||||
(['example.com'], True),
|
||||
([], False),
|
||||
],
|
||||
)
|
||||
def test_is_active(domain_blocker, blocked_domains, expected):
|
||||
"""Test that is_active returns correct value based on blocked domains configuration."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = blocked_domains
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_active()
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'email,expected_domain',
|
||||
[
|
||||
('user@example.com', 'example.com'),
|
||||
('test@colsch.us', 'colsch.us'),
|
||||
('user.name@other-domain.com', 'other-domain.com'),
|
||||
('USER@EXAMPLE.COM', 'example.com'), # Case insensitive
|
||||
('user@EXAMPLE.COM', 'example.com'),
|
||||
(' user@example.com ', 'example.com'), # Whitespace handling
|
||||
],
|
||||
)
|
||||
def test_extract_domain_valid_emails(domain_blocker, email, expected_domain):
|
||||
"""Test that _extract_domain correctly extracts and normalizes domains from valid emails."""
|
||||
# Act
|
||||
result = domain_blocker._extract_domain(email)
|
||||
|
||||
# Assert
|
||||
assert result == expected_domain
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'email,expected',
|
||||
[
|
||||
(None, None),
|
||||
('', None),
|
||||
('invalid-email', None),
|
||||
('user@', None), # Empty domain after @
|
||||
('no-at-sign', None),
|
||||
],
|
||||
)
|
||||
def test_extract_domain_invalid_emails(domain_blocker, email, expected):
|
||||
"""Test that _extract_domain returns None for invalid email formats."""
|
||||
# Act
|
||||
result = domain_blocker._extract_domain(email)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_is_domain_blocked_when_inactive(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when blocking is not active."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = []
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_none_email(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when email is None."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked(None)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_empty_email(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when email is empty."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_invalid_email(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when email format is invalid."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('invalid-email')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_not_blocked(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when domain is not in blocked list."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_blocked(domain_blocker):
|
||||
"""Test that is_domain_blocked returns True when domain is in blocked list."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_domain_blocked_case_insensitive(domain_blocker):
|
||||
"""Test that is_domain_blocked performs case-insensitive domain matching."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@COLSCH.US')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker):
|
||||
"""Test that is_domain_blocked correctly checks against multiple blocked domains."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com', 'blocked.org']
|
||||
|
||||
# Act
|
||||
result1 = domain_blocker.is_domain_blocked('user@other-domain.com')
|
||||
result2 = domain_blocker.is_domain_blocked('user@blocked.org')
|
||||
result3 = domain_blocker.is_domain_blocked('user@allowed.com')
|
||||
|
||||
# Assert
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
assert result3 is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_whitespace(domain_blocker):
|
||||
"""Test that is_domain_blocked handles emails with whitespace correctly."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
@@ -1,294 +0,0 @@
|
||||
"""Tests for email validation utilities."""
|
||||
|
||||
import re
|
||||
|
||||
from server.auth.email_validation import (
|
||||
extract_base_email,
|
||||
get_base_email_regex_pattern,
|
||||
has_plus_modifier,
|
||||
matches_base_email,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractBaseEmail:
|
||||
"""Test cases for extract_base_email function."""
|
||||
|
||||
def test_extract_base_email_with_plus_modifier(self):
|
||||
"""Test extracting base email from email with + modifier."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_without_plus_modifier(self):
|
||||
"""Test that email without + modifier is returned as-is."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_multiple_plus_signs(self):
|
||||
"""Test extracting base email when multiple + signs exist."""
|
||||
# Arrange
|
||||
email = 'joe+openhands+test@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_invalid_no_at_symbol(self):
|
||||
"""Test that invalid email without @ returns None."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_extract_base_email_empty_string(self):
|
||||
"""Test that empty string returns None."""
|
||||
# Arrange
|
||||
email = ''
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_extract_base_email_none(self):
|
||||
"""Test that None input returns None."""
|
||||
# Arrange
|
||||
email = None
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestHasPlusModifier:
|
||||
"""Test cases for has_plus_modifier function."""
|
||||
|
||||
def test_has_plus_modifier_true(self):
|
||||
"""Test detecting + modifier in email."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_has_plus_modifier_false(self):
|
||||
"""Test that email without + modifier returns False."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_has_plus_modifier_invalid_no_at_symbol(self):
|
||||
"""Test that invalid email without @ returns False."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_has_plus_modifier_empty_string(self):
|
||||
"""Test that empty string returns False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestMatchesBaseEmail:
|
||||
"""Test cases for matches_base_email function."""
|
||||
|
||||
def test_matches_base_email_exact_match(self):
|
||||
"""Test that exact base email matches."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_with_plus_variant(self):
|
||||
"""Test that email with + variant matches base email."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_different_base(self):
|
||||
"""Test that different base emails do not match."""
|
||||
# Arrange
|
||||
email = 'jane@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_matches_base_email_different_domain(self):
|
||||
"""Test that same local part but different domain does not match."""
|
||||
# Arrange
|
||||
email = 'joe@other.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_matches_base_email_case_insensitive(self):
|
||||
"""Test that matching is case-insensitive."""
|
||||
# Arrange
|
||||
email = 'JOE+TEST@EXAMPLE.COM'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_empty_strings(self):
|
||||
"""Test that empty strings return False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetBaseEmailRegexPattern:
|
||||
"""Test cases for get_base_email_regex_pattern function."""
|
||||
|
||||
def test_get_base_email_regex_pattern_valid(self):
|
||||
"""Test generating valid regex pattern for base email."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Assert
|
||||
assert pattern is not None
|
||||
assert isinstance(pattern, re.Pattern)
|
||||
assert pattern.match('joe@example.com') is not None
|
||||
assert pattern.match('joe+test@example.com') is not None
|
||||
assert pattern.match('joe+openhands@example.com') is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_matches_plus_variant(self):
|
||||
"""Test that regex pattern matches + variant."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('joe+test@example.com')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_rejects_different_base(self):
|
||||
"""Test that regex pattern rejects different base email."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('jane@example.com')
|
||||
|
||||
# Assert
|
||||
assert match is None
|
||||
|
||||
def test_get_base_email_regex_pattern_rejects_different_domain(self):
|
||||
"""Test that regex pattern rejects different domain."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('joe@other.com')
|
||||
|
||||
# Assert
|
||||
assert match is None
|
||||
|
||||
def test_get_base_email_regex_pattern_case_insensitive(self):
|
||||
"""Test that regex pattern is case-insensitive."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('JOE+TEST@EXAMPLE.COM')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_special_characters(self):
|
||||
"""Test that regex pattern handles special characters in email."""
|
||||
# Arrange
|
||||
base_email = 'user.name+tag@example-site.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('user.name+test@example-site.com')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_invalid_base_email(self):
|
||||
"""Test that invalid base email returns None."""
|
||||
# Arrange
|
||||
base_email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Assert
|
||||
assert pattern is None
|
||||
@@ -1,132 +0,0 @@
|
||||
"""Unit tests for get_user_v1_enabled_setting function."""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.github.github_view import get_user_v1_enabled_setting
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_settings():
|
||||
"""Create a mock user settings object."""
|
||||
settings = MagicMock()
|
||||
settings.v1_enabled = True # Default to True, can be overridden in tests
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings_store(mock_user_settings):
|
||||
"""Create a mock settings store."""
|
||||
store = MagicMock()
|
||||
store.get_user_settings_by_keycloak_id = AsyncMock(return_value=mock_user_settings)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock config object."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker():
|
||||
"""Create a mock session maker."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(
|
||||
mock_settings_store, mock_config, mock_session_maker, mock_user_settings
|
||||
):
|
||||
"""Fixture that patches all the common dependencies."""
|
||||
with patch(
|
||||
'integrations.github.github_view.SaasSettingsStore',
|
||||
return_value=mock_settings_store,
|
||||
) as mock_store_class, patch(
|
||||
'integrations.github.github_view.get_config', return_value=mock_config
|
||||
) as mock_get_config, patch(
|
||||
'integrations.github.github_view.session_maker', mock_session_maker
|
||||
), patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
return_value=mock_user_settings,
|
||||
) as mock_call_sync:
|
||||
yield {
|
||||
'store_class': mock_store_class,
|
||||
'get_config': mock_get_config,
|
||||
'session_maker': mock_session_maker,
|
||||
'call_sync': mock_call_sync,
|
||||
'settings_store': mock_settings_store,
|
||||
'user_settings': mock_user_settings,
|
||||
}
|
||||
|
||||
|
||||
class TestGetUserV1EnabledSetting:
|
||||
"""Test cases for get_user_v1_enabled_setting function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'env_var_enabled,user_setting_enabled,expected_result',
|
||||
[
|
||||
(False, True, False), # Env var disabled, user enabled -> False
|
||||
(True, False, False), # Env var enabled, user disabled -> False
|
||||
(True, True, True), # Both enabled -> True
|
||||
(False, False, False), # Both disabled -> False
|
||||
],
|
||||
)
|
||||
async def test_v1_enabled_combinations(
|
||||
self, mock_dependencies, env_var_enabled, user_setting_enabled, expected_result
|
||||
):
|
||||
"""Test all combinations of environment variable and user setting values."""
|
||||
mock_dependencies['user_settings'].v1_enabled = user_setting_enabled
|
||||
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_enabled
|
||||
):
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'env_var_value,env_var_bool,expected_result',
|
||||
[
|
||||
('false', False, False), # Environment variable 'false' -> False
|
||||
('true', True, True), # Environment variable 'true' -> True
|
||||
],
|
||||
)
|
||||
async def test_environment_variable_integration(
|
||||
self, mock_dependencies, env_var_value, env_var_bool, expected_result
|
||||
):
|
||||
"""Test that the function properly reads the ENABLE_V1_GITHUB_RESOLVER environment variable."""
|
||||
mock_dependencies['user_settings'].v1_enabled = True
|
||||
|
||||
with patch.dict(
|
||||
os.environ, {'ENABLE_V1_GITHUB_RESOLVER': env_var_value}
|
||||
), patch('integrations.utils.os.getenv', return_value=env_var_value), patch(
|
||||
'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_bool
|
||||
):
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_calls_correct_methods(self, mock_dependencies):
|
||||
"""Test that the function calls the correct methods with correct parameters."""
|
||||
mock_dependencies['user_settings'].v1_enabled = True
|
||||
|
||||
with patch('integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', True):
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
|
||||
# Verify the result
|
||||
assert result is True
|
||||
|
||||
# Verify correct methods were called with correct parameters
|
||||
mock_dependencies['get_config'].assert_called_once()
|
||||
mock_dependencies['store_class'].assert_called_once_with(
|
||||
user_id='test_user_123',
|
||||
session_maker=mock_dependencies['session_maker'],
|
||||
config=mock_dependencies['get_config'].return_value,
|
||||
)
|
||||
mock_dependencies['call_sync'].assert_called_once_with(
|
||||
mock_dependencies['settings_store'].get_user_settings_by_keycloak_id,
|
||||
'test_user_123',
|
||||
)
|
||||
@@ -1,10 +1,7 @@
|
||||
from unittest import TestCase, mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.github.github_view import GithubFactory, GithubIssue, get_oh_labels
|
||||
from integrations.github.github_view import GithubFactory, get_oh_labels
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.types import UserData
|
||||
|
||||
|
||||
class TestGithubLabels(TestCase):
|
||||
@@ -78,132 +75,3 @@ class TestGithubCommentCaseInsensitivity(TestCase):
|
||||
self.assertTrue(GithubFactory.is_issue_comment(message_lower))
|
||||
self.assertTrue(GithubFactory.is_issue_comment(message_upper))
|
||||
self.assertTrue(GithubFactory.is_issue_comment(message_mixed))
|
||||
|
||||
|
||||
class TestGithubV1ConversationRouting(TestCase):
|
||||
"""Test V1 conversation routing logic in GitHub integration."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
# Create a proper UserData instance instead of MagicMock
|
||||
user_data = UserData(
|
||||
user_id=123, username='testuser', keycloak_user_id='test-keycloak-id'
|
||||
)
|
||||
|
||||
# Create a mock raw_payload
|
||||
raw_payload = Message(
|
||||
source=SourceType.GITHUB,
|
||||
message={
|
||||
'payload': {
|
||||
'action': 'opened',
|
||||
'issue': {'number': 123},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self.github_issue = GithubIssue(
|
||||
user_info=user_data,
|
||||
full_repo_name='test/repo',
|
||||
issue_number=123,
|
||||
installation_id=456,
|
||||
conversation_id='test-conversation-id',
|
||||
should_extract=True,
|
||||
send_summary_instruction=False,
|
||||
is_public_repo=True,
|
||||
raw_payload=raw_payload,
|
||||
uuid='test-uuid',
|
||||
title='Test Issue',
|
||||
description='Test issue description',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.github.github_view.get_user_v1_enabled_setting')
|
||||
@patch.object(GithubIssue, '_create_v0_conversation')
|
||||
@patch.object(GithubIssue, '_create_v1_conversation')
|
||||
async def test_create_new_conversation_routes_to_v0_when_disabled(
|
||||
self, mock_create_v1, mock_create_v0, mock_get_v1_setting
|
||||
):
|
||||
"""Test that conversation creation routes to V0 when v1_enabled is False."""
|
||||
# Mock v1_enabled as False
|
||||
mock_get_v1_setting.return_value = False
|
||||
mock_create_v0.return_value = None
|
||||
mock_create_v1.return_value = None
|
||||
|
||||
# Mock parameters
|
||||
jinja_env = MagicMock()
|
||||
git_provider_tokens = MagicMock()
|
||||
conversation_metadata = MagicMock()
|
||||
|
||||
# Call the method
|
||||
await self.github_issue.create_new_conversation(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
|
||||
# Verify V0 was called and V1 was not
|
||||
mock_create_v0.assert_called_once_with(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
mock_create_v1.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.github.github_view.get_user_v1_enabled_setting')
|
||||
@patch.object(GithubIssue, '_create_v0_conversation')
|
||||
@patch.object(GithubIssue, '_create_v1_conversation')
|
||||
async def test_create_new_conversation_routes_to_v1_when_enabled(
|
||||
self, mock_create_v1, mock_create_v0, mock_get_v1_setting
|
||||
):
|
||||
"""Test that conversation creation routes to V1 when v1_enabled is True."""
|
||||
# Mock v1_enabled as True
|
||||
mock_get_v1_setting.return_value = True
|
||||
mock_create_v0.return_value = None
|
||||
mock_create_v1.return_value = None
|
||||
|
||||
# Mock parameters
|
||||
jinja_env = MagicMock()
|
||||
git_provider_tokens = MagicMock()
|
||||
conversation_metadata = MagicMock()
|
||||
|
||||
# Call the method
|
||||
await self.github_issue.create_new_conversation(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
|
||||
# Verify V1 was called and V0 was not
|
||||
mock_create_v1.assert_called_once_with(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
mock_create_v0.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.github.github_view.get_user_v1_enabled_setting')
|
||||
@patch.object(GithubIssue, '_create_v0_conversation')
|
||||
@patch.object(GithubIssue, '_create_v1_conversation')
|
||||
async def test_create_new_conversation_fallback_on_v1_setting_error(
|
||||
self, mock_create_v1, mock_create_v0, mock_get_v1_setting
|
||||
):
|
||||
"""Test that conversation creation falls back to V0 when _create_v1_conversation fails."""
|
||||
# Mock v1_enabled as True so V1 is attempted
|
||||
mock_get_v1_setting.return_value = True
|
||||
# Mock _create_v1_conversation to raise an exception
|
||||
mock_create_v1.side_effect = Exception('V1 conversation creation failed')
|
||||
mock_create_v0.return_value = None
|
||||
|
||||
# Mock parameters
|
||||
jinja_env = MagicMock()
|
||||
git_provider_tokens = MagicMock()
|
||||
conversation_metadata = MagicMock()
|
||||
|
||||
# Call the method
|
||||
await self.github_issue.create_new_conversation(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
|
||||
# Verify V1 was attempted first, then V0 was called as fallback
|
||||
mock_create_v1.assert_called_once_with(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
mock_create_v0.assert_called_once_with(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
|
||||
485
enterprise/tests/unit/test_legacy_conversation_manager.py
Normal file
485
enterprise/tests/unit/test_legacy_conversation_manager.py
Normal file
@@ -0,0 +1,485 @@
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from server.legacy_conversation_manager import (
|
||||
_LEGACY_ENTRY_TIMEOUT_SECONDS,
|
||||
LegacyCacheEntry,
|
||||
LegacyConversationManager,
|
||||
)
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sio():
|
||||
"""Create a mock SocketIO server."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock OpenHands config."""
|
||||
return MagicMock(spec=OpenHandsConfig)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server_config():
|
||||
"""Create a mock server config."""
|
||||
return MagicMock(spec=ServerConfig)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_store():
|
||||
"""Create a mock file store."""
|
||||
return MagicMock(spec=InMemoryFileStore)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_monitoring_listener():
|
||||
"""Create a mock monitoring listener."""
|
||||
return MagicMock(spec=MonitoringListener)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_manager():
|
||||
"""Create a mock SaasNestedConversationManager."""
|
||||
mock_cm = MagicMock()
|
||||
mock_cm._get_runtime = AsyncMock()
|
||||
return mock_cm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_legacy_conversation_manager():
|
||||
"""Create a mock ClusteredConversationManager."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def legacy_manager(
|
||||
mock_sio,
|
||||
mock_config,
|
||||
mock_server_config,
|
||||
mock_file_store,
|
||||
mock_conversation_manager,
|
||||
mock_legacy_conversation_manager,
|
||||
):
|
||||
"""Create a LegacyConversationManager instance for testing."""
|
||||
return LegacyConversationManager(
|
||||
sio=mock_sio,
|
||||
config=mock_config,
|
||||
server_config=mock_server_config,
|
||||
file_store=mock_file_store,
|
||||
conversation_manager=mock_conversation_manager,
|
||||
legacy_conversation_manager=mock_legacy_conversation_manager,
|
||||
)
|
||||
|
||||
|
||||
class TestLegacyCacheEntry:
|
||||
"""Test the LegacyCacheEntry dataclass."""
|
||||
|
||||
def test_cache_entry_creation(self):
|
||||
"""Test creating a cache entry."""
|
||||
timestamp = time.time()
|
||||
entry = LegacyCacheEntry(is_legacy=True, timestamp=timestamp)
|
||||
|
||||
assert entry.is_legacy is True
|
||||
assert entry.timestamp == timestamp
|
||||
|
||||
def test_cache_entry_false(self):
|
||||
"""Test creating a cache entry with False value."""
|
||||
timestamp = time.time()
|
||||
entry = LegacyCacheEntry(is_legacy=False, timestamp=timestamp)
|
||||
|
||||
assert entry.is_legacy is False
|
||||
assert entry.timestamp == timestamp
|
||||
|
||||
|
||||
class TestLegacyConversationManagerCacheCleanup:
|
||||
"""Test cache cleanup functionality."""
|
||||
|
||||
def test_cleanup_expired_cache_entries_removes_expired(self, legacy_manager):
|
||||
"""Test that expired entries are removed from cache."""
|
||||
current_time = time.time()
|
||||
expired_time = current_time - _LEGACY_ENTRY_TIMEOUT_SECONDS - 1
|
||||
valid_time = current_time - 100 # Well within timeout
|
||||
|
||||
# Add both expired and valid entries
|
||||
legacy_manager._legacy_cache = {
|
||||
'expired_conversation': LegacyCacheEntry(True, expired_time),
|
||||
'valid_conversation': LegacyCacheEntry(False, valid_time),
|
||||
'another_expired': LegacyCacheEntry(True, expired_time - 100),
|
||||
}
|
||||
|
||||
legacy_manager._cleanup_expired_cache_entries()
|
||||
|
||||
# Only valid entry should remain
|
||||
assert len(legacy_manager._legacy_cache) == 1
|
||||
assert 'valid_conversation' in legacy_manager._legacy_cache
|
||||
assert 'expired_conversation' not in legacy_manager._legacy_cache
|
||||
assert 'another_expired' not in legacy_manager._legacy_cache
|
||||
|
||||
def test_cleanup_expired_cache_entries_keeps_valid(self, legacy_manager):
|
||||
"""Test that valid entries are kept during cleanup."""
|
||||
current_time = time.time()
|
||||
valid_time = current_time - 100 # Well within timeout
|
||||
|
||||
legacy_manager._legacy_cache = {
|
||||
'valid_conversation_1': LegacyCacheEntry(True, valid_time),
|
||||
'valid_conversation_2': LegacyCacheEntry(False, valid_time - 50),
|
||||
}
|
||||
|
||||
legacy_manager._cleanup_expired_cache_entries()
|
||||
|
||||
# Both entries should remain
|
||||
assert len(legacy_manager._legacy_cache) == 2
|
||||
assert 'valid_conversation_1' in legacy_manager._legacy_cache
|
||||
assert 'valid_conversation_2' in legacy_manager._legacy_cache
|
||||
|
||||
def test_cleanup_expired_cache_entries_empty_cache(self, legacy_manager):
|
||||
"""Test cleanup with empty cache."""
|
||||
legacy_manager._legacy_cache = {}
|
||||
|
||||
legacy_manager._cleanup_expired_cache_entries()
|
||||
|
||||
assert len(legacy_manager._legacy_cache) == 0
|
||||
|
||||
|
||||
class TestIsLegacyRuntime:
|
||||
"""Test the is_legacy_runtime method."""
|
||||
|
||||
def test_is_legacy_runtime_none(self, legacy_manager):
|
||||
"""Test with None runtime."""
|
||||
result = legacy_manager.is_legacy_runtime(None)
|
||||
assert result is False
|
||||
|
||||
def test_is_legacy_runtime_legacy_command(self, legacy_manager):
|
||||
"""Test with legacy runtime command."""
|
||||
runtime = {'command': 'some_old_legacy_command'}
|
||||
result = legacy_manager.is_legacy_runtime(runtime)
|
||||
assert result is True
|
||||
|
||||
def test_is_legacy_runtime_new_command(self, legacy_manager):
|
||||
"""Test with new runtime command containing openhands.server."""
|
||||
runtime = {'command': 'python -m openhands.server.listen'}
|
||||
result = legacy_manager.is_legacy_runtime(runtime)
|
||||
assert result is False
|
||||
|
||||
def test_is_legacy_runtime_partial_match(self, legacy_manager):
|
||||
"""Test with command that partially matches but is still legacy."""
|
||||
runtime = {'command': 'openhands.client.start'}
|
||||
result = legacy_manager.is_legacy_runtime(runtime)
|
||||
assert result is True
|
||||
|
||||
def test_is_legacy_runtime_empty_command(self, legacy_manager):
|
||||
"""Test with empty command."""
|
||||
runtime = {'command': ''}
|
||||
result = legacy_manager.is_legacy_runtime(runtime)
|
||||
assert result is True
|
||||
|
||||
def test_is_legacy_runtime_missing_command_key(self, legacy_manager):
|
||||
"""Test with runtime missing command key."""
|
||||
runtime = {'other_key': 'value'}
|
||||
# This should raise a KeyError
|
||||
with pytest.raises(KeyError):
|
||||
legacy_manager.is_legacy_runtime(runtime)
|
||||
|
||||
|
||||
class TestShouldStartInLegacyMode:
|
||||
"""Test the should_start_in_legacy_mode method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_valid_entry_legacy(self, legacy_manager):
|
||||
"""Test cache hit with valid legacy entry."""
|
||||
conversation_id = 'test_conversation'
|
||||
current_time = time.time()
|
||||
|
||||
# Add valid cache entry
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(
|
||||
True, current_time - 100
|
||||
)
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is True
|
||||
# Should not call _get_runtime since we hit cache
|
||||
legacy_manager.conversation_manager._get_runtime.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_valid_entry_non_legacy(self, legacy_manager):
|
||||
"""Test cache hit with valid non-legacy entry."""
|
||||
conversation_id = 'test_conversation'
|
||||
current_time = time.time()
|
||||
|
||||
# Add valid cache entry
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(
|
||||
False, current_time - 100
|
||||
)
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is False
|
||||
# Should not call _get_runtime since we hit cache
|
||||
legacy_manager.conversation_manager._get_runtime.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss_legacy_runtime(self, legacy_manager):
|
||||
"""Test cache miss with legacy runtime."""
|
||||
conversation_id = 'test_conversation'
|
||||
runtime = {'command': 'old_command'}
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is True
|
||||
# Should call _get_runtime
|
||||
legacy_manager.conversation_manager._get_runtime.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Should cache the result
|
||||
assert conversation_id in legacy_manager._legacy_cache
|
||||
assert legacy_manager._legacy_cache[conversation_id].is_legacy is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss_non_legacy_runtime(self, legacy_manager):
|
||||
"""Test cache miss with non-legacy runtime."""
|
||||
conversation_id = 'test_conversation'
|
||||
runtime = {'command': 'python -m openhands.server.listen'}
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is False
|
||||
# Should call _get_runtime
|
||||
legacy_manager.conversation_manager._get_runtime.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Should cache the result
|
||||
assert conversation_id in legacy_manager._legacy_cache
|
||||
assert legacy_manager._legacy_cache[conversation_id].is_legacy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expired_entry(self, legacy_manager):
|
||||
"""Test with expired cache entry."""
|
||||
conversation_id = 'test_conversation'
|
||||
expired_time = time.time() - _LEGACY_ENTRY_TIMEOUT_SECONDS - 1
|
||||
runtime = {'command': 'python -m openhands.server.listen'}
|
||||
|
||||
# Add expired cache entry
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(
|
||||
True,
|
||||
expired_time, # This should be considered expired
|
||||
)
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is False # Runtime indicates non-legacy
|
||||
# Should call _get_runtime since cache is expired
|
||||
legacy_manager.conversation_manager._get_runtime.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Should update cache with new result
|
||||
assert legacy_manager._legacy_cache[conversation_id].is_legacy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_exactly_at_timeout(self, legacy_manager):
|
||||
"""Test with cache entry exactly at timeout boundary."""
|
||||
conversation_id = 'test_conversation'
|
||||
timeout_time = time.time() - _LEGACY_ENTRY_TIMEOUT_SECONDS
|
||||
runtime = {'command': 'python -m openhands.server.listen'}
|
||||
|
||||
# Add cache entry exactly at timeout
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(
|
||||
True, timeout_time
|
||||
)
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
# Should treat as expired and fetch from runtime
|
||||
assert result is False
|
||||
legacy_manager.conversation_manager._get_runtime.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_returns_none(self, legacy_manager):
|
||||
"""Test when runtime returns None."""
|
||||
conversation_id = 'test_conversation'
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = None
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is False
|
||||
# Should cache the result
|
||||
assert conversation_id in legacy_manager._legacy_cache
|
||||
assert legacy_manager._legacy_cache[conversation_id].is_legacy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_called_on_each_invocation(self, legacy_manager):
|
||||
"""Test that cleanup is called on each invocation."""
|
||||
conversation_id = 'test_conversation'
|
||||
runtime = {'command': 'test'}
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
# Mock the cleanup method to verify it's called
|
||||
with patch.object(
|
||||
legacy_manager, '_cleanup_expired_cache_entries'
|
||||
) as mock_cleanup:
|
||||
await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
mock_cleanup.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_conversations_cached_independently(self, legacy_manager):
|
||||
"""Test that multiple conversations are cached independently."""
|
||||
conv1 = 'conversation_1'
|
||||
conv2 = 'conversation_2'
|
||||
|
||||
runtime1 = {'command': 'old_command'} # Legacy
|
||||
runtime2 = {'command': 'python -m openhands.server.listen'} # Non-legacy
|
||||
|
||||
# Mock to return different runtimes based on conversation_id
|
||||
def mock_get_runtime(conversation_id):
|
||||
if conversation_id == conv1:
|
||||
return runtime1
|
||||
return runtime2
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.side_effect = mock_get_runtime
|
||||
|
||||
result1 = await legacy_manager.should_start_in_legacy_mode(conv1)
|
||||
result2 = await legacy_manager.should_start_in_legacy_mode(conv2)
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is False
|
||||
|
||||
# Both should be cached
|
||||
assert conv1 in legacy_manager._legacy_cache
|
||||
assert conv2 in legacy_manager._legacy_cache
|
||||
assert legacy_manager._legacy_cache[conv1].is_legacy is True
|
||||
assert legacy_manager._legacy_cache[conv2].is_legacy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_timestamp_updated_on_refresh(self, legacy_manager):
|
||||
"""Test that cache timestamp is updated when entry is refreshed."""
|
||||
conversation_id = 'test_conversation'
|
||||
old_time = time.time() - _LEGACY_ENTRY_TIMEOUT_SECONDS - 1
|
||||
runtime = {'command': 'test'}
|
||||
|
||||
# Add expired entry
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(True, old_time)
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
# Record time before call
|
||||
before_call = time.time()
|
||||
await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
after_call = time.time()
|
||||
|
||||
# Timestamp should be updated
|
||||
cached_entry = legacy_manager._legacy_cache[conversation_id]
|
||||
assert cached_entry.timestamp >= before_call
|
||||
assert cached_entry.timestamp <= after_call
|
||||
|
||||
|
||||
class TestLegacyConversationManagerIntegration:
|
||||
"""Integration tests for LegacyConversationManager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_creates_proper_manager(
|
||||
self,
|
||||
mock_sio,
|
||||
mock_config,
|
||||
mock_file_store,
|
||||
mock_server_config,
|
||||
mock_monitoring_listener,
|
||||
):
|
||||
"""Test that get_instance creates a properly configured manager."""
|
||||
with patch(
|
||||
'server.legacy_conversation_manager.SaasNestedConversationManager'
|
||||
) as mock_saas, patch(
|
||||
'server.legacy_conversation_manager.ClusteredConversationManager'
|
||||
) as mock_clustered:
|
||||
mock_saas.get_instance.return_value = MagicMock()
|
||||
mock_clustered.get_instance.return_value = MagicMock()
|
||||
|
||||
manager = LegacyConversationManager.get_instance(
|
||||
mock_sio,
|
||||
mock_config,
|
||||
mock_file_store,
|
||||
mock_server_config,
|
||||
mock_monitoring_listener,
|
||||
)
|
||||
|
||||
assert isinstance(manager, LegacyConversationManager)
|
||||
assert manager.sio == mock_sio
|
||||
assert manager.config == mock_config
|
||||
assert manager.file_store == mock_file_store
|
||||
assert manager.server_config == mock_server_config
|
||||
|
||||
# Verify that both nested managers are created
|
||||
mock_saas.get_instance.assert_called_once()
|
||||
mock_clustered.get_instance.assert_called_once()
|
||||
|
||||
def test_legacy_cache_initialized_empty(self, legacy_manager):
|
||||
"""Test that legacy cache is initialized as empty dict."""
|
||||
assert isinstance(legacy_manager._legacy_cache, dict)
|
||||
assert len(legacy_manager._legacy_cache) == 0
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_runtime_raises_exception(self, legacy_manager):
|
||||
"""Test behavior when _get_runtime raises an exception."""
|
||||
conversation_id = 'test_conversation'
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.side_effect = Exception(
|
||||
'Runtime error'
|
||||
)
|
||||
|
||||
# Should propagate the exception
|
||||
with pytest.raises(Exception, match='Runtime error'):
|
||||
await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_large_cache(self, legacy_manager):
|
||||
"""Test behavior with a large number of cache entries."""
|
||||
current_time = time.time()
|
||||
|
||||
# Add many cache entries
|
||||
for i in range(1000):
|
||||
legacy_manager._legacy_cache[f'conversation_{i}'] = LegacyCacheEntry(
|
||||
i % 2 == 0, current_time - i
|
||||
)
|
||||
|
||||
# This should work without issues
|
||||
await legacy_manager.should_start_in_legacy_mode('new_conversation')
|
||||
|
||||
# Should have added one more entry
|
||||
assert len(legacy_manager._legacy_cache) == 1001
|
||||
|
||||
def test_cleanup_with_concurrent_modifications(self, legacy_manager):
|
||||
"""Test cleanup behavior when cache is modified during cleanup."""
|
||||
current_time = time.time()
|
||||
expired_time = current_time - _LEGACY_ENTRY_TIMEOUT_SECONDS - 1
|
||||
|
||||
# Add expired entries
|
||||
legacy_manager._legacy_cache = {
|
||||
f'conversation_{i}': LegacyCacheEntry(True, expired_time) for i in range(10)
|
||||
}
|
||||
|
||||
# This should work without raising exceptions
|
||||
legacy_manager._cleanup_expired_cache_entries()
|
||||
|
||||
# All entries should be removed
|
||||
assert len(legacy_manager._legacy_cache) == 0
|
||||
@@ -6,7 +6,6 @@ from server.constants import (
|
||||
CURRENT_USER_SETTINGS_VERSION,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
@@ -394,11 +393,10 @@ async def test_create_user_in_lite_llm(settings_store):
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_client.post.return_value = mock_response
|
||||
test_model = 'custom-model/test-model'
|
||||
|
||||
# Test with email
|
||||
await settings_store._create_user_in_lite_llm(
|
||||
mock_client, 'test@example.com', 50, 10, test_model
|
||||
mock_client, 'test@example.com', 50, 10
|
||||
)
|
||||
|
||||
# Get the actual call arguments
|
||||
@@ -414,11 +412,11 @@ async def test_create_user_in_lite_llm(settings_store):
|
||||
assert call_args['json']['auto_create_key'] is True
|
||||
assert call_args['json']['send_invite_email'] is False
|
||||
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
|
||||
assert call_args['json']['metadata']['model'] == test_model
|
||||
assert 'model' in call_args['json']['metadata']
|
||||
|
||||
# Test with None email
|
||||
mock_client.post.reset_mock()
|
||||
await settings_store._create_user_in_lite_llm(mock_client, None, 25, 15, test_model)
|
||||
await settings_store._create_user_in_lite_llm(mock_client, None, 25, 15)
|
||||
|
||||
# Get the actual call arguments
|
||||
call_args = mock_client.post.call_args[1]
|
||||
@@ -433,12 +431,12 @@ async def test_create_user_in_lite_llm(settings_store):
|
||||
assert call_args['json']['auto_create_key'] is True
|
||||
assert call_args['json']['send_invite_email'] is False
|
||||
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
|
||||
assert call_args['json']['metadata']['model'] == test_model
|
||||
assert 'model' in call_args['json']['metadata']
|
||||
|
||||
# Verify response is returned correctly
|
||||
assert (
|
||||
await settings_store._create_user_in_lite_llm(
|
||||
mock_client, 'email@test.com', 30, 7, test_model
|
||||
mock_client, 'email@test.com', 30, 7
|
||||
)
|
||||
== mock_response
|
||||
)
|
||||
@@ -466,808 +464,3 @@ async def test_encryption(settings_store):
|
||||
# But we should be able to decrypt it when loading
|
||||
loaded_settings = await settings_store.load()
|
||||
assert loaded_settings.llm_api_key.get_secret_value() == 'secret_key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_preserves_custom_model(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has a custom LLM model set
|
||||
custom_model = 'anthropic/claude-3-5-sonnet-20241022'
|
||||
settings = Settings(llm_model=custom_model)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Custom model is preserved
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_model == custom_model
|
||||
assert updated_settings.agent == 'CodeActAgent'
|
||||
assert updated_settings.llm_api_key is not None
|
||||
|
||||
# Assert: LiteLLM metadata contains user's custom model
|
||||
call_args = mock_litellm_api.return_value.__aenter__.return_value.post.call_args[1]
|
||||
assert call_args['json']['metadata']['model'] == custom_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_uses_default_when_no_model(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has no model set (new user scenario)
|
||||
settings = Settings()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'newuser@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default model is assigned
|
||||
assert updated_settings is not None
|
||||
expected_default = get_default_litellm_model()
|
||||
assert updated_settings.llm_model == expected_default
|
||||
assert updated_settings.agent == 'CodeActAgent'
|
||||
|
||||
# Assert: LiteLLM metadata contains default model
|
||||
call_args = mock_litellm_api.return_value.__aenter__.return_value.post.call_args[1]
|
||||
assert call_args['json']['metadata']['model'] == expected_default
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_handles_empty_string_model(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has empty string as model (edge case)
|
||||
settings = Settings(llm_model='')
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default model is used (empty string treated as no model)
|
||||
assert updated_settings is not None
|
||||
expected_default = get_default_litellm_model()
|
||||
assert updated_settings.llm_model == expected_default
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_handles_whitespace_model(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has whitespace-only model (edge case)
|
||||
settings = Settings(llm_model=' ')
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default model is used (whitespace treated as no model)
|
||||
assert updated_settings is not None
|
||||
expected_default = get_default_litellm_model()
|
||||
assert updated_settings.llm_model == expected_default
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_preserves_custom_api_key(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has a custom API key and custom model (so has_custom=True)
|
||||
custom_api_key = 'sk-custom-user-api-key-12345'
|
||||
custom_model = 'gpt-4'
|
||||
settings = Settings(llm_model=custom_model, llm_api_key=SecretStr(custom_api_key))
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Custom API key is preserved when user has custom settings
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_api_key.get_secret_value() == custom_api_key
|
||||
assert updated_settings.llm_api_key.get_secret_value() != 'test_api_key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_preserves_custom_base_url(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has a custom base URL
|
||||
custom_base_url = 'https://api.custom-llm-provider.com/v1'
|
||||
settings = Settings(llm_base_url=custom_base_url)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Custom base URL is preserved
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_base_url == custom_base_url
|
||||
assert updated_settings.llm_base_url != LITE_LLM_API_URL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_preserves_custom_api_key_and_base_url(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has both custom API key and base URL
|
||||
custom_api_key = 'sk-custom-user-api-key-67890'
|
||||
custom_base_url = 'https://api.another-llm-provider.com/v1'
|
||||
custom_model = 'openai/gpt-4'
|
||||
settings = Settings(
|
||||
llm_model=custom_model,
|
||||
llm_api_key=SecretStr(custom_api_key),
|
||||
llm_base_url=custom_base_url,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: All custom settings are preserved
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_model == custom_model
|
||||
assert updated_settings.llm_api_key.get_secret_value() == custom_api_key
|
||||
assert updated_settings.llm_base_url == custom_base_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_uses_default_api_key_when_none(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has no API key set
|
||||
settings = Settings(llm_api_key=None)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default LiteLLM API key is assigned
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_api_key is not None
|
||||
assert updated_settings.llm_api_key.get_secret_value() == 'test_api_key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_uses_default_base_url_when_none(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has no base URL set
|
||||
settings = Settings(llm_base_url=None)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default LiteLLM base URL is assigned (using mocked value)
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_base_url == 'http://test.url'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_handles_empty_api_key(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has empty string as API key (edge case)
|
||||
settings = Settings(llm_api_key=SecretStr(''))
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default API key is used (empty string treated as no key)
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_api_key.get_secret_value() == 'test_api_key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_handles_empty_base_url(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has empty string as base URL (edge case)
|
||||
settings = Settings(llm_base_url='')
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default base URL is used (empty string treated as no URL)
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_base_url == 'http://test.url'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_handles_whitespace_api_key(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has whitespace-only API key (edge case)
|
||||
settings = Settings(llm_api_key=SecretStr(' '))
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default API key is used (whitespace treated as no key)
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_api_key.get_secret_value() == 'test_api_key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_handles_whitespace_base_url(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has whitespace-only base URL (edge case)
|
||||
settings = Settings(llm_base_url=' ')
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default base URL is used (whitespace treated as no URL)
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_base_url == 'http://test.url'
|
||||
|
||||
|
||||
# Tests for version migration and helper methods
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_custom_base_url(settings_store):
|
||||
# Arrange: User with custom base URL (BYOR)
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(llm_base_url='http://custom.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: Custom base URL detected
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_default_base_url(settings_store):
|
||||
# Arrange: User with default base URL
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(llm_base_url='http://default.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: No custom settings (no model set)
|
||||
assert has_custom is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_no_model(settings_store):
|
||||
# Arrange: User with no model set
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(llm_model=None, llm_base_url='http://default.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: No custom settings (using defaults)
|
||||
assert has_custom is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_empty_model(settings_store):
|
||||
# Arrange: User with empty model
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(llm_model='', llm_base_url='http://default.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: No custom settings (empty treated as no model)
|
||||
assert has_custom is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_whitespace_model(settings_store):
|
||||
# Arrange: User with whitespace-only model
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(llm_model=' ', llm_base_url='http://default.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: No custom settings (whitespace treated as no model)
|
||||
assert has_custom is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_custom_model(settings_store):
|
||||
# Arrange: User with custom model
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(llm_model='gpt-4', llm_base_url='http://default.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: Custom model detected
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_matches_old_default_model(settings_store):
|
||||
# Arrange: User with old version and model matching old default
|
||||
with (
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
|
||||
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022'},
|
||||
),
|
||||
):
|
||||
settings = Settings(
|
||||
llm_model='litellm_proxy/prod/claude-3-5-sonnet-20241022',
|
||||
llm_base_url='http://default.url',
|
||||
)
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, 1)
|
||||
|
||||
# Assert: Matches old default, so not custom
|
||||
assert has_custom is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_matches_old_default_by_base_name(settings_store):
|
||||
# Arrange: User with old version and model matching old default by base name
|
||||
with (
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
|
||||
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022'},
|
||||
),
|
||||
):
|
||||
settings = Settings(
|
||||
llm_model='anthropic/claude-3-5-sonnet-20241022',
|
||||
llm_base_url='http://default.url',
|
||||
)
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, 1)
|
||||
|
||||
# Assert: Matches old default by base name, so not custom
|
||||
assert has_custom is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_old_version_but_custom_model(settings_store):
|
||||
# Arrange: User with old version but custom model
|
||||
with (
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
|
||||
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022'},
|
||||
),
|
||||
):
|
||||
settings = Settings(llm_model='gpt-4', llm_base_url='http://default.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, 1)
|
||||
|
||||
# Assert: Custom model detected
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_current_version(settings_store):
|
||||
# Arrange: User with current version
|
||||
with (
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
|
||||
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022', 5: 'claude-opus-4-5-20251101'},
|
||||
),
|
||||
):
|
||||
settings = Settings(
|
||||
llm_model='claude-3-5-sonnet-20241022', llm_base_url='http://default.url'
|
||||
)
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, 5)
|
||||
|
||||
# Assert: Current version, so model is custom (not old default)
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_none_version(settings_store):
|
||||
# Arrange: User with no version
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(
|
||||
llm_model='claude-3-5-sonnet-20241022', llm_base_url='http://default.url'
|
||||
)
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: No version, so model is custom
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_invalid_version(settings_store):
|
||||
# Arrange: User with invalid version
|
||||
with (
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
|
||||
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022'},
|
||||
),
|
||||
):
|
||||
settings = Settings(
|
||||
llm_model='claude-3-5-sonnet-20241022', llm_base_url='http://default.url'
|
||||
)
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, 99)
|
||||
|
||||
# Assert: Invalid version, so model is custom
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_normalizes_whitespace(settings_store):
|
||||
# Arrange: Settings with whitespace in values
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(
|
||||
llm_model=' claude-3-5-sonnet-20241022 ',
|
||||
llm_base_url=' http://default.url ',
|
||||
)
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: Whitespace is normalized, custom model detected
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_upgrades_user_from_old_defaults(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User with old version using old defaults
|
||||
old_version = 1
|
||||
old_model = 'litellm_proxy/prod/claude-3-5-sonnet-20241022'
|
||||
settings = Settings(llm_model=old_model, llm_base_url=LITE_LLM_API_URL)
|
||||
|
||||
# Use a consistent test URL
|
||||
test_base_url = 'http://test.url'
|
||||
|
||||
with (
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022', 5: 'claude-opus-4-5-20251101'},
|
||||
),
|
||||
patch(
|
||||
'storage.saas_settings_store.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022', 5: 'claude-opus-4-5-20251101'},
|
||||
),
|
||||
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch('storage.saas_settings_store.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', test_base_url),
|
||||
patch(
|
||||
'storage.saas_settings_store.get_default_litellm_model',
|
||||
return_value='litellm_proxy/prod/claude-opus-4-5-20251101',
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
):
|
||||
# Create existing user settings with old version
|
||||
with session_maker() as session:
|
||||
existing_settings = UserSettings(
|
||||
keycloak_user_id=settings_store.user_id,
|
||||
user_version=old_version,
|
||||
llm_model=old_model,
|
||||
llm_base_url=test_base_url,
|
||||
)
|
||||
session.add(existing_settings)
|
||||
session.commit()
|
||||
|
||||
# Update settings to use test_base_url
|
||||
# Set user_version to match the database so _has_custom_settings can detect old defaults
|
||||
settings = Settings(
|
||||
llm_model=old_model, llm_base_url=test_base_url, user_version=old_version
|
||||
)
|
||||
|
||||
# Act: Update settings
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Settings upgraded to new defaults
|
||||
assert updated_settings is not None
|
||||
assert (
|
||||
updated_settings.llm_model == 'litellm_proxy/prod/claude-opus-4-5-20251101'
|
||||
)
|
||||
assert updated_settings.llm_base_url == test_base_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_preserves_custom_settings_during_upgrade(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User with old version but custom settings
|
||||
old_version = 1
|
||||
custom_model = 'gpt-4'
|
||||
custom_base_url = 'http://custom.url'
|
||||
settings = Settings(llm_model=custom_model, llm_base_url=custom_base_url)
|
||||
|
||||
with (
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022'},
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
):
|
||||
# Create existing user settings with old version
|
||||
with session_maker() as session:
|
||||
existing_settings = UserSettings(
|
||||
keycloak_user_id=settings_store.user_id,
|
||||
user_version=old_version,
|
||||
llm_model=custom_model,
|
||||
llm_base_url=custom_base_url,
|
||||
)
|
||||
session.add(existing_settings)
|
||||
session.commit()
|
||||
|
||||
# Act: Update settings
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Custom settings preserved
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_model == custom_model
|
||||
assert updated_settings.llm_base_url == custom_base_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_migrates_billing_margin_v3_to_v4(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User with version 3 and billing margin
|
||||
old_version = 3
|
||||
billing_margin = 2.0
|
||||
max_budget = 10.0
|
||||
spend = 5.0
|
||||
|
||||
settings = Settings()
|
||||
|
||||
mock_get_response = AsyncMock()
|
||||
mock_get_response.is_success = True
|
||||
mock_get_response.json = MagicMock(
|
||||
return_value={'user_info': {'max_budget': max_budget, 'spend': spend}}
|
||||
)
|
||||
|
||||
mock_post_response = AsyncMock()
|
||||
mock_post_response.is_success = True
|
||||
mock_post_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
|
||||
with (
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
):
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_get_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_post_response
|
||||
)
|
||||
|
||||
# Create existing user settings with version 3 and billing margin
|
||||
with session_maker() as session:
|
||||
existing_settings = UserSettings(
|
||||
keycloak_user_id=settings_store.user_id,
|
||||
user_version=old_version,
|
||||
billing_margin=billing_margin,
|
||||
)
|
||||
session.add(existing_settings)
|
||||
session.commit()
|
||||
|
||||
# Act: Update settings
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Settings updated
|
||||
assert updated_settings is not None
|
||||
|
||||
# Assert: Billing margin applied to budget
|
||||
call_args = mock_client.return_value.__aenter__.return_value.post.call_args[1]
|
||||
assert call_args['json']['max_budget'] == max_budget * billing_margin
|
||||
assert call_args['json']['spend'] == spend * billing_margin
|
||||
|
||||
# Assert: Billing margin reset to 1.0
|
||||
with session_maker() as session:
|
||||
updated_user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == settings_store.user_id)
|
||||
.first()
|
||||
)
|
||||
assert updated_user_settings.billing_margin == 1.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_skips_billing_margin_migration_when_already_v4(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User with version 4
|
||||
version = 4
|
||||
billing_margin = 2.0
|
||||
max_budget = 10.0
|
||||
spend = 5.0
|
||||
|
||||
settings = Settings()
|
||||
|
||||
mock_get_response = AsyncMock()
|
||||
mock_get_response.is_success = True
|
||||
mock_get_response.json = MagicMock(
|
||||
return_value={'user_info': {'max_budget': max_budget, 'spend': spend}}
|
||||
)
|
||||
|
||||
mock_post_response = AsyncMock()
|
||||
mock_post_response.is_success = True
|
||||
mock_post_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
|
||||
with (
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
):
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_get_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_post_response
|
||||
)
|
||||
|
||||
# Create existing user settings with version 4
|
||||
with session_maker() as session:
|
||||
existing_settings = UserSettings(
|
||||
keycloak_user_id=settings_store.user_id,
|
||||
user_version=version,
|
||||
billing_margin=billing_margin,
|
||||
)
|
||||
session.add(existing_settings)
|
||||
session.commit()
|
||||
|
||||
# Act: Update settings
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Settings updated
|
||||
assert updated_settings is not None
|
||||
|
||||
# Assert: Billing margin NOT applied (version >= 4)
|
||||
call_args = mock_client.return_value.__aenter__.return_value.post.call_args[1]
|
||||
assert call_args['json']['max_budget'] == max_budget
|
||||
assert call_args['json']['spend'] == spend
|
||||
|
||||
@@ -5,12 +5,7 @@ import jwt
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
from server.auth.auth_error import (
|
||||
AuthError,
|
||||
BearerTokenError,
|
||||
CookieError,
|
||||
NoCredentialsError,
|
||||
)
|
||||
from server.auth.auth_error import BearerTokenError, CookieError, NoCredentialsError
|
||||
from server.auth.saas_user_auth import (
|
||||
SaasUserAuth,
|
||||
get_api_key_from_header,
|
||||
@@ -540,209 +535,3 @@ def test_get_api_key_from_header_with_invalid_authorization_format():
|
||||
|
||||
# Assert that None was returned
|
||||
assert api_key is None
|
||||
|
||||
|
||||
def test_get_api_key_from_header_with_x_access_token():
|
||||
"""Test that get_api_key_from_header extracts API key from X-Access-Token header."""
|
||||
# Create a mock request with X-Access-Token header
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {'X-Access-Token': 'access_token_key'}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key was correctly extracted
|
||||
assert api_key == 'access_token_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_priority_authorization_over_x_access_token():
|
||||
"""Test that Authorization header takes priority over X-Access-Token header."""
|
||||
# Create a mock request with both headers
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': 'Bearer auth_api_key',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from Authorization header was used
|
||||
assert api_key == 'auth_api_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_priority_x_session_over_x_access_token():
|
||||
"""Test that X-Session-API-Key header takes priority over X-Access-Token header."""
|
||||
# Create a mock request with both headers
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'X-Session-API-Key': 'session_api_key',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from X-Session-API-Key header was used
|
||||
assert api_key == 'session_api_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_all_three_headers():
|
||||
"""Test header priority when all three headers are present."""
|
||||
# Create a mock request with all three headers
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': 'Bearer auth_api_key',
|
||||
'X-Session-API-Key': 'session_api_key',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from Authorization header was used (highest priority)
|
||||
assert api_key == 'auth_api_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_invalid_authorization_fallback_to_x_access_token():
|
||||
"""Test that invalid Authorization header falls back to X-Access-Token."""
|
||||
# Create a mock request with invalid Authorization header and X-Access-Token
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': 'InvalidFormat api_key',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from X-Access-Token header was used
|
||||
assert api_key == 'access_token_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_empty_headers():
|
||||
"""Test that empty header values are handled correctly."""
|
||||
# Create a mock request with empty header values
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': '',
|
||||
'X-Session-API-Key': '',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from X-Access-Token header was used
|
||||
assert api_key == 'access_token_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_bearer_with_empty_token():
|
||||
"""Test that Bearer header with empty token falls back to other headers."""
|
||||
# Create a mock request with Bearer header with empty token
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': 'Bearer ',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that empty string from Bearer is returned (current behavior)
|
||||
# This tests the current implementation behavior
|
||||
assert api_key == ''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token raises AuthError when email domain is blocked."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
'exp': int(time.time()) + 3600,
|
||||
'email': 'user@colsch.us',
|
||||
'email_verified': True,
|
||||
}
|
||||
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||
|
||||
token_payload = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': 'test_refresh_token',
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(AuthError) as exc_info:
|
||||
await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
assert 'email domain is not allowed' in str(exc_info.value)
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
'exp': int(time.time()) + 3600,
|
||||
'email': 'user@example.com',
|
||||
'email_verified': True,
|
||||
}
|
||||
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||
|
||||
token_payload = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': 'test_refresh_token',
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
assert result.email == 'user@example.com'
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
|
||||
'user@example.com'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
'exp': int(time.time()) + 3600,
|
||||
'email': 'user@colsch.us',
|
||||
'email_verified': True,
|
||||
}
|
||||
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||
|
||||
token_payload = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': 'test_refresh_token',
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
|
||||
# Act
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Tests for sharing package."""
|
||||
@@ -1,91 +0,0 @@
|
||||
"""Tests for public conversation models."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
|
||||
|
||||
def test_public_conversation_creation():
|
||||
"""Test that SharedConversation can be created with all required fields."""
|
||||
conversation_id = uuid4()
|
||||
now = datetime.utcnow()
|
||||
|
||||
conversation = SharedConversation(
|
||||
id=conversation_id,
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Conversation',
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
selected_repository=None,
|
||||
parent_conversation_id=None,
|
||||
)
|
||||
|
||||
assert conversation.id == conversation_id
|
||||
assert conversation.title == 'Test Conversation'
|
||||
assert conversation.created_by_user_id == 'test_user'
|
||||
assert conversation.sandbox_id == 'test_sandbox'
|
||||
|
||||
|
||||
def test_public_conversation_page_creation():
|
||||
"""Test that SharedConversationPage can be created."""
|
||||
conversation_id = uuid4()
|
||||
now = datetime.utcnow()
|
||||
|
||||
conversation = SharedConversation(
|
||||
id=conversation_id,
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Conversation',
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
selected_repository=None,
|
||||
parent_conversation_id=None,
|
||||
)
|
||||
|
||||
page = SharedConversationPage(
|
||||
items=[conversation],
|
||||
next_page_id='next_page',
|
||||
)
|
||||
|
||||
assert len(page.items) == 1
|
||||
assert page.items[0].id == conversation_id
|
||||
assert page.next_page_id == 'next_page'
|
||||
|
||||
|
||||
def test_public_conversation_sort_order_enum():
|
||||
"""Test that SharedConversationSortOrder enum has expected values."""
|
||||
assert hasattr(SharedConversationSortOrder, 'CREATED_AT')
|
||||
assert hasattr(SharedConversationSortOrder, 'CREATED_AT_DESC')
|
||||
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT')
|
||||
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT_DESC')
|
||||
assert hasattr(SharedConversationSortOrder, 'TITLE')
|
||||
assert hasattr(SharedConversationSortOrder, 'TITLE_DESC')
|
||||
|
||||
|
||||
def test_public_conversation_optional_fields():
|
||||
"""Test that SharedConversation works with optional fields."""
|
||||
conversation_id = uuid4()
|
||||
parent_id = uuid4()
|
||||
now = datetime.utcnow()
|
||||
|
||||
conversation = SharedConversation(
|
||||
id=conversation_id,
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Conversation',
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
selected_repository='owner/repo',
|
||||
parent_conversation_id=parent_id,
|
||||
llm_model='gpt-4',
|
||||
)
|
||||
|
||||
assert conversation.selected_repository == 'owner/repo'
|
||||
assert conversation.parent_conversation_id == parent_id
|
||||
assert conversation.llm_model == 'gpt-4'
|
||||
@@ -1,430 +0,0 @@
|
||||
"""Tests for SharedConversationInfoService."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoService,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def shared_conversation_info_service(async_session):
|
||||
"""Create a SharedConversationInfoService for testing."""
|
||||
return SQLSharedConversationInfoService(db_session=async_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def app_conversation_service(async_session):
|
||||
"""Create an AppConversationInfoService for creating test data."""
|
||||
return SQLAppConversationInfoService(
|
||||
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation_info():
|
||||
"""Create a sample conversation info for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
selected_repository='test/repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Test Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[123],
|
||||
llm_model='gpt-4',
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=1.5,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=4096,
|
||||
per_turn_token=150,
|
||||
),
|
||||
),
|
||||
parent_conversation_id=None,
|
||||
sub_conversation_ids=[],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
public=True, # Make it public for testing
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_private_conversation_info():
|
||||
"""Create a sample private conversation info for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox_private',
|
||||
selected_repository='test/private_repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Private Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[124],
|
||||
llm_model='gpt-4',
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=2.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=4096,
|
||||
per_turn_token=300,
|
||||
),
|
||||
),
|
||||
parent_conversation_id=None,
|
||||
sub_conversation_ids=[],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
public=False, # Make it private
|
||||
)
|
||||
|
||||
|
||||
class TestSharedConversationInfoService:
|
||||
"""Test cases for SharedConversationInfoService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_public_conversation(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns a public conversation."""
|
||||
# Create a public conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
|
||||
# Retrieve it via public service
|
||||
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||
sample_conversation_info.id
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == sample_conversation_info.id
|
||||
assert result.title == sample_conversation_info.title
|
||||
assert result.created_by_user_id == sample_conversation_info.created_by_user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_none_for_private_conversation(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns None for private conversations."""
|
||||
# Create a private conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
|
||||
# Try to retrieve it via public service
|
||||
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||
sample_private_conversation_info.id
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_none_for_nonexistent_conversation(
|
||||
self, shared_conversation_info_service
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns None for nonexistent conversations."""
|
||||
nonexistent_id = uuid4()
|
||||
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||
nonexistent_id
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversation_info_returns_only_public_conversations(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test that search only returns public conversations."""
|
||||
# Create both public and private conversations
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
|
||||
# Search for all conversations
|
||||
result = (
|
||||
await shared_conversation_info_service.search_shared_conversation_info()
|
||||
)
|
||||
|
||||
# Should only return the public conversation
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].id == sample_conversation_info.id
|
||||
assert result.items[0].title == sample_conversation_info.title
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversation_info_with_title_filter(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
):
|
||||
"""Test searching with title filter."""
|
||||
# Create a public conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
|
||||
# Search with matching title
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
title__contains='Test'
|
||||
)
|
||||
assert len(result.items) == 1
|
||||
|
||||
# Search with non-matching title
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
title__contains='NonExistent'
|
||||
)
|
||||
assert len(result.items) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversation_info_with_sort_order(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
):
|
||||
"""Test searching with different sort orders."""
|
||||
# Create multiple public conversations with different titles and timestamps
|
||||
conv1 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox_1',
|
||||
title='A First Conversation',
|
||||
created_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
conv2 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox_2',
|
||||
title='B Second Conversation',
|
||||
created_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
|
||||
await app_conversation_service.save_app_conversation_info(conv1)
|
||||
await app_conversation_service.save_app_conversation_info(conv2)
|
||||
|
||||
# Test sort by title ascending
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.TITLE
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].title == 'A First Conversation'
|
||||
assert result.items[1].title == 'B Second Conversation'
|
||||
|
||||
# Test sort by title descending
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.TITLE_DESC
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].title == 'B Second Conversation'
|
||||
assert result.items[1].title == 'A First Conversation'
|
||||
|
||||
# Test sort by created_at ascending
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].id == conv1.id
|
||||
assert result.items[1].id == conv2.id
|
||||
|
||||
# Test sort by created_at descending (default)
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT_DESC
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].id == conv2.id
|
||||
assert result.items[1].id == conv1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_shared_conversation_info(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test counting public conversations."""
|
||||
# Initially should be 0
|
||||
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||
assert count == 0
|
||||
|
||||
# Create a public conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||
assert count == 1
|
||||
|
||||
# Create a private conversation - count should remain 1
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_get_shared_conversation_info(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test batch getting public conversations."""
|
||||
# Create both public and private conversations
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
|
||||
# Batch get both conversations
|
||||
result = (
|
||||
await shared_conversation_info_service.batch_get_shared_conversation_info(
|
||||
[sample_conversation_info.id, sample_private_conversation_info.id]
|
||||
)
|
||||
)
|
||||
|
||||
# Should return the public one and None for the private one
|
||||
assert len(result) == 2
|
||||
assert result[0] is not None
|
||||
assert result[0].id == sample_conversation_info.id
|
||||
assert result[1] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_pagination(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
):
|
||||
"""Test search with pagination."""
|
||||
# Create multiple public conversations
|
||||
conversations = []
|
||||
for i in range(5):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id=f'test_sandbox_{i}',
|
||||
title=f'Conversation {i}',
|
||||
created_at=datetime(2023, 1, i + 1, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, i + 1, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
conversations.append(conv)
|
||||
await app_conversation_service.save_app_conversation_info(conv)
|
||||
|
||||
# Get first page with limit 2
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
limit=2, sort_order=SharedConversationSortOrder.CREATED_AT
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.next_page_id is not None
|
||||
|
||||
# Get next page
|
||||
result2 = (
|
||||
await shared_conversation_info_service.search_shared_conversation_info(
|
||||
limit=2,
|
||||
page_id=result.next_page_id,
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT,
|
||||
)
|
||||
)
|
||||
assert len(result2.items) == 2
|
||||
assert result2.next_page_id is not None
|
||||
|
||||
# Verify no overlap between pages
|
||||
page1_ids = {item.id for item in result.items}
|
||||
page2_ids = {item.id for item in result2.items}
|
||||
assert page1_ids.isdisjoint(page2_ids)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user