Compare commits

..

16 Commits

Author SHA1 Message Date
openhands
8d89fada9c fix: Use runtime to load microagents
- Update CodeActAgent to use runtime.get_microagents_from_selected_repo
- Add test to verify microagent loading from runtime
- Fix #6304
2025-01-16 14:01:43 +00:00
dependabot[bot]
6e089619e0 chore(deps-dev): bump chromadb from 0.6.2 to 0.6.3 in the chromadb group (#6289)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-01-16 00:37:42 +01:00
Xingyao Wang
179a89a211 Fix microagent loading with trailing slashes and nested directories (#6239)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-01-15 17:07:40 +00:00
tofarr
8795ee6c6e Fix closing sessions (#6114) 2025-01-15 10:04:22 -07:00
Engel Nyst
97e938d545 Fix French doc (#6283) 2025-01-15 04:25:47 +00:00
Engel Nyst
b9a70c8d5c Delegation fixes (#6165) 2025-01-15 03:24:39 +00:00
Ray Myers
082d0b25c5 Send status message on runtime restart (#6275) 2025-01-15 03:21:06 +01:00
Engel Nyst
c5797d1d5a Fix llm_config fallback (#4415)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-01-15 01:17:37 +00:00
Xingyao Wang
7ce1fb85ff chore: remove repo info from initial query for #6057 (#6279) 2025-01-15 00:40:54 +00:00
Robert Brennan
fa6792e5a6 Add GitHub repository information to system prompt (#6057)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-01-15 08:02:07 +08:00
dependabot[bot]
3d9b4c4af6 chore(deps): bump the version-all group with 4 updates (#6267)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-01-14 21:30:56 +01:00
tofarr
e21cbf67ee Feat: User id should be a str (Because it will probably be a UUID) (#6251) 2025-01-14 12:39:51 -07:00
Xingyao Wang
6b2e3f938f fix: prevent runtime size deselection (#6119)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: mamoodi <mamoodiha@gmail.com>
2025-01-14 17:53:51 +00:00
Rohit Malhotra
580d7b938c Fix: Don't refresh github token on local (#5880) 2025-01-14 17:48:33 +00:00
mamoodi
28178a2940 Remove extra optional for github token (#6270) 2025-01-14 17:44:28 +00:00
sp.wack
04382b2b19 hotfix(backend): Remove GH header token middleware (#6269) 2025-01-14 12:07:13 -05:00
64 changed files with 2124 additions and 621 deletions

View File

@@ -56,6 +56,7 @@ jobs:
LLM_MODEL: "litellm_proxy/claude-3-5-haiku-20241022"
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
MAX_ITERATIONS: 10
run: |
echo "[llm.eval]" > config.toml
echo "model = \"$LLM_MODEL\"" >> config.toml
@@ -70,7 +71,7 @@ jobs:
env:
SANDBOX_FORCE_REBUILD_RUNTIME: True
run: |
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD CodeActAgent '' $N_PROCESSES '' 'haiku_run'
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD CodeActAgent '' 10 $N_PROCESSES '' 'haiku_run'
# get integration tests report
REPORT_FILE_HAIKU=$(find evaluation/evaluation_outputs/outputs/integration_tests/CodeActAgent/*haiku*_maxiter_10_N* -name "report.md" -type f | head -n 1)
@@ -88,6 +89,7 @@ jobs:
LLM_MODEL: "litellm_proxy/deepseek-chat"
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
MAX_ITERATIONS: 10
run: |
echo "[llm.eval]" > config.toml
echo "model = \"$LLM_MODEL\"" >> config.toml
@@ -99,7 +101,7 @@ jobs:
env:
SANDBOX_FORCE_REBUILD_RUNTIME: True
run: |
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD CodeActAgent '' $N_PROCESSES '' 'deepseek_run'
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD CodeActAgent '' 10 $N_PROCESSES '' 'deepseek_run'
# get integration tests report
REPORT_FILE_DEEPSEEK=$(find evaluation/evaluation_outputs/outputs/integration_tests/CodeActAgent/deepseek*_maxiter_10_N* -name "report.md" -type f | head -n 1)
@@ -109,11 +111,75 @@ jobs:
echo >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV
# -------------------------------------------------------------
# Run DelegatorAgent tests for Haiku, limited to t01 and t02
- name: Wait a little bit (again)
run: sleep 5
- name: Configure config.toml for testing DelegatorAgent (Haiku)
env:
LLM_MODEL: "litellm_proxy/claude-3-5-haiku-20241022"
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
MAX_ITERATIONS: 30
run: |
echo "[llm.eval]" > config.toml
echo "model = \"$LLM_MODEL\"" >> config.toml
echo "api_key = \"$LLM_API_KEY\"" >> config.toml
echo "base_url = \"$LLM_BASE_URL\"" >> config.toml
echo "temperature = 0.0" >> config.toml
- name: Run integration test evaluation for DelegatorAgent (Haiku)
env:
SANDBOX_FORCE_REBUILD_RUNTIME: True
run: |
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD DelegatorAgent '' 30 $N_PROCESSES "t01_fix_simple_typo,t02_add_bash_hello" 'delegator_haiku_run'
# Find and export the delegator test results
REPORT_FILE_DELEGATOR_HAIKU=$(find evaluation/evaluation_outputs/outputs/integration_tests/DelegatorAgent/*haiku*_maxiter_30_N* -name "report.md" -type f | head -n 1)
echo "REPORT_FILE_DELEGATOR_HAIKU: $REPORT_FILE_DELEGATOR_HAIKU"
echo "INTEGRATION_TEST_REPORT_DELEGATOR_HAIKU<<EOF" >> $GITHUB_ENV
cat $REPORT_FILE_DELEGATOR_HAIKU >> $GITHUB_ENV
echo >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV
# -------------------------------------------------------------
# Run DelegatorAgent tests for DeepSeek, limited to t01 and t02
- name: Wait a little bit (again)
run: sleep 5
- name: Configure config.toml for testing DelegatorAgent (DeepSeek)
env:
LLM_MODEL: "litellm_proxy/deepseek-chat"
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
MAX_ITERATIONS: 30
run: |
echo "[llm.eval]" > config.toml
echo "model = \"$LLM_MODEL\"" >> config.toml
echo "api_key = \"$LLM_API_KEY\"" >> config.toml
echo "base_url = \"$LLM_BASE_URL\"" >> config.toml
echo "temperature = 0.0" >> config.toml
- name: Run integration test evaluation for DelegatorAgent (DeepSeek)
env:
SANDBOX_FORCE_REBUILD_RUNTIME: True
run: |
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD DelegatorAgent '' 30 $N_PROCESSES "t01_fix_simple_typo,t02_add_bash_hello" 'delegator_deepseek_run'
# Find and export the delegator test results
REPORT_FILE_DELEGATOR_DEEPSEEK=$(find evaluation/evaluation_outputs/outputs/integration_tests/DelegatorAgent/deepseek*_maxiter_30_N* -name "report.md" -type f | head -n 1)
echo "REPORT_FILE_DELEGATOR_DEEPSEEK: $REPORT_FILE_DELEGATOR_DEEPSEEK"
echo "INTEGRATION_TEST_REPORT_DELEGATOR_DEEPSEEK<<EOF" >> $GITHUB_ENV
cat $REPORT_FILE_DELEGATOR_DEEPSEEK >> $GITHUB_ENV
echo >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV
- name: Create archive of evaluation outputs
run: |
TIMESTAMP=$(date +'%y-%m-%d-%H-%M')
cd evaluation/evaluation_outputs/outputs # Change to the outputs directory
tar -czvf ../../../integration_tests_${TIMESTAMP}.tar.gz integration_tests/CodeActAgent/* # Only include the actual result directories
tar -czvf ../../../integration_tests_${TIMESTAMP}.tar.gz integration_tests/CodeActAgent/* integration_tests/DelegatorAgent/* # Only include the actual result directories
- name: Upload evaluation results as artifact
uses: actions/upload-artifact@v4
@@ -154,5 +220,11 @@ jobs:
**Integration Tests Report (DeepSeek)**
DeepSeek LLM Test Results:
${{ env.INTEGRATION_TEST_REPORT_DEEPSEEK }}
---
**Integration Tests Report Delegator (Haiku)**
${{ env.INTEGRATION_TEST_REPORT_DELEGATOR_HAIKU }}
---
**Integration Tests Report Delegator (DeepSeek)**
${{ env.INTEGRATION_TEST_REPORT_DELEGATOR_DEEPSEEK }}
---
Download testing outputs (includes both Haiku and DeepSeek results): [Download](${{ steps.upload_results_artifact.outputs.artifact-url }})

View File

@@ -1,5 +1,3 @@
# Options de configuration
Ce guide détaille toutes les options de configuration disponibles pour OpenHands, vous aidant à personnaliser son comportement et à l'intégrer avec d'autres services.
@@ -184,6 +182,10 @@ Les options de configuration LLM (Large Language Model) sont définies dans la s
Pour les utiliser avec la commande docker, passez `-e LLM_<option>`. Exemple : `-e LLM_NUM_RETRIES`.
:::note
Pour les configurations de développement, vous pouvez également définir des configurations LLM personnalisées. Voir [Configurations LLM personnalisées](./llms/custom-llm-configs) pour plus de détails.
:::
**Informations d'identification AWS**
- `aws_access_key_id`
- Type : `str`
@@ -368,4 +370,26 @@ Les options de configuration de l'agent sont définies dans les sections `[agent
- `codeact_enable_llm_editor`
- Type : `bool`
- Valeur par défaut : `false`
- Description : Si l'éditeur LLM est activé dans l'espace d'action (foncti
- Description : Si l'éditeur LLM est activé dans l'espace d'action (fonctionne uniquement avec l'appel de fonction)
**Utilisation du micro-agent**
- `use_microagents`
- Type : `bool`
- Valeur par défaut : `true`
- Description : Indique si l'utilisation des micro-agents est activée ou non
- `disabled_microagents`
- Type : `list of str`
- Valeur par défaut : `None`
- Description : Liste des micro-agents à désactiver
### Exécution
- `timeout`
- Type : `int`
- Valeur par défaut : `120`
- Description : Délai d'expiration du bac à sable, en secondes
- `user_id`
- Type : `int`
- Valeur par défaut : `1000`
- Description : ID de l'utilisateur du bac à sable

View File

@@ -0,0 +1,106 @@
# Configurations LLM personnalisées
OpenHands permet de définir plusieurs configurations LLM nommées dans votre fichier `config.toml`. Cette fonctionnalité vous permet d'utiliser différentes configurations LLM pour différents usages, comme utiliser un modèle moins coûteux pour les tâches qui ne nécessitent pas de réponses de haute qualité, ou utiliser différents modèles avec différents paramètres pour des agents spécifiques.
## Comment ça fonctionne
Les configurations LLM nommées sont définies dans le fichier `config.toml` en utilisant des sections qui commencent par `llm.`. Par exemple :
```toml
# Configuration LLM par défaut
[llm]
model = "gpt-4"
api_key = "votre-clé-api"
temperature = 0.0
# Configuration LLM personnalisée pour un modèle moins coûteux
[llm.gpt3]
model = "gpt-3.5-turbo"
api_key = "votre-clé-api"
temperature = 0.2
# Une autre configuration personnalisée avec des paramètres différents
[llm.haute-creativite]
model = "gpt-4"
api_key = "votre-clé-api"
temperature = 0.8
top_p = 0.9
```
Chaque configuration nommée hérite de tous les paramètres de la section `[llm]` par défaut et peut remplacer n'importe lequel de ces paramètres. Vous pouvez définir autant de configurations personnalisées que nécessaire.
## Utilisation des configurations personnalisées
### Avec les agents
Vous pouvez spécifier quelle configuration LLM un agent doit utiliser en définissant le paramètre `llm_config` dans la section de configuration de l'agent :
```toml
[agent.RepoExplorerAgent]
# Utiliser la configuration GPT-3 moins coûteuse pour cet agent
llm_config = 'gpt3'
[agent.CodeWriterAgent]
# Utiliser la configuration haute créativité pour cet agent
llm_config = 'haute-creativite'
```
### Options de configuration
Chaque configuration LLM nommée prend en charge toutes les mêmes options que la configuration LLM par défaut. Celles-ci incluent :
- Sélection du modèle (`model`)
- Configuration de l'API (`api_key`, `base_url`, etc.)
- Paramètres du modèle (`temperature`, `top_p`, etc.)
- Paramètres de nouvelle tentative (`num_retries`, `retry_multiplier`, etc.)
- Limites de jetons (`max_input_tokens`, `max_output_tokens`)
- Et toutes les autres options de configuration LLM
Pour une liste complète des options disponibles, consultez la section Configuration LLM dans la documentation des [Options de configuration](../configuration-options).
## Cas d'utilisation
Les configurations LLM personnalisées sont particulièrement utiles dans plusieurs scénarios :
- **Optimisation des coûts** : Utiliser des modèles moins coûteux pour les tâches qui ne nécessitent pas de réponses de haute qualité, comme l'exploration de dépôt ou les opérations simples sur les fichiers.
- **Réglage spécifique aux tâches** : Configurer différentes valeurs de température et de top_p pour les tâches qui nécessitent différents niveaux de créativité ou de déterminisme.
- **Différents fournisseurs** : Utiliser différents fournisseurs LLM ou points d'accès API pour différentes tâches.
- **Tests et développement** : Basculer facilement entre différentes configurations de modèles pendant le développement et les tests.
## Exemple : Optimisation des coûts
Un exemple pratique d'utilisation des configurations LLM personnalisées pour optimiser les coûts :
```toml
# Configuration par défaut utilisant GPT-4 pour des réponses de haute qualité
[llm]
model = "gpt-4"
api_key = "votre-clé-api"
temperature = 0.0
# Configuration moins coûteuse pour l'exploration de dépôt
[llm.repo-explorer]
model = "gpt-3.5-turbo"
temperature = 0.2
# Configuration pour la génération de code
[llm.code-gen]
model = "gpt-4"
temperature = 0.0
max_output_tokens = 2000
[agent.RepoExplorerAgent]
llm_config = 'repo-explorer'
[agent.CodeWriterAgent]
llm_config = 'code-gen'
```
Dans cet exemple :
- L'exploration de dépôt utilise un modèle moins coûteux car il s'agit principalement de comprendre et de naviguer dans le code
- La génération de code utilise GPT-4 avec une limite de jetons plus élevée pour générer des blocs de code plus importants
- La configuration par défaut reste disponible pour les autres tâches
:::note
Les configurations LLM personnalisées ne sont disponibles que lors de l'utilisation d'OpenHands en mode développement, via `main.py` ou `cli.py`. Lors de l'exécution via `docker run`, veuillez utiliser les options de configuration standard.
:::

View File

@@ -140,7 +140,11 @@ The LLM (Large Language Model) configuration options are defined in the `[llm]`
To use these with the docker command, pass in `-e LLM_<option>`. Example: `-e LLM_NUM_RETRIES`.
### AWS Credentials
:::note
For development setups, you can also define custom named LLM configurations. See [Custom LLM Configurations](./llms/custom-llm-configs) for details.
:::
**AWS Credentials**
- `aws_access_key_id`
- Type: `str`
- Default: `""`

View File

@@ -0,0 +1,106 @@
# Custom LLM Configurations
OpenHands supports defining multiple named LLM configurations in your `config.toml` file. This feature allows you to use different LLM configurations for different purposes, such as using a cheaper model for tasks that don't require high-quality responses, or using different models with different parameters for specific agents.
## How It Works
Named LLM configurations are defined in the `config.toml` file using sections that start with `llm.`. For example:
```toml
# Default LLM configuration
[llm]
model = "gpt-4"
api_key = "your-api-key"
temperature = 0.0
# Custom LLM configuration for a cheaper model
[llm.gpt3]
model = "gpt-3.5-turbo"
api_key = "your-api-key"
temperature = 0.2
# Another custom configuration with different parameters
[llm.high-creativity]
model = "gpt-4"
api_key = "your-api-key"
temperature = 0.8
top_p = 0.9
```
Each named configuration inherits all settings from the default `[llm]` section and can override any of those settings. You can define as many custom configurations as needed.
## Using Custom Configurations
### With Agents
You can specify which LLM configuration an agent should use by setting the `llm_config` parameter in the agent's configuration section:
```toml
[agent.RepoExplorerAgent]
# Use the cheaper GPT-3 configuration for this agent
llm_config = 'gpt3'
[agent.CodeWriterAgent]
# Use the high creativity configuration for this agent
llm_config = 'high-creativity'
```
### Configuration Options
Each named LLM configuration supports all the same options as the default LLM configuration. These include:
- Model selection (`model`)
- API configuration (`api_key`, `base_url`, etc.)
- Model parameters (`temperature`, `top_p`, etc.)
- Retry settings (`num_retries`, `retry_multiplier`, etc.)
- Token limits (`max_input_tokens`, `max_output_tokens`)
- And all other LLM configuration options
For a complete list of available options, see the LLM Configuration section in the [Configuration Options](../configuration-options) documentation.
## Use Cases
Custom LLM configurations are particularly useful in several scenarios:
- **Cost Optimization**: Use cheaper models for tasks that don't require high-quality responses, like repository exploration or simple file operations.
- **Task-Specific Tuning**: Configure different temperature and top_p values for tasks that require different levels of creativity or determinism.
- **Different Providers**: Use different LLM providers or API endpoints for different tasks.
- **Testing and Development**: Easily switch between different model configurations during development and testing.
## Example: Cost Optimization
A practical example of using custom LLM configurations to optimize costs:
```toml
# Default configuration using GPT-4 for high-quality responses
[llm]
model = "gpt-4"
api_key = "your-api-key"
temperature = 0.0
# Cheaper configuration for repository exploration
[llm.repo-explorer]
model = "gpt-3.5-turbo"
temperature = 0.2
# Configuration for code generation
[llm.code-gen]
model = "gpt-4"
temperature = 0.0
max_output_tokens = 2000
[agent.RepoExplorerAgent]
llm_config = 'repo-explorer'
[agent.CodeWriterAgent]
llm_config = 'code-gen'
```
In this example:
- Repository exploration uses a cheaper model since it mainly involves understanding and navigating code
- Code generation uses GPT-4 with a higher token limit for generating larger code blocks
- The default configuration remains available for other tasks
:::note
Custom LLM configurations are only available when using OpenHands in development mode, via `main.py` or `cli.py`. When running via `docker run`, please use the standard configuration options.
:::

View File

@@ -8,13 +8,15 @@ from evaluation.integration_tests.tests.base import BaseIntegrationTest, TestRes
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
update_llm_config_for_completions_logging,
)
from evaluation.utils.shared import (
codeact_user_response as fake_user_response,
)
from openhands.controller.state.state import State
from openhands.core.config import (
AgentConfig,
@@ -31,7 +33,8 @@ from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
FAKE_RESPONSES = {
'CodeActAgent': codeact_user_response,
'CodeActAgent': fake_user_response,
'DelegatorAgent': fake_user_response,
}
@@ -219,7 +222,7 @@ if __name__ == '__main__':
df = pd.read_json(output_file, lines=True, orient='records')
# record success and reason for failure for the final report
# record success and reason
df['success'] = df['test_result'].apply(lambda x: x['success'])
df['reason'] = df['test_result'].apply(lambda x: x['reason'])
logger.info('-' * 100)
@@ -234,15 +237,27 @@ if __name__ == '__main__':
logger.info('-' * 100)
# record cost for each instance, with 3 decimal places
df['cost'] = df['metrics'].apply(lambda x: round(x['accumulated_cost'], 3))
# we sum up all the "costs" from the metrics array
df['cost'] = df['metrics'].apply(
lambda m: round(sum(c['cost'] for c in m['costs']), 3)
if m and 'costs' in m
else 0.0
)
# capture the top-level error if present, per instance
df['error_message'] = df.get('error', None)
logger.info(f'Total cost: USD {df["cost"].sum():.2f}')
report_file = os.path.join(metadata.eval_output_dir, 'report.md')
with open(report_file, 'w') as f:
f.write(
f'Success rate: {df["success"].mean():.2%} ({df["success"].sum()}/{len(df)})\n'
f'Success rate: {df["success"].mean():.2%}'
f' ({df["success"].sum()}/{len(df)})\n'
)
f.write(f'\nTotal cost: USD {df["cost"].sum():.2f}\n')
f.write(
df[['instance_id', 'success', 'reason', 'cost']].to_markdown(index=False)
df[
['instance_id', 'success', 'reason', 'cost', 'error_message']
].to_markdown(index=False)
)

View File

@@ -7,8 +7,9 @@ MODEL_CONFIG=$1
COMMIT_HASH=$2
AGENT=$3
EVAL_LIMIT=$4
NUM_WORKERS=$5
EVAL_IDS=$6
MAX_ITERATIONS=$5
NUM_WORKERS=$6
EVAL_IDS=$7
if [ -z "$NUM_WORKERS" ]; then
NUM_WORKERS=1
@@ -43,7 +44,7 @@ fi
COMMAND="poetry run python evaluation/integration_tests/run_infer.py \
--agent-cls $AGENT \
--llm-config $MODEL_CONFIG \
--max-iterations 10 \
--max-iterations ${MAX_ITERATIONS:-10} \
--eval-num-workers $NUM_WORKERS \
--eval-note $EVAL_NOTE"

View File

@@ -1,11 +1,10 @@
import { describe, it, expect, afterEach, vi } from "vitest";
import * as router from "react-router";
// Mock useParams before importing components
vi.mock("react-router", async () => {
const actual = await vi.importActual("react-router");
return {
...actual as object,
...(actual as object),
useParams: () => ({ conversationId: "test-conversation-id" }),
};
});
@@ -14,7 +13,7 @@ vi.mock("react-router", async () => {
vi.mock("react-i18next", async () => {
const actual = await vi.importActual("react-i18next");
return {
...actual as object,
...(actual as object),
useTranslation: () => ({
t: (key: string) => key,
i18n: {
@@ -28,7 +27,6 @@ import { screen } from "@testing-library/react";
import { renderWithProviders } from "../../test-utils";
import { BrowserPanel } from "#/components/features/browser/browser";
describe("Browser", () => {
afterEach(() => {
vi.clearAllMocks();

View File

@@ -2,36 +2,42 @@ import { describe, expect, it } from "vitest";
import { screen } from "@testing-library/react";
import { renderWithProviders } from "test-utils";
import { ExpandableMessage } from "#/components/features/chat/expandable-message";
import { vi } from 'vitest';
import { vi } from "vitest";
vi.mock('react-i18next', async () => {
const actual = await vi.importActual('react-i18next');
vi.mock("react-i18next", async () => {
const actual = await vi.importActual("react-i18next");
return {
...actual,
useTranslation: () => ({
t: (key:string) => key,
t: (key: string) => key,
i18n: {
changeLanguage: () => new Promise(() => {}),
language: 'en',
language: "en",
exists: () => true,
},
}),
}
};
});
describe("ExpandableMessage", () => {
it("should render with neutral border for non-action messages", () => {
renderWithProviders(<ExpandableMessage message="Hello" type="thought" />);
const element = screen.getByText("Hello");
const container = element.closest("div.flex.gap-2.items-center.justify-start");
const container = element.closest(
"div.flex.gap-2.items-center.justify-start",
);
expect(container).toHaveClass("border-neutral-300");
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
});
it("should render with neutral border for error messages", () => {
renderWithProviders(<ExpandableMessage message="Error occurred" type="error" />);
renderWithProviders(
<ExpandableMessage message="Error occurred" type="error" />,
);
const element = screen.getByText("Error occurred");
const container = element.closest("div.flex.gap-2.items-center.justify-start");
const container = element.closest(
"div.flex.gap-2.items-center.justify-start",
);
expect(container).toHaveClass("border-danger");
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
});
@@ -43,10 +49,12 @@ describe("ExpandableMessage", () => {
message="Command executed successfully"
type="action"
success={true}
/>
/>,
);
const element = screen.getByText("OBSERVATION_MESSAGE$RUN");
const container = element.closest("div.flex.gap-2.items-center.justify-start");
const container = element.closest(
"div.flex.gap-2.items-center.justify-start",
);
expect(container).toHaveClass("border-neutral-300");
const icon = screen.getByTestId("status-icon");
expect(icon).toHaveClass("fill-success");
@@ -59,10 +67,12 @@ describe("ExpandableMessage", () => {
message="Command failed"
type="action"
success={false}
/>
/>,
);
const element = screen.getByText("OBSERVATION_MESSAGE$RUN");
const container = element.closest("div.flex.gap-2.items-center.justify-start");
const container = element.closest(
"div.flex.gap-2.items-center.justify-start",
);
expect(container).toHaveClass("border-neutral-300");
const icon = screen.getByTestId("status-icon");
expect(icon).toHaveClass("fill-danger");
@@ -74,10 +84,12 @@ describe("ExpandableMessage", () => {
id="OBSERVATION_MESSAGE$RUN"
message="Running command"
type="action"
/>
/>,
);
const element = screen.getByText("OBSERVATION_MESSAGE$RUN");
const container = element.closest("div.flex.gap-2.items-center.justify-start");
const container = element.closest(
"div.flex.gap-2.items-center.justify-start",
);
expect(container).toHaveClass("border-neutral-300");
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
});

View File

@@ -128,7 +128,7 @@ describe("Sidebar", () => {
await user.click(norskOption);
const tokenInput =
within(accountSettingsModal).getByLabelText(/GITHUB\$TOKEN_OPTIONAL/i);
within(accountSettingsModal).getByLabelText(/GITHUB\$TOKEN_LABEL/i);
await user.type(tokenInput, "new-token");
const saveButton =

View File

@@ -1,11 +1,10 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import * as router from "react-router";
import { afterEach, describe, expect, it, vi } from "vitest";
// Mock useParams before importing components
vi.mock("react-router", async () => {
const actual = await vi.importActual("react-router");
return {
...actual as object,
...(actual as object),
useParams: () => ({ conversationId: "test-conversation-id" }),
};
});
@@ -60,7 +59,9 @@ describe("FeedbackForm", () => {
renderWithProviders(
<FeedbackForm polarity="positive" onClose={onCloseMock} />,
);
await user.click(screen.getByRole("button", { name: I18nKey.FEEDBACK$CANCEL_LABEL }));
await user.click(
screen.getByRole("button", { name: I18nKey.FEEDBACK$CANCEL_LABEL }),
);
expect(onCloseMock).toHaveBeenCalled();
});

View File

@@ -1,5 +1,4 @@
import { afterEach, beforeAll, describe, expect, it, vi } from "vitest";
import * as router from "react-router";
import { createRoutesStub } from "react-router";
import { screen, waitFor, within } from "@testing-library/react";
import { renderWithProviders } from "test-utils";

View File

@@ -0,0 +1,20 @@
import { vi } from "vitest";
import OpenHands from "#/api/open-hands";
export const setupTestConfig = () => {
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
getConfigSpy.mockResolvedValue({
APP_MODE: "oss",
GITHUB_CLIENT_ID: "test-id",
POSTHOG_CLIENT_KEY: "test-key",
});
};
export const setupSaasTestConfig = () => {
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
getConfigSpy.mockResolvedValue({
APP_MODE: "saas",
GITHUB_CLIENT_ID: "test-id",
POSTHOG_CLIENT_KEY: "test-key",
});
};

View File

@@ -41,6 +41,7 @@ export const isGitHubErrorReponse = <T extends object | Array<unknown>>(
// Axios interceptor to handle token refresh
const setupAxiosInterceptors = (
appMode: string,
refreshToken: () => Promise<boolean>,
logout: () => void,
) => {
@@ -74,18 +75,21 @@ const setupAxiosInterceptors = (
!originalRequest._retry // Prevent infinite retry loops
) {
originalRequest._retry = true;
try {
const refreshed = await refreshToken();
if (refreshed) {
return await github(originalRequest);
}
logout();
return await Promise.reject(new Error("Failed to refresh token"));
} catch (refreshError) {
// If token refresh fails, evict the user
logout();
return Promise.reject(refreshError);
if (appMode === "saas") {
try {
const refreshed = await refreshToken();
if (refreshed) {
return await github(originalRequest);
}
logout();
return await Promise.reject(new Error("Failed to refresh token"));
} catch (refreshError) {
// If token refresh fails, evict the user
logout();
return Promise.reject(refreshError);
}
}
}

View File

@@ -13,6 +13,7 @@ import { LoadingSpinner } from "#/components/shared/loading-spinner";
import { AccountSettingsModal } from "#/components/shared/modals/account-settings/account-settings-modal";
import { SettingsModal } from "#/components/shared/modals/settings/settings-modal";
import { useCurrentSettings } from "#/context/settings-context";
import { useSettings } from "#/hooks/query/use-settings";
import { ConversationPanel } from "../conversation-panel/conversation-panel";
import { MULTI_CONVERSATION_UI } from "#/utils/feature-flags";
import { useEndSession } from "#/hooks/use-end-session";
@@ -27,7 +28,13 @@ export function Sidebar() {
const user = useGitHubUser();
const { data: isAuthed } = useIsAuthed();
const { logout } = useAuth();
const { isUpToDate: settingsAreUpToDate, settings } = useCurrentSettings();
const {
data: settings,
isError: settingsIsError,
isSuccess: settingsSuccessfulyFetched,
} = useSettings();
const { isUpToDate: settingsAreUpToDate } = useCurrentSettings();
const [accountSettingsModalOpen, setAccountSettingsModalOpen] =
React.useState(false);
@@ -103,12 +110,13 @@ export function Sidebar() {
{accountSettingsModalOpen && (
<AccountSettingsModal onClose={handleAccountSettingsModalClose} />
)}
{showSettingsModal && settings && (
<SettingsModal
settings={settings}
onClose={() => setSettingsModalIsOpen(false)}
/>
)}
{settingsIsError ||
(showSettingsModal && settingsSuccessfulyFetched && (
<SettingsModal
settings={settings}
onClose={() => setSettingsModalIsOpen(false)}
/>
))}
</>
);
}

View File

@@ -91,7 +91,7 @@ export function AccountSettingsForm({
<>
<CustomInput
name="ghToken"
label={t(I18nKey.GITHUB$TOKEN_OPTIONAL)}
label={t(I18nKey.GITHUB$TOKEN_LABEL)}
type="password"
defaultValue={gitHubToken ?? ""}
/>

View File

@@ -26,7 +26,10 @@ export function RuntimeSizeSelector({
id="runtime-size"
name="runtime-size"
defaultSelectedKeys={[String(defaultValue || 1)]}
selectedKeys={[String(defaultValue || 1)]}
isDisabled={isDisabled}
selectionMode="single"
disallowEmptySelection
aria-label={t(I18nKey.SETTINGS_FORM$RUNTIME_SIZE_LABEL)}
classNames={{
trigger: "bg-[#27272A] rounded-md text-sm px-3 py-[10px]",

View File

@@ -87,7 +87,12 @@ function AuthProvider({ children }: React.PropsWithChildren) {
setGitHubToken(storedGitHubToken);
setUserId(userId);
setupGithubAxiosInterceptors(refreshToken, logout);
const setupIntercepter = async () => {
const config = await OpenHands.getConfig();
setupGithubAxiosInterceptors(config.APP_MODE, refreshToken, logout);
};
setupIntercepter();
}, []);
const value = React.useMemo(

View File

@@ -1,6 +1,5 @@
import React from "react";
import { useDispatch, useSelector } from "react-redux";
import { useAuth } from "#/context/auth-context";
import {
useWsClient,
WsClientProviderStatus,
@@ -14,17 +13,12 @@ import { AgentState } from "#/types/agent-state";
export const useWSStatusChange = () => {
const { send, status } = useWsClient();
const { gitHubToken } = useAuth();
const { curAgentState } = useSelector((state: RootState) => state.agent);
const dispatch = useDispatch();
const statusRef = React.useRef<WsClientProviderStatus | null>(null);
const { selectedRepository } = useSelector(
(state: RootState) => state.initialQuery,
);
const { files, importedProjectZip, initialQuery } = useSelector(
const { files, initialQuery } = useSelector(
(state: RootState) => state.initialQuery,
);
@@ -33,30 +27,15 @@ export const useWSStatusChange = () => {
send(createChatMessage(query, base64Files, timestamp));
};
const dispatchInitialQuery = (query: string, additionalInfo: string) => {
if (additionalInfo) {
sendInitialQuery(`${query}\n\n[${additionalInfo}]`, files);
} else {
sendInitialQuery(query, files);
}
const dispatchInitialQuery = (query: string) => {
sendInitialQuery(query, files);
dispatch(clearFiles()); // reset selected files
dispatch(clearInitialQuery()); // reset initial query
};
const handleAgentInit = () => {
let additionalInfo = "";
if (gitHubToken && selectedRepository) {
additionalInfo = `Repository ${selectedRepository} has been cloned to /workspace. Please check the /workspace for files.`;
} else if (importedProjectZip) {
// if there's an uploaded project zip, add it to the chat
additionalInfo =
"Files have been uploaded. Please check the /workspace for files.";
}
if (initialQuery) {
dispatchInitialQuery(initialQuery, additionalInfo);
dispatchInitialQuery(initialQuery);
}
};
React.useEffect(() => {

View File

@@ -107,12 +107,7 @@ class CodeActAgent(Agent):
f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2, ensure_ascii=False).replace("\\n", "\n")}'
)
self.prompt_manager = PromptManager(
microagent_dir=os.path.join(
os.path.dirname(os.path.dirname(openhands.__file__)),
'microagents',
)
if self.config.use_microagents
else None,
microagent_dir=None, # Will be set in step() when we have access to the runtime
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
disabled_microagents=self.config.disabled_microagents,
)
@@ -369,6 +364,14 @@ class CodeActAgent(Agent):
- MessageAction(content) - Message action to run (e.g. ask for clarification)
- AgentFinishAction() - end the interaction
"""
# Initialize the prompt_manager with microagents from the runtime
if self.config.use_microagents and 'runtime' in state.inputs:
# Load microagents from the runtime
runtime = state.inputs['runtime']
microagents = runtime.get_microagents_from_selected_repo(None) # None means current workspace
# Load the microagents into the prompt manager
self.prompt_manager.load_microagents(microagents)
# Continue with pending actions if any
if self.pending_actions:
return self.pending_actions.popleft()

View File

@@ -5,9 +5,14 @@ You are OpenHands agent, a helpful AI assistant that can interact with a compute
* The assistant MUST NOT include comments in the code unless they are necessary to describe non-obvious behavior.
{{ runtime_info }}
</IMPORTANT>
{% if repo_instructions -%}
{% if repository_info %}
<REPOSITORY_INFO>
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
</REPOSITORY_INFO>
{% endif %}
{% if repository_instructions -%}
<REPOSITORY_INSTRUCTIONS>
{{ repo_instructions }}
{{ repository_instructions }}
</REPOSITORY_INSTRUCTIONS>
{% endif %}
{% if runtime_info and runtime_info.available_hosts -%}

View File

@@ -50,6 +50,10 @@ class MicroAgent(Agent):
# history is in reverse order, let's fix it
processed_history.reverse()
# everything starts with a message
# the first message is already in the prompt as the task
# TODO: so we don't need to include it in the history
return json.dumps(processed_history, **kwargs)
def __init__(self, llm: LLM, config: AgentConfig):

View File

@@ -112,12 +112,16 @@ class AgentController:
self.id = sid
self.agent = agent
self.headless_mode = headless_mode
self.is_delegate = is_delegate
# subscribe to the event stream
# the event stream must be set before maybe subscribing to it
self.event_stream = event_stream
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
)
# subscribe to the event stream if this is not a delegate
if not self.is_delegate:
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
)
# state from the previous session, state from a parent agent, or a fresh state
self.set_initial_state(
@@ -165,7 +169,11 @@ class AgentController:
)
# unsubscribe from the event stream
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
# only the root parent controller subscribes to the event stream
if not self.is_delegate:
self.event_stream.unsubscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.id
)
self._closed = True
def log(self, level: str, message: str, extra: dict | None = None) -> None:
@@ -226,9 +234,21 @@ class AgentController:
await self._react_to_exception(reported)
def should_step(self, event: Event) -> bool:
# it might be the delegate's day in the sun
if self.delegate is not None:
return False
if isinstance(event, Action):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
return True
if (
isinstance(event, MessageAction)
and self.get_agent_state() != AgentState.AWAITING_USER_INPUT
):
# TODO: this is fragile, but how else to check if eligible?
return True
if isinstance(event, AgentDelegateAction):
return True
return False
if isinstance(event, Observation):
if isinstance(event, NullObservation) or isinstance(
@@ -244,12 +264,35 @@ class AgentController:
Args:
event (Event): The incoming event to process.
"""
# If we have a delegate that is not finished or errored, forward events to it
if self.delegate is not None:
delegate_state = self.delegate.get_agent_state()
if delegate_state not in (
AgentState.FINISHED,
AgentState.ERROR,
AgentState.REJECTED,
):
# Forward the event to delegate and skip parent processing
asyncio.get_event_loop().run_until_complete(
self.delegate._on_event(event)
)
return
else:
# delegate is done or errored, so end it
self.end_delegate()
return
# continue parent processing only if there's no active delegate
asyncio.get_event_loop().run_until_complete(self._on_event(event))
async def _on_event(self, event: Event) -> None:
if hasattr(event, 'hidden') and event.hidden:
return
# Give others a little chance
await asyncio.sleep(0.01)
# if the event is not filtered out, add it to the history
if not any(isinstance(event, filter_type) for filter_type in self.filter_out):
self.state.history.append(event)
@@ -263,17 +306,22 @@ class AgentController:
self.step()
async def _handle_action(self, action: Action) -> None:
"""Handles actions from the event stream.
Args:
action (Action): The action to handle.
"""
"""Handles an Action from the agent or delegate."""
if isinstance(action, ChangeAgentStateAction):
await self.set_agent_state_to(action.agent_state) # type: ignore
elif isinstance(action, MessageAction):
await self._handle_message_action(action)
elif isinstance(action, AgentDelegateAction):
await self.start_delegate(action)
assert self.delegate is not None
# Post a MessageAction with the task for the delegate
if 'task' in action.inputs:
self.event_stream.add_event(
MessageAction(content='TASK: ' + action.inputs['task']),
EventSource.USER,
)
await self.delegate.set_agent_state_to(AgentState.RUNNING)
return
elif isinstance(action, AgentFinishAction):
self.state.outputs = action.outputs
@@ -491,7 +539,7 @@ class AgentController:
f'start delegate, creating agent {delegate_agent.name} using LLM {llm}',
)
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
# Create the delegate with is_delegate=True so it does NOT subscribe directly
self.delegate = AgentController(
sid=self.id + '-delegate',
agent=delegate_agent,
@@ -504,7 +552,57 @@ class AgentController:
is_delegate=True,
headless_mode=self.headless_mode,
)
await self.delegate.set_agent_state_to(AgentState.RUNNING)
def end_delegate(self) -> None:
"""Ends the currently active delegate (e.g., if it is finished or errored)
so that this controller can resume normal operation.
"""
if self.delegate is None:
return
delegate_state = self.delegate.get_agent_state()
# update iteration that is shared across agents
self.state.iteration = self.delegate.state.iteration
# close the delegate controller before adding new events
asyncio.get_event_loop().run_until_complete(self.delegate.close())
if delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
# retrieve delegate result
delegate_outputs = (
self.delegate.state.outputs if self.delegate.state else {}
)
# prepare delegate result observation
# TODO: replace this with AI-generated summary (#2395)
formatted_output = ', '.join(
f'{key}: {value}' for key, value in delegate_outputs.items()
)
content = (
f'{self.delegate.agent.name} finishes task with {formatted_output}'
)
# emit the delegate result observation
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
self.event_stream.add_event(obs, EventSource.AGENT)
else:
# delegate state is ERROR
# emit AgentDelegateObservation with error content
delegate_outputs = (
self.delegate.state.outputs if self.delegate.state else {}
)
content = (
f'{self.delegate.agent.name} encountered an error during execution.'
)
# emit the delegate result observation
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
self.event_stream.add_event(obs, EventSource.AGENT)
# unset delegate so parent can resume normal handling
self.delegate = None
self.delegateAction = None
async def _step(self) -> None:
"""Executes a single step of the parent or delegate agent. Detects stuck agents and limits on the number of iterations and the task budget."""
@@ -514,14 +612,6 @@ class AgentController:
if self._pending_action:
return
if self.delegate is not None:
assert self.delegate != self
# TODO this conditional will always be false, because the parent controllers are unsubscribed
# remove if it's still useless when delegation is reworked
if self.delegate.get_agent_state() != AgentState.PAUSED:
await self._delegate_step()
return
self.log(
'info',
f'LEVEL {self.state.delegate_level} LOCAL STEP {self.state.local_iteration} GLOBAL STEP {self.state.iteration}',
@@ -611,68 +701,6 @@ class AgentController:
log_level = 'info' if LOG_ALL_EVENTS else 'debug'
self.log(log_level, str(action), extra={'msg_type': 'ACTION'})
async def _delegate_step(self) -> None:
"""Executes a single step of the delegate agent."""
await self.delegate._step() # type: ignore[union-attr]
assert self.delegate is not None
delegate_state = self.delegate.get_agent_state()
self.log('debug', f'Delegate state: {delegate_state}')
if delegate_state == AgentState.ERROR:
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration
# emit AgentDelegateObservation to mark delegate termination due to error
delegate_outputs = (
self.delegate.state.outputs if self.delegate.state else {}
)
content = (
f'{self.delegate.agent.name} encountered an error during execution.'
)
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
self.event_stream.add_event(obs, EventSource.AGENT)
# close the delegate upon error
await self.delegate.close()
# resubscribe parent when delegate is finished
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
)
self.delegate = None
self.delegateAction = None
elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
self.log('debug', 'Delegate agent has finished execution')
# retrieve delegate result
outputs = self.delegate.state.outputs if self.delegate.state else {}
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration
# close delegate controller: we must close the delegate controller before adding new events
await self.delegate.close()
# resubscribe parent when delegate is finished
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
)
# update delegate result observation
# TODO: replace this with AI-generated summary (#2395)
formatted_output = ', '.join(
f'{key}: {value}' for key, value in outputs.items()
)
content = (
f'{self.delegate.agent.name} finishes task with {formatted_output}'
)
obs = AgentDelegateObservation(outputs=outputs, content=content)
# clean up delegate status
self.delegate = None
self.delegateAction = None
self.event_stream.add_event(obs, EventSource.AGENT)
return
async def _handle_traffic_control(
self, limit_type: str, current_value: float, max_value: float
) -> bool:

View File

@@ -138,8 +138,19 @@ class LLMConfig:
This function is used to create an LLMConfig object from a dictionary,
with the exception of the 'draft_editor' key, which is a nested LLMConfig object.
"""
args = {k: v for k, v in llm_config_dict.items() if not isinstance(v, dict)}
if 'draft_editor' in llm_config_dict:
draft_editor_config = LLMConfig(**llm_config_dict['draft_editor'])
args['draft_editor'] = draft_editor_config
# Keep None values to preserve defaults, filter out other dicts
args = {
k: v
for k, v in llm_config_dict.items()
if not isinstance(v, dict) or v is None
}
if (
'draft_editor' in llm_config_dict
and llm_config_dict['draft_editor'] is not None
):
if isinstance(llm_config_dict['draft_editor'], LLMConfig):
args['draft_editor'] = llm_config_dict['draft_editor']
else:
draft_editor_config = LLMConfig(**llm_config_dict['draft_editor'])
args['draft_editor'] = draft_editor_config
return cls(**args)

View File

@@ -41,7 +41,7 @@ class SandboxConfig:
remote_runtime_api_url: str = 'http://localhost:8000'
local_runtime_url: str = 'http://localhost'
keep_runtime_alive: bool = True
keep_runtime_alive: bool = False
rm_all_containers: bool = False
api_key: str | None = None
base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime
@@ -60,7 +60,7 @@ class SandboxConfig:
runtime_startup_env_vars: dict[str, str] = field(default_factory=dict)
browsergym_eval_env: str | None = None
platform: str | None = None
close_delay: int = 900
close_delay: int = 15
remote_runtime_resource_factor: int = 1
enable_gpu: bool = False
docker_runtime_kwargs: str | None = None

View File

@@ -144,15 +144,48 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
logger.openhands_logger.debug(
'Attempt to load default LLM config from config toml'
)
llm_config = LLMConfig.from_dict(value)
cfg.set_llm_config(llm_config, 'llm')
# TODO clean up draft_editor
# Extract generic LLM fields, keeping draft_editor
generic_llm_fields = {}
for k, v in value.items():
if not isinstance(v, dict) or k == 'draft_editor':
generic_llm_fields[k] = v
generic_llm_config = LLMConfig.from_dict(generic_llm_fields)
cfg.set_llm_config(generic_llm_config, 'llm')
# Process custom named LLM configs
for nested_key, nested_value in value.items():
if isinstance(nested_value, dict):
logger.openhands_logger.debug(
f'Attempt to load group {nested_key} from config toml as llm config'
f'Processing custom LLM config "{nested_key}":'
)
llm_config = LLMConfig.from_dict(nested_value)
cfg.set_llm_config(llm_config, nested_key)
# Apply generic LLM config with custom LLM overrides, e.g.
# [llm]
# model="..."
# num_retries = 5
# [llm.claude]
# model="claude-3-5-sonnet"
# results in num_retries APPLIED to claude-3-5-sonnet
custom_fields = {}
for k, v in nested_value.items():
if not isinstance(v, dict) or k == 'draft_editor':
custom_fields[k] = v
merged_llm_dict = generic_llm_config.__dict__.copy()
merged_llm_dict.update(custom_fields)
# TODO clean up draft_editor
# Handle draft_editor with fallback values:
# - If draft_editor is "null", use None
# - If draft_editor is in custom fields, use that value
# - If draft_editor is not specified, fall back to generic config value
if 'draft_editor' in custom_fields:
if custom_fields['draft_editor'] == 'null':
merged_llm_dict['draft_editor'] = None
else:
merged_llm_dict['draft_editor'] = (
generic_llm_config.draft_editor
)
custom_llm_config = LLMConfig.from_dict(merged_llm_dict)
cfg.set_llm_config(custom_llm_config, nested_key)
elif key is not None and key.lower() == 'security':
logger.openhands_logger.debug(
'Attempt to load security config from config toml'
@@ -458,7 +491,11 @@ def setup_config_from_args(args: argparse.Namespace) -> AppConfig:
# Override with command line arguments if provided
if args.llm_config:
llm_config = get_llm_config_arg(args.llm_config)
# if we didn't already load it, get it from the toml file
if args.llm_config not in config.llms:
llm_config = get_llm_config_arg(args.llm_config)
else:
llm_config = config.llms[args.llm_config]
if llm_config is None:
raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
config.set_llm_config(llm_config)

View File

@@ -65,6 +65,7 @@ class EventStream:
_queue: queue.Queue[Event]
_queue_thread: threading.Thread
_queue_loop: asyncio.AbstractEventLoop | None
_thread_pools: dict[str, dict[str, ThreadPoolExecutor]]
_thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]]
def __init__(self, sid: str, file_store: FileStore):
@@ -72,8 +73,8 @@ class EventStream:
self.file_store = file_store
self._stop_flag = threading.Event()
self._queue: queue.Queue[Event] = queue.Queue()
self._thread_pools: dict[str, dict[str, ThreadPoolExecutor]] = {}
self._thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]] = {}
self._thread_pools = {}
self._thread_loops = {}
self._queue_loop = None
self._queue_thread = threading.Thread(target=self._run_queue_loop)
self._queue_thread.daemon = True
@@ -257,7 +258,7 @@ class EventStream:
def add_event(self, event: Event, source: EventSource):
if hasattr(event, '_id') and event.id is not None:
raise ValueError(
'Event already has an ID. It was probably added back to the EventStream from inside a handler, trigging a loop.'
f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.'
)
with self._lock:
event._id = self._cur_id # type: ignore [attr-defined]
@@ -285,6 +286,8 @@ class EventStream:
event = self._queue.get(timeout=0.1)
except queue.Empty:
continue
# pass each event to each callback in order
for key in sorted(self._subscribers.keys()):
callbacks = self._subscribers[key]
for callback_id in callbacks:

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel
from openhands.core.exceptions import (
MicroAgentValidationError,
)
from openhands.core.logger import openhands_logger as logger
from openhands.microagent.types import MicroAgentMetadata, MicroAgentType
@@ -132,8 +133,10 @@ def load_microagents_from_dir(
]:
"""Load all microagents from the given directory.
Note, legacy repo instructions will not be loaded here.
Args:
microagent_dir: Path to the microagents directory.
microagent_dir: Path to the microagents directory (e.g. .openhands/microagents)
Returns:
Tuple of (repo_agents, knowledge_agents, task_agents) dictionaries
@@ -145,20 +148,24 @@ def load_microagents_from_dir(
knowledge_agents = {}
task_agents = {}
# Load all agents
for file in microagent_dir.rglob('*.md'):
# skip README.md
if file.name == 'README.md':
continue
try:
agent = BaseMicroAgent.load(file)
if isinstance(agent, RepoMicroAgent):
repo_agents[agent.name] = agent
elif isinstance(agent, KnowledgeMicroAgent):
knowledge_agents[agent.name] = agent
elif isinstance(agent, TaskMicroAgent):
task_agents[agent.name] = agent
except Exception as e:
raise ValueError(f'Error loading agent from {file}: {e}')
# Load all agents from .openhands/microagents directory
logger.debug(f'Loading agents from {microagent_dir}')
if microagent_dir.exists():
for file in microagent_dir.rglob('*.md'):
logger.debug(f'Checking file {file}...')
# skip README.md
if file.name == 'README.md':
continue
try:
agent = BaseMicroAgent.load(file)
if isinstance(agent, RepoMicroAgent):
repo_agents[agent.name] = agent
elif isinstance(agent, KnowledgeMicroAgent):
knowledge_agents[agent.name] = agent
elif isinstance(agent, TaskMicroAgent):
task_agents[agent.name] = agent
logger.debug(f'Loaded agent {agent.name} from {file}')
except Exception as e:
raise ValueError(f'Error loading agent from {file}: {e}')
return repo_agents, knowledge_agents, task_agents

View File

@@ -610,10 +610,14 @@ def parse_unified_diff(text):
# - Start at line 1 in the old file and show 6 lines
# - Start at line 1 in the new file and show 6 lines
old = int(h.group(1)) # Starting line in old file
old_len = int(h.group(2)) if len(h.group(2)) > 0 else 1 # Number of lines in old file
old_len = (
int(h.group(2)) if len(h.group(2)) > 0 else 1
) # Number of lines in old file
new = int(h.group(3)) # Starting line in new file
new_len = int(h.group(4)) if len(h.group(4)) > 0 else 1 # Number of lines in new file
new_len = (
int(h.group(4)) if len(h.group(4)) > 0 else 1
) # Number of lines in new file
h = None
break
@@ -622,7 +626,9 @@ def parse_unified_diff(text):
for n in hunk:
# Each line in a unified diff starts with a space (context), + (addition), or - (deletion)
# The first character is the kind, the rest is the line content
kind = n[0] if len(n) > 0 else ' ' # Empty lines in the hunk are treated as context lines
kind = (
n[0] if len(n) > 0 else ' '
) # Empty lines in the hunk are treated as context lines
line = n[1:] if len(n) > 1 else ''
# Process the line based on its kind

View File

@@ -4,10 +4,13 @@ import copy
import json
import os
import random
import shutil
import string
import tempfile
from abc import abstractmethod
from pathlib import Path
from typing import Callable
from zipfile import ZipFile
from requests.exceptions import ConnectionError
@@ -37,9 +40,7 @@ from openhands.events.observation import (
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.microagent import (
BaseMicroAgent,
KnowledgeMicroAgent,
RepoMicroAgent,
TaskMicroAgent,
load_microagents_from_dir,
)
from openhands.runtime.plugins import (
JupyterRequirement,
@@ -125,7 +126,7 @@ class Runtime(FileEditRuntimeMixin):
def setup_initial_env(self) -> None:
if self.attach_to_existing:
return
logger.debug(f'Adding env vars: {self.initial_env_vars}')
logger.debug(f'Adding env vars: {self.initial_env_vars.keys()}')
self.add_env_vars(self.initial_env_vars)
if self.config.sandbox.runtime_startup_env_vars:
self.add_env_vars(self.config.sandbox.runtime_startup_env_vars)
@@ -172,7 +173,7 @@ class Runtime(FileEditRuntimeMixin):
obs = self.run(CmdRunAction(cmd))
if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
raise RuntimeError(
f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
f'Failed to add env vars [{env_vars.keys()}] to environment: {obs.content}'
)
def on_event(self, event: Event) -> None:
@@ -206,7 +207,7 @@ class Runtime(FileEditRuntimeMixin):
source = event.source if event.source else EventSource.AGENT
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
def clone_repo(self, github_token: str, selected_repository: str):
def clone_repo(self, github_token: str, selected_repository: str) -> str:
if not github_token or not selected_repository:
raise ValueError(
'github_token and selected_repository must be provided to clone a repository'
@@ -223,25 +224,42 @@ class Runtime(FileEditRuntimeMixin):
)
self.log('info', f'Cloning repo: {selected_repository}')
self.run_action(action)
return dir_name
def get_microagents_from_selected_repo(
self, selected_repository: str | None
) -> list[BaseMicroAgent]:
"""Load microagents from the selected repository.
If selected_repository is None, load microagents from the current workspace.
This is the main entry point for loading microagents.
"""
loaded_microagents: list[BaseMicroAgent] = []
dir_name = Path('.openhands') / 'microagents'
workspace_root = Path(self.config.workspace_mount_path_in_sandbox)
microagents_dir = workspace_root / '.openhands' / 'microagents'
repo_root = None
if selected_repository:
dir_name = Path('/workspace') / selected_repository.split('/')[1] / dir_name
repo_root = workspace_root / selected_repository.split('/')[1]
microagents_dir = repo_root / '.openhands' / 'microagents'
self.log(
'info',
f'Selected repo: {selected_repository}, loading microagents from {microagents_dir} (inside runtime)',
)
# Legacy Repo Instructions
# Check for legacy .openhands_instructions file
obs = self.read(FileReadAction(path='.openhands_instructions'))
if isinstance(obs, ErrorObservation):
obs = self.read(
FileReadAction(path=str(workspace_root / '.openhands_instructions'))
)
if isinstance(obs, ErrorObservation) and repo_root is not None:
# If the instructions file is not found in the workspace root, try to load it from the repo root
self.log(
'debug',
f'openhands_instructions not present, trying to load from {dir_name}',
f'.openhands_instructions not present, trying to load from repository {microagents_dir=}',
)
obs = self.read(
FileReadAction(path=str(dir_name / '.openhands_instructions'))
FileReadAction(path=str(repo_root / '.openhands_instructions'))
)
if isinstance(obs, FileReadObservation):
@@ -252,44 +270,40 @@ class Runtime(FileEditRuntimeMixin):
)
)
# Check for local repository microagents
files = self.list_files(str(dir_name))
self.log('info', f'Found {len(files)} local microagents.')
if 'repo.md' in files:
obs = self.read(FileReadAction(path=str(dir_name / 'repo.md')))
if isinstance(obs, FileReadObservation):
self.log('info', 'repo.md microagent loaded.')
loaded_microagents.append(
RepoMicroAgent.load(
path=str(dir_name / 'repo.md'), file_content=obs.content
)
)
# Load microagents from directory
files = self.list_files(str(microagents_dir))
if files:
self.log('info', f'Found {len(files)} files in microagents directory.')
zip_path = self.copy_from(str(microagents_dir))
microagent_folder = tempfile.mkdtemp()
if 'knowledge' in files:
knowledge_dir = dir_name / 'knowledge'
_knowledge_microagents_files = self.list_files(str(knowledge_dir))
for fname in _knowledge_microagents_files:
obs = self.read(FileReadAction(path=str(knowledge_dir / fname)))
if isinstance(obs, FileReadObservation):
self.log('info', f'knowledge/{fname} microagent loaded.')
loaded_microagents.append(
KnowledgeMicroAgent.load(
path=str(knowledge_dir / fname), file_content=obs.content
)
)
# Properly handle the zip file
with ZipFile(zip_path, 'r') as zip_file:
zip_file.extractall(microagent_folder)
# Add debug print of directory structure
self.log('debug', 'Microagent folder structure:')
for root, _, files in os.walk(microagent_folder):
relative_path = os.path.relpath(root, microagent_folder)
self.log('debug', f'Directory: {relative_path}/')
for file in files:
self.log('debug', f' File: {os.path.join(relative_path, file)}')
# Clean up the temporary zip file
zip_path.unlink()
# Load all microagents using the existing function
repo_agents, knowledge_agents, task_agents = load_microagents_from_dir(
microagent_folder
)
self.log(
'info',
f'Loaded {len(repo_agents)} repo agents, {len(knowledge_agents)} knowledge agents, and {len(task_agents)} task agents',
)
loaded_microagents.extend(repo_agents.values())
loaded_microagents.extend(knowledge_agents.values())
loaded_microagents.extend(task_agents.values())
shutil.rmtree(microagent_folder)
if 'tasks' in files:
tasks_dir = dir_name / 'tasks'
_tasks_microagents_files = self.list_files(str(tasks_dir))
for fname in _tasks_microagents_files:
obs = self.read(FileReadAction(path=str(tasks_dir / fname)))
if isinstance(obs, FileReadObservation):
self.log('info', f'tasks/{fname} microagent loaded.')
loaded_microagents.append(
TaskMicroAgent.load(
path=str(tasks_dir / fname), file_content=obs.content
)
)
return loaded_microagents
def run_action(self, action: Action) -> Observation:

View File

@@ -9,6 +9,7 @@ from openhands.core.exceptions import AgentRuntimeBuildError
from openhands.core.logger import openhands_logger as logger
from openhands.runtime.builder import RuntimeBuilder
from openhands.runtime.utils.request import send_request
from openhands.utils.http_session import HttpSession
from openhands.utils.shutdown_listener import (
should_continue,
sleep_if_should_continue,
@@ -18,12 +19,10 @@ from openhands.utils.shutdown_listener import (
class RemoteRuntimeBuilder(RuntimeBuilder):
"""This class interacts with the remote Runtime API for building and managing container images."""
def __init__(
self, api_url: str, api_key: str, session: requests.Session | None = None
):
def __init__(self, api_url: str, api_key: str, session: HttpSession | None = None):
self.api_url = api_url
self.api_key = api_key
self.session = session or requests.Session()
self.session = session or HttpSession()
self.session.headers.update({'X-API-Key': self.api_key})
def build(

View File

@@ -35,6 +35,7 @@ from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.runtime.base import Runtime
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils.request import send_request
from openhands.utils.http_session import HttpSession
class ActionExecutionClient(Runtime):
@@ -55,7 +56,7 @@ class ActionExecutionClient(Runtime):
attach_to_existing: bool = False,
headless_mode: bool = True,
):
self.session = requests.Session()
self.session = HttpSession()
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
self._runtime_initialized: bool = False
self._vscode_token: str | None = None # initial dummy value

View File

@@ -229,6 +229,13 @@ class RemoteRuntime(ActionExecutionClient):
raise AgentRuntimeUnavailableError() from e
def _resume_runtime(self):
"""
1. Show status update that runtime is being started.
2. Send the runtime API a /resume request
3. Poll for the runtime to be ready
4. Update env vars
"""
self.send_status_message('STATUS$STARTING_RUNTIME')
with self._send_runtime_api_request(
'POST',
f'{self.config.sandbox.remote_runtime_api_url}/resume',

View File

@@ -4,6 +4,7 @@ from typing import Any
import requests
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
from openhands.utils.http_session import HttpSession
from openhands.utils.tenacity_stop import stop_if_should_exit
@@ -34,7 +35,7 @@ def is_retryable_error(exception):
wait=wait_exponential(multiplier=1, min=4, max=60),
)
def send_request(
session: requests.Session,
session: HttpSession,
method: str,
url: str,
timeout: int = 10,
@@ -48,11 +49,11 @@ def send_request(
_json = response.json()
except (requests.exceptions.JSONDecodeError, json.decoder.JSONDecodeError):
_json = None
finally:
response.close()
raise RequestHTTPError(
e,
response=e.response,
detail=_json.get('detail') if _json is not None else None,
) from e
finally:
response.close()
return response

View File

@@ -11,7 +11,6 @@ from fastapi import (
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
from openhands.server.middleware import (
AttachConversationMiddleware,
GitHubTokenMiddleware,
InMemoryRateLimiter,
LocalhostCORSMiddleware,
NoCacheMiddleware,
@@ -45,7 +44,6 @@ app.add_middleware(
allow_headers=['*'],
)
app.add_middleware(GitHubTokenMiddleware)
app.add_middleware(NoCacheMiddleware)
app.add_middleware(
RateLimitMiddleware, rate_limiter=InMemoryRateLimiter(requests=10, seconds=1)

View File

@@ -5,8 +5,8 @@ from jwt.exceptions import InvalidTokenError
from openhands.core.logger import openhands_logger as logger
def get_user_id(request: Request) -> int:
return getattr(request.state, 'github_user_id', 0)
def get_user_id(request: Request) -> str | None:
return getattr(request.state, 'github_user_id', None)
def get_sid_from_token(token: str, jwt_secret: str) -> str:

View File

@@ -166,16 +166,3 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
await self._detach_session(request)
return response
class GitHubTokenMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.url.path.startswith('/api/github'):
github_token = request.headers.get('X-GitHub-Token')
if not github_token:
return JSONResponse(
status_code=400,
content={'error': 'Missing X-GitHub-Token header'},
)
request.state.github_token = github_token
return await call_next(request)

View File

@@ -1,5 +1,5 @@
import requests
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from openhands.server.shared import openhands_config
@@ -8,13 +8,23 @@ from openhands.utils.async_utils import call_sync_from_async
app = APIRouter(prefix='/api/github')
def require_github_token(request: Request):
github_token = request.headers.get('X-GitHub-Token')
if not github_token:
raise HTTPException(
status_code=400,
detail='Missing X-GitHub-Token header',
)
return github_token
@app.get('/repositories')
async def get_github_repositories(
request: Request,
page: int = 1,
per_page: int = 10,
sort: str = 'pushed',
installation_id: int | None = None,
github_token: str = Depends(require_github_token),
):
openhands_config.verify_github_repo_list(installation_id)
@@ -33,7 +43,7 @@ async def get_github_repositories(
params['sort'] = sort
# Set the authorization header with the GitHub token
headers = generate_github_headers(request.state.github_token)
headers = generate_github_headers(github_token)
# Fetch repositories from GitHub
try:
@@ -59,8 +69,8 @@ async def get_github_repositories(
@app.get('/user')
async def get_github_user(request: Request):
headers = generate_github_headers(request.state.github_token)
async def get_github_user(github_token: str = Depends(require_github_token)):
headers = generate_github_headers(github_token)
try:
response = await call_sync_from_async(
requests.get, 'https://api.github.com/user', headers=headers
@@ -79,8 +89,10 @@ async def get_github_user(request: Request):
@app.get('/installations')
async def get_github_installation_ids(request: Request):
headers = generate_github_headers(request.state.github_token)
async def get_github_installation_ids(
github_token: str = Depends(require_github_token),
):
headers = generate_github_headers(github_token)
try:
response = await call_sync_from_async(
requests.get, 'https://api.github.com/user/installations', headers=headers
@@ -102,13 +114,13 @@ async def get_github_installation_ids(request: Request):
@app.get('/search/repositories')
async def search_github_repositories(
request: Request,
query: str,
per_page: int = 5,
sort: str = 'stars',
order: str = 'desc',
github_token: str = Depends(require_github_token),
):
headers = generate_github_headers(request.state.github_token)
headers = generate_github_headers(github_token)
params = {
'q': query,
'per_page': per_page,

View File

@@ -130,7 +130,7 @@ async def search_conversations(
for conversation in conversation_metadata_result_set.results
if hasattr(conversation, 'created_at')
)
running_conversations = await session_manager.get_agent_loop_running(
running_conversations = await session_manager.get_running_agent_loops(
get_user_id(request), set(conversation_ids)
)
result = ConversationInfoResultSet(
@@ -222,7 +222,7 @@ async def _get_conversation_info(
def _create_conversation_update_callback(
user_id: int, conversation_id: str
user_id: str | None, conversation_id: str
) -> Callable:
def callback(*args, **kwargs):
call_async_from_sync(
@@ -235,7 +235,7 @@ def _create_conversation_update_callback(
return callback
async def _update_timestamp_for_conversation(user_id: int, conversation_id: str):
async def _update_timestamp_for_conversation(user_id: str, conversation_id: str):
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
conversation = await conversation_store.get_metadata(conversation_id)
conversation.last_updated_at = datetime.now(timezone.utc)

View File

@@ -1,4 +1,5 @@
import asyncio
import time
from typing import Callable, Optional
from openhands.controller import AgentController
@@ -16,10 +17,10 @@ from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.security import SecurityAnalyzer, options
from openhands.storage.files import FileStore
from openhands.utils.async_utils import call_async_from_sync, call_sync_from_async
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.shutdown_listener import should_continue
WAIT_TIME_BEFORE_CLOSE = 300
WAIT_TIME_BEFORE_CLOSE = 90
WAIT_TIME_BEFORE_CLOSE_INTERVAL = 5
@@ -36,7 +37,8 @@ class AgentSession:
controller: AgentController | None = None
runtime: Runtime | None = None
security_analyzer: SecurityAnalyzer | None = None
_initializing: bool = False
_starting: bool = False
_started_at: float = 0
_closed: bool = False
loop: asyncio.AbstractEventLoop | None = None
@@ -88,7 +90,8 @@ class AgentSession:
if self._closed:
logger.warning('Session closed before starting')
return
self._initializing = True
self._starting = True
self._started_at = time.time()
self._create_security_analyzer(config.security.security_analyzer)
await self._create_runtime(
runtime_name=runtime_name,
@@ -109,24 +112,19 @@ class AgentSession:
self.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
)
self._initializing = False
self._starting = False
def close(self):
async def close(self):
"""Closes the Agent session"""
if self._closed:
return
self._closed = True
call_async_from_sync(self._close)
async def _close(self):
seconds_waited = 0
while self._initializing and should_continue():
while self._starting and should_continue():
logger.debug(
f'Waiting for initialization to finish before closing session {self.sid}'
)
await asyncio.sleep(WAIT_TIME_BEFORE_CLOSE_INTERVAL)
seconds_waited += WAIT_TIME_BEFORE_CLOSE_INTERVAL
if seconds_waited > WAIT_TIME_BEFORE_CLOSE:
if time.time() <= self._started_at + WAIT_TIME_BEFORE_CLOSE:
logger.error(
f'Waited too long for initialization to finish before closing session {self.sid}'
)
@@ -212,8 +210,9 @@ class AgentSession:
)
return
repo_directory = None
if selected_repository:
await call_sync_from_async(
repo_directory = await call_sync_from_async(
self.runtime.clone_repo, github_token, selected_repository
)
@@ -223,6 +222,10 @@ class AgentSession:
self.runtime.get_microagents_from_selected_repo, selected_repository
)
agent.prompt_manager.load_microagents(microagents)
if selected_repository and repo_directory:
agent.prompt_manager.set_repository_info(
selected_repository, repo_directory
)
logger.debug(
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
@@ -306,3 +309,12 @@ class AgentSession:
else:
logger.debug('No events found, no state to restore')
return restored_state
def get_state(self) -> AgentState | None:
controller = self.controller
if controller:
return controller.state.agent_state
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
return AgentState.ERROR
return None

View File

@@ -2,6 +2,7 @@ import asyncio
import json
import time
from dataclasses import dataclass, field
from typing import Generic, Iterable, TypeVar
from uuid import uuid4
import socketio
@@ -9,26 +10,28 @@ import socketio
from openhands.core.config import AppConfig
from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.stream import EventStream, session_exists
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import ROOM_KEY, Session
from openhands.server.settings import Settings
from openhands.storage.files import FileStore
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.async_utils import wait_all
from openhands.utils.shutdown_listener import should_continue
_REDIS_POLL_TIMEOUT = 1.5
_CHECK_ALIVE_INTERVAL = 15
_CLEANUP_INTERVAL = 15
_CLEANUP_EXCEPTION_WAIT_TIME = 15
MAX_RUNNING_CONVERSATIONS = 3
T = TypeVar('T')
@dataclass
class _SessionIsRunningCheck:
request_id: str
request_sids: list[str]
running_sids: set[str] = field(default_factory=set)
class _ClusterQuery(Generic[T]):
query_id: str
request_ids: set[str] | None
result: T
flag: asyncio.Event = field(default_factory=asyncio.Event)
@@ -38,10 +41,10 @@ class SessionManager:
config: AppConfig
file_store: FileStore
_local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict)
local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
_local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
_last_alive_timestamps: dict[str, float] = field(default_factory=dict)
_redis_listen_task: asyncio.Task | None = None
_session_is_running_checks: dict[str, _SessionIsRunningCheck] = field(
_running_sid_queries: dict[str, _ClusterQuery[set[str]]] = field(
default_factory=dict
)
_active_conversations: dict[str, tuple[Conversation, int]] = field(
@@ -52,7 +55,7 @@ class SessionManager:
)
_conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_cleanup_task: asyncio.Task | None = None
_has_remote_connections_flags: dict[str, asyncio.Event] = field(
_connection_queries: dict[str, _ClusterQuery[dict[str, str]]] = field(
default_factory=dict
)
@@ -60,7 +63,7 @@ class SessionManager:
redis_client = self._get_redis_client()
if redis_client:
self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
self._cleanup_task = asyncio.create_task(self._cleanup_detached_conversations())
self._cleanup_task = asyncio.create_task(self._cleanup_stale())
return self
async def __aexit__(self, exc_type, exc_value, traceback):
@@ -82,7 +85,7 @@ class SessionManager:
logger.debug('_redis_subscribe')
redis_client = self._get_redis_client()
pubsub = redis_client.pubsub()
await pubsub.subscribe('oh_event')
await pubsub.subscribe('session_msg')
while should_continue():
try:
message = await pubsub.get_message(
@@ -108,59 +111,71 @@ class SessionManager:
session = self._local_agent_loops_by_sid.get(sid)
if session:
await session.dispatch(data['data'])
elif message_type == 'is_session_running':
elif message_type == 'running_agent_loops_query':
# Another node in the cluster is asking if the current node is running the session given.
request_id = data['request_id']
sids = [
sid for sid in data['sids'] if sid in self._local_agent_loops_by_sid
]
query_id = data['query_id']
sids = self._get_running_agent_loops_locally(
data.get('user_id'), data.get('filter_to_sids')
)
if sids:
await self._get_redis_client().publish(
'oh_event',
'session_msg',
json.dumps(
{
'request_id': request_id,
'sids': sids,
'message_type': 'session_is_running',
'query_id': query_id,
'sids': list(sids),
'message_type': 'running_agent_loops_response',
}
),
)
elif message_type == 'session_is_running':
request_id = data['request_id']
elif message_type == 'running_agent_loops_response':
query_id = data['query_id']
for sid in data['sids']:
self._last_alive_timestamps[sid] = time.time()
check = self._session_is_running_checks.get(request_id)
if check:
check.running_sids.update(data['sids'])
if len(check.request_sids) == len(check.running_sids):
check.flag.set()
elif message_type == 'has_remote_connections_query':
running_query = self._running_sid_queries.get(query_id)
if running_query:
running_query.result.update(data['sids'])
if running_query.request_ids is not None and len(
running_query.request_ids
) == len(running_query.result):
running_query.flag.set()
elif message_type == 'connections_query':
# Another node in the cluster is asking if the current node is connected to a session
sid = data['sid']
required = sid in self.local_connection_id_to_session_id.values()
if required:
query_id = data['query_id']
connections = self._get_connections_locally(
data.get('user_id'), data.get('filter_to_sids')
)
if connections:
await self._get_redis_client().publish(
'oh_event',
'session_msg',
json.dumps(
{'sid': sid, 'message_type': 'has_remote_connections_response'}
{
'query_id': query_id,
'connections': connections,
'message_type': 'connections_response',
}
),
)
elif message_type == 'has_remote_connections_response':
sid = data['sid']
flag = self._has_remote_connections_flags.get(sid)
if flag:
flag.set()
elif message_type == 'connections_response':
query_id = data['query_id']
connection_query = self._connection_queries.get(query_id)
if connection_query:
connection_query.result.update(**data['connections'])
if connection_query.request_ids is not None and len(
connection_query.request_ids
) == len(connection_query.result):
connection_query.flag.set()
elif message_type == 'close_session':
sid = data['sid']
if sid in self._local_agent_loops_by_sid:
await self._on_close_session(sid)
await self._close_session(sid)
elif message_type == 'session_closing':
# Session closing event - We only get this in the event of graceful shutdown,
# which can't be guaranteed - nodes can simply vanish unexpectedly!
sid = data['sid']
logger.debug(f'session_closing:{sid}')
# Create a list of items to process to avoid modifying dict during iteration
items = list(self.local_connection_id_to_session_id.items())
items = list(self._local_connection_id_to_session_id.items())
for connection_id, local_sid in items:
if sid == local_sid:
logger.warning(
@@ -204,11 +219,11 @@ class SessionManager:
return c
async def join_conversation(
self, sid: str, connection_id: str, settings: Settings, user_id: int | None
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
):
logger.info(f'join_conversation:{sid}:{connection_id}')
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self.local_connection_id_to_session_id[connection_id] = sid
self._local_connection_id_to_session_id[connection_id] = sid
event_stream = await self._get_event_stream(sid)
if not event_stream:
return await self.maybe_start_agent_loop(sid, settings, user_id)
@@ -226,7 +241,7 @@ class SessionManager:
self._active_conversations.pop(sid)
self._detached_conversations[sid] = (conversation, time.time())
async def _cleanup_detached_conversations(self):
async def _cleanup_stale(self):
while should_continue():
if self._get_redis_client():
# Debug info for HA envs
@@ -240,7 +255,7 @@ class SessionManager:
f'Running agent loops: {len(self._local_agent_loops_by_sid)}'
)
logger.info(
f'Local connections: {len(self.local_connection_id_to_session_id)}'
f'Local connections: {len(self._local_connection_id_to_session_id)}'
)
try:
async with self._conversations_lock:
@@ -250,107 +265,196 @@ class SessionManager:
await conversation.disconnect()
self._detached_conversations.pop(sid, None)
close_threshold = time.time() - self.config.sandbox.close_delay
running_loops = list(self._local_agent_loops_by_sid.items())
running_loops.sort(key=lambda item: item[1].last_active_ts)
sid_to_close: list[str] = []
for sid, session in running_loops:
state = session.agent_session.get_state()
if session.last_active_ts < close_threshold and state not in [
AgentState.RUNNING,
None,
]:
sid_to_close.append(sid)
connections = self._get_connections_locally(
filter_to_sids=set(sid_to_close)
)
connected_sids = {sid for _, sid in connections.items()}
sid_to_close = [
sid for sid in sid_to_close if sid not in connected_sids
]
if sid_to_close:
connections = await self._get_connections_remotely(
filter_to_sids=set(sid_to_close)
)
connected_sids = {sid for _, sid in connections.items()}
sid_to_close = [
sid for sid in sid_to_close if sid not in connected_sids
]
await wait_all(self._close_session(sid) for sid in sid_to_close)
await asyncio.sleep(_CLEANUP_INTERVAL)
except asyncio.CancelledError:
async with self._conversations_lock:
for conversation, _ in self._detached_conversations.values():
await conversation.disconnect()
self._detached_conversations.clear()
await wait_all(
self._close_session(sid) for sid in self._local_agent_loops_by_sid
)
return
except Exception as e:
logger.warning(f'error_cleaning_detached_conversations: {str(e)}')
await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)
async def get_agent_loop_running(self, user_id, sids: set[str]) -> set[str]:
running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid)
check_cluster_sids = [sid for sid in sids if sid not in running_sids]
running_cluster_sids = await self.get_agent_loop_running_in_cluster(
check_cluster_sids
)
running_sids.union(running_cluster_sids)
return running_sids
logger.warning(f'error_cleaning_stale: {str(e)}')
await asyncio.sleep(_CLEANUP_INTERVAL)
async def is_agent_loop_running(self, sid: str) -> bool:
if await self.is_agent_loop_running_locally(sid):
return True
if await self.is_agent_loop_running_in_cluster(sid):
return True
return False
sids = await self.get_running_agent_loops(filter_to_sids={sid})
return bool(sids)
async def is_agent_loop_running_locally(self, sid: str) -> bool:
return sid in self._local_agent_loops_by_sid
async def get_running_agent_loops(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
) -> set[str]:
"""Get the running session ids. If a user is supplied, then the results are limited to session ids for that user. If a set of filter_to_sids is supplied, then results are limited to these ids of interest."""
sids = self._get_running_agent_loops_locally(user_id, filter_to_sids)
remote_sids = await self._get_running_agent_loops_remotely(
user_id, filter_to_sids
)
return sids.union(remote_sids)
async def is_agent_loop_running_in_cluster(self, sid: str) -> bool:
running_sids = await self.get_agent_loop_running_in_cluster([sid])
return bool(running_sids)
def _get_running_agent_loops_locally(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
) -> set[str]:
items: Iterable[tuple[str, Session]] = self._local_agent_loops_by_sid.items()
if filter_to_sids is not None:
items = (item for item in items if item[0] in filter_to_sids)
if user_id:
items = (item for item in items if item[1].user_id == user_id)
sids = {sid for sid, _ in items}
return sids
async def get_agent_loop_running_in_cluster(self, sids: list[str]) -> set[str]:
async def _get_running_agent_loops_remotely(
self,
user_id: str | None = None,
filter_to_sids: set[str] | None = None,
) -> set[str]:
"""As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
redis_client = self._get_redis_client()
if not redis_client:
return set()
flag = asyncio.Event()
request_id = str(uuid4())
check = _SessionIsRunningCheck(request_id=request_id, request_sids=sids)
self._session_is_running_checks[request_id] = check
query_id = str(uuid4())
query = _ClusterQuery[set[str]](
query_id=query_id, request_ids=filter_to_sids, result=set()
)
self._running_sid_queries[query_id] = query
try:
logger.debug(f'publish:is_session_running:{sids}')
await redis_client.publish(
'oh_event',
json.dumps(
{
'request_id': request_id,
'sids': sids,
'message_type': 'is_session_running',
}
),
logger.debug(
f'publish:_get_running_agent_loops_remotely_query:{user_id}:{filter_to_sids}'
)
data: dict = {
'query_id': query_id,
'message_type': 'running_agent_loops_query',
}
if user_id:
data['user_id'] = user_id
if filter_to_sids:
data['filter_to_sids'] = list(filter_to_sids)
await redis_client.publish('session_msg', json.dumps(data))
async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
await flag.wait()
return check.running_sids
return query.result
except TimeoutError:
# Nobody replied in time
return check.running_sids
return query.result
finally:
self._session_is_running_checks.pop(request_id, None)
self._running_sid_queries.pop(query_id, None)
async def get_connections(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
) -> dict[str, str]:
connection_ids = self._get_connections_locally(user_id, filter_to_sids)
remote_connection_ids = await self._get_connections_remotely(
user_id, filter_to_sids
)
connection_ids.update(**remote_connection_ids)
return connection_ids
def _get_connections_locally(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
) -> dict[str, str]:
connections = dict(**self._local_connection_id_to_session_id)
if filter_to_sids is not None:
connections = {
connection_id: sid
for connection_id, sid in connections.items()
if sid in filter_to_sids
}
if user_id:
for connection_id, sid in list(connections.items()):
session = self._local_agent_loops_by_sid.get(sid)
if not session or session.user_id != user_id:
connections.pop(connection_id)
return connections
async def _get_connections_remotely(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
) -> dict[str, str]:
redis_client = self._get_redis_client()
if not redis_client:
return {}
async def _has_remote_connections(self, sid: str) -> bool:
"""As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
# Create a flag for the callback
flag = asyncio.Event()
self._has_remote_connections_flags[sid] = flag
query_id = str(uuid4())
query = _ClusterQuery[dict[str, str]](
query_id=query_id, request_ids=filter_to_sids, result={}
)
self._connection_queries[query_id] = query
try:
await self._get_redis_client().publish(
'oh_event',
json.dumps(
{
'sid': sid,
'message_type': 'has_remote_connections_query',
}
),
logger.debug(
f'publish:get_connections_remotely_query:{user_id}:{filter_to_sids}'
)
data: dict = {
'query_id': query_id,
'message_type': 'connections_query',
}
if user_id:
data['user_id'] = user_id
if filter_to_sids:
data['filter_to_sids'] = list(filter_to_sids)
await redis_client.publish('session_msg', json.dumps(data))
async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
await flag.wait()
result = flag.is_set()
return result
return query.result
except TimeoutError:
# Nobody replied in time
return False
return query.result
finally:
self._has_remote_connections_flags.pop(sid, None)
self._connection_queries.pop(query_id, None)
async def maybe_start_agent_loop(
self, sid: str, settings: Settings, user_id: int | None
self, sid: str, settings: Settings, user_id: str | None
) -> EventStream:
logger.info(f'maybe_start_agent_loop:{sid}')
session: Session | None = None
if not await self.is_agent_loop_running(sid):
logger.info(f'start_agent_loop:{sid}')
response_ids = await self.get_running_agent_loops(user_id)
if len(response_ids) >= MAX_RUNNING_CONVERSATIONS:
logger.info('too_many_sessions_for:{user_id}')
await self.close_session(next(iter(response_ids)))
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
sid=sid,
file_store=self.file_store,
config=self.config,
sio=self.sio,
user_id=user_id,
)
self._local_agent_loops_by_sid[sid] = session
asyncio.create_task(session.initialize_agent(settings))
@@ -359,7 +463,6 @@ class SessionManager:
if not event_stream:
logger.error(f'No event stream after starting agent loop: {sid}')
raise RuntimeError(f'no_event_stream:{sid}')
asyncio.create_task(self._cleanup_session_later(sid))
return event_stream
async def _get_event_stream(self, sid: str) -> EventStream | None:
@@ -369,7 +472,7 @@ class SessionManager:
logger.info(f'found_local_agent_loop:{sid}')
return session.agent_session.event_stream
if await self.is_agent_loop_running_in_cluster(sid):
if await self._get_running_agent_loops_remotely(filter_to_sids={sid}):
logger.info(f'found_remote_agent_loop:{sid}')
return EventStream(sid, self.file_store)
@@ -377,7 +480,7 @@ class SessionManager:
async def send_to_event_stream(self, connection_id: str, data: dict):
# If there is a local session running, send to that
sid = self.local_connection_id_to_session_id.get(connection_id)
sid = self._local_connection_id_to_session_id.get(connection_id)
if not sid:
raise RuntimeError(f'no_connected_session:{connection_id}')
@@ -393,11 +496,11 @@ class SessionManager:
next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
if (
next_alive_check > time.time()
or await self.is_agent_loop_running_in_cluster(sid)
or await self._get_running_agent_loops_remotely(filter_to_sids={sid})
):
# Send the event to the other pod
await redis_client.publish(
'oh_event',
'session_msg',
json.dumps(
{
'sid': sid,
@@ -411,75 +514,37 @@ class SessionManager:
raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
async def disconnect_from_session(self, connection_id: str):
sid = self.local_connection_id_to_session_id.pop(connection_id, None)
sid = self._local_connection_id_to_session_id.pop(connection_id, None)
logger.info(f'disconnect_from_session:{connection_id}:{sid}')
if not sid:
# This can occur if the init action was never run.
logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
return
if should_continue():
asyncio.create_task(self._cleanup_session_later(sid))
else:
await self._on_close_session(sid)
async def _cleanup_session_later(self, sid: str):
# Once there have been no connections to a session for a reasonable period, we close it
try:
await asyncio.sleep(self.config.sandbox.close_delay)
finally:
# If the sleep was cancelled, we still want to close these
await self._cleanup_session(sid)
async def _cleanup_session(self, sid: str) -> bool:
# Get local connections
logger.info(f'_cleanup_session:{sid}')
has_local_connections = next(
(True for v in self.local_connection_id_to_session_id.values() if v == sid),
False,
)
if has_local_connections:
return False
# If no local connections, get connections through redis
redis_client = self._get_redis_client()
if redis_client and await self._has_remote_connections(sid):
return False
# We alert the cluster in case they are interested
if redis_client:
await redis_client.publish(
'oh_event',
json.dumps({'sid': sid, 'message_type': 'session_closing'}),
)
await self._on_close_session(sid)
return True
async def close_session(self, sid: str):
session = self._local_agent_loops_by_sid.get(sid)
if session:
await self._on_close_session(sid)
await self._close_session(sid)
redis_client = self._get_redis_client()
if redis_client:
await redis_client.publish(
'oh_event',
'session_msg',
json.dumps({'sid': sid, 'message_type': 'close_session'}),
)
async def _on_close_session(self, sid: str):
async def _close_session(self, sid: str):
logger.info(f'_close_session:{sid}')
# Clear up local variables
connection_ids_to_remove = list(
connection_id
for connection_id, conn_sid in self.local_connection_id_to_session_id.items()
for connection_id, conn_sid in self._local_connection_id_to_session_id.items()
if sid == conn_sid
)
logger.info(f'removing connections: {connection_ids_to_remove}')
for connnnection_id in connection_ids_to_remove:
self.local_connection_id_to_session_id.pop(connnnection_id, None)
self._local_connection_id_to_session_id.pop(connnnection_id, None)
session = self._local_agent_loops_by_sid.pop(sid, None)
if not session:
@@ -488,12 +553,17 @@ class SessionManager:
logger.info(f'closing_session:{session.sid}')
# We alert the cluster in case they are interested
redis_client = self._get_redis_client()
if redis_client:
await redis_client.publish(
'oh_event',
json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
try:
redis_client = self._get_redis_client()
if redis_client:
await redis_client.publish(
'session_msg',
json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
)
except Exception:
logger.info(
'error_publishing_close_session_event', exc_info=True, stack_info=True
)
await call_sync_from_async(session.close)
await session.close()
logger.info(f'closed_session:{session.sid}')

View File

@@ -37,7 +37,7 @@ class Session:
loop: asyncio.AbstractEventLoop
config: AppConfig
file_store: FileStore
user_id: int | None
user_id: str | None
def __init__(
self,
@@ -45,7 +45,7 @@ class Session:
config: AppConfig,
file_store: FileStore,
sio: socketio.AsyncServer | None,
user_id: int | None = None,
user_id: str | None = None,
):
self.sid = sid
self.sio = sio
@@ -62,9 +62,17 @@ class Session:
self.loop = asyncio.get_event_loop()
self.user_id = user_id
def close(self):
async def close(self):
if self.sio:
await self.sio.emit(
'oh_event',
event_to_dict(
AgentStateChangedObservation('', AgentState.STOPPED.value)
),
to=ROOM_KEY.format(sid=self.sid),
)
self.is_alive = False
self.agent_session.close()
await self.agent_session.close()
async def initialize_agent(
self,

View File

@@ -41,6 +41,6 @@ class ConversationStore(ABC):
@classmethod
@abstractmethod
async def get_instance(
cls, config: AppConfig, user_id: int | None
cls, config: AppConfig, user_id: str | None
) -> ConversationStore:
"""Get a store for the user represented by the token given"""

View File

@@ -92,7 +92,7 @@ class FileConversationStore(ConversationStore):
@classmethod
async def get_instance(
cls, config: AppConfig, user_id: int | None
cls, config: AppConfig, user_id: str | None
) -> FileConversationStore:
file_store = get_file_store(config.file_store, config.file_store_path)
return FileConversationStore(file_store)

View File

@@ -5,7 +5,7 @@ from datetime import datetime, timezone
@dataclass
class ConversationMetadata:
conversation_id: str
github_user_id: int
github_user_id: str | None
selected_repository: str | None
title: str | None = None
last_updated_at: datetime | None = None

View File

@@ -31,7 +31,7 @@ class FileSettingsStore(SettingsStore):
@classmethod
async def get_instance(
cls, config: AppConfig, user_id: int | None
cls, config: AppConfig, user_id: str | None
) -> FileSettingsStore:
file_store = get_file_store(config.file_store, config.file_store_path)
return FileSettingsStore(file_store)

View File

@@ -22,6 +22,6 @@ class SettingsStore(ABC):
@classmethod
@abstractmethod
async def get_instance(
cls, config: AppConfig, user_id: int | None
cls, config: AppConfig, user_id: str | None
) -> SettingsStore:
"""Get a store for the user represented by the token given"""

View File

@@ -0,0 +1,24 @@
from dataclasses import dataclass, field
import requests
@dataclass
class HttpSession:
"""
request.Session is reusable after it has been closed. This behavior makes it
likely to leak file descriptors (Especially when combined with tenacity).
We wrap the session to make it unusable after being closed
"""
session: requests.Session | None = field(default_factory=requests.Session)
def __getattr__(self, name):
if self.session is None:
raise ValueError('session_was_closed')
return object.__getattribute__(self.session, name)
def close(self):
if self.session is not None:
self.session.close()
self.session = None

View File

@@ -20,6 +20,14 @@ class RuntimeInfo:
available_hosts: dict[str, int]
@dataclass
class RepositoryInfo:
"""Information about a GitHub repository that has been cloned."""
repo_name: str | None = None
repo_directory: str | None = None
class PromptManager:
"""
Manages prompt templates and micro-agents for AI interactions.
@@ -42,7 +50,7 @@ class PromptManager:
):
self.disabled_microagents: list[str] = disabled_microagents or []
self.prompt_dir: str = prompt_dir
self.repository_info: RepositoryInfo | None = None
self.system_template: Template = self._load_template('system_prompt')
self.user_template: Template = self._load_template('user_prompt')
self.runtime_info = RuntimeInfo(available_hosts={})
@@ -80,9 +88,6 @@ class PromptManager:
elif isinstance(microagent, RepoMicroAgent):
self.repo_microagents[microagent.name] = microagent
def set_runtime_info(self, runtime: Runtime):
self.runtime_info.available_hosts = runtime.web_hosts
def _load_template(self, template_name: str) -> Template:
if self.prompt_dir is None:
raise ValueError('Prompt directory is not set')
@@ -102,10 +107,31 @@ class PromptManager:
if repo_instructions:
repo_instructions += '\n\n'
repo_instructions += microagent.content
return self.system_template.render(
runtime_info=self.runtime_info, repo_instructions=repo_instructions
repository_instructions=repo_instructions,
repository_info=self.repository_info,
runtime_info=self.runtime_info,
).strip()
def set_runtime_info(self, runtime: Runtime):
self.runtime_info.available_hosts = runtime.web_hosts
def set_repository_info(
self,
repo_name: str,
repo_directory: str,
) -> None:
"""Sets information about the GitHub repository that has been cloned.
Args:
repo_name: The name of the GitHub repository (e.g. 'owner/repo')
repo_directory: The directory where the repository has been cloned
"""
self.repository_info = RepositoryInfo(
repo_name=repo_name, repo_directory=repo_directory
)
def get_example_user_message(self) -> str:
"""This is the initial user message provided to the agent
before *actual* user instructions are provided.

36
poetry.lock generated
View File

@@ -552,17 +552,17 @@ files = [
[[package]]
name = "boto3"
version = "1.35.97"
version = "1.35.98"
description = "The AWS SDK for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "boto3-1.35.97-py3-none-any.whl", hash = "sha256:8e49416216a6e3a62c2a0c44fba4dd2852c85472e7b702516605b1363867d220"},
{file = "boto3-1.35.97.tar.gz", hash = "sha256:7d398f66a11e67777c189d1f58c0a75d9d60f98d0ee51b8817e828930bf19e4e"},
{file = "boto3-1.35.98-py3-none-any.whl", hash = "sha256:d0224e1499d7189b47aa7f469d96522d98df6f5702fccb20a95a436582ebcd9d"},
{file = "boto3-1.35.98.tar.gz", hash = "sha256:4b6274b4fe9d7113f978abea66a1f20c8a397c268c9d1b2a6c96b14a256da4a5"},
]
[package.dependencies]
botocore = ">=1.35.97,<1.36.0"
botocore = ">=1.35.98,<1.36.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.10.0,<0.11.0"
@@ -571,13 +571,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
version = "1.35.97"
version = "1.35.98"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">=3.8"
files = [
{file = "botocore-1.35.97-py3-none-any.whl", hash = "sha256:fed4f156b1a9b8ece53738f702ba5851b8c6216b4952de326547f349cc494f14"},
{file = "botocore-1.35.97.tar.gz", hash = "sha256:88f2fab29192ffe2f2115d5bafbbd823ff4b6eb2774296e03ec8b5b0fe074f61"},
{file = "botocore-1.35.98-py3-none-any.whl", hash = "sha256:4f1c0b687488663a774ad3a5e81a5f94fae1bcada2364cfdc48482c4dbf794d5"},
{file = "botocore-1.35.98.tar.gz", hash = "sha256:d11742b3824bdeac3c89eeeaf5132351af41823bbcef8fc15e95c8250b1de09c"},
]
[package.dependencies]
@@ -927,13 +927,13 @@ numpy = "*"
[[package]]
name = "chromadb"
version = "0.6.2"
version = "0.6.3"
description = "Chroma."
optional = false
python-versions = ">=3.9"
files = [
{file = "chromadb-0.6.2-py3-none-any.whl", hash = "sha256:77a5e07097e36cdd49d8d2925d0c4d28291cabc9677787423d2cc7c426e8895b"},
{file = "chromadb-0.6.2.tar.gz", hash = "sha256:e9e11f04d3850796711ee05dad4e918c75ec7b62ab9cbe7b4588b68a26aaea06"},
{file = "chromadb-0.6.3-py3-none-any.whl", hash = "sha256:4851258489a3612b558488d98d09ae0fe0a28d5cad6bd1ba64b96fdc419dc0e5"},
{file = "chromadb-0.6.3.tar.gz", hash = "sha256:c8f34c0b704b9108b04491480a36d42e894a960429f87c6516027b5481d59ed3"},
]
[package.dependencies]
@@ -2161,13 +2161,13 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
[[package]]
name = "google-api-python-client"
version = "2.158.0"
version = "2.159.0"
description = "Google API Client Library for Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "google_api_python_client-2.158.0-py2.py3-none-any.whl", hash = "sha256:36f8c8d2e79e50f76790ca5946d2f3f8333e210dc8539a6c88e0742416474ad2"},
{file = "google_api_python_client-2.158.0.tar.gz", hash = "sha256:b6664597a9955e04977a62752e33fe44cb35c580e190c1cb08a041893172bd67"},
{file = "google_api_python_client-2.159.0-py2.py3-none-any.whl", hash = "sha256:baef0bb631a60a0bd7c0bf12a5499e3a40cd4388484de7ee55c1950bf820a0cf"},
{file = "google_api_python_client-2.159.0.tar.gz", hash = "sha256:55197f430f25c907394b44fa078545ffef89d33fd4dca501b7db9f0d8e224bd6"},
]
[package.dependencies]
@@ -3707,13 +3707,13 @@ types-tqdm = "*"
[[package]]
name = "litellm"
version = "1.58.0"
version = "1.58.1"
description = "Library to easily interface with LLM API providers"
optional = false
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
files = [
{file = "litellm-1.58.0-py3-none-any.whl", hash = "sha256:1fc07646f6419f1d7b7d06fe2f5c72b3e6e3407423b50cbf45b58dc9740e7a03"},
{file = "litellm-1.58.0.tar.gz", hash = "sha256:db4512e987809e04e59d5b4240bdef3a4bfe575c332397f80cc6b4403c68120a"},
{file = "litellm-1.58.1-py3-none-any.whl", hash = "sha256:eae311273dd7b8be9b1fc92f12c6ec521f86166effd7ae2cec2982dbb9a7dc2c"},
{file = "litellm-1.58.1.tar.gz", hash = "sha256:c73dff605b830815088bdfcc6c42f380b5dc129a184b678646ce981305d11ac6"},
]
[package.dependencies]
@@ -4583,12 +4583,12 @@ type = ["mypy (==1.11.2)"]
[[package]]
name = "modal"
version = "0.72.10"
version = "0.72.11"
description = "Python client library for Modal"
optional = false
python-versions = ">=3.9"
files = [
{file = "modal-0.72.10-py3-none-any.whl", hash = "sha256:eb67516660e00a9dca07378fb251275a913720162e98410f46c79924cb7cd509"},
{file = "modal-0.72.11-py3-none-any.whl", hash = "sha256:97428dfbc2cb2677eb0a17e6189fedf991296245d25e9700c2b9f8978853a2ae"},
]
[package.dependencies]

View File

@@ -26,7 +26,7 @@ test_mount_path = ''
project_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sandbox_test_folder = '/openhands/workspace'
sandbox_test_folder = '/workspace'
def _get_runtime_sid(runtime: Runtime) -> str:
@@ -233,9 +233,10 @@ def _load_runtime(
if use_workspace:
test_mount_path = os.path.join(config.workspace_base, 'rt')
elif temp_dir is not None:
test_mount_path = os.path.join(temp_dir, sid)
test_mount_path = temp_dir
else:
test_mount_path = None
config.workspace_base = test_mount_path
config.workspace_mount_path = test_mount_path
# Mounting folder specific for this test inside the sandbox

View File

@@ -210,7 +210,7 @@ done && echo "success"
def test_cmd_run(temp_dir, runtime_cls, run_as_openhands):
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
try:
obs = _run_cmd_action(runtime, 'ls -l /openhands/workspace')
obs = _run_cmd_action(runtime, 'ls -l /workspace')
assert obs.exit_code == 0
obs = _run_cmd_action(runtime, 'ls -l')
@@ -377,7 +377,7 @@ def test_copy_to_non_existent_directory(temp_dir, runtime_cls):
def test_overwrite_existing_file(temp_dir, runtime_cls):
runtime = _load_runtime(temp_dir, runtime_cls)
try:
sandbox_dir = '/openhands/workspace'
sandbox_dir = '/workspace'
obs = _run_cmd_action(runtime, f'ls -alh {sandbox_dir}')
assert obs.exit_code == 0

View File

@@ -52,7 +52,7 @@ def test_simple_cmd_ipython_and_fileop(temp_dir, runtime_cls, run_as_openhands):
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.content.strip() == (
'Hello, `World`!\n'
'[Jupyter current working directory: /openhands/workspace]\n'
'[Jupyter current working directory: /workspace]\n'
'[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.12/bin/python]'
)
@@ -73,7 +73,7 @@ def test_simple_cmd_ipython_and_fileop(temp_dir, runtime_cls, run_as_openhands):
assert obs.content == ''
# event stream runtime will always use absolute path
assert obs.path == '/openhands/workspace/hello.sh'
assert obs.path == '/workspace/hello.sh'
# Test read file (file should exist)
action_read = FileReadAction(path='hello.sh')
@@ -85,7 +85,7 @@ def test_simple_cmd_ipython_and_fileop(temp_dir, runtime_cls, run_as_openhands):
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.content == 'echo "Hello, World!"\n'
assert obs.path == '/openhands/workspace/hello.sh'
assert obs.path == '/workspace/hello.sh'
# clean up
action = CmdRunAction(command='rm -rf hello.sh')
@@ -188,7 +188,7 @@ def test_ipython_simple(temp_dir, runtime_cls):
obs.content.strip()
== (
'1\n'
'[Jupyter current working directory: /openhands/workspace]\n'
'[Jupyter current working directory: /workspace]\n'
'[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.12/bin/python]'
).strip()
)
@@ -224,7 +224,7 @@ def test_ipython_package_install(temp_dir, runtime_cls, run_as_openhands):
# import should not error out
assert obs.content.strip() == (
'[Code executed successfully with no output]\n'
'[Jupyter current working directory: /openhands/workspace]\n'
'[Jupyter current working directory: /workspace]\n'
'[Jupyter Python interpreter: /openhands/poetry/openhands-ai-5O4_aCHf-py3.12/bin/python]'
)
@@ -273,16 +273,16 @@ def test_ipython_file_editor_permissions_as_openhands(temp_dir, runtime_cls):
# Try to use file editor in openhands sandbox directory - should work
test_code = """
# Create file
print(file_editor(command='create', path='/openhands/workspace/test.txt', file_text='Line 1\\nLine 2\\nLine 3'))
print(file_editor(command='create', path='/workspace/test.txt', file_text='Line 1\\nLine 2\\nLine 3'))
# View file
print(file_editor(command='view', path='/openhands/workspace/test.txt'))
print(file_editor(command='view', path='/workspace/test.txt'))
# Edit file
print(file_editor(command='str_replace', path='/openhands/workspace/test.txt', old_str='Line 2', new_str='New Line 2'))
print(file_editor(command='str_replace', path='/workspace/test.txt', old_str='Line 2', new_str='New Line 2'))
# Undo edit
print(file_editor(command='undo_edit', path='/openhands/workspace/test.txt'))
print(file_editor(command='undo_edit', path='/workspace/test.txt'))
"""
action = IPythonRunCellAction(code=test_code)
logger.info(action, extra={'msg_type': 'ACTION'})
@@ -297,7 +297,7 @@ print(file_editor(command='undo_edit', path='/openhands/workspace/test.txt'))
assert 'undone successfully' in obs.content
# Clean up
action = CmdRunAction(command='rm -f /openhands/workspace/test.txt')
action = CmdRunAction(command='rm -f /workspace/test.txt')
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
@@ -314,7 +314,7 @@ print(file_editor(command='undo_edit', path='/openhands/workspace/test.txt'))
def test_file_read_and_edit_via_oh_aci(runtime_cls, run_as_openhands):
runtime = _load_runtime(None, runtime_cls, run_as_openhands)
sandbox_dir = '/openhands/workspace'
sandbox_dir = '/workspace'
actions = [
{

View File

@@ -0,0 +1,197 @@
"""Tests for microagent loading in runtime."""
from pathlib import Path
from conftest import (
_close_test_runtime,
_load_runtime,
)
from openhands.microagent import KnowledgeMicroAgent, RepoMicroAgent, TaskMicroAgent
def _create_test_microagents(test_dir: str):
"""Create test microagent files in the given directory."""
microagents_dir = Path(test_dir) / '.openhands' / 'microagents'
microagents_dir.mkdir(parents=True, exist_ok=True)
# Create test knowledge agent
knowledge_dir = microagents_dir / 'knowledge'
knowledge_dir.mkdir(exist_ok=True)
knowledge_agent = """---
name: test_knowledge_agent
type: knowledge
version: 1.0.0
agent: CodeActAgent
triggers:
- test
- pytest
---
# Test Guidelines
Testing best practices and guidelines.
"""
(knowledge_dir / 'knowledge.md').write_text(knowledge_agent)
# Create test repo agent
repo_agent = """---
name: test_repo_agent
type: repo
version: 1.0.0
agent: CodeActAgent
---
# Test Repository Agent
Repository-specific test instructions.
"""
(microagents_dir / 'repo.md').write_text(repo_agent)
# Create test task agent in a nested directory
task_dir = microagents_dir / 'tasks' / 'nested'
task_dir.mkdir(parents=True, exist_ok=True)
task_agent = """---
name: test_task
type: task
version: 1.0.0
agent: CodeActAgent
---
# Test Task
Test task content
"""
(task_dir / 'task.md').write_text(task_agent)
# Create legacy repo instructions
legacy_instructions = """# Legacy Instructions
These are legacy repository instructions.
"""
(Path(test_dir) / '.openhands_instructions').write_text(legacy_instructions)
def test_load_microagents_with_trailing_slashes(
temp_dir, runtime_cls, run_as_openhands
):
"""Test loading microagents when directory paths have trailing slashes."""
# Create test files
_create_test_microagents(temp_dir)
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
try:
# Load microagents
loaded_agents = runtime.get_microagents_from_selected_repo(None)
# Verify all agents are loaded
knowledge_agents = [
a for a in loaded_agents if isinstance(a, KnowledgeMicroAgent)
]
repo_agents = [a for a in loaded_agents if isinstance(a, RepoMicroAgent)]
task_agents = [a for a in loaded_agents if isinstance(a, TaskMicroAgent)]
# Check knowledge agents
assert len(knowledge_agents) == 1
agent = knowledge_agents[0]
assert agent.name == 'test_knowledge_agent'
assert 'test' in agent.triggers
assert 'pytest' in agent.triggers
# Check repo agents (including legacy)
assert len(repo_agents) == 2 # repo.md + .openhands_instructions
repo_names = {a.name for a in repo_agents}
assert 'test_repo_agent' in repo_names
assert 'repo_legacy' in repo_names
# Check task agents
assert len(task_agents) == 1
agent = task_agents[0]
assert agent.name == 'test_task'
finally:
_close_test_runtime(runtime)
def test_load_microagents_with_selected_repo(temp_dir, runtime_cls, run_as_openhands):
"""Test loading microagents from a selected repository."""
# Create test files in a repository-like structure
repo_dir = Path(temp_dir) / 'OpenHands'
repo_dir.mkdir(parents=True)
_create_test_microagents(str(repo_dir))
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
try:
# Load microagents with selected repository
loaded_agents = runtime.get_microagents_from_selected_repo(
'All-Hands-AI/OpenHands'
)
# Verify all agents are loaded
knowledge_agents = [
a for a in loaded_agents if isinstance(a, KnowledgeMicroAgent)
]
repo_agents = [a for a in loaded_agents if isinstance(a, RepoMicroAgent)]
task_agents = [a for a in loaded_agents if isinstance(a, TaskMicroAgent)]
# Check knowledge agents
assert len(knowledge_agents) == 1
agent = knowledge_agents[0]
assert agent.name == 'test_knowledge_agent'
assert 'test' in agent.triggers
assert 'pytest' in agent.triggers
# Check repo agents (including legacy)
assert len(repo_agents) == 2 # repo.md + .openhands_instructions
repo_names = {a.name for a in repo_agents}
assert 'test_repo_agent' in repo_names
assert 'repo_legacy' in repo_names
# Check task agents
assert len(task_agents) == 1
agent = task_agents[0]
assert agent.name == 'test_task'
finally:
_close_test_runtime(runtime)
def test_load_microagents_with_missing_files(temp_dir, runtime_cls, run_as_openhands):
"""Test loading microagents when some files are missing."""
# Create only repo.md, no other files
microagents_dir = Path(temp_dir) / '.openhands' / 'microagents'
microagents_dir.mkdir(parents=True, exist_ok=True)
repo_agent = """---
name: test_repo_agent
type: repo
version: 1.0.0
agent: CodeActAgent
---
# Test Repository Agent
Repository-specific test instructions.
"""
(microagents_dir / 'repo.md').write_text(repo_agent)
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
try:
# Load microagents
loaded_agents = runtime.get_microagents_from_selected_repo(None)
# Verify only repo agent is loaded
knowledge_agents = [
a for a in loaded_agents if isinstance(a, KnowledgeMicroAgent)
]
repo_agents = [a for a in loaded_agents if isinstance(a, RepoMicroAgent)]
task_agents = [a for a in loaded_agents if isinstance(a, TaskMicroAgent)]
assert len(knowledge_agents) == 0
assert len(repo_agents) == 1
assert len(task_agents) == 0
agent = repo_agents[0]
assert agent.name == 'test_repo_agent'
finally:
_close_test_runtime(runtime)

View File

@@ -1,12 +1,10 @@
import pytest
from openhands.resolver.patching.apply import apply_diff
from openhands.resolver.patching.exceptions import HunkApplyException
from openhands.resolver.patching.patch import parse_diff, diffobj
from openhands.resolver.patching.patch import diffobj, parse_diff
def test_patch_apply_with_empty_lines():
# The original file has no indentation and uses \n line endings
original_content = "# PR Viewer\n\nThis React application allows you to view open pull requests from GitHub repositories in a GitHub organization. By default, it uses the All-Hands-AI organization.\n\n## Setup"
original_content = '# PR Viewer\n\nThis React application allows you to view open pull requests from GitHub repositories in a GitHub organization. By default, it uses the All-Hands-AI organization.\n\n## Setup'
# The patch has spaces at the start of each line and uses \n line endings
patch = """diff --git a/README.md b/README.md
@@ -19,18 +17,20 @@ index b760a53..5071727 100644
-This React application allows you to view open pull requests from GitHub repositories in a GitHub organization. By default, it uses the All-Hands-AI organization.
+This React application was created by Graham Neubig and OpenHands. It allows you to view open pull requests from GitHub repositories in a GitHub organization. By default, it uses the All-Hands-AI organization."""
print("Original content lines:")
print('Original content lines:')
for i, line in enumerate(original_content.splitlines(), 1):
print(f"{i}: {repr(line)}")
print(f'{i}: {repr(line)}')
print("\nPatch lines:")
print('\nPatch lines:')
for i, line in enumerate(patch.splitlines(), 1):
print(f"{i}: {repr(line)}")
print(f'{i}: {repr(line)}')
changes = parse_diff(patch)
print("\nParsed changes:")
print('\nParsed changes:')
for change in changes:
print(f"Change(old={change.old}, new={change.new}, line={repr(change.line)}, hunk={change.hunk})")
print(
f'Change(old={change.old}, new={change.new}, line={repr(change.line)}, hunk={change.hunk})'
)
diff = diffobj(header=None, changes=changes, text=patch)
# Apply the patch
@@ -38,10 +38,10 @@ index b760a53..5071727 100644
# The patch should be applied successfully
expected_result = [
"# PR Viewer",
"",
"This React application was created by Graham Neubig and OpenHands. It allows you to view open pull requests from GitHub repositories in a GitHub organization. By default, it uses the All-Hands-AI organization.",
"",
"## Setup"
'# PR Viewer',
'',
'This React application was created by Graham Neubig and OpenHands. It allows you to view open pull requests from GitHub repositories in a GitHub organization. By default, it uses the All-Hands-AI organization.',
'',
'## Setup',
]
assert result == expected_result
assert result == expected_result

View File

@@ -1,5 +1,5 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, Mock
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
@@ -130,7 +130,7 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal
@pytest.mark.asyncio
async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
async def test_run_controller_with_fatal_error():
config = AppConfig()
file_store = get_file_store(config.file_store, config.file_store_path)
event_stream = EventStream(sid='test', file_store=file_store)
@@ -239,55 +239,6 @@ async def test_run_controller_stop_with_stuck():
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
@pytest.mark.asyncio
@pytest.mark.parametrize(
'delegate_state',
[
AgentState.RUNNING,
AgentState.FINISHED,
AgentState.ERROR,
AgentState.REJECTED,
],
)
async def test_delegate_step_different_states(
mock_agent, mock_event_stream, delegate_state
):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
mock_delegate = AsyncMock()
controller.delegate = mock_delegate
mock_delegate.state.iteration = 5
mock_delegate.state.outputs = {'result': 'test'}
mock_delegate.agent.name = 'TestDelegate'
mock_delegate.get_agent_state = Mock(return_value=delegate_state)
mock_delegate._step = AsyncMock()
mock_delegate.close = AsyncMock()
await controller._delegate_step()
mock_delegate._step.assert_called_once()
if delegate_state == AgentState.RUNNING:
assert controller.delegate is not None
assert controller.state.iteration == 0
mock_delegate.close.assert_not_called()
else:
assert controller.delegate is None
assert controller.state.iteration == 5
mock_delegate.close.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_max_iterations_extension(mock_agent, mock_event_stream):
# Test with headless_mode=False - should extend max_iterations

View File

@@ -0,0 +1,187 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import AsyncMock, MagicMock, Mock
from uuid import uuid4
import pytest
from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State
from openhands.core.config import LLMConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream
from openhands.events.action import (
AgentDelegateAction,
AgentFinishAction,
MessageAction,
)
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def mock_event_stream():
"""Creates an event stream in memory."""
sid = f'test-{uuid4()}'
file_store = InMemoryFileStore({})
return EventStream(sid=sid, file_store=file_store)
@pytest.fixture
def mock_parent_agent():
"""Creates a mock parent agent for testing delegation."""
agent = MagicMock(spec=Agent)
agent.name = 'ParentAgent'
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = LLMConfig()
agent.config = AgentConfig()
return agent
@pytest.fixture
def mock_child_agent():
"""Creates a mock child agent for testing delegation."""
agent = MagicMock(spec=Agent)
agent.name = 'ChildAgent'
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = LLMConfig()
agent.config = AgentConfig()
return agent
@pytest.mark.asyncio
async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_stream):
"""
Test that when the parent agent delegates to a child, the parent's delegate
is set, and once the child finishes, the parent is cleaned up properly.
"""
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
# Create parent controller
parent_state = State(max_iterations=10)
parent_controller = AgentController(
agent=mock_parent_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='parent',
confirmation_mode=False,
headless_mode=True,
initial_state=parent_state,
)
# Setup a delegate action from the parent
delegate_action = AgentDelegateAction(agent='ChildAgent', inputs={'test': True})
mock_parent_agent.step.return_value = delegate_action
# Simulate a user message event to cause parent.step() to run
message_action = MessageAction(content='please delegate now')
message_action._source = EventSource.USER
await parent_controller._on_event(message_action)
# Give time for the async step() to execute
await asyncio.sleep(1)
# The parent should receive step() from that event
# Verify that a delegate agent controller is created
assert (
parent_controller.delegate is not None
), "Parent's delegate controller was not set."
# The parent's iteration should have incremented
assert (
parent_controller.state.iteration == 1
), 'Parent iteration should be incremented after step.'
# Now simulate that the child increments local iteration and finishes its subtask
delegate_controller = parent_controller.delegate
delegate_controller.state.iteration = 5 # child had some steps
delegate_controller.state.outputs = {'delegate_result': 'done'}
# The child is done, so we simulate it finishing:
child_finish_action = AgentFinishAction()
await delegate_controller._on_event(child_finish_action)
await asyncio.sleep(0.5)
# Now the parent's delegate is None
assert (
parent_controller.delegate is None
), 'Parent delegate should be None after child finishes.'
# Parent's global iteration is updated from the child
assert (
parent_controller.state.iteration == 6
), "Parent iteration should be the child's iteration + 1 after child is done."
# Cleanup
await parent_controller.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'delegate_state',
[
AgentState.RUNNING,
AgentState.FINISHED,
AgentState.ERROR,
AgentState.REJECTED,
],
)
async def test_delegate_step_different_states(
mock_parent_agent, mock_event_stream, delegate_state
):
"""Ensure that delegate is closed or remains open based on the delegate's state."""
controller = AgentController(
agent=mock_parent_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
mock_delegate = AsyncMock()
controller.delegate = mock_delegate
mock_delegate.state.iteration = 5
mock_delegate.state.outputs = {'result': 'test'}
mock_delegate.agent.name = 'TestDelegate'
mock_delegate.get_agent_state = Mock(return_value=delegate_state)
mock_delegate._step = AsyncMock()
mock_delegate.close = AsyncMock()
def call_on_event_with_new_loop():
"""
In this thread, create and set a fresh event loop, so that the run_until_complete()
calls inside controller.on_event(...) find a valid loop.
"""
loop_in_thread = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop_in_thread)
msg_action = MessageAction(content='Test message')
msg_action._source = EventSource.USER
controller.on_event(msg_action)
finally:
loop_in_thread.close()
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
future = loop.run_in_executor(executor, call_on_event_with_new_loop)
await future
if delegate_state == AgentState.RUNNING:
assert controller.delegate is not None
assert controller.state.iteration == 0
mock_delegate.close.assert_not_called()
else:
assert controller.delegate is None
assert controller.state.iteration == 5
mock_delegate.close.assert_called_once()
await controller.close()

View File

@@ -29,7 +29,7 @@ def _patch_store():
'title': 'Some Conversation',
'selected_repository': 'foobar',
'conversation_id': 'some_conversation_id',
'github_user_id': 12345,
'github_user_id': '12345',
'created_at': '2025-01-01T00:00:00',
'last_updated_at': '2025-01-01T00:01:00',
}

View File

@@ -0,0 +1,228 @@
import pathlib
import pytest
from openhands.core.config import AppConfig
from openhands.core.config.utils import load_from_toml
@pytest.fixture
def default_config(monkeypatch):
# Fixture to provide a default AppConfig instance
yield AppConfig()
@pytest.fixture
def generic_llm_toml(tmp_path: pathlib.Path) -> str:
"""Fixture to create a generic LLM TOML configuration with all custom LLMs
providing mandatory 'model' and 'api_key', and testing fallback to the generic section values
for other attributes like 'num_retries'.
"""
toml_content = """
[core]
workspace_base = "./workspace"
[llm]
model = "base-model"
api_key = "base-api-key"
embedding_model = "base-embedding"
num_retries = 3
[llm.custom1]
model = "custom-model-1"
api_key = "custom-api-key-1"
# 'num_retries' is not overridden and should fallback to the value from [llm]
[llm.custom2]
model = "custom-model-2"
api_key = "custom-api-key-2"
num_retries = 5 # Overridden value
[llm.custom3]
model = "custom-model-3"
api_key = "custom-api-key-3"
# No overrides for additional attributes
"""
toml_file = tmp_path / 'llm_config.toml'
toml_file.write_text(toml_content)
return str(toml_file)
def test_load_from_toml_llm_with_fallback(
default_config: AppConfig, generic_llm_toml: str
) -> None:
"""Test that custom LLM configurations fallback non-overridden attributes
like 'num_retries' from the generic [llm] section.
"""
load_from_toml(default_config, generic_llm_toml)
# Verify generic LLM configuration
generic_llm = default_config.get_llm_config('llm')
assert generic_llm.model == 'base-model'
assert generic_llm.api_key == 'base-api-key'
assert generic_llm.embedding_model == 'base-embedding'
assert generic_llm.num_retries == 3
# Verify custom1 LLM falls back 'num_retries' from base
custom1 = default_config.get_llm_config('custom1')
assert custom1.model == 'custom-model-1'
assert custom1.api_key == 'custom-api-key-1'
assert custom1.embedding_model == 'base-embedding'
assert custom1.num_retries == 3 # from [llm]
# Verify custom2 LLM overrides 'num_retries'
custom2 = default_config.get_llm_config('custom2')
assert custom2.model == 'custom-model-2'
assert custom2.api_key == 'custom-api-key-2'
assert custom2.embedding_model == 'base-embedding'
assert custom2.num_retries == 5 # overridden value
# Verify custom3 LLM inherits all attributes except 'model' and 'api_key'
custom3 = default_config.get_llm_config('custom3')
assert custom3.model == 'custom-model-3'
assert custom3.api_key == 'custom-api-key-3'
assert custom3.embedding_model == 'base-embedding'
assert custom3.num_retries == 3 # from [llm]
def test_load_from_toml_llm_custom_overrides_all(
default_config: AppConfig, tmp_path: pathlib.Path
) -> None:
"""Test that a custom LLM can fully override all attributes from the generic [llm] section."""
toml_content = """
[core]
workspace_base = "./workspace"
[llm]
model = "base-model"
api_key = "base-api-key"
embedding_model = "base-embedding"
num_retries = 3
[llm.custom_full]
model = "full-custom-model"
api_key = "full-custom-api-key"
embedding_model = "full-custom-embedding"
num_retries = 10
"""
toml_file = tmp_path / 'full_override_llm.toml'
toml_file.write_text(toml_content)
load_from_toml(default_config, str(toml_file))
# Verify generic LLM configuration remains unchanged
generic_llm = default_config.get_llm_config('llm')
assert generic_llm.model == 'base-model'
assert generic_llm.api_key == 'base-api-key'
assert generic_llm.embedding_model == 'base-embedding'
assert generic_llm.num_retries == 3
# Verify custom_full LLM overrides all attributes
custom_full = default_config.get_llm_config('custom_full')
assert custom_full.model == 'full-custom-model'
assert custom_full.api_key == 'full-custom-api-key'
assert custom_full.embedding_model == 'full-custom-embedding'
assert custom_full.num_retries == 10 # overridden value
def test_load_from_toml_llm_custom_partial_override(
default_config: AppConfig, generic_llm_toml: str
) -> None:
"""Test that custom LLM configurations can partially override attributes
from the generic [llm] section while inheriting others.
"""
load_from_toml(default_config, generic_llm_toml)
# Verify custom1 LLM overrides 'model' and 'api_key' but inherits 'num_retries'
custom1 = default_config.get_llm_config('custom1')
assert custom1.model == 'custom-model-1'
assert custom1.api_key == 'custom-api-key-1'
assert custom1.embedding_model == 'base-embedding'
assert custom1.num_retries == 3 # from [llm]
# Verify custom2 LLM overrides 'model', 'api_key', and 'num_retries'
custom2 = default_config.get_llm_config('custom2')
assert custom2.model == 'custom-model-2'
assert custom2.api_key == 'custom-api-key-2'
assert custom2.embedding_model == 'base-embedding'
assert custom2.num_retries == 5 # Overridden value
def test_load_from_toml_llm_custom_no_override(
default_config: AppConfig, generic_llm_toml: str
) -> None:
"""Test that custom LLM configurations with no additional overrides
inherit all non-specified attributes from the generic [llm] section.
"""
load_from_toml(default_config, generic_llm_toml)
# Verify custom3 LLM inherits 'embedding_model' and 'num_retries' from generic
custom3 = default_config.get_llm_config('custom3')
assert custom3.model == 'custom-model-3'
assert custom3.api_key == 'custom-api-key-3'
assert custom3.embedding_model == 'base-embedding'
assert custom3.num_retries == 3 # from [llm]
def test_load_from_toml_llm_missing_generic(
default_config: AppConfig, tmp_path: pathlib.Path
) -> None:
"""Test that custom LLM configurations without a generic [llm] section
use only their own attributes and fallback to defaults for others.
"""
toml_content = """
[core]
workspace_base = "./workspace"
[llm.custom_only]
model = "custom-only-model"
api_key = "custom-only-api-key"
"""
toml_file = tmp_path / 'custom_only_llm.toml'
toml_file.write_text(toml_content)
load_from_toml(default_config, str(toml_file))
# Verify custom_only LLM uses its own attributes and defaults for others
custom_only = default_config.get_llm_config('custom_only')
assert custom_only.model == 'custom-only-model'
assert custom_only.api_key == 'custom-only-api-key'
assert custom_only.embedding_model == 'local' # default value
assert custom_only.num_retries == 8 # default value
def test_load_from_toml_llm_invalid_config(
default_config: AppConfig, tmp_path: pathlib.Path
) -> None:
"""Test that invalid custom LLM configurations do not override the generic
and raise appropriate warnings.
"""
toml_content = """
[core]
workspace_base = "./workspace"
[llm]
model = "base-model"
api_key = "base-api-key"
num_retries = 3
[llm.invalid_custom]
unknown_attr = "should_not_exist"
"""
toml_file = tmp_path / 'invalid_custom_llm.toml'
toml_file.write_text(toml_content)
load_from_toml(default_config, str(toml_file))
# Verify generic LLM is loaded correctly
generic_llm = default_config.get_llm_config('llm')
assert generic_llm.model == 'base-model'
assert generic_llm.api_key == 'base-api-key'
assert generic_llm.num_retries == 3
# Verify invalid_custom LLM does not override generic attributes
custom_invalid = default_config.get_llm_config('invalid_custom')
assert custom_invalid.model == 'base-model'
assert custom_invalid.api_key == 'base-api-key'
assert custom_invalid.num_retries == 3 # default value
assert custom_invalid.embedding_model == 'local' # default value

View File

@@ -0,0 +1,92 @@
import pathlib
import pytest
from openhands.core.config import AppConfig
from openhands.core.config.utils import load_from_toml
@pytest.fixture
def draft_llm_toml(tmp_path: pathlib.Path) -> str:
toml_content = """
[core]
workspace_base = "./workspace"
[llm]
model = "base-model"
api_key = "base-api-key"
draft_editor = { model = "draft-model", api_key = "draft-api-key" }
[llm.custom1]
model = "custom-model-1"
api_key = "custom-api-key-1"
# Should use draft_editor from [llm] as fallback
[llm.custom2]
model = "custom-model-2"
api_key = "custom-api-key-2"
draft_editor = { model = "custom-draft", api_key = "custom-draft-key" }
[llm.custom3]
model = "custom-model-3"
api_key = "custom-api-key-3"
draft_editor = "null" # Explicitly set to null in TOML
"""
toml_file = tmp_path / 'llm_config.toml'
toml_file.write_text(toml_content)
return str(toml_file)
def test_draft_editor_fallback(draft_llm_toml):
"""Test that draft_editor is correctly handled in different scenarios:
- Falls back to generic [llm] section value
- Uses custom value when specified
- Can be explicitly set to null
"""
config = AppConfig()
# Verify default draft_editor is None
default_llm = config.get_llm_config('llm')
assert default_llm.draft_editor is None
# Load config from TOML
load_from_toml(config, draft_llm_toml)
# Verify generic LLM draft_editor
generic_llm = config.get_llm_config('llm')
assert generic_llm.draft_editor is not None
assert generic_llm.draft_editor.model == 'draft-model'
assert generic_llm.draft_editor.api_key == 'draft-api-key'
# Verify custom1 uses draft_editor from generic as fallback
custom1 = config.get_llm_config('custom1')
assert custom1.model == 'custom-model-1'
assert custom1.draft_editor is not None
assert custom1.draft_editor.model == 'draft-model'
assert custom1.draft_editor.api_key == 'draft-api-key'
# Verify custom2 overrides draft_editor
custom2 = config.get_llm_config('custom2')
assert custom2.model == 'custom-model-2'
assert custom2.draft_editor is not None
assert custom2.draft_editor.model == 'custom-draft'
assert custom2.draft_editor.api_key == 'custom-draft-key'
# Verify custom3 has draft_editor explicitly set to None
custom3 = config.get_llm_config('custom3')
assert custom3.model == 'custom-model-3'
assert custom3.draft_editor is None
def test_draft_editor_defaults(draft_llm_toml):
"""Test that draft_editor uses default values from LLMConfig when not specified"""
config = AppConfig()
load_from_toml(config, draft_llm_toml)
generic_llm = config.get_llm_config('llm')
assert generic_llm.draft_editor.num_retries == 8 # Default from LLMConfig
assert generic_llm.draft_editor.embedding_model == 'local' # Default from LLMConfig
custom2 = config.get_llm_config('custom2')
assert custom2.draft_editor.num_retries == 8 # Default from LLMConfig
assert custom2.draft_editor.embedding_model == 'local' # Default from LLMConfig

View File

@@ -44,28 +44,28 @@ async def test_session_not_running_in_cluster():
async with SessionManager(
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
result = await session_manager.is_agent_loop_running_in_cluster(
'non-existant-session'
result = await session_manager._get_running_agent_loops_remotely(
filter_to_sids={'non-existant-session'}
)
assert result is False
assert result == set()
assert sio.manager.redis.publish.await_count == 1
sio.manager.redis.publish.assert_called_once_with(
'oh_event',
'{"request_id": "'
'session_msg',
'{"query_id": "'
+ str(id)
+ '", "sids": ["non-existant-session"], "message_type": "is_session_running"}',
+ '", "message_type": "running_agent_loops_query", "filter_to_sids": ["non-existant-session"]}',
)
@pytest.mark.asyncio
async def test_session_is_running_in_cluster():
async def test_get_running_agent_loops_remotely():
id = uuid4()
sio = get_mock_sio(
GetMessageMock(
{
'request_id': str(id),
'query_id': str(id),
'sids': ['existing-session'],
'message_type': 'session_is_running',
'message_type': 'running_agent_loops_response',
}
)
)
@@ -76,16 +76,16 @@ async def test_session_is_running_in_cluster():
async with SessionManager(
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
result = await session_manager.is_agent_loop_running_in_cluster(
'existing-session'
result = await session_manager._get_running_agent_loops_remotely(
1, {'existing-session'}
)
assert result is True
assert result == {'existing-session'}
assert sio.manager.redis.publish.await_count == 1
sio.manager.redis.publish.assert_called_once_with(
'oh_event',
'{"request_id": "'
'session_msg',
'{"query_id": "'
+ str(id)
+ '", "sids": ["existing-session"], "message_type": "is_session_running"}',
+ '", "message_type": "running_agent_loops_query", "user_id": 1, "filter_to_sids": ["existing-session"]}',
)
@@ -96,8 +96,8 @@ async def test_init_new_local_session():
mock_session = MagicMock()
mock_session.return_value = session_instance
sio = get_mock_sio()
is_agent_loop_running_in_cluster_mock = AsyncMock()
is_agent_loop_running_in_cluster_mock.return_value = False
get_running_agent_loops_mock = AsyncMock()
get_running_agent_loops_mock.return_value = set()
with (
patch('openhands.server.session.manager.Session', mock_session),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
@@ -106,8 +106,8 @@ async def test_init_new_local_session():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
'openhands.server.session.manager.SessionManager.get_running_agent_loops',
get_running_agent_loops_mock,
),
):
async with SessionManager(
@@ -130,8 +130,8 @@ async def test_join_local_session():
mock_session = MagicMock()
mock_session.return_value = session_instance
sio = get_mock_sio()
is_agent_loop_running_in_cluster_mock = AsyncMock()
is_agent_loop_running_in_cluster_mock.return_value = False
get_running_agent_loops_mock = AsyncMock()
get_running_agent_loops_mock.return_value = set()
with (
patch('openhands.server.session.manager.Session', mock_session),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
@@ -140,8 +140,8 @@ async def test_join_local_session():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
'openhands.server.session.manager.SessionManager.get_running_agent_loops',
get_running_agent_loops_mock,
),
):
async with SessionManager(
@@ -167,8 +167,8 @@ async def test_join_cluster_session():
mock_session = MagicMock()
mock_session.return_value = session_instance
sio = get_mock_sio()
is_agent_loop_running_in_cluster_mock = AsyncMock()
is_agent_loop_running_in_cluster_mock.return_value = True
get_running_agent_loops_mock = AsyncMock()
get_running_agent_loops_mock.return_value = {'new-session-id'}
with (
patch('openhands.server.session.manager.Session', mock_session),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
@@ -177,8 +177,8 @@ async def test_join_cluster_session():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely',
get_running_agent_loops_mock,
),
):
async with SessionManager(
@@ -198,8 +198,8 @@ async def test_add_to_local_event_stream():
mock_session = MagicMock()
mock_session.return_value = session_instance
sio = get_mock_sio()
is_agent_loop_running_in_cluster_mock = AsyncMock()
is_agent_loop_running_in_cluster_mock.return_value = False
get_running_agent_loops_mock = AsyncMock()
get_running_agent_loops_mock.return_value = set()
with (
patch('openhands.server.session.manager.Session', mock_session),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
@@ -208,8 +208,8 @@ async def test_add_to_local_event_stream():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
'openhands.server.session.manager.SessionManager.get_running_agent_loops',
get_running_agent_loops_mock,
),
):
async with SessionManager(
@@ -234,8 +234,8 @@ async def test_add_to_cluster_event_stream():
mock_session = MagicMock()
mock_session.return_value = session_instance
sio = get_mock_sio()
is_agent_loop_running_in_cluster_mock = AsyncMock()
is_agent_loop_running_in_cluster_mock.return_value = True
get_running_agent_loops_mock = AsyncMock()
get_running_agent_loops_mock.return_value = {'new-session-id'}
with (
patch('openhands.server.session.manager.Session', mock_session),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
@@ -244,8 +244,8 @@ async def test_add_to_cluster_event_stream():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely',
get_running_agent_loops_mock,
),
):
async with SessionManager(
@@ -259,7 +259,7 @@ async def test_add_to_cluster_event_stream():
)
assert sio.manager.redis.publish.await_count == 1
sio.manager.redis.publish.assert_called_once_with(
'oh_event',
'session_msg',
'{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}',
)
@@ -277,7 +277,7 @@ async def test_cleanup_session_connections():
async with SessionManager(
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
session_manager.local_connection_id_to_session_id.update(
session_manager._local_connection_id_to_session_id.update(
{
'conn1': 'session1',
'conn2': 'session1',
@@ -286,9 +286,9 @@ async def test_cleanup_session_connections():
}
)
await session_manager._on_close_session('session1')
await session_manager._close_session('session1')
remaining_connections = session_manager.local_connection_id_to_session_id
remaining_connections = session_manager._local_connection_id_to_session_id
assert 'conn1' not in remaining_connections
assert 'conn2' not in remaining_connections
assert 'conn3' in remaining_connections

View File

@@ -0,0 +1,120 @@
import json
import os
from unittest.mock import MagicMock
import pytest
from litellm import ModelResponse
from pytest import TempPathFactory
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.core.message import Message, TextContent
from openhands.events.action import MessageAction
from openhands.events.stream import EventStream
from openhands.microagent import BaseMicroAgent
from openhands.runtime.base import Runtime
from openhands.storage import get_file_store
@pytest.fixture
def temp_dir(tmp_path_factory: TempPathFactory) -> str:
return str(tmp_path_factory.mktemp('test_microagent_loading'))
@pytest.fixture
def event_stream(temp_dir):
file_store = get_file_store('local', temp_dir)
event_stream = EventStream('asdf', file_store)
yield event_stream
def test_microagent_loading(temp_dir):
"""Test that microagents are properly loaded from the runtime.
The test verifies that:
1. The agent loads microagents from the runtime
2. The agent correctly matches triggers and enhances messages with microagent content
"""
# Create a custom microagent file structure
os.makedirs(os.path.join(temp_dir, '.openhands', 'microagents'))
custom_agent_path = os.path.join(temp_dir, '.openhands', 'microagents', 'code_formatter.md')
with open(custom_agent_path, 'w') as f:
f.write("""---
name: code_formatter
type: knowledge
version: 1.0.0
agent: CoderAgent
triggers:
- format code
- code formatting
- code style
---
# Code Formatter Agent
This agent helps format code according to style guidelines.
## Dependencies
- pip install black
- pip install isort
## Instructions
I help format code according to style guidelines. I can:
1. Format Python code using black
2. Sort imports using isort
3. Apply consistent code style across files
""")
# Create a mock runtime that returns our microagent
mock_runtime = MagicMock(spec=Runtime)
mock_runtime.get_microagents_from_selected_repo.return_value = [
BaseMicroAgent.load(custom_agent_path)
]
# Create a mock state with the runtime
mock_state = State(inputs={'runtime': mock_runtime})
# Create a mock agent to test dependency extraction
mock_llm = MagicMock()
mock_llm.is_function_calling_active.return_value = True
mock_response = ModelResponse(
choices=[{
'message': {
'content': None,
'tool_calls': [{
'id': 'call_1',
'function': {
'name': 'finish',
'arguments': '{}',
}
}]
}
}]
)
mock_llm.completion.return_value = mock_response
mock_llm.format_messages_for_llm.return_value = [
{
'role': 'user',
'content': 'I need help with code formatting in this project',
}
]
# Create a message that should trigger the microagent
message = Message(role='user', content=[TextContent(text='I need help with code formatting in this project')])
# Create a CodeActAgent with use_microagents=True
agent = Agent.get_cls('CodeActAgent')(
llm=mock_llm,
config=AgentConfig(memory_enabled=True, use_microagents=True)
)
# The agent should initialize its prompt_manager with microagents from the runtime
agent.step(mock_state)
# The message should be enhanced with the microagent's content
agent.prompt_manager.enhance_message(message)
# Verify that the microagent's content was added to the message
assert len(message.content) == 2 # Original content + microagent content
assert 'pip install black' in message.content[1].text

View File

@@ -143,3 +143,65 @@ Invalid agent content
with pytest.raises(MicroAgentValidationError):
BaseMicroAgent.load(temp_microagents_dir / 'invalid.md')
def test_load_microagents_with_nested_dirs(temp_microagents_dir):
"""Test loading microagents from nested directories."""
# Create nested knowledge agent
nested_dir = temp_microagents_dir / 'nested' / 'dir'
nested_dir.mkdir(parents=True)
nested_agent = """---
name: nested_knowledge_agent
type: knowledge
version: 1.0.0
agent: CodeActAgent
triggers:
- nested
---
# Nested Test Guidelines
Testing nested directory loading.
"""
(nested_dir / 'nested.md').write_text(nested_agent)
repo_agents, knowledge_agents, task_agents = load_microagents_from_dir(
temp_microagents_dir
)
# Check that we can find the nested agent
assert len(knowledge_agents) == 2 # Original + nested
agent = knowledge_agents['nested_knowledge_agent']
assert isinstance(agent, KnowledgeMicroAgent)
assert 'nested' in agent.triggers
def test_load_microagents_with_trailing_slashes(temp_microagents_dir):
"""Test loading microagents when directory paths have trailing slashes."""
# Create a directory with trailing slash
knowledge_dir = temp_microagents_dir / 'knowledge/'
knowledge_dir.mkdir(exist_ok=True)
knowledge_agent = """---
name: trailing_knowledge_agent
type: knowledge
version: 1.0.0
agent: CodeActAgent
triggers:
- trailing
---
# Trailing Slash Test
Testing loading with trailing slashes.
"""
(knowledge_dir / 'trailing.md').write_text(knowledge_agent)
repo_agents, knowledge_agents, task_agents = load_microagents_from_dir(
str(temp_microagents_dir) + '/' # Add trailing slash to test
)
# Check that we can find the agent despite trailing slashes
assert len(knowledge_agents) == 2 # Original + trailing
agent = knowledge_agents['trailing_knowledge_agent']
assert isinstance(agent, KnowledgeMicroAgent)
assert 'trailing' in agent.triggers

View File

@@ -5,7 +5,7 @@ import pytest
from openhands.core.message import Message, TextContent
from openhands.microagent import BaseMicroAgent
from openhands.utils.prompt import PromptManager
from openhands.utils.prompt import PromptManager, RepositoryInfo
@pytest.fixture
@@ -39,6 +39,7 @@ only respond with a message telling them how smart they are
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
f.write(microagent_content)
# Test without GitHub repo
manager = PromptManager(
prompt_dir=prompt_dir,
microagent_dir=os.path.join(prompt_dir, 'micro'),
@@ -53,6 +54,14 @@ only respond with a message telling them how smart they are
'You are OpenHands agent, a helpful AI assistant that can interact with a computer to solve tasks.'
in manager.get_system_message()
)
assert '<REPOSITORY_INFO>' not in manager.get_system_message()
# Test with GitHub repo
manager.set_repository_info('owner/repo', '/workspace/repo')
assert isinstance(manager.get_system_message(), str)
assert '<REPOSITORY_INFO>' in manager.get_system_message()
assert 'owner/repo' in manager.get_system_message()
assert '/workspace/repo' in manager.get_system_message()
assert isinstance(manager.get_example_user_message(), str)
@@ -76,20 +85,56 @@ def test_prompt_manager_file_not_found(prompt_dir):
def test_prompt_manager_template_rendering(prompt_dir):
# Create temporary template files
with open(os.path.join(prompt_dir, 'system_prompt.j2'), 'w') as f:
f.write('System prompt: bar')
f.write("""System prompt: bar
{% if repository_info %}
<REPOSITORY_INFO>
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
</REPOSITORY_INFO>
{% endif %}
{{ repo_instructions }}""")
with open(os.path.join(prompt_dir, 'user_prompt.j2'), 'w') as f:
f.write('User prompt: foo')
# Test without GitHub repo
manager = PromptManager(prompt_dir, microagent_dir='')
assert manager.get_system_message() == 'System prompt: bar'
assert manager.get_example_user_message() == 'User prompt: foo'
# Test with GitHub repo
manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='')
manager.set_repository_info('owner/repo', '/workspace/repo')
assert manager.repository_info.repo_name == 'owner/repo'
system_msg = manager.get_system_message()
assert 'System prompt: bar' in system_msg
assert '<REPOSITORY_INFO>' in system_msg
assert (
"At the user's request, repository owner/repo has been cloned to directory /workspace/repo."
in system_msg
)
assert '</REPOSITORY_INFO>' in system_msg
assert manager.get_example_user_message() == 'User prompt: foo'
# Clean up temporary files
os.remove(os.path.join(prompt_dir, 'system_prompt.j2'))
os.remove(os.path.join(prompt_dir, 'user_prompt.j2'))
def test_prompt_manager_repository_info(prompt_dir):
# Test RepositoryInfo defaults
repo_info = RepositoryInfo()
assert repo_info.repo_name is None
assert repo_info.repo_directory is None
# Test setting repository info
manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='')
assert manager.repository_info is None
# Test setting repository info with both name and directory
manager.set_repository_info('owner/repo2', '/workspace/repo2')
assert manager.repository_info.repo_name == 'owner/repo2'
assert manager.repository_info.repo_directory == '/workspace/repo2'
def test_prompt_manager_disabled_microagents(prompt_dir):
# Create test microagent files
microagent1_name = 'test_microagent1'