Skip to content

Commit 5cc8e5c

Browse files
committed
improve default branch detection
1 parent 85d950a commit 5cc8e5c

File tree

4 files changed

+272
-1
lines changed

4 files changed

+272
-1
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,22 @@ The pre-push hook:
888888
- For existing branches: scans only the new commits since the last push
889889
- Runs the same comprehensive scanning as other Cycode scan modes
890890
891+
#### Smart Default Branch Detection
892+
893+
The pre-push hook intelligently detects the default branch for merge base calculation using this priority order:
894+
895+
1. **Environment Variable**: `CYCODE_DEFAULT_BRANCH` - allows manual override
896+
2. **Git Remote HEAD**: Uses `git symbolic-ref refs/remotes/origin/HEAD` to detect the actual remote default branch
897+
3. **Git Remote Info**: Falls back to `git remote show origin` if symbolic-ref fails
898+
4. **Hardcoded Fallbacks**: Uses common default branch names (origin/main, origin/master, main, master)
899+
900+
**Setting a Custom Default Branch:**
901+
```bash
902+
export CYCODE_DEFAULT_BRANCH=origin/develop
903+
```
904+
905+
This smart detection ensures the pre-push hook works correctly regardless of whether your repository uses `main`, `master`, `develop`, or any other default branch name.
906+
891907
#### Skipping Pre-Push Scans
892908
893909
To skip the pre-push scan for a specific push operation, use:

