diff --git a/frontend/src/components/features/chat/action-suggestions.tsx b/frontend/src/components/features/chat/action-suggestions.tsx index 91f85a4a87..197225900d 100644 --- a/frontend/src/components/features/chat/action-suggestions.tsx +++ b/frontend/src/components/features/chat/action-suggestions.tsx @@ -38,7 +38,7 @@ export function ActionSuggestions({ pr, prShort, pushToBranch: `Please push the changes to a remote branch on ${getProviderName()}, but do NOT create a ${pr}. Check your current branch name first - if it's main, master, deploy, or another common default branch name, create a new branch with a descriptive name related to your changes. Otherwise, use the exact SAME branch name as the one you are currently on.`, - createPR: `Please push the changes to ${getProviderName()} and open a ${pr}. Please create a meaningful branch name that describes the changes. If a ${pr} template exists in the repository, please follow it when creating the ${prShort} description.`, + createPR: `Please push the changes to ${getProviderName()} and open a ${pr}. If you're on a default branch (e.g., main, master, deploy), create a new branch with a descriptive name otherwise use the current branch. If a ${pr} template exists in the repository, please follow it when creating the ${prShort} description.`, pushToPR: `Please push the latest changes to the existing ${pr}.`, }; diff --git a/openhands/agenthub/codeact_agent/prompts/additional_info.j2 b/openhands/agenthub/codeact_agent/prompts/additional_info.j2 index 50315d11ec..5e70d3c7ac 100644 --- a/openhands/agenthub/codeact_agent/prompts/additional_info.j2 +++ b/openhands/agenthub/codeact_agent/prompts/additional_info.j2 @@ -1,6 +1,12 @@ {% if repository_info %} At the user's request, repository {{ repository_info.repo_name }} has been cloned to {{ repository_info.repo_directory }} in the current working directory. +{% if repository_info.branch_name %}The repository has been checked out to branch "{{ repository_info.branch_name }}". + +IMPORTANT: You should work within the current branch "{{ repository_info.branch_name }}" unless + 1. the user explicitly instructs otherwise + 2. if the current branch is "main", "master", or another default branch where direct pushes may be unsafe +{% endif %} {% endif %} {% if repository_instructions -%} diff --git a/openhands/agenthub/readonly_agent/prompts/additional_info.j2 b/openhands/agenthub/readonly_agent/prompts/additional_info.j2 index 12d6976665..d709733f25 100644 --- a/openhands/agenthub/readonly_agent/prompts/additional_info.j2 +++ b/openhands/agenthub/readonly_agent/prompts/additional_info.j2 @@ -1,6 +1,12 @@ {% if repository_info %} At the user's request, repository {{ repository_info.repo_name }} has been cloned to the current working directory {{ repository_info.repo_directory }}. +{% if repository_info.branch_name %}The repository has been checked out to branch "{{ repository_info.branch_name }}". + +IMPORTANT: You should work within the current branch "{{ repository_info.branch_name }}" unless + 1. the user explicitly instructs otherwise + 2. if the current branch is "main", "master", or another default branch where direct pushes may be unsafe +{% endif %} {% endif %} {% if repository_instructions -%} diff --git a/openhands/events/observation/agent.py b/openhands/events/observation/agent.py index d7252f334a..fc668e9df2 100644 --- a/openhands/events/observation/agent.py +++ b/openhands/events/observation/agent.py @@ -70,6 +70,7 @@ class RecallObservation(Observation): # workspace context repo_name: str = '' repo_directory: str = '' + repo_branch: str = '' repo_instructions: str = '' runtime_hosts: dict[str, int] = field(default_factory=dict) additional_agent_instructions: str = '' diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py index 61ceaaa327..3a767a433e 100644 --- a/openhands/memory/conversation_memory.py +++ b/openhands/memory/conversation_memory.py @@ -512,6 +512,7 @@ class ConversationMemory: repo_info = RepositoryInfo( repo_name=obs.repo_name or '', repo_directory=obs.repo_directory or '', + branch_name=obs.repo_branch or None, ) else: repo_info = None diff --git a/openhands/memory/memory.py b/openhands/memory/memory.py index 1e75ded246..de1818d11a 100644 --- a/openhands/memory/memory.py +++ b/openhands/memory/memory.py @@ -181,6 +181,9 @@ class Memory: if self.repository_info and self.repository_info.repo_directory is not None else '', + repo_branch=self.repository_info.branch_name + if self.repository_info and self.repository_info.branch_name is not None + else '', repo_instructions=repo_instructions if repo_instructions else '', runtime_hosts=self.runtime_info.available_hosts if self.runtime_info and self.runtime_info.available_hosts is not None @@ -322,10 +325,14 @@ class Memory: return mcp_configs - def set_repository_info(self, repo_name: str, repo_directory: str) -> None: + def set_repository_info( + self, repo_name: str, repo_directory: str, branch_name: str | None = None + ) -> None: """Store repository info so we can reference it in an observation.""" if repo_name or repo_directory: - self.repository_info = RepositoryInfo(repo_name, repo_directory) + self.repository_info = RepositoryInfo( + repo_name, repo_directory, branch_name + ) else: self.repository_info = None diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index af4496a3f2..df2d0c15ab 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -152,6 +152,7 @@ class AgentSession: self.memory = await self._create_memory( selected_repository=selected_repository, repo_directory=repo_directory, + selected_branch=selected_branch, conversation_instructions=conversation_instructions, custom_secrets_descriptions=custom_secrets_handler.get_custom_secrets_descriptions(), working_dir=config.workspace_mount_path_in_sandbox, @@ -463,6 +464,7 @@ class AgentSession: self, selected_repository: str | None, repo_directory: str | None, + selected_branch: str | None, conversation_instructions: str | None, custom_secrets_descriptions: dict[str, str], working_dir: str, @@ -488,7 +490,9 @@ class AgentSession: memory.load_user_workspace_microagents(microagents) if selected_repository and repo_directory: - memory.set_repository_info(selected_repository, repo_directory) + memory.set_repository_info( + selected_repository, repo_directory, selected_branch + ) return memory def _maybe_restore_state(self) -> State | None: diff --git a/openhands/utils/prompt.py b/openhands/utils/prompt.py index 467e8eb160..db4eb75c21 100644 --- a/openhands/utils/prompt.py +++ b/openhands/utils/prompt.py @@ -24,6 +24,7 @@ class RepositoryInfo: repo_name: str | None = None repo_directory: str | None = None + branch_name: str | None = None @dataclass diff --git a/tests/unit/test_memory.py b/tests/unit/test_memory.py index 19b2eb96d8..91e6dc261a 100644 --- a/tests/unit/test_memory.py +++ b/tests/unit/test_memory.py @@ -453,7 +453,9 @@ def test_custom_secrets_descriptions_serialization(prompt_dir): # Create a RepositoryInfo repository_info = RepositoryInfo( - repo_name='test-owner/test-repo', repo_directory='/workspace/test-repo' + repo_name='test-owner/test-repo', + repo_directory='/workspace/test-repo', + branch_name='main', ) conversation_instructions = ConversationInstructions( diff --git a/tests/unit/test_observation_serialization.py b/tests/unit/test_observation_serialization.py index f7cdb3ca8c..5897d34b76 100644 --- a/tests/unit/test_observation_serialization.py +++ b/tests/unit/test_observation_serialization.py @@ -296,6 +296,7 @@ def test_microagent_observation_serialization(): 'recall_type': 'workspace_context', 'repo_name': 'some_repo_name', 'repo_directory': 'some_repo_directory', + 'repo_branch': '', 'working_dir': '', 'runtime_hosts': {'host1': 8080, 'host2': 8081}, 'repo_instructions': 'complex_repo_instructions', @@ -318,6 +319,7 @@ def test_microagent_observation_microagent_knowledge_serialization(): 'recall_type': 'knowledge', 'repo_name': '', 'repo_directory': '', + 'repo_branch': '', 'repo_instructions': '', 'runtime_hosts': {}, 'working_dir': '', @@ -348,6 +350,7 @@ def test_microagent_observation_knowledge_microagent_serialization(): original = RecallObservation( content='Knowledge microagent information', recall_type=RecallType.KNOWLEDGE, + repo_branch='', microagent_knowledge=[ MicroagentKnowledge( name='python_best_practices', @@ -395,6 +398,7 @@ def test_microagent_observation_environment_serialization(): recall_type=RecallType.WORKSPACE_CONTEXT, repo_name='OpenHands', repo_directory='/workspace/openhands', + repo_branch='main', repo_instructions="Follow the project's coding style guide.", runtime_hosts={'127.0.0.1': 8080, 'localhost': 5000}, additional_agent_instructions='You know it all about this runtime', @@ -444,6 +448,7 @@ def test_microagent_observation_combined_serialization(): # Environment info repo_name='OpenHands', repo_directory='/workspace/openhands', + repo_branch='main', repo_instructions="Follow the project's coding style guide.", runtime_hosts={'127.0.0.1': 8080}, additional_agent_instructions='You know it all about this runtime', diff --git a/tests/unit/test_prompt_manager.py b/tests/unit/test_prompt_manager.py index 45d0894a42..7cee16914f 100644 --- a/tests/unit/test_prompt_manager.py +++ b/tests/unit/test_prompt_manager.py @@ -50,7 +50,9 @@ At the user's request, repository {{ repository_info.repo_name }} has been clone # Test with GitHub repo manager = PromptManager(prompt_dir=prompt_dir) - repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo') + repo_info = RepositoryInfo( + repo_name='owner/repo', repo_directory='/workspace/repo', branch_name='main' + ) # verify its parts are rendered system_msg = manager.get_system_message() @@ -231,7 +233,9 @@ Today's date is {{ runtime_info.date }} manager = PromptManager(prompt_dir=prompt_dir) # Create repository and runtime information - repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo') + repo_info = RepositoryInfo( + repo_name='owner/repo', repo_directory='/workspace/repo', branch_name='main' + ) runtime_info = RuntimeInfo( date='02/12/1232', available_hosts={'example.com': 8080},