mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Enterprise code and docker build (#10770)
This commit is contained in:
95
.github/workflows/ghcr-build.yml
vendored
95
.github/workflows/ghcr-build.yml
vendored
@@ -10,14 +10,14 @@ on:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
reason:
|
||||
description: 'Reason for manual trigger'
|
||||
description: "Reason for manual trigger"
|
||||
required: true
|
||||
default: ''
|
||||
default: ""
|
||||
|
||||
# If triggered by a PR, it will be in the same group. However, each commit on main will be in its own unique group
|
||||
concurrency:
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
python-version: "3.12"
|
||||
cache: poetry
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: make install-python-dependencies POETRY_GROUP=main INSTALL_PLAYWRIGHT=0
|
||||
@@ -166,6 +166,89 @@ jobs:
|
||||
name: runtime-src-${{ matrix.base_image.tag }}
|
||||
path: containers/runtime
|
||||
|
||||
ghcr_build_enterprise:
|
||||
name: Push Enterprise Image
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2204
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
needs: [define-matrix, ghcr_build_app]
|
||||
# Do not build enterprise in forks
|
||||
if: github.event.pull_request.head.repo.fork != true
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# Set up Docker Buildx for better performance
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Login to GHCR
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ghcr.io/all-hands-ai/enterprise-server
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=sha
|
||||
type=sha,format=long
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
flavor: |
|
||||
latest=auto
|
||||
prefix=
|
||||
suffix=
|
||||
- name: Determine app image tag
|
||||
shell: bash
|
||||
run: |
|
||||
# Duplicated with build.sh
|
||||
sanitized_ref_name=$(echo "$GITHUB_REF_NAME" | sed 's/[^a-zA-Z0-9.-]\+/-/g')
|
||||
OPENHANDS_BUILD_VERSION=$sanitized_ref_name
|
||||
sanitized_ref_name=$(echo "$sanitized_ref_name" | tr '[:upper:]' '[:lower:]') # lower case is required in tagging
|
||||
echo "OPENHANDS_DOCKER_TAG=${sanitized_ref_name}" >> $GITHUB_ENV
|
||||
- name: Build and push Docker image
|
||||
uses: useblacksmith/build-push-action@v1
|
||||
with:
|
||||
context: .
|
||||
file: enterprise/Dockerfile
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
OPENHANDS_VERSION=${{ env.OPENHANDS_DOCKER_TAG }}
|
||||
platforms: linux/amd64
|
||||
# Add build provenance
|
||||
provenance: true
|
||||
# Add build attestations for better security
|
||||
sbom: true
|
||||
|
||||
enterprise-preview:
|
||||
name: Enterprise preview
|
||||
if: |
|
||||
(github.event_name == 'pull_request' && github.event.action == 'labeled' && github.event.label.name == 'deploy') ||
|
||||
(github.event_name == 'pull_request' && github.event.action != 'labeled' && contains(github.event.pull_request.labels.*.name, 'deploy'))
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
needs: [ghcr_build_enterprise]
|
||||
steps:
|
||||
- name: Trigger remote job
|
||||
run: |
|
||||
curl --fail-with-body -sS -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/All-Hands-AI/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
|
||||
# Run unit tests with the Docker runtime Docker images as root
|
||||
test_runtime_root:
|
||||
name: RT Unit Tests (Root)
|
||||
@@ -202,7 +285,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
python-version: "3.12"
|
||||
cache: poetry
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: make install-python-dependencies INSTALL_PLAYWRIGHT=0
|
||||
@@ -264,7 +347,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
python-version: "3.12"
|
||||
cache: poetry
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: make install-python-dependencies POETRY_GROUP=main,test,runtime INSTALL_PLAYWRIGHT=0
|
||||
|
||||
18
.github/workflows/lint.yml
vendored
18
.github/workflows/lint.yml
vendored
@@ -55,6 +55,24 @@ jobs:
|
||||
- name: Run pre-commit hooks
|
||||
run: pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||
|
||||
lint-enterprise-python:
|
||||
name: Lint enterprise python
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: 3.12
|
||||
cache: "pip"
|
||||
- name: Install pre-commit
|
||||
run: pip install pre-commit==4.2.0
|
||||
- name: Run pre-commit hooks
|
||||
working-directory: ./enterprise
|
||||
run: pre-commit run --all-files --config ./dev_config/python/.pre-commit-config.yaml
|
||||
|
||||
# Check version consistency across documentation
|
||||
check-version-consistency:
|
||||
name: Check version consistency
|
||||
|
||||
33
.github/workflows/py-tests.yml
vendored
33
.github/workflows/py-tests.yml
vendored
@@ -21,10 +21,10 @@ jobs:
|
||||
name: Python Tests on Linux
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
env:
|
||||
INSTALL_DOCKER: '0' # Set to '0' to skip Docker installation
|
||||
INSTALL_DOCKER: "0" # Set to '0' to skip Docker installation
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.12']
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
@@ -35,14 +35,14 @@ jobs:
|
||||
- name: Setup Node.js
|
||||
uses: useblacksmith/setup-node@v5
|
||||
with:
|
||||
node-version: '22.x'
|
||||
node-version: "22.x"
|
||||
- name: Install poetry via pipx
|
||||
run: pipx install poetry
|
||||
- name: Set up Python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
cache: "poetry"
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: poetry install --with dev,test,runtime
|
||||
- name: Build Environment
|
||||
@@ -58,7 +58,7 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.12']
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install pipx
|
||||
@@ -69,7 +69,7 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
cache: "poetry"
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: poetry install --with dev,test,runtime
|
||||
- name: Run Windows unit tests
|
||||
@@ -83,3 +83,24 @@ jobs:
|
||||
PYTHONPATH: ".;$env:PYTHONPATH"
|
||||
TEST_RUNTIME: local
|
||||
DEBUG: "1"
|
||||
test-enterprise:
|
||||
name: Enterprise Python Unit Tests
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install poetry via pipx
|
||||
run: pipx install poetry
|
||||
- name: Set up Python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: "poetry"
|
||||
- name: Install Python dependencies using Poetry
|
||||
working-directory: ./enterprise
|
||||
run: poetry install --with dev,test
|
||||
- name: Run Unit Tests
|
||||
working-directory: ./enterprise
|
||||
run: PYTHONPATH=".:$PYTHONPATH" poetry run pytest --forked -n auto -svv -p no:ddtrace -p no:ddtrace.pytest_bdd -p no:ddtrace.pytest_benchmark ./tests/unit
|
||||
|
||||
@@ -3,9 +3,9 @@ repos:
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/)
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/|enterprise/)
|
||||
- id: end-of-file-fixer
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/)
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/|enterprise/)
|
||||
- id: check-yaml
|
||||
args: ["--allow-multiple-documents"]
|
||||
- id: debug-statements
|
||||
@@ -28,19 +28,28 @@ repos:
|
||||
entry: ruff check --config dev_config/python/ruff.toml
|
||||
types_or: [python, pyi, jupyter]
|
||||
args: [--fix, --unsafe-fixes]
|
||||
exclude: third_party/
|
||||
exclude: ^(third_party/|enterprise/)
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
entry: ruff format --config dev_config/python/ruff.toml
|
||||
types_or: [python, pyi, jupyter]
|
||||
exclude: third_party/
|
||||
exclude: ^(third_party/|enterprise/)
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.15.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
[types-requests, types-setuptools, types-pyyaml, types-toml, types-docker, types-Markdown, pydantic, lxml]
|
||||
[
|
||||
types-requests,
|
||||
types-setuptools,
|
||||
types-pyyaml,
|
||||
types-toml,
|
||||
types-docker,
|
||||
types-Markdown,
|
||||
pydantic,
|
||||
lxml,
|
||||
]
|
||||
# To see gaps add `--html-report mypy-report/`
|
||||
entry: mypy --config-file dev_config/python/mypy.ini openhands/
|
||||
always_run: true
|
||||
|
||||
@@ -9,7 +9,7 @@ no_implicit_optional = True
|
||||
strict_optional = True
|
||||
|
||||
# Exclude third-party runtime directory from type checking
|
||||
exclude = third_party/
|
||||
exclude = (third_party/|enterprise/)
|
||||
|
||||
[mypy-openhands.memory.condenser.impl.*]
|
||||
disable_error_code = override
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Exclude third-party runtime directory from linting
|
||||
exclude = ["third_party/"]
|
||||
exclude = ["third_party/", "enterprise/"]
|
||||
|
||||
[lint]
|
||||
select = [
|
||||
|
||||
26
enterprise/Dockerfile
Normal file
26
enterprise/Dockerfile
Normal file
@@ -0,0 +1,26 @@
|
||||
ARG OPENHANDS_VERSION=latest
|
||||
ARG BASE="ghcr.io/all-hands-ai/openhands"
|
||||
FROM ${BASE}:${OPENHANDS_VERSION}
|
||||
|
||||
# Datadog labels
|
||||
LABEL com.datadoghq.tags.service="deploy"
|
||||
LABEL com.datadoghq.tags.env="${DD_ENV}"
|
||||
|
||||
# Install Node.js v20+ and npm (which includes npx)
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
apt-get install -y jq gettext && \
|
||||
apt-get clean
|
||||
|
||||
RUN pip install alembic psycopg2-binary cloud-sql-python-connector pg8000 gspread stripe python-keycloak asyncpg sqlalchemy[asyncio] resend tenacity slack-sdk ddtrace posthog "limits==5.2.0" coredis prometheus-client shap scikit-learn pandas numpy
|
||||
|
||||
WORKDIR /app
|
||||
COPY enterprise .
|
||||
|
||||
RUN chown -R openhands:openhands /app && chmod -R 770 /app
|
||||
USER openhands
|
||||
|
||||
# Command will be overridden by Kubernetes deployment template
|
||||
CMD ["uvicorn", "saas_server:app", "--host", "0.0.0.0", "--port", "3000"]
|
||||
42
enterprise/Makefile
Normal file
42
enterprise/Makefile
Normal file
@@ -0,0 +1,42 @@
|
||||
BACKEND_HOST ?= "127.0.0.1"
|
||||
BACKEND_PORT = 3000
|
||||
BACKEND_HOST_PORT = "$(BACKEND_HOST):$(BACKEND_PORT)"
|
||||
FRONTEND_PORT = 3001
|
||||
OPENHANDS_PATH ?= "../../OpenHands"
|
||||
OPENHANDS := $(OPENHANDS_PATH)
|
||||
OPENHANDS_FRONTEND_PATH = $(OPENHANDS)/frontend/build
|
||||
|
||||
# ANSI color codes
|
||||
GREEN=$(shell tput -Txterm setaf 2)
|
||||
YELLOW=$(shell tput -Txterm setaf 3)
|
||||
RED=$(shell tput -Txterm setaf 1)
|
||||
BLUE=$(shell tput -Txterm setaf 6)
|
||||
RESET=$(shell tput -Txterm sgr0)
|
||||
|
||||
build:
|
||||
@poetry install
|
||||
@cd $(OPENHANDS) && $(MAKE) build
|
||||
|
||||
|
||||
_run_setup:
|
||||
@echo "$(YELLOW)Starting backend server...$(RESET)"
|
||||
@cd app && FRONTEND_DIRECTORY=$(OPENHANDS_FRONTEND_PATH) poetry run uvicorn saas_server:app --host $(BACKEND_HOST) --port $(BACKEND_PORT) &
|
||||
@echo "$(YELLOW)Waiting for the backend to start...$(RESET)"
|
||||
@until nc -z localhost $(BACKEND_PORT); do sleep 0.1; done
|
||||
@echo "$(GREEN)Backend started successfully.$(RESET)"
|
||||
|
||||
run:
|
||||
@echo "$(YELLOW)Running the app...$(RESET)"
|
||||
@$(MAKE) -s _run_setup
|
||||
@cd $(OPENHANDS) && $(MAKE) -s start-frontend
|
||||
@echo "$(GREEN)Application started successfully.$(RESET)"
|
||||
|
||||
# Start backend
|
||||
start-backend:
|
||||
@echo "$(YELLOW)Starting backend...$(RESET)"
|
||||
@echo "$(OPENHANDS_FRONTEND_PATH)"
|
||||
@cd app && FRONTEND_DIRECTORY=$(OPENHANDS_FRONTEND_PATH) poetry run uvicorn saas_server:app --host $(BACKEND_HOST) --port $(BACKEND_PORT) --reload-dir $(OPENHANDS_PATH) --reload --reload-dir ./ --reload-exclude "./workspace"
|
||||
|
||||
|
||||
lint:
|
||||
@poetry run pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||
44
enterprise/README.md
Normal file
44
enterprise/README.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Closed Source extension of Openhands proper (OSS)
|
||||
|
||||
The closed source (CSS) code in the `/app` directory builds on top of open source (OSS) code, extending its functionality. The CSS code is entangled with the OSS code in two ways
|
||||
|
||||
- CSS stacks on top of OSS. For example, the middleware in CSS is stacked right on top of the middlewares in OSS. In `SAAS`, the middleware from BOTH repos will be present and running (which can sometimes cause conflicts)
|
||||
|
||||
- CSS overrides the implementation in OSS (only one is present at a time). For example, the server config [`SaasServerConfig`](https://github.com/All-Hands-AI/deploy/blob/main/app/server/config.py#L43) which overrides [`ServerConfig`](https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/server/config/server_config.py#L8) on OSS. This is done through dynamic imports ([see here](https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/server/config/server_config.py#L37-#L45))
|
||||
|
||||
Key areas that change on `SAAS` are
|
||||
|
||||
- Authentication
|
||||
- User settings
|
||||
- etc
|
||||
|
||||
## Authentication
|
||||
|
||||
| Aspect | OSS | CSS |
|
||||
| ------------------------- | ------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Authentication Method** | User adds a personal access token (PAT) through the UI | User performs OAuth through the UI. The Github app provides a short-lived access token and refresh token |
|
||||
| **Token Storage** | PAT is stored in **Settings** | Token is stored in **GithubTokenManager** (a file store in our backend) |
|
||||
| **Authenticated status** | We simply check if token exists in `Settings` | We issue a signed cookie with `github_user_id` during oauth, so subsequent requests with the cookie can be considered authenticated |
|
||||
|
||||
Note that in the future, authentication will happen via keycloak. All modifications for authentication will happen in CSS.
|
||||
|
||||
## GitHub Service
|
||||
|
||||
The github service is responsible for interacting with Github APIs. As a consequence, it uses the user's token and refreshes it if need be
|
||||
|
||||
| Aspect | OSS | CSS |
|
||||
| ------------------------- | -------------------------------------- | ---------------------------------------------- |
|
||||
| **Class used** | `GitHubService` | `SaaSGitHubService` |
|
||||
| **Token used** | User's PAT fetched from `Settings` | User's token fetched from `GitHubTokenManager` |
|
||||
| **Refresh functionality** | **N/A**; user provides PAT for the app | Uses the `GitHubTokenManager` to refresh |
|
||||
|
||||
NOTE: in the future we will simply replace the `GithubTokenManager` with keycloak. The `SaaSGithubService` should interact with keycloack instead.
|
||||
|
||||
# Areas that are BRITTLE!
|
||||
|
||||
## User ID vs User Token
|
||||
|
||||
- On OSS, the entire APP revolves around the Github token the user sets. `openhands/server` uses `request.state.github_token` for the entire app
|
||||
- On CSS, the entire APP resolves around the Github User ID. This is because the cookie sets it, so `openhands/server` AND `deploy/app/server` depend on it and completly ignore `request.state.github_token` (token is fetched from `GithubTokenManager` instead)
|
||||
|
||||
Note that introducing Github User ID on OSS, for instance, will cause large breakages.
|
||||
1
enterprise/__init__.py
Normal file
1
enterprise/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# App package for OpenHands
|
||||
79
enterprise/alembic.ini
Normal file
79
enterprise/alembic.ini
Normal file
@@ -0,0 +1,79 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = migrations
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library.
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = DEBUG
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = DEBUG
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = DEBUG
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
2770
enterprise/allhands-realm-github-provider.json.tmpl
Normal file
2770
enterprise/allhands-realm-github-provider.json.tmpl
Normal file
File diff suppressed because it is too large
Load Diff
56
enterprise/dev_config/python/.pre-commit-config.yaml
Normal file
56
enterprise/dev_config/python/.pre-commit-config.yaml
Normal file
@@ -0,0 +1,56 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude: docs/modules/python
|
||||
files: ^enterprise/
|
||||
- id: end-of-file-fixer
|
||||
exclude: docs/modules/python
|
||||
files: ^enterprise/
|
||||
- id: check-yaml
|
||||
files: ^enterprise/
|
||||
- id: debug-statements
|
||||
files: ^enterprise/
|
||||
- repo: https://github.com/abravalheri/validate-pyproject
|
||||
rev: v0.16
|
||||
hooks:
|
||||
- id: validate-pyproject
|
||||
types: [toml]
|
||||
files: ^enterprise/pyproject\.toml$
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.4.1
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
entry: ruff check --config enterprise/dev_config/python/ruff.toml
|
||||
types_or: [python, pyi, jupyter]
|
||||
args: [--fix]
|
||||
files: ^enterprise/
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
entry: ruff format --config enterprise/dev_config/python/ruff.toml
|
||||
types_or: [python, pyi, jupyter]
|
||||
files: ^enterprise/
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.9.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
- types-requests
|
||||
- types-setuptools
|
||||
- types-pyyaml
|
||||
- types-toml
|
||||
- types-redis
|
||||
- lxml
|
||||
# TODO: Add OpenHands in parent
|
||||
- stripe==11.5.0
|
||||
- pygithub==2.6.1
|
||||
# To see gaps add `--html-report mypy-report/`
|
||||
entry: mypy --config-file enterprise/dev_config/python/mypy.ini enterprise/
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
files: ^enterprise/
|
||||
21
enterprise/dev_config/python/mypy.ini
Normal file
21
enterprise/dev_config/python/mypy.ini
Normal file
@@ -0,0 +1,21 @@
|
||||
[mypy]
|
||||
warn_unused_configs = True
|
||||
ignore_missing_imports = True
|
||||
check_untyped_defs = True
|
||||
explicit_package_bases = True
|
||||
warn_unreachable = True
|
||||
warn_redundant_casts = True
|
||||
no_implicit_optional = True
|
||||
strict_optional = True
|
||||
exclude = (^enterprise/migrations/.*|^openhands/.*)
|
||||
|
||||
[mypy-enterprise.tests.unit.test_auth_routes.*]
|
||||
disable_error_code = union-attr
|
||||
|
||||
[mypy-enterprise.sync.install_gitlab_webhooks.*]
|
||||
disable_error_code = redundant-cast
|
||||
|
||||
# Let the other config check base openhands packages
|
||||
[mypy-openhands.*]
|
||||
follow_imports = skip
|
||||
ignore_missing_imports = True
|
||||
31
enterprise/dev_config/python/ruff.toml
Normal file
31
enterprise/dev_config/python/ruff.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[lint]
|
||||
select = [
|
||||
"E",
|
||||
"W",
|
||||
"F",
|
||||
"I",
|
||||
"Q",
|
||||
"B",
|
||||
]
|
||||
|
||||
ignore = [
|
||||
"E501",
|
||||
"B003",
|
||||
"B007",
|
||||
"B008", # Allow function calls in argument defaults (FastAPI Query pattern)
|
||||
"B009",
|
||||
"B010",
|
||||
"B904",
|
||||
"B018",
|
||||
]
|
||||
|
||||
exclude = [
|
||||
"app/migrations/*"
|
||||
]
|
||||
|
||||
[lint.flake8-quotes]
|
||||
docstring-quotes = "double"
|
||||
inline-quotes = "single"
|
||||
|
||||
[format]
|
||||
quote-style = "single"
|
||||
0
enterprise/experiments/__init__.py
Normal file
0
enterprise/experiments/__init__.py
Normal file
47
enterprise/experiments/constants.py
Normal file
47
enterprise/experiments/constants.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
|
||||
import posthog
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
# Initialize PostHog
|
||||
posthog.api_key = os.environ.get('POSTHOG_CLIENT_KEY', 'phc_placeholder')
|
||||
posthog.host = os.environ.get('POSTHOG_HOST', 'https://us.i.posthog.com')
|
||||
|
||||
# Log PostHog configuration with masked API key for security
|
||||
api_key = posthog.api_key
|
||||
if api_key and len(api_key) > 8:
|
||||
masked_key = f'{api_key[:4]}...{api_key[-4:]}'
|
||||
else:
|
||||
masked_key = 'not_set_or_too_short'
|
||||
logger.info('posthog_configuration', extra={'posthog_api_key_masked': masked_key})
|
||||
|
||||
# Global toggle for the experiment manager
|
||||
ENABLE_EXPERIMENT_MANAGER = (
|
||||
os.environ.get('ENABLE_EXPERIMENT_MANAGER', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
# Get the current experiment type from environment variable
|
||||
# If None, no experiment is running
|
||||
EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT = os.environ.get(
|
||||
'EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT', ''
|
||||
)
|
||||
# System prompt experiment toggle
|
||||
EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT = os.environ.get(
|
||||
'EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT', ''
|
||||
)
|
||||
|
||||
EXPERIMENT_CLAUDE4_VS_GPT5 = os.environ.get('EXPERIMENT_CLAUDE4_VS_GPT5', '')
|
||||
|
||||
EXPERIMENT_CONDENSER_MAX_STEP = os.environ.get('EXPERIMENT_CONDENSER_MAX_STEP', '')
|
||||
|
||||
logger.info(
|
||||
'experiment_manager:run_conversation_variant_test:experiment_config',
|
||||
extra={
|
||||
'enable_experiment_manager': ENABLE_EXPERIMENT_MANAGER,
|
||||
'experiment_litellm_default_model_experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||
'experiment_system_prompt_experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
'experiment_claude4_vs_gpt5_experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
'experiment_condenser_max_step': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
},
|
||||
)
|
||||
90
enterprise/experiments/experiment_manager.py
Normal file
90
enterprise/experiments/experiment_manager.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from experiments.constants import (
|
||||
ENABLE_EXPERIMENT_MANAGER,
|
||||
)
|
||||
from experiments.experiment_versions import (
|
||||
handle_claude4_vs_gpt5_experiment,
|
||||
handle_condenser_max_step_experiment,
|
||||
handle_system_prompt_experiment,
|
||||
)
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.experiments.experiment_manager import ExperimentManager
|
||||
|
||||
|
||||
class SaaSExperimentManager(ExperimentManager):
|
||||
@staticmethod
|
||||
def run_conversation_variant_test(user_id, conversation_id, conversation_settings):
|
||||
"""
|
||||
Run conversation variant test and potentially modify the conversation settings
|
||||
based on the PostHog feature flags.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
conversation_settings: The conversation settings that may include convo_id and llm_model
|
||||
|
||||
Returns:
|
||||
The modified conversation settings
|
||||
"""
|
||||
logger.debug(
|
||||
'experiment_manager:run_conversation_variant_test:started',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Skip all experiment processing if the experiment manager is disabled
|
||||
if not ENABLE_EXPERIMENT_MANAGER:
|
||||
logger.info(
|
||||
'experiment_manager:run_conversation_variant_test:skipped',
|
||||
extra={'reason': 'experiment_manager_disabled'},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
# Apply conversation-scoped experiments
|
||||
conversation_settings = handle_claude4_vs_gpt5_experiment(
|
||||
user_id, conversation_id, conversation_settings
|
||||
)
|
||||
conversation_settings = handle_condenser_max_step_experiment(
|
||||
user_id, conversation_id, conversation_settings
|
||||
)
|
||||
|
||||
return conversation_settings
|
||||
|
||||
@staticmethod
|
||||
def run_config_variant_test(
|
||||
user_id: str, conversation_id: str, config: OpenHandsConfig
|
||||
):
|
||||
"""
|
||||
Run agent config variant test and potentially modify the OpenHands config
|
||||
based on the current experiment type and PostHog feature flags.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
config: The OpenHands configuration
|
||||
|
||||
Returns:
|
||||
The modified OpenHands configuration
|
||||
"""
|
||||
logger.info(
|
||||
'experiment_manager:run_config_variant_test:started',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Skip all experiment processing if the experiment manager is disabled
|
||||
if not ENABLE_EXPERIMENT_MANAGER:
|
||||
logger.info(
|
||||
'experiment_manager:run_config_variant_test:skipped',
|
||||
extra={'reason': 'experiment_manager_disabled'},
|
||||
)
|
||||
return config
|
||||
|
||||
# Pass the entire OpenHands config to the system prompt experiment
|
||||
# Let the experiment handler directly modify the config as needed
|
||||
modified_config = handle_system_prompt_experiment(
|
||||
user_id, conversation_id, config
|
||||
)
|
||||
|
||||
# Condenser max step experiment is applied via conversation variant test,
|
||||
# not config variant test. Return modified config from system prompt only.
|
||||
return modified_config
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
LiteLLM model experiment handler.
|
||||
|
||||
This module contains the handler for the LiteLLM model experiment.
|
||||
"""
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT
|
||||
from server.constants import (
|
||||
IS_FEATURE_ENV,
|
||||
build_litellm_proxy_model_path,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def handle_litellm_default_model_experiment(
|
||||
user_id, conversation_id, conversation_settings
|
||||
):
|
||||
"""
|
||||
Handle the LiteLLM model experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
conversation_settings: The conversation settings
|
||||
|
||||
Returns:
|
||||
Modified conversation settings
|
||||
"""
|
||||
# No-op if the specific experiment is not enabled
|
||||
if not EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT:
|
||||
logger.info(
|
||||
'experiment_manager:ab_testing:skipped',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'reason': 'experiment_not_enabled',
|
||||
'experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||
},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
# Use experiment name as the flag key
|
||||
try:
|
||||
enabled_variant = posthog.get_feature_flag(
|
||||
EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT, conversation_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:get_feature_flag:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
# Log the experiment event
|
||||
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=posthog_user_id,
|
||||
event='model_set',
|
||||
properties={
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:posthog_capture:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Continue execution as this is not critical
|
||||
|
||||
logger.info(
|
||||
'posthog_capture',
|
||||
extra={
|
||||
'event': 'model_set',
|
||||
'posthog_user_id': posthog_user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
# Set the model based on the feature flag variant
|
||||
if enabled_variant == 'claude37':
|
||||
# Use the shared utility to construct the LiteLLM proxy model path
|
||||
model = build_litellm_proxy_model_path('claude-3-7-sonnet-20250219')
|
||||
# Update the conversation settings with the selected model
|
||||
conversation_settings.llm_model = model
|
||||
else:
|
||||
# Update the conversation settings with the default model for the current version
|
||||
conversation_settings.llm_model = get_default_litellm_model()
|
||||
|
||||
return conversation_settings
|
||||
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
System prompt experiment handler.
|
||||
|
||||
This module contains the handler for the system prompt experiment that uses
|
||||
the PostHog variant as the system prompt filename.
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def _get_system_prompt_variant(user_id, conversation_id):
|
||||
"""
|
||||
Get the system prompt variant for the experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
|
||||
Returns:
|
||||
str or None: The PostHog variant name or None if experiment is not enabled or error occurs
|
||||
"""
|
||||
# No-op if the specific experiment is not enabled
|
||||
if not EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
|
||||
logger.info(
|
||||
'experiment_manager_002:ab_testing:skipped',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'reason': 'experiment_not_enabled',
|
||||
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Use experiment name as the flag key
|
||||
try:
|
||||
enabled_variant = posthog.get_feature_flag(
|
||||
EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT, conversation_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:get_feature_flag:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Store the experiment assignment in the database
|
||||
try:
|
||||
experiment_store = ExperimentAssignmentStore()
|
||||
experiment_store.update_experiment_variant(
|
||||
conversation_id=conversation_id,
|
||||
experiment_name='system_prompt_experiment',
|
||||
variant=enabled_variant,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:store_assignment:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
'variant': enabled_variant,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Fail the experiment if we cannot track the splits - results would not be explainable
|
||||
return None
|
||||
|
||||
# Log the experiment event
|
||||
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=posthog_user_id,
|
||||
event='system_prompt_set',
|
||||
properties={
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:posthog_capture:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Continue execution as this is not critical
|
||||
|
||||
logger.info(
|
||||
'posthog_capture',
|
||||
extra={
|
||||
'event': 'system_prompt_set',
|
||||
'posthog_user_id': posthog_user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
return enabled_variant
|
||||
|
||||
|
||||
def handle_system_prompt_experiment(
|
||||
user_id, conversation_id, config: OpenHandsConfig
|
||||
) -> OpenHandsConfig:
|
||||
"""
|
||||
Handle the system prompt experiment for OpenHands config.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
config: The OpenHands configuration
|
||||
|
||||
Returns:
|
||||
Modified OpenHands configuration
|
||||
"""
|
||||
enabled_variant = _get_system_prompt_variant(user_id, conversation_id)
|
||||
|
||||
# If variant is None, experiment is not enabled or there was an error
|
||||
if enabled_variant is None:
|
||||
return config
|
||||
|
||||
# Deep copy the config to avoid modifying the original
|
||||
modified_config = copy.deepcopy(config)
|
||||
|
||||
# Set the system prompt filename based on the variant
|
||||
if enabled_variant == 'control':
|
||||
# Use the long-horizon system prompt for the control variant
|
||||
agent_config = modified_config.get_agent_config(modified_config.default_agent)
|
||||
agent_config.system_prompt_filename = 'system_prompt_long_horizon.j2'
|
||||
agent_config.enable_plan_mode = True
|
||||
elif enabled_variant == 'interactive':
|
||||
modified_config.get_agent_config(
|
||||
modified_config.default_agent
|
||||
).system_prompt_filename = 'system_prompt_interactive.j2'
|
||||
elif enabled_variant == 'no_tools':
|
||||
modified_config.get_agent_config(
|
||||
modified_config.default_agent
|
||||
).system_prompt_filename = 'system_prompt.j2'
|
||||
else:
|
||||
logger.error(
|
||||
'system_prompt_experiment:unknown_variant',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'reason': 'no explicit mapping; returning original config',
|
||||
},
|
||||
)
|
||||
return config
|
||||
|
||||
# Log which prompt is being used
|
||||
logger.info(
|
||||
'system_prompt_experiment:prompt_selected',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'system_prompt_filename': modified_config.get_agent_config(
|
||||
modified_config.default_agent
|
||||
).system_prompt_filename,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
return modified_config
|
||||
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
LiteLLM model experiment handler.
|
||||
|
||||
This module contains the handler for the LiteLLM model experiment.
|
||||
"""
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_CLAUDE4_VS_GPT5
|
||||
from server.constants import (
|
||||
IS_FEATURE_ENV,
|
||||
build_litellm_proxy_model_path,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def _get_model_variant(user_id, conversation_id) -> str | None:
|
||||
if not EXPERIMENT_CLAUDE4_VS_GPT5:
|
||||
logger.info(
|
||||
'experiment_manager:ab_testing:skipped',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'reason': 'experiment_not_enabled',
|
||||
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
enabled_variant = posthog.get_feature_flag(
|
||||
EXPERIMENT_CLAUDE4_VS_GPT5, conversation_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:get_feature_flag:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Store the experiment assignment in the database
|
||||
try:
|
||||
experiment_store = ExperimentAssignmentStore()
|
||||
experiment_store.update_experiment_variant(
|
||||
conversation_id=conversation_id,
|
||||
experiment_name='claude4_vs_gpt5_experiment',
|
||||
variant=enabled_variant,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:store_assignment:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
'variant': enabled_variant,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Fail the experiment if we cannot track the splits - results would not be explainable
|
||||
return None
|
||||
|
||||
# Log the experiment event
|
||||
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=posthog_user_id,
|
||||
event='claude4_or_gpt5_set',
|
||||
properties={
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:posthog_capture:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Continue execution as this is not critical
|
||||
|
||||
logger.info(
|
||||
'posthog_capture',
|
||||
extra={
|
||||
'event': 'claude4_or_gpt5_set',
|
||||
'posthog_user_id': posthog_user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
return enabled_variant
|
||||
|
||||
|
||||
def handle_claude4_vs_gpt5_experiment(user_id, conversation_id, conversation_settings):
|
||||
"""
|
||||
Handle the LiteLLM model experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
conversation_settings: The conversation settings
|
||||
|
||||
Returns:
|
||||
Modified conversation settings
|
||||
"""
|
||||
|
||||
enabled_variant = _get_model_variant(user_id, conversation_id)
|
||||
|
||||
if not enabled_variant:
|
||||
return None
|
||||
|
||||
# Set the model based on the feature flag variant
|
||||
if enabled_variant == 'gpt5':
|
||||
model = build_litellm_proxy_model_path('gpt-5-2025-08-07')
|
||||
conversation_settings.llm_model = model
|
||||
else:
|
||||
conversation_settings.llm_model = get_default_litellm_model()
|
||||
|
||||
return conversation_settings
|
||||
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Condenser max step experiment handler.
|
||||
|
||||
This module contains the handler for the condenser max step experiment that tests
|
||||
different max_size values for the condenser configuration.
|
||||
"""
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_CONDENSER_MAX_STEP
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def _get_condenser_max_step_variant(user_id, conversation_id):
|
||||
"""
|
||||
Get the condenser max step variant for the experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
|
||||
Returns:
|
||||
str or None: The PostHog variant name or None if experiment is not enabled or error occurs
|
||||
"""
|
||||
# No-op if the specific experiment is not enabled
|
||||
if not EXPERIMENT_CONDENSER_MAX_STEP:
|
||||
logger.info(
|
||||
'experiment_manager_004:ab_testing:skipped',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'reason': 'experiment_not_enabled',
|
||||
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Use experiment name as the flag key
|
||||
try:
|
||||
enabled_variant = posthog.get_feature_flag(
|
||||
EXPERIMENT_CONDENSER_MAX_STEP, conversation_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:get_feature_flag:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Store the experiment assignment in the database
|
||||
try:
|
||||
experiment_store = ExperimentAssignmentStore()
|
||||
experiment_store.update_experiment_variant(
|
||||
conversation_id=conversation_id,
|
||||
experiment_name='condenser_max_step_experiment',
|
||||
variant=enabled_variant,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:store_assignment:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
'variant': enabled_variant,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Fail the experiment if we cannot track the splits - results would not be explainable
|
||||
return None
|
||||
|
||||
# Log the experiment event
|
||||
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=posthog_user_id,
|
||||
event='condenser_max_step_set',
|
||||
properties={
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:posthog_capture:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Continue execution as this is not critical
|
||||
|
||||
logger.info(
|
||||
'posthog_capture',
|
||||
extra={
|
||||
'event': 'condenser_max_step_set',
|
||||
'posthog_user_id': posthog_user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
return enabled_variant
|
||||
|
||||
|
||||
def handle_condenser_max_step_experiment(
|
||||
user_id: str, conversation_id: str, conversation_settings
|
||||
):
|
||||
"""
|
||||
Handle the condenser max step experiment for conversation settings.
|
||||
|
||||
We should not modify persistent user settings. Instead, apply the experiment
|
||||
variant to the conversation's in-memory settings object for this session only.
|
||||
|
||||
Variants:
|
||||
- control -> condenser_max_size = 120
|
||||
- treatment -> condenser_max_size = 80
|
||||
|
||||
Returns the (potentially) modified conversation_settings.
|
||||
"""
|
||||
|
||||
enabled_variant = _get_condenser_max_step_variant(user_id, conversation_id)
|
||||
|
||||
if enabled_variant is None:
|
||||
return conversation_settings
|
||||
|
||||
if enabled_variant == 'control':
|
||||
condenser_max_size = 120
|
||||
elif enabled_variant == 'treatment':
|
||||
condenser_max_size = 80
|
||||
else:
|
||||
logger.error(
|
||||
'condenser_max_step_experiment:unknown_variant',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'reason': 'unknown variant; returning original conversation settings',
|
||||
},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
try:
|
||||
# Apply the variant to this conversation only; do not persist to DB.
|
||||
# Not all OpenHands versions expose `condenser_max_size` on settings.
|
||||
if hasattr(conversation_settings, 'condenser_max_size'):
|
||||
conversation_settings.condenser_max_size = condenser_max_size
|
||||
logger.info(
|
||||
'condenser_max_step_experiment:conversation_settings_applied',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'condenser_max_size': condenser_max_size,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'condenser_max_step_experiment:field_missing_on_settings',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'reason': 'condenser_max_size not present on ConversationInitData',
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'condenser_max_step_experiment:apply_failed',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
return conversation_settings
|
||||
25
enterprise/experiments/experiment_versions/__init__.py
Normal file
25
enterprise/experiments/experiment_versions/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Experiment versions package.
|
||||
|
||||
This package contains handlers for different experiment versions.
|
||||
"""
|
||||
|
||||
from experiments.experiment_versions._001_litellm_default_model_experiment import (
|
||||
handle_litellm_default_model_experiment,
|
||||
)
|
||||
from experiments.experiment_versions._002_system_prompt_experiment import (
|
||||
handle_system_prompt_experiment,
|
||||
)
|
||||
from experiments.experiment_versions._003_llm_claude4_vs_gpt5_experiment import (
|
||||
handle_claude4_vs_gpt5_experiment,
|
||||
)
|
||||
from experiments.experiment_versions._004_condenser_max_step_experiment import (
|
||||
handle_condenser_max_step_experiment,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'handle_litellm_default_model_experiment',
|
||||
'handle_system_prompt_experiment',
|
||||
'handle_claude4_vs_gpt5_experiment',
|
||||
'handle_condenser_max_step_experiment',
|
||||
]
|
||||
0
enterprise/integrations/__init__.py
Normal file
0
enterprise/integrations/__init__.py
Normal file
70
enterprise/integrations/bitbucket/bitbucket_service.py
Normal file
70
enterprise/integrations/bitbucket/bitbucket_service.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.bitbucket.bitbucket_service import BitBucketService
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
|
||||
|
||||
class SaaSBitBucketService(BitBucketService):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
external_auth_id: str | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
base_domain: str | None = None,
|
||||
):
|
||||
logger.info(
|
||||
f'SaaSBitBucketService created with user_id {user_id}, external_auth_id {external_auth_id}, external_auth_token {'set' if external_auth_token else 'None'}, bitbucket_token {'set' if token else 'None'}, external_token_manager {external_token_manager}'
|
||||
)
|
||||
super().__init__(
|
||||
user_id=user_id,
|
||||
external_auth_token=external_auth_token,
|
||||
external_auth_id=external_auth_id,
|
||||
token=token,
|
||||
external_token_manager=external_token_manager,
|
||||
base_domain=base_domain,
|
||||
)
|
||||
|
||||
self.external_auth_token = external_auth_token
|
||||
self.external_auth_id = external_auth_id
|
||||
self.token_manager = TokenManager(external=external_token_manager)
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
bitbucket_token = None
|
||||
if self.external_auth_token:
|
||||
bitbucket_token = SecretStr(
|
||||
await self.token_manager.get_idp_token(
|
||||
self.external_auth_token.get_secret_value(),
|
||||
idp=ProviderType.BITBUCKET,
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got BitBucket token {bitbucket_token} from access token: {self.external_auth_token}'
|
||||
)
|
||||
elif self.external_auth_id:
|
||||
offline_token = await self.token_manager.load_offline_token(
|
||||
self.external_auth_id
|
||||
)
|
||||
bitbucket_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_offline_token(
|
||||
offline_token, ProviderType.BITBUCKET
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f'Got BitBucket token {bitbucket_token.get_secret_value()} from external auth user ID: {self.external_auth_id}'
|
||||
)
|
||||
elif self.user_id:
|
||||
bitbucket_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_idp_user_id(
|
||||
self.user_id, ProviderType.BITBUCKET
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got BitBucket token {bitbucket_token} from user ID: {self.user_id}'
|
||||
)
|
||||
else:
|
||||
logger.warning('external_auth_token and user_id not set!')
|
||||
return bitbucket_token
|
||||
692
enterprise/integrations/github/data_collector.py
Normal file
692
enterprise/integrations/github/data_collector.py
Normal file
@@ -0,0 +1,692 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from github import Github, GithubIntegration
|
||||
from integrations.github.github_view import (
|
||||
GithubIssue,
|
||||
)
|
||||
from integrations.github.queries import PR_QUERY_BY_NODE_ID
|
||||
from integrations.models import Message
|
||||
from integrations.types import PRStatus, ResolverViewInterface
|
||||
from integrations.utils import HOST
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.openhands_pr_store import OpenhandsPRStore
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.locations import get_conversation_dir
|
||||
|
||||
config = load_openhands_config()
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
|
||||
|
||||
COLLECT_GITHUB_INTERACTIONS = (
|
||||
os.getenv('COLLECT_GITHUB_INTERACTIONS', 'false') == 'true'
|
||||
)
|
||||
|
||||
|
||||
class TriggerType(str, Enum):
|
||||
ISSUE_LABEL = 'issue-label'
|
||||
ISSUE_COMMENT = 'issue-coment'
|
||||
PR_COMMENT_MACRO = 'label'
|
||||
INLINE_PR_COMMENT_MACRO = 'inline-label'
|
||||
|
||||
|
||||
class GitHubDataCollector:
|
||||
"""
|
||||
Saves data on Cloud Resolver Interactions
|
||||
|
||||
1. We always save
|
||||
- Resolver trigger (comment or label)
|
||||
- Metadata (who started the job, repo name, issue number)
|
||||
|
||||
2. We save data for the type of interaction
|
||||
a. For labelled issues, we save
|
||||
- {conversation_dir}/{conversation_id}/github_data/issue__{repo_name}_{issue_number}.json
|
||||
- issue number
|
||||
- trigger
|
||||
- metadata
|
||||
- body
|
||||
- title
|
||||
- comments
|
||||
|
||||
- {conversation_dir}/{conversation_id}/github_data/pr__{repo_name}_{pr_number}.json
|
||||
- pr_number
|
||||
- metadata
|
||||
- body
|
||||
- title
|
||||
- commits/authors
|
||||
|
||||
3. For all PRs that were opened with the resolver, we save
|
||||
- github_data/prs/{repo_name}_{pr_number}/data.json
|
||||
- pr_number
|
||||
- title
|
||||
- body
|
||||
- commits/authors
|
||||
- code diffs
|
||||
- merge status (either merged/closed)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.file_store = file_store
|
||||
self.issues_path = 'github_data/issue-{}-{}/data.json'
|
||||
self.matching_pr_path = 'github_data/pr-{}-{}/data.json'
|
||||
# self.full_saved_pr_path = 'github_data/prs/{}-{}/data.json'
|
||||
self.full_saved_pr_path = 'prs/github/{}-{}/data.json'
|
||||
self.github_integration = GithubIntegration(
|
||||
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
)
|
||||
self.conversation_id = None
|
||||
|
||||
async def _get_repo_node_id(self, repo_id: str, gh_client) -> str:
|
||||
"""
|
||||
Get the new GitHub GraphQL node ID for a repository using the GitHub client.
|
||||
|
||||
Args:
|
||||
repo_id: Numeric repository ID as string (e.g., "123456789")
|
||||
gh_client: SaaSGitHubService client with authentication
|
||||
|
||||
Returns:
|
||||
New format node ID for GraphQL queries (e.g., "R_kgDOLfkiww")
|
||||
"""
|
||||
try:
|
||||
return await gh_client.get_repository_node_id(repo_id)
|
||||
except Exception:
|
||||
# Fallback to old format if REST API fails
|
||||
node_string = f'010:Repository{repo_id}'
|
||||
return base64.b64encode(node_string.encode()).decode()
|
||||
|
||||
def _create_file_name(
|
||||
self, path: str, repo_id: str, number: int, conversation_id: str | None
|
||||
):
|
||||
suffix = path.format(repo_id, number)
|
||||
|
||||
if conversation_id:
|
||||
return f'{get_conversation_dir(conversation_id)}{suffix}'
|
||||
|
||||
return suffix
|
||||
|
||||
def _get_installation_access_token(self, installation_id: str) -> str:
|
||||
token_data = self.github_integration.get_access_token(
|
||||
installation_id # type: ignore[arg-type]
|
||||
)
|
||||
return token_data.token
|
||||
|
||||
def _check_openhands_author(self, name, login) -> bool:
|
||||
return (
|
||||
name == 'openhands'
|
||||
or login == 'openhands'
|
||||
or login == 'openhands-agent'
|
||||
or login == 'openhands-ai'
|
||||
or login == 'openhands-staging'
|
||||
or login == 'openhands-exp'
|
||||
or (login and 'openhands' in login.lower())
|
||||
)
|
||||
|
||||
def _get_issue_comments(
|
||||
self, installation_id: str, repo_name: str, issue_number: int, conversation_id
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve all comments from an issue until a comment with conversation_id is found
|
||||
"""
|
||||
|
||||
try:
|
||||
installation_token = self._get_installation_access_token(installation_id)
|
||||
|
||||
with Github(installation_token) as github_client:
|
||||
repo = github_client.get_repo(repo_name)
|
||||
issue = repo.get_issue(issue_number)
|
||||
comments = []
|
||||
|
||||
for comment in issue.get_comments():
|
||||
comment_data = {
|
||||
'id': comment.id,
|
||||
'body': comment.body,
|
||||
'created_at': comment.created_at.isoformat(),
|
||||
'user': comment.user.login,
|
||||
}
|
||||
|
||||
# If we find a comment containing conversation_id, stop collecting comments
|
||||
if conversation_id in comment.body:
|
||||
break
|
||||
|
||||
comments.append(comment_data)
|
||||
|
||||
return comments
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _save_data(self, path: str, data: dict[str, Any]):
|
||||
"""Save data to a path"""
|
||||
self.file_store.write(path, json.dumps(data))
|
||||
|
||||
def _save_issue(
|
||||
self,
|
||||
github_view: GithubIssue,
|
||||
trigger_type: TriggerType,
|
||||
) -> None:
|
||||
"""
|
||||
Save issue data when it's labeled with openhands
|
||||
|
||||
1. Save under {conversation_dir}/{conversation_id}/github_data/issue_{issue_number}.json
|
||||
2. Save issue snapshot (title, body, comments)
|
||||
3. Save trigger type (label)
|
||||
4. Save PR opened (if exists, this information comes later when agent has finished its task)
|
||||
- Save commit shas
|
||||
- Save author info
|
||||
5. Was PR merged or closed
|
||||
"""
|
||||
|
||||
conversation_id = github_view.conversation_id
|
||||
|
||||
if not conversation_id:
|
||||
return
|
||||
|
||||
issue_number = github_view.issue_number
|
||||
file_name = self._create_file_name(
|
||||
path=self.issues_path,
|
||||
repo_id=github_view.full_repo_name,
|
||||
number=issue_number,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
payload_data = github_view.raw_payload.message.get('payload', {})
|
||||
isssue_details = payload_data.get('issue', {})
|
||||
is_repo_private = payload_data.get('repository', {}).get('private', 'true')
|
||||
title = isssue_details.get('title', '')
|
||||
body = isssue_details.get('body', '')
|
||||
|
||||
# Get comments for the issue
|
||||
comments = self._get_issue_comments(
|
||||
github_view.installation_id,
|
||||
github_view.full_repo_name,
|
||||
issue_number,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
data = {
|
||||
'trigger': trigger_type,
|
||||
'metadata': {
|
||||
'user': github_view.user_info.username,
|
||||
'repo_name': github_view.full_repo_name,
|
||||
'is_repo_private': is_repo_private,
|
||||
'number': issue_number,
|
||||
},
|
||||
'contents': {
|
||||
'title': title,
|
||||
'body': body,
|
||||
'comments': comments,
|
||||
},
|
||||
}
|
||||
|
||||
self._save_data(file_name, data)
|
||||
logger.info(
|
||||
f'[Github]: Saved issue #{issue_number} for {github_view.full_repo_name}'
|
||||
)
|
||||
|
||||
def _get_pr_commits(self, installation_id: str, repo_name: str, pr_number: int):
|
||||
commits = []
|
||||
installation_token = self._get_installation_access_token(installation_id)
|
||||
with Github(installation_token) as github_client:
|
||||
repo = github_client.get_repo(repo_name)
|
||||
pr = repo.get_pull(pr_number)
|
||||
|
||||
for commit in pr.get_commits():
|
||||
commit_data = {
|
||||
'sha': commit.sha,
|
||||
'authors': commit.author.login if commit.author else None,
|
||||
'committed_date': commit.commit.committer.date.isoformat()
|
||||
if commit.commit and commit.commit.committer
|
||||
else None,
|
||||
}
|
||||
commits.append(commit_data)
|
||||
|
||||
return commits
|
||||
|
||||
def _extract_repo_metadata(self, repo_data: dict) -> dict:
|
||||
"""Extract repository metadata from GraphQL response"""
|
||||
return {
|
||||
'name': repo_data.get('name'),
|
||||
'owner': repo_data.get('owner', {}).get('login'),
|
||||
'languages': [
|
||||
lang['name'] for lang in repo_data.get('languages', {}).get('nodes', [])
|
||||
],
|
||||
}
|
||||
|
||||
def _process_commits_page(self, pr_data: dict, commits: list) -> None:
|
||||
"""Process commits from a single GraphQL page"""
|
||||
commit_nodes = pr_data.get('commits', {}).get('nodes', [])
|
||||
for commit_node in commit_nodes:
|
||||
commit = commit_node['commit']
|
||||
author_info = commit.get('author', {})
|
||||
commit_data = {
|
||||
'sha': commit['oid'],
|
||||
'message': commit['message'],
|
||||
'committed_date': commit.get('committedDate'),
|
||||
'author': {
|
||||
'name': author_info.get('name'),
|
||||
'email': author_info.get('email'),
|
||||
'github_login': author_info.get('user', {}).get('login'),
|
||||
},
|
||||
'stats': {
|
||||
'additions': commit.get('additions', 0),
|
||||
'deletions': commit.get('deletions', 0),
|
||||
'changed_files': commit.get('changedFiles', 0),
|
||||
},
|
||||
}
|
||||
commits.append(commit_data)
|
||||
|
||||
def _process_pr_comments_page(self, pr_data: dict, pr_comments: list) -> None:
|
||||
"""Process PR comments from a single GraphQL page"""
|
||||
comment_nodes = pr_data.get('comments', {}).get('nodes', [])
|
||||
for comment in comment_nodes:
|
||||
comment_data = {
|
||||
'author': comment.get('author', {}).get('login'),
|
||||
'body': comment.get('body'),
|
||||
'created_at': comment.get('createdAt'),
|
||||
'type': 'pr_comment',
|
||||
}
|
||||
pr_comments.append(comment_data)
|
||||
|
||||
def _process_review_comments_page(
|
||||
self, pr_data: dict, review_comments: list
|
||||
) -> None:
|
||||
"""Process reviews and review comments from a single GraphQL page"""
|
||||
review_nodes = pr_data.get('reviews', {}).get('nodes', [])
|
||||
for review in review_nodes:
|
||||
# Add the review itself if it has a body
|
||||
if review.get('body', '').strip():
|
||||
review_data = {
|
||||
'author': review.get('author', {}).get('login'),
|
||||
'body': review.get('body'),
|
||||
'created_at': review.get('createdAt'),
|
||||
'state': review.get('state'),
|
||||
'type': 'review',
|
||||
}
|
||||
review_comments.append(review_data)
|
||||
|
||||
# Add individual review comments
|
||||
review_comment_nodes = review.get('comments', {}).get('nodes', [])
|
||||
for review_comment in review_comment_nodes:
|
||||
review_comment_data = {
|
||||
'author': review_comment.get('author', {}).get('login'),
|
||||
'body': review_comment.get('body'),
|
||||
'created_at': review_comment.get('createdAt'),
|
||||
'type': 'review_comment',
|
||||
}
|
||||
review_comments.append(review_comment_data)
|
||||
|
||||
def _count_openhands_activity(
|
||||
self, commits: list, review_comments: list, pr_comments: list
|
||||
) -> tuple[int, int, int]:
|
||||
"""Count OpenHands commits, review comments, and general PR comments"""
|
||||
openhands_commit_count = 0
|
||||
openhands_review_comment_count = 0
|
||||
openhands_general_comment_count = 0
|
||||
|
||||
# Count commits by OpenHands (check both name and login)
|
||||
for commit in commits:
|
||||
author = commit.get('author', {})
|
||||
author_name = author.get('name', '').lower()
|
||||
author_login = (
|
||||
author.get('github_login', '').lower()
|
||||
if author.get('github_login')
|
||||
else ''
|
||||
)
|
||||
|
||||
if self._check_openhands_author(author_name, author_login):
|
||||
openhands_commit_count += 1
|
||||
|
||||
# Count review comments by OpenHands
|
||||
for review_comment in review_comments:
|
||||
author_login = (
|
||||
review_comment.get('author', '').lower()
|
||||
if review_comment.get('author')
|
||||
else ''
|
||||
)
|
||||
author_name = '' # Initialize to avoid reference before assignment
|
||||
if self._check_openhands_author(author_name, author_login):
|
||||
openhands_review_comment_count += 1
|
||||
|
||||
# Count general PR comments by OpenHands
|
||||
for pr_comment in pr_comments:
|
||||
author_login = (
|
||||
pr_comment.get('author', '').lower() if pr_comment.get('author') else ''
|
||||
)
|
||||
author_name = '' # Initialize to avoid reference before assignment
|
||||
if self._check_openhands_author(author_name, author_login):
|
||||
openhands_general_comment_count += 1
|
||||
|
||||
return (
|
||||
openhands_commit_count,
|
||||
openhands_review_comment_count,
|
||||
openhands_general_comment_count,
|
||||
)
|
||||
|
||||
def _build_final_data_structure(
|
||||
self,
|
||||
repo_data: dict,
|
||||
pr_data: dict,
|
||||
commits: list,
|
||||
pr_comments: list,
|
||||
review_comments: list,
|
||||
openhands_commit_count: int,
|
||||
openhands_review_comment_count: int,
|
||||
openhands_general_comment_count: int = 0,
|
||||
) -> dict:
|
||||
"""Build the final data structure for JSON storage"""
|
||||
|
||||
is_merged = pr_data['merged']
|
||||
merged_by = None
|
||||
merge_commit_sha = None
|
||||
if is_merged:
|
||||
merged_by = pr_data.get('mergedBy', {}).get('login')
|
||||
merge_commit_sha = pr_data.get('mergeCommit', {}).get('oid')
|
||||
|
||||
return {
|
||||
'repo_metadata': self._extract_repo_metadata(repo_data),
|
||||
'pr_metadata': {
|
||||
'username': pr_data.get('author', {}).get('login'),
|
||||
'number': pr_data['number'],
|
||||
'title': pr_data['title'],
|
||||
'body': pr_data['body'],
|
||||
'comments': pr_comments,
|
||||
},
|
||||
'commits': commits,
|
||||
'review_comments': review_comments,
|
||||
'merge_status': {
|
||||
'merged': pr_data['merged'],
|
||||
'merged_by': merged_by,
|
||||
'state': pr_data['state'],
|
||||
'merge_commit_sha': merge_commit_sha,
|
||||
},
|
||||
'openhands_stats': {
|
||||
'num_commits': openhands_commit_count,
|
||||
'num_review_comments': openhands_review_comment_count,
|
||||
'num_general_comments': openhands_general_comment_count,
|
||||
'helped_author': openhands_commit_count > 0,
|
||||
},
|
||||
}
|
||||
|
||||
async def save_full_pr(self, openhands_pr: OpenhandsPR) -> None:
|
||||
"""
|
||||
Save PR information including metadata and commit details using GraphQL
|
||||
|
||||
Saves:
|
||||
- Repo metadata (repo name, languages, contributors)
|
||||
- PR metadata (number, title, body, author, comments)
|
||||
- Commit information (sha, authors, message, stats)
|
||||
- Merge status
|
||||
- Num openhands commits
|
||||
- Num openhands review comments
|
||||
"""
|
||||
pr_number = openhands_pr.pr_number
|
||||
installation_id = openhands_pr.installation_id
|
||||
repo_id = openhands_pr.repo_id
|
||||
|
||||
# Get installation token and create Github client
|
||||
# This will fail if the user decides to revoke OpenHands' access to their repo
|
||||
# In this case, we will simply return when the exception occurs
|
||||
# This will not lead to infinite loops when processing PRs as we log number of attempts and cap max attempts independently from this
|
||||
try:
|
||||
installation_token = self._get_installation_access_token(installation_id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Failed to generate token for {openhands_pr.repo_name}: {e}'
|
||||
)
|
||||
return
|
||||
|
||||
gh_client = GithubServiceImpl(token=SecretStr(installation_token))
|
||||
|
||||
# Get the new format GraphQL node ID
|
||||
node_id = await self._get_repo_node_id(repo_id, gh_client)
|
||||
|
||||
# Initialize data structures
|
||||
commits: list[dict] = []
|
||||
pr_comments: list[dict] = []
|
||||
review_comments: list[dict] = []
|
||||
pr_data = None
|
||||
repo_data = None
|
||||
|
||||
# Pagination cursors
|
||||
commits_after = None
|
||||
comments_after = None
|
||||
reviews_after = None
|
||||
|
||||
# Fetch all data with pagination
|
||||
while True:
|
||||
variables = {
|
||||
'nodeId': node_id,
|
||||
'pr_number': pr_number,
|
||||
'commits_after': commits_after,
|
||||
'comments_after': comments_after,
|
||||
'reviews_after': reviews_after,
|
||||
}
|
||||
|
||||
try:
|
||||
result = await gh_client.execute_graphql_query(
|
||||
PR_QUERY_BY_NODE_ID, variables
|
||||
)
|
||||
if not result.get('data', {}).get('node', {}).get('pullRequest'):
|
||||
break
|
||||
|
||||
pr_data = result['data']['node']['pullRequest']
|
||||
repo_data = result['data']['node']
|
||||
|
||||
# Process data from this page using modular methods
|
||||
self._process_commits_page(pr_data, commits)
|
||||
self._process_pr_comments_page(pr_data, pr_comments)
|
||||
self._process_review_comments_page(pr_data, review_comments)
|
||||
|
||||
# Check pagination for all three types
|
||||
has_more_commits = (
|
||||
pr_data.get('commits', {})
|
||||
.get('pageInfo', {})
|
||||
.get('hasNextPage', False)
|
||||
)
|
||||
has_more_comments = (
|
||||
pr_data.get('comments', {})
|
||||
.get('pageInfo', {})
|
||||
.get('hasNextPage', False)
|
||||
)
|
||||
has_more_reviews = (
|
||||
pr_data.get('reviews', {})
|
||||
.get('pageInfo', {})
|
||||
.get('hasNextPage', False)
|
||||
)
|
||||
|
||||
# Update cursors
|
||||
if has_more_commits:
|
||||
commits_after = (
|
||||
pr_data.get('commits', {}).get('pageInfo', {}).get('endCursor')
|
||||
)
|
||||
else:
|
||||
commits_after = None
|
||||
|
||||
if has_more_comments:
|
||||
comments_after = (
|
||||
pr_data.get('comments', {}).get('pageInfo', {}).get('endCursor')
|
||||
)
|
||||
else:
|
||||
comments_after = None
|
||||
|
||||
if has_more_reviews:
|
||||
reviews_after = (
|
||||
pr_data.get('reviews', {}).get('pageInfo', {}).get('endCursor')
|
||||
)
|
||||
else:
|
||||
reviews_after = None
|
||||
|
||||
# Continue if there's more data to fetch
|
||||
if not (has_more_commits or has_more_comments or has_more_reviews):
|
||||
break
|
||||
|
||||
except Exception:
|
||||
logger.warning('Error fetching PR data', exc_info=True)
|
||||
return
|
||||
|
||||
if not pr_data or not repo_data:
|
||||
return
|
||||
|
||||
# Count OpenHands activity using modular method
|
||||
(
|
||||
openhands_commit_count,
|
||||
openhands_review_comment_count,
|
||||
openhands_general_comment_count,
|
||||
) = self._count_openhands_activity(commits, review_comments, pr_comments)
|
||||
|
||||
logger.info(
|
||||
f'[Github]: PR #{pr_number} - OpenHands commits: {openhands_commit_count}, review comments: {openhands_review_comment_count}, general comments: {openhands_general_comment_count}'
|
||||
)
|
||||
logger.info(
|
||||
f'[Github]: PR #{pr_number} - Total collected: {len(commits)} commits, {len(pr_comments)} PR comments, {len(review_comments)} review comments'
|
||||
)
|
||||
|
||||
# Build final data structure using modular method
|
||||
data = self._build_final_data_structure(
|
||||
repo_data,
|
||||
pr_data or {},
|
||||
commits,
|
||||
pr_comments,
|
||||
review_comments,
|
||||
openhands_commit_count,
|
||||
openhands_review_comment_count,
|
||||
openhands_general_comment_count,
|
||||
)
|
||||
|
||||
# Update the OpenhandsPR object with OpenHands statistics
|
||||
store = OpenhandsPRStore.get_instance()
|
||||
openhands_helped_author = openhands_commit_count > 0
|
||||
|
||||
# Update the PR with OpenHands statistics
|
||||
update_success = store.update_pr_openhands_stats(
|
||||
repo_id=repo_id,
|
||||
pr_number=pr_number,
|
||||
original_updated_at=openhands_pr.updated_at,
|
||||
openhands_helped_author=openhands_helped_author,
|
||||
num_openhands_commits=openhands_commit_count,
|
||||
num_openhands_review_comments=openhands_review_comment_count,
|
||||
num_openhands_general_comments=openhands_general_comment_count,
|
||||
)
|
||||
|
||||
if not update_success:
|
||||
logger.warning(
|
||||
f'[Github]: Failed to update OpenHands stats for PR #{pr_number} in repo {repo_id} - PR may have been modified concurrently'
|
||||
)
|
||||
|
||||
# Save to file
|
||||
file_name = self._create_file_name(
|
||||
path=self.full_saved_pr_path,
|
||||
repo_id=repo_id,
|
||||
number=pr_number,
|
||||
conversation_id=None,
|
||||
)
|
||||
self._save_data(file_name, data)
|
||||
logger.info(
|
||||
f'[Github]: Saved full PR #{pr_number} for repo {repo_id} with OpenHands stats: commits={openhands_commit_count}, reviews={openhands_review_comment_count}, general_comments={openhands_general_comment_count}, helped={openhands_helped_author}'
|
||||
)
|
||||
|
||||
def _check_for_conversation_url(self, body):
|
||||
conversation_pattern = re.search(
|
||||
rf'https://{HOST}/conversations/([a-zA-Z0-9-]+)(?:\s|[.,;!?)]|$)', body
|
||||
)
|
||||
if conversation_pattern:
|
||||
return conversation_pattern.group(1)
|
||||
|
||||
return None
|
||||
|
||||
def _is_pr_closed_or_merged(self, payload):
|
||||
"""
|
||||
Check if PR was closed (regardless of conversation URL)
|
||||
"""
|
||||
action = payload.get('action', '')
|
||||
return action == 'closed' and 'pull_request' in payload
|
||||
|
||||
def _track_closed_or_merged_pr(self, payload):
|
||||
"""
|
||||
Track PR closed/merged event
|
||||
"""
|
||||
|
||||
repo_id = str(payload['repository']['id'])
|
||||
pr_number = payload['number']
|
||||
installation_id = str(payload['installation']['id'])
|
||||
private = payload['repository']['private']
|
||||
repo_name = payload['repository']['full_name']
|
||||
|
||||
pr_data = payload['pull_request']
|
||||
|
||||
# Extract PR metrics
|
||||
num_reviewers = len(pr_data.get('requested_reviewers', []))
|
||||
num_commits = pr_data.get('commits', 0)
|
||||
num_review_comments = pr_data.get('review_comments', 0)
|
||||
num_general_comments = pr_data.get('comments', 0)
|
||||
num_changed_files = pr_data.get('changed_files', 0)
|
||||
num_additions = pr_data.get('additions', 0)
|
||||
num_deletions = pr_data.get('deletions', 0)
|
||||
merged = pr_data.get('merged', False)
|
||||
|
||||
# Extract closed_at timestamp
|
||||
# Example: "closed_at":"2025-06-19T21:19:36Z"
|
||||
closed_at_str = pr_data.get('closed_at')
|
||||
created_at = pr_data.get('created_at')
|
||||
|
||||
closed_at = datetime.fromisoformat(closed_at_str.replace('Z', '+00:00'))
|
||||
|
||||
# Determine status based on whether it was merged
|
||||
status = PRStatus.MERGED if merged else PRStatus.CLOSED
|
||||
|
||||
store = OpenhandsPRStore.get_instance()
|
||||
|
||||
pr = OpenhandsPR(
|
||||
repo_name=repo_name,
|
||||
repo_id=repo_id,
|
||||
pr_number=pr_number,
|
||||
status=status,
|
||||
provider=ProviderType.GITHUB.value,
|
||||
installation_id=installation_id,
|
||||
private=private,
|
||||
num_reviewers=num_reviewers,
|
||||
num_commits=num_commits,
|
||||
num_review_comments=num_review_comments,
|
||||
num_changed_files=num_changed_files,
|
||||
num_additions=num_additions,
|
||||
num_deletions=num_deletions,
|
||||
merged=merged,
|
||||
created_at=created_at,
|
||||
closed_at=closed_at,
|
||||
# These properties will be enriched later
|
||||
openhands_helped_author=None,
|
||||
num_openhands_commits=None,
|
||||
num_openhands_review_comments=None,
|
||||
num_general_comments=num_general_comments,
|
||||
)
|
||||
|
||||
store.insert_pr(pr)
|
||||
logger.info(f'Tracked PR {status}: {repo_id}#{pr_number}')
|
||||
|
||||
def process_payload(self, message: Message):
|
||||
if not COLLECT_GITHUB_INTERACTIONS:
|
||||
return
|
||||
|
||||
raw_payload = message.message.get('payload', {})
|
||||
|
||||
if self._is_pr_closed_or_merged(raw_payload):
|
||||
self._track_closed_or_merged_pr(raw_payload)
|
||||
|
||||
async def save_data(self, github_view: ResolverViewInterface):
|
||||
if not COLLECT_GITHUB_INTERACTIONS:
|
||||
return
|
||||
|
||||
return
|
||||
|
||||
# TODO: track issue metadata in DB and save comments to filestore
|
||||
344
enterprise/integrations/github/github_manager.py
Normal file
344
enterprise/integrations/github/github_manager.py
Normal file
@@ -0,0 +1,344 @@
|
||||
from types import MappingProxyType
|
||||
|
||||
from github import Github, GithubIntegration
|
||||
from integrations.github.data_collector import GitHubDataCollector
|
||||
from integrations.github.github_solvability import summarize_issue_solvability
|
||||
from integrations.github.github_view import (
|
||||
GithubFactory,
|
||||
GithubFailingAction,
|
||||
GithubInlinePRComment,
|
||||
GithubIssue,
|
||||
GithubIssueComment,
|
||||
GithubPRComment,
|
||||
)
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import (
|
||||
Message,
|
||||
SourceType,
|
||||
)
|
||||
from integrations.types import ResolverViewInterface
|
||||
from integrations.utils import (
|
||||
CONVERSATION_URL,
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
class GithubManager(Manager):
|
||||
def __init__(
|
||||
self, token_manager: TokenManager, data_collector: GitHubDataCollector
|
||||
):
|
||||
self.token_manager = token_manager
|
||||
self.data_collector = data_collector
|
||||
self.github_integration = GithubIntegration(
|
||||
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
)
|
||||
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'github')
|
||||
)
|
||||
|
||||
def _confirm_incoming_source_type(self, message: Message):
|
||||
if message.source != SourceType.GITHUB:
|
||||
raise ValueError(f'Unexpected message source {message.source}')
|
||||
|
||||
def _get_full_repo_name(self, repo_obj: dict) -> str:
|
||||
owner = repo_obj['owner']['login']
|
||||
repo_name = repo_obj['name']
|
||||
|
||||
return f'{owner}/{repo_name}'
|
||||
|
||||
def _get_installation_access_token(self, installation_id: str) -> str:
|
||||
# get_access_token is typed to only accept int, but it can handle str.
|
||||
token_data = self.github_integration.get_access_token(
|
||||
installation_id # type: ignore[arg-type]
|
||||
)
|
||||
return token_data.token
|
||||
|
||||
def _add_reaction(
|
||||
self, github_view: ResolverViewInterface, reaction: str, installation_token: str
|
||||
):
|
||||
"""Add a reaction to the GitHub issue, PR, or comment.
|
||||
|
||||
Args:
|
||||
github_view: The GitHub view object containing issue/PR/comment info
|
||||
reaction: The reaction to add (e.g. "eyes", "+1", "-1", "laugh", "confused", "heart", "hooray", "rocket")
|
||||
installation_token: GitHub installation access token for API access
|
||||
"""
|
||||
with Github(installation_token) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
# Add reaction based on view type
|
||||
if isinstance(github_view, GithubInlinePRComment):
|
||||
pr = repo.get_pull(github_view.issue_number)
|
||||
inline_comment = pr.get_review_comment(github_view.comment_id)
|
||||
inline_comment.create_reaction(reaction)
|
||||
|
||||
elif isinstance(github_view, (GithubIssueComment, GithubPRComment)):
|
||||
issue = repo.get_issue(github_view.issue_number)
|
||||
comment = issue.get_comment(github_view.comment_id)
|
||||
comment.create_reaction(reaction)
|
||||
else:
|
||||
issue = repo.get_issue(github_view.issue_number)
|
||||
issue.create_reaction(reaction)
|
||||
|
||||
def _user_has_write_access_to_repo(
|
||||
self, installation_id: str, full_repo_name: str, username: str
|
||||
) -> bool:
|
||||
"""Check if the user is an owner, collaborator, or member of the repository."""
|
||||
with self.github_integration.get_github_for_installation(
|
||||
installation_id, # type: ignore[arg-type]
|
||||
{},
|
||||
) as repos:
|
||||
repository = repos.get_repo(full_repo_name)
|
||||
|
||||
# Check if the user is a collaborator
|
||||
try:
|
||||
collaborator = repository.get_collaborator_permission(username)
|
||||
if collaborator in ['admin', 'write']:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If the above fails, check if the user is an owner or member
|
||||
org = repository.organization
|
||||
if org:
|
||||
user = org.get_members(username)
|
||||
return user is not None
|
||||
|
||||
return False
|
||||
|
||||
async def is_job_requested(self, message: Message) -> bool:
|
||||
self._confirm_incoming_source_type(message)
|
||||
|
||||
installation_id = message.message['installation']
|
||||
payload = message.message.get('payload', {})
|
||||
repo_obj = payload.get('repository')
|
||||
if not repo_obj:
|
||||
return False
|
||||
username = payload.get('sender', {}).get('login')
|
||||
repo_name = self._get_full_repo_name(repo_obj)
|
||||
|
||||
# Suggestions contain `@openhands` macro; avoid kicking off jobs for system recommendations
|
||||
if GithubFactory.is_pr_comment(
|
||||
message
|
||||
) and GithubFailingAction.unqiue_suggestions_header in payload.get(
|
||||
'comment', {}
|
||||
).get('body', ''):
|
||||
return False
|
||||
|
||||
if GithubFactory.is_eligible_for_conversation_starter(
|
||||
message
|
||||
) and self._user_has_write_access_to_repo(installation_id, repo_name, username):
|
||||
await GithubFactory.trigger_conversation_starter(message)
|
||||
|
||||
if not (
|
||||
GithubFactory.is_labeled_issue(message)
|
||||
or GithubFactory.is_issue_comment(message)
|
||||
or GithubFactory.is_pr_comment(message)
|
||||
or GithubFactory.is_inline_pr_comment(message)
|
||||
):
|
||||
return False
|
||||
|
||||
logger.info(f'[GitHub] Checking permissions for {username} in {repo_name}')
|
||||
|
||||
return self._user_has_write_access_to_repo(installation_id, repo_name, username)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
self._confirm_incoming_source_type(message)
|
||||
try:
|
||||
await call_sync_from_async(self.data_collector.process_payload, message)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
'[Github]: Error processing payload for gh interaction', exc_info=True
|
||||
)
|
||||
|
||||
if await self.is_job_requested(message):
|
||||
github_view = await GithubFactory.create_github_view_from_payload(
|
||||
message, self.token_manager
|
||||
)
|
||||
logger.info(
|
||||
f'[GitHub] Creating job for {github_view.user_info.username} in {github_view.full_repo_name}#{github_view.issue_number}'
|
||||
)
|
||||
# Get the installation token
|
||||
installation_token = self._get_installation_access_token(
|
||||
github_view.installation_id
|
||||
)
|
||||
# Store the installation token
|
||||
self.token_manager.store_org_token(
|
||||
github_view.installation_id, installation_token
|
||||
)
|
||||
# Add eyes reaction to acknowledge we've read the request
|
||||
self._add_reaction(github_view, 'eyes', installation_token)
|
||||
await self.start_job(github_view)
|
||||
|
||||
async def send_message(self, message: Message, github_view: ResolverViewInterface):
|
||||
installation_token = self.token_manager.load_org_token(
|
||||
github_view.installation_id
|
||||
)
|
||||
if not installation_token:
|
||||
logger.warning('Missing installation token')
|
||||
return
|
||||
|
||||
outgoing_message = message.message
|
||||
|
||||
if isinstance(github_view, GithubInlinePRComment):
|
||||
with Github(installation_token) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
pr = repo.get_pull(github_view.issue_number)
|
||||
pr.create_review_comment_reply(
|
||||
comment_id=github_view.comment_id, body=outgoing_message
|
||||
)
|
||||
|
||||
elif (
|
||||
isinstance(github_view, GithubPRComment)
|
||||
or isinstance(github_view, GithubIssueComment)
|
||||
or isinstance(github_view, GithubIssue)
|
||||
):
|
||||
with Github(installation_token) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
issue = repo.get_issue(number=github_view.issue_number)
|
||||
issue.create_comment(outgoing_message)
|
||||
|
||||
else:
|
||||
logger.warning('Unsupported location')
|
||||
return
|
||||
|
||||
async def start_job(self, github_view: ResolverViewInterface):
|
||||
"""Kick off a job with openhands agent.
|
||||
|
||||
1. Get user credential
|
||||
2. Initialize new conversation with repo
|
||||
3. Save interaction data
|
||||
"""
|
||||
# Importing here prevents circular import
|
||||
from server.conversation_callback_processor.github_callback_processor import (
|
||||
GithubCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
msg_info = None
|
||||
|
||||
try:
|
||||
user_info = github_view.user_info
|
||||
logger.info(
|
||||
f'[GitHub] Starting job for user {user_info.username} (id={user_info.user_id})'
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
user_token = await self.token_manager.get_idp_token_from_idp_user_id(
|
||||
str(user_info.user_id), ProviderType.GITHUB
|
||||
)
|
||||
|
||||
if not user_token:
|
||||
logger.warning(
|
||||
f'[GitHub] No token found for user {user_info.username} (id={user_info.user_id})'
|
||||
)
|
||||
raise MissingSettingsError('Missing settings')
|
||||
|
||||
logger.info(
|
||||
f'[GitHub] Creating new conversation for user {user_info.username}'
|
||||
)
|
||||
|
||||
secret_store = UserSecrets(
|
||||
provider_tokens=MappingProxyType(
|
||||
{
|
||||
ProviderType.GITHUB: ProviderToken(
|
||||
token=SecretStr(user_token),
|
||||
user_id=str(user_info.user_id),
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# We first initialize a conversation and generate the solvability report BEFORE starting the conversation runtime
|
||||
# This helps us accumulate llm spend without requiring a running runtime. This setups us up for
|
||||
# 1. If there is a problem starting the runtime we still have accumulated total conversation cost
|
||||
# 2. In the future, based on the report confidence we can conditionally start the conversation
|
||||
# 3. Once the conversation is started, its base cost will include the report's spend as well which allows us to control max budget per resolver task
|
||||
convo_metadata = await github_view.initialize_new_conversation()
|
||||
solvability_summary = None
|
||||
try:
|
||||
if user_token:
|
||||
solvability_summary = await summarize_issue_solvability(
|
||||
github_view, user_token
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'[Github]: No user token available for solvability analysis'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'[Github]: Error summarizing issue solvability: {str(e)}'
|
||||
)
|
||||
|
||||
await github_view.create_new_conversation(
|
||||
self.jinja_env, secret_store.provider_tokens, convo_metadata
|
||||
)
|
||||
|
||||
conversation_id = github_view.conversation_id
|
||||
|
||||
logger.info(
|
||||
f'[GitHub] Created conversation {conversation_id} for user {user_info.username}'
|
||||
)
|
||||
|
||||
# Create a GithubCallbackProcessor
|
||||
processor = GithubCallbackProcessor(
|
||||
github_view=github_view,
|
||||
send_summary_instruction=True,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Github] Registered callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
# Send message with conversation link
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
base_msg = f"I'm on it! {user_info.username} can [track my progress at all-hands.dev]({conversation_link})"
|
||||
# Combine messages: include solvability report with "I'm on it!" if successful
|
||||
if solvability_summary:
|
||||
msg_info = f'{base_msg}\n\n{solvability_summary}'
|
||||
else:
|
||||
msg_info = base_msg
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(
|
||||
f'[GitHub] Missing settings error for user {user_info.username}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = f'@{user_info.username} please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(
|
||||
f'[GitHub] LLM authentication error for user {user_info.username}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
msg = self.create_outgoing_message(msg_info)
|
||||
await self.send_message(msg, github_view)
|
||||
|
||||
except Exception:
|
||||
logger.exception('[Github]: Error starting job')
|
||||
msg = self.create_outgoing_message(
|
||||
msg='Uh oh! There was an unexpected error starting the job :('
|
||||
)
|
||||
await self.send_message(msg, github_view)
|
||||
|
||||
try:
|
||||
await self.data_collector.save_data(github_view)
|
||||
except Exception:
|
||||
logger.warning('[Github]: Error saving interaction data', exc_info=True)
|
||||
143
enterprise/integrations/github/github_service.py
Normal file
143
enterprise/integrations/github/github_service.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import asyncio
|
||||
|
||||
from integrations.utils import store_repositories_in_db
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GitHubService
|
||||
from openhands.integrations.service_types import ProviderType, Repository
|
||||
from openhands.server.types import AppMode
|
||||
|
||||
|
||||
class SaaSGitHubService(GitHubService):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
external_auth_id: str | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
base_domain: str | None = None,
|
||||
):
|
||||
logger.debug(
|
||||
f'SaaSGitHubService created with user_id {user_id}, external_auth_id {external_auth_id}, external_auth_token {'set' if external_auth_token else 'None'}, github_token {'set' if token else 'None'}, external_token_manager {external_token_manager}'
|
||||
)
|
||||
super().__init__(
|
||||
user_id=user_id,
|
||||
external_auth_token=external_auth_token,
|
||||
external_auth_id=external_auth_id,
|
||||
token=token,
|
||||
external_token_manager=external_token_manager,
|
||||
base_domain=base_domain,
|
||||
)
|
||||
|
||||
self.external_auth_token = external_auth_token
|
||||
self.external_auth_id = external_auth_id
|
||||
self.token_manager = TokenManager(external=external_token_manager)
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
github_token = None
|
||||
if self.external_auth_token:
|
||||
github_token = SecretStr(
|
||||
await self.token_manager.get_idp_token(
|
||||
self.external_auth_token.get_secret_value(), ProviderType.GITHUB
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got GitHub token {github_token} from access token: {self.external_auth_token}'
|
||||
)
|
||||
elif self.external_auth_id:
|
||||
offline_token = await self.token_manager.load_offline_token(
|
||||
self.external_auth_id
|
||||
)
|
||||
github_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_offline_token(
|
||||
offline_token, ProviderType.GITHUB
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got GitHub token {github_token} from external auth user ID: {self.external_auth_id}'
|
||||
)
|
||||
elif self.user_id:
|
||||
github_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_idp_user_id(
|
||||
self.user_id, ProviderType.GITHUB
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got GitHub token {github_token} from user ID: {self.user_id}'
|
||||
)
|
||||
else:
|
||||
logger.warning('external_auth_token and user_id not set!')
|
||||
return github_token
|
||||
|
||||
async def get_pr_patches(
|
||||
self, owner: str, repo: str, pr_number: int, per_page: int = 30, page: int = 1
|
||||
):
|
||||
"""Get patches for files changed in a PR with pagination support.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
per_page: Number of files per page (default: 30, max: 100)
|
||||
page: Page number to fetch (default: 1)
|
||||
"""
|
||||
url = f'https://api.github.com/repos/{owner}/{repo}/pulls/{pr_number}/files'
|
||||
params = {'per_page': min(per_page, 100), 'page': page} # GitHub max is 100
|
||||
response, headers = await self._make_request(url, params)
|
||||
|
||||
# Parse pagination info from headers
|
||||
has_next_page = 'next' in headers.get('link', '')
|
||||
total_count = int(headers.get('total', 0))
|
||||
|
||||
return {
|
||||
'files': response,
|
||||
'pagination': {
|
||||
'has_next_page': has_next_page,
|
||||
'total_count': total_count,
|
||||
'current_page': page,
|
||||
'per_page': per_page,
|
||||
},
|
||||
}
|
||||
|
||||
async def get_repository_node_id(self, repo_id: str) -> str:
|
||||
"""
|
||||
Get the new GitHub GraphQL node ID for a repository using REST API.
|
||||
|
||||
Args:
|
||||
repo_id: Numeric repository ID as string (e.g., "123456789")
|
||||
|
||||
Returns:
|
||||
New format node ID for GraphQL queries (e.g., "R_kgDOLfkiww")
|
||||
|
||||
Raises:
|
||||
Exception: If the API request fails or node_id is not found
|
||||
"""
|
||||
url = f'https://api.github.com/repositories/{repo_id}'
|
||||
response, _ = await self._make_request(url)
|
||||
node_id = response.get('node_id')
|
||||
if not node_id:
|
||||
raise Exception(f'No node_id found for repository {repo_id}')
|
||||
return node_id
|
||||
|
||||
async def get_paginated_repos(self, page, per_page, sort, installation_id):
|
||||
repositories = await super().get_paginated_repos(
|
||||
page, per_page, sort, installation_id
|
||||
)
|
||||
asyncio.create_task(
|
||||
store_repositories_in_db(repositories, self.external_auth_id)
|
||||
)
|
||||
return repositories
|
||||
|
||||
async def get_all_repositories(
|
||||
self, sort: str, app_mode: AppMode
|
||||
) -> list[Repository]:
|
||||
repositories = await super().get_all_repositories(sort, app_mode)
|
||||
# Schedule the background task without awaiting it
|
||||
asyncio.create_task(
|
||||
store_repositories_in_db(repositories, self.external_auth_id)
|
||||
)
|
||||
# Return repositories immediately
|
||||
return repositories
|
||||
183
enterprise/integrations/github/github_solvability.py
Normal file
183
enterprise/integrations/github/github_solvability.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from github import Github
|
||||
from integrations.github.github_view import (
|
||||
GithubInlinePRComment,
|
||||
GithubIssueComment,
|
||||
GithubPRComment,
|
||||
GithubViewType,
|
||||
)
|
||||
from integrations.solvability.data import load_classifier
|
||||
from integrations.solvability.models.report import SolvabilityReport
|
||||
from integrations.solvability.models.summary import SolvabilitySummary
|
||||
from integrations.utils import ENABLE_SOLVABILITY_ANALYSIS
|
||||
from pydantic import ValidationError
|
||||
from server.auth.token_manager import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.utils import create_registry_and_conversation_stats
|
||||
|
||||
|
||||
def fetch_github_issue_context(
|
||||
github_view: GithubViewType,
|
||||
user_token: str,
|
||||
) -> str:
|
||||
"""Fetch full GitHub issue/PR context including title, body, and comments.
|
||||
|
||||
Args:
|
||||
full_repo_name: Full repository name in the format 'owner/repo'
|
||||
issue_number: The issue or PR number
|
||||
user_token: GitHub user access token
|
||||
max_comments: Maximum number of comments to fetch (default: 10)
|
||||
max_comment_length: Maximum length of each comment to include in the context (default: 500)
|
||||
|
||||
Returns:
|
||||
A comprehensive string containing the issue/PR context
|
||||
"""
|
||||
|
||||
# Build context string
|
||||
context_parts = []
|
||||
|
||||
# Add title and body
|
||||
context_parts.append(f'Title: {github_view.title}')
|
||||
context_parts.append(f'Description:\n{github_view.description}')
|
||||
|
||||
with Github(user_token) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
issue = repo.get_issue(github_view.issue_number)
|
||||
if issue.labels:
|
||||
labels = [label.name for label in issue.labels]
|
||||
context_parts.append(f"Labels: {', '.join(labels)}")
|
||||
|
||||
for comment in github_view.previous_comments:
|
||||
context_parts.append(f'- {comment.author}: {comment.body}')
|
||||
|
||||
return '\n\n'.join(context_parts)
|
||||
|
||||
|
||||
async def summarize_issue_solvability(
|
||||
github_view: GithubViewType,
|
||||
user_token: str,
|
||||
timeout: float = 60.0 * 5,
|
||||
) -> str:
|
||||
"""Generate a solvability summary for an issue using the resolver view interface.
|
||||
|
||||
Args:
|
||||
resolver_view: A resolver view interface instance (e.g., GithubIssue, GithubPRComment)
|
||||
user_token: GitHub user access token for API access
|
||||
timeout: Maximum time in seconds to wait for the result (default: 60.0)
|
||||
|
||||
Returns:
|
||||
The solvability summary as a string
|
||||
|
||||
Raises:
|
||||
ValueError: If LLM settings cannot be found for the user
|
||||
asyncio.TimeoutError: If the operation exceeds the specified timeout
|
||||
"""
|
||||
if not ENABLE_SOLVABILITY_ANALYSIS:
|
||||
raise ValueError('Solvability report feature is disabled')
|
||||
|
||||
if github_view.user_info.keycloak_user_id is None:
|
||||
raise ValueError(
|
||||
f'[Solvability] No user ID found for user {github_view.user_info.username}'
|
||||
)
|
||||
|
||||
# Grab the user's information so we can load their LLM configuration
|
||||
store = SaasSettingsStore(
|
||||
user_id=github_view.user_info.keycloak_user_id,
|
||||
session_maker=session_maker,
|
||||
config=get_config(),
|
||||
)
|
||||
|
||||
user_settings = await store.load()
|
||||
|
||||
if user_settings is None:
|
||||
raise ValueError(
|
||||
f'[Solvability] No user settings found for user ID {github_view.user_info.user_id}'
|
||||
)
|
||||
|
||||
# Check if solvability analysis is enabled for this user, exit early if
|
||||
# needed
|
||||
if not getattr(user_settings, 'enable_solvability_analysis', False):
|
||||
raise ValueError(
|
||||
f'Solvability analysis disabled for user {github_view.user_info.user_id}'
|
||||
)
|
||||
|
||||
try:
|
||||
llm_config = LLMConfig(
|
||||
model=user_settings.llm_model,
|
||||
api_key=user_settings.llm_api_key.get_secret_value(),
|
||||
base_url=user_settings.llm_base_url,
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
f'[Solvability] Invalid LLM configuration for user {github_view.user_info.user_id}: {str(e)}'
|
||||
)
|
||||
|
||||
# Fetch the full GitHub issue/PR context using the GitHub API
|
||||
start_time = time.time()
|
||||
issue_context = fetch_github_issue_context(github_view, user_token)
|
||||
logger.info(
|
||||
f'[Solvability] Grabbed issue context for {github_view.conversation_id}',
|
||||
extra={
|
||||
'conversation_id': github_view.conversation_id,
|
||||
'response_latency': time.time() - start_time,
|
||||
'full_repo_name': github_view.full_repo_name,
|
||||
'issue_number': github_view.issue_number,
|
||||
},
|
||||
)
|
||||
|
||||
# For comment-based triggers, also include the specific comment that triggered the action
|
||||
if isinstance(
|
||||
github_view, (GithubIssueComment, GithubPRComment, GithubInlinePRComment)
|
||||
):
|
||||
issue_context += f'\n\nTriggering Comment:\n{github_view.comment_body}'
|
||||
|
||||
solvability_classifier = load_classifier('default-classifier')
|
||||
|
||||
async with asyncio.timeout(timeout):
|
||||
solvability_report: SolvabilityReport = await call_sync_from_async(
|
||||
lambda: solvability_classifier.solvability_report(
|
||||
issue_context, llm_config=llm_config
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Solvability] Generated report for {github_view.conversation_id}',
|
||||
extra={
|
||||
'conversation_id': github_view.conversation_id,
|
||||
'report': solvability_report.model_dump(exclude=['issue']),
|
||||
},
|
||||
)
|
||||
|
||||
llm_registry, conversation_stats, _ = create_registry_and_conversation_stats(
|
||||
get_config(),
|
||||
github_view.conversation_id,
|
||||
github_view.user_info.keycloak_user_id,
|
||||
None,
|
||||
)
|
||||
|
||||
solvability_summary = await call_sync_from_async(
|
||||
lambda: SolvabilitySummary.from_report(
|
||||
solvability_report,
|
||||
llm=llm_registry.get_llm(
|
||||
service_id='solvability_analysis', config=llm_config
|
||||
),
|
||||
)
|
||||
)
|
||||
conversation_stats.save_metrics()
|
||||
|
||||
logger.info(
|
||||
f'[Solvability] Generated summary for {github_view.conversation_id}',
|
||||
extra={
|
||||
'conversation_id': github_view.conversation_id,
|
||||
'summary': solvability_summary.model_dump(exclude=['content']),
|
||||
},
|
||||
)
|
||||
|
||||
return solvability_summary.format_as_markdown()
|
||||
26
enterprise/integrations/github/github_types.py
Normal file
26
enterprise/integrations/github/github_types.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WorkflowRunStatus(Enum):
|
||||
FAILURE = 'failure'
|
||||
COMPLETED = 'completed'
|
||||
PENDING = 'pending'
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.value == other
|
||||
return super().__eq__(other)
|
||||
|
||||
|
||||
class WorkflowRun(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
status: WorkflowRunStatus
|
||||
|
||||
model_config = {'use_enum_values': True}
|
||||
|
||||
|
||||
class WorkflowRunGroup(BaseModel):
|
||||
runs: dict[str, WorkflowRun]
|
||||
756
enterprise/integrations/github/github_view.py
Normal file
756
enterprise/integrations/github/github_view.py
Normal file
@@ -0,0 +1,756 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from github import Github, GithubIntegration
|
||||
from github.Issue import Issue
|
||||
from integrations.github.github_types import (
|
||||
WorkflowRun,
|
||||
WorkflowRunGroup,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from integrations.models import Message
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import (
|
||||
ENABLE_PROACTIVE_CONVERSATION_STARTERS,
|
||||
HOST,
|
||||
HOST_URL,
|
||||
get_oh_labels,
|
||||
has_exact_mention,
|
||||
)
|
||||
from jinja2 import Environment
|
||||
from pydantic.dataclasses import dataclass
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager, get_config
|
||||
from storage.database import session_maker
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
from openhands.integrations.service_types import Comment
|
||||
from openhands.server.services.conversation_service import (
|
||||
initialize_conversation,
|
||||
start_conversation,
|
||||
)
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
|
||||
|
||||
async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
"""Get the user's proactive conversation setting.
|
||||
|
||||
Args:
|
||||
user_id: The keycloak user ID
|
||||
|
||||
Returns:
|
||||
True if proactive conversations are enabled for this user, False otherwise
|
||||
|
||||
Note:
|
||||
This function checks both the global environment variable kill switch AND
|
||||
the user's individual setting. Both must be true for the function to return true.
|
||||
"""
|
||||
|
||||
# If no user ID is provided, we can't check user settings
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
def _get_setting():
|
||||
with session_maker() as session:
|
||||
settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not settings or settings.enable_proactive_conversation_starters is None:
|
||||
return False
|
||||
|
||||
return settings.enable_proactive_conversation_starters
|
||||
|
||||
return await call_sync_from_async(_get_setting)
|
||||
|
||||
|
||||
# =================================================
|
||||
# SECTION: Github view types
|
||||
# =================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubIssue(ResolverViewInterface):
|
||||
issue_number: int
|
||||
installation_id: int
|
||||
full_repo_name: str
|
||||
is_public_repo: bool
|
||||
user_info: UserData
|
||||
raw_payload: Message
|
||||
conversation_id: str
|
||||
uuid: str | None
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
title: str
|
||||
description: str
|
||||
previous_comments: list[Comment]
|
||||
|
||||
async def _load_resolver_context(self):
|
||||
github_service = GithubServiceImpl(
|
||||
external_auth_id=self.user_info.keycloak_user_id
|
||||
)
|
||||
|
||||
self.previous_comments = await github_service.get_issue_or_pr_comments(
|
||||
self.full_repo_name, self.issue_number
|
||||
)
|
||||
|
||||
(
|
||||
self.title,
|
||||
self.description,
|
||||
) = await github_service.get_issue_or_pr_title_and_body(
|
||||
self.full_repo_name, self.issue_number
|
||||
)
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('issue_prompt.j2')
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
)
|
||||
|
||||
await self._load_resolver_context()
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_title=self.title,
|
||||
issue_body=self.description,
|
||||
previous_comments=self.previous_comments,
|
||||
)
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=None,
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
async def create_new_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
):
|
||||
custom_secrets = await self._get_user_secrets()
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
|
||||
await start_conversation(
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
custom_secrets=custom_secrets,
|
||||
initial_user_msg=user_instructions,
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_id=conversation_metadata.conversation_id,
|
||||
conversation_metadata=conversation_metadata,
|
||||
conversation_instructions=conversation_instructions,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubIssueComment(GithubIssue):
|
||||
comment_body: str
|
||||
comment_id: int
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('issue_prompt.j2')
|
||||
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
issue_comment=self.comment_body
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
issue_title=self.title,
|
||||
issue_body=self.description,
|
||||
previous_comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubPRComment(GithubIssueComment):
|
||||
branch_name: str
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('pr_update_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
pr_comment=self.comment_body,
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'pr_update_conversation_instructions.j2'
|
||||
)
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
pr_number=self.issue_number,
|
||||
branch_name=self.branch_name,
|
||||
pr_title=self.title,
|
||||
pr_body=self.description,
|
||||
comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=self.branch_name,
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubInlinePRComment(GithubPRComment):
|
||||
file_location: str
|
||||
line_number: int
|
||||
comment_node_id: str
|
||||
|
||||
async def _load_resolver_context(self):
|
||||
github_service = GithubServiceImpl(
|
||||
external_auth_id=self.user_info.keycloak_user_id
|
||||
)
|
||||
|
||||
(
|
||||
self.title,
|
||||
self.description,
|
||||
) = await github_service.get_issue_or_pr_title_and_body(
|
||||
self.full_repo_name, self.issue_number
|
||||
)
|
||||
|
||||
self.previous_comments = await github_service.get_review_thread_comments(
|
||||
self.comment_node_id, self.full_repo_name, self.issue_number
|
||||
)
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('pr_update_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
pr_comment=self.comment_body,
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'pr_update_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
pr_number=self.issue_number,
|
||||
pr_title=self.title,
|
||||
pr_body=self.description,
|
||||
branch_name=self.branch_name,
|
||||
file_location=self.file_location,
|
||||
line_number=self.line_number,
|
||||
comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubFailingAction:
|
||||
unqiue_suggestions_header: str = (
|
||||
'Looks like there are a few issues preventing this PR from being merged!'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_latest_sha(pr: Issue) -> str:
|
||||
pr_obj = pr.as_pull_request()
|
||||
return pr_obj.head.sha
|
||||
|
||||
@staticmethod
|
||||
def create_retrieve_workflows_callback(pr: Issue, head_sha: str):
|
||||
def get_all_workflows():
|
||||
repo = pr.repository
|
||||
workflows = repo.get_workflow_runs(head_sha=head_sha)
|
||||
|
||||
runs = {}
|
||||
|
||||
for workflow in workflows:
|
||||
conclusion = workflow.conclusion
|
||||
workflow_conclusion = WorkflowRunStatus.COMPLETED
|
||||
if conclusion is None:
|
||||
workflow_conclusion = WorkflowRunStatus.PENDING # type: ignore[unreachable]
|
||||
elif conclusion == WorkflowRunStatus.FAILURE.value:
|
||||
workflow_conclusion = WorkflowRunStatus.FAILURE
|
||||
|
||||
runs[str(workflow.id)] = WorkflowRun(
|
||||
id=str(workflow.id), name=workflow.name, status=workflow_conclusion
|
||||
)
|
||||
|
||||
return WorkflowRunGroup(runs=runs)
|
||||
|
||||
return get_all_workflows
|
||||
|
||||
@staticmethod
|
||||
def delete_old_comment_if_exists(pr: Issue):
|
||||
paginated_comments = pr.get_comments()
|
||||
for page in range(paginated_comments.totalCount):
|
||||
comments = paginated_comments.get_page(page)
|
||||
for comment in comments:
|
||||
if GithubFailingAction.unqiue_suggestions_header in comment.body:
|
||||
comment.delete()
|
||||
|
||||
@staticmethod
|
||||
def get_suggestions(
|
||||
failed_jobs: dict, pr_number: int, branch_name: str | None = None
|
||||
) -> str:
|
||||
issues = []
|
||||
|
||||
# Collect failing actions with their specific names
|
||||
if failed_jobs['actions']:
|
||||
failing_actions = failed_jobs['actions']
|
||||
issues.append(('GitHub Actions are failing:', False))
|
||||
for action in failing_actions:
|
||||
issues.append((action, True))
|
||||
|
||||
if any(failed_jobs['merge conflict']):
|
||||
issues.append(('There are merge conflicts', False))
|
||||
|
||||
# Format each line with proper indentation and dashes
|
||||
formatted_issues = []
|
||||
for issue, is_nested in issues:
|
||||
if is_nested:
|
||||
formatted_issues.append(f' - {issue}')
|
||||
else:
|
||||
formatted_issues.append(f'- {issue}')
|
||||
issues_text = '\n'.join(formatted_issues)
|
||||
|
||||
# Build list of possible suggestions based on actual issues
|
||||
suggestions = []
|
||||
branch_info = f' at branch `{branch_name}`' if branch_name else ''
|
||||
|
||||
if any(failed_jobs['merge conflict']):
|
||||
suggestions.append(
|
||||
f'@OpenHands please fix the merge conflicts on PR #{pr_number}{branch_info}'
|
||||
)
|
||||
if any(failed_jobs['actions']):
|
||||
suggestions.append(
|
||||
f'@OpenHands please fix the failing actions on PR #{pr_number}{branch_info}'
|
||||
)
|
||||
|
||||
# Take at most 2 suggestions
|
||||
suggestions = suggestions[:2]
|
||||
|
||||
help_text = """If you'd like me to help, just leave a comment, like
|
||||
|
||||
```
|
||||
{}
|
||||
```
|
||||
|
||||
Feel free to include any additional details that might help me get this PR into a better state.
|
||||
|
||||
<sub><sup>You can manage your notification [settings]({})</sup></sub>""".format(
|
||||
'\n```\n\nor\n\n```\n'.join(suggestions), f'{HOST_URL}/settings/app'
|
||||
)
|
||||
|
||||
return f'{GithubFailingAction.unqiue_suggestions_header}\n\n{issues_text}\n\n{help_text}'
|
||||
|
||||
@staticmethod
|
||||
def leave_requesting_comment(pr: Issue, failed_runs: WorkflowRunGroup):
|
||||
failed_jobs: dict = {'actions': [], 'merge conflict': []}
|
||||
|
||||
pr_obj = pr.as_pull_request()
|
||||
if not pr_obj.mergeable:
|
||||
failed_jobs['merge conflict'].append('Merge conflict detected')
|
||||
|
||||
for _, workflow_run in failed_runs.runs.items():
|
||||
if workflow_run.status == WorkflowRunStatus.FAILURE:
|
||||
failed_jobs['actions'].append(workflow_run.name)
|
||||
|
||||
logger.info(f'[GitHub] Found failing jobs for PR #{pr.number}: {failed_jobs}')
|
||||
|
||||
# Get the branch name
|
||||
branch_name = pr_obj.head.ref
|
||||
|
||||
# Get suggestions with branch name included
|
||||
suggestions = GithubFailingAction.get_suggestions(
|
||||
failed_jobs, pr.number, branch_name
|
||||
)
|
||||
|
||||
GithubFailingAction.delete_old_comment_if_exists(pr)
|
||||
pr.create_comment(suggestions)
|
||||
|
||||
|
||||
GithubViewType = (
|
||||
GithubInlinePRComment | GithubPRComment | GithubIssueComment | GithubIssue
|
||||
)
|
||||
|
||||
|
||||
# =================================================
|
||||
# SECTION: Factory to create appriorate Github view
|
||||
# =================================================
|
||||
|
||||
|
||||
class GithubFactory:
|
||||
@staticmethod
|
||||
def is_labeled_issue(message: Message):
|
||||
payload = message.message.get('payload', {})
|
||||
action = payload.get('action', '')
|
||||
|
||||
if action == 'labeled' and 'label' in payload and 'issue' in payload:
|
||||
label_name = payload['label'].get('name', '')
|
||||
if label_name == OH_LABEL:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_issue_comment(message: Message):
|
||||
payload = message.message.get('payload', {})
|
||||
action = payload.get('action', '')
|
||||
|
||||
if (
|
||||
action == 'created'
|
||||
and 'comment' in payload
|
||||
and 'issue' in payload
|
||||
and 'pull_request' not in payload['issue']
|
||||
):
|
||||
comment_body = payload['comment']['body']
|
||||
if has_exact_mention(comment_body, INLINE_OH_LABEL):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_pr_comment(message: Message):
|
||||
payload = message.message.get('payload', {})
|
||||
action = payload.get('action', '')
|
||||
|
||||
if (
|
||||
action == 'created'
|
||||
and 'comment' in payload
|
||||
and 'issue' in payload
|
||||
and 'pull_request' in payload['issue']
|
||||
):
|
||||
comment_body = payload['comment'].get('body', '')
|
||||
if has_exact_mention(comment_body, INLINE_OH_LABEL):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_inline_pr_comment(message: Message):
|
||||
payload = message.message.get('payload', {})
|
||||
action = payload.get('action', '')
|
||||
|
||||
if action == 'created' and 'comment' in payload and 'pull_request' in payload:
|
||||
comment_body = payload['comment'].get('body', '')
|
||||
if has_exact_mention(comment_body, INLINE_OH_LABEL):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_eligible_for_conversation_starter(message: Message):
|
||||
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
|
||||
return False
|
||||
|
||||
payload = message.message.get('payload', {})
|
||||
action = payload.get('action', '')
|
||||
|
||||
if not (action == 'completed' and 'workflow_run' in payload):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def trigger_conversation_starter(message: Message):
|
||||
"""Trigger a conversation starter when a workflow fails.
|
||||
|
||||
This is the updated version that checks user settings.
|
||||
"""
|
||||
payload = message.message.get('payload', {})
|
||||
workflow_payload = payload['workflow_run']
|
||||
status = WorkflowRunStatus.COMPLETED
|
||||
|
||||
if workflow_payload['conclusion'] == 'failure':
|
||||
status = WorkflowRunStatus.FAILURE
|
||||
elif workflow_payload['conclusion'] is None:
|
||||
status = WorkflowRunStatus.PENDING
|
||||
|
||||
workflow_run = WorkflowRun(
|
||||
id=str(workflow_payload['id']), name=workflow_payload['name'], status=status
|
||||
)
|
||||
|
||||
selected_repo = GithubFactory.get_full_repo_name(payload['repository'])
|
||||
head_branch = payload['workflow_run']['head_branch']
|
||||
|
||||
# Get the user ID to check their settings
|
||||
user_id = None
|
||||
try:
|
||||
sender_id = payload['sender']['id']
|
||||
token_manager = TokenManager()
|
||||
user_id = await token_manager.get_user_id_from_idp_user_id(
|
||||
sender_id, ProviderType.GITHUB
|
||||
)
|
||||
except (KeyError, Exception) as e:
|
||||
logger.warning(
|
||||
f'Failed to get user ID for proactive conversation check: {str(e)}'
|
||||
)
|
||||
|
||||
# Check if proactive conversations are enabled for this user
|
||||
if not await get_user_proactive_conversation_setting(user_id):
|
||||
return False
|
||||
|
||||
def _interact_with_github() -> Issue | None:
|
||||
with GithubIntegration(
|
||||
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
) as integration:
|
||||
access_token = integration.get_access_token(
|
||||
payload['installation']['id']
|
||||
).token
|
||||
|
||||
with Github(access_token) as gh:
|
||||
repo = gh.get_repo(selected_repo)
|
||||
login = (
|
||||
payload['organization']['login']
|
||||
if 'organization' in payload
|
||||
else payload['sender']['login']
|
||||
)
|
||||
|
||||
# See if a pull request is open
|
||||
open_pulls = repo.get_pulls(state='open', head=f'{login}:{head_branch}')
|
||||
if open_pulls.totalCount > 0:
|
||||
prs = open_pulls.get_page(0)
|
||||
relevant_pr = prs[0]
|
||||
issue = repo.get_issue(number=relevant_pr.number)
|
||||
return issue
|
||||
|
||||
return None
|
||||
|
||||
issue: Issue | None = await call_sync_from_async(_interact_with_github)
|
||||
if not issue:
|
||||
return False
|
||||
|
||||
incoming_commit = payload['workflow_run']['head_sha']
|
||||
latest_sha = GithubFailingAction.get_latest_sha(issue)
|
||||
if latest_sha != incoming_commit:
|
||||
# Return as this commit is not the latest
|
||||
return False
|
||||
|
||||
convo_store = ProactiveConversationStore()
|
||||
workflow_group = await convo_store.store_workflow_information(
|
||||
provider=ProviderType.GITHUB,
|
||||
repo_id=payload['repository']['id'],
|
||||
incoming_commit=incoming_commit,
|
||||
workflow=workflow_run,
|
||||
pr_number=issue.number,
|
||||
get_all_workflows=GithubFailingAction.create_retrieve_workflows_callback(
|
||||
issue, incoming_commit
|
||||
),
|
||||
)
|
||||
|
||||
if not workflow_group:
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
f'[GitHub] Workflow completed for {selected_repo}#{issue.number} on branch {head_branch}'
|
||||
)
|
||||
GithubFailingAction.leave_requesting_comment(issue, workflow_group)
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_full_repo_name(repo_obj: dict) -> str:
|
||||
owner = repo_obj['owner']['login']
|
||||
repo_name = repo_obj['name']
|
||||
return f'{owner}/{repo_name}'
|
||||
|
||||
@staticmethod
|
||||
async def create_github_view_from_payload(
|
||||
message: Message, token_manager: TokenManager
|
||||
) -> ResolverViewInterface:
|
||||
"""Create the appropriate class (GithubIssue or GithubPRComment) based on the payload.
|
||||
Also return metadata about the event (e.g., action type).
|
||||
"""
|
||||
payload = message.message.get('payload', {})
|
||||
repo_obj = payload['repository']
|
||||
user_id = payload['sender']['id']
|
||||
username = payload['sender']['login']
|
||||
|
||||
keyloak_user_id = await token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITHUB
|
||||
)
|
||||
|
||||
if keyloak_user_id is None:
|
||||
logger.warning(f'Got invalid keyloak user id for GitHub User {user_id} ')
|
||||
|
||||
selected_repo = GithubFactory.get_full_repo_name(repo_obj)
|
||||
is_public_repo = not repo_obj.get('private', True)
|
||||
user_info = UserData(
|
||||
user_id=user_id, username=username, keycloak_user_id=keyloak_user_id
|
||||
)
|
||||
|
||||
installation_id = message.message['installation']
|
||||
|
||||
if GithubFactory.is_labeled_issue(message):
|
||||
issue_number = payload['issue']['number']
|
||||
logger.info(
|
||||
f'[GitHub] Creating view for labeled issue from {username} in {selected_repo}#{issue_number}'
|
||||
)
|
||||
return GithubIssue(
|
||||
issue_number=issue_number,
|
||||
installation_id=installation_id,
|
||||
full_repo_name=selected_repo,
|
||||
is_public_repo=is_public_repo,
|
||||
raw_payload=message,
|
||||
user_info=user_info,
|
||||
conversation_id='',
|
||||
uuid=str(uuid4()),
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
)
|
||||
|
||||
elif GithubFactory.is_issue_comment(message):
|
||||
issue_number = payload['issue']['number']
|
||||
comment_body = payload['comment']['body']
|
||||
comment_id = payload['comment']['id']
|
||||
logger.info(
|
||||
f'[GitHub] Creating view for issue comment from {username} in {selected_repo}#{issue_number}'
|
||||
)
|
||||
return GithubIssueComment(
|
||||
issue_number=issue_number,
|
||||
comment_body=comment_body,
|
||||
comment_id=comment_id,
|
||||
installation_id=installation_id,
|
||||
full_repo_name=selected_repo,
|
||||
is_public_repo=is_public_repo,
|
||||
raw_payload=message,
|
||||
user_info=user_info,
|
||||
conversation_id='',
|
||||
uuid=None,
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
)
|
||||
|
||||
elif GithubFactory.is_pr_comment(message):
|
||||
issue_number = payload['issue']['number']
|
||||
logger.info(
|
||||
f'[GitHub] Creating view for PR comment from {username} in {selected_repo}#{issue_number}'
|
||||
)
|
||||
|
||||
access_token = ''
|
||||
with GithubIntegration(
|
||||
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
) as integration:
|
||||
access_token = integration.get_access_token(installation_id).token
|
||||
|
||||
head_ref = None
|
||||
with Github(access_token) as gh:
|
||||
repo = gh.get_repo(selected_repo)
|
||||
pull_request = repo.get_pull(issue_number)
|
||||
head_ref = pull_request.head.ref
|
||||
logger.info(
|
||||
f'[GitHub] Found PR branch {head_ref} for {selected_repo}#{issue_number}'
|
||||
)
|
||||
|
||||
comment_id = payload['comment']['id']
|
||||
return GithubPRComment(
|
||||
issue_number=issue_number,
|
||||
branch_name=head_ref,
|
||||
comment_body=payload['comment']['body'],
|
||||
comment_id=comment_id,
|
||||
installation_id=installation_id,
|
||||
full_repo_name=selected_repo,
|
||||
is_public_repo=is_public_repo,
|
||||
raw_payload=message,
|
||||
user_info=user_info,
|
||||
conversation_id='',
|
||||
uuid=None,
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
)
|
||||
|
||||
elif GithubFactory.is_inline_pr_comment(message):
|
||||
pr_number = payload['pull_request']['number']
|
||||
branch_name = payload['pull_request']['head']['ref']
|
||||
comment_id = payload['comment']['id']
|
||||
comment_node_id = payload['comment']['node_id']
|
||||
file_path = payload['comment']['path']
|
||||
line_number = payload['comment']['line']
|
||||
logger.info(
|
||||
f'[GitHub] Creating view for inline PR comment from {username} in {selected_repo}#{pr_number} at {file_path}'
|
||||
)
|
||||
|
||||
return GithubInlinePRComment(
|
||||
issue_number=pr_number,
|
||||
branch_name=branch_name,
|
||||
comment_body=payload['comment']['body'],
|
||||
comment_node_id=comment_node_id,
|
||||
comment_id=comment_id,
|
||||
file_location=file_path,
|
||||
line_number=line_number,
|
||||
installation_id=installation_id,
|
||||
full_repo_name=selected_repo,
|
||||
is_public_repo=is_public_repo,
|
||||
raw_payload=message,
|
||||
user_info=user_info,
|
||||
conversation_id='',
|
||||
uuid=None,
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid payload: must contain either 'issue' or 'pull_request'"
|
||||
)
|
||||
102
enterprise/integrations/github/queries.py
Normal file
102
enterprise/integrations/github/queries.py
Normal file
@@ -0,0 +1,102 @@
|
||||
PR_QUERY_BY_NODE_ID = """
|
||||
query($nodeId: ID!, $pr_number: Int!, $commits_after: String, $comments_after: String, $reviews_after: String) {
|
||||
node(id: $nodeId) {
|
||||
... on Repository {
|
||||
name
|
||||
owner {
|
||||
login
|
||||
}
|
||||
languages(first: 10, orderBy: {field: SIZE, direction: DESC}) {
|
||||
nodes {
|
||||
name
|
||||
}
|
||||
}
|
||||
pullRequest(number: $pr_number) {
|
||||
number
|
||||
title
|
||||
body
|
||||
author {
|
||||
login
|
||||
}
|
||||
merged
|
||||
mergedAt
|
||||
mergedBy {
|
||||
login
|
||||
}
|
||||
state
|
||||
mergeCommit {
|
||||
oid
|
||||
}
|
||||
comments(first: 50, after: $comments_after) {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
endCursor
|
||||
}
|
||||
nodes {
|
||||
author {
|
||||
login
|
||||
}
|
||||
body
|
||||
createdAt
|
||||
}
|
||||
}
|
||||
commits(first: 50, after: $commits_after) {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
endCursor
|
||||
}
|
||||
nodes {
|
||||
commit {
|
||||
oid
|
||||
message
|
||||
committedDate
|
||||
author {
|
||||
name
|
||||
email
|
||||
user {
|
||||
login
|
||||
}
|
||||
}
|
||||
additions
|
||||
deletions
|
||||
changedFiles
|
||||
}
|
||||
}
|
||||
}
|
||||
reviews(first: 50, after: $reviews_after) {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
endCursor
|
||||
}
|
||||
nodes {
|
||||
author {
|
||||
login
|
||||
}
|
||||
body
|
||||
state
|
||||
createdAt
|
||||
comments(first: 50) {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
endCursor
|
||||
}
|
||||
nodes {
|
||||
author {
|
||||
login
|
||||
}
|
||||
body
|
||||
createdAt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rateLimit {
|
||||
remaining
|
||||
limit
|
||||
resetAt
|
||||
}
|
||||
}
|
||||
"""
|
||||
249
enterprise/integrations/gitlab/gitlab_manager.py
Normal file
249
enterprise/integrations/gitlab/gitlab_manager.py
Normal file
@@ -0,0 +1,249 @@
|
||||
from types import MappingProxyType
|
||||
|
||||
from integrations.gitlab.gitlab_view import (
|
||||
GitlabFactory,
|
||||
GitlabInlineMRComment,
|
||||
GitlabIssue,
|
||||
GitlabIssueComment,
|
||||
GitlabMRComment,
|
||||
GitlabViewType,
|
||||
)
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.types import ResolverViewInterface
|
||||
from integrations.utils import (
|
||||
CONVERSATION_URL,
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
|
||||
|
||||
class GitlabManager(Manager):
|
||||
def __init__(self, token_manager: TokenManager, data_collector: None = None):
|
||||
self.token_manager = token_manager
|
||||
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'gitlab')
|
||||
)
|
||||
|
||||
def _confirm_incoming_source_type(self, message: Message):
|
||||
if message.source != SourceType.GITLAB:
|
||||
raise ValueError(f'Unexpected message source {message.source}')
|
||||
|
||||
async def _user_has_write_access_to_repo(
|
||||
self, project_id: str, user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the user has write access to the repository (can pull/push changes and open merge requests).
|
||||
|
||||
Args:
|
||||
project_id: The ID of the GitLab project
|
||||
username: The username of the user
|
||||
user_id: The GitLab user ID
|
||||
|
||||
Returns:
|
||||
bool: True if the user has write access to the repository, False otherwise
|
||||
"""
|
||||
|
||||
keycloak_user_id = await self.token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITLAB
|
||||
)
|
||||
if keycloak_user_id is None:
|
||||
logger.warning(f'Got invalid keyloak user id for GitLab User {user_id}')
|
||||
return False
|
||||
|
||||
gitlab_service = GitLabServiceImpl(external_auth_id=keycloak_user_id)
|
||||
return await gitlab_service.user_has_write_access(project_id)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
self._confirm_incoming_source_type(message)
|
||||
if await self.is_job_requested(message):
|
||||
gitlab_view = await GitlabFactory.create_gitlab_view_from_payload(
|
||||
message, self.token_manager
|
||||
)
|
||||
logger.info(
|
||||
f'[GitLab] Creating job for {gitlab_view.user_info.username} in {gitlab_view.full_repo_name}#{gitlab_view.issue_number}'
|
||||
)
|
||||
|
||||
await self.start_job(gitlab_view)
|
||||
|
||||
async def is_job_requested(self, message) -> bool:
|
||||
self._confirm_incoming_source_type(message)
|
||||
if not (
|
||||
GitlabFactory.is_labeled_issue(message)
|
||||
or GitlabFactory.is_issue_comment(message)
|
||||
or GitlabFactory.is_mr_comment(message)
|
||||
or GitlabFactory.is_mr_comment(message, inline=True)
|
||||
):
|
||||
return False
|
||||
|
||||
payload = message.message['payload']
|
||||
|
||||
repo_obj = payload['project']
|
||||
project_id = repo_obj['id']
|
||||
selected_project = repo_obj['path_with_namespace']
|
||||
user = payload['user']
|
||||
user_id = user['id']
|
||||
username = user['username']
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Checking permissions for {username} in {selected_project}'
|
||||
)
|
||||
|
||||
has_write_access = await self._user_has_write_access_to_repo(
|
||||
project_id=str(project_id), user_id=user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[GitLab]: {username} access in {selected_project}: {has_write_access}'
|
||||
)
|
||||
# Check if the user has write access to the repository
|
||||
return has_write_access
|
||||
|
||||
async def send_message(self, message: Message, gitlab_view: ResolverViewInterface):
|
||||
"""
|
||||
Send a message to GitLab based on the view type.
|
||||
|
||||
Args:
|
||||
message: The message to send
|
||||
gitlab_view: The GitLab view object containing issue/PR/comment info
|
||||
"""
|
||||
keycloak_user_id = gitlab_view.user_info.keycloak_user_id
|
||||
gitlab_service = GitLabServiceImpl(external_auth_id=keycloak_user_id)
|
||||
|
||||
outgoing_message = message.message
|
||||
|
||||
if isinstance(gitlab_view, GitlabInlineMRComment) or isinstance(
|
||||
gitlab_view, GitlabMRComment
|
||||
):
|
||||
await gitlab_service.reply_to_mr(
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
gitlab_view.discussion_id,
|
||||
message.message,
|
||||
)
|
||||
|
||||
elif isinstance(gitlab_view, GitlabIssueComment):
|
||||
await gitlab_service.reply_to_issue(
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
gitlab_view.discussion_id,
|
||||
outgoing_message,
|
||||
)
|
||||
elif isinstance(gitlab_view, GitlabIssue):
|
||||
await gitlab_service.reply_to_issue(
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
None, # no discussion id, issue is tagged
|
||||
outgoing_message,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'[GitLab] Unsupported view type: {type(gitlab_view).__name__}'
|
||||
)
|
||||
|
||||
async def start_job(self, gitlab_view: GitlabViewType):
|
||||
"""
|
||||
Start a job for the GitLab view.
|
||||
|
||||
Args:
|
||||
gitlab_view: The GitLab view object containing issue/PR/comment info
|
||||
"""
|
||||
# Importing here prevents circular import
|
||||
from server.conversation_callback_processor.gitlab_callback_processor import (
|
||||
GitlabCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
try:
|
||||
user_info = gitlab_view.user_info
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Starting job for {user_info.username} in {gitlab_view.full_repo_name}#{gitlab_view.issue_number}'
|
||||
)
|
||||
|
||||
user_token = await self.token_manager.get_idp_token_from_idp_user_id(
|
||||
str(user_info.user_id), ProviderType.GITLAB
|
||||
)
|
||||
|
||||
if not user_token:
|
||||
logger.warning(
|
||||
f'[GitLab] No token found for user {user_info.username} (id={user_info.user_id})'
|
||||
)
|
||||
raise MissingSettingsError('Missing settings')
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Creating new conversation for user {user_info.username}'
|
||||
)
|
||||
|
||||
secret_store = UserSecrets(
|
||||
provider_tokens=MappingProxyType(
|
||||
{
|
||||
ProviderType.GITLAB: ProviderToken(
|
||||
token=SecretStr(user_token),
|
||||
user_id=str(user_info.user_id),
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
await gitlab_view.create_new_conversation(
|
||||
self.jinja_env, secret_store.provider_tokens
|
||||
)
|
||||
|
||||
conversation_id = gitlab_view.conversation_id
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Created conversation {conversation_id} for user {user_info.username}'
|
||||
)
|
||||
|
||||
# Create a GitlabCallbackProcessor for this conversation
|
||||
processor = GitlabCallbackProcessor(
|
||||
gitlab_view=gitlab_view,
|
||||
send_summary_instruction=True,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
msg_info = f"I'm on it! {user_info.username} can [track my progress at all-hands.dev]({conversation_link})"
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(
|
||||
f'[GitLab] Missing settings error for user {user_info.username}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = f'@{user_info.username} please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(
|
||||
f'[GitLab] LLM authentication error for user {user_info.username}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
# Send the acknowledgment message
|
||||
msg = self.create_outgoing_message(msg_info)
|
||||
await self.send_message(msg, gitlab_view)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'[GitLab] Error starting job: {str(e)}')
|
||||
msg = self.create_outgoing_message(
|
||||
msg='Uh oh! There was an unexpected error starting the job :('
|
||||
)
|
||||
await self.send_message(msg, gitlab_view)
|
||||
529
enterprise/integrations/gitlab/gitlab_service.py
Normal file
529
enterprise/integrations/gitlab/gitlab_service.py
Normal file
@@ -0,0 +1,529 @@
|
||||
import asyncio
|
||||
|
||||
from integrations.types import GitLabResourceType
|
||||
from integrations.utils import store_repositories_in_db
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabService
|
||||
from openhands.integrations.service_types import (
|
||||
ProviderType,
|
||||
RateLimitError,
|
||||
Repository,
|
||||
RequestMethod,
|
||||
)
|
||||
from openhands.server.types import AppMode
|
||||
|
||||
|
||||
class SaaSGitLabService(GitLabService):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
external_auth_id: str | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
base_domain: str | None = None,
|
||||
):
|
||||
logger.info(
|
||||
f'SaaSGitLabService created with user_id {user_id}, external_auth_id {external_auth_id}, external_auth_token {'set' if external_auth_token else 'None'}, gitlab_token {'set' if token else 'None'}, external_token_manager {external_token_manager}'
|
||||
)
|
||||
super().__init__(
|
||||
user_id=user_id,
|
||||
external_auth_token=external_auth_token,
|
||||
external_auth_id=external_auth_id,
|
||||
token=token,
|
||||
external_token_manager=external_token_manager,
|
||||
base_domain=base_domain,
|
||||
)
|
||||
|
||||
self.external_auth_token = external_auth_token
|
||||
self.external_auth_id = external_auth_id
|
||||
self.token_manager = TokenManager(external=external_token_manager)
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
gitlab_token = None
|
||||
if self.external_auth_token:
|
||||
gitlab_token = SecretStr(
|
||||
await self.token_manager.get_idp_token(
|
||||
self.external_auth_token.get_secret_value(), idp=ProviderType.GITLAB
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got GitLab token {gitlab_token} from access token: {self.external_auth_token}'
|
||||
)
|
||||
elif self.external_auth_id:
|
||||
offline_token = await self.token_manager.load_offline_token(
|
||||
self.external_auth_id
|
||||
)
|
||||
gitlab_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_offline_token(
|
||||
offline_token, ProviderType.GITLAB
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f'Got GitLab token {gitlab_token.get_secret_value()} from external auth user ID: {self.external_auth_id}'
|
||||
)
|
||||
elif self.user_id:
|
||||
gitlab_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_idp_user_id(
|
||||
self.user_id, ProviderType.GITLAB
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got Gitlab token {gitlab_token} from user ID: {self.user_id}'
|
||||
)
|
||||
else:
|
||||
logger.warning('external_auth_token and user_id not set!')
|
||||
return gitlab_token
|
||||
|
||||
async def get_owned_groups(self) -> list[dict]:
|
||||
"""
|
||||
Get all groups for which the current user is the owner.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of groups owned by the current user.
|
||||
"""
|
||||
url = f'{self.BASE_URL}/groups'
|
||||
params = {'owned': 'true', 'per_page': 100, 'top_level_only': 'true'}
|
||||
|
||||
try:
|
||||
response, headers = await self._make_request(url, params)
|
||||
return response
|
||||
except Exception:
|
||||
logger.warning('Error fetching owned groups', exc_info=True)
|
||||
return []
|
||||
|
||||
async def add_owned_projects_and_groups_to_db(self, owned_personal_projects):
|
||||
"""
|
||||
Add owned projects and groups to the database for webhook tracking.
|
||||
|
||||
Args:
|
||||
owned_personal_projects: List of personal projects owned by the user
|
||||
"""
|
||||
owned_groups = await self.get_owned_groups()
|
||||
webhooks = []
|
||||
|
||||
def build_group_webhook_entries(groups):
|
||||
return [
|
||||
GitlabWebhook(
|
||||
group_id=str(group['id']),
|
||||
project_id=None,
|
||||
user_id=self.external_auth_id,
|
||||
webhook_exists=False,
|
||||
)
|
||||
for group in groups
|
||||
]
|
||||
|
||||
def build_project_webhook_entries(projects):
|
||||
return [
|
||||
GitlabWebhook(
|
||||
group_id=None,
|
||||
project_id=str(project['id']),
|
||||
user_id=self.external_auth_id,
|
||||
webhook_exists=False,
|
||||
)
|
||||
for project in projects
|
||||
]
|
||||
|
||||
# Collect all webhook entries
|
||||
webhooks.extend(build_group_webhook_entries(owned_groups))
|
||||
webhooks.extend(build_project_webhook_entries(owned_personal_projects))
|
||||
|
||||
# Store webhooks in the database
|
||||
if webhooks:
|
||||
try:
|
||||
webhook_store = GitlabWebhookStore()
|
||||
await webhook_store.store_webhooks(webhooks)
|
||||
logger.info(
|
||||
f'Added GitLab webhooks to db for user {self.external_auth_id}'
|
||||
)
|
||||
except Exception:
|
||||
logger.warning('Failed to add Gitlab webhooks to db', exc_info=True)
|
||||
|
||||
async def store_repository_data(
|
||||
self, users_personal_projects: list[dict], repositories: list[Repository]
|
||||
) -> None:
|
||||
"""
|
||||
Store repository data in the database.
|
||||
This function combines the functionality of add_owned_projects_and_groups_to_db and store_repositories_in_db.
|
||||
|
||||
Args:
|
||||
users_personal_projects: List of personal projects owned by the user
|
||||
repositories: List of Repository objects to store
|
||||
"""
|
||||
try:
|
||||
# First, add owned projects and groups to the database
|
||||
await self.add_owned_projects_and_groups_to_db(users_personal_projects)
|
||||
|
||||
# Then, store repositories in the database
|
||||
await store_repositories_in_db(repositories, self.external_auth_id)
|
||||
|
||||
logger.info(
|
||||
f'Successfully stored repository data for user {self.external_auth_id}'
|
||||
)
|
||||
except Exception:
|
||||
logger.warning('Error storing repository data', exc_info=True)
|
||||
|
||||
async def get_all_repositories(
|
||||
self, sort: str, app_mode: AppMode, store_in_background: bool = True
|
||||
) -> list[Repository]:
|
||||
"""
|
||||
Get repositories for the authenticated user, including information about the kind of project.
|
||||
Also collects repositories where the kind is "user" and the user is the owner.
|
||||
|
||||
Args:
|
||||
sort: The field to sort repositories by
|
||||
app_mode: The application mode (OSS or SAAS)
|
||||
|
||||
Returns:
|
||||
List[Repository]: A list of repositories for the authenticated user
|
||||
"""
|
||||
MAX_REPOS = 1000
|
||||
PER_PAGE = 100 # Maximum allowed by GitLab API
|
||||
all_repos: list[dict] = []
|
||||
users_personal_projects: list[dict] = []
|
||||
page = 1
|
||||
|
||||
url = f'{self.BASE_URL}/projects'
|
||||
# Map GitHub's sort values to GitLab's order_by values
|
||||
order_by = {
|
||||
'pushed': 'last_activity_at',
|
||||
'updated': 'last_activity_at',
|
||||
'created': 'created_at',
|
||||
'full_name': 'name',
|
||||
}.get(sort, 'last_activity_at')
|
||||
|
||||
user_id = None
|
||||
try:
|
||||
user_info = await self.get_user()
|
||||
user_id = user_info.id
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not fetch user id: {e}')
|
||||
|
||||
while len(all_repos) < MAX_REPOS:
|
||||
params = {
|
||||
'page': str(page),
|
||||
'per_page': str(PER_PAGE),
|
||||
'order_by': order_by,
|
||||
'sort': 'desc', # GitLab uses sort for direction (asc/desc)
|
||||
'membership': 1, # Use 1 instead of True
|
||||
}
|
||||
|
||||
try:
|
||||
response, headers = await self._make_request(url, params)
|
||||
|
||||
if not response: # No more repositories
|
||||
break
|
||||
|
||||
# Process each repository to identify user-owned ones
|
||||
for repo in response:
|
||||
namespace = repo.get('namespace', {})
|
||||
kind = namespace.get('kind')
|
||||
owner_id = repo.get('owner', {}).get('id')
|
||||
|
||||
# Collect user owned personal projects
|
||||
if kind == 'user' and str(owner_id) == str(user_id):
|
||||
users_personal_projects.append(repo)
|
||||
|
||||
# Add to all repos regardless
|
||||
all_repos.append(repo)
|
||||
|
||||
page += 1
|
||||
|
||||
# Check if we've reached the last page
|
||||
link_header = headers.get('Link', '')
|
||||
if 'rel="next"' not in link_header:
|
||||
break
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f'Error fetching repositories on page {page}', exc_info=True
|
||||
)
|
||||
break
|
||||
|
||||
# Trim to MAX_REPOS if needed and convert to Repository objects
|
||||
all_repos = all_repos[:MAX_REPOS]
|
||||
repositories = [
|
||||
Repository(
|
||||
id=str(repo.get('id')),
|
||||
full_name=str(repo.get('path_with_namespace')),
|
||||
stargazers_count=repo.get('star_count'),
|
||||
git_provider=ProviderType.GITLAB,
|
||||
is_public=repo.get('visibility') == 'public',
|
||||
)
|
||||
for repo in all_repos
|
||||
]
|
||||
|
||||
# Store webhook and repository info
|
||||
if store_in_background:
|
||||
asyncio.create_task(
|
||||
self.store_repository_data(users_personal_projects, repositories)
|
||||
)
|
||||
else:
|
||||
await self.store_repository_data(users_personal_projects, repositories)
|
||||
return repositories
|
||||
|
||||
async def check_resource_exists(
|
||||
self, resource_type: GitLabResourceType, resource_id: str
|
||||
) -> tuple[bool, WebhookStatus | None]:
|
||||
"""
|
||||
Check if resource exists and the user has access to it.
|
||||
|
||||
Args:
|
||||
resource_type: The type of resource
|
||||
resource_id: The ID of resource to check
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing:
|
||||
- bool: True if the resource exists and the user has access to it, False otherwise
|
||||
- str: A reason message explaining the result
|
||||
"""
|
||||
|
||||
if resource_type == GitLabResourceType.GROUP:
|
||||
url = f'{self.BASE_URL}/groups/{resource_id}'
|
||||
else:
|
||||
url = f'{self.BASE_URL}/projects/{resource_id}'
|
||||
|
||||
try:
|
||||
response, _ = await self._make_request(url)
|
||||
# If we get a response, the resource exists and the user has access to it
|
||||
return bool(response and 'id' in response), None
|
||||
except RateLimitError:
|
||||
return False, WebhookStatus.RATE_LIMITED
|
||||
except Exception:
|
||||
logger.warning('Resource existence check failed', exc_info=True)
|
||||
return False, WebhookStatus.INVALID
|
||||
|
||||
async def check_webhook_exists_on_resource(
|
||||
self, resource_type: GitLabResourceType, resource_id: str, webhook_url: str
|
||||
) -> tuple[bool, WebhookStatus | None]:
|
||||
"""
|
||||
Check if a webhook already exists for resource with a specific URL.
|
||||
|
||||
Args:
|
||||
resource_type: The type of resource
|
||||
resource_id: The ID of the resource to check
|
||||
webhook_url: The URL of the webhook to check for
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing:
|
||||
- bool: True if the webhook exists, False otherwise
|
||||
- str: A reason message explaining the result
|
||||
"""
|
||||
|
||||
# Construct the URL based on the resource type
|
||||
if resource_type == GitLabResourceType.GROUP:
|
||||
url = f'{self.BASE_URL}/groups/{resource_id}/hooks'
|
||||
else:
|
||||
url = f'{self.BASE_URL}/projects/{resource_id}/hooks'
|
||||
|
||||
try:
|
||||
# Get all webhooks for the resource
|
||||
response, _ = await self._make_request(url)
|
||||
|
||||
# Check if any webhook has the specified URL
|
||||
exists = False
|
||||
if response:
|
||||
for webhook in response:
|
||||
if webhook.get('url') == webhook_url:
|
||||
exists = True
|
||||
|
||||
return exists, None
|
||||
|
||||
except RateLimitError:
|
||||
return False, WebhookStatus.RATE_LIMITED
|
||||
except Exception:
|
||||
logger.warning('Webhook existence check failed', exc_info=True)
|
||||
return False, WebhookStatus.INVALID
|
||||
|
||||
async def check_user_has_admin_access_to_resource(
|
||||
self, resource_type: GitLabResourceType, resource_id: str
|
||||
) -> tuple[bool, WebhookStatus | None]:
|
||||
"""
|
||||
Check if the user has admin access to resource (is either an owner or maintainer)
|
||||
|
||||
Args:
|
||||
resource_type: The type of resource
|
||||
resource_id: The ID of the resource to check
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing:
|
||||
- bool: True if the user has admin access to the resource (owner or maintainer), False otherwise
|
||||
- str: A reason message explaining the result
|
||||
"""
|
||||
|
||||
# For groups, we need to check if the user is an owner or maintainer
|
||||
if resource_type == GitLabResourceType.GROUP:
|
||||
url = f'{self.BASE_URL}/groups/{resource_id}/members/all'
|
||||
try:
|
||||
response, _ = await self._make_request(url)
|
||||
# Check if the current user is in the members list with access level >= 40 (Maintainer or Owner)
|
||||
|
||||
exists = False
|
||||
if response:
|
||||
current_user = await self.get_user()
|
||||
user_id = current_user.id
|
||||
for member in response:
|
||||
if (
|
||||
str(member.get('id')) == str(user_id)
|
||||
and member.get('access_level', 0) >= 40
|
||||
):
|
||||
exists = True
|
||||
return exists, None
|
||||
except RateLimitError:
|
||||
return False, WebhookStatus.RATE_LIMITED
|
||||
except Exception:
|
||||
return False, WebhookStatus.INVALID
|
||||
|
||||
# For projects, we need to check if the user has maintainer or owner access
|
||||
else:
|
||||
url = f'{self.BASE_URL}/projects/{resource_id}/members/all'
|
||||
try:
|
||||
response, _ = await self._make_request(url)
|
||||
exists = False
|
||||
# Check if the current user is in the members list with access level >= 40 (Maintainer)
|
||||
if response:
|
||||
current_user = await self.get_user()
|
||||
user_id = current_user.id
|
||||
for member in response:
|
||||
if (
|
||||
str(member.get('id')) == str(user_id)
|
||||
and member.get('access_level', 0) >= 40
|
||||
):
|
||||
exists = True
|
||||
return exists, None
|
||||
except RateLimitError:
|
||||
return False, WebhookStatus.RATE_LIMITED
|
||||
except Exception:
|
||||
logger.warning('Admin access check failed', exc_info=True)
|
||||
return False, WebhookStatus.INVALID
|
||||
|
||||
async def install_webhook(
|
||||
self,
|
||||
resource_type: GitLabResourceType,
|
||||
resource_id: str,
|
||||
webhook_name: str,
|
||||
webhook_url: str,
|
||||
webhook_secret: str,
|
||||
webhook_uuid: str,
|
||||
scopes: list[str],
|
||||
) -> tuple[str | None, WebhookStatus | None]:
|
||||
"""
|
||||
Install webhook for user's group or project
|
||||
|
||||
Args:
|
||||
resource_type: The type of resource
|
||||
resource_id: The ID of the resource to check
|
||||
webhook_secret: Webhook secret that is used to verify payload
|
||||
webhook_name: Name of webhook
|
||||
webhook_url: Webhook URL
|
||||
scopes: activity webhook listens for
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing:
|
||||
- bool: True if installation was successful, False otherwise
|
||||
- str: A reason message explaining the result
|
||||
"""
|
||||
|
||||
description = 'Cloud OpenHands Resolver'
|
||||
|
||||
# Set up webhook parameters
|
||||
webhook_data = {
|
||||
'url': webhook_url,
|
||||
'name': webhook_name,
|
||||
'enable_ssl_verification': True,
|
||||
'token': webhook_secret,
|
||||
'description': description,
|
||||
}
|
||||
|
||||
for scope in scopes:
|
||||
webhook_data[scope] = True
|
||||
|
||||
# Add custom headers with user id
|
||||
if self.external_auth_id:
|
||||
webhook_data['custom_headers'] = [
|
||||
{'key': 'X-OpenHands-User-ID', 'value': self.external_auth_id},
|
||||
{'key': 'X-OpenHands-Webhook-ID', 'value': webhook_uuid},
|
||||
]
|
||||
|
||||
if resource_type == GitLabResourceType.GROUP:
|
||||
url = f'{self.BASE_URL}/groups/{resource_id}/hooks'
|
||||
else:
|
||||
url = f'{self.BASE_URL}/projects/{resource_id}/hooks'
|
||||
|
||||
try:
|
||||
# Make the API request
|
||||
response, _ = await self._make_request(
|
||||
url=url, params=webhook_data, method=RequestMethod.POST
|
||||
)
|
||||
|
||||
if response and 'id' in response:
|
||||
return str(response['id']), None
|
||||
|
||||
# Check if the webhook was created successfully
|
||||
return None, None
|
||||
|
||||
except RateLimitError:
|
||||
return None, WebhookStatus.RATE_LIMITED
|
||||
except Exception:
|
||||
logger.warning('Webhook installation failed', exc_info=True)
|
||||
return None, WebhookStatus.INVALID
|
||||
|
||||
async def user_has_write_access(self, project_id: str) -> bool:
|
||||
url = f'{self.BASE_URL}/projects/{project_id}'
|
||||
try:
|
||||
response, _ = await self._make_request(url)
|
||||
# Check if the current user is in the members list with access level >= 30 (Developer)
|
||||
|
||||
if 'permissions' not in response:
|
||||
logger.info('permissions not found', extra={'response': response})
|
||||
return False
|
||||
|
||||
permissions = response['permissions']
|
||||
if permissions['project_access']:
|
||||
logger.info('[GitLab]: Checking project access')
|
||||
return permissions['project_access']['access_level'] >= 30
|
||||
|
||||
if permissions['group_access']:
|
||||
logger.info('[GitLab]: Checking group access')
|
||||
return permissions['group_access']['access_level'] >= 30
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
logger.warning('Access check failed', exc_info=True)
|
||||
return False
|
||||
|
||||
async def reply_to_issue(
|
||||
self, project_id: str, issue_number: str, discussion_id: str | None, body: str
|
||||
):
|
||||
"""
|
||||
Either create new comment thread, or reply to comment thread (depending on discussion_id param)
|
||||
"""
|
||||
try:
|
||||
if discussion_id:
|
||||
url = f'{self.BASE_URL}/projects/{project_id}/issues/{issue_number}/discussions/{discussion_id}/notes'
|
||||
else:
|
||||
url = f'{self.BASE_URL}/projects/{project_id}/issues/{issue_number}/discussions'
|
||||
params = {'body': body}
|
||||
|
||||
await self._make_request(url=url, params=params, method=RequestMethod.POST)
|
||||
except Exception as e:
|
||||
logger.exception(f'[GitLab]: Reply to issue failed {e}')
|
||||
|
||||
async def reply_to_mr(
|
||||
self, project_id: str, merge_request_iid: str, discussion_id: str, body: str
|
||||
):
|
||||
"""
|
||||
Reply to comment thread on MR
|
||||
"""
|
||||
try:
|
||||
url = f'{self.BASE_URL}/projects/{project_id}/merge_requests/{merge_request_iid}/discussions/{discussion_id}/notes'
|
||||
params = {'body': body}
|
||||
|
||||
await self._make_request(url=url, params=params, method=RequestMethod.POST)
|
||||
except Exception as e:
|
||||
logger.exception(f'[GitLab]: Reply to MR failed {e}')
|
||||
450
enterprise/integrations/gitlab/gitlab_view.py
Normal file
450
enterprise/integrations/gitlab/gitlab_view.py
Normal file
@@ -0,0 +1,450 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.models import Message
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import HOST, get_oh_labels, has_exact_mention
|
||||
from jinja2 import Environment
|
||||
from server.auth.token_manager import TokenManager, get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
from openhands.integrations.service_types import Comment
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
CONFIDENTIAL_NOTE = 'confidential_note'
|
||||
NOTE_TYPES = ['note', CONFIDENTIAL_NOTE]
|
||||
|
||||
# =================================================
|
||||
# SECTION: Factory to create appriorate Gitlab view
|
||||
# =================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitlabIssue(ResolverViewInterface):
|
||||
installation_id: str # Webhook installation ID for Gitlab (comes from our DB)
|
||||
issue_number: int
|
||||
project_id: int
|
||||
full_repo_name: str
|
||||
is_public_repo: bool
|
||||
user_info: UserData
|
||||
raw_payload: Message
|
||||
conversation_id: str
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
title: str
|
||||
description: str
|
||||
previous_comments: list[Comment]
|
||||
is_mr: bool
|
||||
|
||||
async def _load_resolver_context(self):
|
||||
gitlab_service = GitLabServiceImpl(
|
||||
external_auth_id=self.user_info.keycloak_user_id
|
||||
)
|
||||
|
||||
self.previous_comments = await gitlab_service.get_issue_or_mr_comments(
|
||||
self.project_id, self.issue_number, is_mr=self.is_mr
|
||||
)
|
||||
|
||||
(
|
||||
self.title,
|
||||
self.description,
|
||||
) = await gitlab_service.get_issue_or_mr_title_and_body(
|
||||
self.project_id, self.issue_number, is_mr=self.is_mr
|
||||
)
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('issue_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_title=self.title,
|
||||
issue_body=self.description,
|
||||
comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
async def create_new_conversation(
|
||||
self, jinja_env: Environment, git_provider_tokens: PROVIDER_TOKEN_TYPE
|
||||
):
|
||||
custom_secrets = await self._get_user_secrets()
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
custom_secrets=custom_secrets,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=None,
|
||||
initial_user_msg=user_instructions,
|
||||
conversation_instructions=conversation_instructions,
|
||||
image_urls=None,
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
replay_json=None,
|
||||
)
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
return self.conversation_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitlabIssueComment(GitlabIssue):
|
||||
comment_body: str
|
||||
discussion_id: str
|
||||
confidential: bool
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('issue_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
issue_comment=self.comment_body
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
issue_title=self.title,
|
||||
issue_body=self.description,
|
||||
comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitlabMRComment(GitlabIssueComment):
|
||||
branch_name: str
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('mr_update_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
mr_comment=self.comment_body,
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'mr_update_conversation_instructions.j2'
|
||||
)
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
mr_number=self.issue_number,
|
||||
branch_name=self.branch_name,
|
||||
mr_title=self.title,
|
||||
mr_body=self.description,
|
||||
comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def create_new_conversation(
|
||||
self, jinja_env: Environment, git_provider_tokens: PROVIDER_TOKEN_TYPE
|
||||
):
|
||||
custom_secrets = await self._get_user_secrets()
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
custom_secrets=custom_secrets,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=self.branch_name,
|
||||
initial_user_msg=user_instructions,
|
||||
conversation_instructions=conversation_instructions,
|
||||
image_urls=None,
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
replay_json=None,
|
||||
)
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
return self.conversation_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitlabInlineMRComment(GitlabMRComment):
|
||||
file_location: str
|
||||
line_number: int
|
||||
|
||||
async def _load_resolver_context(self):
|
||||
gitlab_service = GitLabServiceImpl(
|
||||
external_auth_id=self.user_info.keycloak_user_id
|
||||
)
|
||||
|
||||
(
|
||||
self.title,
|
||||
self.description,
|
||||
) = await gitlab_service.get_issue_or_mr_title_and_body(
|
||||
self.project_id, self.issue_number, is_mr=self.is_mr
|
||||
)
|
||||
|
||||
self.previous_comments = await gitlab_service.get_review_thread_comments(
|
||||
self.project_id, self.issue_number, self.discussion_id
|
||||
)
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('mr_update_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
mr_comment=self.comment_body,
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'mr_update_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
mr_number=self.issue_number,
|
||||
mr_title=self.title,
|
||||
mr_body=self.description,
|
||||
branch_name=self.branch_name,
|
||||
file_location=self.file_location,
|
||||
line_number=self.line_number,
|
||||
comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
|
||||
GitlabViewType = (
|
||||
GitlabInlineMRComment | GitlabMRComment | GitlabIssueComment | GitlabIssue
|
||||
)
|
||||
|
||||
|
||||
class GitlabFactory:
|
||||
@staticmethod
|
||||
def is_labeled_issue(message: Message) -> bool:
|
||||
payload = message.message['payload']
|
||||
object_kind = payload.get('object_kind')
|
||||
event_type = payload.get('event_type')
|
||||
|
||||
if object_kind == 'issue' and event_type == 'issue':
|
||||
changes = payload.get('changes', {})
|
||||
labels = changes.get('labels', {})
|
||||
previous = labels.get('previous', [])
|
||||
current = labels.get('current', [])
|
||||
|
||||
previous_labels = [obj['title'] for obj in previous]
|
||||
current_labels = [obj['title'] for obj in current]
|
||||
|
||||
if OH_LABEL not in previous_labels and OH_LABEL in current_labels:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_issue_comment(message: Message) -> bool:
|
||||
payload = message.message['payload']
|
||||
object_kind = payload.get('object_kind')
|
||||
event_type = payload.get('event_type')
|
||||
issue = payload.get('issue')
|
||||
|
||||
if object_kind == 'note' and event_type in NOTE_TYPES and issue:
|
||||
comment_body = payload.get('object_attributes', {}).get('note', '')
|
||||
return has_exact_mention(comment_body, INLINE_OH_LABEL)
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_mr_comment(message: Message, inline=False) -> bool:
|
||||
payload = message.message['payload']
|
||||
object_kind = payload.get('object_kind')
|
||||
event_type = payload.get('event_type')
|
||||
merge_request = payload.get('merge_request')
|
||||
|
||||
if not (object_kind == 'note' and event_type in NOTE_TYPES and merge_request):
|
||||
return False
|
||||
|
||||
# Check whether not belongs to MR
|
||||
object_attributes = payload.get('object_attributes', {})
|
||||
noteable_type = object_attributes.get('noteable_type')
|
||||
|
||||
if noteable_type != 'MergeRequest':
|
||||
return False
|
||||
|
||||
# Check whether comment is inline
|
||||
change_position = object_attributes.get('change_position')
|
||||
if inline and not change_position:
|
||||
return False
|
||||
if not inline and change_position:
|
||||
return False
|
||||
|
||||
# Check body
|
||||
comment_body = object_attributes.get('note', '')
|
||||
return has_exact_mention(comment_body, INLINE_OH_LABEL)
|
||||
|
||||
@staticmethod
|
||||
def determine_if_confidential(event_type: str):
|
||||
return event_type == CONFIDENTIAL_NOTE
|
||||
|
||||
@staticmethod
|
||||
async def create_gitlab_view_from_payload(
|
||||
message: Message, token_manager: TokenManager
|
||||
) -> ResolverViewInterface:
|
||||
payload = message.message['payload']
|
||||
installation_id = message.message['installation_id']
|
||||
user = payload['user']
|
||||
user_id = user['id']
|
||||
username = user['username']
|
||||
repo_obj = payload['project']
|
||||
selected_project = repo_obj['path_with_namespace']
|
||||
is_public_repo = repo_obj['visibility_level'] == 0
|
||||
project_id = payload['object_attributes']['project_id']
|
||||
|
||||
keycloak_user_id = await token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITLAB
|
||||
)
|
||||
|
||||
user_info = UserData(
|
||||
user_id=user_id, username=username, keycloak_user_id=keycloak_user_id
|
||||
)
|
||||
|
||||
if GitlabFactory.is_labeled_issue(message):
|
||||
issue_iid = payload['object_attributes']['iid']
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Creating view for labeled issue from {username} in {selected_project}#{issue_iid}'
|
||||
)
|
||||
return GitlabIssue(
|
||||
installation_id=installation_id,
|
||||
issue_number=issue_iid,
|
||||
project_id=project_id,
|
||||
full_repo_name=selected_project,
|
||||
is_public_repo=is_public_repo,
|
||||
user_info=user_info,
|
||||
raw_payload=message,
|
||||
conversation_id='',
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=False,
|
||||
)
|
||||
|
||||
elif GitlabFactory.is_issue_comment(message):
|
||||
event_type = payload['event_type']
|
||||
issue_iid = payload['issue']['iid']
|
||||
object_attributes = payload['object_attributes']
|
||||
discussion_id = object_attributes['discussion_id']
|
||||
comment_body = object_attributes['note']
|
||||
logger.info(
|
||||
f'[GitLab] Creating view for issue comment from {username} in {selected_project}#{issue_iid}'
|
||||
)
|
||||
|
||||
return GitlabIssueComment(
|
||||
installation_id=installation_id,
|
||||
comment_body=comment_body,
|
||||
issue_number=issue_iid,
|
||||
discussion_id=discussion_id,
|
||||
project_id=project_id,
|
||||
confidential=GitlabFactory.determine_if_confidential(event_type),
|
||||
full_repo_name=selected_project,
|
||||
is_public_repo=is_public_repo,
|
||||
user_info=user_info,
|
||||
raw_payload=message,
|
||||
conversation_id='',
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=False,
|
||||
)
|
||||
|
||||
elif GitlabFactory.is_mr_comment(message):
|
||||
event_type = payload['event_type']
|
||||
merge_request_iid = payload['merge_request']['iid']
|
||||
branch_name = payload['merge_request']['source_branch']
|
||||
object_attributes = payload['object_attributes']
|
||||
discussion_id = object_attributes['discussion_id']
|
||||
comment_body = object_attributes['note']
|
||||
logger.info(
|
||||
f'[GitLab] Creating view for merge request comment from {username} in {selected_project}#{merge_request_iid}'
|
||||
)
|
||||
|
||||
return GitlabMRComment(
|
||||
installation_id=installation_id,
|
||||
comment_body=comment_body,
|
||||
issue_number=merge_request_iid, # Using issue_number as mr_number for compatibility
|
||||
discussion_id=discussion_id,
|
||||
project_id=project_id,
|
||||
full_repo_name=selected_project,
|
||||
is_public_repo=is_public_repo,
|
||||
user_info=user_info,
|
||||
raw_payload=message,
|
||||
conversation_id='',
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
confidential=GitlabFactory.determine_if_confidential(event_type),
|
||||
branch_name=branch_name,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=True,
|
||||
)
|
||||
|
||||
elif GitlabFactory.is_mr_comment(message, inline=True):
|
||||
event_type = payload['event_type']
|
||||
merge_request_iid = payload['merge_request']['iid']
|
||||
branch_name = payload['merge_request']['source_branch']
|
||||
object_attributes = payload['object_attributes']
|
||||
comment_body = object_attributes['note']
|
||||
position_info = object_attributes['position']
|
||||
discussion_id = object_attributes['discussion_id']
|
||||
file_location = object_attributes['position']['new_path']
|
||||
line_number = (
|
||||
position_info.get('new_line') or position_info.get('old_line') or 0
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Creating view for inline merge request comment from {username} in {selected_project}#{merge_request_iid}'
|
||||
)
|
||||
|
||||
return GitlabInlineMRComment(
|
||||
installation_id=installation_id,
|
||||
issue_number=merge_request_iid, # Using issue_number as mr_number for compatibility
|
||||
discussion_id=discussion_id,
|
||||
project_id=project_id,
|
||||
full_repo_name=selected_project,
|
||||
is_public_repo=is_public_repo,
|
||||
user_info=user_info,
|
||||
raw_payload=message,
|
||||
conversation_id='',
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
confidential=GitlabFactory.determine_if_confidential(event_type),
|
||||
branch_name=branch_name,
|
||||
file_location=file_location,
|
||||
line_number=line_number,
|
||||
comment_body=comment_body,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=True,
|
||||
)
|
||||
503
enterprise/integrations/jira/jira_manager.py
Normal file
503
enterprise/integrations/jira/jira_manager.py
Normal file
@@ -0,0 +1,503 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
from integrations.jira.jira_types import JiraViewInterface
|
||||
from integrations.jira.jira_view import (
|
||||
JiraExistingConversationView,
|
||||
JiraFactory,
|
||||
JiraNewConversationView,
|
||||
)
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import JobContext, Message
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
filter_potential_repos_by_user_msg,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from storage.jira_integration_store import JiraIntegrationStore
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
|
||||
|
||||
class JiraManager(Manager):
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
self.token_manager = token_manager
|
||||
self.integration_store = JiraIntegrationStore.get_instance()
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'jira')
|
||||
)
|
||||
|
||||
async def authenticate_user(
|
||||
self, jira_user_id: str, workspace_id: int
|
||||
) -> tuple[JiraUser | None, UserAuth | None]:
|
||||
"""Authenticate Jira user and get their OpenHands user auth."""
|
||||
|
||||
# Find active Jira user by Keycloak user ID and workspace ID
|
||||
jira_user = await self.integration_store.get_active_user(
|
||||
jira_user_id, workspace_id
|
||||
)
|
||||
|
||||
if not jira_user:
|
||||
logger.warning(
|
||||
f'[Jira] No active Jira user found for {jira_user_id} in workspace {workspace_id}'
|
||||
)
|
||||
return None, None
|
||||
|
||||
saas_user_auth = await get_user_auth_from_keycloak_id(
|
||||
jira_user.keycloak_user_id
|
||||
)
|
||||
return jira_user, saas_user_auth
|
||||
|
||||
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||
"""Get repositories that the user has access to."""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return []
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
'pushed', server_config.app_mode, None, None, None, None
|
||||
)
|
||||
return repos
|
||||
|
||||
async def validate_request(
|
||||
self, request: Request
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
||||
"""Verify Jira webhook signature."""
|
||||
signature_header = request.headers.get('x-hub-signature')
|
||||
signature = signature_header.split('=')[1] if signature_header else None
|
||||
body = await request.body()
|
||||
payload = await request.json()
|
||||
workspace_name = ''
|
||||
|
||||
if payload.get('webhookEvent') == 'comment_created':
|
||||
selfUrl = payload.get('comment', {}).get('author', {}).get('self')
|
||||
elif payload.get('webhookEvent') == 'jira:issue_updated':
|
||||
selfUrl = payload.get('user', {}).get('self')
|
||||
else:
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(selfUrl)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not workspace_name:
|
||||
logger.warning('[Jira] No workspace name found in webhook payload')
|
||||
return False, None, None
|
||||
|
||||
if not signature:
|
||||
logger.warning('[Jira] No signature found in webhook headers')
|
||||
return False, None, None
|
||||
|
||||
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
|
||||
|
||||
if not workspace:
|
||||
logger.warning('[Jira] Could not identify workspace for webhook')
|
||||
return False, None, None
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
|
||||
return False, None, None
|
||||
|
||||
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
|
||||
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
if hmac.compare_digest(signature, digest):
|
||||
logger.info('[Jira] Webhook signature verified successfully')
|
||||
return True, signature, payload
|
||||
|
||||
return False, None, None
|
||||
|
||||
def parse_webhook(self, payload: Dict) -> JobContext | None:
|
||||
event_type = payload.get('webhookEvent')
|
||||
|
||||
if event_type == 'comment_created':
|
||||
comment_data = payload.get('comment', {})
|
||||
comment = comment_data.get('body', '')
|
||||
|
||||
if '@openhands' not in comment:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = comment_data.get('author', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
account_id = user_data.get('accountId')
|
||||
elif event_type == 'jira:issue_updated':
|
||||
changelog = payload.get('changelog', {})
|
||||
items = changelog.get('items', [])
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
if 'openhands' not in labels:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = payload.get('user', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
account_id = user_data.get('accountId')
|
||||
comment = ''
|
||||
else:
|
||||
return None
|
||||
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(base_api_url)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not all(
|
||||
[
|
||||
issue_id,
|
||||
issue_key,
|
||||
user_email,
|
||||
display_name,
|
||||
account_id,
|
||||
workspace_name,
|
||||
base_api_url,
|
||||
]
|
||||
):
|
||||
return None
|
||||
|
||||
return JobContext(
|
||||
issue_id=issue_id,
|
||||
issue_key=issue_key,
|
||||
user_msg=comment,
|
||||
user_email=user_email,
|
||||
display_name=display_name,
|
||||
platform_user_id=account_id,
|
||||
workspace_name=workspace_name,
|
||||
base_api_url=base_api_url,
|
||||
)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira webhook message."""
|
||||
|
||||
payload = message.message.get('payload', {})
|
||||
job_context = self.parse_webhook(payload)
|
||||
|
||||
if not job_context:
|
||||
logger.info('[Jira] Webhook does not match trigger conditions')
|
||||
return
|
||||
|
||||
# Get workspace by user email domain
|
||||
workspace = await self.integration_store.get_workspace_by_name(
|
||||
job_context.workspace_name
|
||||
)
|
||||
if not workspace:
|
||||
logger.warning(
|
||||
f'[Jira] No workspace found for email domain: {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Your workspace is not configured with Jira integration.',
|
||||
None,
|
||||
)
|
||||
return
|
||||
|
||||
# Prevent any recursive triggers from the service account
|
||||
if job_context.user_email == workspace.svc_acc_email:
|
||||
return
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Jira integration is not active for your workspace.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Authenticate user
|
||||
jira_user, saas_user_auth = await self.authenticate_user(
|
||||
job_context.platform_user_id, workspace.id
|
||||
)
|
||||
if not jira_user or not saas_user_auth:
|
||||
logger.warning(
|
||||
f'[Jira] User authentication failed for {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
f'User {job_context.user_email} is not authenticated or active in the Jira integration.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Get issue details
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
issue_title, issue_description = await self.get_issue_details(
|
||||
job_context, workspace.jira_cloud_id, workspace.svc_acc_email, api_key
|
||||
)
|
||||
job_context.issue_title = issue_title
|
||||
job_context.issue_description = issue_description
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to get issue context: {str(e)}')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to retrieve issue details. Please check the issue key and try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create Jira view
|
||||
jira_view = await JiraFactory.create_jira_view_from_payload(
|
||||
job_context,
|
||||
saas_user_auth,
|
||||
jira_user,
|
||||
workspace,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to create jira view: {str(e)}', exc_info=True)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to initialize conversation. Please try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, jira_view):
|
||||
return
|
||||
|
||||
await self.start_job(jira_view)
|
||||
|
||||
async def is_job_requested(
|
||||
self, message: Message, jira_view: JiraViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
if isinstance(jira_view, JiraExistingConversationView):
|
||||
return True
|
||||
|
||||
try:
|
||||
# Get user repositories
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
jira_view.saas_user_auth
|
||||
)
|
||||
|
||||
target_str = f'{jira_view.job_context.issue_description}\n{jira_view.job_context.user_msg}'
|
||||
|
||||
# Try to infer repository from issue description
|
||||
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
|
||||
|
||||
if match:
|
||||
# Found exact repository match
|
||||
jira_view.selected_repo = repos[0].full_name
|
||||
logger.info(f'[Jira] Inferred repository: {repos[0].full_name}')
|
||||
return True
|
||||
else:
|
||||
# No clear match - send repository selection comment
|
||||
await self._send_repo_selection_comment(jira_view)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Error in is_job_requested: {str(e)}')
|
||||
return False
|
||||
|
||||
async def start_job(self, jira_view: JiraViewInterface):
|
||||
"""Start a Jira job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.jira_callback_processor import (
|
||||
JiraCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
user_info: JiraUser = jira_view.jira_user
|
||||
logger.info(
|
||||
f'[Jira] Starting job for user {user_info.keycloak_user_id} '
|
||||
f'issue {jira_view.job_context.issue_key}',
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
conversation_id = await jira_view.create_or_update_conversation(
|
||||
self.jinja_env
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Created/Updated conversation {conversation_id} for issue {jira_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
# Register callback processor for updates
|
||||
if isinstance(jira_view, JiraNewConversationView):
|
||||
processor = JiraCallbackProcessor(
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
workspace_name=jira_view.jira_workspace.name,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
# Send initial response
|
||||
msg_info = jira_view.get_response_msg()
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(f'[Jira] Missing settings error: {str(e)}')
|
||||
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(f'[Jira] LLM authentication error: {str(e)}')
|
||||
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Unexpected error starting job: {str(e)}', exc_info=True
|
||||
)
|
||||
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
|
||||
|
||||
# Send response comment
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to send response message: {str(e)}')
|
||||
|
||||
async def get_issue_details(
|
||||
self,
|
||||
job_context: JobContext,
|
||||
jira_cloud_id: str,
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
) -> Tuple[str, str]:
|
||||
url = f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{job_context.issue_key}'
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, auth=(svc_acc_email, svc_acc_api_key))
|
||||
response.raise_for_status()
|
||||
issue_payload = response.json()
|
||||
|
||||
if not issue_payload:
|
||||
raise ValueError(f'Issue with key {job_context.issue_key} not found.')
|
||||
|
||||
title = issue_payload.get('fields', {}).get('summary', '')
|
||||
description = issue_payload.get('fields', {}).get('description', '')
|
||||
|
||||
if not title:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a title.'
|
||||
)
|
||||
|
||||
if not description:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a description.'
|
||||
)
|
||||
|
||||
return title, description
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Message,
|
||||
issue_key: str,
|
||||
jira_cloud_id: str,
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
):
|
||||
url = (
|
||||
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
|
||||
)
|
||||
data = {'body': message.message}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url, auth=(svc_acc_email, svc_acc_api_key), json=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _send_error_comment(
|
||||
self,
|
||||
job_context: JobContext,
|
||||
error_msg: str,
|
||||
workspace: JiraWorkspace | None,
|
||||
):
|
||||
"""Send error comment to Jira issue."""
|
||||
if not workspace:
|
||||
logger.error('[Jira] Cannot send error comment - no workspace available')
|
||||
return
|
||||
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
issue_key=job_context.issue_key,
|
||||
jira_cloud_id=workspace.jira_cloud_id,
|
||||
svc_acc_email=workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to send error comment: {str(e)}')
|
||||
|
||||
async def _send_repo_selection_comment(self, jira_view: JiraViewInterface):
|
||||
"""Send a comment with repository options for the user to choose."""
|
||||
try:
|
||||
comment_msg = (
|
||||
'I need to know which repository to work with. '
|
||||
'Please add it to your issue description or send a followup comment.'
|
||||
)
|
||||
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Sent repository selection comment for issue {jira_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to send repository selection comment: {str(e)}'
|
||||
)
|
||||
40
enterprise/integrations/jira/jira_types.py
Normal file
40
enterprise/integrations/jira/jira_types.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from integrations.models import JobContext
|
||||
from jinja2 import Environment
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class JiraViewInterface(ABC):
|
||||
"""Interface for Jira views that handle different types of Jira interactions."""
|
||||
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create or update a conversation and return the conversation ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira."""
|
||||
pass
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""Exception raised when starting a conversation fails."""
|
||||
|
||||
pass
|
||||
222
enterprise/integrations/jira/jira_view.py
Normal file
222
enterprise/integrations/jira/jira_view.py
Normal file
@@ -0,0 +1,222 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.jira.jira_types import JiraViewInterface, StartingConvoException
|
||||
from integrations.models import JobContext
|
||||
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||
from jinja2 import Environment
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_integration_store import JiraIntegrationStore
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
integration_store = JiraIntegrationStore.get_instance()
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraNewConversationView(JiraViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_instructions.j2')
|
||||
instructions = instructions_template.render()
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_new_conversation.j2')
|
||||
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
)
|
||||
|
||||
return instructions, user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create a new Jira conversation"""
|
||||
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.jira_user.keycloak_user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=self.selected_repo,
|
||||
selected_branch=None,
|
||||
initial_user_msg=user_msg,
|
||||
conversation_instructions=instructions,
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_trigger=ConversationTrigger.JIRA,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
)
|
||||
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
|
||||
logger.info(f'[Jira] Created conversation {self.conversation_id}')
|
||||
|
||||
# Store Jira conversation mapping
|
||||
jira_conversation = JiraConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
issue_id=self.job_context.issue_id,
|
||||
issue_key=self.job_context.issue_key,
|
||||
jira_user_id=self.jira_user.id,
|
||||
)
|
||||
|
||||
await integration_store.create_conversation(jira_conversation)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [track my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraExistingConversationView(JiraViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_existing_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
)
|
||||
|
||||
return '', user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Update an existing Jira conversation"""
|
||||
|
||||
user_id = self.jira_user.keycloak_user_id
|
||||
|
||||
try:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||
if not metadata:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
# Should we raise here if there are no providers?
|
||||
providers_set = list(provider_tokens.keys()) if provider_tokens else []
|
||||
|
||||
conversation_init_data = await setup_init_conversation_settings(
|
||||
user_id, self.conversation_id, providers_set
|
||||
)
|
||||
|
||||
# Either join ongoing conversation, or restart the conversation
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
agent_state = (
|
||||
None
|
||||
if len(final_agent_observation) == 0
|
||||
else final_agent_observation[0].agent_state
|
||||
)
|
||||
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
class JiraFactory:
|
||||
"""Factory for creating Jira views based on message content"""
|
||||
|
||||
@staticmethod
|
||||
async def create_jira_view_from_payload(
|
||||
job_context: JobContext,
|
||||
saas_user_auth: UserAuth,
|
||||
jira_user: JiraUser,
|
||||
jira_workspace: JiraWorkspace,
|
||||
) -> JiraViewInterface:
|
||||
"""Create appropriate Jira view based on the message and user state"""
|
||||
|
||||
if not jira_user or not saas_user_auth or not jira_workspace:
|
||||
raise StartingConvoException('User not authenticated with Jira integration')
|
||||
|
||||
conversation = await integration_store.get_user_conversations_by_issue_id(
|
||||
job_context.issue_id, jira_user.id
|
||||
)
|
||||
|
||||
if conversation:
|
||||
logger.info(
|
||||
f'[Jira] Found existing conversation for issue {job_context.issue_id}'
|
||||
)
|
||||
return JiraExistingConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_user=jira_user,
|
||||
jira_workspace=jira_workspace,
|
||||
selected_repo=None,
|
||||
conversation_id=conversation.conversation_id,
|
||||
)
|
||||
|
||||
return JiraNewConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_user=jira_user,
|
||||
jira_workspace=jira_workspace,
|
||||
selected_repo=None, # Will be set later after repo inference
|
||||
conversation_id='', # Will be set when conversation is created
|
||||
)
|
||||
508
enterprise/integrations/jira_dc/jira_dc_manager.py
Normal file
508
enterprise/integrations/jira_dc/jira_dc_manager.py
Normal file
@@ -0,0 +1,508 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
from integrations.jira_dc.jira_dc_types import (
|
||||
JiraDcViewInterface,
|
||||
)
|
||||
from integrations.jira_dc.jira_dc_view import (
|
||||
JiraDcExistingConversationView,
|
||||
JiraDcFactory,
|
||||
JiraDcNewConversationView,
|
||||
)
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import JobContext, Message
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
filter_potential_repos_by_user_msg,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from storage.jira_dc_integration_store import JiraDcIntegrationStore
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class JiraDcManager(Manager):
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
self.token_manager = token_manager
|
||||
self.integration_store = JiraDcIntegrationStore.get_instance()
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'jira_dc')
|
||||
)
|
||||
|
||||
async def authenticate_user(
|
||||
self, user_email: str, jira_dc_user_id: str, workspace_id: int
|
||||
) -> tuple[JiraDcUser | None, UserAuth | None]:
|
||||
"""Authenticate Jira DC user and get their OpenHands user auth."""
|
||||
|
||||
if not jira_dc_user_id or jira_dc_user_id == 'none':
|
||||
# Get Keycloak user ID from email
|
||||
keycloak_user_id = await self.token_manager.get_user_id_from_user_email(
|
||||
user_email
|
||||
)
|
||||
if not keycloak_user_id:
|
||||
logger.warning(
|
||||
f'[Jira DC] No Keycloak user found for email: {user_email}'
|
||||
)
|
||||
return None, None
|
||||
|
||||
# Find active Jira DC user by Keycloak user ID and organization
|
||||
jira_dc_user = await self.integration_store.get_active_user_by_keycloak_id_and_workspace(
|
||||
keycloak_user_id, workspace_id
|
||||
)
|
||||
else:
|
||||
jira_dc_user = await self.integration_store.get_active_user(
|
||||
jira_dc_user_id, workspace_id
|
||||
)
|
||||
|
||||
if not jira_dc_user:
|
||||
logger.warning(
|
||||
f'[Jira DC] No active Jira DC user found for {user_email} in workspace {workspace_id}'
|
||||
)
|
||||
return None, None
|
||||
|
||||
saas_user_auth = await get_user_auth_from_keycloak_id(
|
||||
jira_dc_user.keycloak_user_id
|
||||
)
|
||||
return jira_dc_user, saas_user_auth
|
||||
|
||||
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||
"""Get repositories that the user has access to."""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return []
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
'pushed', server_config.app_mode, None, None, None, None
|
||||
)
|
||||
return repos
|
||||
|
||||
async def validate_request(
|
||||
self, request: Request
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
||||
"""Verify Jira DC webhook signature."""
|
||||
signature_header = request.headers.get('x-hub-signature')
|
||||
signature = signature_header.split('=')[1] if signature_header else None
|
||||
body = await request.body()
|
||||
payload = await request.json()
|
||||
workspace_name = ''
|
||||
|
||||
if payload.get('webhookEvent') == 'comment_created':
|
||||
selfUrl = payload.get('comment', {}).get('author', {}).get('self')
|
||||
elif payload.get('webhookEvent') == 'jira:issue_updated':
|
||||
selfUrl = payload.get('user', {}).get('self')
|
||||
else:
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(selfUrl)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not workspace_name:
|
||||
logger.warning('[Jira DC] No workspace name found in webhook payload')
|
||||
return False, None, None
|
||||
|
||||
if not signature:
|
||||
logger.warning('[Jira DC] No signature found in webhook headers')
|
||||
return False, None, None
|
||||
|
||||
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
|
||||
|
||||
if not workspace:
|
||||
logger.warning('[Jira DC] Could not identify workspace for webhook')
|
||||
return False, None, None
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira DC] Workspace {workspace.id} is not active')
|
||||
return False, None, None
|
||||
|
||||
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
|
||||
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
if hmac.compare_digest(signature, digest):
|
||||
logger.info('[Jira DC] Webhook signature verified successfully')
|
||||
return True, signature, payload
|
||||
|
||||
return False, None, None
|
||||
|
||||
def parse_webhook(self, payload: Dict) -> JobContext | None:
|
||||
event_type = payload.get('webhookEvent')
|
||||
|
||||
if event_type == 'comment_created':
|
||||
comment_data = payload.get('comment', {})
|
||||
comment = comment_data.get('body', '')
|
||||
|
||||
if '@openhands' not in comment:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = comment_data.get('author', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
user_key = user_data.get('key')
|
||||
elif event_type == 'jira:issue_updated':
|
||||
changelog = payload.get('changelog', {})
|
||||
items = changelog.get('items', [])
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
if 'openhands' not in labels:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = payload.get('user', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
user_key = user_data.get('key')
|
||||
comment = ''
|
||||
else:
|
||||
return None
|
||||
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(base_api_url)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not all(
|
||||
[
|
||||
issue_id,
|
||||
issue_key,
|
||||
user_email,
|
||||
display_name,
|
||||
user_key,
|
||||
workspace_name,
|
||||
base_api_url,
|
||||
]
|
||||
):
|
||||
return None
|
||||
|
||||
return JobContext(
|
||||
issue_id=issue_id,
|
||||
issue_key=issue_key,
|
||||
user_msg=comment,
|
||||
user_email=user_email,
|
||||
display_name=display_name,
|
||||
platform_user_id=user_key,
|
||||
workspace_name=workspace_name,
|
||||
base_api_url=base_api_url,
|
||||
)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira DC webhook message."""
|
||||
|
||||
payload = message.message.get('payload', {})
|
||||
job_context = self.parse_webhook(payload)
|
||||
|
||||
if not job_context:
|
||||
logger.info('[Jira DC] Webhook does not match trigger conditions')
|
||||
return
|
||||
|
||||
workspace = await self.integration_store.get_workspace_by_name(
|
||||
job_context.workspace_name
|
||||
)
|
||||
if not workspace:
|
||||
logger.warning(
|
||||
f'[Jira DC] No workspace found for email domain: {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Your workspace is not configured with Jira DC integration.',
|
||||
None,
|
||||
)
|
||||
return
|
||||
|
||||
# Prevent any recursive triggers from the service account
|
||||
if job_context.user_email == workspace.svc_acc_email:
|
||||
return
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira DC] Workspace {workspace.id} is not active')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Jira DC integration is not active for your workspace.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Authenticate user
|
||||
jira_dc_user, saas_user_auth = await self.authenticate_user(
|
||||
job_context.user_email, job_context.platform_user_id, workspace.id
|
||||
)
|
||||
if not jira_dc_user or not saas_user_auth:
|
||||
logger.warning(
|
||||
f'[Jira DC] User authentication failed for {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
f'User {job_context.user_email} is not authenticated or active in the Jira DC integration.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Get issue details
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
issue_title, issue_description = await self.get_issue_details(
|
||||
job_context, api_key
|
||||
)
|
||||
job_context.issue_title = issue_title
|
||||
job_context.issue_description = issue_description
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira DC] Failed to get issue context: {str(e)}')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to retrieve issue details. Please check the issue key and try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create Jira DC view
|
||||
jira_dc_view = await JiraDcFactory.create_jira_dc_view_from_payload(
|
||||
job_context,
|
||||
saas_user_auth,
|
||||
jira_dc_user,
|
||||
workspace,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira DC] Failed to create jira dc view: {str(e)}', exc_info=True
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to initialize conversation. Please try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, jira_dc_view):
|
||||
return
|
||||
|
||||
await self.start_job(jira_dc_view)
|
||||
|
||||
async def is_job_requested(
|
||||
self, message: Message, jira_dc_view: JiraDcViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
if isinstance(jira_dc_view, JiraDcExistingConversationView):
|
||||
return True
|
||||
|
||||
try:
|
||||
# Get user repositories
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
jira_dc_view.saas_user_auth
|
||||
)
|
||||
|
||||
target_str = f'{jira_dc_view.job_context.issue_description}\n{jira_dc_view.job_context.user_msg}'
|
||||
|
||||
# Try to infer repository from issue description
|
||||
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
|
||||
|
||||
if match:
|
||||
# Found exact repository match
|
||||
jira_dc_view.selected_repo = repos[0].full_name
|
||||
logger.info(f'[Jira DC] Inferred repository: {repos[0].full_name}')
|
||||
return True
|
||||
else:
|
||||
# No clear match - send repository selection comment
|
||||
await self._send_repo_selection_comment(jira_dc_view)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira DC] Error in is_job_requested: {str(e)}')
|
||||
return False
|
||||
|
||||
async def start_job(self, jira_dc_view: JiraDcViewInterface):
|
||||
"""Start a Jira DC job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.jira_dc_callback_processor import (
|
||||
JiraDcCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
user_info: JiraDcUser = jira_dc_view.jira_dc_user
|
||||
logger.info(
|
||||
f'[Jira DC] Starting job for user {user_info.keycloak_user_id} '
|
||||
f'issue {jira_dc_view.job_context.issue_key}',
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
conversation_id = await jira_dc_view.create_or_update_conversation(
|
||||
self.jinja_env
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Jira DC] Created/Updated conversation {conversation_id} for issue {jira_dc_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
if isinstance(jira_dc_view, JiraDcNewConversationView):
|
||||
# Register callback processor for updates
|
||||
processor = JiraDcCallbackProcessor(
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
workspace_name=jira_dc_view.jira_dc_workspace.name,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Jira DC] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
# Send initial response
|
||||
msg_info = jira_dc_view.get_response_msg()
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(f'[Jira DC] Missing settings error: {str(e)}')
|
||||
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(f'[Jira DC] LLM authentication error: {str(e)}')
|
||||
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira DC] Unexpected error starting job: {str(e)}', exc_info=True
|
||||
)
|
||||
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
|
||||
|
||||
# Send response comment
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_dc_view.jira_dc_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to send response message: {str(e)}')
|
||||
|
||||
async def get_issue_details(
|
||||
self, job_context: JobContext, svc_acc_api_key: str
|
||||
) -> Tuple[str, str]:
|
||||
"""Get issue details from Jira DC API."""
|
||||
url = f'{job_context.base_api_url}/rest/api/2/issue/{job_context.issue_key}'
|
||||
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
issue_payload = response.json()
|
||||
|
||||
if not issue_payload:
|
||||
raise ValueError(f'Issue with key {job_context.issue_key} not found.')
|
||||
|
||||
title = issue_payload.get('fields', {}).get('summary', '')
|
||||
description = issue_payload.get('fields', {}).get('description', '')
|
||||
|
||||
if not title:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a title.'
|
||||
)
|
||||
|
||||
if not description:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a description.'
|
||||
)
|
||||
|
||||
return title, description
|
||||
|
||||
async def send_message(
|
||||
self, message: Message, issue_key: str, base_api_url: str, svc_acc_api_key: str
|
||||
):
|
||||
"""Send message/comment to Jira DC issue."""
|
||||
url = f'{base_api_url}/rest/api/2/issue/{issue_key}/comment'
|
||||
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||
data = {'body': message.message}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _send_error_comment(
|
||||
self,
|
||||
job_context: JobContext,
|
||||
error_msg: str,
|
||||
workspace: JiraDcWorkspace | None,
|
||||
):
|
||||
"""Send error comment to Jira DC issue."""
|
||||
if not workspace:
|
||||
logger.error('[Jira DC] Cannot send error comment - no workspace available')
|
||||
return
|
||||
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
issue_key=job_context.issue_key,
|
||||
base_api_url=job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira DC] Failed to send error comment: {str(e)}')
|
||||
|
||||
async def _send_repo_selection_comment(self, jira_dc_view: JiraDcViewInterface):
|
||||
"""Send a comment with repository options for the user to choose."""
|
||||
try:
|
||||
comment_msg = (
|
||||
'I need to know which repository to work with. '
|
||||
'Please add it to your issue description or send a followup comment.'
|
||||
)
|
||||
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_dc_view.jira_dc_workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Sent repository selection comment for issue {jira_dc_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to send repository selection comment: {str(e)}'
|
||||
)
|
||||
40
enterprise/integrations/jira_dc/jira_dc_types.py
Normal file
40
enterprise/integrations/jira_dc/jira_dc_types.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from integrations.models import JobContext
|
||||
from jinja2 import Environment
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class JiraDcViewInterface(ABC):
|
||||
"""Interface for Jira DC views that handle different types of Jira DC interactions."""
|
||||
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_dc_user: JiraDcUser
|
||||
jira_dc_workspace: JiraDcWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create or update a conversation and return the conversation ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira DC."""
|
||||
pass
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""Exception raised when starting a conversation fails."""
|
||||
|
||||
pass
|
||||
223
enterprise/integrations/jira_dc/jira_dc_view.py
Normal file
223
enterprise/integrations/jira_dc/jira_dc_view.py
Normal file
@@ -0,0 +1,223 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.jira_dc.jira_dc_types import (
|
||||
JiraDcViewInterface,
|
||||
StartingConvoException,
|
||||
)
|
||||
from integrations.models import JobContext
|
||||
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||
from jinja2 import Environment
|
||||
from storage.jira_dc_conversation import JiraDcConversation
|
||||
from storage.jira_dc_integration_store import JiraDcIntegrationStore
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
integration_store = JiraDcIntegrationStore.get_instance()
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_dc_user: JiraDcUser
|
||||
jira_dc_workspace: JiraDcWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_dc_instructions.j2')
|
||||
instructions = instructions_template.render()
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_dc_new_conversation.j2')
|
||||
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
)
|
||||
|
||||
return instructions, user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create a new Jira DC conversation"""
|
||||
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.jira_dc_user.keycloak_user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=self.selected_repo,
|
||||
selected_branch=None,
|
||||
initial_user_msg=user_msg,
|
||||
conversation_instructions=instructions,
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_trigger=ConversationTrigger.JIRA_DC,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
)
|
||||
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
|
||||
logger.info(f'[Jira DC] Created conversation {self.conversation_id}')
|
||||
|
||||
# Store Jira DC conversation mapping
|
||||
jira_dc_conversation = JiraDcConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
issue_id=self.job_context.issue_id,
|
||||
issue_key=self.job_context.issue_key,
|
||||
jira_dc_user_id=self.jira_dc_user.id,
|
||||
)
|
||||
|
||||
await integration_store.create_conversation(jira_dc_conversation)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira DC] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira DC"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [track my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_dc_user: JiraDcUser
|
||||
jira_dc_workspace: JiraDcWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_dc_existing_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
)
|
||||
|
||||
return '', user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Update an existing Jira conversation"""
|
||||
|
||||
user_id = self.jira_dc_user.keycloak_user_id
|
||||
|
||||
try:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||
if not metadata:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
raise ValueError('Could not load provider tokens')
|
||||
providers_set = list(provider_tokens.keys())
|
||||
|
||||
conversation_init_data = await setup_init_conversation_settings(
|
||||
user_id, self.conversation_id, providers_set
|
||||
)
|
||||
|
||||
# Either join ongoing conversation, or restart the conversation
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
agent_state = (
|
||||
None
|
||||
if len(final_agent_observation) == 0
|
||||
else final_agent_observation[0].agent_state
|
||||
)
|
||||
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
class JiraDcFactory:
|
||||
"""Factory class for creating Jira DC views based on message type."""
|
||||
|
||||
@staticmethod
|
||||
async def create_jira_dc_view_from_payload(
|
||||
job_context: JobContext,
|
||||
saas_user_auth: UserAuth,
|
||||
jira_dc_user: JiraDcUser,
|
||||
jira_dc_workspace: JiraDcWorkspace,
|
||||
) -> JiraDcViewInterface:
|
||||
"""Create appropriate Jira DC view based on the payload."""
|
||||
|
||||
if not jira_dc_user or not saas_user_auth or not jira_dc_workspace:
|
||||
raise StartingConvoException('User not authenticated with Jira integration')
|
||||
|
||||
conversation = await integration_store.get_user_conversations_by_issue_id(
|
||||
job_context.issue_id, jira_dc_user.id
|
||||
)
|
||||
|
||||
if conversation:
|
||||
return JiraDcExistingConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_dc_user=jira_dc_user,
|
||||
jira_dc_workspace=jira_dc_workspace,
|
||||
selected_repo=None,
|
||||
conversation_id=conversation.conversation_id,
|
||||
)
|
||||
|
||||
return JiraDcNewConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_dc_user=jira_dc_user,
|
||||
jira_dc_workspace=jira_dc_workspace,
|
||||
selected_repo=None, # Will be set later after repo inference
|
||||
conversation_id='', # Will be set when conversation is created
|
||||
)
|
||||
522
enterprise/integrations/linear/linear_manager.py
Normal file
522
enterprise/integrations/linear/linear_manager.py
Normal file
@@ -0,0 +1,522 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
from integrations.linear.linear_types import LinearViewInterface
|
||||
from integrations.linear.linear_view import (
|
||||
LinearExistingConversationView,
|
||||
LinearFactory,
|
||||
LinearNewConversationView,
|
||||
)
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import JobContext, Message
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
filter_potential_repos_by_user_msg,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from storage.linear_integration_store import LinearIntegrationStore
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class LinearManager(Manager):
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
self.token_manager = token_manager
|
||||
self.integration_store = LinearIntegrationStore.get_instance()
|
||||
self.api_url = 'https://api.linear.app/graphql'
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'linear')
|
||||
)
|
||||
|
||||
async def authenticate_user(
|
||||
self, linear_user_id: str, workspace_id: int
|
||||
) -> tuple[LinearUser | None, UserAuth | None]:
|
||||
"""Authenticate Linear user and get their OpenHands user auth."""
|
||||
|
||||
# Find active Linear user by Linear user ID and workspace ID
|
||||
linear_user = await self.integration_store.get_active_user(
|
||||
linear_user_id, workspace_id
|
||||
)
|
||||
|
||||
if not linear_user:
|
||||
logger.warning(
|
||||
f'[Linear] No active Linear user found for {linear_user_id} in workspace {workspace_id}'
|
||||
)
|
||||
return None, None
|
||||
|
||||
saas_user_auth = await get_user_auth_from_keycloak_id(
|
||||
linear_user.keycloak_user_id
|
||||
)
|
||||
return linear_user, saas_user_auth
|
||||
|
||||
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||
"""Get repositories that the user has access to."""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return []
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
'pushed', server_config.app_mode, None, None, None, None
|
||||
)
|
||||
return repos
|
||||
|
||||
async def validate_request(
|
||||
self, request: Request
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
||||
"""Verify Linear webhook signature."""
|
||||
signature = request.headers.get('linear-signature')
|
||||
body = await request.body()
|
||||
payload = await request.json()
|
||||
actor_url = payload.get('actor', {}).get('url', '')
|
||||
workspace_name = ''
|
||||
|
||||
# Extract workspace name from actor URL
|
||||
# Format: https://linear.app/{workspace}/profiles/{user}
|
||||
if actor_url.startswith('https://linear.app/'):
|
||||
url_parts = actor_url.split('/')
|
||||
if len(url_parts) >= 4:
|
||||
workspace_name = url_parts[3] # Extract workspace name
|
||||
else:
|
||||
logger.warning(f'[Linear] Invalid actor URL format: {actor_url}')
|
||||
return False, None, None
|
||||
else:
|
||||
logger.warning(
|
||||
f'[Linear] Actor URL does not match expected format: {actor_url}'
|
||||
)
|
||||
return False, None, None
|
||||
|
||||
if not workspace_name:
|
||||
logger.warning('[Linear] No workspace name found in webhook payload')
|
||||
return False, None, None
|
||||
|
||||
if not signature:
|
||||
logger.warning('[Linear] No signature found in webhook headers')
|
||||
return False, None, None
|
||||
|
||||
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
|
||||
|
||||
if not workspace:
|
||||
logger.warning('[Linear] Could not identify workspace for webhook')
|
||||
return False, None, None
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Linear] Workspace {workspace.id} is not active')
|
||||
return False, None, None
|
||||
|
||||
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
|
||||
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
if hmac.compare_digest(signature, digest):
|
||||
logger.info('[Linear] Webhook signature verified successfully')
|
||||
return True, signature, payload
|
||||
|
||||
return False, None, None
|
||||
|
||||
def parse_webhook(self, payload: Dict) -> JobContext | None:
|
||||
action = payload.get('action')
|
||||
type = payload.get('type')
|
||||
|
||||
if action == 'create' and type == 'Comment':
|
||||
data = payload.get('data', {})
|
||||
comment = data.get('body', '')
|
||||
|
||||
if '@openhands' not in comment:
|
||||
return None
|
||||
|
||||
issue_data = data.get('issue', {})
|
||||
issue_id = issue_data.get('id', '')
|
||||
issue_key = issue_data.get('identifier', '')
|
||||
elif action == 'update' and type == 'Issue':
|
||||
data = payload.get('data', {})
|
||||
labels = data.get('labels', [])
|
||||
|
||||
has_openhands_label = False
|
||||
label_id = ''
|
||||
for label in labels:
|
||||
if label.get('name') == 'openhands':
|
||||
label_id = label.get('id', '')
|
||||
has_openhands_label = True
|
||||
break
|
||||
|
||||
if not has_openhands_label and not label_id:
|
||||
return None
|
||||
|
||||
labelIdChanges = data.get('updatedFrom', {}).get('labelIds', [])
|
||||
|
||||
if labelIdChanges and label_id in labelIdChanges:
|
||||
return None # Label was added previously, ignore this webhook
|
||||
|
||||
issue_id = data.get('id', '')
|
||||
issue_key = data.get('identifier', '')
|
||||
comment = ''
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
actor = payload.get('actor', {})
|
||||
display_name = actor.get('name', '')
|
||||
user_email = actor.get('email', '')
|
||||
actor_url = actor.get('url', '')
|
||||
actor_id = actor.get('id', '')
|
||||
workspace_name = ''
|
||||
|
||||
if actor_url.startswith('https://linear.app/'):
|
||||
url_parts = actor_url.split('/')
|
||||
if len(url_parts) >= 4:
|
||||
workspace_name = url_parts[3] # Extract workspace name
|
||||
else:
|
||||
logger.warning(f'[Linear] Invalid actor URL format: {actor_url}')
|
||||
return None
|
||||
else:
|
||||
logger.warning(
|
||||
f'[Linear] Actor URL does not match expected format: {actor_url}'
|
||||
)
|
||||
return None
|
||||
|
||||
if not all(
|
||||
[issue_id, issue_key, display_name, user_email, actor_id, workspace_name]
|
||||
):
|
||||
logger.warning('[Linear] Missing required fields in webhook payload')
|
||||
return None
|
||||
|
||||
return JobContext(
|
||||
issue_id=issue_id,
|
||||
issue_key=issue_key,
|
||||
user_msg=comment,
|
||||
user_email=user_email,
|
||||
platform_user_id=actor_id,
|
||||
workspace_name=workspace_name,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Linear webhook message."""
|
||||
payload = message.message.get('payload', {})
|
||||
job_context = self.parse_webhook(payload)
|
||||
|
||||
if not job_context:
|
||||
logger.info('[Linear] Webhook does not match trigger conditions')
|
||||
return
|
||||
|
||||
# Get workspace by user email domain
|
||||
workspace = await self.integration_store.get_workspace_by_name(
|
||||
job_context.workspace_name
|
||||
)
|
||||
if not workspace:
|
||||
logger.warning(
|
||||
f'[Linear] No workspace found for email domain: {job_context.workspace_name}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context.issue_id,
|
||||
'Your workspace is not configured with Linear integration.',
|
||||
None,
|
||||
)
|
||||
return
|
||||
|
||||
# Prevent any recursive triggers from the service account
|
||||
if job_context.user_email == workspace.svc_acc_email:
|
||||
return
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Linear] Workspace {workspace.id} is not active')
|
||||
await self._send_error_comment(
|
||||
job_context.issue_id,
|
||||
'Linear integration is not active for your workspace.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Authenticate user
|
||||
linear_user, saas_user_auth = await self.authenticate_user(
|
||||
job_context.platform_user_id, workspace.id
|
||||
)
|
||||
if not linear_user or not saas_user_auth:
|
||||
logger.warning(
|
||||
f'[Linear] User authentication failed for {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context.issue_id,
|
||||
f'User {job_context.user_email} is not authenticated or active in the Linear integration.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Get issue details
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
issue_title, issue_description = await self.get_issue_details(
|
||||
job_context.issue_id, api_key
|
||||
)
|
||||
job_context.issue_title = issue_title
|
||||
job_context.issue_description = issue_description
|
||||
except Exception as e:
|
||||
logger.error(f'[Linear] Failed to get issue context: {str(e)}')
|
||||
await self._send_error_comment(
|
||||
job_context.issue_id,
|
||||
'Failed to retrieve issue details. Please check the issue ID and try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create Linear view
|
||||
linear_view = await LinearFactory.create_linear_view_from_payload(
|
||||
job_context,
|
||||
saas_user_auth,
|
||||
linear_user,
|
||||
workspace,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Linear] Failed to create linear view: {str(e)}', exc_info=True
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context.issue_id,
|
||||
'Failed to initialize conversation. Please try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, linear_view):
|
||||
return
|
||||
|
||||
await self.start_job(linear_view)
|
||||
|
||||
async def is_job_requested(
|
||||
self, message: Message, linear_view: LinearViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
if isinstance(linear_view, LinearExistingConversationView):
|
||||
return True
|
||||
|
||||
try:
|
||||
# Get user repositories
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
linear_view.saas_user_auth
|
||||
)
|
||||
|
||||
target_str = f'{linear_view.job_context.issue_description}\n{linear_view.job_context.user_msg}'
|
||||
|
||||
# Try to infer repository from issue description
|
||||
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
|
||||
|
||||
if match:
|
||||
# Found exact repository match
|
||||
linear_view.selected_repo = repos[0].full_name
|
||||
logger.info(f'[Linear] Inferred repository: {repos[0].full_name}')
|
||||
return True
|
||||
else:
|
||||
# No clear match - send repository selection comment
|
||||
await self._send_repo_selection_comment(linear_view)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'[Linear] Error in is_job_requested: {str(e)}')
|
||||
return False
|
||||
|
||||
async def start_job(self, linear_view: LinearViewInterface):
|
||||
"""Start a Linear job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.linear_callback_processor import (
|
||||
LinearCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
user_info: LinearUser = linear_view.linear_user
|
||||
logger.info(
|
||||
f'[Linear] Starting job for user {user_info.keycloak_user_id} '
|
||||
f'issue {linear_view.job_context.issue_key}',
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
conversation_id = await linear_view.create_or_update_conversation(
|
||||
self.jinja_env
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Linear] Created/Updated conversation {conversation_id} for issue {linear_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
if isinstance(linear_view, LinearNewConversationView):
|
||||
# Register callback processor for updates
|
||||
processor = LinearCallbackProcessor(
|
||||
issue_id=linear_view.job_context.issue_id,
|
||||
issue_key=linear_view.job_context.issue_key,
|
||||
workspace_name=linear_view.linear_workspace.name,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Linear] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
# Send initial response
|
||||
msg_info = linear_view.get_response_msg()
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(f'[Linear] Missing settings error: {str(e)}')
|
||||
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(f'[Linear] LLM authentication error: {str(e)}')
|
||||
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Linear] Unexpected error starting job: {str(e)}', exc_info=True
|
||||
)
|
||||
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
|
||||
|
||||
# Send response comment
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
linear_view.linear_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
linear_view.job_context.issue_id,
|
||||
api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Linear] Failed to send response message: {str(e)}')
|
||||
|
||||
async def _query_api(self, query: str, variables: Dict, api_key: str) -> Dict:
|
||||
"""Query Linear GraphQL API."""
|
||||
headers = {'Authorization': api_key}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
json={'query': query, 'variables': variables},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_issue_details(self, issue_id: str, api_key: str) -> Tuple[str, str]:
|
||||
"""Get issue details from Linear API."""
|
||||
query = """
|
||||
query Issue($issueId: String!) {
|
||||
issue(id: $issueId) {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
description
|
||||
syncedWith {
|
||||
metadata {
|
||||
... on ExternalEntityInfoGithubMetadata {
|
||||
owner
|
||||
repo
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
issue_payload = await self._query_api(query, {'issueId': issue_id}, api_key)
|
||||
|
||||
if not issue_payload:
|
||||
raise ValueError(f'Issue with ID {issue_id} not found.')
|
||||
|
||||
issue_data = issue_payload.get('data', {}).get('issue', {})
|
||||
title = issue_data.get('title', '')
|
||||
description = issue_data.get('description', '')
|
||||
synced_with = issue_data.get('syncedWith', [])
|
||||
owner = ''
|
||||
repo = ''
|
||||
if synced_with:
|
||||
owner = synced_with[0].get('metadata', {}).get('owner', '')
|
||||
repo = synced_with[0].get('metadata', {}).get('repo', '')
|
||||
|
||||
if not title:
|
||||
raise ValueError(f'Issue with ID {issue_id} does not have a title.')
|
||||
|
||||
if not description:
|
||||
raise ValueError(f'Issue with ID {issue_id} does not have a description.')
|
||||
|
||||
if owner and repo:
|
||||
description += f'\n\nGit Repo: {owner}/{repo}'
|
||||
|
||||
return title, description
|
||||
|
||||
async def send_message(self, message: Message, issue_id: str, api_key: str):
|
||||
"""Send message/comment to Linear issue."""
|
||||
query = """
|
||||
mutation CommentCreate($input: CommentCreateInput!) {
|
||||
commentCreate(input: $input) {
|
||||
success
|
||||
comment {
|
||||
id
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
variables = {'input': {'issueId': issue_id, 'body': message.message}}
|
||||
return await self._query_api(query, variables, api_key)
|
||||
|
||||
async def _send_error_comment(
|
||||
self, issue_id: str, error_msg: str, workspace: LinearWorkspace | None
|
||||
):
|
||||
"""Send error comment to Linear issue."""
|
||||
if not workspace:
|
||||
logger.error('[Linear] Cannot send error comment - no workspace available')
|
||||
return
|
||||
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg), issue_id, api_key
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Linear] Failed to send error comment: {str(e)}')
|
||||
|
||||
async def _send_repo_selection_comment(self, linear_view: LinearViewInterface):
|
||||
"""Send a comment with repository options for the user to choose."""
|
||||
try:
|
||||
comment_msg = (
|
||||
'I need to know which repository to work with. '
|
||||
'Please add it to your issue description or send a followup comment.'
|
||||
)
|
||||
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
linear_view.linear_workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
linear_view.job_context.issue_id,
|
||||
api_key,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Linear] Sent repository selection comment for issue {linear_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Linear] Failed to send repository selection comment: {str(e)}'
|
||||
)
|
||||
40
enterprise/integrations/linear/linear_types.py
Normal file
40
enterprise/integrations/linear/linear_types.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from integrations.models import JobContext
|
||||
from jinja2 import Environment
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class LinearViewInterface(ABC):
|
||||
"""Interface for Linear views that handle different types of Linear interactions."""
|
||||
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
linear_user: LinearUser
|
||||
linear_workspace: LinearWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create or update a conversation and return the conversation ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Linear."""
|
||||
pass
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""Exception raised when starting a conversation fails."""
|
||||
|
||||
pass
|
||||
224
enterprise/integrations/linear/linear_view.py
Normal file
224
enterprise/integrations/linear/linear_view.py
Normal file
@@ -0,0 +1,224 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.linear.linear_types import LinearViewInterface, StartingConvoException
|
||||
from integrations.models import JobContext
|
||||
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||
from jinja2 import Environment
|
||||
from storage.linear_conversation import LinearConversation
|
||||
from storage.linear_integration_store import LinearIntegrationStore
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
integration_store = LinearIntegrationStore.get_instance()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearNewConversationView(LinearViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
linear_user: LinearUser
|
||||
linear_workspace: LinearWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('linear_instructions.j2')
|
||||
instructions = instructions_template.render()
|
||||
|
||||
user_msg_template = jinja_env.get_template('linear_new_conversation.j2')
|
||||
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
)
|
||||
|
||||
return instructions, user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create a new Linear conversation"""
|
||||
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.linear_user.keycloak_user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=self.selected_repo,
|
||||
selected_branch=None,
|
||||
initial_user_msg=user_msg,
|
||||
conversation_instructions=instructions,
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_trigger=ConversationTrigger.LINEAR,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
)
|
||||
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
|
||||
logger.info(f'[Linear] Created conversation {self.conversation_id}')
|
||||
|
||||
# Store Linear conversation mapping
|
||||
linear_conversation = LinearConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
issue_id=self.job_context.issue_id,
|
||||
issue_key=self.job_context.issue_key,
|
||||
linear_user_id=self.linear_user.id,
|
||||
)
|
||||
|
||||
await integration_store.create_conversation(linear_conversation)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Linear] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Linear"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [track my progress here]({conversation_link})."
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearExistingConversationView(LinearViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
linear_user: LinearUser
|
||||
linear_workspace: LinearWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('linear_existing_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
)
|
||||
|
||||
return '', user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Update an existing Linear conversation"""
|
||||
|
||||
user_id = self.linear_user.keycloak_user_id
|
||||
|
||||
try:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||
if not metadata:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
raise ValueError('Could not load provider tokens')
|
||||
providers_set = list(provider_tokens.keys())
|
||||
|
||||
conversation_init_data = await setup_init_conversation_settings(
|
||||
user_id, self.conversation_id, providers_set
|
||||
)
|
||||
|
||||
# Either join ongoing conversation, or restart the conversation
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
agent_state = (
|
||||
None
|
||||
if len(final_agent_observation) == 0
|
||||
else final_agent_observation[0].agent_state
|
||||
)
|
||||
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Linear] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Linear"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here]({conversation_link})."
|
||||
|
||||
|
||||
class LinearFactory:
|
||||
"""Factory for creating Linear views based on message content"""
|
||||
|
||||
@staticmethod
|
||||
async def create_linear_view_from_payload(
|
||||
job_context: JobContext,
|
||||
saas_user_auth: UserAuth,
|
||||
linear_user: LinearUser,
|
||||
linear_workspace: LinearWorkspace,
|
||||
) -> LinearViewInterface:
|
||||
"""Create appropriate Linear view based on the message and user state"""
|
||||
|
||||
if not linear_user or not saas_user_auth or not linear_workspace:
|
||||
raise StartingConvoException(
|
||||
'User not authenticated with Linear integration'
|
||||
)
|
||||
|
||||
conversation = await integration_store.get_user_conversations_by_issue_id(
|
||||
job_context.issue_id, linear_user.id
|
||||
)
|
||||
if conversation:
|
||||
logger.info(
|
||||
f'[Linear] Found existing conversation for issue {job_context.issue_id}'
|
||||
)
|
||||
return LinearExistingConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
linear_user=linear_user,
|
||||
linear_workspace=linear_workspace,
|
||||
selected_repo=None,
|
||||
conversation_id=conversation.conversation_id,
|
||||
)
|
||||
|
||||
return LinearNewConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
linear_user=linear_user,
|
||||
linear_workspace=linear_workspace,
|
||||
selected_repo=None, # Will be set later after repo inference
|
||||
conversation_id='', # Will be set when conversation is created
|
||||
)
|
||||
30
enterprise/integrations/manager.py
Normal file
30
enterprise/integrations/manager.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from integrations.models import Message, SourceType
|
||||
|
||||
|
||||
class Manager(ABC):
|
||||
manager_type: SourceType
|
||||
|
||||
@abstractmethod
|
||||
async def receive_message(self, message: Message):
|
||||
"Receive message from integration"
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_message(self, message: Message):
|
||||
"Send message to integration from Openhands server"
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def is_job_requested(self, message: Message) -> bool:
|
||||
"Confirm that a job is being requested"
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self):
|
||||
"Kick off a job with openhands agent"
|
||||
raise NotImplementedError
|
||||
|
||||
def create_outgoing_message(self, msg: str | dict, ephemeral: bool = False):
|
||||
return Message(source=SourceType.OPENHANDS, message=msg, ephemeral=ephemeral)
|
||||
52
enterprise/integrations/models.py
Normal file
52
enterprise/integrations/models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.schema import AgentState
|
||||
|
||||
|
||||
class SourceType(str, Enum):
|
||||
GITHUB = 'github'
|
||||
GITLAB = 'gitlab'
|
||||
OPENHANDS = 'openhands'
|
||||
SLACK = 'slack'
|
||||
JIRA = 'jira'
|
||||
JIRA_DC = 'jira_dc'
|
||||
LINEAR = 'linear'
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
source: SourceType
|
||||
message: str | dict
|
||||
ephemeral: bool = False
|
||||
|
||||
|
||||
class JobContext(BaseModel):
|
||||
issue_id: str
|
||||
issue_key: str
|
||||
user_msg: str
|
||||
user_email: str
|
||||
display_name: str
|
||||
platform_user_id: str = ''
|
||||
workspace_name: str
|
||||
base_api_url: str = ''
|
||||
issue_title: str = ''
|
||||
issue_description: str = ''
|
||||
|
||||
|
||||
class JobResult:
|
||||
result: str
|
||||
explanation: str
|
||||
|
||||
|
||||
class GithubResolverJob:
|
||||
type: SourceType
|
||||
status: AgentState
|
||||
result: JobResult
|
||||
owner: str
|
||||
repo: str
|
||||
installation_token: str
|
||||
issue_number: int
|
||||
runtime_id: int
|
||||
created_at: int
|
||||
completed_at: int
|
||||
363
enterprise/integrations/slack/slack_manager.py
Normal file
363
enterprise/integrations/slack/slack_manager.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import re
|
||||
|
||||
import jwt
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.slack.slack_types import SlackViewInterface, StartingConvoException
|
||||
from integrations.slack.slack_view import (
|
||||
SlackFactory,
|
||||
SlackNewConversationFromRepoFormView,
|
||||
SlackNewConversationView,
|
||||
SlackUnkownUserView,
|
||||
SlackUpdateExistingConversationView,
|
||||
)
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.constants import SLACK_CLIENT_ID
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from slack_sdk.oauth import AuthorizeUrlGenerator
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from storage.slack_user import SlackUser
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import config, server_config
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
authorize_url_generator = AuthorizeUrlGenerator(
|
||||
client_id=SLACK_CLIENT_ID,
|
||||
scopes=['app_mentions:read', 'chat:write'],
|
||||
user_scopes=['search:read'],
|
||||
)
|
||||
|
||||
|
||||
class SlackManager(Manager):
|
||||
def __init__(self, token_manager):
|
||||
self.token_manager = token_manager
|
||||
self.login_link = (
|
||||
'User has not yet authenticated: [Click here to Login to OpenHands]({}).'
|
||||
)
|
||||
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'slack')
|
||||
)
|
||||
|
||||
def _confirm_incoming_source_type(self, message: Message):
|
||||
if message.source != SourceType.SLACK:
|
||||
raise ValueError(f'Unexpected message source {message.source}')
|
||||
|
||||
async def _get_user_auth(self, keycloak_user_id: str) -> UserAuth:
|
||||
offline_token = await self.token_manager.load_offline_token(keycloak_user_id)
|
||||
if offline_token is None:
|
||||
logger.info('no_offline_token_found')
|
||||
|
||||
user_auth = SaasUserAuth(
|
||||
user_id=keycloak_user_id,
|
||||
refresh_token=SecretStr(offline_token),
|
||||
)
|
||||
return user_auth
|
||||
|
||||
async def authenticate_user(
|
||||
self, slack_user_id: str
|
||||
) -> tuple[SlackUser | None, UserAuth | None]:
|
||||
# We get the user and correlate them back to a user in OpenHands - if we can
|
||||
slack_user = None
|
||||
with session_maker() as session:
|
||||
slack_user = (
|
||||
session.query(SlackUser)
|
||||
.filter(SlackUser.slack_user_id == slack_user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# slack_view.slack_to_openhands_user = slack_user # attach user auth info to view
|
||||
|
||||
saas_user_auth = None
|
||||
if slack_user:
|
||||
saas_user_auth = await self._get_user_auth(slack_user.keycloak_user_id)
|
||||
# slack_view.saas_user_auth = await self._get_user_auth(slack_view.slack_to_openhands_user.keycloak_user_id)
|
||||
|
||||
return slack_user, saas_user_auth
|
||||
|
||||
def _infer_repo_from_message(self, user_msg: str) -> str | None:
|
||||
# Regular expression to match patterns like "All-Hands-AI/OpenHands" or "deploy repo"
|
||||
pattern = r'([a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+)|([a-zA-Z0-9_-]+)(?=\s+repo)'
|
||||
match = re.search(pattern, user_msg)
|
||||
|
||||
if match:
|
||||
repo = match.group(1) if match.group(1) else match.group(2)
|
||||
return repo
|
||||
|
||||
return None
|
||||
|
||||
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return []
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
'pushed', server_config.app_mode, None, None, None, None
|
||||
)
|
||||
return repos
|
||||
|
||||
def _generate_repo_selection_form(
|
||||
self, repo_list: list[Repository], message_ts: str, thread_ts: str | None
|
||||
):
|
||||
options = [
|
||||
{
|
||||
'text': {'type': 'plain_text', 'text': 'No Repository'},
|
||||
'value': '-',
|
||||
}
|
||||
]
|
||||
options.extend(
|
||||
{
|
||||
'text': {
|
||||
'type': 'plain_text',
|
||||
'text': repo.full_name,
|
||||
},
|
||||
'value': repo.full_name,
|
||||
}
|
||||
for repo in repo_list
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
'type': 'header',
|
||||
'text': {
|
||||
'type': 'plain_text',
|
||||
'text': 'Choose a repository',
|
||||
'emoji': True,
|
||||
},
|
||||
},
|
||||
{
|
||||
'type': 'actions',
|
||||
'elements': [
|
||||
{
|
||||
'type': 'static_select',
|
||||
'action_id': f'repository_select:{message_ts}:{thread_ts}',
|
||||
'options': options,
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def filter_potential_repos_by_user_msg(
|
||||
self, user_msg: str, user_repos: list[Repository]
|
||||
) -> tuple[bool, list[Repository]]:
|
||||
inferred_repo = self._infer_repo_from_message(user_msg)
|
||||
if not inferred_repo:
|
||||
return False, user_repos[0:99]
|
||||
|
||||
final_repos = []
|
||||
for repo in user_repos:
|
||||
if inferred_repo.lower() in repo.full_name.lower():
|
||||
final_repos.append(repo)
|
||||
|
||||
# no repos matched, return original list
|
||||
if len(final_repos) == 0:
|
||||
return False, user_repos[0:99]
|
||||
|
||||
# Found exact match
|
||||
elif len(final_repos) == 1:
|
||||
return True, final_repos
|
||||
|
||||
# Found partial matches
|
||||
return False, final_repos[0:99]
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
self._confirm_incoming_source_type(message)
|
||||
|
||||
slack_user, saas_user_auth = await self.authenticate_user(
|
||||
slack_user_id=message.message['slack_user_id']
|
||||
)
|
||||
|
||||
try:
|
||||
slack_view = SlackFactory.create_slack_view_from_payload(
|
||||
message, slack_user, saas_user_auth
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Slack]: Failed to create slack view: {e}',
|
||||
exc_info=True,
|
||||
stack_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(slack_view, SlackUnkownUserView):
|
||||
jwt_secret = config.jwt_secret
|
||||
if not jwt_secret:
|
||||
raise ValueError('Must configure jwt_secret')
|
||||
state = jwt.encode(
|
||||
message.message, jwt_secret.get_secret_value(), algorithm='HS256'
|
||||
)
|
||||
link = authorize_url_generator.generate(state)
|
||||
msg = self.login_link.format(link)
|
||||
|
||||
logger.info('slack_not_yet_authenticated')
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg, ephemeral=True), slack_view
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, slack_view):
|
||||
return
|
||||
|
||||
await self.start_job(slack_view)
|
||||
|
||||
async def send_message(self, message: Message, slack_view: SlackViewInterface):
|
||||
client = AsyncWebClient(token=slack_view.bot_access_token)
|
||||
if message.ephemeral and isinstance(message.message, str):
|
||||
await client.chat_postEphemeral(
|
||||
channel=slack_view.channel_id,
|
||||
markdown_text=message.message,
|
||||
user=slack_view.slack_user_id,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
)
|
||||
elif message.ephemeral and isinstance(message.message, dict):
|
||||
await client.chat_postEphemeral(
|
||||
channel=slack_view.channel_id,
|
||||
user=slack_view.slack_user_id,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
text=message.message['text'],
|
||||
blocks=message.message['blocks'],
|
||||
)
|
||||
else:
|
||||
await client.chat_postMessage(
|
||||
channel=slack_view.channel_id,
|
||||
markdown_text=message.message,
|
||||
thread_ts=slack_view.message_ts,
|
||||
)
|
||||
|
||||
async def is_job_requested(
|
||||
self, message: Message, slack_view: SlackViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
A job is always request we only receive webhooks for events associated with the slack bot
|
||||
This method really just checks
|
||||
1. Is the user is authenticated
|
||||
2. Do we have the necessary information to start a job (either by inferring the selected repo, otherwise asking the user)
|
||||
"""
|
||||
|
||||
# Infer repo from user message is not needed; user selected repo from the form or is updating existing convo
|
||||
if isinstance(slack_view, SlackUpdateExistingConversationView):
|
||||
return True
|
||||
elif isinstance(slack_view, SlackNewConversationFromRepoFormView):
|
||||
return True
|
||||
elif isinstance(slack_view, SlackNewConversationView):
|
||||
user = slack_view.slack_to_openhands_user
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
slack_view.saas_user_auth
|
||||
)
|
||||
match, repos = self.filter_potential_repos_by_user_msg(
|
||||
slack_view.user_msg, user_repos
|
||||
)
|
||||
|
||||
# User mentioned a matching repo is their message, start job without repo selection form
|
||||
if match:
|
||||
slack_view.selected_repo = repos[0].full_name
|
||||
return True
|
||||
|
||||
logger.info(
|
||||
'render_repository_selector',
|
||||
extra={
|
||||
'slack_user_id': user,
|
||||
'keycloak_user_id': user.keycloak_user_id,
|
||||
'message_ts': slack_view.message_ts,
|
||||
'thread_ts': slack_view.thread_ts,
|
||||
},
|
||||
)
|
||||
|
||||
repo_selection_msg = {
|
||||
'text': 'Choose a Repository:',
|
||||
'blocks': self._generate_repo_selection_form(
|
||||
repos, slack_view.message_ts, slack_view.thread_ts
|
||||
),
|
||||
}
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(repo_selection_msg, ephemeral=True),
|
||||
slack_view,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def start_job(self, slack_view: SlackViewInterface):
|
||||
# Importing here prevents circular import
|
||||
from server.conversation_callback_processor.slack_callback_processor import (
|
||||
SlackCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
msg_info = None
|
||||
user_info: SlackUser = slack_view.slack_to_openhands_user
|
||||
try:
|
||||
logger.info(
|
||||
f'[Slack] Starting job for user {user_info.slack_display_name} (id={user_info.slack_user_id})',
|
||||
extra={'keyloak_user_id': user_info.keycloak_user_id},
|
||||
)
|
||||
conversation_id = await slack_view.create_or_update_conversation(
|
||||
self.jinja_env
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Slack] Created conversation {conversation_id} for user {user_info.slack_display_name}'
|
||||
)
|
||||
|
||||
if not isinstance(slack_view, SlackUpdateExistingConversationView):
|
||||
# We don't re-subscribe for follow up messages from slack.
|
||||
# Summaries are generated for every messages anyways, we only need to do
|
||||
# this subscription once for the event which kicked off the job.
|
||||
processor = SlackCallbackProcessor(
|
||||
slack_user_id=slack_view.slack_user_id,
|
||||
channel_id=slack_view.channel_id,
|
||||
message_ts=slack_view.message_ts,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
team_id=slack_view.team_id,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Slack] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
msg_info = slack_view.get_response_msg()
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(
|
||||
f'[Slack] Missing settings error for user {user_info.slack_display_name}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = f'{user_info.slack_display_name} please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(
|
||||
f'[Slack] LLM authentication error for user {user_info.slack_display_name}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = f'@{user_info.slack_display_name} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except StartingConvoException as e:
|
||||
msg_info = str(e)
|
||||
|
||||
await self.send_message(self.create_outgoing_message(msg_info), slack_view)
|
||||
|
||||
except Exception:
|
||||
logger.exception('[Slack]: Error starting job')
|
||||
msg = 'Uh oh! There was an unexpected error starting the job :('
|
||||
await self.send_message(self.create_outgoing_message(msg), slack_view)
|
||||
48
enterprise/integrations/slack/slack_types.py
Normal file
48
enterprise/integrations/slack/slack_types.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from integrations.types import SummaryExtractionTracker
|
||||
from jinja2 import Environment
|
||||
from storage.slack_user import SlackUser
|
||||
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class SlackViewInterface(SummaryExtractionTracker, ABC):
|
||||
bot_access_token: str
|
||||
user_msg: str | None
|
||||
slack_user_id: str
|
||||
slack_to_openhands_user: SlackUser | None
|
||||
saas_user_auth: UserAuth | None
|
||||
channel_id: str
|
||||
message_ts: str
|
||||
thread_ts: str | None
|
||||
selected_repo: str | None
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
conversation_id: str
|
||||
team_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_or_update_conversation(self, jinja_env: Environment):
|
||||
"Create a new conversation"
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_callback_id(self) -> str:
|
||||
"Unique callback id for subscribription made to EventStream for fetching agent summary"
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_response_msg(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""
|
||||
Raised when trying to send message to a conversation that's is still starting up
|
||||
"""
|
||||
435
enterprise/integrations/slack/slack_view.py
Normal file
435
enterprise/integrations/slack/slack_view.py
Normal file
@@ -0,0 +1,435 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.models import Message
|
||||
from integrations.slack.slack_types import SlackViewInterface, StartingConvoException
|
||||
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||
from jinja2 import Environment
|
||||
from slack_sdk import WebClient
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_conversation_store import SlackConversationStore
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
from storage.slack_user import SlackUser
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
# =================================================
|
||||
# SECTION: Github view types
|
||||
# =================================================
|
||||
|
||||
|
||||
CONTEXT_LIMIT = 21
|
||||
slack_conversation_store = SlackConversationStore.get_instance()
|
||||
slack_team_store = SlackTeamStore.get_instance()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackUnkownUserView(SlackViewInterface):
|
||||
bot_access_token: str
|
||||
user_msg: str | None
|
||||
slack_user_id: str
|
||||
slack_to_openhands_user: SlackUser | None
|
||||
saas_user_auth: UserAuth | None
|
||||
channel_id: str
|
||||
message_ts: str
|
||||
thread_ts: str | None
|
||||
selected_repo: str | None
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
conversation_id: str
|
||||
team_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_callback_id(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackNewConversationView(SlackViewInterface):
|
||||
bot_access_token: str
|
||||
user_msg: str | None
|
||||
slack_user_id: str
|
||||
slack_to_openhands_user: SlackUser
|
||||
saas_user_auth: UserAuth
|
||||
channel_id: str
|
||||
message_ts: str
|
||||
thread_ts: str | None
|
||||
selected_repo: str | None
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
conversation_id: str
|
||||
team_id: str
|
||||
|
||||
def _get_initial_prompt(self, text: str, blocks: list[dict]):
|
||||
bot_id = self._get_bot_id(blocks)
|
||||
text = text.replace(f'<@{bot_id}>', '').strip()
|
||||
return text
|
||||
|
||||
def _get_bot_id(self, blocks: list[dict]) -> str:
|
||||
for block in blocks:
|
||||
type_ = block['type']
|
||||
if type_ in ('rich_text', 'rich_text_section'):
|
||||
bot_id = self._get_bot_id(block['elements'])
|
||||
if bot_id:
|
||||
return bot_id
|
||||
if type_ == 'user':
|
||||
return block['user_id']
|
||||
return ''
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
|
||||
messages = []
|
||||
if self.thread_ts:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_replies(
|
||||
channel=self.channel_id,
|
||||
ts=self.thread_ts,
|
||||
inclusive=True,
|
||||
latest=self.message_ts,
|
||||
limit=CONTEXT_LIMIT, # We can be smarter about getting more context/condensing it even in the future
|
||||
)
|
||||
|
||||
messages = result['messages']
|
||||
|
||||
else:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_history(
|
||||
channel=self.channel_id,
|
||||
inclusive=True,
|
||||
latest=self.message_ts,
|
||||
limit=CONTEXT_LIMIT,
|
||||
)
|
||||
|
||||
messages = result['messages']
|
||||
messages.reverse()
|
||||
|
||||
if not messages:
|
||||
raise ValueError('Failed to fetch information from slack API')
|
||||
|
||||
logger.info('got_messages_from_slack', extra={'messages': messages})
|
||||
|
||||
trigger_msg = messages[-1]
|
||||
user_message = self._get_initial_prompt(
|
||||
trigger_msg['text'], trigger_msg['blocks']
|
||||
)
|
||||
|
||||
conversation_instructions = ''
|
||||
|
||||
if len(messages) > 1:
|
||||
messages.pop()
|
||||
text_messages = [m['text'] for m in messages if m.get('text')]
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'user_message_conversation_instructions.j2'
|
||||
)
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
messages=text_messages,
|
||||
username=user_info.slack_display_name,
|
||||
conversation_url=CONVERSATION_URL,
|
||||
)
|
||||
|
||||
return user_message, conversation_instructions
|
||||
|
||||
def _verify_necessary_values_are_set(self):
|
||||
if not self.selected_repo:
|
||||
raise ValueError(
|
||||
'Attempting to start conversation without confirming selected repo from user'
|
||||
)
|
||||
|
||||
async def save_slack_convo(self):
|
||||
if self.slack_to_openhands_user:
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
|
||||
logger.info(
|
||||
'Create slack conversation',
|
||||
extra={
|
||||
'channel_id': self.channel_id,
|
||||
'conversation_id': self.conversation_id,
|
||||
'keycloak_user_id': user_info.keycloak_user_id,
|
||||
'parent_id': self.thread_ts or self.message_ts,
|
||||
},
|
||||
)
|
||||
slack_conversation = SlackConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
channel_id=self.channel_id,
|
||||
keycloak_user_id=user_info.keycloak_user_id,
|
||||
parent_id=self.thread_ts
|
||||
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
|
||||
)
|
||||
await slack_conversation_store.create_slack_conversation(slack_conversation)
|
||||
|
||||
async def create_or_update_conversation(self, jinja: Environment) -> str:
|
||||
"""
|
||||
Only creates a new conversation
|
||||
"""
|
||||
self._verify_necessary_values_are_set()
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||
user_instructions, conversation_instructions = self._get_instructions(jinja)
|
||||
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.slack_to_openhands_user.keycloak_user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=self.selected_repo,
|
||||
selected_branch=None,
|
||||
initial_user_msg=user_instructions,
|
||||
conversation_instructions=conversation_instructions
|
||||
if conversation_instructions
|
||||
else None,
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_trigger=ConversationTrigger.SLACK,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
)
|
||||
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
await self.save_slack_convo()
|
||||
return self.conversation_id
|
||||
|
||||
def get_callback_id(self) -> str:
|
||||
return f'slack_{self.channel_id}_{self.message_ts}'
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {user_info.slack_display_name} can [track my progress here]({conversation_link})."
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackNewConversationFromRepoFormView(SlackNewConversationView):
|
||||
def _verify_necessary_values_are_set(self):
|
||||
# Exclude selected repo check from parent
|
||||
# User can start conversations without a repo when specified via the repo selection form
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
slack_conversation: SlackConversation
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_replies(
|
||||
channel=self.channel_id,
|
||||
ts=self.message_ts,
|
||||
inclusive=True,
|
||||
latest=self.message_ts,
|
||||
limit=1, # Get exact user message, in future we can be smarter with collecting additional context
|
||||
)
|
||||
|
||||
user_message = result['messages'][0]
|
||||
user_message = self._get_initial_prompt(
|
||||
user_message['text'], user_message['blocks']
|
||||
)
|
||||
|
||||
return user_message, ''
|
||||
|
||||
async def create_or_update_conversation(self, jinja: Environment) -> str:
|
||||
"""
|
||||
Send new user message to converation
|
||||
"""
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
saas_user_auth: UserAuth = self.saas_user_auth
|
||||
user_id = user_info.keycloak_user_id
|
||||
|
||||
# Org management in the future will get rid of this
|
||||
# For now, only user that created the conversation can send follow up messages to it
|
||||
if user_id != self.slack_conversation.keycloak_user_id:
|
||||
raise StartingConvoException(
|
||||
f'{user_info.slack_display_name} is not authorized to send messages to this conversation.'
|
||||
)
|
||||
|
||||
# Check if conversation has been deleted
|
||||
# Update logic when soft delete is implemented
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||
if not metadata:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await saas_user_auth.get_provider_tokens()
|
||||
|
||||
# Should we raise here if there are no provider tokens?
|
||||
providers_set = list(provider_tokens.keys()) if provider_tokens else []
|
||||
|
||||
conversation_init_data = await setup_init_conversation_settings(
|
||||
user_id, self.conversation_id, providers_set
|
||||
)
|
||||
|
||||
# Either join ongoing conversation, or restart the conversation
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
agent_state = (
|
||||
None
|
||||
if len(final_agent_observation) == 0
|
||||
else final_agent_observation[0].agent_state
|
||||
)
|
||||
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
user_msg_action = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_msg_action)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
|
||||
def get_response_msg(self):
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {user_info.slack_display_name} can [continue tracking my progress here]({conversation_link})."
|
||||
|
||||
|
||||
class SlackFactory:
|
||||
@staticmethod
|
||||
def did_user_select_repo_from_form(message: Message):
|
||||
payload = message.message
|
||||
return 'selected_repo' in payload
|
||||
|
||||
@staticmethod
|
||||
async def determine_if_updating_existing_conversation(
|
||||
message: Message,
|
||||
) -> SlackConversation | None:
|
||||
payload = message.message
|
||||
channel_id = payload.get('channel_id')
|
||||
thread_ts = payload.get('thread_ts')
|
||||
|
||||
# Follow up conversations must be contained in-thread
|
||||
if not thread_ts:
|
||||
return None
|
||||
|
||||
# thread_ts in slack payloads in the parent's (root level msg's) message ID
|
||||
return await slack_conversation_store.get_slack_conversation(
|
||||
channel_id, thread_ts
|
||||
)
|
||||
|
||||
def create_slack_view_from_payload(
|
||||
message: Message, slack_user: SlackUser | None, saas_user_auth: UserAuth | None
|
||||
):
|
||||
payload = message.message
|
||||
slack_user_id = payload['slack_user_id']
|
||||
channel_id = payload.get('channel_id')
|
||||
message_ts = payload.get('message_ts')
|
||||
thread_ts = payload.get('thread_ts')
|
||||
team_id = payload['team_id']
|
||||
user_msg = payload.get('user_msg')
|
||||
|
||||
bot_access_token = slack_team_store.get_team_bot_token(team_id)
|
||||
if not bot_access_token:
|
||||
logger.error(
|
||||
'Did not find slack team',
|
||||
extra={
|
||||
'slack_user_id': slack_user_id,
|
||||
'channel_id': channel_id,
|
||||
},
|
||||
)
|
||||
raise Exception('Did not slack team')
|
||||
|
||||
# Determine if this is a known slack user by openhands
|
||||
if not slack_user or not saas_user_auth or not channel_id:
|
||||
return SlackUnkownUserView(
|
||||
bot_access_token=bot_access_token,
|
||||
user_msg=user_msg,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_to_openhands_user=slack_user,
|
||||
saas_user_auth=saas_user_auth,
|
||||
channel_id=channel_id,
|
||||
message_ts=message_ts,
|
||||
thread_ts=thread_ts,
|
||||
selected_repo=None,
|
||||
should_extract=False,
|
||||
send_summary_instruction=False,
|
||||
conversation_id='',
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
conversation: SlackConversation | None = call_async_from_sync(
|
||||
SlackFactory.determine_if_updating_existing_conversation,
|
||||
GENERAL_TIMEOUT,
|
||||
message,
|
||||
)
|
||||
if conversation:
|
||||
logger.info(
|
||||
'Found existing slack conversation',
|
||||
extra={
|
||||
'conversation_id': conversation.conversation_id,
|
||||
'parent_id': conversation.parent_id,
|
||||
},
|
||||
)
|
||||
return SlackUpdateExistingConversationView(
|
||||
bot_access_token=bot_access_token,
|
||||
user_msg=user_msg,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_to_openhands_user=slack_user,
|
||||
saas_user_auth=saas_user_auth,
|
||||
channel_id=channel_id,
|
||||
message_ts=message_ts,
|
||||
thread_ts=thread_ts,
|
||||
selected_repo=None,
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
conversation_id=conversation.conversation_id,
|
||||
slack_conversation=conversation,
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
elif SlackFactory.did_user_select_repo_from_form(message):
|
||||
return SlackNewConversationFromRepoFormView(
|
||||
bot_access_token=bot_access_token,
|
||||
user_msg=user_msg,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_to_openhands_user=slack_user,
|
||||
saas_user_auth=saas_user_auth,
|
||||
channel_id=channel_id,
|
||||
message_ts=message_ts,
|
||||
thread_ts=thread_ts,
|
||||
selected_repo=payload['selected_repo'],
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
conversation_id='',
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
else:
|
||||
return SlackNewConversationView(
|
||||
bot_access_token=bot_access_token,
|
||||
user_msg=user_msg,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_to_openhands_user=slack_user,
|
||||
saas_user_auth=saas_user_auth,
|
||||
channel_id=channel_id,
|
||||
message_ts=message_ts,
|
||||
thread_ts=thread_ts,
|
||||
selected_repo=None,
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
conversation_id='',
|
||||
team_id=team_id,
|
||||
)
|
||||
0
enterprise/integrations/solvability/__init__.py
Normal file
0
enterprise/integrations/solvability/__init__.py
Normal file
41
enterprise/integrations/solvability/data/__init__.py
Normal file
41
enterprise/integrations/solvability/data/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Utilities for loading and managing pre-trained classifiers.
|
||||
|
||||
Assumes that classifiers are stored adjacent to this file in the `solvability/data` directory, using a simple
|
||||
`name + .json` pattern.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from integrations.solvability.models.classifier import SolvabilityClassifier
|
||||
|
||||
|
||||
def load_classifier(name: str) -> SolvabilityClassifier:
|
||||
"""
|
||||
Load a classifier by name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the classifier to load.
|
||||
|
||||
Returns:
|
||||
SolvabilityClassifier: The loaded classifier instance.
|
||||
"""
|
||||
data_dir = Path(__file__).parent
|
||||
classifier_path = data_dir / f'{name}.json'
|
||||
|
||||
if not classifier_path.exists():
|
||||
raise FileNotFoundError(f"Classifier '{name}' not found at {classifier_path}")
|
||||
|
||||
with classifier_path.open('r') as f:
|
||||
return SolvabilityClassifier.model_validate_json(f.read())
|
||||
|
||||
|
||||
def available_classifiers() -> list[str]:
|
||||
"""
|
||||
List all available classifiers in the data directory.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of classifier names (without the .json extension).
|
||||
"""
|
||||
data_dir = Path(__file__).parent
|
||||
return [f.stem for f in data_dir.glob('*.json') if f.is_file()]
|
||||
File diff suppressed because one or more lines are too long
38
enterprise/integrations/solvability/models/__init__.py
Normal file
38
enterprise/integrations/solvability/models/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Solvability Models Package
|
||||
|
||||
This package contains the core machine learning models and components for predicting
|
||||
the solvability of GitHub issues and similar technical problems.
|
||||
|
||||
The solvability prediction system works by:
|
||||
1. Using a Featurizer to extract semantic features from issue descriptions via LLM calls
|
||||
2. Training a RandomForestClassifier on these features to predict solvability
|
||||
3. Generating detailed reports with feature importance analysis
|
||||
|
||||
Key Components:
|
||||
- Feature: Defines individual features that can be extracted from issues
|
||||
- Featurizer: Orchestrates LLM-based feature extraction with sampling and batching
|
||||
- SolvabilityClassifier: Main ML pipeline combining featurization and classification
|
||||
- SolvabilityReport: Comprehensive output with predictions, feature analysis, and metadata
|
||||
- ImportanceStrategy: Configurable methods for calculating feature importance (SHAP, permutation, impurity)
|
||||
"""
|
||||
|
||||
from integrations.solvability.models.classifier import SolvabilityClassifier
|
||||
from integrations.solvability.models.featurizer import (
|
||||
EmbeddingDimension,
|
||||
Feature,
|
||||
FeatureEmbedding,
|
||||
Featurizer,
|
||||
)
|
||||
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||
from integrations.solvability.models.report import SolvabilityReport
|
||||
|
||||
__all__ = [
|
||||
'Feature',
|
||||
'EmbeddingDimension',
|
||||
'FeatureEmbedding',
|
||||
'Featurizer',
|
||||
'ImportanceStrategy',
|
||||
'SolvabilityClassifier',
|
||||
'SolvabilityReport',
|
||||
]
|
||||
433
enterprise/integrations/solvability/models/classifier.py
Normal file
433
enterprise/integrations/solvability/models/classifier.py
Normal file
@@ -0,0 +1,433 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import pickle
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import shap
|
||||
from integrations.solvability.models.featurizer import Feature, Featurizer
|
||||
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||
from integrations.solvability.models.report import SolvabilityReport
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
PrivateAttr,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.inspection import permutation_importance
|
||||
from sklearn.utils.validation import check_is_fitted
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
|
||||
|
||||
class SolvabilityClassifier(BaseModel):
|
||||
"""
|
||||
Machine learning pipeline for predicting the solvability of GitHub issues and similar problems.
|
||||
|
||||
This classifier combines LLM-based feature extraction with traditional ML classification:
|
||||
1. Uses a Featurizer to extract semantic boolean features from issue descriptions via LLM calls
|
||||
2. Trains a RandomForestClassifier on these features to predict solvability scores
|
||||
3. Provides feature importance analysis using configurable strategies (SHAP, permutation, impurity)
|
||||
4. Generates comprehensive reports with predictions, feature analysis, and cost metrics
|
||||
|
||||
The classifier supports both training on labeled data and inference on new issues, with built-in
|
||||
support for batch processing and concurrent feature extraction.
|
||||
"""
|
||||
|
||||
identifier: str
|
||||
"""
|
||||
The identifier for the classifier.
|
||||
"""
|
||||
|
||||
featurizer: Featurizer
|
||||
"""
|
||||
The featurizer to use for transforming the input data.
|
||||
"""
|
||||
|
||||
classifier: RandomForestClassifier
|
||||
"""
|
||||
The RandomForestClassifier used for predicting solvability from extracted features.
|
||||
|
||||
This ensemble model provides robust predictions and built-in feature importance metrics.
|
||||
"""
|
||||
|
||||
importance_strategy: ImportanceStrategy = ImportanceStrategy.IMPURITY
|
||||
"""
|
||||
Strategy to use for calculating feature importance.
|
||||
"""
|
||||
|
||||
samples: int = 10
|
||||
"""
|
||||
Number of samples to use for calculating feature embedding coefficients.
|
||||
"""
|
||||
|
||||
random_state: int | None = None
|
||||
"""
|
||||
Random state for reproducibility.
|
||||
"""
|
||||
|
||||
_classifier_attrs: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
"""
|
||||
Private dictionary storing cached results from feature extraction and importance calculations.
|
||||
|
||||
Contains keys like 'features_', 'cost_', 'feature_importances_', and 'labels_' that are populated
|
||||
during transform(), fit(), and predict() operations. Access these via the corresponding properties.
|
||||
|
||||
This field is never serialized, so cached values will not persist across model save/load cycles.
|
||||
"""
|
||||
|
||||
model_config = {
|
||||
'arbitrary_types_allowed': True,
|
||||
}
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_random_state(self) -> SolvabilityClassifier:
|
||||
"""
|
||||
Validate the random state configuration between this object and the classifier.
|
||||
"""
|
||||
# If both random states are set, they definitely need to agree.
|
||||
if self.random_state is not None and self.classifier.random_state is not None:
|
||||
if self.random_state != self.classifier.random_state:
|
||||
raise ValueError(
|
||||
'The random state of the classifier and the top-level classifier must agree.'
|
||||
)
|
||||
|
||||
# Otherwise, we'll always set the classifier's random state to the top-level one.
|
||||
self.classifier.random_state = self.random_state
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def features_(self) -> pd.DataFrame:
|
||||
"""
|
||||
Get the features used by the classifier for the most recent inputs.
|
||||
"""
|
||||
if 'features_' not in self._classifier_attrs:
|
||||
raise ValueError(
|
||||
'SolvabilityClassifier.transform() has not yet been called.'
|
||||
)
|
||||
return self._classifier_attrs['features_']
|
||||
|
||||
@property
|
||||
def cost_(self) -> pd.DataFrame:
|
||||
"""
|
||||
Get the cost of the classifier for the most recent inputs.
|
||||
"""
|
||||
if 'cost_' not in self._classifier_attrs:
|
||||
raise ValueError(
|
||||
'SolvabilityClassifier.transform() has not yet been called.'
|
||||
)
|
||||
return self._classifier_attrs['cost_']
|
||||
|
||||
@property
|
||||
def feature_importances_(self) -> np.ndarray:
|
||||
"""
|
||||
Get the feature importances for the most recent inputs.
|
||||
"""
|
||||
if 'feature_importances_' not in self._classifier_attrs:
|
||||
raise ValueError(
|
||||
'No SolvabilityClassifier methods that produce feature importances (.fit(), .predict_proba(), and '
|
||||
'.predict()) have been called.'
|
||||
)
|
||||
return self._classifier_attrs['feature_importances_'] # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def is_fitted(self) -> bool:
|
||||
"""
|
||||
Check if the classifier is fitted.
|
||||
"""
|
||||
try:
|
||||
check_is_fitted(self.classifier)
|
||||
return True
|
||||
except NotFittedError:
|
||||
return False
|
||||
|
||||
def transform(self, issues: pd.Series, llm_config: LLMConfig) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input issues using the featurizer to extract features.
|
||||
|
||||
This method orchestrates the feature extraction pipeline:
|
||||
1. Uses the featurizer to generate embeddings for all issues
|
||||
2. Converts embeddings to a structured DataFrame
|
||||
3. Separates feature columns from metadata columns
|
||||
4. Stores results for later access via properties
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
llm_config: LLM configuration to use for feature extraction.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: A DataFrame containing only the feature columns (no metadata).
|
||||
"""
|
||||
# Generate feature embeddings for all issues using batch processing
|
||||
feature_embeddings = self.featurizer.embed_batch(
|
||||
issues, samples=self.samples, llm_config=llm_config
|
||||
)
|
||||
df = pd.DataFrame(embedding.to_row() for embedding in feature_embeddings)
|
||||
|
||||
# Split into feature columns (used by classifier) and cost columns (metadata)
|
||||
feature_columns = [feature.identifier for feature in self.featurizer.features]
|
||||
cost_columns = [col for col in df.columns if col not in feature_columns]
|
||||
|
||||
# Store both sets for access via properties
|
||||
self._classifier_attrs['features_'] = df[feature_columns]
|
||||
self._classifier_attrs['cost_'] = df[cost_columns]
|
||||
|
||||
return self.features_
|
||||
|
||||
def fit(
|
||||
self, issues: pd.Series, labels: pd.Series, llm_config: LLMConfig
|
||||
) -> SolvabilityClassifier:
|
||||
"""
|
||||
Fit the classifier to the input issues and labels.
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
|
||||
labels: A pandas Series containing the labels (0 or 1) for each issue.
|
||||
|
||||
llm_config: LLM configuration to use for feature extraction.
|
||||
|
||||
Returns:
|
||||
SolvabilityClassifier: The fitted classifier.
|
||||
"""
|
||||
features = self.transform(issues, llm_config=llm_config)
|
||||
self.classifier.fit(features, labels)
|
||||
|
||||
# Store labels for permutation importance calculation
|
||||
self._classifier_attrs['labels_'] = labels
|
||||
self._classifier_attrs['feature_importances_'] = self._importance(
|
||||
features, self.classifier.predict_proba(features), labels
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def predict_proba(self, issues: pd.Series, llm_config: LLMConfig) -> np.ndarray:
|
||||
"""
|
||||
Predict the solvability probabilities for the input issues.
|
||||
|
||||
Returns class probabilities where the second column represents the probability
|
||||
of the issue being solvable (positive class).
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
llm_config: LLM configuration to use for feature extraction.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of shape (n_samples, 2) with probabilities for each class.
|
||||
Column 0: probability of not solvable, Column 1: probability of solvable.
|
||||
"""
|
||||
features = self.transform(issues, llm_config=llm_config)
|
||||
scores = self.classifier.predict_proba(features)
|
||||
|
||||
# Calculate feature importances based on the configured strategy
|
||||
# For permutation importance, we need ground truth labels if available
|
||||
labels = self._classifier_attrs.get('labels_')
|
||||
if (
|
||||
self.importance_strategy == ImportanceStrategy.PERMUTATION
|
||||
and labels is not None
|
||||
):
|
||||
self._classifier_attrs['feature_importances_'] = self._importance(
|
||||
features, scores, labels
|
||||
)
|
||||
else:
|
||||
self._classifier_attrs['feature_importances_'] = self._importance(
|
||||
features, scores
|
||||
)
|
||||
|
||||
return scores # type: ignore[no-any-return]
|
||||
|
||||
def predict(self, issues: pd.Series, llm_config: LLMConfig) -> np.ndarray:
|
||||
"""
|
||||
Predict the solvability of the input issues by returning binary labels.
|
||||
|
||||
Uses a 0.5 probability threshold to convert probabilities to binary predictions.
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
llm_config: LLM configuration to use for feature extraction.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Boolean array where True indicates the issue is predicted as solvable.
|
||||
"""
|
||||
probabilities = self.predict_proba(issues, llm_config=llm_config)
|
||||
# Apply 0.5 threshold to convert probabilities to binary predictions
|
||||
labels = probabilities[:, 1] >= 0.5
|
||||
return labels
|
||||
|
||||
def _importance(
|
||||
self,
|
||||
features: pd.DataFrame,
|
||||
scores: np.ndarray,
|
||||
labels: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Calculate feature importance scores using the configured strategy.
|
||||
|
||||
Different strategies provide different interpretations:
|
||||
- SHAP: Shapley values indicating contribution to individual predictions
|
||||
- PERMUTATION: Decrease in model performance when feature is shuffled
|
||||
- IMPURITY: Gini impurity decrease from splits on each feature
|
||||
|
||||
Args:
|
||||
features: Feature matrix used for predictions.
|
||||
scores: Model prediction scores (unused for some strategies).
|
||||
labels: Ground truth labels (required for permutation importance).
|
||||
|
||||
Returns:
|
||||
np.ndarray: Feature importance scores, one per feature.
|
||||
"""
|
||||
match self.importance_strategy:
|
||||
case ImportanceStrategy.SHAP:
|
||||
# Use SHAP TreeExplainer for tree-based models
|
||||
explainer = shap.TreeExplainer(self.classifier)
|
||||
shap_values = explainer.shap_values(features)
|
||||
# Return mean SHAP values for the positive class (solvable)
|
||||
return shap_values.mean(axis=0)[:, 1] # type: ignore[no-any-return]
|
||||
|
||||
case ImportanceStrategy.PERMUTATION:
|
||||
# Permutation importance requires ground truth labels
|
||||
if labels is None:
|
||||
raise ValueError('Labels are required for permutation importance')
|
||||
result = permutation_importance(
|
||||
self.classifier,
|
||||
features,
|
||||
labels,
|
||||
n_repeats=10, # Number of permutation rounds for stability
|
||||
random_state=self.random_state,
|
||||
)
|
||||
return result.importances_mean # type: ignore[no-any-return]
|
||||
|
||||
case ImportanceStrategy.IMPURITY:
|
||||
# Use built-in feature importances from RandomForest
|
||||
return self.classifier.feature_importances_ # type: ignore[no-any-return]
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f'Unknown importance strategy: {self.importance_strategy}'
|
||||
)
|
||||
|
||||
def add_features(self, features: list[Feature]) -> SolvabilityClassifier:
|
||||
"""
|
||||
Add new features to the classifier's featurizer.
|
||||
|
||||
Note: Adding features after training requires retraining the classifier
|
||||
since the feature space will have changed.
|
||||
|
||||
Args:
|
||||
features: List of Feature objects to add.
|
||||
|
||||
Returns:
|
||||
SolvabilityClassifier: Self for method chaining.
|
||||
"""
|
||||
for feature in features:
|
||||
if feature not in self.featurizer.features:
|
||||
self.featurizer.features.append(feature)
|
||||
return self
|
||||
|
||||
def forget_features(self, features: list[Feature]) -> SolvabilityClassifier:
|
||||
"""
|
||||
Remove features from the classifier's featurizer.
|
||||
|
||||
Note: Removing features after training requires retraining the classifier
|
||||
since the feature space will have changed.
|
||||
|
||||
Args:
|
||||
features: List of Feature objects to remove.
|
||||
|
||||
Returns:
|
||||
SolvabilityClassifier: Self for method chaining.
|
||||
"""
|
||||
for feature in features:
|
||||
try:
|
||||
self.featurizer.features.remove(feature)
|
||||
except ValueError:
|
||||
# Feature not in list, continue with others
|
||||
continue
|
||||
return self
|
||||
|
||||
@field_serializer('classifier')
|
||||
@staticmethod
|
||||
def _rfc_to_json(rfc: RandomForestClassifier) -> str:
|
||||
"""
|
||||
Convert a RandomForestClassifier to a JSON-compatible value (a string).
|
||||
"""
|
||||
return base64.b64encode(pickle.dumps(rfc)).decode('utf-8')
|
||||
|
||||
@field_validator('classifier', mode='before')
|
||||
@staticmethod
|
||||
def _json_to_rfc(value: str | RandomForestClassifier) -> RandomForestClassifier:
|
||||
"""
|
||||
Convert a JSON-compatible value (a string) back to a RandomForestClassifier.
|
||||
"""
|
||||
if isinstance(value, RandomForestClassifier):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
model = pickle.loads(base64.b64decode(value))
|
||||
if isinstance(model, RandomForestClassifier):
|
||||
return model
|
||||
except Exception as e:
|
||||
raise ValueError(f'Failed to decode the classifier: {e}')
|
||||
|
||||
raise ValueError(
|
||||
'The classifier must be a RandomForestClassifier or a JSON-compatible dictionary.'
|
||||
)
|
||||
|
||||
def solvability_report(
|
||||
self, issue: str, llm_config: LLMConfig, **kwargs: Any
|
||||
) -> SolvabilityReport:
|
||||
"""
|
||||
Generate a solvability report for the given issue.
|
||||
|
||||
Args:
|
||||
issue: The issue description for which to generate the report.
|
||||
llm_config: Optional LLM configuration to use for feature extraction.
|
||||
kwargs: Additional metadata to include in the report.
|
||||
|
||||
Returns:
|
||||
SolvabilityReport: The generated solvability report.
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError(
|
||||
'The classifier must be fitted before generating a report.'
|
||||
)
|
||||
|
||||
scores = self.predict_proba(pd.Series([issue]), llm_config=llm_config)
|
||||
|
||||
return SolvabilityReport(
|
||||
identifier=self.identifier,
|
||||
issue=issue,
|
||||
score=scores[0, 1],
|
||||
features=self.features_.iloc[0].to_dict(),
|
||||
samples=self.samples,
|
||||
importance_strategy=self.importance_strategy,
|
||||
# Unlike the features, the importances are just a series with no link
|
||||
# to the actual feature names. For that we have to recombine with the
|
||||
# feature identifiers.
|
||||
feature_importances=dict(
|
||||
zip(
|
||||
self.featurizer.feature_identifiers(),
|
||||
self.feature_importances_.tolist(),
|
||||
)
|
||||
),
|
||||
random_state=self.random_state,
|
||||
metadata=dict(kwargs) if kwargs else None,
|
||||
# Both cost and response_latency are columns in the cost_ DataFrame,
|
||||
# so we can get both by just unpacking the first row.
|
||||
**self.cost_.iloc[0].to_dict(),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, issue: str, llm_config: LLMConfig, **kwargs: Any
|
||||
) -> SolvabilityReport:
|
||||
"""
|
||||
Generate a solvability report for the given issue.
|
||||
"""
|
||||
return self.solvability_report(issue, llm_config=llm_config, **kwargs)
|
||||
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DifficultyLevel(Enum):
|
||||
"""Enum representing the difficulty level based on solvability score."""
|
||||
|
||||
EASY = ('EASY', 0.7, '🟢')
|
||||
MEDIUM = ('MEDIUM', 0.4, '🟡')
|
||||
HARD = ('HARD', 0.0, '🔴')
|
||||
|
||||
def __init__(self, label: str, threshold: float, emoji: str):
|
||||
self.label = label
|
||||
self.threshold = threshold
|
||||
self.emoji = emoji
|
||||
|
||||
@classmethod
|
||||
def from_score(cls, score: float) -> DifficultyLevel:
|
||||
"""Get difficulty level from a solvability score.
|
||||
|
||||
Returns the difficulty level with the highest threshold that is less than or equal to the given score.
|
||||
"""
|
||||
# Sort enum values by threshold in descending order
|
||||
sorted_levels = sorted(cls, key=lambda x: x.threshold, reverse=True)
|
||||
|
||||
# Find the first level where score meets the threshold
|
||||
for level in sorted_levels:
|
||||
if score >= level.threshold:
|
||||
return level
|
||||
|
||||
# This should never happen if thresholds are set correctly,
|
||||
# but return the lowest threshold level as fallback
|
||||
return sorted_levels[-1]
|
||||
|
||||
def format_display(self) -> str:
|
||||
"""Format the difficulty level for display."""
|
||||
return f'{self.emoji} **Solvability: {self.label}**'
|
||||
368
enterprise/integrations/solvability/models/featurizer.py
Normal file
368
enterprise/integrations/solvability/models/featurizer.py
Normal file
@@ -0,0 +1,368 @@
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
|
||||
class Feature(BaseModel):
|
||||
"""
|
||||
Represents a single boolean feature that can be extracted from issue descriptions.
|
||||
|
||||
Features are semantic properties of issues (e.g., "has_code_example", "requires_debugging")
|
||||
that are evaluated by LLMs and used as input to the solvability classifier.
|
||||
"""
|
||||
|
||||
identifier: str
|
||||
"""Unique identifier for the feature, used as column name in feature matrices."""
|
||||
|
||||
description: str
|
||||
"""Human-readable description of what the feature represents, used in LLM prompts."""
|
||||
|
||||
@property
|
||||
def to_tool_description_field(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert this feature to a JSON schema field for LLM tool calling.
|
||||
|
||||
Returns:
|
||||
dict: JSON schema field definition for this feature.
|
||||
"""
|
||||
return {
|
||||
'type': 'boolean',
|
||||
'description': self.description,
|
||||
}
|
||||
|
||||
|
||||
class EmbeddingDimension(BaseModel):
|
||||
"""
|
||||
Represents a single dimension (feature evaluation) within a feature embedding sample.
|
||||
|
||||
Each dimension corresponds to one feature being evaluated as true/false for a given issue.
|
||||
"""
|
||||
|
||||
feature_id: str
|
||||
"""Identifier of the feature being evaluated."""
|
||||
|
||||
result: bool
|
||||
"""Boolean result of the feature evaluation for this sample."""
|
||||
|
||||
|
||||
# Type alias for a single embedding sample - maps feature identifiers to boolean values
|
||||
EmbeddingSample = dict[str, bool]
|
||||
"""
|
||||
A single sample from the LLM evaluation of features for an issue.
|
||||
Maps feature identifiers to their boolean evaluations.
|
||||
"""
|
||||
|
||||
|
||||
class FeatureEmbedding(BaseModel):
|
||||
"""
|
||||
Represents the complete feature embedding for a single issue, including multiple samples
|
||||
and associated metadata about the LLM calls used to generate it.
|
||||
|
||||
Multiple samples are collected to account for LLM variability and provide more robust
|
||||
feature estimates through averaging.
|
||||
"""
|
||||
|
||||
samples: list[EmbeddingSample]
|
||||
"""List of individual feature evaluation samples from the LLM."""
|
||||
|
||||
prompt_tokens: int | None = None
|
||||
"""Total prompt tokens consumed across all LLM calls for this embedding."""
|
||||
|
||||
completion_tokens: int | None = None
|
||||
"""Total completion tokens generated across all LLM calls for this embedding."""
|
||||
|
||||
response_latency: float | None = None
|
||||
"""Total response latency (seconds) across all LLM calls for this embedding."""
|
||||
|
||||
@property
|
||||
def dimensions(self) -> list[str]:
|
||||
"""
|
||||
Get all unique feature identifiers present across all samples.
|
||||
|
||||
Returns:
|
||||
list[str]: List of feature identifiers that appear in at least one sample.
|
||||
"""
|
||||
dims: set[str] = set()
|
||||
for sample in self.samples:
|
||||
dims.update(sample.keys())
|
||||
return list(dims)
|
||||
|
||||
def coefficient(self, dimension: str) -> float | None:
|
||||
"""
|
||||
Calculate the average coefficient (0-1) for a specific feature dimension.
|
||||
|
||||
This computes the proportion of samples where the feature was evaluated as True,
|
||||
providing a continuous feature value for the classifier.
|
||||
|
||||
Args:
|
||||
dimension: Feature identifier to calculate coefficient for.
|
||||
|
||||
Returns:
|
||||
float | None: Average coefficient (0.0-1.0), or None if dimension not found.
|
||||
"""
|
||||
# Extract boolean values for this dimension, converting to 0/1
|
||||
values = [
|
||||
1 if v else 0
|
||||
for v in [sample.get(dimension) for sample in self.samples]
|
||||
if v is not None
|
||||
]
|
||||
if values:
|
||||
return sum(values) / len(values)
|
||||
return None
|
||||
|
||||
def to_row(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert the embedding to a flat dictionary suitable for DataFrame construction.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Dictionary with metadata fields and feature coefficients.
|
||||
"""
|
||||
return {
|
||||
'response_latency': self.response_latency,
|
||||
'prompt_tokens': self.prompt_tokens,
|
||||
'completion_tokens': self.completion_tokens,
|
||||
**{dimension: self.coefficient(dimension) for dimension in self.dimensions},
|
||||
}
|
||||
|
||||
def sample_entropy(self) -> dict[str, float]:
|
||||
"""
|
||||
Calculate the Shannon entropy of feature evaluations across samples.
|
||||
|
||||
Higher entropy indicates more variability in LLM responses for a feature,
|
||||
which may suggest ambiguity in the feature definition or issue description.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: Mapping of feature identifiers to their entropy values (0-1).
|
||||
"""
|
||||
from collections import Counter
|
||||
from math import log2
|
||||
|
||||
entropy = {}
|
||||
for dimension in self.dimensions:
|
||||
# Count True/False occurrences for this feature across samples
|
||||
counts = Counter(sample.get(dimension, False) for sample in self.samples)
|
||||
total = sum(counts.values())
|
||||
if total == 0:
|
||||
entropy[dimension] = 0.0
|
||||
continue
|
||||
# Calculate Shannon entropy: -Σ(p * log2(p))
|
||||
entropy_value = -sum(
|
||||
(count / total) * log2(count / total)
|
||||
for count in counts.values()
|
||||
if count > 0
|
||||
)
|
||||
entropy[dimension] = entropy_value
|
||||
return entropy
|
||||
|
||||
|
||||
class Featurizer(BaseModel):
|
||||
"""
|
||||
Orchestrates LLM-based feature extraction from issue descriptions.
|
||||
|
||||
The Featurizer uses structured LLM tool calling to evaluate boolean features
|
||||
for issue descriptions. It handles prompt construction, tool schema generation,
|
||||
and batch processing with concurrency.
|
||||
"""
|
||||
|
||||
system_prompt: str
|
||||
"""System prompt that provides context and instructions to the LLM."""
|
||||
|
||||
message_prefix: str
|
||||
"""Prefix added to user messages before the issue description."""
|
||||
|
||||
features: list[Feature]
|
||||
"""List of features to extract from each issue description."""
|
||||
|
||||
def system_message(self) -> dict[str, Any]:
|
||||
"""
|
||||
Construct the system message for LLM conversations.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: System message dictionary for LLM API calls.
|
||||
"""
|
||||
return {
|
||||
'role': 'system',
|
||||
'content': self.system_prompt,
|
||||
}
|
||||
|
||||
def user_message(
|
||||
self, issue_description: str, set_cache: bool = True
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Construct the user message containing the issue description.
|
||||
|
||||
Args:
|
||||
issue_description: The description of the issue to analyze.
|
||||
set_cache: Whether to enable ephemeral caching for this message.
|
||||
Should be False for single samples to avoid cache overhead.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: User message dictionary for LLM API calls.
|
||||
"""
|
||||
message: dict[str, Any] = {
|
||||
'role': 'user',
|
||||
'content': f'{self.message_prefix}{issue_description}',
|
||||
}
|
||||
if set_cache:
|
||||
message['cache_control'] = {'type': 'ephemeral'}
|
||||
return message
|
||||
|
||||
@property
|
||||
def tool_choice(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get the tool choice configuration for forcing LLM to use the featurizer tool.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Tool choice configuration for LLM API calls.
|
||||
"""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {'name': 'call_featurizer'},
|
||||
}
|
||||
|
||||
@property
|
||||
def tool_description(self) -> dict[str, Any]:
|
||||
"""
|
||||
Generate the tool schema for the featurizer function.
|
||||
|
||||
Creates a JSON schema that describes the featurizer tool with all configured
|
||||
features as boolean parameters.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Complete tool description for LLM API calls.
|
||||
"""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'call_featurizer',
|
||||
'description': 'Record the features present in the issue.',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
feature.identifier: feature.to_tool_description_field
|
||||
for feature in self.features
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def embed(
|
||||
self,
|
||||
issue_description: str,
|
||||
llm_config: LLMConfig,
|
||||
temperature: float = 1.0,
|
||||
samples: int = 10,
|
||||
) -> FeatureEmbedding:
|
||||
"""
|
||||
Generate a feature embedding for a single issue description.
|
||||
|
||||
Makes multiple LLM calls to collect samples and reduce variance in feature evaluations.
|
||||
Each call uses tool calling to extract structured boolean feature values.
|
||||
|
||||
Args:
|
||||
issue_description: The description of the issue to analyze.
|
||||
llm_config: Configuration for the LLM to use.
|
||||
temperature: Sampling temperature for the model. Higher values increase randomness.
|
||||
samples: Number of samples to generate for averaging.
|
||||
|
||||
Returns:
|
||||
FeatureEmbedding: Complete embedding with samples and metadata.
|
||||
"""
|
||||
embedding_samples: list[dict[str, Any]] = []
|
||||
response_latency: float = 0.0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
|
||||
# TODO: use llm registry
|
||||
llm = LLM(llm_config, service_id='solvability')
|
||||
|
||||
# Generate multiple samples to account for LLM variability
|
||||
for _ in range(samples):
|
||||
start_time = time.time()
|
||||
response = llm.completion(
|
||||
messages=[
|
||||
self.system_message(),
|
||||
self.user_message(issue_description, set_cache=(samples > 1)),
|
||||
],
|
||||
tools=[self.tool_description],
|
||||
tool_choice=self.tool_choice,
|
||||
temperature=temperature,
|
||||
)
|
||||
stop_time = time.time()
|
||||
|
||||
# Extract timing and token usage metrics
|
||||
latency = stop_time - start_time
|
||||
# Parse the structured tool call response containing feature evaluations
|
||||
features = response.choices[0].message.tool_calls[0].function.arguments # type: ignore[index, union-attr]
|
||||
embedding = json.loads(features)
|
||||
|
||||
# Accumulate results and metrics
|
||||
embedding_samples.append(embedding)
|
||||
prompt_tokens += response.usage.prompt_tokens # type: ignore[union-attr, attr-defined]
|
||||
completion_tokens += response.usage.completion_tokens # type: ignore[union-attr, attr-defined]
|
||||
response_latency += latency
|
||||
|
||||
return FeatureEmbedding(
|
||||
samples=embedding_samples,
|
||||
response_latency=response_latency,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
def embed_batch(
|
||||
self,
|
||||
issue_descriptions: list[str],
|
||||
llm_config: LLMConfig,
|
||||
temperature: float = 1.0,
|
||||
samples: int = 10,
|
||||
) -> list[FeatureEmbedding]:
|
||||
"""
|
||||
Generate embeddings for a batch of issue descriptions using concurrent processing.
|
||||
|
||||
Processes multiple issues in parallel to improve throughput while maintaining
|
||||
result ordering.
|
||||
|
||||
Args:
|
||||
issue_descriptions: List of issue descriptions to analyze.
|
||||
llm_config: Configuration for the LLM to use.
|
||||
temperature: Sampling temperature for the model.
|
||||
samples: Number of samples to generate per issue.
|
||||
|
||||
Returns:
|
||||
list[FeatureEmbedding]: List of embeddings in the same order as input.
|
||||
"""
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# Submit all embedding tasks concurrently
|
||||
future_to_desc = {
|
||||
executor.submit(
|
||||
self.embed,
|
||||
desc,
|
||||
llm_config,
|
||||
temperature=temperature,
|
||||
samples=samples,
|
||||
): i
|
||||
for i, desc in enumerate(issue_descriptions)
|
||||
}
|
||||
|
||||
# Collect results in original order to maintain consistency
|
||||
results: list[FeatureEmbedding] = [None] * len(issue_descriptions) # type: ignore[list-item]
|
||||
for future in as_completed(future_to_desc):
|
||||
index = future_to_desc[future]
|
||||
results[index] = future.result()
|
||||
|
||||
return results
|
||||
|
||||
def feature_identifiers(self) -> list[str]:
|
||||
"""
|
||||
Get the identifiers of all configured features.
|
||||
|
||||
Returns:
|
||||
list[str]: List of feature identifiers in the order they were defined.
|
||||
"""
|
||||
return [feature.identifier for feature in self.features]
|
||||
@@ -0,0 +1,23 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ImportanceStrategy(str, Enum):
|
||||
"""
|
||||
Strategy to use for calculating feature importances, which are used to estimate the predictive power of each feature
|
||||
in training loops and explanations.
|
||||
"""
|
||||
|
||||
SHAP = 'shap'
|
||||
"""
|
||||
Use SHAP (SHapley Additive exPlanations) to calculate feature importances.
|
||||
"""
|
||||
|
||||
PERMUTATION = 'permutation'
|
||||
"""
|
||||
Use the permutation-based feature importances.
|
||||
"""
|
||||
|
||||
IMPURITY = 'impurity'
|
||||
"""
|
||||
Use the impurity-based feature importances from the RandomForestClassifier.
|
||||
"""
|
||||
87
enterprise/integrations/solvability/models/report.py
Normal file
87
enterprise/integrations/solvability/models/report.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SolvabilityReport(BaseModel):
|
||||
"""
|
||||
Comprehensive report containing solvability predictions and analysis for a single issue.
|
||||
|
||||
This report includes the solvability score, extracted feature values, feature importance analysis,
|
||||
cost metrics (tokens and latency), and metadata about the prediction process. It serves as the
|
||||
primary output format for solvability analysis and can be used for logging, debugging, and
|
||||
generating human-readable summaries.
|
||||
"""
|
||||
|
||||
identifier: str
|
||||
"""
|
||||
The identifier of the solvability model used to generate the report.
|
||||
"""
|
||||
|
||||
issue: str
|
||||
"""
|
||||
The issue description for which the solvability is predicted.
|
||||
|
||||
This field is exactly the input to the solvability model.
|
||||
"""
|
||||
|
||||
score: float
|
||||
"""
|
||||
[0, 1]-valued score indicating the likelihood of the issue being solvable.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
"""
|
||||
Total number of prompt tokens used in API calls made to generate the features.
|
||||
"""
|
||||
|
||||
completion_tokens: int
|
||||
"""
|
||||
Total number of completion tokens used in API calls made to generate the features.
|
||||
"""
|
||||
|
||||
response_latency: float
|
||||
"""
|
||||
Total response latency of API calls made to generate the features.
|
||||
"""
|
||||
|
||||
features: dict[str, float]
|
||||
"""
|
||||
[0, 1]-valued scores for each feature in the model.
|
||||
|
||||
These are the values fed to the random forest classifier to generate the solvability score.
|
||||
"""
|
||||
|
||||
samples: int
|
||||
"""
|
||||
Number of samples used to compute the feature embedding coefficients.
|
||||
"""
|
||||
|
||||
importance_strategy: ImportanceStrategy
|
||||
"""
|
||||
Strategy used to calculate feature importances.
|
||||
"""
|
||||
|
||||
feature_importances: dict[str, float]
|
||||
"""
|
||||
Importance scores for each feature in the model.
|
||||
|
||||
Interpretation of these scores depends on the importance strategy used.
|
||||
"""
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
"""
|
||||
Datetime when the report was created.
|
||||
"""
|
||||
|
||||
random_state: int | None = None
|
||||
"""
|
||||
Classifier random state used when generating this report.
|
||||
"""
|
||||
|
||||
metadata: dict[str, Any] | None = None
|
||||
"""
|
||||
Metadata for logging and debugging purposes.
|
||||
"""
|
||||
172
enterprise/integrations/solvability/models/summary.py
Normal file
172
enterprise/integrations/solvability/models/summary.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from integrations.solvability.models.difficulty_level import DifficultyLevel
|
||||
from integrations.solvability.models.report import SolvabilityReport
|
||||
from integrations.solvability.prompts import load_prompt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.llm import LLM
|
||||
|
||||
|
||||
class SolvabilitySummary(BaseModel):
|
||||
"""Summary of the solvability analysis in human-readable format."""
|
||||
|
||||
score: float
|
||||
"""
|
||||
Solvability score indicating the likelihood of the issue being solvable.
|
||||
"""
|
||||
|
||||
summary: str
|
||||
"""
|
||||
The executive summary content generated by the LLM.
|
||||
"""
|
||||
|
||||
actionable_feedback: str
|
||||
"""
|
||||
Actionable feedback content generated by the LLM.
|
||||
"""
|
||||
|
||||
positive_feedback: str
|
||||
"""
|
||||
Positive feedback content generated by the LLM, highlighting what is good about the issue.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
"""
|
||||
Number of prompt tokens used in the API call to generate the summary.
|
||||
"""
|
||||
|
||||
completion_tokens: int
|
||||
"""
|
||||
Number of completion tokens used in the API call to generate the summary.
|
||||
"""
|
||||
|
||||
response_latency: float
|
||||
"""
|
||||
Response latency of the API call to generate the summary.
|
||||
"""
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
"""
|
||||
Datetime when the summary was created.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def tool_description() -> dict[str, Any]:
|
||||
"""Get the tool description for the LLM."""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'solvability_summary',
|
||||
'description': 'Generate a human-readable summary of the solvability analysis.',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'summary': {
|
||||
'type': 'string',
|
||||
'description': 'A high-level (at most two sentences) summary of the solvability report.',
|
||||
},
|
||||
'actionable_feedback': {
|
||||
'type': 'string',
|
||||
'description': (
|
||||
'Bullet list of 1-3 pieces of actionable feedback on how the user can address the lowest scoring relevant features.'
|
||||
),
|
||||
},
|
||||
'positive_feedback': {
|
||||
'type': 'string',
|
||||
'description': (
|
||||
'Bullet list of 1-3 pieces of positive feedback on the issue, highlighting what is good about it.'
|
||||
),
|
||||
},
|
||||
},
|
||||
'required': ['summary', 'actionable_feedback'],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def tool_choice() -> dict[str, Any]:
|
||||
"""Get the tool choice for the LLM."""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'solvability_summary',
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def system_message() -> dict[str, Any]:
|
||||
"""Get the system message for the LLM."""
|
||||
return {
|
||||
'role': 'system',
|
||||
'content': load_prompt('summary_system_message'),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def user_message(report: SolvabilityReport) -> dict[str, Any]:
|
||||
"""Get the user message for the LLM."""
|
||||
return {
|
||||
'role': 'user',
|
||||
'content': load_prompt(
|
||||
'summary_user_message',
|
||||
report=report.model_dump(),
|
||||
difficulty_level=DifficultyLevel.from_score(report.score).value[0],
|
||||
),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_report(report: SolvabilityReport, llm: LLM) -> SolvabilitySummary:
|
||||
"""Create a SolvabilitySummary from a SolvabilityReport."""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
response = llm.completion(
|
||||
messages=[
|
||||
SolvabilitySummary.system_message(),
|
||||
SolvabilitySummary.user_message(report),
|
||||
],
|
||||
tools=[SolvabilitySummary.tool_description()],
|
||||
tool_choice=SolvabilitySummary.tool_choice(),
|
||||
)
|
||||
response_latency = time.time() - start_time
|
||||
|
||||
# Grab the arguments from the forced function call
|
||||
arguments = json.loads(
|
||||
response.choices[0].message.tool_calls[0].function.arguments
|
||||
)
|
||||
|
||||
return SolvabilitySummary(
|
||||
# The score is copied directly from the report
|
||||
score=report.score,
|
||||
# Performance and usage metrics are pulled from the response
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
response_latency=response_latency,
|
||||
# Every other field should be taken from the forced function call
|
||||
**arguments,
|
||||
)
|
||||
|
||||
def format_as_markdown(self) -> str:
|
||||
"""Format the summary content as Markdown."""
|
||||
# Convert score to difficulty level enum
|
||||
difficulty_level = DifficultyLevel.from_score(self.score)
|
||||
|
||||
# Create the main difficulty display
|
||||
result = f'{difficulty_level.format_display()}\n\n{self.summary}'
|
||||
|
||||
# If not easy, show the three features with lowest importance scores
|
||||
if difficulty_level != DifficultyLevel.EASY:
|
||||
# Add dropdown with lowest importance features
|
||||
result += '\n\nYou can make the issue easier to resolve by addressing these concerns in the conversation:\n\n'
|
||||
result += self.actionable_feedback
|
||||
|
||||
# If the difficulty isn't hard, add some positive feedback
|
||||
if difficulty_level != DifficultyLevel.HARD:
|
||||
result += '\n\nPositive feedback:\n\n'
|
||||
result += self.positive_feedback
|
||||
|
||||
return result
|
||||
13
enterprise/integrations/solvability/prompts/__init__.py
Normal file
13
enterprise/integrations/solvability/prompts/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
import jinja2
|
||||
|
||||
|
||||
def load_prompt(prompt: str, **kwargs) -> str:
|
||||
"""Load a prompt by name. Passes all the keyword arguments to the prompt template."""
|
||||
env = jinja2.Environment(loader=jinja2.FileSystemLoader(Path(__file__).parent))
|
||||
template = env.get_template(f'{prompt}.j2')
|
||||
return template.render(**kwargs)
|
||||
|
||||
|
||||
__all__ = ['load_prompt']
|
||||
@@ -0,0 +1,10 @@
|
||||
You are a helpful assistant that generates human-readable summaries of solvability reports.
|
||||
The report predicts how likely it is that the issue can be resolved, and is produced purely based on the information provided in the issue description and comments.
|
||||
The report explains which features are present in the issue and how impactful they are to the solvability score (using SHAP values).
|
||||
Your task is to create a concise, high-level summary of the solvability analysis,
|
||||
with an emphasis on the key factors that make the issue easy or hard to resolve.
|
||||
Focus on the features with extreme scores, BUT ONLY if they are related to the issue at hand after careful consideration.
|
||||
You should NEVER mention: SHAP, scores, feature names, or technical metrics.
|
||||
You will also be given the expected difficulty of the issue, as EASY/MEDIUM/HARD.
|
||||
Be sure to frame your responses with that difficulty in mind.
|
||||
For example, if the issue is HARD you should not describe it as "straightforward".
|
||||
@@ -0,0 +1,9 @@
|
||||
Generate a high-level summary of the solvability report:
|
||||
|
||||
{{ report }}
|
||||
|
||||
We estimate the issue is {{ difficulty_level }}.
|
||||
The summary should be concise (at most two sentences) and describe the primary characteristics of this issue.
|
||||
Focus on what information is present and what factors are most relevant to resolution.
|
||||
Actionable feedback should be something that can be addressed by the user purely by providing more information.
|
||||
Positive feedback should explain the features that are positively contributing to the solvability score.
|
||||
0
enterprise/integrations/solvability/py.typed
Normal file
0
enterprise/integrations/solvability/py.typed
Normal file
73
enterprise/integrations/stripe_service.py
Normal file
73
enterprise/integrations/stripe_service.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import stripe
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from storage.database import session_maker
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer:
|
||||
return stripe_customer.stripe_customer_id
|
||||
|
||||
# If that fails, fallback to stripe
|
||||
search_result = await stripe.Customer.search_async(
|
||||
query=f"metadata['user_id']:'{user_id}'",
|
||||
)
|
||||
data = search_result.data
|
||||
if not data:
|
||||
logger.info('no_customer_for_user_id', extra={'user_id': user_id})
|
||||
return None
|
||||
return data[0].id # type: ignore [attr-defined]
|
||||
|
||||
|
||||
async def find_or_create_customer(user_id: str) -> str:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id:
|
||||
return customer_id
|
||||
logger.info('creating_customer', extra={'user_id': user_id})
|
||||
|
||||
# Get the user info from keycloak
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id) or {}
|
||||
|
||||
# Create the customer in stripe
|
||||
customer = await stripe.Customer.create_async(
|
||||
email=str(user_info.get('email', '')),
|
||||
metadata={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(keycloak_user_id=user_id, stripe_customer_id=customer.id)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
extra={'user_id': user_id, 'stripe_customer_id': customer.id},
|
||||
)
|
||||
return customer.id
|
||||
|
||||
|
||||
async def has_payment_method(user_id: str) -> bool:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id is None:
|
||||
return False
|
||||
payment_methods = await stripe.Customer.list_payment_methods_async(
|
||||
customer_id,
|
||||
)
|
||||
logger.info(
|
||||
f'has_payment_method:{user_id}:{customer_id}:{bool(payment_methods.data)}'
|
||||
)
|
||||
return bool(payment_methods.data)
|
||||
51
enterprise/integrations/types.py
Normal file
51
enterprise/integrations/types.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from jinja2 import Environment
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GitLabResourceType(Enum):
|
||||
GROUP = 'group'
|
||||
SUBGROUP = 'subgroup'
|
||||
PROJECT = 'project'
|
||||
|
||||
|
||||
class PRStatus(Enum):
|
||||
CLOSED = 'CLOSED'
|
||||
MERGED = 'MERGED'
|
||||
|
||||
|
||||
class UserData(BaseModel):
|
||||
user_id: int
|
||||
username: str
|
||||
keycloak_user_id: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummaryExtractionTracker:
|
||||
conversation_id: str
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolverViewInterface(SummaryExtractionTracker):
|
||||
installation_id: int
|
||||
user_info: UserData
|
||||
issue_number: int
|
||||
full_repo_name: str
|
||||
is_public_repo: bool
|
||||
raw_payload: dict
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create_new_conversation(self, jinja_env: Environment, token: str):
|
||||
"Create a new conversation"
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_callback_id(self) -> str:
|
||||
"Unique callback id for subscribription made to EventStream for fetching agent summary"
|
||||
raise NotImplementedError()
|
||||
546
enterprise/integrations/utils.py
Normal file
546
enterprise/integrations/utils.py
Normal file
@@ -0,0 +1,546 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.constants import WEB_HOST
|
||||
from storage.repository_store import RepositoryStore
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_repo_map_store import UserRepositoryMapStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events import Event, EventSource
|
||||
from openhands.events.action import (
|
||||
AgentFinishAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event_store_abc import EventStoreABC
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.server.conversation_manager.conversation_manager import (
|
||||
ConversationManager,
|
||||
)
|
||||
|
||||
# ---- DO NOT REMOVE ----
|
||||
# WARNING: Langfuse depends on the WEB_HOST environment variable being set to track events.
|
||||
HOST = WEB_HOST
|
||||
# ---- DO NOT REMOVE ----
|
||||
|
||||
HOST_URL = f'https://{HOST}'
|
||||
GITHUB_WEBHOOK_URL = f'{HOST_URL}/integration/github/events'
|
||||
GITLAB_WEBHOOK_URL = f'{HOST_URL}/integration/gitlab/events'
|
||||
conversation_prefix = 'conversations/{}'
|
||||
CONVERSATION_URL = f'{HOST_URL}/{conversation_prefix}'
|
||||
|
||||
# Toggle for auto-response feature that proactively starts conversations with users when workflow tests fail
|
||||
ENABLE_PROACTIVE_CONVERSATION_STARTERS = (
|
||||
os.getenv('ENABLE_PROACTIVE_CONVERSATION_STARTERS', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
# Toggle for solvability report feature
|
||||
ENABLE_SOLVABILITY_ANALYSIS = (
|
||||
os.getenv('ENABLE_SOLVABILITY_ANALYSIS', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR = 'openhands/integrations/templates/resolver/'
|
||||
jinja_env = Environment(loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR))
|
||||
|
||||
|
||||
def get_oh_labels(web_host: str) -> tuple[str, str]:
|
||||
"""Get the OpenHands labels based on the web host.
|
||||
|
||||
Args:
|
||||
web_host: The web host string to check
|
||||
|
||||
Returns:
|
||||
A tuple of (oh_label, inline_oh_label) where:
|
||||
- oh_label is 'openhands-exp' for staging/local hosts, 'openhands' otherwise
|
||||
- inline_oh_label is '@openhands-exp' for staging/local hosts, '@openhands' otherwise
|
||||
"""
|
||||
web_host = web_host.strip()
|
||||
is_staging_or_local = 'staging' in web_host or 'local' in web_host
|
||||
oh_label = 'openhands-exp' if is_staging_or_local else 'openhands'
|
||||
inline_oh_label = '@openhands-exp' if is_staging_or_local else '@openhands'
|
||||
return oh_label, inline_oh_label
|
||||
|
||||
|
||||
def get_summary_instruction():
|
||||
summary_instruction_template = jinja_env.get_template('summary_prompt.j2')
|
||||
summary_instruction = summary_instruction_template.render()
|
||||
return summary_instruction
|
||||
|
||||
|
||||
def has_exact_mention(text: str, mention: str) -> bool:
|
||||
"""Check if the text contains an exact mention (not part of a larger word).
|
||||
|
||||
Args:
|
||||
text: The text to check for mentions
|
||||
mention: The mention to look for (e.g. "@openhands")
|
||||
|
||||
Returns:
|
||||
bool: True if the exact mention is found, False otherwise
|
||||
|
||||
Example:
|
||||
>>> has_exact_mention("Hello @openhands!", "@openhands") # True
|
||||
>>> has_exact_mention("Hello @openhands-agent!", "@openhands") # False
|
||||
>>> has_exact_mention("(@openhands)", "@openhands") # True
|
||||
>>> has_exact_mention("user@openhands.com", "@openhands") # False
|
||||
>>> has_exact_mention("Hello @OpenHands!", "@openhands") # True (case-insensitive)
|
||||
"""
|
||||
# Convert both text and mention to lowercase for case-insensitive matching
|
||||
text_lower = text.lower()
|
||||
mention_lower = mention.lower()
|
||||
|
||||
pattern = re.escape(mention_lower)
|
||||
# Match mention that is not part of a larger word
|
||||
return bool(re.search(rf'(?:^|[^\w@]){pattern}(?![\w-])', text_lower))
|
||||
|
||||
|
||||
def confirm_event_type(event: Event):
|
||||
return isinstance(event, AgentStateChangedObservation) and not (
|
||||
event.agent_state == AgentState.REJECTED
|
||||
or event.agent_state == AgentState.USER_CONFIRMED
|
||||
or event.agent_state == AgentState.USER_REJECTED
|
||||
or event.agent_state == AgentState.LOADING
|
||||
or event.agent_state == AgentState.RUNNING
|
||||
)
|
||||
|
||||
|
||||
def get_readable_error_reason(reason: str):
|
||||
if reason == 'STATUS$ERROR_LLM_AUTHENTICATION':
|
||||
reason = 'Authentication with the LLM provider failed. Please check your API key or credentials'
|
||||
elif reason == 'STATUS$ERROR_LLM_SERVICE_UNAVAILABLE':
|
||||
reason = 'The LLM service is temporarily unavailable. Please try again later'
|
||||
elif reason == 'STATUS$ERROR_LLM_INTERNAL_SERVER_ERROR':
|
||||
reason = 'The LLM provider encountered an internal error. Please try again soon'
|
||||
elif reason == 'STATUS$ERROR_LLM_OUT_OF_CREDITS':
|
||||
reason = "You've run out of credits. Please top up to continue"
|
||||
elif reason == 'STATUS$ERROR_LLM_CONTENT_POLICY_VIOLATION':
|
||||
reason = 'Content policy violation. The output was blocked by content filtering policy'
|
||||
return reason
|
||||
|
||||
|
||||
def get_summary_for_agent_state(
|
||||
observations: list[AgentStateChangedObservation], conversation_link: str
|
||||
) -> str:
|
||||
unknown_error_msg = f'OpenHands encountered an unknown error. [See the conversation]({conversation_link}) for more information, or try again'
|
||||
|
||||
if len(observations) == 0:
|
||||
logger.error(
|
||||
'Unknown error: No agent state observations found',
|
||||
extra={'conversation_link': conversation_link},
|
||||
)
|
||||
return unknown_error_msg
|
||||
|
||||
observation: AgentStateChangedObservation = observations[0]
|
||||
state = observation.agent_state
|
||||
|
||||
if state == AgentState.RATE_LIMITED:
|
||||
logger.warning(
|
||||
'Agent was rate limited',
|
||||
extra={
|
||||
'agent_state': state.value,
|
||||
'conversation_link': conversation_link,
|
||||
'observation_reason': getattr(observation, 'reason', None),
|
||||
},
|
||||
)
|
||||
return 'OpenHands was rate limited by the LLM provider. Please try again later.'
|
||||
|
||||
if state == AgentState.ERROR:
|
||||
reason = observation.reason
|
||||
reason = get_readable_error_reason(reason)
|
||||
|
||||
logger.error(
|
||||
'Agent encountered an error',
|
||||
extra={
|
||||
'agent_state': state.value,
|
||||
'conversation_link': conversation_link,
|
||||
'observation_reason': observation.reason,
|
||||
'readable_reason': reason,
|
||||
},
|
||||
)
|
||||
|
||||
return f'OpenHands encountered an error: **{reason}**.\n\n[See the conversation]({conversation_link}) for more information.'
|
||||
|
||||
# Log unknown agent state as error
|
||||
logger.error(
|
||||
'Unknown error: Unhandled agent state',
|
||||
extra={
|
||||
'agent_state': state.value if hasattr(state, 'value') else str(state),
|
||||
'conversation_link': conversation_link,
|
||||
'observation_reason': getattr(observation, 'reason', None),
|
||||
},
|
||||
)
|
||||
return unknown_error_msg
|
||||
|
||||
|
||||
def get_final_agent_observation(
|
||||
event_store: EventStoreABC,
|
||||
) -> list[AgentStateChangedObservation]:
|
||||
return event_store.get_matching_events(
|
||||
source=EventSource.ENVIRONMENT,
|
||||
event_types=(AgentStateChangedObservation,),
|
||||
limit=1,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
|
||||
def get_last_user_msg(event_store: EventStoreABC) -> list[MessageAction]:
|
||||
return event_store.get_matching_events(
|
||||
source=EventSource.USER, event_types=(MessageAction,), limit=1, reverse='true'
|
||||
)
|
||||
|
||||
|
||||
def extract_summary_from_event_store(
|
||||
event_store: EventStoreABC, conversation_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Get agent summary or alternative message depending on current AgentState
|
||||
"""
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
summary_instruction = get_summary_instruction()
|
||||
|
||||
instruction_event: list[MessageAction] = event_store.get_matching_events(
|
||||
query=json.dumps(summary_instruction),
|
||||
source=EventSource.USER,
|
||||
event_types=(MessageAction,),
|
||||
limit=1,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(event_store)
|
||||
|
||||
# Find summary instruction event ID
|
||||
if len(instruction_event) == 0:
|
||||
logger.warning(
|
||||
'no_instruction_event_found', extra={'conversation_id': conversation_id}
|
||||
)
|
||||
return get_summary_for_agent_state(
|
||||
final_agent_observation, conversation_link
|
||||
) # Agent did not receive summary instruction
|
||||
|
||||
event_id: int = instruction_event[0].id
|
||||
|
||||
agent_messages: list[MessageAction | AgentFinishAction] = (
|
||||
event_store.get_matching_events(
|
||||
start_id=event_id,
|
||||
source=EventSource.AGENT,
|
||||
event_types=(MessageAction, AgentFinishAction),
|
||||
reverse=True,
|
||||
limit=1,
|
||||
)
|
||||
)
|
||||
|
||||
if len(agent_messages) == 0:
|
||||
logger.warning(
|
||||
'no_agent_messages_found', extra={'conversation_id': conversation_id}
|
||||
)
|
||||
return get_summary_for_agent_state(
|
||||
final_agent_observation, conversation_link
|
||||
) # Agent failed to generate summary
|
||||
|
||||
summary_event: MessageAction | AgentFinishAction = agent_messages[0]
|
||||
if isinstance(summary_event, MessageAction):
|
||||
return summary_event.content
|
||||
|
||||
return summary_event.final_thought
|
||||
|
||||
|
||||
async def get_event_store_from_conversation_manager(
|
||||
conversation_manager: ConversationManager, conversation_id: str
|
||||
) -> EventStoreABC:
|
||||
agent_loop_infos = await conversation_manager.get_agent_loop_info(
|
||||
filter_to_sids={conversation_id}
|
||||
)
|
||||
if not agent_loop_infos or agent_loop_infos[0].status != ConversationStatus.RUNNING:
|
||||
raise RuntimeError(f'conversation_not_running:{conversation_id}')
|
||||
event_store = agent_loop_infos[0].event_store
|
||||
if not event_store:
|
||||
raise RuntimeError(f'event_store_missing:{conversation_id}')
|
||||
return event_store
|
||||
|
||||
|
||||
async def get_last_user_msg_from_conversation_manager(
|
||||
conversation_manager: ConversationManager, conversation_id: str
|
||||
):
|
||||
event_store = await get_event_store_from_conversation_manager(
|
||||
conversation_manager, conversation_id
|
||||
)
|
||||
return get_last_user_msg(event_store)
|
||||
|
||||
|
||||
async def extract_summary_from_conversation_manager(
|
||||
conversation_manager: ConversationManager, conversation_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Get agent summary or alternative message depending on current AgentState
|
||||
"""
|
||||
|
||||
event_store = await get_event_store_from_conversation_manager(
|
||||
conversation_manager, conversation_id
|
||||
)
|
||||
summary = extract_summary_from_event_store(event_store, conversation_id)
|
||||
return append_conversation_footer(summary, conversation_id)
|
||||
|
||||
|
||||
def append_conversation_footer(message: str, conversation_id: str) -> str:
|
||||
"""
|
||||
Append a small footer with the conversation URL to a message.
|
||||
|
||||
Args:
|
||||
message: The original message content
|
||||
conversation_id: The conversation ID to link to
|
||||
|
||||
Returns:
|
||||
The message with the conversation footer appended
|
||||
"""
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
footer = f'\n\n<sub>[View full conversation]({conversation_link})</sub>'
|
||||
return message + footer
|
||||
|
||||
|
||||
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
|
||||
"""
|
||||
Store repositories in DB and create user-repository mappings
|
||||
|
||||
Args:
|
||||
repos: List of Repository objects to store
|
||||
user_id: User ID associated with these repositories
|
||||
"""
|
||||
|
||||
# Convert Repository objects to StoredRepository objects
|
||||
# Convert Repository objects to UserRepositoryMap objects
|
||||
stored_repos = []
|
||||
user_repos = []
|
||||
for repo in repos:
|
||||
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
|
||||
stored_repo = StoredRepository(
|
||||
repo_name=repo.full_name,
|
||||
repo_id=repo_id,
|
||||
is_public=repo.is_public,
|
||||
# Optional fields set to None by default
|
||||
has_microagent=None,
|
||||
has_setup_script=None,
|
||||
)
|
||||
stored_repos.append(stored_repo)
|
||||
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
|
||||
|
||||
user_repos.append(user_repo_map)
|
||||
|
||||
# Get config instance
|
||||
config = OpenHandsConfig()
|
||||
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
logger.warning('Failed to save repos', exc_info=True)
|
||||
|
||||
|
||||
def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||
"""
|
||||
Extract all repository names in the format 'owner/repo' from various Git provider URLs
|
||||
and direct mentions in text. Supports GitHub, GitLab, and BitBucket.
|
||||
Args:
|
||||
user_msg: Input message that may contain repository references
|
||||
Returns:
|
||||
List of repository names in 'owner/repo' format, empty list if none found
|
||||
"""
|
||||
# Normalize the message by removing extra whitespace and newlines
|
||||
normalized_msg = re.sub(r'\s+', ' ', user_msg.strip())
|
||||
|
||||
# Pattern to match Git URLs from GitHub, GitLab, and BitBucket
|
||||
# Captures: protocol, domain, owner, repo (with optional .git extension)
|
||||
git_url_pattern = r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?(?:[/?#].*?)?(?=\s|$|[^\w.-])'
|
||||
|
||||
# Pattern to match direct owner/repo mentions (e.g., "All-Hands-AI/OpenHands")
|
||||
# Must be surrounded by word boundaries or specific characters to avoid false positives
|
||||
direct_pattern = (
|
||||
r'(?:^|\s|[\[\(\'"])([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)(?=\s|$|[\]\)\'",.])'
|
||||
)
|
||||
|
||||
matches = []
|
||||
|
||||
# First, find all Git URLs (highest priority)
|
||||
git_matches = re.findall(git_url_pattern, normalized_msg)
|
||||
for owner, repo in git_matches:
|
||||
# Remove .git extension if present
|
||||
repo = re.sub(r'\.git$', '', repo)
|
||||
matches.append(f'{owner}/{repo}')
|
||||
|
||||
# Second, find all direct owner/repo mentions
|
||||
direct_matches = re.findall(direct_pattern, normalized_msg)
|
||||
for owner, repo in direct_matches:
|
||||
full_match = f'{owner}/{repo}'
|
||||
|
||||
# Skip if it looks like a version number, date, or file path
|
||||
if (
|
||||
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match) # version numbers
|
||||
or re.match(r'^\d{1,2}/\d{1,2}$', full_match) # dates
|
||||
or re.match(r'^[A-Z]/[A-Z]$', full_match) # single letters
|
||||
or repo.endswith('.txt')
|
||||
or repo.endswith('.md') # file extensions
|
||||
or repo.endswith('.py')
|
||||
or repo.endswith('.js')
|
||||
or '.' in repo
|
||||
and len(repo.split('.')) > 2
|
||||
): # complex file paths
|
||||
continue
|
||||
|
||||
# Avoid duplicates from Git URLs already found
|
||||
if full_match not in matches:
|
||||
matches.append(full_match)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def filter_potential_repos_by_user_msg(
|
||||
user_msg: str, user_repos: list[Repository]
|
||||
) -> tuple[bool, list[Repository]]:
|
||||
"""Filter repositories based on user message inference."""
|
||||
inferred_repos = infer_repo_from_message(user_msg)
|
||||
if not inferred_repos:
|
||||
return False, user_repos[0:99]
|
||||
|
||||
final_repos = []
|
||||
for repo in user_repos:
|
||||
# Check if the repo matches any of the inferred repositories
|
||||
for inferred_repo in inferred_repos:
|
||||
if inferred_repo.lower() in repo.full_name.lower():
|
||||
final_repos.append(repo)
|
||||
break # Avoid adding the same repo multiple times
|
||||
|
||||
# no repos matched, return original list
|
||||
if len(final_repos) == 0:
|
||||
return False, user_repos[0:99]
|
||||
|
||||
# Found exact match
|
||||
elif len(final_repos) == 1:
|
||||
return True, final_repos
|
||||
|
||||
# Found partial matches
|
||||
return False, final_repos[0:99]
|
||||
|
||||
|
||||
def markdown_to_jira_markup(markdown_text: str) -> str:
|
||||
"""
|
||||
Convert markdown text to Jira Wiki Markup format.
|
||||
This function handles common markdown elements and converts them to their
|
||||
Jira Wiki Markup equivalents. It's designed to be exception-safe.
|
||||
Args:
|
||||
markdown_text: The markdown text to convert
|
||||
Returns:
|
||||
str: The converted Jira Wiki Markup text
|
||||
"""
|
||||
if not markdown_text or not isinstance(markdown_text, str):
|
||||
return ''
|
||||
|
||||
try:
|
||||
# Work with a copy to avoid modifying the original
|
||||
text = markdown_text
|
||||
|
||||
# Convert headers (# ## ### #### ##### ######)
|
||||
text = re.sub(r'^#{6}\s+(.*?)$', r'h6. \1', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^#{5}\s+(.*?)$', r'h5. \1', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^#{4}\s+(.*?)$', r'h4. \1', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^#{3}\s+(.*?)$', r'h3. \1', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^#{2}\s+(.*?)$', r'h2. \1', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^#{1}\s+(.*?)$', r'h1. \1', text, flags=re.MULTILINE)
|
||||
|
||||
# Convert code blocks first (before other formatting)
|
||||
text = re.sub(
|
||||
r'```(\w+)\n(.*?)\n```', r'{code:\1}\n\2\n{code}', text, flags=re.DOTALL
|
||||
)
|
||||
text = re.sub(r'```\n(.*?)\n```', r'{code}\n\1\n{code}', text, flags=re.DOTALL)
|
||||
|
||||
# Convert inline code (`code`)
|
||||
text = re.sub(r'`([^`]+)`', r'{{\1}}', text)
|
||||
|
||||
# Convert markdown formatting to Jira formatting
|
||||
# Use temporary placeholders to avoid conflicts between bold and italic conversion
|
||||
|
||||
# First convert bold (double markers) to temporary placeholders
|
||||
text = re.sub(r'\*\*(.*?)\*\*', r'JIRA_BOLD_START\1JIRA_BOLD_END', text)
|
||||
text = re.sub(r'__(.*?)__', r'JIRA_BOLD_START\1JIRA_BOLD_END', text)
|
||||
|
||||
# Now convert single asterisk italics
|
||||
text = re.sub(r'\*([^*]+?)\*', r'_\1_', text)
|
||||
|
||||
# Convert underscore italics
|
||||
text = re.sub(r'(?<!_)_([^_]+?)_(?!_)', r'_\1_', text)
|
||||
|
||||
# Finally, restore bold markers
|
||||
text = text.replace('JIRA_BOLD_START', '*')
|
||||
text = text.replace('JIRA_BOLD_END', '*')
|
||||
|
||||
# Convert links [text](url)
|
||||
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'[\1|\2]', text)
|
||||
|
||||
# Convert unordered lists (- or * or +)
|
||||
text = re.sub(r'^[\s]*[-*+]\s+(.*?)$', r'* \1', text, flags=re.MULTILINE)
|
||||
|
||||
# Convert ordered lists (1. 2. etc.)
|
||||
text = re.sub(r'^[\s]*\d+\.\s+(.*?)$', r'# \1', text, flags=re.MULTILINE)
|
||||
|
||||
# Convert strikethrough (~~text~~)
|
||||
text = re.sub(r'~~(.*?)~~', r'-\1-', text)
|
||||
|
||||
# Convert horizontal rules (---, ***, ___)
|
||||
text = re.sub(r'^[\s]*[-*_]{3,}[\s]*$', r'----', text, flags=re.MULTILINE)
|
||||
|
||||
# Convert blockquotes (> text)
|
||||
text = re.sub(r'^>\s+(.*?)$', r'bq. \1', text, flags=re.MULTILINE)
|
||||
|
||||
# Convert tables (basic support)
|
||||
# This is a simplified table conversion - Jira tables are quite different
|
||||
lines = text.split('\n')
|
||||
in_table = False
|
||||
converted_lines = []
|
||||
|
||||
for line in lines:
|
||||
if (
|
||||
'|' in line
|
||||
and line.strip().startswith('|')
|
||||
and line.strip().endswith('|')
|
||||
):
|
||||
# Skip markdown table separator lines (contain ---)
|
||||
if '---' in line:
|
||||
continue
|
||||
if not in_table:
|
||||
in_table = True
|
||||
# Convert markdown table row to Jira table row
|
||||
cells = [cell.strip() for cell in line.split('|')[1:-1]]
|
||||
converted_line = '|' + '|'.join(cells) + '|'
|
||||
converted_lines.append(converted_line)
|
||||
elif in_table and line.strip() and '|' not in line:
|
||||
in_table = False
|
||||
converted_lines.append(line)
|
||||
else:
|
||||
in_table = False
|
||||
converted_lines.append(line)
|
||||
|
||||
text = '\n'.join(converted_lines)
|
||||
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but don't raise it - return original text as fallback
|
||||
print(f'Error converting markdown to Jira markup: {str(e)}')
|
||||
return markdown_text or ''
|
||||
114
enterprise/migrations/env.py
Normal file
114
enterprise/migrations/env.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from google.cloud.sql.connector import Connector
|
||||
from sqlalchemy import create_engine
|
||||
from storage.base import Base
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
DB_USER = os.getenv('DB_USER', 'postgres')
|
||||
DB_PASS = os.getenv('DB_PASS', 'postgres')
|
||||
DB_HOST = os.getenv('DB_HOST', 'localhost')
|
||||
DB_PORT = os.getenv('DB_PORT', '5432')
|
||||
DB_NAME = os.getenv('DB_NAME', 'openhands')
|
||||
|
||||
GCP_DB_INSTANCE = os.getenv('GCP_DB_INSTANCE')
|
||||
GCP_PROJECT = os.getenv('GCP_PROJECT')
|
||||
GCP_REGION = os.getenv('GCP_REGION')
|
||||
|
||||
POOL_SIZE = int(os.getenv('DB_POOL_SIZE', '25'))
|
||||
MAX_OVERFLOW = int(os.getenv('DB_MAX_OVERFLOW', '10'))
|
||||
|
||||
|
||||
def get_engine(database_name=DB_NAME):
|
||||
"""Create SQLAlchemy engine with optional database name."""
|
||||
if GCP_DB_INSTANCE:
|
||||
|
||||
def get_db_connection():
|
||||
connector = Connector()
|
||||
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
|
||||
return connector.connect(
|
||||
instance_string,
|
||||
'pg8000',
|
||||
user=DB_USER,
|
||||
password=DB_PASS.strip(),
|
||||
db=database_name,
|
||||
)
|
||||
|
||||
return create_engine(
|
||||
'postgresql+pg8000://',
|
||||
creator=get_db_connection,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
else:
|
||||
url = f'postgresql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{database_name}'
|
||||
return create_engine(
|
||||
url,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
|
||||
engine = get_engine()
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
"""
|
||||
url = config.get_main_option('sqlalchemy.url')
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={'paramstyle': 'named'},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
"""
|
||||
connectable = engine
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
version_table_schema=target_metadata.schema,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
enterprise/migrations/script.py.mako
Normal file
26
enterprise/migrations/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
45
enterprise/migrations/versions/001_create_feedback_table.py
Normal file
45
enterprise/migrations/versions/001_create_feedback_table.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Create feedback table
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2024-03-19 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '001'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'feedback',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('version', sa.String(), nullable=False),
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'polarity',
|
||||
sa.Enum('positive', 'negative', name='polarity_enum'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
'permissions',
|
||||
sa.Enum('public', 'private', name='permissions_enum'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column('trajectory', sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('feedback')
|
||||
op.execute('DROP TYPE polarity_enum')
|
||||
op.execute('DROP TYPE permissions_enum')
|
||||
@@ -0,0 +1,45 @@
|
||||
"""create saas settings table
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001
|
||||
Create Date: 2025-01-27 20:08:58.360566
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '002'
|
||||
down_revision: Union[str, None] = '001'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# This was created to match the settings object - in future some of these strings should probabyl
|
||||
# be replaced with enum types.
|
||||
op.create_table(
|
||||
'settings',
|
||||
sa.Column('id', sa.String(), nullable=False, primary_key=True),
|
||||
sa.Column('language', sa.String(), nullable=True),
|
||||
sa.Column('agent', sa.String(), nullable=True),
|
||||
sa.Column('max_iterations', sa.Integer(), nullable=True),
|
||||
sa.Column('security_analyzer', sa.String(), nullable=True),
|
||||
sa.Column('confirmation_mode', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('llm_model', sa.String(), nullable=True),
|
||||
sa.Column('llm_api_key', sa.String(), nullable=True),
|
||||
sa.Column('llm_base_url', sa.String(), nullable=True),
|
||||
sa.Column('remote_runtime_resource_factor', sa.Integer(), nullable=True),
|
||||
sa.Column('github_token', sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
'enable_default_condenser', sa.Boolean(), nullable=False, default=False
|
||||
),
|
||||
sa.Column('user_consents_to_analytics', sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('settings')
|
||||
@@ -0,0 +1,35 @@
|
||||
"""create saas conversations table
|
||||
|
||||
Revision ID: 003
|
||||
Revises: 002
|
||||
Create Date: 2025-01-29 09:36:49.475467
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '003'
|
||||
down_revision: Union[str, None] = '002'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'conversation_metadata',
|
||||
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=False, index=True),
|
||||
sa.Column('selected_repository', sa.String(), nullable=True),
|
||||
sa.Column('title', sa.String(), nullable=True),
|
||||
sa.Column('last_updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, index=True),
|
||||
sa.PrimaryKeyConstraint('conversation_id'),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('conversation_metadata')
|
||||
@@ -0,0 +1,47 @@
|
||||
"""create saas conversations table
|
||||
|
||||
Revision ID: 004
|
||||
Revises: 003
|
||||
Create Date: 2025-01-29 09:36:49.475467
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '004'
|
||||
down_revision: Union[str, None] = '003'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'billing_sessions',
|
||||
sa.Column('id', sa.String(), nullable=False, primary_key=True),
|
||||
sa.Column('user_id', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'status',
|
||||
sa.Enum(
|
||||
'in_progress',
|
||||
'completed',
|
||||
'cancelled',
|
||||
'error',
|
||||
name='billing_session_status_enum',
|
||||
),
|
||||
nullable=False,
|
||||
default='in_progress',
|
||||
),
|
||||
sa.Column('price', sa.DECIMAL(19, 4), nullable=False),
|
||||
sa.Column('price_code', sa.String(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('billing_sessions')
|
||||
op.execute('DROP TYPE billing_session_status_enum')
|
||||
26
enterprise/migrations/versions/005_add_margin_column.py
Normal file
26
enterprise/migrations/versions/005_add_margin_column.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""add margin column
|
||||
|
||||
Revision ID: 005
|
||||
Revises: 004
|
||||
Create Date: 2025-02-10 08:36:49.475467
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '005'
|
||||
down_revision: Union[str, None] = '004'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('settings', sa.Column('margin', sa.Float(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('settings', 'margin')
|
||||
@@ -0,0 +1,29 @@
|
||||
"""add branch column to convo metadata table
|
||||
|
||||
Revision ID: 006
|
||||
Revises: 005
|
||||
Create Date: 2025-02-11 14:59:09.415
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '006'
|
||||
down_revision: Union[str, None] = '005'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('selected_branch', sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('conversation_metadata', 'selected_branch')
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add enable_sound_notifications column to settings table
|
||||
|
||||
Revision ID: 007
|
||||
Revises: 006
|
||||
Create Date: 2025-05-01 10:00:00.000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '007'
|
||||
down_revision: Union[str, None] = '006'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'settings',
|
||||
sa.Column(
|
||||
'enable_sound_notifications', sa.Boolean(), nullable=True, default=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('settings', 'enable_sound_notifications')
|
||||
@@ -0,0 +1,36 @@
|
||||
"""fix enable_sound_notifications settings to not be nullable
|
||||
|
||||
Revision ID: 008
|
||||
Revises: 007
|
||||
Create Date: 2025-02-28 18:28:00.000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '008'
|
||||
down_revision: Union[str, None] = '007'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
'settings',
|
||||
sa.Column(
|
||||
'enable_sound_notifications', sa.Boolean(), nullable=False, default=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
'settings',
|
||||
sa.Column(
|
||||
'enable_sound_notifications', sa.Boolean(), nullable=True, default=False
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""fix enable_sound_notifications settings to not be nullable
|
||||
|
||||
Revision ID: 009
|
||||
Revises: 008
|
||||
Create Date: 2025-02-28 18:28:00.000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '009'
|
||||
down_revision: Union[str, None] = '008'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
'UPDATE settings SET enable_sound_notifications=FALSE where enable_sound_notifications IS NULL'
|
||||
)
|
||||
op.alter_column(
|
||||
'settings',
|
||||
sa.Column(
|
||||
'enable_sound_notifications', sa.Boolean(), nullable=False, default=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
'settings',
|
||||
sa.Column(
|
||||
'enable_sound_notifications', sa.Boolean(), nullable=True, default=False
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""create offline tokens table.
|
||||
|
||||
Revision ID: 010
|
||||
Revises: 009_fix_enable_sound_notifications_column
|
||||
Create Date: 2024-03-11
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '010'
|
||||
down_revision: Union[str, None] = '009'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'offline_tokens',
|
||||
sa.Column('user_id', sa.String(length=255), primary_key=True),
|
||||
sa.Column('offline_token', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
'updated_at',
|
||||
sa.DateTime(),
|
||||
server_default=sa.text('now()'),
|
||||
onupdate=sa.text('now()'),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('offline_tokens')
|
||||
@@ -0,0 +1,50 @@
|
||||
"""create user settings table
|
||||
|
||||
Revision ID: 011
|
||||
Revises: 010
|
||||
Create Date: 2024-03-11 23:39:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '011'
|
||||
down_revision: Union[str, None] = '010'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'user_settings',
|
||||
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||
sa.Column('keycloak_user_id', sa.String(), nullable=True),
|
||||
sa.Column('language', sa.String(), nullable=True),
|
||||
sa.Column('agent', sa.String(), nullable=True),
|
||||
sa.Column('max_iterations', sa.Integer(), nullable=True),
|
||||
sa.Column('security_analyzer', sa.String(), nullable=True),
|
||||
sa.Column('confirmation_mode', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('llm_model', sa.String(), nullable=True),
|
||||
sa.Column('llm_api_key', sa.String(), nullable=True),
|
||||
sa.Column('llm_base_url', sa.String(), nullable=True),
|
||||
sa.Column('remote_runtime_resource_factor', sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
'enable_default_condenser', sa.Boolean(), nullable=False, default=False
|
||||
),
|
||||
sa.Column('user_consents_to_analytics', sa.Boolean(), nullable=True),
|
||||
sa.Column('billing_margin', sa.Float(), nullable=True),
|
||||
sa.Column(
|
||||
'enable_sound_notifications', sa.Boolean(), nullable=True, default=False
|
||||
),
|
||||
)
|
||||
# Create indexes for faster lookups
|
||||
op.create_index('idx_keycloak_user_id', 'user_settings', ['keycloak_user_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('idx_keycloak_user_id', 'user_settings')
|
||||
op.drop_table('user_settings')
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add secret_store column to settings table
|
||||
Revision ID: 012
|
||||
Revises: 011
|
||||
Create Date: 2025-05-01 10:00:00.000
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '012'
|
||||
down_revision: Union[str, None] = '011'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'settings', sa.Column('secrets_store', sa.JSON(), nullable=True, default=False)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('settings', 'secrets_store')
|
||||
@@ -0,0 +1,53 @@
|
||||
"""create user settings table
|
||||
|
||||
Revision ID: 013
|
||||
Revises: 012
|
||||
Create Date: 2024-03-12 23:39:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '013'
|
||||
down_revision: Union[str, None] = '012'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'github_app_installations',
|
||||
sa.Column('id', sa.Integer(), sa.Identity(), primary_key=True),
|
||||
sa.Column('installation_id', sa.String(), nullable=False),
|
||||
sa.Column('encrypted_token', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(),
|
||||
server_default=sa.text('now()'),
|
||||
onupdate=sa.text('now()'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
'updated_at',
|
||||
sa.DateTime(),
|
||||
server_default=sa.text('now()'),
|
||||
onupdate=sa.text('now()'),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
# Create indexes for faster lookups
|
||||
op.create_index(
|
||||
'idx_installation_id',
|
||||
'github_app_installations',
|
||||
['installation_id'],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('idx_installation_id', 'github_app_installations')
|
||||
op.drop_table('github_app_installations')
|
||||
40
enterprise/migrations/versions/014_add_github_user_id.py
Normal file
40
enterprise/migrations/versions/014_add_github_user_id.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Add github_user_id field and rename user_id to github_user_id.
|
||||
|
||||
This migration:
|
||||
1. Renames the existing user_id column to github_user_id
|
||||
2. Creates a new user_id column
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import Column, String
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '014'
|
||||
down_revision: Union[str, None] = '013'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# First rename the existing user_id column to github_user_id
|
||||
op.alter_column(
|
||||
'conversation_metadata',
|
||||
'user_id',
|
||||
nullable=True,
|
||||
new_column_name='github_user_id',
|
||||
)
|
||||
|
||||
# Then add the new user_id column
|
||||
op.add_column('conversation_metadata', Column('user_id', String, nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Drop the new user_id column
|
||||
op.drop_column('conversation_metadata', 'user_id')
|
||||
|
||||
# Rename github_user_id back to user_id
|
||||
op.alter_column(
|
||||
'conversation_metadata', 'github_user_id', new_column_name='user_id'
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
"""add sandbox_base_container_image and sandbox_runtime_container_image columns
|
||||
|
||||
Revision ID: 015
|
||||
Revises: 014
|
||||
Create Date: 2025-03-19 19:30:00.000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '015'
|
||||
down_revision: Union[str, None] = '014'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add columns to settings table
|
||||
op.add_column(
|
||||
'settings',
|
||||
sa.Column('sandbox_base_container_image', sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
'settings',
|
||||
sa.Column('sandbox_runtime_container_image', sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
# Add columns to user_settings table
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column('sandbox_base_container_image', sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column('sandbox_runtime_container_image', sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop columns from settings table
|
||||
op.drop_column('settings', 'sandbox_base_container_image')
|
||||
op.drop_column('settings', 'sandbox_runtime_container_image')
|
||||
|
||||
# Drop columns from user_settings table
|
||||
op.drop_column('user_settings', 'sandbox_base_container_image')
|
||||
op.drop_column('user_settings', 'sandbox_runtime_container_image')
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Add user settings version which acts as a hint of external db state
|
||||
|
||||
Revision ID: 016
|
||||
Revises: 015
|
||||
Create Date: 2025-03-20 16:30:00.000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '016'
|
||||
down_revision: Union[str, None] = '015'
|
||||
branch_labels: Union[str, sa.Sequence[str], None] = None
|
||||
depends_on: Union[str, sa.Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column('user_version', sa.Integer(), nullable=False, server_default='0'),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('user_settings', 'user_version')
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Add a stripe customers table
|
||||
|
||||
Revision ID: 017
|
||||
Revises: 016
|
||||
Create Date: 2025-03-20 16:30:00.000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '017'
|
||||
down_revision: Union[str, None] = '016'
|
||||
branch_labels: Union[str, sa.Sequence[str], None] = None
|
||||
depends_on: Union[str, sa.Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'stripe_customers',
|
||||
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||
sa.Column('keycloak_user_id', sa.String(), nullable=False),
|
||||
sa.Column('stripe_customer_id', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(),
|
||||
server_default=sa.text('now()'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
'updated_at',
|
||||
sa.DateTime(),
|
||||
server_default=sa.text('now()'),
|
||||
onupdate=sa.text('now()'),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
# Create indexes for faster lookups
|
||||
op.create_index(
|
||||
'idx_stripe_customers_keycloak_user_id',
|
||||
'stripe_customers',
|
||||
['keycloak_user_id'],
|
||||
)
|
||||
op.create_index(
|
||||
'idx_stripe_customers_stripe_customer_id',
|
||||
'stripe_customers',
|
||||
['stripe_customer_id'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('stripe_customers')
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Add a table for tracking output from maintainance scripts. These are basically migrations that are not sql centric.
|
||||
Revision ID: 018
|
||||
Revises: 017
|
||||
Create Date: 2025-03-26 19:45:00.000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '018'
|
||||
down_revision: Union[str, None] = '017'
|
||||
branch_labels: Union[str, sa.Sequence[str], None] = None
|
||||
depends_on: Union[str, sa.Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'script_results',
|
||||
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||
sa.Column('revision', sa.String(), nullable=False, index=True),
|
||||
sa.Column('data', sa.JSON()),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(),
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('script_results')
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Remove duplicates from stripe. This is a non standard alembic migration for non sql resources.
|
||||
|
||||
Revision ID: 019
|
||||
Revises: 018
|
||||
Create Date: 2025-03-20 16:30:00.000
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
import stripe
|
||||
from alembic import op
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '019'
|
||||
down_revision: Union[str, None] = '018'
|
||||
branch_labels: Union[str, sa.Sequence[str], None] = None
|
||||
depends_on: Union[str, sa.Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Skip migration if STRIPE_API_KEY is not set
|
||||
if 'STRIPE_API_KEY' not in os.environ:
|
||||
print('Skipping migration 019: STRIPE_API_KEY not set')
|
||||
return
|
||||
|
||||
stripe.api_key = os.environ['STRIPE_API_KEY']
|
||||
|
||||
# Get all users from stripe
|
||||
user_id_to_customer_ids = defaultdict(list)
|
||||
customers = stripe.Customer.list()
|
||||
for customer in customers.auto_paging_iter():
|
||||
user_id = customer.metadata.get('user_id')
|
||||
if user_id:
|
||||
user_id_to_customer_ids[user_id].append(customer.id)
|
||||
|
||||
# Canonical
|
||||
stripe_customers = {
|
||||
row[0]: row[1]
|
||||
for row in op.get_bind().execute(
|
||||
text('SELECT keycloak_user_id, stripe_customer_id FROM stripe_customers')
|
||||
)
|
||||
}
|
||||
|
||||
to_delete = []
|
||||
for user_id, customer_ids in user_id_to_customer_ids.items():
|
||||
if len(customer_ids) == 1:
|
||||
continue
|
||||
canonical_customer_id = stripe_customers.get(user_id)
|
||||
if canonical_customer_id:
|
||||
for customer_id in customer_ids:
|
||||
if customer_id != canonical_customer_id:
|
||||
to_delete.append({'user_id': user_id, 'customer_id': customer_id})
|
||||
else:
|
||||
# Prioritize deletion of items that don't have payment methods
|
||||
to_delete_for_customer = []
|
||||
for customer_id in customer_ids:
|
||||
payment_methods = stripe.Customer.list_payment_methods(customer_id)
|
||||
to_delete_for_customer.append(
|
||||
{
|
||||
'user_id': user_id,
|
||||
'customer_id': customer_id,
|
||||
'num_payment_methods': len(payment_methods),
|
||||
}
|
||||
)
|
||||
to_delete_for_customer.sort(
|
||||
key=lambda c: c['num_payment_methods'], reverse=True
|
||||
)
|
||||
to_delete.extend(to_delete_for_customer[1:])
|
||||
|
||||
for item in to_delete:
|
||||
op.get_bind().execute(
|
||||
text(
|
||||
'INSERT INTO script_results (revision, data) VALUES (:revision, :data)'
|
||||
),
|
||||
{
|
||||
'revision': revision,
|
||||
'data': json.dumps(item),
|
||||
},
|
||||
)
|
||||
stripe.Customer.delete(item['customer_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.get_bind().execute(
|
||||
text('DELETE FROM script_results WHERE revision=:revision'),
|
||||
{'revision': revision},
|
||||
)
|
||||
40
enterprise/migrations/versions/020_set_condenser_to_false.py
Normal file
40
enterprise/migrations/versions/020_set_condenser_to_false.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""set condenser to false for all users
|
||||
|
||||
Revision ID: 020
|
||||
Revises: 019
|
||||
Create Date: 2025-04-02 12:45:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.sql import column, table
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '020'
|
||||
down_revision: Union[str, None] = '019'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Define tables for update operations
|
||||
settings_table = table('settings', column('enable_default_condenser', sa.Boolean))
|
||||
|
||||
user_settings_table = table(
|
||||
'user_settings', column('enable_default_condenser', sa.Boolean)
|
||||
)
|
||||
|
||||
# Update the enable_default_condenser column to False for all users in the settings table
|
||||
op.execute(settings_table.update().values(enable_default_condenser=False))
|
||||
|
||||
# Update the enable_default_condenser column to False for all users in the user_settings table
|
||||
op.execute(user_settings_table.update().values(enable_default_condenser=False))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# No downgrade operation needed as we're just setting a value
|
||||
# and not changing schema structure
|
||||
pass
|
||||
@@ -0,0 +1,46 @@
|
||||
"""create auth tokens table
|
||||
|
||||
Revision ID: 021
|
||||
Revises: 020
|
||||
Create Date: 2025-03-30 20:15:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '021'
|
||||
down_revision: Union[str, None] = '020'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'auth_tokens',
|
||||
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||
sa.Column('keycloak_user_id', sa.String(), nullable=False),
|
||||
sa.Column('identity_provider', sa.String(), nullable=False),
|
||||
sa.Column('access_token', sa.String(), nullable=False),
|
||||
sa.Column('refresh_token', sa.String(), nullable=False),
|
||||
sa.Column('access_token_expires_at', sa.BigInteger(), nullable=False),
|
||||
sa.Column('refresh_token_expires_at', sa.BigInteger(), nullable=False),
|
||||
)
|
||||
op.create_index(
|
||||
'idx_auth_tokens_keycloak_user_id', 'auth_tokens', ['keycloak_user_id']
|
||||
)
|
||||
op.create_index(
|
||||
'idx_auth_tokens_keycloak_user_identity_provider',
|
||||
'auth_tokens',
|
||||
['keycloak_user_id', 'identity_provider'],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('idx_auth_tokens_keycloak_user_identity_provider', 'auth_tokens')
|
||||
op.drop_index('idx_auth_tokens_keycloak_user_id', 'auth_tokens')
|
||||
op.drop_table('auth_tokens')
|
||||
44
enterprise/migrations/versions/022_create_api_keys_table.py
Normal file
44
enterprise/migrations/versions/022_create_api_keys_table.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Create API keys table
|
||||
|
||||
Revision ID: 022
|
||||
Revises: 021
|
||||
Create Date: 2025-04-03
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '022'
|
||||
down_revision = '021'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
'api_keys',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('key', sa.String(length=255), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=True),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(),
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('key'),
|
||||
)
|
||||
op.create_index(op.f('ix_api_keys_key'), 'api_keys', ['key'], unique=True)
|
||||
op.create_index(op.f('ix_api_keys_user_id'), 'api_keys', ['user_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index(op.f('ix_api_keys_user_id'), table_name='api_keys')
|
||||
op.drop_index(op.f('ix_api_keys_key'), table_name='api_keys')
|
||||
op.drop_table('api_keys')
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Add cost and token metrics columns to conversation_metadata table.
|
||||
|
||||
Revision ID: 023
|
||||
Revises: 022
|
||||
Create Date: 2025-04-07
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '023'
|
||||
down_revision = '022'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Add cost and token metrics columns to conversation_metadata table
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('accumulated_cost', sa.Float(), nullable=True, server_default='0.0'),
|
||||
)
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('prompt_tokens', sa.Integer(), nullable=True, server_default='0'),
|
||||
)
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('completion_tokens', sa.Integer(), nullable=True, server_default='0'),
|
||||
)
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('total_tokens', sa.Integer(), nullable=True, server_default='0'),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Remove cost and token metrics columns from conversation_metadata table
|
||||
op.drop_column('conversation_metadata', 'accumulated_cost')
|
||||
op.drop_column('conversation_metadata', 'prompt_tokens')
|
||||
op.drop_column('conversation_metadata', 'completion_tokens')
|
||||
op.drop_column('conversation_metadata', 'total_tokens')
|
||||
@@ -0,0 +1,72 @@
|
||||
"""update enable_default_condenser default to True
|
||||
|
||||
Revision ID: 024
|
||||
Revises: 023
|
||||
Create Date: 2024-04-08 15:30:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.sql import column, table
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '024'
|
||||
down_revision: Union[str, None] = '023'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Update existing rows in settings table
|
||||
settings_table = table('settings', column('enable_default_condenser', sa.Boolean))
|
||||
op.execute(settings_table.update().values(enable_default_condenser=True))
|
||||
|
||||
# Update existing rows in user_settings table
|
||||
user_settings_table = table(
|
||||
'user_settings', column('enable_default_condenser', sa.Boolean)
|
||||
)
|
||||
op.execute(user_settings_table.update().values(enable_default_condenser=True))
|
||||
|
||||
# Alter the default value for settings table
|
||||
op.alter_column(
|
||||
'settings',
|
||||
'enable_default_condenser',
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.true(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Alter the default value for user_settings table
|
||||
op.alter_column(
|
||||
'user_settings',
|
||||
'enable_default_condenser',
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.true(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert the default value for settings table
|
||||
op.alter_column(
|
||||
'settings',
|
||||
'enable_default_condenser',
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.false(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Revert the default value for user_settings table
|
||||
op.alter_column(
|
||||
'user_settings',
|
||||
'enable_default_condenser',
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.false(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Note: We don't revert the data changes in the downgrade function
|
||||
# as it would be arbitrary which rows to change back
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Revert user_version from 3 to 2
|
||||
|
||||
Revision ID: 025
|
||||
Revises: 024
|
||||
Create Date: 2025-04-09
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '025'
|
||||
down_revision: Union[str, None] = '024'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Update user_version from 3 to 2 for all users who have version 3
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE user_settings
|
||||
SET user_version = 2
|
||||
WHERE user_version = 3
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert back to version 3 for users who have version 2
|
||||
# Note: This is not a perfect downgrade as we can't know which users originally had version 3
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE user_settings
|
||||
SET user_version = 3
|
||||
WHERE user_version = 2
|
||||
"""
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
"""add branch column to convo metadata table
|
||||
|
||||
Revision ID: 026
|
||||
Revises: 025
|
||||
Create Date: 2025-04-16 14:59:09.415
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '026'
|
||||
down_revision: Union[str, None] = '025'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('trigger', sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('conversation_metadata', 'trigger')
|
||||
@@ -0,0 +1,60 @@
|
||||
"""create saas settings table
|
||||
|
||||
Revision ID: 027
|
||||
Revises: 026
|
||||
Create Date: 2025-01-27 20:08:58.360566
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '027'
|
||||
down_revision: Union[str, None] = '026'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# This was created to match the settings object - in future some of these strings should probabyl
|
||||
# be replaced with enum types.
|
||||
op.create_table(
|
||||
'gitlab-webhook',
|
||||
sa.Column(
|
||||
'id', sa.Integer(), nullable=False, primary_key=True, autoincrement=True
|
||||
),
|
||||
sa.Column('group_id', sa.String(), nullable=True),
|
||||
sa.Column('project_id', sa.String(), nullable=True),
|
||||
sa.Column('user_id', sa.String(), nullable=False),
|
||||
sa.Column('webhook_exists', sa.Boolean(), nullable=False),
|
||||
sa.Column('webhook_name', sa.Boolean(), nullable=True),
|
||||
sa.Column('webhook_url', sa.String(), nullable=True),
|
||||
sa.Column('webhook_secret', sa.String(), nullable=True),
|
||||
sa.Column('scopes', sa.String, nullable=True),
|
||||
)
|
||||
|
||||
# Create indexes for faster lookups
|
||||
op.create_index('ix_gitlab_webhook_user_id', 'gitlab-webhook', ['user_id'])
|
||||
op.create_index('ix_gitlab_webhook_group_id', 'gitlab-webhook', ['group_id'])
|
||||
op.create_index('ix_gitlab_webhook_project_id', 'gitlab-webhook', ['project_id'])
|
||||
|
||||
# Add unique constraints on group_id and project_id to support UPSERT operations
|
||||
op.create_unique_constraint(
|
||||
'uq_gitlab_webhook_group_id', 'gitlab-webhook', ['group_id']
|
||||
)
|
||||
op.create_unique_constraint(
|
||||
'uq_gitlab_webhook_project_id', 'gitlab-webhook', ['project_id']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the constraints and indexes first before dropping the table
|
||||
op.drop_constraint('uq_gitlab_webhook_group_id', 'gitlab-webhook', type_='unique')
|
||||
op.drop_constraint('uq_gitlab_webhook_project_id', 'gitlab-webhook', type_='unique')
|
||||
op.drop_index('ix_gitlab_webhook_user_id', table_name='gitlab-webhook')
|
||||
op.drop_index('ix_gitlab_webhook_group_id', table_name='gitlab-webhook')
|
||||
op.drop_index('ix_gitlab_webhook_project_id', table_name='gitlab-webhook')
|
||||
op.drop_table('gitlab-webhook')
|
||||
@@ -0,0 +1,63 @@
|
||||
"""create user-repos table
|
||||
|
||||
Revision ID: 027
|
||||
Revises: 026
|
||||
Create Date: 2025-04-14
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '028'
|
||||
down_revision: Union[str, None] = '027'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'repos',
|
||||
sa.Column('id', sa.Integer(), sa.Identity(), primary_key=True),
|
||||
sa.Column('repo_name', sa.String(), nullable=False),
|
||||
sa.Column('repo_id', sa.String(), nullable=False),
|
||||
sa.Column('is_public', sa.Boolean(), nullable=False),
|
||||
sa.Column('has_microagent', sa.Boolean(), nullable=True),
|
||||
sa.Column('has_setup_script', sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'idx_repos_repo_id',
|
||||
'repos',
|
||||
['repo_id'],
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
'user-repos',
|
||||
sa.Column('id', sa.Integer(), sa.Identity(), primary_key=True),
|
||||
sa.Column('user_id', sa.String(), nullable=False),
|
||||
sa.Column('repo_id', sa.String(), nullable=False),
|
||||
sa.Column('admin', sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'idx_user_repos_repo_id',
|
||||
'user-repos',
|
||||
['repo_id'],
|
||||
)
|
||||
op.create_index(
|
||||
'idx_user_repos_user_id',
|
||||
'user-repos',
|
||||
['user_id'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('idx_repos_repo_id', 'repos')
|
||||
op.drop_index('idx_user_repos_repo_id', 'user-repos')
|
||||
op.drop_index('idx_user_repos_user_id', 'user-repos')
|
||||
op.drop_table('repos')
|
||||
op.drop_table('user-repos')
|
||||
@@ -0,0 +1,28 @@
|
||||
"""add accepted_tos to user_settings
|
||||
|
||||
Revision ID: 029
|
||||
Revises: 028
|
||||
Create Date: 2025-04-23
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '029'
|
||||
down_revision: Union[str, None] = '028'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'user_settings', sa.Column('accepted_tos', sa.DateTime(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('user_settings', 'accepted_tos')
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add proactive conversation starters column
|
||||
|
||||
Revision ID: 030
|
||||
Revises: 029
|
||||
Create Date: 2025-04-30
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '030'
|
||||
down_revision = '029'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column(
|
||||
'enable_proactive_conversation_starters',
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
default=True,
|
||||
server_default='TRUE',
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column('user_settings', 'enable_proactive_conversation_starters')
|
||||
36
enterprise/migrations/versions/031_add_user_secrets_store.py
Normal file
36
enterprise/migrations/versions/031_add_user_secrets_store.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""create user secrets table
|
||||
|
||||
Revision ID: 031
|
||||
Revises: 030
|
||||
Create Date: 2024-03-11 23:39:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '031'
|
||||
down_revision: Union[str, None] = '030'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'user_secrets',
|
||||
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||
sa.Column('keycloak_user_id', sa.String(), nullable=True),
|
||||
sa.Column('custom_secrets', sa.JSON(), nullable=True),
|
||||
)
|
||||
# Create indexes for faster lookups
|
||||
op.create_index(
|
||||
'idx_user_secrets_keycloak_user_id', 'user_secrets', ['keycloak_user_id']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('idx_user_secrets_keycloak_user_id', 'user_secrets')
|
||||
op.drop_table('user_secrets')
|
||||
@@ -0,0 +1,56 @@
|
||||
"""add status column to gitlab-webhook table
|
||||
|
||||
Revision ID: 032
|
||||
Revises: 031
|
||||
Create Date: 2025-04-21
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '032'
|
||||
down_revision: Union[str, None] = '031'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.rename_table('gitlab-webhook', 'gitlab_webhook')
|
||||
|
||||
op.add_column(
|
||||
'gitlab_webhook',
|
||||
sa.Column(
|
||||
'last_synced',
|
||||
sa.DateTime(),
|
||||
server_default=sa.text('now()'),
|
||||
onupdate=sa.text('now()'),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.drop_column('gitlab_webhook', 'webhook_name')
|
||||
|
||||
op.alter_column(
|
||||
'gitlab_webhook',
|
||||
'scopes',
|
||||
existing_type=sa.String,
|
||||
type_=sa.ARRAY(sa.Text()),
|
||||
existing_nullable=True,
|
||||
postgresql_using='ARRAY[]::text[]',
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
'gitlab_webhook', sa.Column('webhook_name', sa.Boolean(), nullable=True)
|
||||
)
|
||||
|
||||
# Drop the new column from the renamed table
|
||||
op.drop_column('gitlab_webhook', 'last_synced')
|
||||
|
||||
# Rename the table back
|
||||
op.rename_table('gitlab_webhook', 'gitlab-webhook')
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user