diff --git a/openhands_resolver/resolve_issue.py b/openhands_resolver/resolve_issue.py index d6d39ba..9429cd7 100644 --- a/openhands_resolver/resolve_issue.py +++ b/openhands_resolver/resolve_issue.py @@ -309,6 +309,7 @@ async def resolve_issue( repo_instruction: str | None, issue_number: int, comment_id: int | None, + branch: str = "main", reset_logger: bool = False, ) -> None: """Resolve a single github issue. @@ -372,6 +373,15 @@ async def resolve_issue( if "fatal" in checkout_output: raise RuntimeError(f"Failed to clone repository: {checkout_output}") + # checkout the specified branch + if branch != "main": + checkout_branch = subprocess.check_output( + ["git", "checkout", branch], + cwd=repo_dir + ).decode("utf-8") + if "fatal" in checkout_branch: + raise RuntimeError(f"Failed to checkout branch {branch}: {checkout_branch}") + # get the commit id of current repo for reproducibility base_commit = ( subprocess.check_output( @@ -543,6 +553,12 @@ def int_or_none(value): choices=["issue", "pr"], help="Type of issue to resolve, either open issue or pr comments.", ) + parser.add_argument( + "--branch", + type=str, + default="main", + help="Branch to checkout and use as base (default: main)", + ) my_args = parser.parse_args() @@ -601,6 +617,7 @@ def int_or_none(value): repo_instruction=repo_instruction, issue_number=my_args.issue_number, comment_id=my_args.comment_id, + branch=my_args.branch, ) ) diff --git a/openhands_resolver/send_pull_request.py b/openhands_resolver/send_pull_request.py index 206b303..89c390c 100644 --- a/openhands_resolver/send_pull_request.py +++ b/openhands_resolver/send_pull_request.py @@ -201,6 +201,7 @@ def send_pull_request( pr_type: str, fork_owner: str | None = None, additional_message: str | None = None, + target_branch: str | None = None, ) -> str: if pr_type not in ["branch", "draft", "ready"]: raise ValueError(f"Invalid pr_type: {pr_type}") @@ -212,6 +213,14 @@ def send_pull_request( } base_url = f"https://api.github.com/repos/{github_issue.owner}/{github_issue.repo}" + # Use target_branch if specified, otherwise get the default branch + if not target_branch: + print("Getting default branch...") + response = requests.get(f"{base_url}", headers=headers) + response.raise_for_status() + target_branch = response.json()["default_branch"] + print(f"Target branch: {target_branch}") + # Create a new branch with a unique name base_branch_name = f"openhands-fix-issue-{github_issue.number}" branch_name = base_branch_name @@ -222,13 +231,6 @@ def send_pull_request( attempt += 1 branch_name = f"{base_branch_name}-try{attempt}" - # Get the default branch - print("Getting default branch...") - response = requests.get(f"{base_url}", headers=headers) - response.raise_for_status() - default_branch = response.json()["default_branch"] - print(f"Default branch: {default_branch}") - # Create and checkout the new branch print("Creating new branch...") result = subprocess.run( @@ -268,13 +270,13 @@ def send_pull_request( # If we are not sending a PR, we can finish early and return the # URL for the user to open a PR manually if pr_type == "branch": - url = f"https://github.com/{push_owner}/{github_issue.repo}/compare/{branch_name}?expand=1" + url = f"https://github.com/{push_owner}/{github_issue.repo}/compare/{target_branch}...{branch_name}?expand=1" else: data = { "title": pr_title, # No need to escape title for GitHub API "body": pr_body, "head": branch_name, - "base": default_branch, + "base": target_branch, "draft": pr_type == "draft", } response = requests.post(f"{base_url}/pulls", headers=headers, json=data) @@ -421,6 +423,7 @@ def process_single_issue( llm_config: LLMConfig, fork_owner: str | None, send_on_failure: bool, + target_branch: str | None = None, ) -> None: if not resolver_output.success and not send_on_failure: print( @@ -473,6 +476,7 @@ def process_single_issue( llm_config=llm_config, fork_owner=fork_owner, additional_message=resolver_output.success_explanation, + target_branch=target_branch, ) @@ -497,6 +501,7 @@ def process_all_successful_issues( llm_config, fork_owner, False, + None, ) @@ -562,6 +567,12 @@ def main(): default=None, help="Base URL for the LLM model.", ) + parser.add_argument( + "--target-branch", + type=str, + default=None, + help="Target branch to create the pull request against (if not specified, uses repo's default branch)", + ) my_args = parser.parse_args() github_token = ( @@ -610,6 +621,7 @@ def main(): llm_config, my_args.fork_owner, my_args.send_on_failure, + my_args.target_branch, ) if __name__ == "__main__":