Compare commits

...

1 Commits

Author SHA1 Message Date
amanape
14b0eae9e8 fix: Prevent git changes pollution when local branch includes merged main commits 2025-07-23 19:18:20 +04:00

View File

@@ -4,8 +4,7 @@ from typing import Callable
@dataclass
class CommandResult:
"""
Represents the result of a shell command execution.
"""Represents the result of a shell command execution.
Attributes:
content (str): The output content of the command.
@@ -17,9 +16,7 @@ class CommandResult:
class GitHandler:
"""
A handler for executing Git-related operations via shell commands.
"""
"""A handler for executing Git-related operations via shell commands."""
def __init__(
self,
@@ -29,8 +26,7 @@ class GitHandler:
self.cwd: str | None = None
def set_cwd(self, cwd: str) -> None:
"""
Sets the current working directory for Git operations.
"""Sets the current working directory for Git operations.
Args:
cwd (str): The directory path.
@@ -38,8 +34,7 @@ class GitHandler:
self.cwd = cwd
def _is_git_repo(self) -> bool:
"""
Checks if the current directory is a Git repository.
"""Checks if the current directory is a Git repository.
Returns:
bool: True if inside a Git repository, otherwise False.
@@ -49,8 +44,7 @@ class GitHandler:
return output.content.strip() == 'true'
def _get_current_file_content(self, file_path: str) -> str:
"""
Retrieves the current content of a given file.
"""Retrieves the current content of a given file.
Args:
file_path (str): Path to the file.
@@ -62,8 +56,7 @@ class GitHandler:
return output.content
def _verify_ref_exists(self, ref: str) -> bool:
"""
Verifies whether a specific Git reference exists.
"""Verifies whether a specific Git reference exists.
Args:
ref (str): The Git reference to check.
@@ -75,10 +68,71 @@ class GitHandler:
output = self.execute(cmd, self.cwd)
return output.exit_code == 0
def _get_valid_ref(self) -> str | None:
"""
Determines a valid Git reference for comparison.
def _is_ahead_of_remote_branch(self, remote_branch: str) -> bool:
"""Checks if the current branch is ahead of the specified remote branch.
Args:
remote_branch (str): The remote branch reference (e.g., 'origin/feature-branch').
Returns:
bool: True if current branch is ahead, False otherwise.
"""
cmd = f'git --no-pager rev-list --count {remote_branch}..HEAD'
output = self.execute(cmd, self.cwd)
if output.exit_code != 0:
return False
return int(output.content.strip()) > 0
def _includes_merged_main_commits(self, remote_branch: str, default_branch: str) -> bool:
"""Checks if the local branch includes commits that were merged from the default branch.
Since the remote branch was last updated.
Args:
remote_branch (str): The remote branch reference (e.g., 'origin/feature-branch').
default_branch (str): The default branch name (e.g., 'main').
Returns:
bool: True if merged main commits are included in the diff.
"""
# Get commits that are in HEAD but not in remote_branch
cmd = f'git --no-pager log --oneline {remote_branch}..HEAD'
output = self.execute(cmd, self.cwd)
if output.exit_code != 0:
return False
local_commits = output.content.strip().splitlines()
if not local_commits:
return False
# Get commits that are in origin/default_branch but not in remote_branch
origin_default = f'origin/{default_branch}'
if not self._verify_ref_exists(origin_default):
return False
cmd = f'git --no-pager log --oneline {remote_branch}..{origin_default}'
output = self.execute(cmd, self.cwd)
if output.exit_code != 0:
return False
main_commits = output.content.strip().splitlines()
if not main_commits:
return False
# Extract commit hashes from both lists
local_hashes = {line.split()[0] for line in local_commits if line.strip()}
main_hashes = {line.split()[0] for line in main_commits if line.strip()}
# If there's significant overlap, we likely have merged main commits
overlap = local_hashes.intersection(main_hashes)
return len(overlap) >= min(2, len(main_hashes) // 2)
def _get_valid_ref(self) -> str | None:
"""Determines a valid Git reference for comparison using a hybrid approach.
- Uses origin/current_branch when it's the best representation of push status
- Falls back to merge-base when origin/current_branch includes merged main commits
Returns:
str | None: A valid Git reference or None if no valid reference is found.
"""
@@ -90,8 +144,19 @@ class GitHandler:
ref_default_branch = 'origin/' + default_branch
ref_new_repo = '$(git --no-pager rev-parse --verify 4b825dc642cb6eb9a060e54bf8d69288fbee4904)' # compares with empty tree
# Hybrid logic: check if origin/current_branch exists and causes pollution
if self._verify_ref_exists(ref_current_branch):
# If we're ahead of remote and it includes merged main commits, use merge-base instead
if (self._is_ahead_of_remote_branch(ref_current_branch) and
self._includes_merged_main_commits(ref_current_branch, default_branch)):
# Try merge-base first to avoid pollution
if self._verify_ref_exists(ref_non_default_branch):
return ref_non_default_branch
# Otherwise use origin/current_branch for normal push workflow
return ref_current_branch
# Fallback to original logic
refs = [
ref_current_branch,
ref_non_default_branch,
ref_default_branch,
ref_new_repo,
@@ -103,8 +168,7 @@ class GitHandler:
return None
def _get_ref_content(self, file_path: str) -> str:
"""
Retrieves the content of a file from a valid Git reference.
"""Retrieves the content of a file from a valid Git reference.
Args:
file_path (str): The file path in the repository.
@@ -121,8 +185,7 @@ class GitHandler:
return output.content if output.exit_code == 0 else ''
def _get_default_branch(self) -> str:
"""
Retrieves the primary Git branch name of the repository.
"""Retrieves the primary Git branch name of the repository.
Returns:
str: The name of the primary branch.
@@ -132,8 +195,7 @@ class GitHandler:
return output.content.split()[-1].strip()
def _get_current_branch(self) -> str:
"""
Retrieves the currently selected Git branch.
"""Retrieves the currently selected Git branch.
Returns:
str: The name of the current branch.
@@ -143,8 +205,7 @@ class GitHandler:
return output.content.strip()
def _get_changed_files(self) -> list[str]:
"""
Retrieves a list of changed files compared to a valid Git reference.
"""Retrieves a list of changed files compared to a valid Git reference.
Returns:
list[str]: A list of changed file paths.
@@ -162,8 +223,7 @@ class GitHandler:
return output.content.splitlines()
def _get_untracked_files(self) -> list[dict[str, str]]:
"""
Retrieves a list of untracked files in the repository. This is useful for detecting new files.
"""Retrieves a list of untracked files in the repository. This is useful for detecting new files.
Returns:
list[dict[str, str]]: A list of dictionaries containing file paths and statuses.
@@ -178,8 +238,7 @@ class GitHandler:
)
def get_git_changes(self) -> list[dict[str, str]] | None:
"""
Retrieves the list of changed files in the Git repository.
"""Retrieves the list of changed files in the Git repository.
Returns:
list[dict[str, str]] | None: A list of dictionaries containing file paths and statuses. None if not a git repository.
@@ -195,8 +254,7 @@ class GitHandler:
return result
def get_git_diff(self, file_path: str) -> dict[str, str]:
"""
Retrieves the original and modified content of a file in the repository.
"""Retrieves the original and modified content of a file in the repository.
Args:
file_path (str): Path to the file.
@@ -214,8 +272,7 @@ class GitHandler:
def parse_git_changes(changes_list: list[str]) -> list[dict[str, str]]:
"""
Parses the list of changed files and extracts their statuses and paths.
"""Parses the list of changed files and extracts their statuses and paths.
Args:
changes_list (list[str]): List of changed file entries.