cycode/cli/consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@
245245
DEFAULT_PRE_PUSH_MAX_COMMITS_TO_SCAN_COUNT = 50
246246
PRE_PUSH_COMMAND_TIMEOUT_ENV_VAR_NAME = 'PRE_PUSH_COMMAND_TIMEOUT'
247247
DEFAULT_PRE_PUSH_COMMAND_TIMEOUT_IN_SECONDS = 60
248+
CYCODE_DEFAULT_BRANCH_ENV_VAR_NAME = 'CYCODE_DEFAULT_BRANCH'
248249
# pre push and pre receive common
249250
PRE_RECEIVE_AND_PUSH_REMEDIATION_MESSAGE = """
250251
Cycode Secrets Push Protection

cycode/cli/files_collector/commit_range_documents.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,62 @@ def parse_pre_push_input() -> str:
232232
return pre_push_input.splitlines()[0]
233233

234234

235+
def _get_default_branches_for_merge_base(repo: 'Repo') -> list[str]:
236+
"""Get a list of default branches to try for merge base calculation.
237+
238+
Priority order:
239+
1. Environment variable CYCODE_DEFAULT_BRANCH
240+
2. Git remote HEAD (git symbolic-ref refs/remotes/origin/HEAD)
241+
3. Fallback to common default branch names
242+
243+
Args:
244+
repo: Git repository object
245+
246+
Returns:
247+
List of branch names to try for merge base calculation
248+
"""
249+
default_branches = []
250+
251+
# 1. Check environment variable first
252+
env_default_branch = os.getenv(consts.CYCODE_DEFAULT_BRANCH_ENV_VAR_NAME)
253+
if env_default_branch:
254+
logger.debug('Using default branch from environment variable: %s', env_default_branch)
255+
default_branches.append(env_default_branch)
256+
257+
# 2. Try to get the actual default branch from remote HEAD
258+
try:
259+
remote_head = repo.git.symbolic_ref('refs/remotes/origin/HEAD')
260+
# symbolic-ref returns something like "refs/remotes/origin/main"
261+
if remote_head.startswith('refs/remotes/origin/'):
262+
default_branch = remote_head.replace('refs/remotes/origin/', '')
263+
logger.debug('Found remote default branch: %s', default_branch)
264+
# Add both the remote tracking branch and local branch variants
265+
default_branches.extend([f'origin/{default_branch}', default_branch])
266+
except Exception as e:
267+
logger.debug('Failed to get remote HEAD via symbolic-ref: %s', exc_info=e)
268+
269+
# Try an alternative method: git remote show origin
270+
try:
271+
remote_info = repo.git.remote('show', 'origin')
272+
for line in remote_info.splitlines():
273+
if 'HEAD branch:' in line:
274+
default_branch = line.split('HEAD branch:')[1].strip()
275+
logger.debug('Found default branch via remote show: %s', default_branch)
276+
default_branches.extend([f'origin/{default_branch}', default_branch])
277+
break
278+
except Exception as e2:
279+
logger.debug('Failed to get remote info via remote show: %s', exc_info=e2)
280+
281+
# 3. Add fallback branches (avoiding duplicates)
282+
fallback_branches = ['origin/main', 'origin/master', 'main', 'master']
283+
for branch in fallback_branches:
284+
if branch not in default_branches:
285+
default_branches.append(branch)
286+
287+
logger.debug('Default branches to try: %s', default_branches)
288+
return default_branches
289+
290+
235291
def calculate_pre_push_commit_range(push_update_details: str) -> Optional[str]:
236292
"""Calculate the commit range for pre-push hook scanning.
237293
@@ -240,18 +296,22 @@ def calculate_pre_push_commit_range(push_update_details: str) -> Optional[str]:
240296
241297
Returns:
242298
Commit range string for scanning, or None if no scanning is needed
299+
300+
Environment Variables:
301+
CYCODE_DEFAULT_BRANCH: Override the default branch for merge base calculation
243302
"""
244303
local_ref, local_object_name, remote_ref, remote_object_name = push_update_details.split()
245304

246305
if remote_object_name == consts.EMPTY_COMMIT_SHA:
247306
try:
248307
repo = git_proxy.get_repo(os.getcwd())
249-
default_branches = ['origin/main', 'origin/master', 'main', 'master']
308+
default_branches = _get_default_branches_for_merge_base(repo)
250309

251310
merge_base = None
252311
for default_branch in default_branches:
253312
try:
254313
merge_base = repo.git.merge_base(local_object_name, default_branch)
314+
logger.debug('Found merge base %s with branch %s', merge_base, default_branch)
255315
break
256316
except Exception as e:
257317
logger.debug('Failed to find merge base with %s: %s', default_branch, exc_info=e)

tests/cli/files_collector/test_commit_range_documents.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from cycode.cli import consts
1212
from cycode.cli.files_collector.commit_range_documents import (
13+
_get_default_branches_for_merge_base,
1314
calculate_pre_push_commit_range,
1415
get_diff_file_path,
1516
get_safe_head_reference_for_diff,
@@ -390,6 +391,114 @@ def test_parse_whitespace_only_input_raises_error(self) -> None:
390391
parse_pre_push_input()
391392

392393

394+
class TestGetDefaultBranchesForMergeBase:
395+
"""Test the _get_default_branches_for_merge_base function with various scenarios."""
396+
397+
def test_environment_variable_override(self) -> None:
398+
"""Test that the environment variable takes precedence."""
399+
with (
400+
temporary_git_repository() as (temp_dir, repo),
401+
patch.dict(os.environ, {consts.CYCODE_DEFAULT_BRANCH_ENV_VAR_NAME: 'custom-main'}),
402+
):
403+
branches = _get_default_branches_for_merge_base(repo)
404+
assert branches[0] == 'custom-main'
405+
assert 'origin/main' in branches # Fallbacks should still be included
406+
407+
def test_git_symbolic_ref_success(self) -> None:
408+
"""Test getting default branch via git symbolic-ref."""
409+
with temporary_git_repository() as (temp_dir, repo):
410+
# Create a mock repo with a git interface that returns origin/main
411+
mock_repo = Mock()
412+
mock_repo.git.symbolic_ref.return_value = 'refs/remotes/origin/main'
413+
414+
branches = _get_default_branches_for_merge_base(mock_repo)
415+
assert 'origin/main' in branches
416+
assert 'main' in branches
417+
418+
def test_git_symbolic_ref_with_master(self) -> None:
419+
"""Test getting default branch via git symbolic-ref when it's master."""
420+
with temporary_git_repository() as (temp_dir, repo):
421+
# Create a mock repo with a git interface that returns origin/master
422+
mock_repo = Mock()
423+
mock_repo.git.symbolic_ref.return_value = 'refs/remotes/origin/master'
424+
425+
branches = _get_default_branches_for_merge_base(mock_repo)
426+
assert 'origin/master' in branches
427+
assert 'master' in branches
428+
429+
def test_git_remote_show_fallback(self) -> None:
430+
"""Test fallback to git remote show when symbolic-ref fails."""
431+
with temporary_git_repository() as (temp_dir, repo):
432+
# Create a mock repo where symbolic-ref fails but the remote show succeeds
433+
mock_repo = Mock()
434+
mock_repo.git.symbolic_ref.side_effect = Exception('symbolic-ref failed')
435+
remote_output = """* remote origin
436+
Fetch URL: https://github.com/user/repo.git
437+
Push URL: https://github.com/user/repo.git
438+
HEAD branch: develop
439+
Remote branches:
440+
develop tracked
441+
main tracked"""
442+
mock_repo.git.remote.return_value = remote_output
443+
444+
branches = _get_default_branches_for_merge_base(mock_repo)
445+
assert 'origin/develop' in branches
446+
assert 'develop' in branches
447+
448+
def test_both_git_methods_fail_fallback_to_hardcoded(self) -> None:
449+
"""Test fallback to hardcoded branches when both Git methods fail."""
450+
with temporary_git_repository() as (temp_dir, repo):
451+
# Create a mock repo where both Git methods fail
452+
mock_repo = Mock()
453+
mock_repo.git.symbolic_ref.side_effect = Exception('symbolic-ref failed')
454+
mock_repo.git.remote.side_effect = Exception('remote show failed')
455+
456+
branches = _get_default_branches_for_merge_base(mock_repo)
457+
# Should contain fallback branches
458+
assert 'origin/main' in branches
459+
assert 'origin/master' in branches
460+
assert 'main' in branches
461+
assert 'master' in branches
462+
463+
def test_no_duplicates_in_branch_list(self) -> None:
464+
"""Test that duplicate branches are not added to the list."""
465+
with temporary_git_repository() as (temp_dir, repo):
466+
# Create a mock repo that returns main (which is also in fallback list)
467+
mock_repo = Mock()
468+
mock_repo.git.symbolic_ref.return_value = 'refs/remotes/origin/main'
469+
470+
branches = _get_default_branches_for_merge_base(mock_repo)
471+
# Count occurrences of origin/main - should be exactly 1
472+
assert branches.count('origin/main') == 1
473+
assert branches.count('main') == 1
474+
475+
def test_env_var_plus_git_detection(self) -> None:
476+
"""Test combination of environment variable and git detection."""
477+
with temporary_git_repository() as (temp_dir, repo):
478+
mock_repo = Mock()
479+
mock_repo.git.symbolic_ref.return_value = 'refs/remotes/origin/develop'
480+
481+
with patch.dict(os.environ, {consts.CYCODE_DEFAULT_BRANCH_ENV_VAR_NAME: 'origin/custom'}):
482+
branches = _get_default_branches_for_merge_base(mock_repo)
483+
# Env var should be first
484+
assert branches[0] == 'origin/custom'
485+
# Git detected branches should also be present
486+
assert 'origin/develop' in branches
487+
assert 'develop' in branches
488+
489+
def test_malformed_symbolic_ref_response(self) -> None:
490+
"""Test handling of malformed symbolic-ref response."""
491+
with temporary_git_repository() as (temp_dir, repo):
492+
# Create a mock repo that returns a malformed response
493+
mock_repo = Mock()
494+
mock_repo.git.symbolic_ref.return_value = 'malformed-response'
495+
496+
branches = _get_default_branches_for_merge_base(mock_repo)
497+
# Should fall back to hardcoded branches
498+
assert 'origin/main' in branches
499+
assert 'origin/master' in branches
500+
501+
393502
class TestCalculatePrePushCommitRange:
394503
"""Test the calculate_pre_push_commit_range function with various Git repository scenarios."""
395504

