mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
16 Commits
hotfix/set
...
fix-microa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d89fada9c | ||
|
|
6e089619e0 | ||
|
|
179a89a211 | ||
|
|
8795ee6c6e | ||
|
|
97e938d545 | ||
|
|
b9a70c8d5c | ||
|
|
082d0b25c5 | ||
|
|
c5797d1d5a | ||
|
|
7ce1fb85ff | ||
|
|
fa6792e5a6 | ||
|
|
3d9b4c4af6 | ||
|
|
e21cbf67ee | ||
|
|
6b2e3f938f | ||
|
|
580d7b938c | ||
|
|
28178a2940 | ||
|
|
04382b2b19 |
78
.github/workflows/integration-runner.yml
vendored
78
.github/workflows/integration-runner.yml
vendored
@@ -56,6 +56,7 @@ jobs:
|
||||
LLM_MODEL: "litellm_proxy/claude-3-5-haiku-20241022"
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
|
||||
MAX_ITERATIONS: 10
|
||||
run: |
|
||||
echo "[llm.eval]" > config.toml
|
||||
echo "model = \"$LLM_MODEL\"" >> config.toml
|
||||
@@ -70,7 +71,7 @@ jobs:
|
||||
env:
|
||||
SANDBOX_FORCE_REBUILD_RUNTIME: True
|
||||
run: |
|
||||
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD CodeActAgent '' $N_PROCESSES '' 'haiku_run'
|
||||
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD CodeActAgent '' 10 $N_PROCESSES '' 'haiku_run'
|
||||
|
||||
# get integration tests report
|
||||
REPORT_FILE_HAIKU=$(find evaluation/evaluation_outputs/outputs/integration_tests/CodeActAgent/*haiku*_maxiter_10_N* -name "report.md" -type f | head -n 1)
|
||||
@@ -88,6 +89,7 @@ jobs:
|
||||
LLM_MODEL: "litellm_proxy/deepseek-chat"
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
|
||||
MAX_ITERATIONS: 10
|
||||
run: |
|
||||
echo "[llm.eval]" > config.toml
|
||||
echo "model = \"$LLM_MODEL\"" >> config.toml
|
||||
@@ -99,7 +101,7 @@ jobs:
|
||||
env:
|
||||
SANDBOX_FORCE_REBUILD_RUNTIME: True
|
||||
run: |
|
||||
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD CodeActAgent '' $N_PROCESSES '' 'deepseek_run'
|
||||
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD CodeActAgent '' 10 $N_PROCESSES '' 'deepseek_run'
|
||||
|
||||
# get integration tests report
|
||||
REPORT_FILE_DEEPSEEK=$(find evaluation/evaluation_outputs/outputs/integration_tests/CodeActAgent/deepseek*_maxiter_10_N* -name "report.md" -type f | head -n 1)
|
||||
@@ -109,11 +111,75 @@ jobs:
|
||||
echo >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# Run DelegatorAgent tests for Haiku, limited to t01 and t02
|
||||
- name: Wait a little bit (again)
|
||||
run: sleep 5
|
||||
|
||||
- name: Configure config.toml for testing DelegatorAgent (Haiku)
|
||||
env:
|
||||
LLM_MODEL: "litellm_proxy/claude-3-5-haiku-20241022"
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
|
||||
MAX_ITERATIONS: 30
|
||||
run: |
|
||||
echo "[llm.eval]" > config.toml
|
||||
echo "model = \"$LLM_MODEL\"" >> config.toml
|
||||
echo "api_key = \"$LLM_API_KEY\"" >> config.toml
|
||||
echo "base_url = \"$LLM_BASE_URL\"" >> config.toml
|
||||
echo "temperature = 0.0" >> config.toml
|
||||
|
||||
- name: Run integration test evaluation for DelegatorAgent (Haiku)
|
||||
env:
|
||||
SANDBOX_FORCE_REBUILD_RUNTIME: True
|
||||
run: |
|
||||
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD DelegatorAgent '' 30 $N_PROCESSES "t01_fix_simple_typo,t02_add_bash_hello" 'delegator_haiku_run'
|
||||
|
||||
# Find and export the delegator test results
|
||||
REPORT_FILE_DELEGATOR_HAIKU=$(find evaluation/evaluation_outputs/outputs/integration_tests/DelegatorAgent/*haiku*_maxiter_30_N* -name "report.md" -type f | head -n 1)
|
||||
echo "REPORT_FILE_DELEGATOR_HAIKU: $REPORT_FILE_DELEGATOR_HAIKU"
|
||||
echo "INTEGRATION_TEST_REPORT_DELEGATOR_HAIKU<<EOF" >> $GITHUB_ENV
|
||||
cat $REPORT_FILE_DELEGATOR_HAIKU >> $GITHUB_ENV
|
||||
echo >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# Run DelegatorAgent tests for DeepSeek, limited to t01 and t02
|
||||
- name: Wait a little bit (again)
|
||||
run: sleep 5
|
||||
|
||||
- name: Configure config.toml for testing DelegatorAgent (DeepSeek)
|
||||
env:
|
||||
LLM_MODEL: "litellm_proxy/deepseek-chat"
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
|
||||
MAX_ITERATIONS: 30
|
||||
run: |
|
||||
echo "[llm.eval]" > config.toml
|
||||
echo "model = \"$LLM_MODEL\"" >> config.toml
|
||||
echo "api_key = \"$LLM_API_KEY\"" >> config.toml
|
||||
echo "base_url = \"$LLM_BASE_URL\"" >> config.toml
|
||||
echo "temperature = 0.0" >> config.toml
|
||||
|
||||
- name: Run integration test evaluation for DelegatorAgent (DeepSeek)
|
||||
env:
|
||||
SANDBOX_FORCE_REBUILD_RUNTIME: True
|
||||
run: |
|
||||
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD DelegatorAgent '' 30 $N_PROCESSES "t01_fix_simple_typo,t02_add_bash_hello" 'delegator_deepseek_run'
|
||||
|
||||
# Find and export the delegator test results
|
||||
REPORT_FILE_DELEGATOR_DEEPSEEK=$(find evaluation/evaluation_outputs/outputs/integration_tests/DelegatorAgent/deepseek*_maxiter_30_N* -name "report.md" -type f | head -n 1)
|
||||
echo "REPORT_FILE_DELEGATOR_DEEPSEEK: $REPORT_FILE_DELEGATOR_DEEPSEEK"
|
||||
echo "INTEGRATION_TEST_REPORT_DELEGATOR_DEEPSEEK<<EOF" >> $GITHUB_ENV
|
||||
cat $REPORT_FILE_DELEGATOR_DEEPSEEK >> $GITHUB_ENV
|
||||
echo >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
|
||||
- name: Create archive of evaluation outputs
|
||||
run: |
|
||||
TIMESTAMP=$(date +'%y-%m-%d-%H-%M')
|
||||
cd evaluation/evaluation_outputs/outputs # Change to the outputs directory
|
||||
tar -czvf ../../../integration_tests_${TIMESTAMP}.tar.gz integration_tests/CodeActAgent/* # Only include the actual result directories
|
||||
tar -czvf ../../../integration_tests_${TIMESTAMP}.tar.gz integration_tests/CodeActAgent/* integration_tests/DelegatorAgent/* # Only include the actual result directories
|
||||
|
||||
- name: Upload evaluation results as artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
@@ -154,5 +220,11 @@ jobs:
|
||||
**Integration Tests Report (DeepSeek)**
|
||||
DeepSeek LLM Test Results:
|
||||
${{ env.INTEGRATION_TEST_REPORT_DEEPSEEK }}
|
||||
---
|
||||
**Integration Tests Report Delegator (Haiku)**
|
||||
${{ env.INTEGRATION_TEST_REPORT_DELEGATOR_HAIKU }}
|
||||
---
|
||||
**Integration Tests Report Delegator (DeepSeek)**
|
||||
${{ env.INTEGRATION_TEST_REPORT_DELEGATOR_DEEPSEEK }}
|
||||
---
|
||||
Download testing outputs (includes both Haiku and DeepSeek results): [Download](${{ steps.upload_results_artifact.outputs.artifact-url }})
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
# Options de configuration
|
||||
|
||||
Ce guide détaille toutes les options de configuration disponibles pour OpenHands, vous aidant à personnaliser son comportement et à l'intégrer avec d'autres services.
|
||||
@@ -184,6 +182,10 @@ Les options de configuration LLM (Large Language Model) sont définies dans la s
|
||||
|
||||
Pour les utiliser avec la commande docker, passez `-e LLM_<option>`. Exemple : `-e LLM_NUM_RETRIES`.
|
||||
|
||||
:::note
|
||||
Pour les configurations de développement, vous pouvez également définir des configurations LLM personnalisées. Voir [Configurations LLM personnalisées](./llms/custom-llm-configs) pour plus de détails.
|
||||
:::
|
||||
|
||||
**Informations d'identification AWS**
|
||||
- `aws_access_key_id`
|
||||
- Type : `str`
|
||||
@@ -368,4 +370,26 @@ Les options de configuration de l'agent sont définies dans les sections `[agent
|
||||
- `codeact_enable_llm_editor`
|
||||
- Type : `bool`
|
||||
- Valeur par défaut : `false`
|
||||
- Description : Si l'éditeur LLM est activé dans l'espace d'action (foncti
|
||||
- Description : Si l'éditeur LLM est activé dans l'espace d'action (fonctionne uniquement avec l'appel de fonction)
|
||||
|
||||
**Utilisation du micro-agent**
|
||||
- `use_microagents`
|
||||
- Type : `bool`
|
||||
- Valeur par défaut : `true`
|
||||
- Description : Indique si l'utilisation des micro-agents est activée ou non
|
||||
|
||||
- `disabled_microagents`
|
||||
- Type : `list of str`
|
||||
- Valeur par défaut : `None`
|
||||
- Description : Liste des micro-agents à désactiver
|
||||
|
||||
### Exécution
|
||||
- `timeout`
|
||||
- Type : `int`
|
||||
- Valeur par défaut : `120`
|
||||
- Description : Délai d'expiration du bac à sable, en secondes
|
||||
|
||||
- `user_id`
|
||||
- Type : `int`
|
||||
- Valeur par défaut : `1000`
|
||||
- Description : ID de l'utilisateur du bac à sable
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
# Configurations LLM personnalisées
|
||||
|
||||
OpenHands permet de définir plusieurs configurations LLM nommées dans votre fichier `config.toml`. Cette fonctionnalité vous permet d'utiliser différentes configurations LLM pour différents usages, comme utiliser un modèle moins coûteux pour les tâches qui ne nécessitent pas de réponses de haute qualité, ou utiliser différents modèles avec différents paramètres pour des agents spécifiques.
|
||||
|
||||
## Comment ça fonctionne
|
||||
|
||||
Les configurations LLM nommées sont définies dans le fichier `config.toml` en utilisant des sections qui commencent par `llm.`. Par exemple :
|
||||
|
||||
```toml
|
||||
# Configuration LLM par défaut
|
||||
[llm]
|
||||
model = "gpt-4"
|
||||
api_key = "votre-clé-api"
|
||||
temperature = 0.0
|
||||
|
||||
# Configuration LLM personnalisée pour un modèle moins coûteux
|
||||
[llm.gpt3]
|
||||
model = "gpt-3.5-turbo"
|
||||
api_key = "votre-clé-api"
|
||||
temperature = 0.2
|
||||
|
||||
# Une autre configuration personnalisée avec des paramètres différents
|
||||
[llm.haute-creativite]
|
||||
model = "gpt-4"
|
||||
api_key = "votre-clé-api"
|
||||
temperature = 0.8
|
||||
top_p = 0.9
|
||||
```
|
||||
|
||||
Chaque configuration nommée hérite de tous les paramètres de la section `[llm]` par défaut et peut remplacer n'importe lequel de ces paramètres. Vous pouvez définir autant de configurations personnalisées que nécessaire.
|
||||
|
||||
## Utilisation des configurations personnalisées
|
||||
|
||||
### Avec les agents
|
||||
|
||||
Vous pouvez spécifier quelle configuration LLM un agent doit utiliser en définissant le paramètre `llm_config` dans la section de configuration de l'agent :
|
||||
|
||||
```toml
|
||||
[agent.RepoExplorerAgent]
|
||||
# Utiliser la configuration GPT-3 moins coûteuse pour cet agent
|
||||
llm_config = 'gpt3'
|
||||
|
||||
[agent.CodeWriterAgent]
|
||||
# Utiliser la configuration haute créativité pour cet agent
|
||||
llm_config = 'haute-creativite'
|
||||
```
|
||||
|
||||
### Options de configuration
|
||||
|
||||
Chaque configuration LLM nommée prend en charge toutes les mêmes options que la configuration LLM par défaut. Celles-ci incluent :
|
||||
|
||||
- Sélection du modèle (`model`)
|
||||
- Configuration de l'API (`api_key`, `base_url`, etc.)
|
||||
- Paramètres du modèle (`temperature`, `top_p`, etc.)
|
||||
- Paramètres de nouvelle tentative (`num_retries`, `retry_multiplier`, etc.)
|
||||
- Limites de jetons (`max_input_tokens`, `max_output_tokens`)
|
||||
- Et toutes les autres options de configuration LLM
|
||||
|
||||
Pour une liste complète des options disponibles, consultez la section Configuration LLM dans la documentation des [Options de configuration](../configuration-options).
|
||||
|
||||
## Cas d'utilisation
|
||||
|
||||
Les configurations LLM personnalisées sont particulièrement utiles dans plusieurs scénarios :
|
||||
|
||||
- **Optimisation des coûts** : Utiliser des modèles moins coûteux pour les tâches qui ne nécessitent pas de réponses de haute qualité, comme l'exploration de dépôt ou les opérations simples sur les fichiers.
|
||||
- **Réglage spécifique aux tâches** : Configurer différentes valeurs de température et de top_p pour les tâches qui nécessitent différents niveaux de créativité ou de déterminisme.
|
||||
- **Différents fournisseurs** : Utiliser différents fournisseurs LLM ou points d'accès API pour différentes tâches.
|
||||
- **Tests et développement** : Basculer facilement entre différentes configurations de modèles pendant le développement et les tests.
|
||||
|
||||
## Exemple : Optimisation des coûts
|
||||
|
||||
Un exemple pratique d'utilisation des configurations LLM personnalisées pour optimiser les coûts :
|
||||
|
||||
```toml
|
||||
# Configuration par défaut utilisant GPT-4 pour des réponses de haute qualité
|
||||
[llm]
|
||||
model = "gpt-4"
|
||||
api_key = "votre-clé-api"
|
||||
temperature = 0.0
|
||||
|
||||
# Configuration moins coûteuse pour l'exploration de dépôt
|
||||
[llm.repo-explorer]
|
||||
model = "gpt-3.5-turbo"
|
||||
temperature = 0.2
|
||||
|
||||
# Configuration pour la génération de code
|
||||
[llm.code-gen]
|
||||
model = "gpt-4"
|
||||
temperature = 0.0
|
||||
max_output_tokens = 2000
|
||||
|
||||
[agent.RepoExplorerAgent]
|
||||
llm_config = 'repo-explorer'
|
||||
|
||||
[agent.CodeWriterAgent]
|
||||
llm_config = 'code-gen'
|
||||
```
|
||||
|
||||
Dans cet exemple :
|
||||
- L'exploration de dépôt utilise un modèle moins coûteux car il s'agit principalement de comprendre et de naviguer dans le code
|
||||
- La génération de code utilise GPT-4 avec une limite de jetons plus élevée pour générer des blocs de code plus importants
|
||||
- La configuration par défaut reste disponible pour les autres tâches
|
||||
|
||||
:::note
|
||||
Les configurations LLM personnalisées ne sont disponibles que lors de l'utilisation d'OpenHands en mode développement, via `main.py` ou `cli.py`. Lors de l'exécution via `docker run`, veuillez utiliser les options de configuration standard.
|
||||
:::
|
||||
@@ -140,7 +140,11 @@ The LLM (Large Language Model) configuration options are defined in the `[llm]`
|
||||
|
||||
To use these with the docker command, pass in `-e LLM_<option>`. Example: `-e LLM_NUM_RETRIES`.
|
||||
|
||||
### AWS Credentials
|
||||
:::note
|
||||
For development setups, you can also define custom named LLM configurations. See [Custom LLM Configurations](./llms/custom-llm-configs) for details.
|
||||
:::
|
||||
|
||||
**AWS Credentials**
|
||||
- `aws_access_key_id`
|
||||
- Type: `str`
|
||||
- Default: `""`
|
||||
|
||||
106
docs/modules/usage/llms/custom-llm-configs.md
Normal file
106
docs/modules/usage/llms/custom-llm-configs.md
Normal file
@@ -0,0 +1,106 @@
|
||||
# Custom LLM Configurations
|
||||
|
||||
OpenHands supports defining multiple named LLM configurations in your `config.toml` file. This feature allows you to use different LLM configurations for different purposes, such as using a cheaper model for tasks that don't require high-quality responses, or using different models with different parameters for specific agents.
|
||||
|
||||
## How It Works
|
||||
|
||||
Named LLM configurations are defined in the `config.toml` file using sections that start with `llm.`. For example:
|
||||
|
||||
```toml
|
||||
# Default LLM configuration
|
||||
[llm]
|
||||
model = "gpt-4"
|
||||
api_key = "your-api-key"
|
||||
temperature = 0.0
|
||||
|
||||
# Custom LLM configuration for a cheaper model
|
||||
[llm.gpt3]
|
||||
model = "gpt-3.5-turbo"
|
||||
api_key = "your-api-key"
|
||||
temperature = 0.2
|
||||
|
||||
# Another custom configuration with different parameters
|
||||
[llm.high-creativity]
|
||||
model = "gpt-4"
|
||||
api_key = "your-api-key"
|
||||
temperature = 0.8
|
||||
top_p = 0.9
|
||||
```
|
||||
|
||||
Each named configuration inherits all settings from the default `[llm]` section and can override any of those settings. You can define as many custom configurations as needed.
|
||||
|
||||
## Using Custom Configurations
|
||||
|
||||
### With Agents
|
||||
|
||||
You can specify which LLM configuration an agent should use by setting the `llm_config` parameter in the agent's configuration section:
|
||||
|
||||
```toml
|
||||
[agent.RepoExplorerAgent]
|
||||
# Use the cheaper GPT-3 configuration for this agent
|
||||
llm_config = 'gpt3'
|
||||
|
||||
[agent.CodeWriterAgent]
|
||||
# Use the high creativity configuration for this agent
|
||||
llm_config = 'high-creativity'
|
||||
```
|
||||
|
||||
### Configuration Options
|
||||
|
||||
Each named LLM configuration supports all the same options as the default LLM configuration. These include:
|
||||
|
||||
- Model selection (`model`)
|
||||
- API configuration (`api_key`, `base_url`, etc.)
|
||||
- Model parameters (`temperature`, `top_p`, etc.)
|
||||
- Retry settings (`num_retries`, `retry_multiplier`, etc.)
|
||||
- Token limits (`max_input_tokens`, `max_output_tokens`)
|
||||
- And all other LLM configuration options
|
||||
|
||||
For a complete list of available options, see the LLM Configuration section in the [Configuration Options](../configuration-options) documentation.
|
||||
|
||||
## Use Cases
|
||||
|
||||
Custom LLM configurations are particularly useful in several scenarios:
|
||||
|
||||
- **Cost Optimization**: Use cheaper models for tasks that don't require high-quality responses, like repository exploration or simple file operations.
|
||||
- **Task-Specific Tuning**: Configure different temperature and top_p values for tasks that require different levels of creativity or determinism.
|
||||
- **Different Providers**: Use different LLM providers or API endpoints for different tasks.
|
||||
- **Testing and Development**: Easily switch between different model configurations during development and testing.
|
||||
|
||||
## Example: Cost Optimization
|
||||
|
||||
A practical example of using custom LLM configurations to optimize costs:
|
||||
|
||||
```toml
|
||||
# Default configuration using GPT-4 for high-quality responses
|
||||
[llm]
|
||||
model = "gpt-4"
|
||||
api_key = "your-api-key"
|
||||
temperature = 0.0
|
||||
|
||||
# Cheaper configuration for repository exploration
|
||||
[llm.repo-explorer]
|
||||
model = "gpt-3.5-turbo"
|
||||
temperature = 0.2
|
||||
|
||||
# Configuration for code generation
|
||||
[llm.code-gen]
|
||||
model = "gpt-4"
|
||||
temperature = 0.0
|
||||
max_output_tokens = 2000
|
||||
|
||||
[agent.RepoExplorerAgent]
|
||||
llm_config = 'repo-explorer'
|
||||
|
||||
[agent.CodeWriterAgent]
|
||||
llm_config = 'code-gen'
|
||||
```
|
||||
|
||||
In this example:
|
||||
- Repository exploration uses a cheaper model since it mainly involves understanding and navigating code
|
||||
- Code generation uses GPT-4 with a higher token limit for generating larger code blocks
|
||||
- The default configuration remains available for other tasks
|
||||
|
||||
:::note
|
||||
Custom LLM configurations are only available when using OpenHands in development mode, via `main.py` or `cli.py`. When running via `docker run`, please use the standard configuration options.
|
||||
:::
|
||||
@@ -8,13 +8,15 @@ from evaluation.integration_tests.tests.base import BaseIntegrationTest, TestRes
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
codeact_user_response,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
run_evaluation,
|
||||
update_llm_config_for_completions_logging,
|
||||
)
|
||||
from evaluation.utils.shared import (
|
||||
codeact_user_response as fake_user_response,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
AgentConfig,
|
||||
@@ -31,7 +33,8 @@ from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
FAKE_RESPONSES = {
|
||||
'CodeActAgent': codeact_user_response,
|
||||
'CodeActAgent': fake_user_response,
|
||||
'DelegatorAgent': fake_user_response,
|
||||
}
|
||||
|
||||
|
||||
@@ -219,7 +222,7 @@ if __name__ == '__main__':
|
||||
|
||||
df = pd.read_json(output_file, lines=True, orient='records')
|
||||
|
||||
# record success and reason for failure for the final report
|
||||
# record success and reason
|
||||
df['success'] = df['test_result'].apply(lambda x: x['success'])
|
||||
df['reason'] = df['test_result'].apply(lambda x: x['reason'])
|
||||
logger.info('-' * 100)
|
||||
@@ -234,15 +237,27 @@ if __name__ == '__main__':
|
||||
logger.info('-' * 100)
|
||||
|
||||
# record cost for each instance, with 3 decimal places
|
||||
df['cost'] = df['metrics'].apply(lambda x: round(x['accumulated_cost'], 3))
|
||||
# we sum up all the "costs" from the metrics array
|
||||
df['cost'] = df['metrics'].apply(
|
||||
lambda m: round(sum(c['cost'] for c in m['costs']), 3)
|
||||
if m and 'costs' in m
|
||||
else 0.0
|
||||
)
|
||||
|
||||
# capture the top-level error if present, per instance
|
||||
df['error_message'] = df.get('error', None)
|
||||
|
||||
logger.info(f'Total cost: USD {df["cost"].sum():.2f}')
|
||||
|
||||
report_file = os.path.join(metadata.eval_output_dir, 'report.md')
|
||||
with open(report_file, 'w') as f:
|
||||
f.write(
|
||||
f'Success rate: {df["success"].mean():.2%} ({df["success"].sum()}/{len(df)})\n'
|
||||
f'Success rate: {df["success"].mean():.2%}'
|
||||
f' ({df["success"].sum()}/{len(df)})\n'
|
||||
)
|
||||
f.write(f'\nTotal cost: USD {df["cost"].sum():.2f}\n')
|
||||
f.write(
|
||||
df[['instance_id', 'success', 'reason', 'cost']].to_markdown(index=False)
|
||||
df[
|
||||
['instance_id', 'success', 'reason', 'cost', 'error_message']
|
||||
].to_markdown(index=False)
|
||||
)
|
||||
|
||||
@@ -7,8 +7,9 @@ MODEL_CONFIG=$1
|
||||
COMMIT_HASH=$2
|
||||
AGENT=$3
|
||||
EVAL_LIMIT=$4
|
||||
NUM_WORKERS=$5
|
||||
EVAL_IDS=$6
|
||||
MAX_ITERATIONS=$5
|
||||
NUM_WORKERS=$6
|
||||
EVAL_IDS=$7
|
||||
|
||||
if [ -z "$NUM_WORKERS" ]; then
|
||||
NUM_WORKERS=1
|
||||
@@ -43,7 +44,7 @@ fi
|
||||
COMMAND="poetry run python evaluation/integration_tests/run_infer.py \
|
||||
--agent-cls $AGENT \
|
||||
--llm-config $MODEL_CONFIG \
|
||||
--max-iterations 10 \
|
||||
--max-iterations ${MAX_ITERATIONS:-10} \
|
||||
--eval-num-workers $NUM_WORKERS \
|
||||
--eval-note $EVAL_NOTE"
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { describe, it, expect, afterEach, vi } from "vitest";
|
||||
import * as router from "react-router";
|
||||
|
||||
// Mock useParams before importing components
|
||||
vi.mock("react-router", async () => {
|
||||
const actual = await vi.importActual("react-router");
|
||||
return {
|
||||
...actual as object,
|
||||
...(actual as object),
|
||||
useParams: () => ({ conversationId: "test-conversation-id" }),
|
||||
};
|
||||
});
|
||||
@@ -14,7 +13,7 @@ vi.mock("react-router", async () => {
|
||||
vi.mock("react-i18next", async () => {
|
||||
const actual = await vi.importActual("react-i18next");
|
||||
return {
|
||||
...actual as object,
|
||||
...(actual as object),
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
i18n: {
|
||||
@@ -28,7 +27,6 @@ import { screen } from "@testing-library/react";
|
||||
import { renderWithProviders } from "../../test-utils";
|
||||
import { BrowserPanel } from "#/components/features/browser/browser";
|
||||
|
||||
|
||||
describe("Browser", () => {
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
@@ -2,36 +2,42 @@ import { describe, expect, it } from "vitest";
|
||||
import { screen } from "@testing-library/react";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { ExpandableMessage } from "#/components/features/chat/expandable-message";
|
||||
import { vi } from 'vitest';
|
||||
import { vi } from "vitest";
|
||||
|
||||
vi.mock('react-i18next', async () => {
|
||||
const actual = await vi.importActual('react-i18next');
|
||||
vi.mock("react-i18next", async () => {
|
||||
const actual = await vi.importActual("react-i18next");
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key:string) => key,
|
||||
t: (key: string) => key,
|
||||
i18n: {
|
||||
changeLanguage: () => new Promise(() => {}),
|
||||
language: 'en',
|
||||
language: "en",
|
||||
exists: () => true,
|
||||
},
|
||||
}),
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
describe("ExpandableMessage", () => {
|
||||
it("should render with neutral border for non-action messages", () => {
|
||||
renderWithProviders(<ExpandableMessage message="Hello" type="thought" />);
|
||||
const element = screen.getByText("Hello");
|
||||
const container = element.closest("div.flex.gap-2.items-center.justify-start");
|
||||
const container = element.closest(
|
||||
"div.flex.gap-2.items-center.justify-start",
|
||||
);
|
||||
expect(container).toHaveClass("border-neutral-300");
|
||||
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render with neutral border for error messages", () => {
|
||||
renderWithProviders(<ExpandableMessage message="Error occurred" type="error" />);
|
||||
renderWithProviders(
|
||||
<ExpandableMessage message="Error occurred" type="error" />,
|
||||
);
|
||||
const element = screen.getByText("Error occurred");
|
||||
const container = element.closest("div.flex.gap-2.items-center.justify-start");
|
||||
const container = element.closest(
|
||||
"div.flex.gap-2.items-center.justify-start",
|
||||
);
|
||||
expect(container).toHaveClass("border-danger");
|
||||
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
|
||||
});
|
||||
@@ -43,10 +49,12 @@ describe("ExpandableMessage", () => {
|
||||
message="Command executed successfully"
|
||||
type="action"
|
||||
success={true}
|
||||
/>
|
||||
/>,
|
||||
);
|
||||
const element = screen.getByText("OBSERVATION_MESSAGE$RUN");
|
||||
const container = element.closest("div.flex.gap-2.items-center.justify-start");
|
||||
const container = element.closest(
|
||||
"div.flex.gap-2.items-center.justify-start",
|
||||
);
|
||||
expect(container).toHaveClass("border-neutral-300");
|
||||
const icon = screen.getByTestId("status-icon");
|
||||
expect(icon).toHaveClass("fill-success");
|
||||
@@ -59,10 +67,12 @@ describe("ExpandableMessage", () => {
|
||||
message="Command failed"
|
||||
type="action"
|
||||
success={false}
|
||||
/>
|
||||
/>,
|
||||
);
|
||||
const element = screen.getByText("OBSERVATION_MESSAGE$RUN");
|
||||
const container = element.closest("div.flex.gap-2.items-center.justify-start");
|
||||
const container = element.closest(
|
||||
"div.flex.gap-2.items-center.justify-start",
|
||||
);
|
||||
expect(container).toHaveClass("border-neutral-300");
|
||||
const icon = screen.getByTestId("status-icon");
|
||||
expect(icon).toHaveClass("fill-danger");
|
||||
@@ -74,10 +84,12 @@ describe("ExpandableMessage", () => {
|
||||
id="OBSERVATION_MESSAGE$RUN"
|
||||
message="Running command"
|
||||
type="action"
|
||||
/>
|
||||
/>,
|
||||
);
|
||||
const element = screen.getByText("OBSERVATION_MESSAGE$RUN");
|
||||
const container = element.closest("div.flex.gap-2.items-center.justify-start");
|
||||
const container = element.closest(
|
||||
"div.flex.gap-2.items-center.justify-start",
|
||||
);
|
||||
expect(container).toHaveClass("border-neutral-300");
|
||||
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
@@ -128,7 +128,7 @@ describe("Sidebar", () => {
|
||||
await user.click(norskOption);
|
||||
|
||||
const tokenInput =
|
||||
within(accountSettingsModal).getByLabelText(/GITHUB\$TOKEN_OPTIONAL/i);
|
||||
within(accountSettingsModal).getByLabelText(/GITHUB\$TOKEN_LABEL/i);
|
||||
await user.type(tokenInput, "new-token");
|
||||
|
||||
const saveButton =
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as router from "react-router";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
// Mock useParams before importing components
|
||||
vi.mock("react-router", async () => {
|
||||
const actual = await vi.importActual("react-router");
|
||||
return {
|
||||
...actual as object,
|
||||
...(actual as object),
|
||||
useParams: () => ({ conversationId: "test-conversation-id" }),
|
||||
};
|
||||
});
|
||||
@@ -60,7 +59,9 @@ describe("FeedbackForm", () => {
|
||||
renderWithProviders(
|
||||
<FeedbackForm polarity="positive" onClose={onCloseMock} />,
|
||||
);
|
||||
await user.click(screen.getByRole("button", { name: I18nKey.FEEDBACK$CANCEL_LABEL }));
|
||||
await user.click(
|
||||
screen.getByRole("button", { name: I18nKey.FEEDBACK$CANCEL_LABEL }),
|
||||
);
|
||||
|
||||
expect(onCloseMock).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { afterEach, beforeAll, describe, expect, it, vi } from "vitest";
|
||||
import * as router from "react-router";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { screen, waitFor, within } from "@testing-library/react";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
|
||||
20
frontend/__tests__/utils/test-config.tsx
Normal file
20
frontend/__tests__/utils/test-config.tsx
Normal file
@@ -0,0 +1,20 @@
|
||||
import { vi } from "vitest";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
|
||||
export const setupTestConfig = () => {
|
||||
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "oss",
|
||||
GITHUB_CLIENT_ID: "test-id",
|
||||
POSTHOG_CLIENT_KEY: "test-key",
|
||||
});
|
||||
};
|
||||
|
||||
export const setupSaasTestConfig = () => {
|
||||
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
GITHUB_CLIENT_ID: "test-id",
|
||||
POSTHOG_CLIENT_KEY: "test-key",
|
||||
});
|
||||
};
|
||||
@@ -41,6 +41,7 @@ export const isGitHubErrorReponse = <T extends object | Array<unknown>>(
|
||||
|
||||
// Axios interceptor to handle token refresh
|
||||
const setupAxiosInterceptors = (
|
||||
appMode: string,
|
||||
refreshToken: () => Promise<boolean>,
|
||||
logout: () => void,
|
||||
) => {
|
||||
@@ -74,18 +75,21 @@ const setupAxiosInterceptors = (
|
||||
!originalRequest._retry // Prevent infinite retry loops
|
||||
) {
|
||||
originalRequest._retry = true;
|
||||
try {
|
||||
const refreshed = await refreshToken();
|
||||
if (refreshed) {
|
||||
return await github(originalRequest);
|
||||
}
|
||||
|
||||
logout();
|
||||
return await Promise.reject(new Error("Failed to refresh token"));
|
||||
} catch (refreshError) {
|
||||
// If token refresh fails, evict the user
|
||||
logout();
|
||||
return Promise.reject(refreshError);
|
||||
if (appMode === "saas") {
|
||||
try {
|
||||
const refreshed = await refreshToken();
|
||||
if (refreshed) {
|
||||
return await github(originalRequest);
|
||||
}
|
||||
|
||||
logout();
|
||||
return await Promise.reject(new Error("Failed to refresh token"));
|
||||
} catch (refreshError) {
|
||||
// If token refresh fails, evict the user
|
||||
logout();
|
||||
return Promise.reject(refreshError);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { AccountSettingsModal } from "#/components/shared/modals/account-settings/account-settings-modal";
|
||||
import { SettingsModal } from "#/components/shared/modals/settings/settings-modal";
|
||||
import { useCurrentSettings } from "#/context/settings-context";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { ConversationPanel } from "../conversation-panel/conversation-panel";
|
||||
import { MULTI_CONVERSATION_UI } from "#/utils/feature-flags";
|
||||
import { useEndSession } from "#/hooks/use-end-session";
|
||||
@@ -27,7 +28,13 @@ export function Sidebar() {
|
||||
const user = useGitHubUser();
|
||||
const { data: isAuthed } = useIsAuthed();
|
||||
const { logout } = useAuth();
|
||||
const { isUpToDate: settingsAreUpToDate, settings } = useCurrentSettings();
|
||||
const {
|
||||
data: settings,
|
||||
isError: settingsIsError,
|
||||
isSuccess: settingsSuccessfulyFetched,
|
||||
} = useSettings();
|
||||
|
||||
const { isUpToDate: settingsAreUpToDate } = useCurrentSettings();
|
||||
|
||||
const [accountSettingsModalOpen, setAccountSettingsModalOpen] =
|
||||
React.useState(false);
|
||||
@@ -103,12 +110,13 @@ export function Sidebar() {
|
||||
{accountSettingsModalOpen && (
|
||||
<AccountSettingsModal onClose={handleAccountSettingsModalClose} />
|
||||
)}
|
||||
{showSettingsModal && settings && (
|
||||
<SettingsModal
|
||||
settings={settings}
|
||||
onClose={() => setSettingsModalIsOpen(false)}
|
||||
/>
|
||||
)}
|
||||
{settingsIsError ||
|
||||
(showSettingsModal && settingsSuccessfulyFetched && (
|
||||
<SettingsModal
|
||||
settings={settings}
|
||||
onClose={() => setSettingsModalIsOpen(false)}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ export function AccountSettingsForm({
|
||||
<>
|
||||
<CustomInput
|
||||
name="ghToken"
|
||||
label={t(I18nKey.GITHUB$TOKEN_OPTIONAL)}
|
||||
label={t(I18nKey.GITHUB$TOKEN_LABEL)}
|
||||
type="password"
|
||||
defaultValue={gitHubToken ?? ""}
|
||||
/>
|
||||
|
||||
@@ -26,7 +26,10 @@ export function RuntimeSizeSelector({
|
||||
id="runtime-size"
|
||||
name="runtime-size"
|
||||
defaultSelectedKeys={[String(defaultValue || 1)]}
|
||||
selectedKeys={[String(defaultValue || 1)]}
|
||||
isDisabled={isDisabled}
|
||||
selectionMode="single"
|
||||
disallowEmptySelection
|
||||
aria-label={t(I18nKey.SETTINGS_FORM$RUNTIME_SIZE_LABEL)}
|
||||
classNames={{
|
||||
trigger: "bg-[#27272A] rounded-md text-sm px-3 py-[10px]",
|
||||
|
||||
@@ -87,7 +87,12 @@ function AuthProvider({ children }: React.PropsWithChildren) {
|
||||
|
||||
setGitHubToken(storedGitHubToken);
|
||||
setUserId(userId);
|
||||
setupGithubAxiosInterceptors(refreshToken, logout);
|
||||
const setupIntercepter = async () => {
|
||||
const config = await OpenHands.getConfig();
|
||||
setupGithubAxiosInterceptors(config.APP_MODE, refreshToken, logout);
|
||||
};
|
||||
|
||||
setupIntercepter();
|
||||
}, []);
|
||||
|
||||
const value = React.useMemo(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import React from "react";
|
||||
import { useDispatch, useSelector } from "react-redux";
|
||||
import { useAuth } from "#/context/auth-context";
|
||||
import {
|
||||
useWsClient,
|
||||
WsClientProviderStatus,
|
||||
@@ -14,17 +13,12 @@ import { AgentState } from "#/types/agent-state";
|
||||
|
||||
export const useWSStatusChange = () => {
|
||||
const { send, status } = useWsClient();
|
||||
const { gitHubToken } = useAuth();
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
const dispatch = useDispatch();
|
||||
|
||||
const statusRef = React.useRef<WsClientProviderStatus | null>(null);
|
||||
|
||||
const { selectedRepository } = useSelector(
|
||||
(state: RootState) => state.initialQuery,
|
||||
);
|
||||
|
||||
const { files, importedProjectZip, initialQuery } = useSelector(
|
||||
const { files, initialQuery } = useSelector(
|
||||
(state: RootState) => state.initialQuery,
|
||||
);
|
||||
|
||||
@@ -33,30 +27,15 @@ export const useWSStatusChange = () => {
|
||||
send(createChatMessage(query, base64Files, timestamp));
|
||||
};
|
||||
|
||||
const dispatchInitialQuery = (query: string, additionalInfo: string) => {
|
||||
if (additionalInfo) {
|
||||
sendInitialQuery(`${query}\n\n[${additionalInfo}]`, files);
|
||||
} else {
|
||||
sendInitialQuery(query, files);
|
||||
}
|
||||
|
||||
const dispatchInitialQuery = (query: string) => {
|
||||
sendInitialQuery(query, files);
|
||||
dispatch(clearFiles()); // reset selected files
|
||||
dispatch(clearInitialQuery()); // reset initial query
|
||||
};
|
||||
|
||||
const handleAgentInit = () => {
|
||||
let additionalInfo = "";
|
||||
|
||||
if (gitHubToken && selectedRepository) {
|
||||
additionalInfo = `Repository ${selectedRepository} has been cloned to /workspace. Please check the /workspace for files.`;
|
||||
} else if (importedProjectZip) {
|
||||
// if there's an uploaded project zip, add it to the chat
|
||||
additionalInfo =
|
||||
"Files have been uploaded. Please check the /workspace for files.";
|
||||
}
|
||||
|
||||
if (initialQuery) {
|
||||
dispatchInitialQuery(initialQuery, additionalInfo);
|
||||
dispatchInitialQuery(initialQuery);
|
||||
}
|
||||
};
|
||||
React.useEffect(() => {
|
||||
|
||||
@@ -107,12 +107,7 @@ class CodeActAgent(Agent):
|
||||
f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2, ensure_ascii=False).replace("\\n", "\n")}'
|
||||
)
|
||||
self.prompt_manager = PromptManager(
|
||||
microagent_dir=os.path.join(
|
||||
os.path.dirname(os.path.dirname(openhands.__file__)),
|
||||
'microagents',
|
||||
)
|
||||
if self.config.use_microagents
|
||||
else None,
|
||||
microagent_dir=None, # Will be set in step() when we have access to the runtime
|
||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||
disabled_microagents=self.config.disabled_microagents,
|
||||
)
|
||||
@@ -369,6 +364,14 @@ class CodeActAgent(Agent):
|
||||
- MessageAction(content) - Message action to run (e.g. ask for clarification)
|
||||
- AgentFinishAction() - end the interaction
|
||||
"""
|
||||
# Initialize the prompt_manager with microagents from the runtime
|
||||
if self.config.use_microagents and 'runtime' in state.inputs:
|
||||
# Load microagents from the runtime
|
||||
runtime = state.inputs['runtime']
|
||||
microagents = runtime.get_microagents_from_selected_repo(None) # None means current workspace
|
||||
|
||||
# Load the microagents into the prompt manager
|
||||
self.prompt_manager.load_microagents(microagents)
|
||||
# Continue with pending actions if any
|
||||
if self.pending_actions:
|
||||
return self.pending_actions.popleft()
|
||||
|
||||
@@ -5,9 +5,14 @@ You are OpenHands agent, a helpful AI assistant that can interact with a compute
|
||||
* The assistant MUST NOT include comments in the code unless they are necessary to describe non-obvious behavior.
|
||||
{{ runtime_info }}
|
||||
</IMPORTANT>
|
||||
{% if repo_instructions -%}
|
||||
{% if repository_info %}
|
||||
<REPOSITORY_INFO>
|
||||
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
|
||||
</REPOSITORY_INFO>
|
||||
{% endif %}
|
||||
{% if repository_instructions -%}
|
||||
<REPOSITORY_INSTRUCTIONS>
|
||||
{{ repo_instructions }}
|
||||
{{ repository_instructions }}
|
||||
</REPOSITORY_INSTRUCTIONS>
|
||||
{% endif %}
|
||||
{% if runtime_info and runtime_info.available_hosts -%}
|
||||
|
||||
@@ -50,6 +50,10 @@ class MicroAgent(Agent):
|
||||
# history is in reverse order, let's fix it
|
||||
processed_history.reverse()
|
||||
|
||||
# everything starts with a message
|
||||
# the first message is already in the prompt as the task
|
||||
# TODO: so we don't need to include it in the history
|
||||
|
||||
return json.dumps(processed_history, **kwargs)
|
||||
|
||||
def __init__(self, llm: LLM, config: AgentConfig):
|
||||
|
||||
@@ -112,12 +112,16 @@ class AgentController:
|
||||
self.id = sid
|
||||
self.agent = agent
|
||||
self.headless_mode = headless_mode
|
||||
self.is_delegate = is_delegate
|
||||
|
||||
# subscribe to the event stream
|
||||
# the event stream must be set before maybe subscribing to it
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
|
||||
)
|
||||
|
||||
# subscribe to the event stream if this is not a delegate
|
||||
if not self.is_delegate:
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
|
||||
)
|
||||
|
||||
# state from the previous session, state from a parent agent, or a fresh state
|
||||
self.set_initial_state(
|
||||
@@ -165,7 +169,11 @@ class AgentController:
|
||||
)
|
||||
|
||||
# unsubscribe from the event stream
|
||||
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
|
||||
# only the root parent controller subscribes to the event stream
|
||||
if not self.is_delegate:
|
||||
self.event_stream.unsubscribe(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.id
|
||||
)
|
||||
self._closed = True
|
||||
|
||||
def log(self, level: str, message: str, extra: dict | None = None) -> None:
|
||||
@@ -226,9 +234,21 @@ class AgentController:
|
||||
await self._react_to_exception(reported)
|
||||
|
||||
def should_step(self, event: Event) -> bool:
|
||||
# it might be the delegate's day in the sun
|
||||
if self.delegate is not None:
|
||||
return False
|
||||
|
||||
if isinstance(event, Action):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
return True
|
||||
if (
|
||||
isinstance(event, MessageAction)
|
||||
and self.get_agent_state() != AgentState.AWAITING_USER_INPUT
|
||||
):
|
||||
# TODO: this is fragile, but how else to check if eligible?
|
||||
return True
|
||||
if isinstance(event, AgentDelegateAction):
|
||||
return True
|
||||
return False
|
||||
if isinstance(event, Observation):
|
||||
if isinstance(event, NullObservation) or isinstance(
|
||||
@@ -244,12 +264,35 @@ class AgentController:
|
||||
Args:
|
||||
event (Event): The incoming event to process.
|
||||
"""
|
||||
|
||||
# If we have a delegate that is not finished or errored, forward events to it
|
||||
if self.delegate is not None:
|
||||
delegate_state = self.delegate.get_agent_state()
|
||||
if delegate_state not in (
|
||||
AgentState.FINISHED,
|
||||
AgentState.ERROR,
|
||||
AgentState.REJECTED,
|
||||
):
|
||||
# Forward the event to delegate and skip parent processing
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
self.delegate._on_event(event)
|
||||
)
|
||||
return
|
||||
else:
|
||||
# delegate is done or errored, so end it
|
||||
self.end_delegate()
|
||||
return
|
||||
|
||||
# continue parent processing only if there's no active delegate
|
||||
asyncio.get_event_loop().run_until_complete(self._on_event(event))
|
||||
|
||||
async def _on_event(self, event: Event) -> None:
|
||||
if hasattr(event, 'hidden') and event.hidden:
|
||||
return
|
||||
|
||||
# Give others a little chance
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# if the event is not filtered out, add it to the history
|
||||
if not any(isinstance(event, filter_type) for filter_type in self.filter_out):
|
||||
self.state.history.append(event)
|
||||
@@ -263,17 +306,22 @@ class AgentController:
|
||||
self.step()
|
||||
|
||||
async def _handle_action(self, action: Action) -> None:
|
||||
"""Handles actions from the event stream.
|
||||
|
||||
Args:
|
||||
action (Action): The action to handle.
|
||||
"""
|
||||
"""Handles an Action from the agent or delegate."""
|
||||
if isinstance(action, ChangeAgentStateAction):
|
||||
await self.set_agent_state_to(action.agent_state) # type: ignore
|
||||
elif isinstance(action, MessageAction):
|
||||
await self._handle_message_action(action)
|
||||
elif isinstance(action, AgentDelegateAction):
|
||||
await self.start_delegate(action)
|
||||
assert self.delegate is not None
|
||||
# Post a MessageAction with the task for the delegate
|
||||
if 'task' in action.inputs:
|
||||
self.event_stream.add_event(
|
||||
MessageAction(content='TASK: ' + action.inputs['task']),
|
||||
EventSource.USER,
|
||||
)
|
||||
await self.delegate.set_agent_state_to(AgentState.RUNNING)
|
||||
return
|
||||
|
||||
elif isinstance(action, AgentFinishAction):
|
||||
self.state.outputs = action.outputs
|
||||
@@ -491,7 +539,7 @@ class AgentController:
|
||||
f'start delegate, creating agent {delegate_agent.name} using LLM {llm}',
|
||||
)
|
||||
|
||||
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
|
||||
# Create the delegate with is_delegate=True so it does NOT subscribe directly
|
||||
self.delegate = AgentController(
|
||||
sid=self.id + '-delegate',
|
||||
agent=delegate_agent,
|
||||
@@ -504,7 +552,57 @@ class AgentController:
|
||||
is_delegate=True,
|
||||
headless_mode=self.headless_mode,
|
||||
)
|
||||
await self.delegate.set_agent_state_to(AgentState.RUNNING)
|
||||
|
||||
def end_delegate(self) -> None:
|
||||
"""Ends the currently active delegate (e.g., if it is finished or errored)
|
||||
so that this controller can resume normal operation.
|
||||
"""
|
||||
if self.delegate is None:
|
||||
return
|
||||
|
||||
delegate_state = self.delegate.get_agent_state()
|
||||
|
||||
# update iteration that is shared across agents
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
|
||||
# close the delegate controller before adding new events
|
||||
asyncio.get_event_loop().run_until_complete(self.delegate.close())
|
||||
|
||||
if delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
|
||||
# retrieve delegate result
|
||||
delegate_outputs = (
|
||||
self.delegate.state.outputs if self.delegate.state else {}
|
||||
)
|
||||
|
||||
# prepare delegate result observation
|
||||
# TODO: replace this with AI-generated summary (#2395)
|
||||
formatted_output = ', '.join(
|
||||
f'{key}: {value}' for key, value in delegate_outputs.items()
|
||||
)
|
||||
content = (
|
||||
f'{self.delegate.agent.name} finishes task with {formatted_output}'
|
||||
)
|
||||
|
||||
# emit the delegate result observation
|
||||
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
else:
|
||||
# delegate state is ERROR
|
||||
# emit AgentDelegateObservation with error content
|
||||
delegate_outputs = (
|
||||
self.delegate.state.outputs if self.delegate.state else {}
|
||||
)
|
||||
content = (
|
||||
f'{self.delegate.agent.name} encountered an error during execution.'
|
||||
)
|
||||
|
||||
# emit the delegate result observation
|
||||
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
|
||||
# unset delegate so parent can resume normal handling
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
|
||||
async def _step(self) -> None:
|
||||
"""Executes a single step of the parent or delegate agent. Detects stuck agents and limits on the number of iterations and the task budget."""
|
||||
@@ -514,14 +612,6 @@ class AgentController:
|
||||
if self._pending_action:
|
||||
return
|
||||
|
||||
if self.delegate is not None:
|
||||
assert self.delegate != self
|
||||
# TODO this conditional will always be false, because the parent controllers are unsubscribed
|
||||
# remove if it's still useless when delegation is reworked
|
||||
if self.delegate.get_agent_state() != AgentState.PAUSED:
|
||||
await self._delegate_step()
|
||||
return
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'LEVEL {self.state.delegate_level} LOCAL STEP {self.state.local_iteration} GLOBAL STEP {self.state.iteration}',
|
||||
@@ -611,68 +701,6 @@ class AgentController:
|
||||
log_level = 'info' if LOG_ALL_EVENTS else 'debug'
|
||||
self.log(log_level, str(action), extra={'msg_type': 'ACTION'})
|
||||
|
||||
async def _delegate_step(self) -> None:
|
||||
"""Executes a single step of the delegate agent."""
|
||||
await self.delegate._step() # type: ignore[union-attr]
|
||||
assert self.delegate is not None
|
||||
delegate_state = self.delegate.get_agent_state()
|
||||
self.log('debug', f'Delegate state: {delegate_state}')
|
||||
if delegate_state == AgentState.ERROR:
|
||||
# update iteration that shall be shared across agents
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
|
||||
# emit AgentDelegateObservation to mark delegate termination due to error
|
||||
delegate_outputs = (
|
||||
self.delegate.state.outputs if self.delegate.state else {}
|
||||
)
|
||||
content = (
|
||||
f'{self.delegate.agent.name} encountered an error during execution.'
|
||||
)
|
||||
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
|
||||
# close the delegate upon error
|
||||
await self.delegate.close()
|
||||
|
||||
# resubscribe parent when delegate is finished
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
|
||||
)
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
|
||||
elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
|
||||
self.log('debug', 'Delegate agent has finished execution')
|
||||
# retrieve delegate result
|
||||
outputs = self.delegate.state.outputs if self.delegate.state else {}
|
||||
|
||||
# update iteration that shall be shared across agents
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
|
||||
# close delegate controller: we must close the delegate controller before adding new events
|
||||
await self.delegate.close()
|
||||
|
||||
# resubscribe parent when delegate is finished
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
|
||||
)
|
||||
|
||||
# update delegate result observation
|
||||
# TODO: replace this with AI-generated summary (#2395)
|
||||
formatted_output = ', '.join(
|
||||
f'{key}: {value}' for key, value in outputs.items()
|
||||
)
|
||||
content = (
|
||||
f'{self.delegate.agent.name} finishes task with {formatted_output}'
|
||||
)
|
||||
obs = AgentDelegateObservation(outputs=outputs, content=content)
|
||||
|
||||
# clean up delegate status
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
return
|
||||
|
||||
async def _handle_traffic_control(
|
||||
self, limit_type: str, current_value: float, max_value: float
|
||||
) -> bool:
|
||||
|
||||
@@ -138,8 +138,19 @@ class LLMConfig:
|
||||
This function is used to create an LLMConfig object from a dictionary,
|
||||
with the exception of the 'draft_editor' key, which is a nested LLMConfig object.
|
||||
"""
|
||||
args = {k: v for k, v in llm_config_dict.items() if not isinstance(v, dict)}
|
||||
if 'draft_editor' in llm_config_dict:
|
||||
draft_editor_config = LLMConfig(**llm_config_dict['draft_editor'])
|
||||
args['draft_editor'] = draft_editor_config
|
||||
# Keep None values to preserve defaults, filter out other dicts
|
||||
args = {
|
||||
k: v
|
||||
for k, v in llm_config_dict.items()
|
||||
if not isinstance(v, dict) or v is None
|
||||
}
|
||||
if (
|
||||
'draft_editor' in llm_config_dict
|
||||
and llm_config_dict['draft_editor'] is not None
|
||||
):
|
||||
if isinstance(llm_config_dict['draft_editor'], LLMConfig):
|
||||
args['draft_editor'] = llm_config_dict['draft_editor']
|
||||
else:
|
||||
draft_editor_config = LLMConfig(**llm_config_dict['draft_editor'])
|
||||
args['draft_editor'] = draft_editor_config
|
||||
return cls(**args)
|
||||
|
||||
@@ -41,7 +41,7 @@ class SandboxConfig:
|
||||
|
||||
remote_runtime_api_url: str = 'http://localhost:8000'
|
||||
local_runtime_url: str = 'http://localhost'
|
||||
keep_runtime_alive: bool = True
|
||||
keep_runtime_alive: bool = False
|
||||
rm_all_containers: bool = False
|
||||
api_key: str | None = None
|
||||
base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime
|
||||
@@ -60,7 +60,7 @@ class SandboxConfig:
|
||||
runtime_startup_env_vars: dict[str, str] = field(default_factory=dict)
|
||||
browsergym_eval_env: str | None = None
|
||||
platform: str | None = None
|
||||
close_delay: int = 900
|
||||
close_delay: int = 15
|
||||
remote_runtime_resource_factor: int = 1
|
||||
enable_gpu: bool = False
|
||||
docker_runtime_kwargs: str | None = None
|
||||
|
||||
@@ -144,15 +144,48 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
|
||||
logger.openhands_logger.debug(
|
||||
'Attempt to load default LLM config from config toml'
|
||||
)
|
||||
llm_config = LLMConfig.from_dict(value)
|
||||
cfg.set_llm_config(llm_config, 'llm')
|
||||
# TODO clean up draft_editor
|
||||
# Extract generic LLM fields, keeping draft_editor
|
||||
generic_llm_fields = {}
|
||||
for k, v in value.items():
|
||||
if not isinstance(v, dict) or k == 'draft_editor':
|
||||
generic_llm_fields[k] = v
|
||||
generic_llm_config = LLMConfig.from_dict(generic_llm_fields)
|
||||
cfg.set_llm_config(generic_llm_config, 'llm')
|
||||
|
||||
# Process custom named LLM configs
|
||||
for nested_key, nested_value in value.items():
|
||||
if isinstance(nested_value, dict):
|
||||
logger.openhands_logger.debug(
|
||||
f'Attempt to load group {nested_key} from config toml as llm config'
|
||||
f'Processing custom LLM config "{nested_key}":'
|
||||
)
|
||||
llm_config = LLMConfig.from_dict(nested_value)
|
||||
cfg.set_llm_config(llm_config, nested_key)
|
||||
# Apply generic LLM config with custom LLM overrides, e.g.
|
||||
# [llm]
|
||||
# model="..."
|
||||
# num_retries = 5
|
||||
# [llm.claude]
|
||||
# model="claude-3-5-sonnet"
|
||||
# results in num_retries APPLIED to claude-3-5-sonnet
|
||||
custom_fields = {}
|
||||
for k, v in nested_value.items():
|
||||
if not isinstance(v, dict) or k == 'draft_editor':
|
||||
custom_fields[k] = v
|
||||
merged_llm_dict = generic_llm_config.__dict__.copy()
|
||||
merged_llm_dict.update(custom_fields)
|
||||
# TODO clean up draft_editor
|
||||
# Handle draft_editor with fallback values:
|
||||
# - If draft_editor is "null", use None
|
||||
# - If draft_editor is in custom fields, use that value
|
||||
# - If draft_editor is not specified, fall back to generic config value
|
||||
if 'draft_editor' in custom_fields:
|
||||
if custom_fields['draft_editor'] == 'null':
|
||||
merged_llm_dict['draft_editor'] = None
|
||||
else:
|
||||
merged_llm_dict['draft_editor'] = (
|
||||
generic_llm_config.draft_editor
|
||||
)
|
||||
custom_llm_config = LLMConfig.from_dict(merged_llm_dict)
|
||||
cfg.set_llm_config(custom_llm_config, nested_key)
|
||||
elif key is not None and key.lower() == 'security':
|
||||
logger.openhands_logger.debug(
|
||||
'Attempt to load security config from config toml'
|
||||
@@ -458,7 +491,11 @@ def setup_config_from_args(args: argparse.Namespace) -> AppConfig:
|
||||
|
||||
# Override with command line arguments if provided
|
||||
if args.llm_config:
|
||||
llm_config = get_llm_config_arg(args.llm_config)
|
||||
# if we didn't already load it, get it from the toml file
|
||||
if args.llm_config not in config.llms:
|
||||
llm_config = get_llm_config_arg(args.llm_config)
|
||||
else:
|
||||
llm_config = config.llms[args.llm_config]
|
||||
if llm_config is None:
|
||||
raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
|
||||
config.set_llm_config(llm_config)
|
||||
|
||||
@@ -65,6 +65,7 @@ class EventStream:
|
||||
_queue: queue.Queue[Event]
|
||||
_queue_thread: threading.Thread
|
||||
_queue_loop: asyncio.AbstractEventLoop | None
|
||||
_thread_pools: dict[str, dict[str, ThreadPoolExecutor]]
|
||||
_thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]]
|
||||
|
||||
def __init__(self, sid: str, file_store: FileStore):
|
||||
@@ -72,8 +73,8 @@ class EventStream:
|
||||
self.file_store = file_store
|
||||
self._stop_flag = threading.Event()
|
||||
self._queue: queue.Queue[Event] = queue.Queue()
|
||||
self._thread_pools: dict[str, dict[str, ThreadPoolExecutor]] = {}
|
||||
self._thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]] = {}
|
||||
self._thread_pools = {}
|
||||
self._thread_loops = {}
|
||||
self._queue_loop = None
|
||||
self._queue_thread = threading.Thread(target=self._run_queue_loop)
|
||||
self._queue_thread.daemon = True
|
||||
@@ -257,7 +258,7 @@ class EventStream:
|
||||
def add_event(self, event: Event, source: EventSource):
|
||||
if hasattr(event, '_id') and event.id is not None:
|
||||
raise ValueError(
|
||||
'Event already has an ID. It was probably added back to the EventStream from inside a handler, trigging a loop.'
|
||||
f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.'
|
||||
)
|
||||
with self._lock:
|
||||
event._id = self._cur_id # type: ignore [attr-defined]
|
||||
@@ -285,6 +286,8 @@ class EventStream:
|
||||
event = self._queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# pass each event to each callback in order
|
||||
for key in sorted(self._subscribers.keys()):
|
||||
callbacks = self._subscribers[key]
|
||||
for callback_id in callbacks:
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel
|
||||
from openhands.core.exceptions import (
|
||||
MicroAgentValidationError,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.microagent.types import MicroAgentMetadata, MicroAgentType
|
||||
|
||||
|
||||
@@ -132,8 +133,10 @@ def load_microagents_from_dir(
|
||||
]:
|
||||
"""Load all microagents from the given directory.
|
||||
|
||||
Note, legacy repo instructions will not be loaded here.
|
||||
|
||||
Args:
|
||||
microagent_dir: Path to the microagents directory.
|
||||
microagent_dir: Path to the microagents directory (e.g. .openhands/microagents)
|
||||
|
||||
Returns:
|
||||
Tuple of (repo_agents, knowledge_agents, task_agents) dictionaries
|
||||
@@ -145,20 +148,24 @@ def load_microagents_from_dir(
|
||||
knowledge_agents = {}
|
||||
task_agents = {}
|
||||
|
||||
# Load all agents
|
||||
for file in microagent_dir.rglob('*.md'):
|
||||
# skip README.md
|
||||
if file.name == 'README.md':
|
||||
continue
|
||||
try:
|
||||
agent = BaseMicroAgent.load(file)
|
||||
if isinstance(agent, RepoMicroAgent):
|
||||
repo_agents[agent.name] = agent
|
||||
elif isinstance(agent, KnowledgeMicroAgent):
|
||||
knowledge_agents[agent.name] = agent
|
||||
elif isinstance(agent, TaskMicroAgent):
|
||||
task_agents[agent.name] = agent
|
||||
except Exception as e:
|
||||
raise ValueError(f'Error loading agent from {file}: {e}')
|
||||
# Load all agents from .openhands/microagents directory
|
||||
logger.debug(f'Loading agents from {microagent_dir}')
|
||||
if microagent_dir.exists():
|
||||
for file in microagent_dir.rglob('*.md'):
|
||||
logger.debug(f'Checking file {file}...')
|
||||
# skip README.md
|
||||
if file.name == 'README.md':
|
||||
continue
|
||||
try:
|
||||
agent = BaseMicroAgent.load(file)
|
||||
if isinstance(agent, RepoMicroAgent):
|
||||
repo_agents[agent.name] = agent
|
||||
elif isinstance(agent, KnowledgeMicroAgent):
|
||||
knowledge_agents[agent.name] = agent
|
||||
elif isinstance(agent, TaskMicroAgent):
|
||||
task_agents[agent.name] = agent
|
||||
logger.debug(f'Loaded agent {agent.name} from {file}')
|
||||
except Exception as e:
|
||||
raise ValueError(f'Error loading agent from {file}: {e}')
|
||||
|
||||
return repo_agents, knowledge_agents, task_agents
|
||||
|
||||
@@ -610,10 +610,14 @@ def parse_unified_diff(text):
|
||||
# - Start at line 1 in the old file and show 6 lines
|
||||
# - Start at line 1 in the new file and show 6 lines
|
||||
old = int(h.group(1)) # Starting line in old file
|
||||
old_len = int(h.group(2)) if len(h.group(2)) > 0 else 1 # Number of lines in old file
|
||||
old_len = (
|
||||
int(h.group(2)) if len(h.group(2)) > 0 else 1
|
||||
) # Number of lines in old file
|
||||
|
||||
new = int(h.group(3)) # Starting line in new file
|
||||
new_len = int(h.group(4)) if len(h.group(4)) > 0 else 1 # Number of lines in new file
|
||||
new_len = (
|
||||
int(h.group(4)) if len(h.group(4)) > 0 else 1
|
||||
) # Number of lines in new file
|
||||
|
||||
h = None
|
||||
break
|
||||
@@ -622,7 +626,9 @@ def parse_unified_diff(text):
|
||||
for n in hunk:
|
||||
# Each line in a unified diff starts with a space (context), + (addition), or - (deletion)
|
||||
# The first character is the kind, the rest is the line content
|
||||
kind = n[0] if len(n) > 0 else ' ' # Empty lines in the hunk are treated as context lines
|
||||
kind = (
|
||||
n[0] if len(n) > 0 else ' '
|
||||
) # Empty lines in the hunk are treated as context lines
|
||||
line = n[1:] if len(n) > 1 else ''
|
||||
|
||||
# Process the line based on its kind
|
||||
|
||||
@@ -4,10 +4,13 @@ import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import string
|
||||
import tempfile
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from zipfile import ZipFile
|
||||
|
||||
from requests.exceptions import ConnectionError
|
||||
|
||||
@@ -37,9 +40,7 @@ from openhands.events.observation import (
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.microagent import (
|
||||
BaseMicroAgent,
|
||||
KnowledgeMicroAgent,
|
||||
RepoMicroAgent,
|
||||
TaskMicroAgent,
|
||||
load_microagents_from_dir,
|
||||
)
|
||||
from openhands.runtime.plugins import (
|
||||
JupyterRequirement,
|
||||
@@ -125,7 +126,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
def setup_initial_env(self) -> None:
|
||||
if self.attach_to_existing:
|
||||
return
|
||||
logger.debug(f'Adding env vars: {self.initial_env_vars}')
|
||||
logger.debug(f'Adding env vars: {self.initial_env_vars.keys()}')
|
||||
self.add_env_vars(self.initial_env_vars)
|
||||
if self.config.sandbox.runtime_startup_env_vars:
|
||||
self.add_env_vars(self.config.sandbox.runtime_startup_env_vars)
|
||||
@@ -172,7 +173,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
obs = self.run(CmdRunAction(cmd))
|
||||
if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
|
||||
raise RuntimeError(
|
||||
f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
|
||||
f'Failed to add env vars [{env_vars.keys()}] to environment: {obs.content}'
|
||||
)
|
||||
|
||||
def on_event(self, event: Event) -> None:
|
||||
@@ -206,7 +207,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
source = event.source if event.source else EventSource.AGENT
|
||||
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
|
||||
|
||||
def clone_repo(self, github_token: str, selected_repository: str):
|
||||
def clone_repo(self, github_token: str, selected_repository: str) -> str:
|
||||
if not github_token or not selected_repository:
|
||||
raise ValueError(
|
||||
'github_token and selected_repository must be provided to clone a repository'
|
||||
@@ -223,25 +224,42 @@ class Runtime(FileEditRuntimeMixin):
|
||||
)
|
||||
self.log('info', f'Cloning repo: {selected_repository}')
|
||||
self.run_action(action)
|
||||
return dir_name
|
||||
|
||||
def get_microagents_from_selected_repo(
|
||||
self, selected_repository: str | None
|
||||
) -> list[BaseMicroAgent]:
|
||||
"""Load microagents from the selected repository.
|
||||
If selected_repository is None, load microagents from the current workspace.
|
||||
|
||||
This is the main entry point for loading microagents.
|
||||
"""
|
||||
|
||||
loaded_microagents: list[BaseMicroAgent] = []
|
||||
dir_name = Path('.openhands') / 'microagents'
|
||||
workspace_root = Path(self.config.workspace_mount_path_in_sandbox)
|
||||
microagents_dir = workspace_root / '.openhands' / 'microagents'
|
||||
repo_root = None
|
||||
if selected_repository:
|
||||
dir_name = Path('/workspace') / selected_repository.split('/')[1] / dir_name
|
||||
repo_root = workspace_root / selected_repository.split('/')[1]
|
||||
microagents_dir = repo_root / '.openhands' / 'microagents'
|
||||
self.log(
|
||||
'info',
|
||||
f'Selected repo: {selected_repository}, loading microagents from {microagents_dir} (inside runtime)',
|
||||
)
|
||||
|
||||
# Legacy Repo Instructions
|
||||
# Check for legacy .openhands_instructions file
|
||||
obs = self.read(FileReadAction(path='.openhands_instructions'))
|
||||
if isinstance(obs, ErrorObservation):
|
||||
obs = self.read(
|
||||
FileReadAction(path=str(workspace_root / '.openhands_instructions'))
|
||||
)
|
||||
if isinstance(obs, ErrorObservation) and repo_root is not None:
|
||||
# If the instructions file is not found in the workspace root, try to load it from the repo root
|
||||
self.log(
|
||||
'debug',
|
||||
f'openhands_instructions not present, trying to load from {dir_name}',
|
||||
f'.openhands_instructions not present, trying to load from repository {microagents_dir=}',
|
||||
)
|
||||
obs = self.read(
|
||||
FileReadAction(path=str(dir_name / '.openhands_instructions'))
|
||||
FileReadAction(path=str(repo_root / '.openhands_instructions'))
|
||||
)
|
||||
|
||||
if isinstance(obs, FileReadObservation):
|
||||
@@ -252,44 +270,40 @@ class Runtime(FileEditRuntimeMixin):
|
||||
)
|
||||
)
|
||||
|
||||
# Check for local repository microagents
|
||||
files = self.list_files(str(dir_name))
|
||||
self.log('info', f'Found {len(files)} local microagents.')
|
||||
if 'repo.md' in files:
|
||||
obs = self.read(FileReadAction(path=str(dir_name / 'repo.md')))
|
||||
if isinstance(obs, FileReadObservation):
|
||||
self.log('info', 'repo.md microagent loaded.')
|
||||
loaded_microagents.append(
|
||||
RepoMicroAgent.load(
|
||||
path=str(dir_name / 'repo.md'), file_content=obs.content
|
||||
)
|
||||
)
|
||||
# Load microagents from directory
|
||||
files = self.list_files(str(microagents_dir))
|
||||
if files:
|
||||
self.log('info', f'Found {len(files)} files in microagents directory.')
|
||||
zip_path = self.copy_from(str(microagents_dir))
|
||||
microagent_folder = tempfile.mkdtemp()
|
||||
|
||||
if 'knowledge' in files:
|
||||
knowledge_dir = dir_name / 'knowledge'
|
||||
_knowledge_microagents_files = self.list_files(str(knowledge_dir))
|
||||
for fname in _knowledge_microagents_files:
|
||||
obs = self.read(FileReadAction(path=str(knowledge_dir / fname)))
|
||||
if isinstance(obs, FileReadObservation):
|
||||
self.log('info', f'knowledge/{fname} microagent loaded.')
|
||||
loaded_microagents.append(
|
||||
KnowledgeMicroAgent.load(
|
||||
path=str(knowledge_dir / fname), file_content=obs.content
|
||||
)
|
||||
)
|
||||
# Properly handle the zip file
|
||||
with ZipFile(zip_path, 'r') as zip_file:
|
||||
zip_file.extractall(microagent_folder)
|
||||
|
||||
# Add debug print of directory structure
|
||||
self.log('debug', 'Microagent folder structure:')
|
||||
for root, _, files in os.walk(microagent_folder):
|
||||
relative_path = os.path.relpath(root, microagent_folder)
|
||||
self.log('debug', f'Directory: {relative_path}/')
|
||||
for file in files:
|
||||
self.log('debug', f' File: {os.path.join(relative_path, file)}')
|
||||
|
||||
# Clean up the temporary zip file
|
||||
zip_path.unlink()
|
||||
# Load all microagents using the existing function
|
||||
repo_agents, knowledge_agents, task_agents = load_microagents_from_dir(
|
||||
microagent_folder
|
||||
)
|
||||
self.log(
|
||||
'info',
|
||||
f'Loaded {len(repo_agents)} repo agents, {len(knowledge_agents)} knowledge agents, and {len(task_agents)} task agents',
|
||||
)
|
||||
loaded_microagents.extend(repo_agents.values())
|
||||
loaded_microagents.extend(knowledge_agents.values())
|
||||
loaded_microagents.extend(task_agents.values())
|
||||
shutil.rmtree(microagent_folder)
|
||||
|
||||
if 'tasks' in files:
|
||||
tasks_dir = dir_name / 'tasks'
|
||||
_tasks_microagents_files = self.list_files(str(tasks_dir))
|
||||
for fname in _tasks_microagents_files:
|
||||
obs = self.read(FileReadAction(path=str(tasks_dir / fname)))
|
||||
if isinstance(obs, FileReadObservation):
|
||||
self.log('info', f'tasks/{fname} microagent loaded.')
|
||||
loaded_microagents.append(
|
||||
TaskMicroAgent.load(
|
||||
path=str(tasks_dir / fname), file_content=obs.content
|
||||
)
|
||||
)
|
||||
return loaded_microagents
|
||||
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
|
||||
@@ -9,6 +9,7 @@ from openhands.core.exceptions import AgentRuntimeBuildError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.runtime.builder import RuntimeBuilder
|
||||
from openhands.runtime.utils.request import send_request
|
||||
from openhands.utils.http_session import HttpSession
|
||||
from openhands.utils.shutdown_listener import (
|
||||
should_continue,
|
||||
sleep_if_should_continue,
|
||||
@@ -18,12 +19,10 @@ from openhands.utils.shutdown_listener import (
|
||||
class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
"""This class interacts with the remote Runtime API for building and managing container images."""
|
||||
|
||||
def __init__(
|
||||
self, api_url: str, api_key: str, session: requests.Session | None = None
|
||||
):
|
||||
def __init__(self, api_url: str, api_key: str, session: HttpSession | None = None):
|
||||
self.api_url = api_url
|
||||
self.api_key = api_key
|
||||
self.session = session or requests.Session()
|
||||
self.session = session or HttpSession()
|
||||
self.session.headers.update({'X-API-Key': self.api_key})
|
||||
|
||||
def build(
|
||||
|
||||
@@ -35,6 +35,7 @@ from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils.request import send_request
|
||||
from openhands.utils.http_session import HttpSession
|
||||
|
||||
|
||||
class ActionExecutionClient(Runtime):
|
||||
@@ -55,7 +56,7 @@ class ActionExecutionClient(Runtime):
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
):
|
||||
self.session = requests.Session()
|
||||
self.session = HttpSession()
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
self._runtime_initialized: bool = False
|
||||
self._vscode_token: str | None = None # initial dummy value
|
||||
|
||||
@@ -229,6 +229,13 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
raise AgentRuntimeUnavailableError() from e
|
||||
|
||||
def _resume_runtime(self):
|
||||
"""
|
||||
1. Show status update that runtime is being started.
|
||||
2. Send the runtime API a /resume request
|
||||
3. Poll for the runtime to be ready
|
||||
4. Update env vars
|
||||
"""
|
||||
self.send_status_message('STATUS$STARTING_RUNTIME')
|
||||
with self._send_runtime_api_request(
|
||||
'POST',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/resume',
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
import requests
|
||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
||||
|
||||
from openhands.utils.http_session import HttpSession
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
@@ -34,7 +35,7 @@ def is_retryable_error(exception):
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
)
|
||||
def send_request(
|
||||
session: requests.Session,
|
||||
session: HttpSession,
|
||||
method: str,
|
||||
url: str,
|
||||
timeout: int = 10,
|
||||
@@ -48,11 +49,11 @@ def send_request(
|
||||
_json = response.json()
|
||||
except (requests.exceptions.JSONDecodeError, json.decoder.JSONDecodeError):
|
||||
_json = None
|
||||
finally:
|
||||
response.close()
|
||||
raise RequestHTTPError(
|
||||
e,
|
||||
response=e.response,
|
||||
detail=_json.get('detail') if _json is not None else None,
|
||||
) from e
|
||||
finally:
|
||||
response.close()
|
||||
return response
|
||||
|
||||
@@ -11,7 +11,6 @@ from fastapi import (
|
||||
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
|
||||
from openhands.server.middleware import (
|
||||
AttachConversationMiddleware,
|
||||
GitHubTokenMiddleware,
|
||||
InMemoryRateLimiter,
|
||||
LocalhostCORSMiddleware,
|
||||
NoCacheMiddleware,
|
||||
@@ -45,7 +44,6 @@ app.add_middleware(
|
||||
allow_headers=['*'],
|
||||
)
|
||||
|
||||
app.add_middleware(GitHubTokenMiddleware)
|
||||
app.add_middleware(NoCacheMiddleware)
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware, rate_limiter=InMemoryRateLimiter(requests=10, seconds=1)
|
||||
|
||||
@@ -5,8 +5,8 @@ from jwt.exceptions import InvalidTokenError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def get_user_id(request: Request) -> int:
|
||||
return getattr(request.state, 'github_user_id', 0)
|
||||
def get_user_id(request: Request) -> str | None:
|
||||
return getattr(request.state, 'github_user_id', None)
|
||||
|
||||
|
||||
def get_sid_from_token(token: str, jwt_secret: str) -> str:
|
||||
|
||||
@@ -166,16 +166,3 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
await self._detach_session(request)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class GitHubTokenMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.url.path.startswith('/api/github'):
|
||||
github_token = request.headers.get('X-GitHub-Token')
|
||||
if not github_token:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={'error': 'Missing X-GitHub-Token header'},
|
||||
)
|
||||
request.state.github_token = github_token
|
||||
return await call_next(request)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import requests
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.server.shared import openhands_config
|
||||
@@ -8,13 +8,23 @@ from openhands.utils.async_utils import call_sync_from_async
|
||||
app = APIRouter(prefix='/api/github')
|
||||
|
||||
|
||||
def require_github_token(request: Request):
|
||||
github_token = request.headers.get('X-GitHub-Token')
|
||||
if not github_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail='Missing X-GitHub-Token header',
|
||||
)
|
||||
return github_token
|
||||
|
||||
|
||||
@app.get('/repositories')
|
||||
async def get_github_repositories(
|
||||
request: Request,
|
||||
page: int = 1,
|
||||
per_page: int = 10,
|
||||
sort: str = 'pushed',
|
||||
installation_id: int | None = None,
|
||||
github_token: str = Depends(require_github_token),
|
||||
):
|
||||
openhands_config.verify_github_repo_list(installation_id)
|
||||
|
||||
@@ -33,7 +43,7 @@ async def get_github_repositories(
|
||||
params['sort'] = sort
|
||||
|
||||
# Set the authorization header with the GitHub token
|
||||
headers = generate_github_headers(request.state.github_token)
|
||||
headers = generate_github_headers(github_token)
|
||||
|
||||
# Fetch repositories from GitHub
|
||||
try:
|
||||
@@ -59,8 +69,8 @@ async def get_github_repositories(
|
||||
|
||||
|
||||
@app.get('/user')
|
||||
async def get_github_user(request: Request):
|
||||
headers = generate_github_headers(request.state.github_token)
|
||||
async def get_github_user(github_token: str = Depends(require_github_token)):
|
||||
headers = generate_github_headers(github_token)
|
||||
try:
|
||||
response = await call_sync_from_async(
|
||||
requests.get, 'https://api.github.com/user', headers=headers
|
||||
@@ -79,8 +89,10 @@ async def get_github_user(request: Request):
|
||||
|
||||
|
||||
@app.get('/installations')
|
||||
async def get_github_installation_ids(request: Request):
|
||||
headers = generate_github_headers(request.state.github_token)
|
||||
async def get_github_installation_ids(
|
||||
github_token: str = Depends(require_github_token),
|
||||
):
|
||||
headers = generate_github_headers(github_token)
|
||||
try:
|
||||
response = await call_sync_from_async(
|
||||
requests.get, 'https://api.github.com/user/installations', headers=headers
|
||||
@@ -102,13 +114,13 @@ async def get_github_installation_ids(request: Request):
|
||||
|
||||
@app.get('/search/repositories')
|
||||
async def search_github_repositories(
|
||||
request: Request,
|
||||
query: str,
|
||||
per_page: int = 5,
|
||||
sort: str = 'stars',
|
||||
order: str = 'desc',
|
||||
github_token: str = Depends(require_github_token),
|
||||
):
|
||||
headers = generate_github_headers(request.state.github_token)
|
||||
headers = generate_github_headers(github_token)
|
||||
params = {
|
||||
'q': query,
|
||||
'per_page': per_page,
|
||||
|
||||
@@ -130,7 +130,7 @@ async def search_conversations(
|
||||
for conversation in conversation_metadata_result_set.results
|
||||
if hasattr(conversation, 'created_at')
|
||||
)
|
||||
running_conversations = await session_manager.get_agent_loop_running(
|
||||
running_conversations = await session_manager.get_running_agent_loops(
|
||||
get_user_id(request), set(conversation_ids)
|
||||
)
|
||||
result = ConversationInfoResultSet(
|
||||
@@ -222,7 +222,7 @@ async def _get_conversation_info(
|
||||
|
||||
|
||||
def _create_conversation_update_callback(
|
||||
user_id: int, conversation_id: str
|
||||
user_id: str | None, conversation_id: str
|
||||
) -> Callable:
|
||||
def callback(*args, **kwargs):
|
||||
call_async_from_sync(
|
||||
@@ -235,7 +235,7 @@ def _create_conversation_update_callback(
|
||||
return callback
|
||||
|
||||
|
||||
async def _update_timestamp_for_conversation(user_id: int, conversation_id: str):
|
||||
async def _update_timestamp_for_conversation(user_id: str, conversation_id: str):
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
conversation = await conversation_store.get_metadata(conversation_id)
|
||||
conversation.last_updated_at = datetime.now(timezone.utc)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
from openhands.controller import AgentController
|
||||
@@ -16,10 +17,10 @@ from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.security import SecurityAnalyzer, options
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import call_async_from_sync, call_sync_from_async
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
|
||||
WAIT_TIME_BEFORE_CLOSE = 300
|
||||
WAIT_TIME_BEFORE_CLOSE = 90
|
||||
WAIT_TIME_BEFORE_CLOSE_INTERVAL = 5
|
||||
|
||||
|
||||
@@ -36,7 +37,8 @@ class AgentSession:
|
||||
controller: AgentController | None = None
|
||||
runtime: Runtime | None = None
|
||||
security_analyzer: SecurityAnalyzer | None = None
|
||||
_initializing: bool = False
|
||||
_starting: bool = False
|
||||
_started_at: float = 0
|
||||
_closed: bool = False
|
||||
loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
@@ -88,7 +90,8 @@ class AgentSession:
|
||||
if self._closed:
|
||||
logger.warning('Session closed before starting')
|
||||
return
|
||||
self._initializing = True
|
||||
self._starting = True
|
||||
self._started_at = time.time()
|
||||
self._create_security_analyzer(config.security.security_analyzer)
|
||||
await self._create_runtime(
|
||||
runtime_name=runtime_name,
|
||||
@@ -109,24 +112,19 @@ class AgentSession:
|
||||
self.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
|
||||
)
|
||||
self._initializing = False
|
||||
self._starting = False
|
||||
|
||||
def close(self):
|
||||
async def close(self):
|
||||
"""Closes the Agent session"""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
call_async_from_sync(self._close)
|
||||
|
||||
async def _close(self):
|
||||
seconds_waited = 0
|
||||
while self._initializing and should_continue():
|
||||
while self._starting and should_continue():
|
||||
logger.debug(
|
||||
f'Waiting for initialization to finish before closing session {self.sid}'
|
||||
)
|
||||
await asyncio.sleep(WAIT_TIME_BEFORE_CLOSE_INTERVAL)
|
||||
seconds_waited += WAIT_TIME_BEFORE_CLOSE_INTERVAL
|
||||
if seconds_waited > WAIT_TIME_BEFORE_CLOSE:
|
||||
if time.time() <= self._started_at + WAIT_TIME_BEFORE_CLOSE:
|
||||
logger.error(
|
||||
f'Waited too long for initialization to finish before closing session {self.sid}'
|
||||
)
|
||||
@@ -212,8 +210,9 @@ class AgentSession:
|
||||
)
|
||||
return
|
||||
|
||||
repo_directory = None
|
||||
if selected_repository:
|
||||
await call_sync_from_async(
|
||||
repo_directory = await call_sync_from_async(
|
||||
self.runtime.clone_repo, github_token, selected_repository
|
||||
)
|
||||
|
||||
@@ -223,6 +222,10 @@ class AgentSession:
|
||||
self.runtime.get_microagents_from_selected_repo, selected_repository
|
||||
)
|
||||
agent.prompt_manager.load_microagents(microagents)
|
||||
if selected_repository and repo_directory:
|
||||
agent.prompt_manager.set_repository_info(
|
||||
selected_repository, repo_directory
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
|
||||
@@ -306,3 +309,12 @@ class AgentSession:
|
||||
else:
|
||||
logger.debug('No events found, no state to restore')
|
||||
return restored_state
|
||||
|
||||
def get_state(self) -> AgentState | None:
|
||||
controller = self.controller
|
||||
if controller:
|
||||
return controller.state.agent_state
|
||||
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
|
||||
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
|
||||
return AgentState.ERROR
|
||||
return None
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, Iterable, TypeVar
|
||||
from uuid import uuid4
|
||||
|
||||
import socketio
|
||||
@@ -9,26 +10,28 @@ import socketio
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.exceptions import AgentRuntimeUnavailableError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.stream import EventStream, session_exists
|
||||
from openhands.server.session.conversation import Conversation
|
||||
from openhands.server.session.session import ROOM_KEY, Session
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.async_utils import wait_all
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
|
||||
_REDIS_POLL_TIMEOUT = 1.5
|
||||
_CHECK_ALIVE_INTERVAL = 15
|
||||
|
||||
_CLEANUP_INTERVAL = 15
|
||||
_CLEANUP_EXCEPTION_WAIT_TIME = 15
|
||||
MAX_RUNNING_CONVERSATIONS = 3
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SessionIsRunningCheck:
|
||||
request_id: str
|
||||
request_sids: list[str]
|
||||
running_sids: set[str] = field(default_factory=set)
|
||||
class _ClusterQuery(Generic[T]):
|
||||
query_id: str
|
||||
request_ids: set[str] | None
|
||||
result: T
|
||||
flag: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
|
||||
|
||||
@@ -38,10 +41,10 @@ class SessionManager:
|
||||
config: AppConfig
|
||||
file_store: FileStore
|
||||
_local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict)
|
||||
local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
|
||||
_local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
|
||||
_last_alive_timestamps: dict[str, float] = field(default_factory=dict)
|
||||
_redis_listen_task: asyncio.Task | None = None
|
||||
_session_is_running_checks: dict[str, _SessionIsRunningCheck] = field(
|
||||
_running_sid_queries: dict[str, _ClusterQuery[set[str]]] = field(
|
||||
default_factory=dict
|
||||
)
|
||||
_active_conversations: dict[str, tuple[Conversation, int]] = field(
|
||||
@@ -52,7 +55,7 @@ class SessionManager:
|
||||
)
|
||||
_conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
_cleanup_task: asyncio.Task | None = None
|
||||
_has_remote_connections_flags: dict[str, asyncio.Event] = field(
|
||||
_connection_queries: dict[str, _ClusterQuery[dict[str, str]]] = field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
@@ -60,7 +63,7 @@ class SessionManager:
|
||||
redis_client = self._get_redis_client()
|
||||
if redis_client:
|
||||
self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_detached_conversations())
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_stale())
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
@@ -82,7 +85,7 @@ class SessionManager:
|
||||
logger.debug('_redis_subscribe')
|
||||
redis_client = self._get_redis_client()
|
||||
pubsub = redis_client.pubsub()
|
||||
await pubsub.subscribe('oh_event')
|
||||
await pubsub.subscribe('session_msg')
|
||||
while should_continue():
|
||||
try:
|
||||
message = await pubsub.get_message(
|
||||
@@ -108,59 +111,71 @@ class SessionManager:
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if session:
|
||||
await session.dispatch(data['data'])
|
||||
elif message_type == 'is_session_running':
|
||||
elif message_type == 'running_agent_loops_query':
|
||||
# Another node in the cluster is asking if the current node is running the session given.
|
||||
request_id = data['request_id']
|
||||
sids = [
|
||||
sid for sid in data['sids'] if sid in self._local_agent_loops_by_sid
|
||||
]
|
||||
query_id = data['query_id']
|
||||
sids = self._get_running_agent_loops_locally(
|
||||
data.get('user_id'), data.get('filter_to_sids')
|
||||
)
|
||||
if sids:
|
||||
await self._get_redis_client().publish(
|
||||
'oh_event',
|
||||
'session_msg',
|
||||
json.dumps(
|
||||
{
|
||||
'request_id': request_id,
|
||||
'sids': sids,
|
||||
'message_type': 'session_is_running',
|
||||
'query_id': query_id,
|
||||
'sids': list(sids),
|
||||
'message_type': 'running_agent_loops_response',
|
||||
}
|
||||
),
|
||||
)
|
||||
elif message_type == 'session_is_running':
|
||||
request_id = data['request_id']
|
||||
elif message_type == 'running_agent_loops_response':
|
||||
query_id = data['query_id']
|
||||
for sid in data['sids']:
|
||||
self._last_alive_timestamps[sid] = time.time()
|
||||
check = self._session_is_running_checks.get(request_id)
|
||||
if check:
|
||||
check.running_sids.update(data['sids'])
|
||||
if len(check.request_sids) == len(check.running_sids):
|
||||
check.flag.set()
|
||||
elif message_type == 'has_remote_connections_query':
|
||||
running_query = self._running_sid_queries.get(query_id)
|
||||
if running_query:
|
||||
running_query.result.update(data['sids'])
|
||||
if running_query.request_ids is not None and len(
|
||||
running_query.request_ids
|
||||
) == len(running_query.result):
|
||||
running_query.flag.set()
|
||||
elif message_type == 'connections_query':
|
||||
# Another node in the cluster is asking if the current node is connected to a session
|
||||
sid = data['sid']
|
||||
required = sid in self.local_connection_id_to_session_id.values()
|
||||
if required:
|
||||
query_id = data['query_id']
|
||||
connections = self._get_connections_locally(
|
||||
data.get('user_id'), data.get('filter_to_sids')
|
||||
)
|
||||
if connections:
|
||||
await self._get_redis_client().publish(
|
||||
'oh_event',
|
||||
'session_msg',
|
||||
json.dumps(
|
||||
{'sid': sid, 'message_type': 'has_remote_connections_response'}
|
||||
{
|
||||
'query_id': query_id,
|
||||
'connections': connections,
|
||||
'message_type': 'connections_response',
|
||||
}
|
||||
),
|
||||
)
|
||||
elif message_type == 'has_remote_connections_response':
|
||||
sid = data['sid']
|
||||
flag = self._has_remote_connections_flags.get(sid)
|
||||
if flag:
|
||||
flag.set()
|
||||
elif message_type == 'connections_response':
|
||||
query_id = data['query_id']
|
||||
connection_query = self._connection_queries.get(query_id)
|
||||
if connection_query:
|
||||
connection_query.result.update(**data['connections'])
|
||||
if connection_query.request_ids is not None and len(
|
||||
connection_query.request_ids
|
||||
) == len(connection_query.result):
|
||||
connection_query.flag.set()
|
||||
elif message_type == 'close_session':
|
||||
sid = data['sid']
|
||||
if sid in self._local_agent_loops_by_sid:
|
||||
await self._on_close_session(sid)
|
||||
await self._close_session(sid)
|
||||
elif message_type == 'session_closing':
|
||||
# Session closing event - We only get this in the event of graceful shutdown,
|
||||
# which can't be guaranteed - nodes can simply vanish unexpectedly!
|
||||
sid = data['sid']
|
||||
logger.debug(f'session_closing:{sid}')
|
||||
# Create a list of items to process to avoid modifying dict during iteration
|
||||
items = list(self.local_connection_id_to_session_id.items())
|
||||
items = list(self._local_connection_id_to_session_id.items())
|
||||
for connection_id, local_sid in items:
|
||||
if sid == local_sid:
|
||||
logger.warning(
|
||||
@@ -204,11 +219,11 @@ class SessionManager:
|
||||
return c
|
||||
|
||||
async def join_conversation(
|
||||
self, sid: str, connection_id: str, settings: Settings, user_id: int | None
|
||||
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
|
||||
):
|
||||
logger.info(f'join_conversation:{sid}:{connection_id}')
|
||||
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
|
||||
self.local_connection_id_to_session_id[connection_id] = sid
|
||||
self._local_connection_id_to_session_id[connection_id] = sid
|
||||
event_stream = await self._get_event_stream(sid)
|
||||
if not event_stream:
|
||||
return await self.maybe_start_agent_loop(sid, settings, user_id)
|
||||
@@ -226,7 +241,7 @@ class SessionManager:
|
||||
self._active_conversations.pop(sid)
|
||||
self._detached_conversations[sid] = (conversation, time.time())
|
||||
|
||||
async def _cleanup_detached_conversations(self):
|
||||
async def _cleanup_stale(self):
|
||||
while should_continue():
|
||||
if self._get_redis_client():
|
||||
# Debug info for HA envs
|
||||
@@ -240,7 +255,7 @@ class SessionManager:
|
||||
f'Running agent loops: {len(self._local_agent_loops_by_sid)}'
|
||||
)
|
||||
logger.info(
|
||||
f'Local connections: {len(self.local_connection_id_to_session_id)}'
|
||||
f'Local connections: {len(self._local_connection_id_to_session_id)}'
|
||||
)
|
||||
try:
|
||||
async with self._conversations_lock:
|
||||
@@ -250,107 +265,196 @@ class SessionManager:
|
||||
await conversation.disconnect()
|
||||
self._detached_conversations.pop(sid, None)
|
||||
|
||||
close_threshold = time.time() - self.config.sandbox.close_delay
|
||||
running_loops = list(self._local_agent_loops_by_sid.items())
|
||||
running_loops.sort(key=lambda item: item[1].last_active_ts)
|
||||
sid_to_close: list[str] = []
|
||||
for sid, session in running_loops:
|
||||
state = session.agent_session.get_state()
|
||||
if session.last_active_ts < close_threshold and state not in [
|
||||
AgentState.RUNNING,
|
||||
None,
|
||||
]:
|
||||
sid_to_close.append(sid)
|
||||
|
||||
connections = self._get_connections_locally(
|
||||
filter_to_sids=set(sid_to_close)
|
||||
)
|
||||
connected_sids = {sid for _, sid in connections.items()}
|
||||
sid_to_close = [
|
||||
sid for sid in sid_to_close if sid not in connected_sids
|
||||
]
|
||||
|
||||
if sid_to_close:
|
||||
connections = await self._get_connections_remotely(
|
||||
filter_to_sids=set(sid_to_close)
|
||||
)
|
||||
connected_sids = {sid for _, sid in connections.items()}
|
||||
sid_to_close = [
|
||||
sid for sid in sid_to_close if sid not in connected_sids
|
||||
]
|
||||
|
||||
await wait_all(self._close_session(sid) for sid in sid_to_close)
|
||||
await asyncio.sleep(_CLEANUP_INTERVAL)
|
||||
except asyncio.CancelledError:
|
||||
async with self._conversations_lock:
|
||||
for conversation, _ in self._detached_conversations.values():
|
||||
await conversation.disconnect()
|
||||
self._detached_conversations.clear()
|
||||
await wait_all(
|
||||
self._close_session(sid) for sid in self._local_agent_loops_by_sid
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f'error_cleaning_detached_conversations: {str(e)}')
|
||||
await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)
|
||||
|
||||
async def get_agent_loop_running(self, user_id, sids: set[str]) -> set[str]:
|
||||
running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid)
|
||||
check_cluster_sids = [sid for sid in sids if sid not in running_sids]
|
||||
running_cluster_sids = await self.get_agent_loop_running_in_cluster(
|
||||
check_cluster_sids
|
||||
)
|
||||
running_sids.union(running_cluster_sids)
|
||||
return running_sids
|
||||
logger.warning(f'error_cleaning_stale: {str(e)}')
|
||||
await asyncio.sleep(_CLEANUP_INTERVAL)
|
||||
|
||||
async def is_agent_loop_running(self, sid: str) -> bool:
|
||||
if await self.is_agent_loop_running_locally(sid):
|
||||
return True
|
||||
if await self.is_agent_loop_running_in_cluster(sid):
|
||||
return True
|
||||
return False
|
||||
sids = await self.get_running_agent_loops(filter_to_sids={sid})
|
||||
return bool(sids)
|
||||
|
||||
async def is_agent_loop_running_locally(self, sid: str) -> bool:
|
||||
return sid in self._local_agent_loops_by_sid
|
||||
async def get_running_agent_loops(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> set[str]:
|
||||
"""Get the running session ids. If a user is supplied, then the results are limited to session ids for that user. If a set of filter_to_sids is supplied, then results are limited to these ids of interest."""
|
||||
sids = self._get_running_agent_loops_locally(user_id, filter_to_sids)
|
||||
remote_sids = await self._get_running_agent_loops_remotely(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
return sids.union(remote_sids)
|
||||
|
||||
async def is_agent_loop_running_in_cluster(self, sid: str) -> bool:
|
||||
running_sids = await self.get_agent_loop_running_in_cluster([sid])
|
||||
return bool(running_sids)
|
||||
def _get_running_agent_loops_locally(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> set[str]:
|
||||
items: Iterable[tuple[str, Session]] = self._local_agent_loops_by_sid.items()
|
||||
if filter_to_sids is not None:
|
||||
items = (item for item in items if item[0] in filter_to_sids)
|
||||
if user_id:
|
||||
items = (item for item in items if item[1].user_id == user_id)
|
||||
sids = {sid for sid, _ in items}
|
||||
return sids
|
||||
|
||||
async def get_agent_loop_running_in_cluster(self, sids: list[str]) -> set[str]:
|
||||
async def _get_running_agent_loops_remotely(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
filter_to_sids: set[str] | None = None,
|
||||
) -> set[str]:
|
||||
"""As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
|
||||
redis_client = self._get_redis_client()
|
||||
if not redis_client:
|
||||
return set()
|
||||
|
||||
flag = asyncio.Event()
|
||||
request_id = str(uuid4())
|
||||
check = _SessionIsRunningCheck(request_id=request_id, request_sids=sids)
|
||||
self._session_is_running_checks[request_id] = check
|
||||
query_id = str(uuid4())
|
||||
query = _ClusterQuery[set[str]](
|
||||
query_id=query_id, request_ids=filter_to_sids, result=set()
|
||||
)
|
||||
self._running_sid_queries[query_id] = query
|
||||
try:
|
||||
logger.debug(f'publish:is_session_running:{sids}')
|
||||
await redis_client.publish(
|
||||
'oh_event',
|
||||
json.dumps(
|
||||
{
|
||||
'request_id': request_id,
|
||||
'sids': sids,
|
||||
'message_type': 'is_session_running',
|
||||
}
|
||||
),
|
||||
logger.debug(
|
||||
f'publish:_get_running_agent_loops_remotely_query:{user_id}:{filter_to_sids}'
|
||||
)
|
||||
data: dict = {
|
||||
'query_id': query_id,
|
||||
'message_type': 'running_agent_loops_query',
|
||||
}
|
||||
if user_id:
|
||||
data['user_id'] = user_id
|
||||
if filter_to_sids:
|
||||
data['filter_to_sids'] = list(filter_to_sids)
|
||||
await redis_client.publish('session_msg', json.dumps(data))
|
||||
async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
|
||||
await flag.wait()
|
||||
|
||||
return check.running_sids
|
||||
return query.result
|
||||
except TimeoutError:
|
||||
# Nobody replied in time
|
||||
return check.running_sids
|
||||
return query.result
|
||||
finally:
|
||||
self._session_is_running_checks.pop(request_id, None)
|
||||
self._running_sid_queries.pop(query_id, None)
|
||||
|
||||
async def get_connections(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> dict[str, str]:
|
||||
connection_ids = self._get_connections_locally(user_id, filter_to_sids)
|
||||
remote_connection_ids = await self._get_connections_remotely(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
connection_ids.update(**remote_connection_ids)
|
||||
return connection_ids
|
||||
|
||||
def _get_connections_locally(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> dict[str, str]:
|
||||
connections = dict(**self._local_connection_id_to_session_id)
|
||||
if filter_to_sids is not None:
|
||||
connections = {
|
||||
connection_id: sid
|
||||
for connection_id, sid in connections.items()
|
||||
if sid in filter_to_sids
|
||||
}
|
||||
if user_id:
|
||||
for connection_id, sid in list(connections.items()):
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if not session or session.user_id != user_id:
|
||||
connections.pop(connection_id)
|
||||
return connections
|
||||
|
||||
async def _get_connections_remotely(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> dict[str, str]:
|
||||
redis_client = self._get_redis_client()
|
||||
if not redis_client:
|
||||
return {}
|
||||
|
||||
async def _has_remote_connections(self, sid: str) -> bool:
|
||||
"""As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
|
||||
# Create a flag for the callback
|
||||
flag = asyncio.Event()
|
||||
self._has_remote_connections_flags[sid] = flag
|
||||
query_id = str(uuid4())
|
||||
query = _ClusterQuery[dict[str, str]](
|
||||
query_id=query_id, request_ids=filter_to_sids, result={}
|
||||
)
|
||||
self._connection_queries[query_id] = query
|
||||
try:
|
||||
await self._get_redis_client().publish(
|
||||
'oh_event',
|
||||
json.dumps(
|
||||
{
|
||||
'sid': sid,
|
||||
'message_type': 'has_remote_connections_query',
|
||||
}
|
||||
),
|
||||
logger.debug(
|
||||
f'publish:get_connections_remotely_query:{user_id}:{filter_to_sids}'
|
||||
)
|
||||
data: dict = {
|
||||
'query_id': query_id,
|
||||
'message_type': 'connections_query',
|
||||
}
|
||||
if user_id:
|
||||
data['user_id'] = user_id
|
||||
if filter_to_sids:
|
||||
data['filter_to_sids'] = list(filter_to_sids)
|
||||
await redis_client.publish('session_msg', json.dumps(data))
|
||||
async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
|
||||
await flag.wait()
|
||||
|
||||
result = flag.is_set()
|
||||
return result
|
||||
return query.result
|
||||
except TimeoutError:
|
||||
# Nobody replied in time
|
||||
return False
|
||||
return query.result
|
||||
finally:
|
||||
self._has_remote_connections_flags.pop(sid, None)
|
||||
self._connection_queries.pop(query_id, None)
|
||||
|
||||
async def maybe_start_agent_loop(
|
||||
self, sid: str, settings: Settings, user_id: int | None
|
||||
self, sid: str, settings: Settings, user_id: str | None
|
||||
) -> EventStream:
|
||||
logger.info(f'maybe_start_agent_loop:{sid}')
|
||||
session: Session | None = None
|
||||
if not await self.is_agent_loop_running(sid):
|
||||
logger.info(f'start_agent_loop:{sid}')
|
||||
|
||||
response_ids = await self.get_running_agent_loops(user_id)
|
||||
if len(response_ids) >= MAX_RUNNING_CONVERSATIONS:
|
||||
logger.info('too_many_sessions_for:{user_id}')
|
||||
await self.close_session(next(iter(response_ids)))
|
||||
|
||||
session = Session(
|
||||
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
|
||||
sid=sid,
|
||||
file_store=self.file_store,
|
||||
config=self.config,
|
||||
sio=self.sio,
|
||||
user_id=user_id,
|
||||
)
|
||||
self._local_agent_loops_by_sid[sid] = session
|
||||
asyncio.create_task(session.initialize_agent(settings))
|
||||
@@ -359,7 +463,6 @@ class SessionManager:
|
||||
if not event_stream:
|
||||
logger.error(f'No event stream after starting agent loop: {sid}')
|
||||
raise RuntimeError(f'no_event_stream:{sid}')
|
||||
asyncio.create_task(self._cleanup_session_later(sid))
|
||||
return event_stream
|
||||
|
||||
async def _get_event_stream(self, sid: str) -> EventStream | None:
|
||||
@@ -369,7 +472,7 @@ class SessionManager:
|
||||
logger.info(f'found_local_agent_loop:{sid}')
|
||||
return session.agent_session.event_stream
|
||||
|
||||
if await self.is_agent_loop_running_in_cluster(sid):
|
||||
if await self._get_running_agent_loops_remotely(filter_to_sids={sid}):
|
||||
logger.info(f'found_remote_agent_loop:{sid}')
|
||||
return EventStream(sid, self.file_store)
|
||||
|
||||
@@ -377,7 +480,7 @@ class SessionManager:
|
||||
|
||||
async def send_to_event_stream(self, connection_id: str, data: dict):
|
||||
# If there is a local session running, send to that
|
||||
sid = self.local_connection_id_to_session_id.get(connection_id)
|
||||
sid = self._local_connection_id_to_session_id.get(connection_id)
|
||||
if not sid:
|
||||
raise RuntimeError(f'no_connected_session:{connection_id}')
|
||||
|
||||
@@ -393,11 +496,11 @@ class SessionManager:
|
||||
next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
|
||||
if (
|
||||
next_alive_check > time.time()
|
||||
or await self.is_agent_loop_running_in_cluster(sid)
|
||||
or await self._get_running_agent_loops_remotely(filter_to_sids={sid})
|
||||
):
|
||||
# Send the event to the other pod
|
||||
await redis_client.publish(
|
||||
'oh_event',
|
||||
'session_msg',
|
||||
json.dumps(
|
||||
{
|
||||
'sid': sid,
|
||||
@@ -411,75 +514,37 @@ class SessionManager:
|
||||
raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
|
||||
|
||||
async def disconnect_from_session(self, connection_id: str):
|
||||
sid = self.local_connection_id_to_session_id.pop(connection_id, None)
|
||||
sid = self._local_connection_id_to_session_id.pop(connection_id, None)
|
||||
logger.info(f'disconnect_from_session:{connection_id}:{sid}')
|
||||
if not sid:
|
||||
# This can occur if the init action was never run.
|
||||
logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
|
||||
return
|
||||
|
||||
if should_continue():
|
||||
asyncio.create_task(self._cleanup_session_later(sid))
|
||||
else:
|
||||
await self._on_close_session(sid)
|
||||
|
||||
async def _cleanup_session_later(self, sid: str):
|
||||
# Once there have been no connections to a session for a reasonable period, we close it
|
||||
try:
|
||||
await asyncio.sleep(self.config.sandbox.close_delay)
|
||||
finally:
|
||||
# If the sleep was cancelled, we still want to close these
|
||||
await self._cleanup_session(sid)
|
||||
|
||||
async def _cleanup_session(self, sid: str) -> bool:
|
||||
# Get local connections
|
||||
logger.info(f'_cleanup_session:{sid}')
|
||||
has_local_connections = next(
|
||||
(True for v in self.local_connection_id_to_session_id.values() if v == sid),
|
||||
False,
|
||||
)
|
||||
if has_local_connections:
|
||||
return False
|
||||
|
||||
# If no local connections, get connections through redis
|
||||
redis_client = self._get_redis_client()
|
||||
if redis_client and await self._has_remote_connections(sid):
|
||||
return False
|
||||
|
||||
# We alert the cluster in case they are interested
|
||||
if redis_client:
|
||||
await redis_client.publish(
|
||||
'oh_event',
|
||||
json.dumps({'sid': sid, 'message_type': 'session_closing'}),
|
||||
)
|
||||
|
||||
await self._on_close_session(sid)
|
||||
return True
|
||||
|
||||
async def close_session(self, sid: str):
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if session:
|
||||
await self._on_close_session(sid)
|
||||
await self._close_session(sid)
|
||||
|
||||
redis_client = self._get_redis_client()
|
||||
if redis_client:
|
||||
await redis_client.publish(
|
||||
'oh_event',
|
||||
'session_msg',
|
||||
json.dumps({'sid': sid, 'message_type': 'close_session'}),
|
||||
)
|
||||
|
||||
async def _on_close_session(self, sid: str):
|
||||
async def _close_session(self, sid: str):
|
||||
logger.info(f'_close_session:{sid}')
|
||||
|
||||
# Clear up local variables
|
||||
connection_ids_to_remove = list(
|
||||
connection_id
|
||||
for connection_id, conn_sid in self.local_connection_id_to_session_id.items()
|
||||
for connection_id, conn_sid in self._local_connection_id_to_session_id.items()
|
||||
if sid == conn_sid
|
||||
)
|
||||
logger.info(f'removing connections: {connection_ids_to_remove}')
|
||||
for connnnection_id in connection_ids_to_remove:
|
||||
self.local_connection_id_to_session_id.pop(connnnection_id, None)
|
||||
self._local_connection_id_to_session_id.pop(connnnection_id, None)
|
||||
|
||||
session = self._local_agent_loops_by_sid.pop(sid, None)
|
||||
if not session:
|
||||
@@ -488,12 +553,17 @@ class SessionManager:
|
||||
|
||||
logger.info(f'closing_session:{session.sid}')
|
||||
# We alert the cluster in case they are interested
|
||||
redis_client = self._get_redis_client()
|
||||
if redis_client:
|
||||
await redis_client.publish(
|
||||
'oh_event',
|
||||
json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
|
||||
try:
|
||||
redis_client = self._get_redis_client()
|
||||
if redis_client:
|
||||
await redis_client.publish(
|
||||
'session_msg',
|
||||
json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
|
||||
)
|
||||
except Exception:
|
||||
logger.info(
|
||||
'error_publishing_close_session_event', exc_info=True, stack_info=True
|
||||
)
|
||||
|
||||
await call_sync_from_async(session.close)
|
||||
await session.close()
|
||||
logger.info(f'closed_session:{session.sid}')
|
||||
|
||||
@@ -37,7 +37,7 @@ class Session:
|
||||
loop: asyncio.AbstractEventLoop
|
||||
config: AppConfig
|
||||
file_store: FileStore
|
||||
user_id: int | None
|
||||
user_id: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -45,7 +45,7 @@ class Session:
|
||||
config: AppConfig,
|
||||
file_store: FileStore,
|
||||
sio: socketio.AsyncServer | None,
|
||||
user_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
self.sid = sid
|
||||
self.sio = sio
|
||||
@@ -62,9 +62,17 @@ class Session:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.user_id = user_id
|
||||
|
||||
def close(self):
|
||||
async def close(self):
|
||||
if self.sio:
|
||||
await self.sio.emit(
|
||||
'oh_event',
|
||||
event_to_dict(
|
||||
AgentStateChangedObservation('', AgentState.STOPPED.value)
|
||||
),
|
||||
to=ROOM_KEY.format(sid=self.sid),
|
||||
)
|
||||
self.is_alive = False
|
||||
self.agent_session.close()
|
||||
await self.agent_session.close()
|
||||
|
||||
async def initialize_agent(
|
||||
self,
|
||||
|
||||
@@ -41,6 +41,6 @@ class ConversationStore(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: int | None
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
) -> ConversationStore:
|
||||
"""Get a store for the user represented by the token given"""
|
||||
|
||||
@@ -92,7 +92,7 @@ class FileConversationStore(ConversationStore):
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: int | None
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
) -> FileConversationStore:
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
return FileConversationStore(file_store)
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import datetime, timezone
|
||||
@dataclass
|
||||
class ConversationMetadata:
|
||||
conversation_id: str
|
||||
github_user_id: int
|
||||
github_user_id: str | None
|
||||
selected_repository: str | None
|
||||
title: str | None = None
|
||||
last_updated_at: datetime | None = None
|
||||
|
||||
@@ -31,7 +31,7 @@ class FileSettingsStore(SettingsStore):
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: int | None
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
) -> FileSettingsStore:
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
return FileSettingsStore(file_store)
|
||||
|
||||
@@ -22,6 +22,6 @@ class SettingsStore(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: int | None
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
) -> SettingsStore:
|
||||
"""Get a store for the user represented by the token given"""
|
||||
|
||||
24
openhands/utils/http_session.py
Normal file
24
openhands/utils/http_session.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpSession:
|
||||
"""
|
||||
request.Session is reusable after it has been closed. This behavior makes it
|
||||
likely to leak file descriptors (Especially when combined with tenacity).
|
||||
We wrap the session to make it unusable after being closed
|
||||
"""
|
||||
|
||||
session: requests.Session | None = field(default_factory=requests.Session)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self.session is None:
|
||||
raise ValueError('session_was_closed')
|
||||
return object.__getattribute__(self.session, name)
|
||||
|
||||
def close(self):
|
||||
if self.session is not None:
|
||||
self.session.close()
|
||||
self.session = None
|
||||
@@ -20,6 +20,14 @@ class RuntimeInfo:
|
||||
available_hosts: dict[str, int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepositoryInfo:
|
||||
"""Information about a GitHub repository that has been cloned."""
|
||||
|
||||
repo_name: str | None = None
|
||||
repo_directory: str | None = None
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""
|
||||
Manages prompt templates and micro-agents for AI interactions.
|
||||
@@ -42,7 +50,7 @@ class PromptManager:
|
||||
):
|
||||
self.disabled_microagents: list[str] = disabled_microagents or []
|
||||
self.prompt_dir: str = prompt_dir
|
||||
|
||||
self.repository_info: RepositoryInfo | None = None
|
||||
self.system_template: Template = self._load_template('system_prompt')
|
||||
self.user_template: Template = self._load_template('user_prompt')
|
||||
self.runtime_info = RuntimeInfo(available_hosts={})
|
||||
@@ -80,9 +88,6 @@ class PromptManager:
|
||||
elif isinstance(microagent, RepoMicroAgent):
|
||||
self.repo_microagents[microagent.name] = microagent
|
||||
|
||||
def set_runtime_info(self, runtime: Runtime):
|
||||
self.runtime_info.available_hosts = runtime.web_hosts
|
||||
|
||||
def _load_template(self, template_name: str) -> Template:
|
||||
if self.prompt_dir is None:
|
||||
raise ValueError('Prompt directory is not set')
|
||||
@@ -102,10 +107,31 @@ class PromptManager:
|
||||
if repo_instructions:
|
||||
repo_instructions += '\n\n'
|
||||
repo_instructions += microagent.content
|
||||
|
||||
return self.system_template.render(
|
||||
runtime_info=self.runtime_info, repo_instructions=repo_instructions
|
||||
repository_instructions=repo_instructions,
|
||||
repository_info=self.repository_info,
|
||||
runtime_info=self.runtime_info,
|
||||
).strip()
|
||||
|
||||
def set_runtime_info(self, runtime: Runtime):
|
||||
self.runtime_info.available_hosts = runtime.web_hosts
|
||||
|
||||
def set_repository_info(
|
||||
self,
|
||||
repo_name: str,
|
||||
repo_directory: str,
|
||||
) -> None:
|
||||
"""Sets information about the GitHub repository that has been cloned.
|
||||
|
||||
Args:
|
||||
repo_name: The name of the GitHub repository (e.g. 'owner/repo')
|
||||
repo_directory: The directory where the repository has been cloned
|
||||
"""
|
||||
self.repository_info = RepositoryInfo(
|
||||
repo_name=repo_name, repo_directory=repo_directory
|
||||
)
|
||||
|
||||
def get_example_user_message(self) -> str:
|
||||
"""This is the initial user message provided to the agent
|
||||
before *actual* user instructions are provided.
|
||||
|
||||
36
poetry.lock
generated
36
poetry.lock
generated
@@ -552,17 +552,17 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.35.97"
|
||||
version = "1.35.98"
|
||||
description = "The AWS SDK for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "boto3-1.35.97-py3-none-any.whl", hash = "sha256:8e49416216a6e3a62c2a0c44fba4dd2852c85472e7b702516605b1363867d220"},
|
||||
{file = "boto3-1.35.97.tar.gz", hash = "sha256:7d398f66a11e67777c189d1f58c0a75d9d60f98d0ee51b8817e828930bf19e4e"},
|
||||
{file = "boto3-1.35.98-py3-none-any.whl", hash = "sha256:d0224e1499d7189b47aa7f469d96522d98df6f5702fccb20a95a436582ebcd9d"},
|
||||
{file = "boto3-1.35.98.tar.gz", hash = "sha256:4b6274b4fe9d7113f978abea66a1f20c8a397c268c9d1b2a6c96b14a256da4a5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
botocore = ">=1.35.97,<1.36.0"
|
||||
botocore = ">=1.35.98,<1.36.0"
|
||||
jmespath = ">=0.7.1,<2.0.0"
|
||||
s3transfer = ">=0.10.0,<0.11.0"
|
||||
|
||||
@@ -571,13 +571,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.35.97"
|
||||
version = "1.35.98"
|
||||
description = "Low-level, data-driven core of boto 3."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "botocore-1.35.97-py3-none-any.whl", hash = "sha256:fed4f156b1a9b8ece53738f702ba5851b8c6216b4952de326547f349cc494f14"},
|
||||
{file = "botocore-1.35.97.tar.gz", hash = "sha256:88f2fab29192ffe2f2115d5bafbbd823ff4b6eb2774296e03ec8b5b0fe074f61"},
|
||||
{file = "botocore-1.35.98-py3-none-any.whl", hash = "sha256:4f1c0b687488663a774ad3a5e81a5f94fae1bcada2364cfdc48482c4dbf794d5"},
|
||||
{file = "botocore-1.35.98.tar.gz", hash = "sha256:d11742b3824bdeac3c89eeeaf5132351af41823bbcef8fc15e95c8250b1de09c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -927,13 +927,13 @@ numpy = "*"
|
||||
|
||||
[[package]]
|
||||
name = "chromadb"
|
||||
version = "0.6.2"
|
||||
version = "0.6.3"
|
||||
description = "Chroma."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "chromadb-0.6.2-py3-none-any.whl", hash = "sha256:77a5e07097e36cdd49d8d2925d0c4d28291cabc9677787423d2cc7c426e8895b"},
|
||||
{file = "chromadb-0.6.2.tar.gz", hash = "sha256:e9e11f04d3850796711ee05dad4e918c75ec7b62ab9cbe7b4588b68a26aaea06"},
|
||||
{file = "chromadb-0.6.3-py3-none-any.whl", hash = "sha256:4851258489a3612b558488d98d09ae0fe0a28d5cad6bd1ba64b96fdc419dc0e5"},
|
||||
{file = "chromadb-0.6.3.tar.gz", hash = "sha256:c8f34c0b704b9108b04491480a36d42e894a960429f87c6516027b5481d59ed3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2161,13 +2161,13 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
|
||||
|
||||
[[package]]
|
||||
name = "google-api-python-client"
|
||||
version = "2.158.0"
|
||||
version = "2.159.0"
|
||||
description = "Google API Client Library for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "google_api_python_client-2.158.0-py2.py3-none-any.whl", hash = "sha256:36f8c8d2e79e50f76790ca5946d2f3f8333e210dc8539a6c88e0742416474ad2"},
|
||||
{file = "google_api_python_client-2.158.0.tar.gz", hash = "sha256:b6664597a9955e04977a62752e33fe44cb35c580e190c1cb08a041893172bd67"},
|
||||
{file = "google_api_python_client-2.159.0-py2.py3-none-any.whl", hash = "sha256:baef0bb631a60a0bd7c0bf12a5499e3a40cd4388484de7ee55c1950bf820a0cf"},
|
||||
{file = "google_api_python_client-2.159.0.tar.gz", hash = "sha256:55197f430f25c907394b44fa078545ffef89d33fd4dca501b7db9f0d8e224bd6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3707,13 +3707,13 @@ types-tqdm = "*"
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.58.0"
|
||||
version = "1.58.1"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
optional = false
|
||||
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
|
||||
files = [
|
||||
{file = "litellm-1.58.0-py3-none-any.whl", hash = "sha256:1fc07646f6419f1d7b7d06fe2f5c72b3e6e3407423b50cbf45b58dc9740e7a03"},
|
||||
{file = "litellm-1.58.0.tar.gz", hash = "sha256:db4512e987809e04e59d5b4240bdef3a4bfe575c332397f80cc6b4403c68120a"},
|
||||
{file = "litellm-1.58.1-py3-none-any.whl", hash = "sha256:eae311273dd7b8be9b1fc92f12c6ec521f86166effd7ae2cec2982dbb9a7dc2c"},
|
||||
{file = "litellm-1.58.1.tar.gz", hash = "sha256:c73dff605b830815088bdfcc6c42f380b5dc129a184b678646ce981305d11ac6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4583,12 +4583,12 @@ type = ["mypy (==1.11.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "modal"
|
||||
version = "0.72.10"
|
||||
version = "0.72.11"
|
||||
description = "Python client library for Modal"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "modal-0.72.10-py3-none-any.whl", hash = "sha256:eb67516660e00a9dca07378fb251275a913720162e98410f46c79924cb7cd509"},
|
||||
{file = "modal-0.72.11-py3-none-any.whl", hash = "sha256:97428dfbc2cb2677eb0a17e6189fedf991296245d25e9700c2b9f8978853a2ae"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@@ -26,7 +26,7 @@ test_mount_path = ''
|
||||
project_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
sandbox_test_folder = '/openhands/workspace'
|
||||
sandbox_test_folder = '/workspace'
|
||||
|
||||
|
||||
def _get_runtime_sid(runtime: Runtime) -> str:
|
||||
@@ -233,9 +233,10 @@ def _load_runtime(
|
||||
if use_workspace:
|
||||
test_mount_path = os.path.join(config.workspace_base, 'rt')
|
||||
elif temp_dir is not None:
|
||||
test_mount_path = os.path.join(temp_dir, sid)
|
||||
test_mount_path = temp_dir
|
||||
else:
|
||||
test_mount_path = None
|
||||
config.workspace_base = test_mount_path
|
||||
config.workspace_mount_path = test_mount_path
|
||||
|
||||
# Mounting folder specific for this test inside the sandbox
|
||||
|
||||
@@ -210,7 +210,7 @@ done && echo "success"
|
||||
def test_cmd_run(temp_dir, runtime_cls, run_as_openhands):
|
||||
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
|
||||
try:
|
||||
obs = _run_cmd_action(runtime, 'ls -l /openhands/workspace')
|
||||
obs = _run_cmd_action(runtime, 'ls -l /workspace')
|
||||
assert obs.exit_code == 0
|
||||
|
||||
obs = _run_cmd_action(runtime, 'ls -l')
|
||||
@@ -377,7 +377,7 @@ def test_copy_to_non_existent_directory(temp_dir, runtime_cls):
|
||||
def test_overwrite_existing_file(temp_dir, runtime_cls):
|
||||
runtime = _load_runtime(temp_dir, runtime_cls)
|
||||
try:
|
||||
sandbox_dir = '/openhands/workspace'
|
||||
sandbox_dir = '/workspace'
|
||||
|
||||
obs = _run_cmd_action(runtime, f'ls -alh {sandbox_dir}')
|
||||
assert obs.exit_code == 0
|
||||
|
||||
@@ -52,7 +52,7 @@ def test_simple_cmd_ipython_and_fileop(temp_dir, runtime_cls, run_as_openhands):
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.content.strip() == (
|
||||
'Hello, `World`!\n'
|
||||
'[Jupyter current working directory: /openhands/workspace]\n'
|
||||
'[Jupyter current working directory: /workspace]\n'
|
||||
'[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.12/bin/python]'
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ def test_simple_cmd_ipython_and_fileop(temp_dir, runtime_cls, run_as_openhands):
|
||||
|
||||
assert obs.content == ''
|
||||
# event stream runtime will always use absolute path
|
||||
assert obs.path == '/openhands/workspace/hello.sh'
|
||||
assert obs.path == '/workspace/hello.sh'
|
||||
|
||||
# Test read file (file should exist)
|
||||
action_read = FileReadAction(path='hello.sh')
|
||||
@@ -85,7 +85,7 @@ def test_simple_cmd_ipython_and_fileop(temp_dir, runtime_cls, run_as_openhands):
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert obs.content == 'echo "Hello, World!"\n'
|
||||
assert obs.path == '/openhands/workspace/hello.sh'
|
||||
assert obs.path == '/workspace/hello.sh'
|
||||
|
||||
# clean up
|
||||
action = CmdRunAction(command='rm -rf hello.sh')
|
||||
@@ -188,7 +188,7 @@ def test_ipython_simple(temp_dir, runtime_cls):
|
||||
obs.content.strip()
|
||||
== (
|
||||
'1\n'
|
||||
'[Jupyter current working directory: /openhands/workspace]\n'
|
||||
'[Jupyter current working directory: /workspace]\n'
|
||||
'[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.12/bin/python]'
|
||||
).strip()
|
||||
)
|
||||
@@ -224,7 +224,7 @@ def test_ipython_package_install(temp_dir, runtime_cls, run_as_openhands):
|
||||
# import should not error out
|
||||
assert obs.content.strip() == (
|
||||
'[Code executed successfully with no output]\n'
|
||||
'[Jupyter current working directory: /openhands/workspace]\n'
|
||||
'[Jupyter current working directory: /workspace]\n'
|
||||
'[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.12/bin/python]'
|
||||
)
|
||||
|
||||
@@ -273,16 +273,16 @@ def test_ipython_file_editor_permissions_as_openhands(temp_dir, runtime_cls):
|
||||
# Try to use file editor in openhands sandbox directory - should work
|
||||
test_code = """
|
||||
# Create file
|
||||
print(file_editor(command='create', path='/openhands/workspace/test.txt', file_text='Line 1\\nLine 2\\nLine 3'))
|
||||
print(file_editor(command='create', path='/workspace/test.txt', file_text='Line 1\\nLine 2\\nLine 3'))
|
||||
|
||||
# View file
|
||||
print(file_editor(command='view', path='/openhands/workspace/test.txt'))
|
||||
print(file_editor(command='view', path='/workspace/test.txt'))
|
||||
|
||||
# Edit file
|
||||
print(file_editor(command='str_replace', path='/openhands/workspace/test.txt', old_str='Line 2', new_str='New Line 2'))
|
||||
print(file_editor(command='str_replace', path='/workspace/test.txt', old_str='Line 2', new_str='New Line 2'))
|
||||
|
||||
# Undo edit
|
||||
print(file_editor(command='undo_edit', path='/openhands/workspace/test.txt'))
|
||||
print(file_editor(command='undo_edit', path='/workspace/test.txt'))
|
||||
"""
|
||||
action = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
@@ -297,7 +297,7 @@ print(file_editor(command='undo_edit', path='/openhands/workspace/test.txt'))
|
||||
assert 'undone successfully' in obs.content
|
||||
|
||||
# Clean up
|
||||
action = CmdRunAction(command='rm -f /openhands/workspace/test.txt')
|
||||
action = CmdRunAction(command='rm -f /workspace/test.txt')
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
@@ -314,7 +314,7 @@ print(file_editor(command='undo_edit', path='/openhands/workspace/test.txt'))
|
||||
|
||||
def test_file_read_and_edit_via_oh_aci(runtime_cls, run_as_openhands):
|
||||
runtime = _load_runtime(None, runtime_cls, run_as_openhands)
|
||||
sandbox_dir = '/openhands/workspace'
|
||||
sandbox_dir = '/workspace'
|
||||
|
||||
actions = [
|
||||
{
|
||||
|
||||
197
tests/runtime/test_microagent.py
Normal file
197
tests/runtime/test_microagent.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Tests for microagent loading in runtime."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from conftest import (
|
||||
_close_test_runtime,
|
||||
_load_runtime,
|
||||
)
|
||||
|
||||
from openhands.microagent import KnowledgeMicroAgent, RepoMicroAgent, TaskMicroAgent
|
||||
|
||||
|
||||
def _create_test_microagents(test_dir: str):
|
||||
"""Create test microagent files in the given directory."""
|
||||
microagents_dir = Path(test_dir) / '.openhands' / 'microagents'
|
||||
microagents_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create test knowledge agent
|
||||
knowledge_dir = microagents_dir / 'knowledge'
|
||||
knowledge_dir.mkdir(exist_ok=True)
|
||||
knowledge_agent = """---
|
||||
name: test_knowledge_agent
|
||||
type: knowledge
|
||||
version: 1.0.0
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- test
|
||||
- pytest
|
||||
---
|
||||
|
||||
# Test Guidelines
|
||||
|
||||
Testing best practices and guidelines.
|
||||
"""
|
||||
(knowledge_dir / 'knowledge.md').write_text(knowledge_agent)
|
||||
|
||||
# Create test repo agent
|
||||
repo_agent = """---
|
||||
name: test_repo_agent
|
||||
type: repo
|
||||
version: 1.0.0
|
||||
agent: CodeActAgent
|
||||
---
|
||||
|
||||
# Test Repository Agent
|
||||
|
||||
Repository-specific test instructions.
|
||||
"""
|
||||
(microagents_dir / 'repo.md').write_text(repo_agent)
|
||||
|
||||
# Create test task agent in a nested directory
|
||||
task_dir = microagents_dir / 'tasks' / 'nested'
|
||||
task_dir.mkdir(parents=True, exist_ok=True)
|
||||
task_agent = """---
|
||||
name: test_task
|
||||
type: task
|
||||
version: 1.0.0
|
||||
agent: CodeActAgent
|
||||
---
|
||||
|
||||
# Test Task
|
||||
|
||||
Test task content
|
||||
"""
|
||||
(task_dir / 'task.md').write_text(task_agent)
|
||||
|
||||
# Create legacy repo instructions
|
||||
legacy_instructions = """# Legacy Instructions
|
||||
|
||||
These are legacy repository instructions.
|
||||
"""
|
||||
(Path(test_dir) / '.openhands_instructions').write_text(legacy_instructions)
|
||||
|
||||
|
||||
def test_load_microagents_with_trailing_slashes(
|
||||
temp_dir, runtime_cls, run_as_openhands
|
||||
):
|
||||
"""Test loading microagents when directory paths have trailing slashes."""
|
||||
# Create test files
|
||||
_create_test_microagents(temp_dir)
|
||||
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
|
||||
try:
|
||||
# Load microagents
|
||||
loaded_agents = runtime.get_microagents_from_selected_repo(None)
|
||||
|
||||
# Verify all agents are loaded
|
||||
knowledge_agents = [
|
||||
a for a in loaded_agents if isinstance(a, KnowledgeMicroAgent)
|
||||
]
|
||||
repo_agents = [a for a in loaded_agents if isinstance(a, RepoMicroAgent)]
|
||||
task_agents = [a for a in loaded_agents if isinstance(a, TaskMicroAgent)]
|
||||
|
||||
# Check knowledge agents
|
||||
assert len(knowledge_agents) == 1
|
||||
agent = knowledge_agents[0]
|
||||
assert agent.name == 'test_knowledge_agent'
|
||||
assert 'test' in agent.triggers
|
||||
assert 'pytest' in agent.triggers
|
||||
|
||||
# Check repo agents (including legacy)
|
||||
assert len(repo_agents) == 2 # repo.md + .openhands_instructions
|
||||
repo_names = {a.name for a in repo_agents}
|
||||
assert 'test_repo_agent' in repo_names
|
||||
assert 'repo_legacy' in repo_names
|
||||
|
||||
# Check task agents
|
||||
assert len(task_agents) == 1
|
||||
agent = task_agents[0]
|
||||
assert agent.name == 'test_task'
|
||||
|
||||
finally:
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
def test_load_microagents_with_selected_repo(temp_dir, runtime_cls, run_as_openhands):
|
||||
"""Test loading microagents from a selected repository."""
|
||||
# Create test files in a repository-like structure
|
||||
repo_dir = Path(temp_dir) / 'OpenHands'
|
||||
repo_dir.mkdir(parents=True)
|
||||
_create_test_microagents(str(repo_dir))
|
||||
|
||||
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
|
||||
try:
|
||||
# Load microagents with selected repository
|
||||
loaded_agents = runtime.get_microagents_from_selected_repo(
|
||||
'All-Hands-AI/OpenHands'
|
||||
)
|
||||
|
||||
# Verify all agents are loaded
|
||||
knowledge_agents = [
|
||||
a for a in loaded_agents if isinstance(a, KnowledgeMicroAgent)
|
||||
]
|
||||
repo_agents = [a for a in loaded_agents if isinstance(a, RepoMicroAgent)]
|
||||
task_agents = [a for a in loaded_agents if isinstance(a, TaskMicroAgent)]
|
||||
|
||||
# Check knowledge agents
|
||||
assert len(knowledge_agents) == 1
|
||||
agent = knowledge_agents[0]
|
||||
assert agent.name == 'test_knowledge_agent'
|
||||
assert 'test' in agent.triggers
|
||||
assert 'pytest' in agent.triggers
|
||||
|
||||
# Check repo agents (including legacy)
|
||||
assert len(repo_agents) == 2 # repo.md + .openhands_instructions
|
||||
repo_names = {a.name for a in repo_agents}
|
||||
assert 'test_repo_agent' in repo_names
|
||||
assert 'repo_legacy' in repo_names
|
||||
|
||||
# Check task agents
|
||||
assert len(task_agents) == 1
|
||||
agent = task_agents[0]
|
||||
assert agent.name == 'test_task'
|
||||
|
||||
finally:
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
def test_load_microagents_with_missing_files(temp_dir, runtime_cls, run_as_openhands):
|
||||
"""Test loading microagents when some files are missing."""
|
||||
# Create only repo.md, no other files
|
||||
microagents_dir = Path(temp_dir) / '.openhands' / 'microagents'
|
||||
microagents_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
repo_agent = """---
|
||||
name: test_repo_agent
|
||||
type: repo
|
||||
version: 1.0.0
|
||||
agent: CodeActAgent
|
||||
---
|
||||
|
||||
# Test Repository Agent
|
||||
|
||||
Repository-specific test instructions.
|
||||
"""
|
||||
(microagents_dir / 'repo.md').write_text(repo_agent)
|
||||
|
||||
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
|
||||
try:
|
||||
# Load microagents
|
||||
loaded_agents = runtime.get_microagents_from_selected_repo(None)
|
||||
|
||||
# Verify only repo agent is loaded
|
||||
knowledge_agents = [
|
||||
a for a in loaded_agents if isinstance(a, KnowledgeMicroAgent)
|
||||
]
|
||||
repo_agents = [a for a in loaded_agents if isinstance(a, RepoMicroAgent)]
|
||||
task_agents = [a for a in loaded_agents if isinstance(a, TaskMicroAgent)]
|
||||
|
||||
assert len(knowledge_agents) == 0
|
||||
assert len(repo_agents) == 1
|
||||
assert len(task_agents) == 0
|
||||
|
||||
agent = repo_agents[0]
|
||||
assert agent.name == 'test_repo_agent'
|
||||
|
||||
finally:
|
||||
_close_test_runtime(runtime)
|
||||
@@ -1,12 +1,10 @@
|
||||
import pytest
|
||||
from openhands.resolver.patching.apply import apply_diff
|
||||
from openhands.resolver.patching.exceptions import HunkApplyException
|
||||
from openhands.resolver.patching.patch import parse_diff, diffobj
|
||||
from openhands.resolver.patching.patch import diffobj, parse_diff
|
||||
|
||||
|
||||
def test_patch_apply_with_empty_lines():
|
||||
# The original file has no indentation and uses \n line endings
|
||||
original_content = "# PR Viewer\n\nThis React application allows you to view open pull requests from GitHub repositories in a GitHub organization. By default, it uses the All-Hands-AI organization.\n\n## Setup"
|
||||
original_content = '# PR Viewer\n\nThis React application allows you to view open pull requests from GitHub repositories in a GitHub organization. By default, it uses the All-Hands-AI organization.\n\n## Setup'
|
||||
|
||||
# The patch has spaces at the start of each line and uses \n line endings
|
||||
patch = """diff --git a/README.md b/README.md
|
||||
@@ -19,18 +17,20 @@ index b760a53..5071727 100644
|
||||
-This React application allows you to view open pull requests from GitHub repositories in a GitHub organization. By default, it uses the All-Hands-AI organization.
|
||||
+This React application was created by Graham Neubig and OpenHands. It allows you to view open pull requests from GitHub repositories in a GitHub organization. By default, it uses the All-Hands-AI organization."""
|
||||
|
||||
print("Original content lines:")
|
||||
print('Original content lines:')
|
||||
for i, line in enumerate(original_content.splitlines(), 1):
|
||||
print(f"{i}: {repr(line)}")
|
||||
print(f'{i}: {repr(line)}')
|
||||
|
||||
print("\nPatch lines:")
|
||||
print('\nPatch lines:')
|
||||
for i, line in enumerate(patch.splitlines(), 1):
|
||||
print(f"{i}: {repr(line)}")
|
||||
print(f'{i}: {repr(line)}')
|
||||
|
||||
changes = parse_diff(patch)
|
||||
print("\nParsed changes:")
|
||||
print('\nParsed changes:')
|
||||
for change in changes:
|
||||
print(f"Change(old={change.old}, new={change.new}, line={repr(change.line)}, hunk={change.hunk})")
|
||||
print(
|
||||
f'Change(old={change.old}, new={change.new}, line={repr(change.line)}, hunk={change.hunk})'
|
||||
)
|
||||
diff = diffobj(header=None, changes=changes, text=patch)
|
||||
|
||||
# Apply the patch
|
||||
@@ -38,10 +38,10 @@ index b760a53..5071727 100644
|
||||
|
||||
# The patch should be applied successfully
|
||||
expected_result = [
|
||||
"# PR Viewer",
|
||||
"",
|
||||
"This React application was created by Graham Neubig and OpenHands. It allows you to view open pull requests from GitHub repositories in a GitHub organization. By default, it uses the All-Hands-AI organization.",
|
||||
"",
|
||||
"## Setup"
|
||||
'# PR Viewer',
|
||||
'',
|
||||
'This React application was created by Graham Neubig and OpenHands. It allows you to view open pull requests from GitHub repositories in a GitHub organization. By default, it uses the All-Hands-AI organization.',
|
||||
'',
|
||||
'## Setup',
|
||||
]
|
||||
assert result == expected_result
|
||||
assert result == expected_result
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@@ -130,7 +130,7 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
|
||||
async def test_run_controller_with_fatal_error():
|
||||
config = AppConfig()
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
@@ -239,55 +239,6 @@ async def test_run_controller_stop_with_stuck():
|
||||
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'delegate_state',
|
||||
[
|
||||
AgentState.RUNNING,
|
||||
AgentState.FINISHED,
|
||||
AgentState.ERROR,
|
||||
AgentState.REJECTED,
|
||||
],
|
||||
)
|
||||
async def test_delegate_step_different_states(
|
||||
mock_agent, mock_event_stream, delegate_state
|
||||
):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
mock_delegate = AsyncMock()
|
||||
controller.delegate = mock_delegate
|
||||
|
||||
mock_delegate.state.iteration = 5
|
||||
mock_delegate.state.outputs = {'result': 'test'}
|
||||
mock_delegate.agent.name = 'TestDelegate'
|
||||
|
||||
mock_delegate.get_agent_state = Mock(return_value=delegate_state)
|
||||
mock_delegate._step = AsyncMock()
|
||||
mock_delegate.close = AsyncMock()
|
||||
|
||||
await controller._delegate_step()
|
||||
|
||||
mock_delegate._step.assert_called_once()
|
||||
|
||||
if delegate_state == AgentState.RUNNING:
|
||||
assert controller.delegate is not None
|
||||
assert controller.state.iteration == 0
|
||||
mock_delegate.close.assert_not_called()
|
||||
else:
|
||||
assert controller.delegate is None
|
||||
assert controller.state.iteration == 5
|
||||
mock_delegate.close.assert_called_once()
|
||||
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
||||
# Test with headless_mode=False - should extend max_iterations
|
||||
|
||||
187
tests/unit/test_agent_delegation.py
Normal file
187
tests/unit/test_agent_delegation.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource, EventStream
|
||||
from openhands.events.action import (
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream():
|
||||
"""Creates an event stream in memory."""
|
||||
sid = f'test-{uuid4()}'
|
||||
file_store = InMemoryFileStore({})
|
||||
return EventStream(sid=sid, file_store=file_store)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_parent_agent():
|
||||
"""Creates a mock parent agent for testing delegation."""
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.name = 'ParentAgent'
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = LLMConfig()
|
||||
agent.config = AgentConfig()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_child_agent():
|
||||
"""Creates a mock child agent for testing delegation."""
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.name = 'ChildAgent'
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = LLMConfig()
|
||||
agent.config = AgentConfig()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_stream):
|
||||
"""
|
||||
Test that when the parent agent delegates to a child, the parent's delegate
|
||||
is set, and once the child finishes, the parent is cleaned up properly.
|
||||
"""
|
||||
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
|
||||
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
|
||||
|
||||
# Create parent controller
|
||||
parent_state = State(max_iterations=10)
|
||||
parent_controller = AgentController(
|
||||
agent=mock_parent_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='parent',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
initial_state=parent_state,
|
||||
)
|
||||
|
||||
# Setup a delegate action from the parent
|
||||
delegate_action = AgentDelegateAction(agent='ChildAgent', inputs={'test': True})
|
||||
mock_parent_agent.step.return_value = delegate_action
|
||||
|
||||
# Simulate a user message event to cause parent.step() to run
|
||||
message_action = MessageAction(content='please delegate now')
|
||||
message_action._source = EventSource.USER
|
||||
await parent_controller._on_event(message_action)
|
||||
|
||||
# Give time for the async step() to execute
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# The parent should receive step() from that event
|
||||
# Verify that a delegate agent controller is created
|
||||
assert (
|
||||
parent_controller.delegate is not None
|
||||
), "Parent's delegate controller was not set."
|
||||
|
||||
# The parent's iteration should have incremented
|
||||
assert (
|
||||
parent_controller.state.iteration == 1
|
||||
), 'Parent iteration should be incremented after step.'
|
||||
|
||||
# Now simulate that the child increments local iteration and finishes its subtask
|
||||
delegate_controller = parent_controller.delegate
|
||||
delegate_controller.state.iteration = 5 # child had some steps
|
||||
delegate_controller.state.outputs = {'delegate_result': 'done'}
|
||||
|
||||
# The child is done, so we simulate it finishing:
|
||||
child_finish_action = AgentFinishAction()
|
||||
await delegate_controller._on_event(child_finish_action)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Now the parent's delegate is None
|
||||
assert (
|
||||
parent_controller.delegate is None
|
||||
), 'Parent delegate should be None after child finishes.'
|
||||
|
||||
# Parent's global iteration is updated from the child
|
||||
assert (
|
||||
parent_controller.state.iteration == 6
|
||||
), "Parent iteration should be the child's iteration + 1 after child is done."
|
||||
|
||||
# Cleanup
|
||||
await parent_controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'delegate_state',
|
||||
[
|
||||
AgentState.RUNNING,
|
||||
AgentState.FINISHED,
|
||||
AgentState.ERROR,
|
||||
AgentState.REJECTED,
|
||||
],
|
||||
)
|
||||
async def test_delegate_step_different_states(
|
||||
mock_parent_agent, mock_event_stream, delegate_state
|
||||
):
|
||||
"""Ensure that delegate is closed or remains open based on the delegate's state."""
|
||||
controller = AgentController(
|
||||
agent=mock_parent_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
mock_delegate = AsyncMock()
|
||||
controller.delegate = mock_delegate
|
||||
|
||||
mock_delegate.state.iteration = 5
|
||||
mock_delegate.state.outputs = {'result': 'test'}
|
||||
mock_delegate.agent.name = 'TestDelegate'
|
||||
|
||||
mock_delegate.get_agent_state = Mock(return_value=delegate_state)
|
||||
mock_delegate._step = AsyncMock()
|
||||
mock_delegate.close = AsyncMock()
|
||||
|
||||
def call_on_event_with_new_loop():
|
||||
"""
|
||||
In this thread, create and set a fresh event loop, so that the run_until_complete()
|
||||
calls inside controller.on_event(...) find a valid loop.
|
||||
"""
|
||||
loop_in_thread = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop_in_thread)
|
||||
msg_action = MessageAction(content='Test message')
|
||||
msg_action._source = EventSource.USER
|
||||
controller.on_event(msg_action)
|
||||
finally:
|
||||
loop_in_thread.close()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = loop.run_in_executor(executor, call_on_event_with_new_loop)
|
||||
await future
|
||||
|
||||
if delegate_state == AgentState.RUNNING:
|
||||
assert controller.delegate is not None
|
||||
assert controller.state.iteration == 0
|
||||
mock_delegate.close.assert_not_called()
|
||||
else:
|
||||
assert controller.delegate is None
|
||||
assert controller.state.iteration == 5
|
||||
mock_delegate.close.assert_called_once()
|
||||
|
||||
await controller.close()
|
||||
@@ -29,7 +29,7 @@ def _patch_store():
|
||||
'title': 'Some Conversation',
|
||||
'selected_repository': 'foobar',
|
||||
'conversation_id': 'some_conversation_id',
|
||||
'github_user_id': 12345,
|
||||
'github_user_id': '12345',
|
||||
'created_at': '2025-01-01T00:00:00',
|
||||
'last_updated_at': '2025-01-01T00:01:00',
|
||||
}
|
||||
|
||||
228
tests/unit/test_llm_config.py
Normal file
228
tests/unit/test_llm_config.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.config.utils import load_from_toml
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_config(monkeypatch):
|
||||
# Fixture to provide a default AppConfig instance
|
||||
yield AppConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generic_llm_toml(tmp_path: pathlib.Path) -> str:
|
||||
"""Fixture to create a generic LLM TOML configuration with all custom LLMs
|
||||
providing mandatory 'model' and 'api_key', and testing fallback to the generic section values
|
||||
for other attributes like 'num_retries'.
|
||||
"""
|
||||
toml_content = """
|
||||
[core]
|
||||
workspace_base = "./workspace"
|
||||
|
||||
[llm]
|
||||
model = "base-model"
|
||||
api_key = "base-api-key"
|
||||
embedding_model = "base-embedding"
|
||||
num_retries = 3
|
||||
|
||||
[llm.custom1]
|
||||
model = "custom-model-1"
|
||||
api_key = "custom-api-key-1"
|
||||
# 'num_retries' is not overridden and should fallback to the value from [llm]
|
||||
|
||||
[llm.custom2]
|
||||
model = "custom-model-2"
|
||||
api_key = "custom-api-key-2"
|
||||
num_retries = 5 # Overridden value
|
||||
|
||||
[llm.custom3]
|
||||
model = "custom-model-3"
|
||||
api_key = "custom-api-key-3"
|
||||
# No overrides for additional attributes
|
||||
"""
|
||||
toml_file = tmp_path / 'llm_config.toml'
|
||||
toml_file.write_text(toml_content)
|
||||
return str(toml_file)
|
||||
|
||||
|
||||
def test_load_from_toml_llm_with_fallback(
|
||||
default_config: AppConfig, generic_llm_toml: str
|
||||
) -> None:
|
||||
"""Test that custom LLM configurations fallback non-overridden attributes
|
||||
like 'num_retries' from the generic [llm] section.
|
||||
"""
|
||||
load_from_toml(default_config, generic_llm_toml)
|
||||
|
||||
# Verify generic LLM configuration
|
||||
generic_llm = default_config.get_llm_config('llm')
|
||||
assert generic_llm.model == 'base-model'
|
||||
assert generic_llm.api_key == 'base-api-key'
|
||||
assert generic_llm.embedding_model == 'base-embedding'
|
||||
assert generic_llm.num_retries == 3
|
||||
|
||||
# Verify custom1 LLM falls back 'num_retries' from base
|
||||
custom1 = default_config.get_llm_config('custom1')
|
||||
assert custom1.model == 'custom-model-1'
|
||||
assert custom1.api_key == 'custom-api-key-1'
|
||||
assert custom1.embedding_model == 'base-embedding'
|
||||
assert custom1.num_retries == 3 # from [llm]
|
||||
|
||||
# Verify custom2 LLM overrides 'num_retries'
|
||||
custom2 = default_config.get_llm_config('custom2')
|
||||
assert custom2.model == 'custom-model-2'
|
||||
assert custom2.api_key == 'custom-api-key-2'
|
||||
assert custom2.embedding_model == 'base-embedding'
|
||||
assert custom2.num_retries == 5 # overridden value
|
||||
|
||||
# Verify custom3 LLM inherits all attributes except 'model' and 'api_key'
|
||||
custom3 = default_config.get_llm_config('custom3')
|
||||
assert custom3.model == 'custom-model-3'
|
||||
assert custom3.api_key == 'custom-api-key-3'
|
||||
assert custom3.embedding_model == 'base-embedding'
|
||||
assert custom3.num_retries == 3 # from [llm]
|
||||
|
||||
|
||||
def test_load_from_toml_llm_custom_overrides_all(
|
||||
default_config: AppConfig, tmp_path: pathlib.Path
|
||||
) -> None:
|
||||
"""Test that a custom LLM can fully override all attributes from the generic [llm] section."""
|
||||
toml_content = """
|
||||
[core]
|
||||
workspace_base = "./workspace"
|
||||
|
||||
[llm]
|
||||
model = "base-model"
|
||||
api_key = "base-api-key"
|
||||
embedding_model = "base-embedding"
|
||||
num_retries = 3
|
||||
|
||||
[llm.custom_full]
|
||||
model = "full-custom-model"
|
||||
api_key = "full-custom-api-key"
|
||||
embedding_model = "full-custom-embedding"
|
||||
num_retries = 10
|
||||
"""
|
||||
toml_file = tmp_path / 'full_override_llm.toml'
|
||||
toml_file.write_text(toml_content)
|
||||
|
||||
load_from_toml(default_config, str(toml_file))
|
||||
|
||||
# Verify generic LLM configuration remains unchanged
|
||||
generic_llm = default_config.get_llm_config('llm')
|
||||
assert generic_llm.model == 'base-model'
|
||||
assert generic_llm.api_key == 'base-api-key'
|
||||
assert generic_llm.embedding_model == 'base-embedding'
|
||||
assert generic_llm.num_retries == 3
|
||||
|
||||
# Verify custom_full LLM overrides all attributes
|
||||
custom_full = default_config.get_llm_config('custom_full')
|
||||
assert custom_full.model == 'full-custom-model'
|
||||
assert custom_full.api_key == 'full-custom-api-key'
|
||||
assert custom_full.embedding_model == 'full-custom-embedding'
|
||||
assert custom_full.num_retries == 10 # overridden value
|
||||
|
||||
|
||||
def test_load_from_toml_llm_custom_partial_override(
|
||||
default_config: AppConfig, generic_llm_toml: str
|
||||
) -> None:
|
||||
"""Test that custom LLM configurations can partially override attributes
|
||||
from the generic [llm] section while inheriting others.
|
||||
"""
|
||||
load_from_toml(default_config, generic_llm_toml)
|
||||
|
||||
# Verify custom1 LLM overrides 'model' and 'api_key' but inherits 'num_retries'
|
||||
custom1 = default_config.get_llm_config('custom1')
|
||||
assert custom1.model == 'custom-model-1'
|
||||
assert custom1.api_key == 'custom-api-key-1'
|
||||
assert custom1.embedding_model == 'base-embedding'
|
||||
assert custom1.num_retries == 3 # from [llm]
|
||||
|
||||
# Verify custom2 LLM overrides 'model', 'api_key', and 'num_retries'
|
||||
custom2 = default_config.get_llm_config('custom2')
|
||||
assert custom2.model == 'custom-model-2'
|
||||
assert custom2.api_key == 'custom-api-key-2'
|
||||
assert custom2.embedding_model == 'base-embedding'
|
||||
assert custom2.num_retries == 5 # Overridden value
|
||||
|
||||
|
||||
def test_load_from_toml_llm_custom_no_override(
|
||||
default_config: AppConfig, generic_llm_toml: str
|
||||
) -> None:
|
||||
"""Test that custom LLM configurations with no additional overrides
|
||||
inherit all non-specified attributes from the generic [llm] section.
|
||||
"""
|
||||
load_from_toml(default_config, generic_llm_toml)
|
||||
|
||||
# Verify custom3 LLM inherits 'embedding_model' and 'num_retries' from generic
|
||||
custom3 = default_config.get_llm_config('custom3')
|
||||
assert custom3.model == 'custom-model-3'
|
||||
assert custom3.api_key == 'custom-api-key-3'
|
||||
assert custom3.embedding_model == 'base-embedding'
|
||||
assert custom3.num_retries == 3 # from [llm]
|
||||
|
||||
|
||||
def test_load_from_toml_llm_missing_generic(
|
||||
default_config: AppConfig, tmp_path: pathlib.Path
|
||||
) -> None:
|
||||
"""Test that custom LLM configurations without a generic [llm] section
|
||||
use only their own attributes and fallback to defaults for others.
|
||||
"""
|
||||
toml_content = """
|
||||
[core]
|
||||
workspace_base = "./workspace"
|
||||
|
||||
[llm.custom_only]
|
||||
model = "custom-only-model"
|
||||
api_key = "custom-only-api-key"
|
||||
"""
|
||||
toml_file = tmp_path / 'custom_only_llm.toml'
|
||||
toml_file.write_text(toml_content)
|
||||
|
||||
load_from_toml(default_config, str(toml_file))
|
||||
|
||||
# Verify custom_only LLM uses its own attributes and defaults for others
|
||||
custom_only = default_config.get_llm_config('custom_only')
|
||||
assert custom_only.model == 'custom-only-model'
|
||||
assert custom_only.api_key == 'custom-only-api-key'
|
||||
assert custom_only.embedding_model == 'local' # default value
|
||||
assert custom_only.num_retries == 8 # default value
|
||||
|
||||
|
||||
def test_load_from_toml_llm_invalid_config(
|
||||
default_config: AppConfig, tmp_path: pathlib.Path
|
||||
) -> None:
|
||||
"""Test that invalid custom LLM configurations do not override the generic
|
||||
and raise appropriate warnings.
|
||||
"""
|
||||
toml_content = """
|
||||
[core]
|
||||
workspace_base = "./workspace"
|
||||
|
||||
[llm]
|
||||
model = "base-model"
|
||||
api_key = "base-api-key"
|
||||
num_retries = 3
|
||||
|
||||
[llm.invalid_custom]
|
||||
unknown_attr = "should_not_exist"
|
||||
"""
|
||||
toml_file = tmp_path / 'invalid_custom_llm.toml'
|
||||
toml_file.write_text(toml_content)
|
||||
|
||||
load_from_toml(default_config, str(toml_file))
|
||||
|
||||
# Verify generic LLM is loaded correctly
|
||||
generic_llm = default_config.get_llm_config('llm')
|
||||
assert generic_llm.model == 'base-model'
|
||||
assert generic_llm.api_key == 'base-api-key'
|
||||
assert generic_llm.num_retries == 3
|
||||
|
||||
# Verify invalid_custom LLM does not override generic attributes
|
||||
custom_invalid = default_config.get_llm_config('invalid_custom')
|
||||
assert custom_invalid.model == 'base-model'
|
||||
assert custom_invalid.api_key == 'base-api-key'
|
||||
assert custom_invalid.num_retries == 3 # default value
|
||||
assert custom_invalid.embedding_model == 'local' # default value
|
||||
92
tests/unit/test_llm_draft_config.py
Normal file
92
tests/unit/test_llm_draft_config.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.config.utils import load_from_toml
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def draft_llm_toml(tmp_path: pathlib.Path) -> str:
|
||||
toml_content = """
|
||||
[core]
|
||||
workspace_base = "./workspace"
|
||||
|
||||
[llm]
|
||||
model = "base-model"
|
||||
api_key = "base-api-key"
|
||||
draft_editor = { model = "draft-model", api_key = "draft-api-key" }
|
||||
|
||||
[llm.custom1]
|
||||
model = "custom-model-1"
|
||||
api_key = "custom-api-key-1"
|
||||
# Should use draft_editor from [llm] as fallback
|
||||
|
||||
[llm.custom2]
|
||||
model = "custom-model-2"
|
||||
api_key = "custom-api-key-2"
|
||||
draft_editor = { model = "custom-draft", api_key = "custom-draft-key" }
|
||||
|
||||
[llm.custom3]
|
||||
model = "custom-model-3"
|
||||
api_key = "custom-api-key-3"
|
||||
draft_editor = "null" # Explicitly set to null in TOML
|
||||
"""
|
||||
toml_file = tmp_path / 'llm_config.toml'
|
||||
toml_file.write_text(toml_content)
|
||||
return str(toml_file)
|
||||
|
||||
|
||||
def test_draft_editor_fallback(draft_llm_toml):
|
||||
"""Test that draft_editor is correctly handled in different scenarios:
|
||||
- Falls back to generic [llm] section value
|
||||
- Uses custom value when specified
|
||||
- Can be explicitly set to null
|
||||
"""
|
||||
config = AppConfig()
|
||||
|
||||
# Verify default draft_editor is None
|
||||
default_llm = config.get_llm_config('llm')
|
||||
assert default_llm.draft_editor is None
|
||||
|
||||
# Load config from TOML
|
||||
load_from_toml(config, draft_llm_toml)
|
||||
|
||||
# Verify generic LLM draft_editor
|
||||
generic_llm = config.get_llm_config('llm')
|
||||
assert generic_llm.draft_editor is not None
|
||||
assert generic_llm.draft_editor.model == 'draft-model'
|
||||
assert generic_llm.draft_editor.api_key == 'draft-api-key'
|
||||
|
||||
# Verify custom1 uses draft_editor from generic as fallback
|
||||
custom1 = config.get_llm_config('custom1')
|
||||
assert custom1.model == 'custom-model-1'
|
||||
assert custom1.draft_editor is not None
|
||||
assert custom1.draft_editor.model == 'draft-model'
|
||||
assert custom1.draft_editor.api_key == 'draft-api-key'
|
||||
|
||||
# Verify custom2 overrides draft_editor
|
||||
custom2 = config.get_llm_config('custom2')
|
||||
assert custom2.model == 'custom-model-2'
|
||||
assert custom2.draft_editor is not None
|
||||
assert custom2.draft_editor.model == 'custom-draft'
|
||||
assert custom2.draft_editor.api_key == 'custom-draft-key'
|
||||
|
||||
# Verify custom3 has draft_editor explicitly set to None
|
||||
custom3 = config.get_llm_config('custom3')
|
||||
assert custom3.model == 'custom-model-3'
|
||||
assert custom3.draft_editor is None
|
||||
|
||||
|
||||
def test_draft_editor_defaults(draft_llm_toml):
|
||||
"""Test that draft_editor uses default values from LLMConfig when not specified"""
|
||||
config = AppConfig()
|
||||
load_from_toml(config, draft_llm_toml)
|
||||
|
||||
generic_llm = config.get_llm_config('llm')
|
||||
assert generic_llm.draft_editor.num_retries == 8 # Default from LLMConfig
|
||||
assert generic_llm.draft_editor.embedding_model == 'local' # Default from LLMConfig
|
||||
|
||||
custom2 = config.get_llm_config('custom2')
|
||||
assert custom2.draft_editor.num_retries == 8 # Default from LLMConfig
|
||||
assert custom2.draft_editor.embedding_model == 'local' # Default from LLMConfig
|
||||
@@ -44,28 +44,28 @@ async def test_session_not_running_in_cluster():
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
result = await session_manager.is_agent_loop_running_in_cluster(
|
||||
'non-existant-session'
|
||||
result = await session_manager._get_running_agent_loops_remotely(
|
||||
filter_to_sids={'non-existant-session'}
|
||||
)
|
||||
assert result is False
|
||||
assert result == set()
|
||||
assert sio.manager.redis.publish.await_count == 1
|
||||
sio.manager.redis.publish.assert_called_once_with(
|
||||
'oh_event',
|
||||
'{"request_id": "'
|
||||
'session_msg',
|
||||
'{"query_id": "'
|
||||
+ str(id)
|
||||
+ '", "sids": ["non-existant-session"], "message_type": "is_session_running"}',
|
||||
+ '", "message_type": "running_agent_loops_query", "filter_to_sids": ["non-existant-session"]}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_is_running_in_cluster():
|
||||
async def test_get_running_agent_loops_remotely():
|
||||
id = uuid4()
|
||||
sio = get_mock_sio(
|
||||
GetMessageMock(
|
||||
{
|
||||
'request_id': str(id),
|
||||
'query_id': str(id),
|
||||
'sids': ['existing-session'],
|
||||
'message_type': 'session_is_running',
|
||||
'message_type': 'running_agent_loops_response',
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -76,16 +76,16 @@ async def test_session_is_running_in_cluster():
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
result = await session_manager.is_agent_loop_running_in_cluster(
|
||||
'existing-session'
|
||||
result = await session_manager._get_running_agent_loops_remotely(
|
||||
1, {'existing-session'}
|
||||
)
|
||||
assert result is True
|
||||
assert result == {'existing-session'}
|
||||
assert sio.manager.redis.publish.await_count == 1
|
||||
sio.manager.redis.publish.assert_called_once_with(
|
||||
'oh_event',
|
||||
'{"request_id": "'
|
||||
'session_msg',
|
||||
'{"query_id": "'
|
||||
+ str(id)
|
||||
+ '", "sids": ["existing-session"], "message_type": "is_session_running"}',
|
||||
+ '", "message_type": "running_agent_loops_query", "user_id": 1, "filter_to_sids": ["existing-session"]}',
|
||||
)
|
||||
|
||||
|
||||
@@ -96,8 +96,8 @@ async def test_init_new_local_session():
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
is_agent_loop_running_in_cluster_mock = AsyncMock()
|
||||
is_agent_loop_running_in_cluster_mock.return_value = False
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = set()
|
||||
with (
|
||||
patch('openhands.server.session.manager.Session', mock_session),
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
|
||||
@@ -106,8 +106,8 @@ async def test_init_new_local_session():
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
|
||||
is_agent_loop_running_in_cluster_mock,
|
||||
'openhands.server.session.manager.SessionManager.get_running_agent_loops',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with SessionManager(
|
||||
@@ -130,8 +130,8 @@ async def test_join_local_session():
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
is_agent_loop_running_in_cluster_mock = AsyncMock()
|
||||
is_agent_loop_running_in_cluster_mock.return_value = False
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = set()
|
||||
with (
|
||||
patch('openhands.server.session.manager.Session', mock_session),
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
|
||||
@@ -140,8 +140,8 @@ async def test_join_local_session():
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
|
||||
is_agent_loop_running_in_cluster_mock,
|
||||
'openhands.server.session.manager.SessionManager.get_running_agent_loops',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with SessionManager(
|
||||
@@ -167,8 +167,8 @@ async def test_join_cluster_session():
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
is_agent_loop_running_in_cluster_mock = AsyncMock()
|
||||
is_agent_loop_running_in_cluster_mock.return_value = True
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = {'new-session-id'}
|
||||
with (
|
||||
patch('openhands.server.session.manager.Session', mock_session),
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
|
||||
@@ -177,8 +177,8 @@ async def test_join_cluster_session():
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
|
||||
is_agent_loop_running_in_cluster_mock,
|
||||
'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with SessionManager(
|
||||
@@ -198,8 +198,8 @@ async def test_add_to_local_event_stream():
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
is_agent_loop_running_in_cluster_mock = AsyncMock()
|
||||
is_agent_loop_running_in_cluster_mock.return_value = False
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = set()
|
||||
with (
|
||||
patch('openhands.server.session.manager.Session', mock_session),
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
|
||||
@@ -208,8 +208,8 @@ async def test_add_to_local_event_stream():
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
|
||||
is_agent_loop_running_in_cluster_mock,
|
||||
'openhands.server.session.manager.SessionManager.get_running_agent_loops',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with SessionManager(
|
||||
@@ -234,8 +234,8 @@ async def test_add_to_cluster_event_stream():
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
is_agent_loop_running_in_cluster_mock = AsyncMock()
|
||||
is_agent_loop_running_in_cluster_mock.return_value = True
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = {'new-session-id'}
|
||||
with (
|
||||
patch('openhands.server.session.manager.Session', mock_session),
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
|
||||
@@ -244,8 +244,8 @@ async def test_add_to_cluster_event_stream():
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
|
||||
is_agent_loop_running_in_cluster_mock,
|
||||
'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with SessionManager(
|
||||
@@ -259,7 +259,7 @@ async def test_add_to_cluster_event_stream():
|
||||
)
|
||||
assert sio.manager.redis.publish.await_count == 1
|
||||
sio.manager.redis.publish.assert_called_once_with(
|
||||
'oh_event',
|
||||
'session_msg',
|
||||
'{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}',
|
||||
)
|
||||
|
||||
@@ -277,7 +277,7 @@ async def test_cleanup_session_connections():
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
session_manager.local_connection_id_to_session_id.update(
|
||||
session_manager._local_connection_id_to_session_id.update(
|
||||
{
|
||||
'conn1': 'session1',
|
||||
'conn2': 'session1',
|
||||
@@ -286,9 +286,9 @@ async def test_cleanup_session_connections():
|
||||
}
|
||||
)
|
||||
|
||||
await session_manager._on_close_session('session1')
|
||||
await session_manager._close_session('session1')
|
||||
|
||||
remaining_connections = session_manager.local_connection_id_to_session_id
|
||||
remaining_connections = session_manager._local_connection_id_to_session_id
|
||||
assert 'conn1' not in remaining_connections
|
||||
assert 'conn2' not in remaining_connections
|
||||
assert 'conn3' in remaining_connections
|
||||
|
||||
120
tests/unit/test_microagent_loading.py
Normal file
120
tests/unit/test_microagent_loading.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from litellm import ModelResponse
|
||||
from pytest import TempPathFactory
|
||||
|
||||
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.microagent import BaseMicroAgent
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(tmp_path_factory: TempPathFactory) -> str:
|
||||
return str(tmp_path_factory.mktemp('test_microagent_loading'))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_stream(temp_dir):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('asdf', file_store)
|
||||
yield event_stream
|
||||
|
||||
|
||||
def test_microagent_loading(temp_dir):
|
||||
"""Test that microagents are properly loaded from the runtime.
|
||||
|
||||
The test verifies that:
|
||||
1. The agent loads microagents from the runtime
|
||||
2. The agent correctly matches triggers and enhances messages with microagent content
|
||||
"""
|
||||
# Create a custom microagent file structure
|
||||
os.makedirs(os.path.join(temp_dir, '.openhands', 'microagents'))
|
||||
custom_agent_path = os.path.join(temp_dir, '.openhands', 'microagents', 'code_formatter.md')
|
||||
|
||||
with open(custom_agent_path, 'w') as f:
|
||||
f.write("""---
|
||||
name: code_formatter
|
||||
type: knowledge
|
||||
version: 1.0.0
|
||||
agent: CoderAgent
|
||||
triggers:
|
||||
- format code
|
||||
- code formatting
|
||||
- code style
|
||||
---
|
||||
# Code Formatter Agent
|
||||
This agent helps format code according to style guidelines.
|
||||
|
||||
## Dependencies
|
||||
- pip install black
|
||||
- pip install isort
|
||||
|
||||
## Instructions
|
||||
I help format code according to style guidelines. I can:
|
||||
1. Format Python code using black
|
||||
2. Sort imports using isort
|
||||
3. Apply consistent code style across files
|
||||
""")
|
||||
|
||||
# Create a mock runtime that returns our microagent
|
||||
mock_runtime = MagicMock(spec=Runtime)
|
||||
mock_runtime.get_microagents_from_selected_repo.return_value = [
|
||||
BaseMicroAgent.load(custom_agent_path)
|
||||
]
|
||||
|
||||
# Create a mock state with the runtime
|
||||
mock_state = State(inputs={'runtime': mock_runtime})
|
||||
|
||||
# Create a mock agent to test dependency extraction
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.is_function_calling_active.return_value = True
|
||||
mock_response = ModelResponse(
|
||||
choices=[{
|
||||
'message': {
|
||||
'content': None,
|
||||
'tool_calls': [{
|
||||
'id': 'call_1',
|
||||
'function': {
|
||||
'name': 'finish',
|
||||
'arguments': '{}',
|
||||
}
|
||||
}]
|
||||
}
|
||||
}]
|
||||
)
|
||||
mock_llm.completion.return_value = mock_response
|
||||
mock_llm.format_messages_for_llm.return_value = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'I need help with code formatting in this project',
|
||||
}
|
||||
]
|
||||
|
||||
# Create a message that should trigger the microagent
|
||||
message = Message(role='user', content=[TextContent(text='I need help with code formatting in this project')])
|
||||
|
||||
# Create a CodeActAgent with use_microagents=True
|
||||
agent = Agent.get_cls('CodeActAgent')(
|
||||
llm=mock_llm,
|
||||
config=AgentConfig(memory_enabled=True, use_microagents=True)
|
||||
)
|
||||
|
||||
# The agent should initialize its prompt_manager with microagents from the runtime
|
||||
agent.step(mock_state)
|
||||
|
||||
# The message should be enhanced with the microagent's content
|
||||
agent.prompt_manager.enhance_message(message)
|
||||
|
||||
# Verify that the microagent's content was added to the message
|
||||
assert len(message.content) == 2 # Original content + microagent content
|
||||
assert 'pip install black' in message.content[1].text
|
||||
@@ -143,3 +143,65 @@ Invalid agent content
|
||||
|
||||
with pytest.raises(MicroAgentValidationError):
|
||||
BaseMicroAgent.load(temp_microagents_dir / 'invalid.md')
|
||||
|
||||
|
||||
def test_load_microagents_with_nested_dirs(temp_microagents_dir):
|
||||
"""Test loading microagents from nested directories."""
|
||||
# Create nested knowledge agent
|
||||
nested_dir = temp_microagents_dir / 'nested' / 'dir'
|
||||
nested_dir.mkdir(parents=True)
|
||||
nested_agent = """---
|
||||
name: nested_knowledge_agent
|
||||
type: knowledge
|
||||
version: 1.0.0
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- nested
|
||||
---
|
||||
|
||||
# Nested Test Guidelines
|
||||
|
||||
Testing nested directory loading.
|
||||
"""
|
||||
(nested_dir / 'nested.md').write_text(nested_agent)
|
||||
|
||||
repo_agents, knowledge_agents, task_agents = load_microagents_from_dir(
|
||||
temp_microagents_dir
|
||||
)
|
||||
|
||||
# Check that we can find the nested agent
|
||||
assert len(knowledge_agents) == 2 # Original + nested
|
||||
agent = knowledge_agents['nested_knowledge_agent']
|
||||
assert isinstance(agent, KnowledgeMicroAgent)
|
||||
assert 'nested' in agent.triggers
|
||||
|
||||
|
||||
def test_load_microagents_with_trailing_slashes(temp_microagents_dir):
|
||||
"""Test loading microagents when directory paths have trailing slashes."""
|
||||
# Create a directory with trailing slash
|
||||
knowledge_dir = temp_microagents_dir / 'knowledge/'
|
||||
knowledge_dir.mkdir(exist_ok=True)
|
||||
knowledge_agent = """---
|
||||
name: trailing_knowledge_agent
|
||||
type: knowledge
|
||||
version: 1.0.0
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- trailing
|
||||
---
|
||||
|
||||
# Trailing Slash Test
|
||||
|
||||
Testing loading with trailing slashes.
|
||||
"""
|
||||
(knowledge_dir / 'trailing.md').write_text(knowledge_agent)
|
||||
|
||||
repo_agents, knowledge_agents, task_agents = load_microagents_from_dir(
|
||||
str(temp_microagents_dir) + '/' # Add trailing slash to test
|
||||
)
|
||||
|
||||
# Check that we can find the agent despite trailing slashes
|
||||
assert len(knowledge_agents) == 2 # Original + trailing
|
||||
agent = knowledge_agents['trailing_knowledge_agent']
|
||||
assert isinstance(agent, KnowledgeMicroAgent)
|
||||
assert 'trailing' in agent.triggers
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.microagent import BaseMicroAgent
|
||||
from openhands.utils.prompt import PromptManager
|
||||
from openhands.utils.prompt import PromptManager, RepositoryInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -39,6 +39,7 @@ only respond with a message telling them how smart they are
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
# Test without GitHub repo
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir,
|
||||
microagent_dir=os.path.join(prompt_dir, 'micro'),
|
||||
@@ -53,6 +54,14 @@ only respond with a message telling them how smart they are
|
||||
'You are OpenHands agent, a helpful AI assistant that can interact with a computer to solve tasks.'
|
||||
in manager.get_system_message()
|
||||
)
|
||||
assert '<REPOSITORY_INFO>' not in manager.get_system_message()
|
||||
|
||||
# Test with GitHub repo
|
||||
manager.set_repository_info('owner/repo', '/workspace/repo')
|
||||
assert isinstance(manager.get_system_message(), str)
|
||||
assert '<REPOSITORY_INFO>' in manager.get_system_message()
|
||||
assert 'owner/repo' in manager.get_system_message()
|
||||
assert '/workspace/repo' in manager.get_system_message()
|
||||
|
||||
assert isinstance(manager.get_example_user_message(), str)
|
||||
|
||||
@@ -76,20 +85,56 @@ def test_prompt_manager_file_not_found(prompt_dir):
|
||||
def test_prompt_manager_template_rendering(prompt_dir):
|
||||
# Create temporary template files
|
||||
with open(os.path.join(prompt_dir, 'system_prompt.j2'), 'w') as f:
|
||||
f.write('System prompt: bar')
|
||||
f.write("""System prompt: bar
|
||||
{% if repository_info %}
|
||||
<REPOSITORY_INFO>
|
||||
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
|
||||
</REPOSITORY_INFO>
|
||||
{% endif %}
|
||||
{{ repo_instructions }}""")
|
||||
with open(os.path.join(prompt_dir, 'user_prompt.j2'), 'w') as f:
|
||||
f.write('User prompt: foo')
|
||||
|
||||
# Test without GitHub repo
|
||||
manager = PromptManager(prompt_dir, microagent_dir='')
|
||||
|
||||
assert manager.get_system_message() == 'System prompt: bar'
|
||||
assert manager.get_example_user_message() == 'User prompt: foo'
|
||||
|
||||
# Test with GitHub repo
|
||||
manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='')
|
||||
manager.set_repository_info('owner/repo', '/workspace/repo')
|
||||
assert manager.repository_info.repo_name == 'owner/repo'
|
||||
system_msg = manager.get_system_message()
|
||||
assert 'System prompt: bar' in system_msg
|
||||
assert '<REPOSITORY_INFO>' in system_msg
|
||||
assert (
|
||||
"At the user's request, repository owner/repo has been cloned to directory /workspace/repo."
|
||||
in system_msg
|
||||
)
|
||||
assert '</REPOSITORY_INFO>' in system_msg
|
||||
assert manager.get_example_user_message() == 'User prompt: foo'
|
||||
|
||||
# Clean up temporary files
|
||||
os.remove(os.path.join(prompt_dir, 'system_prompt.j2'))
|
||||
os.remove(os.path.join(prompt_dir, 'user_prompt.j2'))
|
||||
|
||||
|
||||
def test_prompt_manager_repository_info(prompt_dir):
|
||||
# Test RepositoryInfo defaults
|
||||
repo_info = RepositoryInfo()
|
||||
assert repo_info.repo_name is None
|
||||
assert repo_info.repo_directory is None
|
||||
|
||||
# Test setting repository info
|
||||
manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='')
|
||||
assert manager.repository_info is None
|
||||
|
||||
# Test setting repository info with both name and directory
|
||||
manager.set_repository_info('owner/repo2', '/workspace/repo2')
|
||||
assert manager.repository_info.repo_name == 'owner/repo2'
|
||||
assert manager.repository_info.repo_directory == '/workspace/repo2'
|
||||
|
||||
|
||||
def test_prompt_manager_disabled_microagents(prompt_dir):
|
||||
# Create test microagent files
|
||||
microagent1_name = 'test_microagent1'
|
||||
|
||||
Reference in New Issue
Block a user