Compare commits

...

3 Commits

Author SHA1 Message Date
openhands
1385d92447 refactor: Set git_dir during clone_or_init_repo and remove cwd params 2025-04-25 20:02:05 +00:00
openhands
74f5b24144 refactor: Simplify GitHandler to always use async functions 2025-04-25 19:44:46 +00:00
openhands
e8d51a0878 feat: Make git functions in runtime/base.py async 2025-04-25 19:42:13 +00:00
4 changed files with 122 additions and 152 deletions

View File

@@ -98,6 +98,7 @@ class Runtime(FileEditRuntimeMixin):
initial_env_vars: dict[str, str]
attach_to_existing: bool
status_callback: Callable | None
git_dir: str | None
def __init__(
self,
@@ -112,9 +113,11 @@ class Runtime(FileEditRuntimeMixin):
user_id: str | None = None,
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
):
# GitHandler will be initialized with an async function
self.git_handler = GitHandler(
execute_shell_fn=self._execute_shell_fn_git_handler
)
self.git_dir = None
self.sid = sid
self.event_stream = event_stream
self.event_stream.subscribe(
@@ -316,6 +319,9 @@ class Runtime(FileEditRuntimeMixin):
selected_branch: str | None,
repository_provider: ProviderType = ProviderType.GITHUB,
) -> str:
# Set the git_dir to the workspace mount path by default
self.git_dir = self.config.workspace_mount_path_in_sandbox
if not selected_repository:
# In SaaS mode (indicated by user_id being set), always run git init
# In OSS mode, only run git init if workspace_base is not set
@@ -327,6 +333,7 @@ class Runtime(FileEditRuntimeMixin):
command='git init',
)
self.run_action(action)
# git_dir is already set to workspace mount path
else:
logger.info(
'In workspace mount mode, not initializing a new git repository.'
@@ -395,6 +402,13 @@ class Runtime(FileEditRuntimeMixin):
)
self.log('info', f'Cloning repo: {selected_repository}')
self.run_action(action)
# Update git_dir to point to the cloned repository directory
self.git_dir = os.path.join(
self.config.workspace_mount_path_in_sandbox, dir_name
)
self.git_handler.set_cwd(self.git_dir)
return dir_name
def maybe_run_setup_script(self):
@@ -612,13 +626,15 @@ class Runtime(FileEditRuntimeMixin):
# Git
# ====================================================================
def _execute_shell_fn_git_handler(
async def _execute_shell_fn_git_handler(
self, command: str, cwd: str | None
) -> CommandResult:
"""
This function is used by the GitHandler to execute shell commands.
"""
obs = self.run(CmdRunAction(command=command, is_static=True, cwd=cwd))
obs = await call_sync_from_async(
self.run, CmdRunAction(command=command, is_static=True, cwd=cwd)
)
exit_code = 0
content = ''
@@ -629,13 +645,15 @@ class Runtime(FileEditRuntimeMixin):
return CommandResult(content=content, exit_code=exit_code)
def get_git_changes(self, cwd: str) -> list[dict[str, str]] | None:
self.git_handler.set_cwd(cwd)
return self.git_handler.get_git_changes()
async def get_git_changes(self) -> list[dict[str, str]] | None:
if self.git_dir:
self.git_handler.set_cwd(self.git_dir)
return await call_sync_from_async(self.git_handler.get_git_changes)
def get_git_diff(self, file_path: str, cwd: str) -> dict[str, str]:
self.git_handler.set_cwd(cwd)
return self.git_handler.get_git_diff(file_path)
async def get_git_diff(self, file_path: str) -> dict[str, str]:
if self.git_dir:
self.git_handler.set_cwd(self.git_dir)
return await call_sync_from_async(self.git_handler.get_git_diff, file_path)
@property
def additional_agent_instructions(self) -> str:

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable
from typing import Awaitable, Callable
@dataclass
@@ -23,7 +23,7 @@ class GitHandler:
def __init__(
self,
execute_shell_fn: Callable[[str, str | None], CommandResult],
execute_shell_fn: Callable[[str, str | None], Awaitable[CommandResult]],
):
self.execute = execute_shell_fn
self.cwd: str | None = None
@@ -37,7 +37,11 @@ class GitHandler:
"""
self.cwd = cwd
def _is_git_repo(self) -> bool:
async def _execute_async(self, cmd: str, cwd: str | None) -> CommandResult:
"""Execute the command asynchronously."""
return await self.execute(cmd, cwd)
async def _is_git_repo(self) -> bool:
"""
Checks if the current directory is a Git repository.
@@ -45,10 +49,10 @@ class GitHandler:
bool: True if inside a Git repository, otherwise False.
"""
cmd = 'git rev-parse --is-inside-work-tree'
output = self.execute(cmd, self.cwd)
output = await self._execute_async(cmd, self.cwd)
return output.content.strip() == 'true'
def _get_current_file_content(self, file_path: str) -> str:
async def _get_current_file_content(self, file_path: str) -> str:
"""
Retrieves the current content of a given file.
@@ -58,10 +62,10 @@ class GitHandler:
Returns:
str: The file content.
"""
output = self.execute(f'cat {file_path}', self.cwd)
output = await self._execute_async(f'cat {file_path}', self.cwd)
return output.content
def _verify_ref_exists(self, ref: str) -> bool:
async def _verify_ref_exists(self, ref: str) -> bool:
"""
Verifies whether a specific Git reference exists.
@@ -72,18 +76,18 @@ class GitHandler:
bool: True if the reference exists, otherwise False.
"""
cmd = f'git rev-parse --verify {ref}'
output = self.execute(cmd, self.cwd)
output = await self._execute_async(cmd, self.cwd)
return output.exit_code == 0
def _get_valid_ref(self) -> str | None:
async def _get_valid_ref(self) -> str | None:
"""
Determines a valid Git reference for comparison.
Returns:
str | None: A valid Git reference or None if no valid reference is found.
"""
current_branch = self._get_current_branch()
default_branch = self._get_default_branch()
current_branch = await self._get_current_branch()
default_branch = await self._get_default_branch()
ref_current_branch = f'origin/{current_branch}'
ref_non_default_branch = f'$(git merge-base HEAD "$(git rev-parse --abbrev-ref origin/{default_branch})")'
@@ -97,12 +101,12 @@ class GitHandler:
ref_new_repo,
]
for ref in refs:
if self._verify_ref_exists(ref):
if await self._verify_ref_exists(ref):
return ref
return None
def _get_ref_content(self, file_path: str) -> str:
async def _get_ref_content(self, file_path: str) -> str:
"""
Retrieves the content of a file from a valid Git reference.
@@ -112,15 +116,15 @@ class GitHandler:
Returns:
str: The content of the file from the reference, or an empty string if unavailable.
"""
ref = self._get_valid_ref()
ref = await self._get_valid_ref()
if not ref:
return ''
cmd = f'git show {ref}:{file_path}'
output = self.execute(cmd, self.cwd)
output = await self._execute_async(cmd, self.cwd)
return output.content if output.exit_code == 0 else ''
def _get_default_branch(self) -> str:
async def _get_default_branch(self) -> str:
"""
Retrieves the primary Git branch name of the repository.
@@ -128,10 +132,10 @@ class GitHandler:
str: The name of the primary branch.
"""
cmd = 'git remote show origin | grep "HEAD branch"'
output = self.execute(cmd, self.cwd)
output = await self._execute_async(cmd, self.cwd)
return output.content.split()[-1].strip()
def _get_current_branch(self) -> str:
async def _get_current_branch(self) -> str:
"""
Retrieves the currently selected Git branch.
@@ -139,25 +143,25 @@ class GitHandler:
str: The name of the current branch.
"""
cmd = 'git rev-parse --abbrev-ref HEAD'
output = self.execute(cmd, self.cwd)
output = await self._execute_async(cmd, self.cwd)
return output.content.strip()
def _get_changed_files(self) -> list[str]:
async def _get_changed_files(self) -> list[str]:
"""
Retrieves a list of changed files compared to a valid Git reference.
Returns:
list[str]: A list of changed file paths.
"""
ref = self._get_valid_ref()
ref = await self._get_valid_ref()
if not ref:
return []
diff_cmd = f'git diff --name-status {ref}'
output = self.execute(diff_cmd, self.cwd)
output = await self._execute_async(diff_cmd, self.cwd)
return output.content.splitlines()
def _get_untracked_files(self) -> list[dict[str, str]]:
async def _get_untracked_files(self) -> list[dict[str, str]]:
"""
Retrieves a list of untracked files in the repository. This is useful for detecting new files.
@@ -165,7 +169,7 @@ class GitHandler:
list[dict[str, str]]: A list of dictionaries containing file paths and statuses.
"""
cmd = 'git ls-files --others --exclude-standard'
output = self.execute(cmd, self.cwd)
output = await self._execute_async(cmd, self.cwd)
obs_list = output.content.splitlines()
return (
[{'status': 'A', 'path': path} for path in obs_list]
@@ -173,24 +177,24 @@ class GitHandler:
else []
)
def get_git_changes(self) -> list[dict[str, str]] | None:
async def get_git_changes(self) -> list[dict[str, str]] | None:
"""
Retrieves the list of changed files in the Git repository.
Returns:
list[dict[str, str]] | None: A list of dictionaries containing file paths and statuses. None if not a git repository.
"""
if not self._is_git_repo():
if not await self._is_git_repo():
return None
changes_list = self._get_changed_files()
changes_list = await self._get_changed_files()
result = parse_git_changes(changes_list)
# join with any untracked files
result += self._get_untracked_files()
result += await self._get_untracked_files()
return result
def get_git_diff(self, file_path: str) -> dict[str, str]:
async def get_git_diff(self, file_path: str) -> dict[str, str]:
"""
Retrieves the original and modified content of a file in the repository.
@@ -200,8 +204,8 @@ class GitHandler:
Returns:
dict[str, str]: A dictionary containing the original and modified content.
"""
modified = self._get_current_file_content(file_path)
original = self._get_ref_content(file_path)
modified = await self._get_current_file_content(file_path)
original = await self._get_ref_content(file_path)
return {
'modified': modified,

View File

@@ -22,20 +22,10 @@ from openhands.events.observation import (
FileReadObservation,
)
from openhands.runtime.base import Runtime
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.file_config import (
FILES_TO_IGNORE,
)
from openhands.server.shared import (
ConversationStoreImpl,
config,
conversation_manager,
)
from openhands.server.user_auth import get_user_id
from openhands.server.utils import get_conversation_store
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.utils.async_utils import call_sync_from_async
app = APIRouter(prefix='/api/conversations/{conversation_id}')
@@ -195,20 +185,10 @@ async def git_changes(
user_id: str = Depends(get_user_id),
):
runtime: Runtime = request.state.conversation.runtime
conversation_store = await ConversationStoreImpl.get_instance(
config,
user_id,
)
cwd = await get_cwd(
conversation_store,
conversation_id,
runtime.config.workspace_mount_path_in_sandbox,
)
logger.info(f'Getting git changes in {cwd}')
logger.info(f'Getting git changes in {runtime.git_dir}')
try:
changes = await call_sync_from_async(runtime.get_git_changes, cwd)
changes = await call_sync_from_async(runtime.get_git_changes)
if changes is None:
return JSONResponse(
status_code=404,
@@ -234,18 +214,12 @@ async def git_diff(
request: Request,
path: str,
conversation_id: str,
conversation_store = Depends(get_conversation_store),
):
runtime: Runtime = request.state.conversation.runtime
cwd = await get_cwd(
conversation_store,
conversation_id,
runtime.config.workspace_mount_path_in_sandbox,
)
logger.info(f'Getting git diff for {path} in {runtime.git_dir}')
try:
diff = await call_sync_from_async(runtime.get_git_diff, path, cwd)
diff = await call_sync_from_async(runtime.get_git_diff, path)
return diff
except AgentRuntimeUnavailableError as e:
logger.error(f'Error getting diff: {e}')
@@ -253,46 +227,3 @@ async def git_diff(
status_code=500,
content={'error': f'Error getting diff: {e}'},
)
async def get_cwd(
conversation_store: ConversationStore,
conversation_id: str,
workspace_mount_path_in_sandbox: str,
):
metadata = await conversation_store.get_metadata(conversation_id)
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
conversation_info = await _get_conversation_info(metadata, is_running)
cwd = workspace_mount_path_in_sandbox
if conversation_info and conversation_info.selected_repository:
repo_dir = conversation_info.selected_repository.split('/')[-1]
cwd = os.path.join(cwd, repo_dir)
return cwd
async def _get_conversation_info(
conversation: ConversationMetadata,
is_running: bool,
) -> ConversationInfo | None:
try:
title = conversation.title
if not title:
title = f'Conversation {conversation.conversation_id[:5]}'
return ConversationInfo(
conversation_id=conversation.conversation_id,
title=title,
last_updated_at=conversation.last_updated_at,
created_at=conversation.created_at,
selected_repository=conversation.selected_repository,
status=ConversationStatus.RUNNING
if is_running
else ConversationStatus.STOPPED,
)
except Exception as e:
logger.error(
f'Error loading conversation {conversation.conversation_id}: {str(e)}',
extra={'session_id': conversation.conversation_id},
)
return None

View File

@@ -1,11 +1,17 @@
import asyncio
import os
import shutil
import subprocess
import tempfile
import unittest
import pytest
from openhands.runtime.utils.git_handler import CommandResult, GitHandler
# Mark all test methods as asyncio tests
pytestmark = pytest.mark.asyncio
class TestGitHandler(unittest.TestCase):
def setUp(self):
@@ -104,9 +110,9 @@ class TestGitHandler(unittest.TestCase):
# Push the feature branch to origin
self._execute_command('git push -u origin feature-branch', self.local_dir)
def test_is_git_repo(self):
async def test_is_git_repo(self):
"""Test that _is_git_repo returns True for a git repository."""
self.assertTrue(self.git_handler._is_git_repo())
self.assertTrue(await self.git_handler._is_git_repo())
# Verify the command was executed
self.assertTrue(
@@ -116,9 +122,9 @@ class TestGitHandler(unittest.TestCase):
)
)
def test_get_default_branch(self):
async def test_get_default_branch(self):
"""Test that _get_default_branch returns the correct branch name."""
branch = self.git_handler._get_default_branch()
branch = await self.git_handler._get_default_branch()
self.assertEqual(branch, 'main')
# Verify the command was executed
@@ -129,9 +135,9 @@ class TestGitHandler(unittest.TestCase):
)
)
def test_get_current_branch(self):
async def test_get_current_branch(self):
"""Test that _get_current_branch returns the correct branch name."""
branch = self.git_handler._get_current_branch()
branch = await self.git_handler._get_current_branch()
self.assertEqual(branch, 'feature-branch')
# Verify the command was executed
@@ -142,10 +148,10 @@ class TestGitHandler(unittest.TestCase):
)
)
def test_get_valid_ref_with_origin_current_branch(self):
async def test_get_valid_ref_with_origin_current_branch(self):
"""Test that _get_valid_ref returns the current branch in origin when it exists."""
# This test uses the setup from setUp where the current branch exists in origin
ref = self.git_handler._get_valid_ref()
ref = await self.git_handler._get_valid_ref()
self.assertIsNotNone(ref)
# Check that the refs were checked in the correct order
@@ -165,7 +171,7 @@ class TestGitHandler(unittest.TestCase):
result = self._execute_command(f'git rev-parse --verify {ref}', self.local_dir)
self.assertEqual(result.exit_code, 0)
def test_get_valid_ref_without_origin_current_branch(self):
async def test_get_valid_ref_without_origin_current_branch(self):
"""Test that _get_valid_ref falls back to default branch when current branch doesn't exist in origin."""
# Create a new branch that doesn't exist in origin
self._execute_command('git checkout -b new-local-branch', self.local_dir)
@@ -173,7 +179,7 @@ class TestGitHandler(unittest.TestCase):
# Clear the executed commands to start fresh
self.executed_commands = []
ref = self.git_handler._get_valid_ref()
ref = await self.git_handler._get_valid_ref()
self.assertIsNotNone(ref)
# Check that the refs were checked in the correct order
@@ -196,7 +202,7 @@ class TestGitHandler(unittest.TestCase):
result = self._execute_command(f'git rev-parse --verify {ref}', self.local_dir)
self.assertEqual(result.exit_code, 0)
def test_get_valid_ref_without_origin(self):
async def test_get_valid_ref_without_origin(self):
"""Test that _get_valid_ref falls back to empty tree ref when there's no origin."""
# Create a new directory with a git repo but no origin
no_origin_dir = os.path.join(self.test_dir, 'no-origin')
@@ -207,18 +213,20 @@ class TestGitHandler(unittest.TestCase):
self._execute_command("git config user.email 'test@example.com'", no_origin_dir)
self._execute_command("git config user.name 'Test User'", no_origin_dir)
# Create a file and commit it
with open(os.path.join(no_origin_dir, 'file1.txt'), 'w') as f:
f.write('Content in repo without origin')
# Create a file and commit it using subprocess
file_path = os.path.join(no_origin_dir, 'file1.txt')
self._execute_command(
f'echo "Content in repo without origin" > {file_path}', no_origin_dir
)
self._execute_command('git add file1.txt', no_origin_dir)
self._execute_command("git commit -m 'Initial commit'", no_origin_dir)
# Create a custom GitHandler with a modified _get_default_branch method for this test
class TestGitHandler(GitHandler):
def _get_default_branch(self) -> str:
async def _get_default_branch(self) -> str:
# Override to handle repos without origin
try:
return super()._get_default_branch()
return await super()._get_default_branch()
except IndexError:
return 'main' # Default fallback
@@ -229,7 +237,7 @@ class TestGitHandler(unittest.TestCase):
# Clear the executed commands to start fresh
self.executed_commands = []
ref = no_origin_handler._get_valid_ref()
ref = await no_origin_handler._get_valid_ref()
# Verify that git commands were executed
self.assertTrue(
@@ -251,9 +259,9 @@ class TestGitHandler(unittest.TestCase):
)
self.assertEqual(result.exit_code, 0)
def test_get_ref_content(self):
async def test_get_ref_content(self):
"""Test that _get_ref_content returns the content from a valid ref."""
content = self.git_handler._get_ref_content('file1.txt')
content = await self.git_handler._get_ref_content('file1.txt')
self.assertEqual(content.strip(), 'Modified content')
# Should have called _get_valid_ref and then git show
@@ -262,9 +270,9 @@ class TestGitHandler(unittest.TestCase):
]
self.assertTrue(any('file1.txt' in cmd for cmd in show_commands))
def test_get_current_file_content(self):
async def test_get_current_file_content(self):
"""Test that _get_current_file_content returns the current content of a file."""
content = self.git_handler._get_current_file_content('file1.txt')
content = await self.git_handler._get_current_file_content('file1.txt')
self.assertEqual(content.strip(), 'Modified content again')
# Verify the command was executed
@@ -272,14 +280,15 @@ class TestGitHandler(unittest.TestCase):
any(cmd == 'cat file1.txt' for cmd, _ in self.executed_commands)
)
def test_get_changed_files(self):
async def test_get_changed_files(self):
"""Test that _get_changed_files returns the list of changed files."""
# Let's create a new file to ensure it shows up in the diff
with open(os.path.join(self.local_dir, 'new_file.txt'), 'w') as f:
f.write('New file content')
# Use subprocess directly to create and add the file
file_path = os.path.join(self.local_dir, 'new_file.txt')
self._execute_command(f'echo "New file content" > {file_path}', self.local_dir)
self._execute_command('git add new_file.txt', self.local_dir)
files = self.git_handler._get_changed_files()
files = await self.git_handler._get_changed_files()
self.assertTrue(files)
# Should include file1.txt (modified) and file3.txt (deleted)
@@ -295,13 +304,15 @@ class TestGitHandler(unittest.TestCase):
]
self.assertTrue(diff_commands)
def test_get_untracked_files(self):
async def test_get_untracked_files(self):
"""Test that _get_untracked_files returns the list of untracked files."""
# Create an untracked file
with open(os.path.join(self.local_dir, 'untracked.txt'), 'w') as f:
f.write('Untracked file content')
# Create an untracked file using subprocess
file_path = os.path.join(self.local_dir, 'untracked.txt')
self._execute_command(
f'echo "Untracked file content" > {file_path}', self.local_dir
)
files = self.git_handler._get_untracked_files()
files = await self.git_handler._get_untracked_files()
self.assertEqual(len(files), 1)
self.assertEqual(files[0]['path'], 'untracked.txt')
self.assertEqual(files[0]['status'], 'A')
@@ -314,18 +325,22 @@ class TestGitHandler(unittest.TestCase):
)
)
def test_get_git_changes(self):
async def test_get_git_changes(self):
"""Test that get_git_changes returns the combined list of changed and untracked files."""
# Create an untracked file
with open(os.path.join(self.local_dir, 'untracked.txt'), 'w') as f:
f.write('Untracked file content')
# Create an untracked file using subprocess
file_path = os.path.join(self.local_dir, 'untracked.txt')
self._execute_command(
f'echo "Untracked file content" > {file_path}', self.local_dir
)
# Create a new file and stage it
with open(os.path.join(self.local_dir, 'new_file2.txt'), 'w') as f:
f.write('New file 2 content')
file_path2 = os.path.join(self.local_dir, 'new_file2.txt')
self._execute_command(
f'echo "New file 2 content" > {file_path2}', self.local_dir
)
self._execute_command('git add new_file2.txt', self.local_dir)
changes = self.git_handler.get_git_changes()
changes = await self.git_handler.get_git_changes()
self.assertIsNotNone(changes)
# Should include file1.txt (modified), file3.txt (deleted), new_file2.txt (added), and untracked.txt (untracked)
@@ -341,9 +356,9 @@ class TestGitHandler(unittest.TestCase):
self.assertIn('A', statuses) # Added
self.assertIn('D', statuses) # Deleted
def test_get_git_diff(self):
async def test_get_git_diff(self):
"""Test that get_git_diff returns the original and modified content of a file."""
diff = self.git_handler.get_git_diff('file1.txt')
diff = await self.git_handler.get_git_diff('file1.txt')
self.assertEqual(diff['modified'].strip(), 'Modified content again')
self.assertEqual(diff['original'].strip(), 'Modified content')
@@ -360,4 +375,6 @@ class TestGitHandler(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
import asyncio
asyncio.run(unittest.main())