@@ -501,6 +610,91 @@ def test_calculate_range_with_origin_main_as_merge_base(self) -> None:
501610
result = calculate_pre_push_commit_range(push_details)
502611
assert result == f'{main_commit.hexsha}..{feature_commit.hexsha}'
503612

613+
def test_calculate_range_with_environment_variable_override(self) -> None:
614+
"""Test that environment variable override works for commit range calculation."""
615+
with temporary_git_repository() as (temp_dir, repo):
616+
# Create custom default branch
617+
custom_file = os.path.join(temp_dir, 'custom.py')
618+
with open(custom_file, 'w') as f:
619+
f.write("print('custom')")
620+
621+
repo.index.add(['custom.py'])
622+
custom_commit = repo.index.commit('Custom branch commit')
623+
624+
# Create a custom branch
625+
repo.create_head('custom-main', custom_commit)
626+
627+
# Create a feature branch from custom
628+
feature_branch = repo.create_head('feature', custom_commit)
629+
feature_branch.checkout()
630+
631+
# Add feature commits
632+
feature_file = os.path.join(temp_dir, 'feature.py')
633+
with open(feature_file, 'w') as f:
634+
f.write("print('feature')")
635+
636+
repo.index.add(['feature.py'])
637+
feature_commit = repo.index.commit('Feature commit')
638+
639+
# Test new branch push with custom default branch
640+
push_details = f'refs/heads/feature {feature_commit.hexsha} refs/heads/feature {consts.EMPTY_COMMIT_SHA}'
641+
642+
with (
643+
patch('os.getcwd', return_value=temp_dir),
644+
patch.dict(os.environ, {consts.CYCODE_DEFAULT_BRANCH_ENV_VAR_NAME: 'custom-main'}),
645+
):
646+
result = calculate_pre_push_commit_range(push_details)
647+
assert result == f'{custom_commit.hexsha}..{feature_commit.hexsha}'
648+
649+
def test_calculate_range_with_git_symbolic_ref_detection(self) -> None:
650+
"""Test commit range calculation with Git symbolic-ref detection."""
651+
with temporary_git_repository() as (temp_dir, repo):
652+
# Create develop branch and commits
653+
develop_file = os.path.join(temp_dir, 'develop.py')
654+
with open(develop_file, 'w') as f:
655+
f.write("print('develop')")
656+
657+
repo.index.add(['develop.py'])
658+
develop_commit = repo.index.commit('Develop commit')
659+
660+
# Create origin/develop reference
661+
repo.create_head('origin/develop', develop_commit)
662+
repo.create_head('develop', develop_commit)
663+
664+
# Create a feature branch
665+
feature_branch = repo.create_head('feature', develop_commit)
666+
feature_branch.checkout()
667+
668+
# Add feature commits
669+
feature_file = os.path.join(temp_dir, 'feature.py')
670+
with open(feature_file, 'w') as f:
671+
f.write("print('feature')")
672+
673+
repo.index.add(['feature.py'])
674+
feature_commit = repo.index.commit('Feature commit')
675+
676+
# Test a new branch push with mocked default branch detection
677+
push_details = f'refs/heads/feature {feature_commit.hexsha} refs/heads/feature {consts.EMPTY_COMMIT_SHA}'
678+
679+
# # Mock the default branch detection to return origin/develop first
680+
with (
681+
patch('os.getcwd', return_value=temp_dir),
682+
patch(
683+
'cycode.cli.files_collector.commit_range_documents._get_default_branches_for_merge_base'
684+
) as mock_get_branches,
685+
):
686+
mock_get_branches.return_value = [
687+
'origin/develop',
688+
'develop',
689+
'origin/main',
690+
'main',
691+
'origin/master',
692+
'master',
693+
]
694+
with patch('cycode.cli.files_collector.commit_range_documents.git_proxy.get_repo', return_value=repo):
695+
result = calculate_pre_push_commit_range(push_details)
696+
assert result == f'{develop_commit.hexsha}..{feature_commit.hexsha}'
697+
504698
def test_calculate_range_with_origin_master_as_merge_base(self) -> None:
505699
"""Test calculating commit range using origin/master as a merge base."""
506700
with temporary_git_repository() as (temp_dir, repo):

0 commit comments

Comments
 (0)