diff --git a/.gitignore b/.gitignore index d652e03..d8e4034 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ dist/ rockset_stacky.egg-info/ src/rockset_stacky.egg-info/ build/ -src/stacky/__pycache__ +__pycache__/ bazel-* .mypy_cache diff --git a/BUILD.bazel b/BUILD.bazel index 7a1e4aa..c85d8b9 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -16,6 +16,7 @@ py_binary( requirement("ansicolors"), requirement("simple-term-menu"), requirement("asciitree"), + requirement("argcomplete"), ] ) diff --git a/README.md b/README.md index 65bd579..8e68e53 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,8 @@ There is also a [xar](https://github.com/facebookincubator/xar/) version it shou ### Pip ``` -pip3 install rockset-stacky +1. Clone this repository +2. From this repository root run `pip install -e .` ``` ### Manual @@ -20,12 +21,47 @@ pip3 install rockset-stacky 1. asciitree 2. ansicolors 3. simple-term-menu +4. argcomplete (for tab completion) ``` -pip3 install asciitree ansicolors simple-term-menu +pip3 install asciitree ansicolors simple-term-menu argcomplete ``` After which `stacky` can be directly run with `./src/stacky/stacky.py`. We would recommend symlinking `stacky.py` into your path so you can use it anywhere +## Tab Completion + +Stacky supports tab completion for branch names in bash and zsh. To enable it: + +### One-time setup +```bash +# Install argcomplete +pip3 install argcomplete + +# Enable global completion (recommended) +activate-global-python-argcomplete +``` + +### Per-session setup (alternative) +If you prefer not to use global completion, you can enable it per session: +```bash +# For bash/zsh +eval "$(register-python-argcomplete stacky)" +``` + +### Permanent setup (alternative) +Add the completion to your shell config: +```bash +# For bash - add to ~/.bashrc +eval "$(register-python-argcomplete stacky)" + +# For zsh - add to ~/.zshrc +eval "$(register-python-argcomplete stacky)" +``` + +After setup, you can use tab completion with commands like: +- `stacky checkout ` - completes branch names +- `stacky adopt ` - completes branch names +- `stacky branch checkout ` - completes branch names ## Accessing Github Stacky doesn't use any git or Github APIs. It expects `git` and `gh` cli commands to work and be properly configured. For instructions on installing the github cli `gh` please read their [documentation](https://cli.github.com/manual/). @@ -33,14 +69,18 @@ Stacky doesn't use any git or Github APIs. It expects `git` and `gh` cli command ## Usage `stacky` stores all information locally, within your git repository Syntax is as follows: -- `stacky info`: show all stacks , add `-pr` if you want to see GitHub PR numbers (slows things down a bit) +- `stacky info`: show all stacks , add `-pr` if you want to see GitHub PR numbers (slows things down a bit) +- `stacky inbox [--compact]`: show all active GitHub pull requests for the current user, organized by status (waiting on you, waiting on review, approved, and PRs awaiting your review). Use `--compact` or `-c` for a condensed one-line-per-PR view with clickable PR numbers. +- `stacky prs`: interactive PR management tool that allows you to select and edit PR descriptions. Shows a simple menu of all your open PRs and PRs awaiting your review, then opens your preferred editor (from `$EDITOR` environment variable) to modify the selected PR's description. - `stacky branch`: per branch commands (shortcut: `stacky b`) - `stacky branch up` (`stacky b u`): move down the stack (towards `master`) - - `stacky branch down` (`stacky b d`): move down the stack (towards `master`) + - `stacky branch down` (`stacky b d`): move down the stack (towards `master`) - `stacky branch new `: create a new branch on top of the current one -- `stacky commit [-m ] [--amend] [--allow-empty]`: wrapper around `git commit` that syncs everything upstack + - `stacky branch commit [-m ] [-a]`: create a new branch and commit changes in one command +- `stacky commit [-m ] [--amend] [--allow-empty] [-a]`: wrapper around `git commit` that syncs everything upstack - `stacky amend`: will amend currently tracked changes to top commit -- Based on the first argument (`stack` vs `upstack` vs `downstack`), the following commands operate on the entire current stack, everything upstack from the current PR (inclusive), or everything downstack from the current PR: +- `stacky fold [--allow-empty]`: fold current branch into its parent branch and delete the current branch. Any children of the current branch become children of the parent branch. Uses cherry-pick by default, or merge if `use_merge` is enabled in config. Use `--allow-empty` to allow empty commits during cherry-pick. +- Based on the first argument (`stack` vs `upstack` vs `downstack`), the following commands operate on the entire current stack, everything upstack from the current PR (inclusive), or everything downstack from the current PR: - `stacky stack info [--pr]` - `stacky stack sync`: sync (rebase) branches in the stack on top of their parents - `stacky stack push [--no-pr]`: push to origin, optionally not creating PRs if they don’t exist @@ -56,12 +96,12 @@ The indicators (`*`, `~`, `!`) mean: ``` $ stacky --help usage: stacky [-h] [--color {always,auto,never}] - {continue,info,commit,amend,branch,b,stack,s,upstack,us,downstack,ds,update,import,adopt,land,push,sync,checkout,co,sco} ... + {continue,info,commit,amend,branch,b,stack,s,upstack,us,downstack,ds,update,import,adopt,land,push,sync,checkout,co,sco,inbox,prs,fold} ... Handle git stacks positional arguments: - {continue,info,commit,amend,branch,b,stack,s,upstack,us,downstack,ds,update,import,adopt,land,push,sync,checkout,co,sco} + {continue,info,commit,amend,branch,b,stack,s,upstack,us,downstack,ds,update,import,adopt,land,push,sync,checkout,co,sco,inbox,prs,fold} continue Continue previously interrupted command info Stack info commit Commit @@ -71,12 +111,16 @@ positional arguments: upstack (us) Operations on the current upstack downstack (ds) Operations on the current downstack update Update repo + import Import Graphite stack adopt Adopt one branch land Land bottom-most PR on current stack push Alias for downstack push sync Alias for stack sync checkout (co) Checkout a branch sco Checkout a branch in this stack + inbox List all active GitHub pull requests for the current user + prs Interactive PR management - select and edit PR descriptions + fold Fold current branch into parent branch and delete current branch optional arguments: -h, --help show this help message and exit @@ -166,6 +210,7 @@ In the file you have sections and each sections define some parameters. We currently have the following sections: * UI + * GIT List of parameters for each sections: @@ -174,6 +219,44 @@ List of parameters for each sections: * change_to_main: boolean with a default value of `False`, by default `stacky` will stop doing action is you are not in a valid stack (ie. a branch that was created or adopted by stacky), when set to `True` `stacky` will first change to `main` or `master` *when* the current branch is not a valid stack. * change_to_adopted: boolean with a default value of `False`, when set to `True` `stacky` will change the current branch to the adopted one. * share_ssh_session: boolean with a default value of `False`, when set to `True` `stacky` will create a shared `ssh` session to the `github.com` server. This is useful when you are pushing a stack of diff and you have some kind of 2FA on your ssh key like the ed25519-sk. + * compact_pr_display: boolean with a default value of `False`, when set to `True` `stacky info --pr` will show a compact format displaying only the PR number and status emoji (✅ approved, ❌ changes requested, 🔄 waiting for review, 🚧 draft) without the PR title. Both compact and full formats include clickable links to the PRs. + * enable_stack_comment: boolean with a default value of `True`, when set to `False` `stacky` will not post stack comments to GitHub PRs showing the entire stack structure. Disable this if you don't want automated stack comments in your PR descriptions. + +### GIT + * use_merge: boolean with a default value of `False`, when set to `True` `stacky` will use `git merge` instead of `git rebase` for sync operations and `stacky fold` will merge the child branch into the parent instead of cherry-picking individual commits. + * use_force_push: boolean with a default value of `True`, controls whether `stacky` can use force push when pushing branches. + +### Example Configuration + +Here's a complete example of a `.stackyconfig` file with all available options: + +```ini +[UI] +# Skip confirmation prompts (useful for automation) +skip_confirm = False + +# Automatically change to main/master when not in a valid stack +change_to_main = False + +# Change to the adopted branch after running 'stacky adopt' +change_to_adopted = False + +# Create shared SSH session for multiple operations (helpful with 2FA) +share_ssh_session = False + +# Show compact format for 'stacky info --pr' (just number and emoji) +compact_pr_display = False + +# Enable posting stack comments to GitHub PRs +enable_stack_comment = True + +[GIT] +# Use git merge instead of rebase for sync operations +use_merge = False + +# Allow force push when pushing branches +use_force_push = True +``` ## License diff --git a/setup.py b/setup.py index 675b840..644eccc 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ package_dir={"": "src"}, packages=find_packages(where="src"), python_requires=">=3.8, <4", - install_requires=["asciitree", "ansicolors", "simple-term-menu"], + install_requires=["asciitree", "ansicolors", "simple-term-menu", "argcomplete"], entry_points={ "console_scripts": [ "stacky=stacky:main", diff --git a/src/stacky/__init__.py b/src/stacky/__init__.py index 1ba03f4..cce717e 100644 --- a/src/stacky/__init__.py +++ b/src/stacky/__init__.py @@ -1,4 +1,37 @@ -from .stacky import main +"""Stacky - GitHub helper for stacked diffs.""" + +from .main import main + +# Re-exports for backward compatibility with tests +from .utils.shell import _check_returncode, run, run_always_return, run_multiline +from .utils.logging import ( + die, cout, debug, info, warning, error, fmt, + COLOR_STDOUT, COLOR_STDERR, ExitException +) +from .utils.types import BranchName, Commit, CmdArgs, STACK_BOTTOMS +from .utils.config import StackyConfig, get_config, read_config +from .utils.ui import confirm, prompt + +from .git.branch import ( + get_current_branch, get_all_branches, get_top_level_dir, + get_stack_parent_branch, checkout, create_branch +) +from .git.remote import ( + get_remote_info, get_remote_type, gen_ssh_mux_cmd, + start_muxed_ssh, stop_muxed_ssh +) +from .git.refs import get_stack_parent_commit, set_parent_commit, get_commit + +from .stack.models import PRInfo, PRInfos, StackBranch, StackBranchSet +from .stack.tree import ( + get_all_stacks_as_forest, get_current_stack_as_forest, + get_current_downstack_as_forest, get_current_upstack_as_forest, + print_tree, print_forest, format_tree +) + +from .pr.github import find_issue_marker, get_pr_info, create_gh_pr + +from .commands.land import cmd_land def runner(): diff --git a/src/stacky/commands/__init__.py b/src/stacky/commands/__init__.py new file mode 100644 index 0000000..e630d94 --- /dev/null +++ b/src/stacky/commands/__init__.py @@ -0,0 +1 @@ +# Commands module - command handlers for stacky diff --git a/src/stacky/commands/branch.py b/src/stacky/commands/branch.py new file mode 100644 index 0000000..99418df --- /dev/null +++ b/src/stacky/commands/branch.py @@ -0,0 +1,57 @@ +"""Branch commands - new, commit, checkout.""" + +from stacky.commands.commit import do_commit +from stacky.git.branch import checkout, create_branch, get_current_branch_name, set_current_branch +from stacky.git.refs import get_commit +from stacky.stack.models import StackBranchSet +from stacky.stack.operations import load_stack_for_given_branch +from stacky.stack.tree import get_all_stacks_as_forest +from stacky.utils.shell import run +from stacky.utils.types import BranchName, CmdArgs +from stacky.utils.ui import menu_choose_branch + + +def cmd_branch_new(stack: StackBranchSet, args): + """Create a new branch on top of the current branch.""" + current_branch = get_current_branch_name() + b = stack.stack[current_branch] + assert b.commit + name = args.name + create_branch(name) + run(CmdArgs(["git", "update-ref", "refs/stack-parent/{}".format(name), b.commit, ""])) + + +def cmd_branch_commit(stack: StackBranchSet, args): + """Create a new branch and commit all changes with the provided message.""" + current_branch = get_current_branch_name() + b = stack.stack[current_branch] + assert b.commit + name = args.name + create_branch(name) + run(CmdArgs(["git", "update-ref", "refs/stack-parent/{}".format(name), b.commit, ""])) + + # Update global CURRENT_BRANCH since we just checked out the new branch + set_current_branch(BranchName(name)) + + # Reload the stack to include the new branch + load_stack_for_given_branch(stack, BranchName(name)) + + # Now commit all changes with the provided message + do_commit( + stack, + message=args.message, + amend=False, + allow_empty=False, + edit=True, + add_all=args.add_all, + no_verify=args.no_verify, + ) + + +def cmd_branch_checkout(stack: StackBranchSet, args): + """Checkout a branch (with menu if no name provided).""" + branch_name = args.name + if branch_name is None: + forest = get_all_stacks_as_forest(stack) + branch_name = menu_choose_branch(forest).name + checkout(branch_name) diff --git a/src/stacky/commands/commit.py b/src/stacky/commands/commit.py new file mode 100644 index 0000000..7a25eee --- /dev/null +++ b/src/stacky/commands/commit.py @@ -0,0 +1,70 @@ +"""Commit commands - commit, amend.""" + +from stacky.git.branch import get_current_branch_name +from stacky.git.refs import get_commit +from stacky.stack.models import StackBranchSet +from stacky.stack.operations import do_sync +from stacky.stack.tree import get_current_upstack_as_forest +from stacky.utils.config import get_config +from stacky.utils.logging import die +from stacky.utils.shell import run +from stacky.utils.types import CmdArgs + + +def do_commit(stack: StackBranchSet, *, message=None, amend=False, allow_empty=False, + edit=True, add_all=False, no_verify=False): + """Perform a commit operation.""" + current_branch = get_current_branch_name() + b = stack.stack[current_branch] + if not b.parent: + die("Do not commit directly on {}", b.name) + if not b.is_synced_with_parent(): + die( + "Branch {} is not synced with parent {}, sync before committing", + b.name, b.parent.name, + ) + + if amend and (get_config().use_merge or not get_config().use_force_push): + die("Amending is not allowed if using git merge or if force pushing is disallowed") + + if amend and b.commit == b.parent.commit: + die("Branch {} has no commits, may not amend", b.name) + + cmd = ["git", "commit"] + if add_all: + cmd += ["-a"] + if allow_empty: + cmd += ["--allow-empty"] + if no_verify: + cmd += ["--no-verify"] + if amend: + cmd += ["--amend"] + if not edit: + cmd += ["--no-edit"] + elif not edit: + die("--no-edit is only supported with --amend") + if message: + cmd += ["-m", message] + run(CmdArgs(cmd), out=True) + + # Sync everything upstack + b.commit = get_commit(b.name) + do_sync(get_current_upstack_as_forest(stack)) + + +def cmd_commit(stack: StackBranchSet, args): + """Commit command handler.""" + do_commit( + stack, + message=args.message, + amend=args.amend, + allow_empty=args.allow_empty, + edit=not args.no_edit, + add_all=args.add_all, + no_verify=args.no_verify, + ) + + +def cmd_amend(stack: StackBranchSet, args): + """Amend last commit (shortcut).""" + do_commit(stack, amend=True, edit=False, no_verify=args.no_verify) diff --git a/src/stacky/commands/downstack.py b/src/stacky/commands/downstack.py new file mode 100644 index 0000000..dd1d5f4 --- /dev/null +++ b/src/stacky/commands/downstack.py @@ -0,0 +1,30 @@ +"""Downstack commands - info, push, sync.""" + +from stacky.stack.models import StackBranchSet +from stacky.stack.operations import do_push, do_sync +from stacky.stack.tree import ( + get_current_downstack_as_forest, load_pr_info_for_forest, print_forest +) + + +def cmd_downstack_info(stack: StackBranchSet, args): + """Show info for current downstack.""" + forest = get_current_downstack_as_forest(stack) + if args.pr: + load_pr_info_for_forest(forest) + print_forest(forest) + + +def cmd_downstack_push(stack: StackBranchSet, args): + """Push current downstack.""" + do_push( + get_current_downstack_as_forest(stack), + force=args.force, + pr=args.pr, + remote_name=args.remote_name, + ) + + +def cmd_downstack_sync(stack: StackBranchSet, args): + """Sync current downstack.""" + do_sync(get_current_downstack_as_forest(stack)) diff --git a/src/stacky/commands/fold.py b/src/stacky/commands/fold.py new file mode 100644 index 0000000..e5b93dc --- /dev/null +++ b/src/stacky/commands/fold.py @@ -0,0 +1,199 @@ +"""Fold command - fold branch into parent.""" + +import json +import os +from typing import List + +from stacky.git.branch import checkout, get_current_branch_name, set_current_branch +from stacky.git.refs import get_commit, get_commits_between, set_parent, set_parent_commit +from stacky.stack.models import StackBranch, StackBranchSet +from stacky.utils.config import get_config +from stacky.utils.logging import cout, die, info +from stacky.utils.shell import run +from stacky.utils.types import BranchName, CmdArgs, STACK_BOTTOMS, STATE_FILE, TMP_STATE_FILE + + +def cmd_fold(stack: StackBranchSet, args): + """Fold current branch into parent branch and delete current branch.""" + current_branch = get_current_branch_name() + + if current_branch not in stack.stack: + die("Current branch {} is not in a stack", current_branch) + + b = stack.stack[current_branch] + + if not b.parent: + die("Cannot fold stack bottom branch {}", current_branch) + + if b.parent.name in STACK_BOTTOMS: + die("Cannot fold into stack bottom branch {}", b.parent.name) + + if not b.is_synced_with_parent(): + die( + "Branch {} is not synced with parent {}, sync before folding", + b.name, b.parent.name, + ) + + commits_to_apply = get_commits_between(b.parent_commit, b.commit) + if not commits_to_apply: + info("No commits to fold from {} into {}", b.name, b.parent.name) + else: + cout("Folding {} commits from {} into {}\n", len(commits_to_apply), b.name, b.parent.name, fg="green") + + children = list(b.children) + if children: + cout("Reparenting {} children to {}\n", len(children), b.parent.name, fg="yellow") + for child in children: + cout(" {} -> {}\n", child.name, b.parent.name, fg="gray") + + checkout(b.parent.name) + set_current_branch(b.parent.name) + + if get_config().use_merge: + inner_do_merge_fold(stack, b.name, b.parent.name, [child.name for child in children]) + else: + if commits_to_apply: + commits_to_apply = list(reversed(commits_to_apply)) + inner_do_fold(stack, b.name, b.parent.name, commits_to_apply, [child.name for child in children], args.allow_empty) + else: + finish_fold_operation(stack, b.name, b.parent.name, [child.name for child in children]) + + +def inner_do_merge_fold(stack: StackBranchSet, fold_branch_name: BranchName, parent_branch_name: BranchName, + children_names: List[BranchName]): + """Perform merge-based fold operation.""" + print() + current_branch = get_current_branch_name() + + with open(TMP_STATE_FILE, "w") as f: + json.dump({ + "branch": current_branch, + "merge_fold": { + "fold_branch": fold_branch_name, + "parent_branch": parent_branch_name, + "children": children_names, + } + }, f) + os.replace(TMP_STATE_FILE, STATE_FILE) + + cout("Merging {} into {}\n", fold_branch_name, parent_branch_name, fg="green") + result = run(CmdArgs(["git", "merge", fold_branch_name]), check=False) + if result is None: + die("Merge failed for branch {}. Please resolve conflicts and run `stacky continue`", fold_branch_name) + + finish_merge_fold_operation(stack, fold_branch_name, parent_branch_name, children_names) + + +def finish_merge_fold_operation(stack: StackBranchSet, fold_branch_name: BranchName, + parent_branch_name: BranchName, children_names: List[BranchName]): + """Complete merge-based fold operation.""" + fold_branch = stack.stack.get(fold_branch_name) + parent_branch = stack.stack[parent_branch_name] + + if not fold_branch: + cout("✓ Merge fold operation completed\n", fg="green") + return + + parent_branch.commit = get_commit(parent_branch_name) + + for child_name in children_names: + if child_name in stack.stack: + child = stack.stack[child_name] + info("Reparenting {} from {} to {}", child.name, fold_branch.name, parent_branch.name) + child.parent = parent_branch + parent_branch.children.add(child) + fold_branch.children.discard(child) + set_parent(child.name, parent_branch.name) + set_parent_commit(child.name, parent_branch.commit, child.parent_commit) + child.parent_commit = parent_branch.commit + + parent_branch.children.discard(fold_branch) + + info("Deleting branch {}", fold_branch.name) + run(CmdArgs(["git", "branch", "-D", fold_branch.name])) + run(CmdArgs(["git", "update-ref", "-d", "refs/stack-parent/{}".format(fold_branch.name)])) + stack.remove(fold_branch.name) + + cout("✓ Successfully merged and folded {} into {}\n", fold_branch.name, parent_branch.name, fg="green") + + +def inner_do_fold(stack: StackBranchSet, fold_branch_name: BranchName, parent_branch_name: BranchName, + commits_to_apply: List[str], children_names: List[BranchName], allow_empty: bool): + """Cherry-pick based fold operation.""" + print() + current_branch = get_current_branch_name() + + if not commits_to_apply: + finish_fold_operation(stack, fold_branch_name, parent_branch_name, children_names) + return + + while commits_to_apply: + with open(TMP_STATE_FILE, "w") as f: + json.dump({ + "branch": current_branch, + "fold": { + "fold_branch": fold_branch_name, + "parent_branch": parent_branch_name, + "commits": commits_to_apply, + "children": children_names, + "allow_empty": allow_empty + } + }, f) + os.replace(TMP_STATE_FILE, STATE_FILE) + + commit = commits_to_apply.pop() + + # Check if commit would be empty + dry_run_result = run(CmdArgs(["git", "cherry-pick", "--no-commit", commit]), check=False) + if dry_run_result is not None: + has_changes = run(CmdArgs(["git", "diff", "--cached", "--quiet"]), check=False) is None + run(CmdArgs(["git", "reset", "--hard", "HEAD"])) + if not has_changes: + cout("Skipping empty commit {}\n", commit[:8], fg="yellow") + continue + else: + run(CmdArgs(["git", "reset", "--hard", "HEAD"]), check=False) + + cout("Cherry-picking commit {}\n", commit[:8], fg="green") + cherry_pick_cmd = ["git", "cherry-pick"] + if allow_empty: + cherry_pick_cmd.append("--allow-empty") + cherry_pick_cmd.append(commit) + result = run(CmdArgs(cherry_pick_cmd), check=False) + if result is None: + die("Cherry-pick failed for commit {}. Please resolve conflicts and run `stacky continue`", commit) + + finish_fold_operation(stack, fold_branch_name, parent_branch_name, children_names) + + +def finish_fold_operation(stack: StackBranchSet, fold_branch_name: BranchName, + parent_branch_name: BranchName, children_names: List[BranchName]): + """Complete fold operation after commits applied.""" + fold_branch = stack.stack.get(fold_branch_name) + parent_branch = stack.stack[parent_branch_name] + + if not fold_branch: + cout("✓ Fold operation completed\n", fg="green") + return + + parent_branch.commit = get_commit(parent_branch_name) + + for child_name in children_names: + if child_name in stack.stack: + child = stack.stack[child_name] + info("Reparenting {} from {} to {}", child.name, fold_branch.name, parent_branch.name) + child.parent = parent_branch + parent_branch.children.add(child) + fold_branch.children.discard(child) + set_parent(child.name, parent_branch.name) + set_parent_commit(child.name, parent_branch.commit, child.parent_commit) + child.parent_commit = parent_branch.commit + + parent_branch.children.discard(fold_branch) + + info("Deleting branch {}", fold_branch.name) + run(CmdArgs(["git", "branch", "-D", fold_branch.name])) + run(CmdArgs(["git", "update-ref", "-d", "refs/stack-parent/{}".format(fold_branch.name)])) + stack.remove(fold_branch.name) + + cout("✓ Successfully folded {} into {}\n", fold_branch.name, parent_branch.name, fg="green") diff --git a/src/stacky/commands/inbox.py b/src/stacky/commands/inbox.py new file mode 100644 index 0000000..59e1b50 --- /dev/null +++ b/src/stacky/commands/inbox.py @@ -0,0 +1,171 @@ +"""Inbox commands - inbox, prs.""" + +import json + +from simple_term_menu import TerminalMenu # type: ignore + +from stacky.pr.github import edit_pr_description +from stacky.stack.models import StackBranchSet +from stacky.utils.logging import IS_TERMINAL, cout, die +from stacky.utils.shell import run_always_return +from stacky.utils.types import CmdArgs + + +def cmd_inbox(stack: StackBranchSet, args): + """List all active GitHub pull requests for the current user.""" + fields = [ + "number", "title", "headRefName", "baseRefName", "state", "url", + "createdAt", "updatedAt", "author", "reviewDecision", "reviewRequests", + "mergeable", "mergeStateStatus", "statusCheckRollup", "isDraft", "body" + ] + + my_prs_data = json.loads( + run_always_return(CmdArgs([ + "gh", "pr", "list", "--json", ",".join(fields), + "--state", "open", "--author", "@me" + ])) + ) + + review_prs_data = json.loads( + run_always_return(CmdArgs([ + "gh", "pr", "list", "--json", ",".join(fields), + "--state", "open", "--search", "review-requested:@me" + ])) + ) + + # Categorize PRs + waiting_on_me = [] + waiting_on_review = [] + approved = [] + + for pr in my_prs_data: + if pr.get("isDraft", False): + waiting_on_me.append(pr) + elif pr["reviewDecision"] == "APPROVED": + approved.append(pr) + elif pr["reviewRequests"] and len(pr["reviewRequests"]) > 0: + waiting_on_review.append(pr) + else: + waiting_on_me.append(pr) + + # Sort by updatedAt + for lst in [waiting_on_me, waiting_on_review, approved, review_prs_data]: + lst.sort(key=lambda pr: pr["updatedAt"], reverse=True) + + def get_check_status(pr): + if not pr.get("statusCheckRollup") or len(pr.get("statusCheckRollup")) == 0: + return "", "gray" + rollup = pr["statusCheckRollup"] + states = [check["state"] for check in rollup if isinstance(check, dict) and "state" in check] + if not states: + return "", "gray" + if "FAILURE" in states or "ERROR" in states: + return "✗ Checks failed", "red" + elif "PENDING" in states or "QUEUED" in states: + return "⏳ Checks running", "yellow" + elif all(state == "SUCCESS" for state in states): + return "✓ Checks passed", "green" + return "Checks mixed", "yellow" + + def display_pr_compact(pr, show_author=False): + check_text, check_color = get_check_status(pr) + pr_number_text = f"#{pr['number']}" + clickable_number = f"\033]8;;{pr['url']}\033\\\033[96m{pr_number_text}\033[0m\033]8;;\033\\" + cout("{} ", clickable_number) + cout("{} ", pr["title"], fg="white") + cout("({}) ", pr["headRefName"], fg="gray") + if show_author: + cout("by {} ", pr["author"]["login"], fg="gray") + if pr.get("isDraft", False): + cout("[DRAFT] ", fg="orange") + if check_text: + cout("{} ", check_text, fg=check_color) + cout("Updated: {}\n", pr["updatedAt"][:10], fg="gray") + + def display_pr_full(pr, show_author=False): + check_text, check_color = get_check_status(pr) + pr_number_text = f"#{pr['number']}" + clickable_number = f"\033]8;;{pr['url']}\033\\\033[96m{pr_number_text}\033[0m\033]8;;\033\\" + cout("{} ", clickable_number) + cout("{}\n", pr["title"], fg="white") + cout(" {} -> {}\n", pr["headRefName"], pr["baseRefName"], fg="gray") + if show_author: + cout(" Author: {}\n", pr["author"]["login"], fg="gray") + if pr.get("isDraft", False): + cout(" [DRAFT]\n", fg="orange") + if check_text: + cout(" {}\n", check_text, fg=check_color) + cout(" {}\n", pr["url"], fg="blue") + cout(" Updated: {}, Created: {}\n\n", pr["updatedAt"][:10], pr["createdAt"][:10], fg="gray") + + def display_pr_list(prs, show_author=False): + for pr in prs: + if args.compact: + display_pr_compact(pr, show_author) + else: + display_pr_full(pr, show_author) + + if waiting_on_me: + cout("Your PRs - Waiting on You:\n", fg="red") + display_pr_list(waiting_on_me) + cout("\n") + if waiting_on_review: + cout("Your PRs - Waiting on Review:\n", fg="yellow") + display_pr_list(waiting_on_review) + cout("\n") + if approved: + cout("Your PRs - Approved:\n", fg="green") + display_pr_list(approved) + cout("\n") + if not my_prs_data: + cout("No active pull requests authored by you.\n", fg="green") + if review_prs_data: + cout("Pull Requests Awaiting Your Review:\n", fg="yellow") + display_pr_list(review_prs_data, show_author=True) + else: + cout("No pull requests awaiting your review.\n", fg="yellow") + + +def cmd_prs(stack: StackBranchSet, args): + """Interactive PR management - select and edit PR descriptions.""" + fields = [ + "number", "title", "headRefName", "baseRefName", "state", "url", + "createdAt", "updatedAt", "author", "reviewDecision", "reviewRequests", + "mergeable", "mergeStateStatus", "statusCheckRollup", "isDraft", "body" + ] + + my_prs_data = json.loads( + run_always_return(CmdArgs([ + "gh", "pr", "list", "--json", ",".join(fields), + "--state", "open", "--author", "@me" + ])) + ) + + review_prs_data = json.loads( + run_always_return(CmdArgs([ + "gh", "pr", "list", "--json", ",".join(fields), + "--state", "open", "--search", "review-requested:@me" + ])) + ) + + all_prs = my_prs_data + review_prs_data + if not all_prs: + cout("No active pull requests found.\n", fg="green") + return + + if not IS_TERMINAL: + die("Interactive PR management requires a terminal") + + menu_options = [f"#{pr['number']} {pr['title']}" for pr in all_prs] + menu_options.append("Exit") + + while True: + cout("\nSelect a PR to edit its description:\n", fg="cyan") + menu = TerminalMenu(menu_options, cursor_index=0) + idx = menu.show() + + if idx is None or idx == len(menu_options) - 1: + break + + selected_pr = all_prs[idx] + edit_pr_description(selected_pr) diff --git a/src/stacky/commands/land.py b/src/stacky/commands/land.py new file mode 100644 index 0000000..bf86d81 --- /dev/null +++ b/src/stacky/commands/land.py @@ -0,0 +1,77 @@ +"""Land command - land a PR.""" + +import sys + +from stacky.git.branch import get_current_branch_name +from stacky.stack.models import StackBranchSet +from stacky.stack.tree import get_current_downstack_as_forest +from stacky.utils.logging import COLOR_STDOUT, cout, die, fmt +from stacky.utils.shell import run +from stacky.utils.types import CmdArgs, Commit +from stacky.utils.ui import confirm + + +def cmd_land(stack: StackBranchSet, args): + """Land bottom-most PR on current stack.""" + current_branch = get_current_branch_name() + forest = get_current_downstack_as_forest(stack) + assert len(forest) == 1 + branches = [] + p = forest[0] + while p: + assert len(p) == 1 + _, (b, p) = next(iter(p.items())) + branches.append(b) + assert branches + assert branches[0] in stack.bottoms + if len(branches) == 1: + die("May not land {}", branches[0].name) + + b = branches[1] + if not b.is_synced_with_parent(): + die( + "Branch {} is not synced with parent {}, sync before landing", + b.name, b.parent.name, + ) + if not b.is_synced_with_remote(): + die( + "Branch {} is not synced with remote branch, push local changes before landing", + b.name, + ) + + b.load_pr_info() + pr = b.open_pr_info + if not pr: + die("Branch {} does not have an open PR", b.name) + assert pr is not None + + if pr["mergeable"] != "MERGEABLE": + die( + "PR #{} for branch {} is not mergeable: {}", + pr["number"], b.name, pr["mergeable"], + ) + + if len(branches) > 2: + cout( + "The `land` command only lands the bottom-most branch {}; " + "the current stack has {} branches, ending with {}\n", + b.name, len(branches) - 1, current_branch, fg="yellow", + ) + + msg = fmt("- Will land PR #{} (", pr["number"], color=COLOR_STDOUT) + msg += fmt("{}", pr["url"], color=COLOR_STDOUT, fg="blue") + msg += fmt(") for branch {}", b.name, color=COLOR_STDOUT) + msg += fmt(" into branch {}\n", b.parent.name, color=COLOR_STDOUT) + sys.stdout.write(msg) + + if not args.force: + confirm() + + v = run(CmdArgs(["git", "rev-parse", b.name])) + assert v is not None + head_commit = Commit(v) + cmd = CmdArgs(["gh", "pr", "merge", b.name, "--squash", "--match-head-commit", head_commit]) + if args.auto: + cmd.append("--auto") + run(cmd, out=True) + cout("\n✓ Success! Run `stacky update` to update local state.\n", fg="green") diff --git a/src/stacky/commands/navigation.py b/src/stacky/commands/navigation.py new file mode 100644 index 0000000..29b5a3e --- /dev/null +++ b/src/stacky/commands/navigation.py @@ -0,0 +1,66 @@ +"""Navigation commands - info, log, up, down.""" + +from stacky.git.branch import checkout, get_current_branch_name +from stacky.stack.models import StackBranchSet +from stacky.stack.tree import ( + get_all_stacks_as_forest, load_pr_info_for_forest, print_forest +) +from stacky.utils.config import get_config +from stacky.utils.logging import IS_TERMINAL, cout, die, info +from stacky.utils.shell import run +from stacky.utils.types import BranchesTreeForest, BranchName, BranchesTree +from stacky.utils.ui import menu_choose_branch + + +def cmd_info(stack: StackBranchSet, args): + """Show info for all stacks.""" + forest = get_all_stacks_as_forest(stack) + if args.pr: + load_pr_info_for_forest(forest) + print_forest(forest) + + +def cmd_log(stack: StackBranchSet, args): + """Show git log with conditional merge handling.""" + config = get_config() + if config.use_merge: + run(["git", "log", "--no-merges", "--first-parent"], out=True) + else: + run(["git", "log"], out=True) + + +def cmd_branch_up(stack: StackBranchSet, args): + """Move up in the stack (away from master/main).""" + current_branch = get_current_branch_name() + b = stack.stack[current_branch] + if not b.children: + info("Branch {} is already at the top of the stack", current_branch) + return + if len(b.children) > 1: + if not IS_TERMINAL: + die( + "Branch {} has multiple children: {}", + current_branch, ", ".join(c.name for c in b.children), + ) + cout( + "Branch {} has {} children, choose one\n", + current_branch, len(b.children), fg="green", + ) + forest = BranchesTreeForest([ + BranchesTree({BranchName(c.name): (c, BranchesTree({}))}) + for c in b.children + ]) + child = menu_choose_branch(forest).name + else: + child = next(iter(b.children)).name + checkout(child) + + +def cmd_branch_down(stack: StackBranchSet, args): + """Move down in the stack (towards master/main).""" + current_branch = get_current_branch_name() + b = stack.stack[current_branch] + if not b.parent: + info("Branch {} is already at the bottom of the stack", current_branch) + return + checkout(b.parent.name) diff --git a/src/stacky/commands/stack.py b/src/stacky/commands/stack.py new file mode 100644 index 0000000..bf49a58 --- /dev/null +++ b/src/stacky/commands/stack.py @@ -0,0 +1,39 @@ +"""Stack commands - stack info, push, sync, checkout.""" + +from stacky.git.branch import checkout +from stacky.stack.models import StackBranchSet +from stacky.stack.operations import do_push, do_sync +from stacky.stack.tree import ( + get_current_stack_as_forest, load_pr_info_for_forest, print_forest +) +from stacky.utils.ui import menu_choose_branch + + +def cmd_stack_info(stack: StackBranchSet, args): + """Show info for current stack.""" + forest = get_current_stack_as_forest(stack) + if args.pr: + load_pr_info_for_forest(forest) + print_forest(forest) + + +def cmd_stack_push(stack: StackBranchSet, args): + """Push current stack.""" + do_push( + get_current_stack_as_forest(stack), + force=args.force, + pr=args.pr, + remote_name=args.remote_name, + ) + + +def cmd_stack_sync(stack: StackBranchSet, args): + """Sync current stack.""" + do_sync(get_current_stack_as_forest(stack)) + + +def cmd_stack_checkout(stack: StackBranchSet, args): + """Checkout a branch in current stack.""" + forest = get_current_stack_as_forest(stack) + branch_name = menu_choose_branch(forest).name + checkout(branch_name) diff --git a/src/stacky/commands/update.py b/src/stacky/commands/update.py new file mode 100644 index 0000000..702e017 --- /dev/null +++ b/src/stacky/commands/update.py @@ -0,0 +1,129 @@ +"""Update commands - update, import, adopt.""" + +from stacky.git.branch import get_current_branch_name, get_real_stack_bottom, set_current_branch +from stacky.git.refs import get_merge_base, set_parent, set_parent_commit +from stacky.git.remote import start_muxed_ssh, stop_muxed_ssh +from stacky.pr.github import get_pr_info +from stacky.stack.models import StackBranch, StackBranchSet +from stacky.stack.operations import cleanup_unused_refs, delete_branches, get_branches_to_delete +from stacky.stack.tree import get_bottom_level_branches_as_forest, load_pr_info_for_forest +from stacky.utils.config import get_config +from stacky.utils.logging import cout, die, info +from stacky.utils.shell import run, run_always_return +from stacky.utils.types import BranchName, CmdArgs, Commit, FROZEN_STACK_BOTTOMS, STACK_BOTTOMS +from stacky.utils.ui import confirm + + +def cmd_update(stack: StackBranchSet, args): + """Update repo from remote.""" + remote = "origin" + start_muxed_ssh(remote) + info("Fetching from {}", remote) + run(CmdArgs(["git", "fetch", remote])) + + current_branch = get_current_branch_name() + for b in stack.bottoms: + run( + CmdArgs([ + "git", "update-ref", + "refs/heads/{}".format(b.name), + "refs/remotes/{}/{}".format(remote, b.remote_branch), + ]) + ) + if b.name == current_branch: + run(CmdArgs(["git", "reset", "--hard", "HEAD"])) + + info("Checking if any PRs have been merged and can be deleted") + forest = get_bottom_level_branches_as_forest(stack) + load_pr_info_for_forest(forest) + + deletes = get_branches_to_delete(forest) + if deletes and not args.force: + confirm() + + delete_branches(stack, deletes) + stop_muxed_ssh(remote) + + info("Cleaning up refs for non-existent branches") + cleanup_unused_refs(stack) + + +def cmd_import(stack: StackBranchSet, args): + """Import Graphite stack.""" + branch = args.name + branches = [] + bottoms = set(b.name for b in stack.bottoms) + while branch not in bottoms: + pr_info = get_pr_info(branch, full=True) + open_pr = pr_info.open + info("Getting PR information for {}", branch) + if open_pr is None: + die("Branch {} has no open PR", branch) + assert open_pr is not None + if open_pr["headRefName"] != branch: + die( + "Branch {} is misconfigured: PR #{} head is {}", + branch, open_pr["number"], open_pr["headRefName"], + ) + if not open_pr["commits"]: + die("PR #{} has no commits", open_pr["number"]) + first_commit = open_pr["commits"][0]["oid"] + parent_commit = Commit(run_always_return(CmdArgs(["git", "rev-parse", "{}^".format(first_commit)]))) + next_branch = open_pr["baseRefName"] + info( + "Branch {}: PR #{}, parent is {} at commit {}", + branch, open_pr["number"], next_branch, parent_commit, + ) + branches.append((branch, parent_commit)) + branch = next_branch + + if not branches: + return + + base_branch = branch + branches.reverse() + + for b, parent_commit in branches: + cout("- Will set parent of {} to {} at commit {}\n", b, branch, parent_commit) + branch = b + + if not args.force: + confirm() + + branch = base_branch + for b, parent_commit in branches: + set_parent(b, branch, set_origin=True) + set_parent_commit(b, parent_commit) + branch = b + + +def cmd_adopt(stack: StackBranch, args): + """Adopt a branch onto current stack bottom.""" + branch = args.name + current_branch = get_current_branch_name() + + if branch == current_branch: + die("A branch cannot adopt itself") + + if current_branch not in STACK_BOTTOMS: + main_branch = get_real_stack_bottom() + if get_config().change_to_main and main_branch is not None: + run(CmdArgs(["git", "checkout", main_branch])) + set_current_branch(main_branch) + current_branch = main_branch + else: + die( + "The current branch {} must be a valid stack bottom: {}", + current_branch, ", ".join(sorted(STACK_BOTTOMS)), + ) + + if branch in STACK_BOTTOMS: + if branch in FROZEN_STACK_BOTTOMS: + die("Cannot adopt frozen stack bottoms {}".format(FROZEN_STACK_BOTTOMS)) + run(CmdArgs(["git", "update-ref", "-d", "refs/stacky-bottom-branch/{}".format(branch)])) + + parent_commit = get_merge_base(current_branch, branch) + set_parent(branch, current_branch, set_origin=True) + set_parent_commit(branch, parent_commit) + if get_config().change_to_adopted: + run(CmdArgs(["git", "checkout", branch])) diff --git a/src/stacky/commands/upstack.py b/src/stacky/commands/upstack.py new file mode 100644 index 0000000..73f1e53 --- /dev/null +++ b/src/stacky/commands/upstack.py @@ -0,0 +1,76 @@ +"""Upstack commands - info, push, sync, onto, as.""" + +from stacky.git.branch import get_current_branch_name +from stacky.git.refs import set_parent +from stacky.stack.models import StackBranchSet +from stacky.stack.operations import do_push, do_sync +from stacky.stack.tree import ( + forest_depth_first, get_current_upstack_as_forest, + load_pr_info_for_forest, print_forest +) +from stacky.utils.logging import die, info +from stacky.utils.shell import run +from stacky.utils.types import CmdArgs + + +def cmd_upstack_info(stack: StackBranchSet, args): + """Show info for current upstack.""" + forest = get_current_upstack_as_forest(stack) + if args.pr: + load_pr_info_for_forest(forest) + print_forest(forest) + + +def cmd_upstack_push(stack: StackBranchSet, args): + """Push current upstack.""" + do_push( + get_current_upstack_as_forest(stack), + force=args.force, + pr=args.pr, + remote_name=args.remote_name, + ) + + +def cmd_upstack_sync(stack: StackBranchSet, args): + """Sync current upstack.""" + do_sync(get_current_upstack_as_forest(stack)) + + +def cmd_upstack_onto(stack: StackBranchSet, args): + """Move current upstack onto a different parent.""" + current_branch = get_current_branch_name() + b = stack.stack[current_branch] + if not b.parent: + die("may not upstack a stack bottom, use stacky adopt") + target = stack.stack[args.target] + upstack = get_current_upstack_as_forest(stack) + for ub in forest_depth_first(upstack): + if ub == target: + die("Target branch {} is upstack of {}", target.name, b.name) + b.parent = target + set_parent(b.name, target.name) + do_sync(upstack) + + +def cmd_upstack_as_base(stack: StackBranchSet): + """Set current branch as a new stack bottom.""" + current_branch = get_current_branch_name() + b = stack.stack[current_branch] + if not b.parent: + die("Branch {} is already a stack bottom", b.name) + + b.parent = None # type: ignore + stack.remove(b.name) + stack.addStackBranch(b) + set_parent(b.name, None) + + run(CmdArgs(["git", "update-ref", "refs/stacky-bottom-branch/{}".format(b.name), b.commit, ""])) + info("Set {} as new bottom branch".format(b.name)) + + +def cmd_upstack_as(stack: StackBranchSet, args): + """Upstack branch as something (e.g., bottom).""" + if args.target == "bottom": + cmd_upstack_as_base(stack) + else: + die("Invalid target {}, acceptable targets are [base]", args.target) diff --git a/src/stacky/git/__init__.py b/src/stacky/git/__init__.py new file mode 100644 index 0000000..c237e48 --- /dev/null +++ b/src/stacky/git/__init__.py @@ -0,0 +1 @@ +# Git module - git operations for stacky diff --git a/src/stacky/git/branch.py b/src/stacky/git/branch.py new file mode 100644 index 0000000..f68aacf --- /dev/null +++ b/src/stacky/git/branch.py @@ -0,0 +1,104 @@ +"""Branch operations for stacky.""" + +from typing import List, Optional + +from stacky.utils.logging import info +from stacky.utils.shell import remove_prefix, run, run_always_return, run_multiline +from stacky.utils.types import BranchName, CmdArgs, PathName, STACK_BOTTOMS + +# Global current branch - set by init_git() +CURRENT_BRANCH: BranchName = BranchName("") + + +def get_current_branch_name() -> BranchName: + """Get the current branch name (from global state).""" + return CURRENT_BRANCH + + +def set_current_branch(branch: BranchName): + """Set the current branch (global state).""" + global CURRENT_BRANCH + CURRENT_BRANCH = branch + + +def get_current_branch() -> Optional[BranchName]: + """Get the current branch from git.""" + s = run(CmdArgs(["git", "symbolic-ref", "-q", "HEAD"])) + if s is not None: + return BranchName(remove_prefix(s, "refs/heads/")) + return None + + +def get_all_branches() -> List[BranchName]: + """Get all local branches.""" + branches = run_multiline(CmdArgs(["git", "for-each-ref", "--format", "%(refname:short)", "refs/heads"])) + assert branches is not None + return [BranchName(b) for b in branches.split("\n") if b] + + +def branch_name_completer(prefix, parsed_args, **kwargs): + """Argcomplete completer function for branch names.""" + try: + branches = get_all_branches() + return [branch for branch in branches if branch.startswith(prefix)] + except Exception: + return [] + + +def get_real_stack_bottom() -> Optional[BranchName]: + """Return the actual stack bottom for this current repo.""" + branches = get_all_branches() + candidates = set() + for b in branches: + if b in STACK_BOTTOMS: + candidates.add(b) + + if len(candidates) == 1: + return candidates.pop() + return None + + +def get_stack_parent_branch(branch: BranchName) -> Optional[BranchName]: + """Get the parent branch of a stack branch.""" + if branch in STACK_BOTTOMS: + return None + p = run(CmdArgs(["git", "config", "branch.{}.merge".format(branch)]), check=False) + if p is not None: + p = remove_prefix(p, "refs/heads/") + if BranchName(p) == branch: + return None + return BranchName(p) + return None + + +def get_top_level_dir() -> PathName: + """Get the top-level directory of the git repository.""" + p = run_always_return(CmdArgs(["git", "rev-parse", "--show-toplevel"])) + return PathName(p) + + +def checkout(branch: BranchName): + """Checkout a branch.""" + info("Checking out branch {}", branch) + run(["git", "checkout", branch], out=True) + + +def create_branch(branch: BranchName): + """Create a new branch tracking current branch.""" + run(["git", "checkout", "-b", branch, "--track"], out=True) + + +def init_git(): + """Initialize git state for stacky.""" + from stacky.utils.logging import die + + push_default = run(["git", "config", "remote.pushDefault"], check=False) + if push_default is not None: + die("`git config remote.pushDefault` may not be set") + auth_status = run(["gh", "auth", "status"], check=False) + if auth_status is None: + die("`gh` authentication failed") + global CURRENT_BRANCH + current = get_current_branch() + if current is not None: + CURRENT_BRANCH = current diff --git a/src/stacky/git/refs.py b/src/stacky/git/refs.py new file mode 100644 index 0000000..e4d63ed --- /dev/null +++ b/src/stacky/git/refs.py @@ -0,0 +1,107 @@ +"""Git ref operations for stacky.""" + +from typing import List, Optional + +from stacky.utils.logging import die +from stacky.utils.shell import run, run_multiline +from stacky.utils.types import BranchName, CmdArgs, Commit + + +def get_stack_parent_commit(branch: BranchName) -> Optional[Commit]: + """Get the parent commit of a stack branch.""" + c = run( + CmdArgs(["git", "rev-parse", "refs/stack-parent/{}".format(branch)]), + check=False, + ) + if c is not None: + return Commit(c) + return None + + +def get_commit(branch: BranchName) -> Commit: + """Get the current commit of a branch.""" + c = run(CmdArgs(["git", "rev-parse", "refs/heads/{}".format(branch)]), check=False) + assert c is not None + return Commit(c) + + +def set_parent_commit(branch: BranchName, new_commit: Commit, prev_commit: Optional[str] = None): + """Set the parent commit ref for a branch.""" + cmd = [ + "git", + "update-ref", + "refs/stack-parent/{}".format(branch), + new_commit, + ] + if prev_commit is not None: + cmd.append(prev_commit) + run(CmdArgs(cmd)) + + +def set_parent(branch: BranchName, target: Optional[BranchName], *, set_origin: bool = False): + """Set the parent branch for a stack branch.""" + if set_origin: + run(CmdArgs(["git", "config", "branch.{}.remote".format(branch), "."])) + + # If target is none this becomes a new stack bottom + run( + CmdArgs( + [ + "git", + "config", + "branch.{}.merge".format(branch), + "refs/heads/{}".format(target if target is not None else branch), + ] + ) + ) + + if target is None: + run( + CmdArgs( + [ + "git", + "update-ref", + "-d", + "refs/stack-parent/{}".format(branch), + ] + ) + ) + + +def get_branch_name_from_short_ref(ref: str) -> BranchName: + """Extract branch name from a short ref like 'stack-parent/branch'.""" + parts = ref.split("/", 1) + if len(parts) != 2: + die("invalid ref: {}".format(ref)) + return BranchName(parts[1]) + + +def get_all_stack_bottoms() -> List[BranchName]: + """Get all custom stack bottom branches.""" + branches = run_multiline( + CmdArgs(["git", "for-each-ref", "--format", "%(refname:short)", "refs/stacky-bottom-branch"]) + ) + if branches: + return [get_branch_name_from_short_ref(b) for b in branches.split("\n") if b] + return [] + + +def get_all_stack_parent_refs() -> List[BranchName]: + """Get all branches that have stack-parent refs.""" + branches = run_multiline(CmdArgs(["git", "for-each-ref", "--format", "%(refname:short)", "refs/stack-parent"])) + if branches: + return [get_branch_name_from_short_ref(b) for b in branches.split("\n") if b] + return [] + + +def get_commits_between(a: Commit, b: Commit) -> List[str]: + """Get list of commits between two refs.""" + lines = run_multiline(CmdArgs(["git", "rev-list", "{}..{}".format(a, b)])) + assert lines is not None + # Have to strip the last element because it's empty, rev list includes a new line at the end + return [x.strip() for x in lines.split("\n")][:-1] + + +def get_merge_base(b1: BranchName, b2: BranchName) -> Optional[str]: + """Get the merge base of two branches.""" + return run(CmdArgs(["git", "merge-base", str(b1), str(b2)])) diff --git a/src/stacky/git/remote.py b/src/stacky/git/remote.py new file mode 100644 index 0000000..4e53f5a --- /dev/null +++ b/src/stacky/git/remote.py @@ -0,0 +1,100 @@ +"""Remote and SSH operations for stacky.""" + +import os +import re +import subprocess +import time +from typing import Optional, Tuple + +from stacky.utils.config import get_config +from stacky.utils.logging import die, error, info +from stacky.utils.shell import run, run_always_return +from stacky.utils.types import BranchName, CmdArgs, Commit, MAX_SSH_MUX_LIFETIME, STACK_BOTTOMS + + +def get_remote_info(branch: BranchName) -> Tuple[str, BranchName, Optional[Commit]]: + """Get remote info for a branch: (remote, remote_branch, remote_branch_commit).""" + if branch not in STACK_BOTTOMS: + remote = run(CmdArgs(["git", "config", "branch.{}.remote".format(branch)]), check=False) + if remote != ".": + die("Misconfigured branch {}: remote {}", branch, remote) + + # TODO(tudor): Maybe add a way to change these. + remote = "origin" + remote_branch = branch + + remote_commit = run( + CmdArgs(["git", "rev-parse", "refs/remotes/{}/{}".format(remote, remote_branch)]), + check=False, + ) + + commit = None + if remote_commit is not None: + commit = Commit(remote_commit) + + return (remote, BranchName(remote_branch), commit) + + +def get_remote_type(remote: str = "origin") -> Optional[str]: + """Get the SSH host type for a remote.""" + out = run_always_return(CmdArgs(["git", "remote", "-v"])) + for l in out.split("\n"): + match = re.match(r"^{}\s+(?:ssh://)?([^/]*):(?!//).*\s+\(push\)$".format(remote), l) + if match: + sshish_host = match.group(1) + return sshish_host + + return None + + +def gen_ssh_mux_cmd() -> list[str]: + """Generate SSH multiplexing command arguments.""" + args = [ + "ssh", + "-o", + "ControlMaster=auto", + "-o", + f"ControlPersist={MAX_SSH_MUX_LIFETIME}", + "-o", + "ControlPath=~/.ssh/stacky-%C", + ] + return args + + +def start_muxed_ssh(remote: str = "origin"): + """Start a multiplexed SSH connection.""" + if not get_config().share_ssh_session: + return + hostish = get_remote_type(remote) + if hostish is not None: + info("Creating a muxed ssh connection") + cmd = gen_ssh_mux_cmd() + os.environ["GIT_SSH_COMMAND"] = " ".join(cmd) + cmd.append("-MNf") + cmd.append(hostish) + # We don't want to use the run() wrapper because + # we don't want to wait for the process to finish + + p = subprocess.Popen(cmd, stderr=subprocess.PIPE) + # Wait a little bit for the connection to establish + # before carrying on + while p.poll() is None: + time.sleep(1) + if p.returncode != 0: + if p.stderr is not None: + err = p.stderr.read() + else: + err = b"unknown" + die(f"Failed to start ssh muxed connection, error was: {err.decode('utf-8').strip()}") + + +def stop_muxed_ssh(remote: str = "origin"): + """Stop a multiplexed SSH connection.""" + if get_config().share_ssh_session: + hostish = get_remote_type(remote) + if hostish is not None: + cmd = gen_ssh_mux_cmd() + cmd.append("-O") + cmd.append("exit") + cmd.append(hostish) + subprocess.Popen(cmd, stderr=subprocess.DEVNULL) diff --git a/src/stacky/main.py b/src/stacky/main.py new file mode 100644 index 0000000..1536959 --- /dev/null +++ b/src/stacky/main.py @@ -0,0 +1,324 @@ +"""Main entry point for stacky.""" + +import json +import logging +import os +import sys +from argparse import ArgumentParser + +import argcomplete # type: ignore + +from stacky.git.branch import ( + branch_name_completer, get_current_branch_name, get_real_stack_bottom, + init_git, set_current_branch +) +from stacky.stack.models import StackBranchSet +from stacky.stack.operations import inner_do_sync, load_all_stacks, load_stack_for_given_branch +from stacky.stack.tree import get_current_stack_as_forest +from stacky.utils.config import get_config +from stacky.utils.logging import ( + ExitException, _LOGGING_FORMAT, error, set_color_mode +) +from stacky.utils.shell import run +from stacky.utils.types import BranchName, LOGLEVELS, STATE_FILE + +# Import all command handlers +from stacky.commands.navigation import cmd_info, cmd_log, cmd_branch_up, cmd_branch_down +from stacky.commands.branch import cmd_branch_new, cmd_branch_commit, cmd_branch_checkout +from stacky.commands.commit import cmd_commit, cmd_amend +from stacky.commands.stack import cmd_stack_info, cmd_stack_push, cmd_stack_sync, cmd_stack_checkout +from stacky.commands.upstack import ( + cmd_upstack_info, cmd_upstack_push, cmd_upstack_sync, cmd_upstack_onto, cmd_upstack_as +) +from stacky.commands.downstack import cmd_downstack_info, cmd_downstack_push, cmd_downstack_sync +from stacky.commands.update import cmd_update, cmd_import, cmd_adopt +from stacky.commands.land import cmd_land +from stacky.commands.inbox import cmd_inbox, cmd_prs +from stacky.commands.fold import ( + cmd_fold, inner_do_fold, finish_merge_fold_operation +) + + +def main(): + """Main entry point for stacky.""" + logging.basicConfig(format=_LOGGING_FORMAT, level=logging.INFO) + try: + parser = ArgumentParser(description="Handle git stacks") + parser.add_argument( + "--log-level", default="info", choices=LOGLEVELS.keys(), + help="Set the log level", + ) + parser.add_argument( + "--color", default="auto", choices=["always", "auto", "never"], + help="Colorize output and error", + ) + parser.add_argument( + "--remote-name", "-r", default="origin", + help="name of the git remote where branches will be pushed", + ) + + subparsers = parser.add_subparsers(required=True, dest="command") + + # continue + continue_parser = subparsers.add_parser("continue", help="Continue previously interrupted command") + continue_parser.set_defaults(func=None) + + # down / up + down_parser = subparsers.add_parser("down", help="Go down in the current stack (towards master/main)") + down_parser.set_defaults(func=cmd_branch_down) + up_parser = subparsers.add_parser("up", help="Go up in the current stack (away master/main)") + up_parser.set_defaults(func=cmd_branch_up) + + # info + info_parser = subparsers.add_parser("info", help="Stack info") + info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") + info_parser.set_defaults(func=cmd_info) + + # log + log_parser = subparsers.add_parser("log", help="Show git log with conditional merge handling") + log_parser.set_defaults(func=cmd_log) + + # commit + commit_parser = subparsers.add_parser("commit", help="Commit") + commit_parser.add_argument("-m", help="Commit message", dest="message") + commit_parser.add_argument("--amend", action="store_true", help="Amend last commit") + commit_parser.add_argument("--allow-empty", action="store_true", help="Allow empty commit") + commit_parser.add_argument("--no-edit", action="store_true", help="Skip editor") + commit_parser.add_argument("-a", action="store_true", help="Add all files to commit", dest="add_all") + commit_parser.add_argument("--no-verify", action="store_true", help="Bypass pre-commit and commit-msg hooks") + commit_parser.set_defaults(func=cmd_commit) + + # amend + amend_parser = subparsers.add_parser("amend", help="Shortcut for amending last commit") + amend_parser.add_argument("--no-verify", action="store_true", help="Bypass pre-commit and commit-msg hooks") + amend_parser.set_defaults(func=cmd_amend) + + _setup_branch_subcommands(subparsers) + _setup_stack_subcommands(subparsers) + _setup_upstack_subcommands(subparsers) + _setup_downstack_subcommands(subparsers) + _setup_other_commands(subparsers) + + argcomplete.autocomplete(parser) + args = parser.parse_args() + logging.basicConfig(format=_LOGGING_FORMAT, level=LOGLEVELS[args.log_level], force=True) + set_color_mode(args.color) + + init_git() + stack = StackBranchSet() + load_all_stacks(stack) + + current_branch = get_current_branch_name() + if args.command == "continue": + _handle_continue(stack, current_branch) + else: + if current_branch not in stack.stack: + main_branch = get_real_stack_bottom() + if get_config().change_to_main and main_branch is not None: + run(["git", "checkout", main_branch]) + set_current_branch(main_branch) + else: + from stacky.utils.logging import die + die("Current branch {} is not in a stack", current_branch) + + get_current_stack_as_forest(stack) + args.func(stack, args) + + # Success, delete the state file + try: + os.remove(STATE_FILE) + except FileNotFoundError: + pass + except ExitException as e: + error("{}", e.args[0]) + sys.exit(1) + + +def _handle_continue(stack: StackBranchSet, current_branch: BranchName): + """Handle the 'continue' command for interrupted operations.""" + from stacky.utils.logging import die + + try: + with open(STATE_FILE) as f: + state = json.load(f) + except FileNotFoundError: + die("No previous command in progress") + + branch = state["branch"] + run(["git", "checkout", branch]) + set_current_branch(branch) + + if branch not in stack.stack: + die("Current branch {} is not in a stack", branch) + + if "sync" in state: + sync_names = state["sync"] + syncs = [stack.stack[n] for n in sync_names] + inner_do_sync(syncs, sync_names) + elif "fold" in state: + fold_state = state["fold"] + inner_do_fold( + stack, + fold_state["fold_branch"], + fold_state["parent_branch"], + fold_state["commits"], + fold_state["children"], + fold_state["allow_empty"] + ) + elif "merge_fold" in state: + merge_fold_state = state["merge_fold"] + finish_merge_fold_operation( + stack, + merge_fold_state["fold_branch"], + merge_fold_state["parent_branch"], + merge_fold_state["children"] + ) + else: + die("Unknown operation in progress") + + +def _setup_branch_subcommands(subparsers): + """Setup branch subcommands.""" + branch_parser = subparsers.add_parser("branch", aliases=["b"], help="Operations on branches") + branch_subparsers = branch_parser.add_subparsers(required=True, dest="branch_command") + + branch_up_parser = branch_subparsers.add_parser("up", aliases=["u"], help="Move upstack") + branch_up_parser.set_defaults(func=cmd_branch_up) + + branch_down_parser = branch_subparsers.add_parser("down", aliases=["d"], help="Move downstack") + branch_down_parser.set_defaults(func=cmd_branch_down) + + branch_new_parser = branch_subparsers.add_parser("new", aliases=["create"], help="Create a new branch") + branch_new_parser.add_argument("name", help="Branch name") + branch_new_parser.set_defaults(func=cmd_branch_new) + + branch_commit_parser = branch_subparsers.add_parser("commit", help="Create a new branch and commit all changes") + branch_commit_parser.add_argument("name", help="Branch name") + branch_commit_parser.add_argument("-m", help="Commit message", dest="message") + branch_commit_parser.add_argument("-a", action="store_true", help="Add all files to commit", dest="add_all") + branch_commit_parser.add_argument("--no-verify", action="store_true", help="Bypass pre-commit and commit-msg hooks") + branch_commit_parser.set_defaults(func=cmd_branch_commit) + + branch_checkout_parser = branch_subparsers.add_parser("checkout", aliases=["co"], help="Checkout a branch") + branch_checkout_parser.add_argument("name", help="Branch name", nargs="?").completer = branch_name_completer + branch_checkout_parser.set_defaults(func=cmd_branch_checkout) + + +def _setup_stack_subcommands(subparsers): + """Setup stack subcommands.""" + stack_parser = subparsers.add_parser("stack", aliases=["s"], help="Operations on the full current stack") + stack_subparsers = stack_parser.add_subparsers(required=True, dest="stack_command") + + stack_info_parser = stack_subparsers.add_parser("info", aliases=["i"], help="Info for current stack") + stack_info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") + stack_info_parser.set_defaults(func=cmd_stack_info) + + stack_push_parser = stack_subparsers.add_parser("push", help="Push") + stack_push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + stack_push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") + stack_push_parser.set_defaults(func=cmd_stack_push) + + stack_sync_parser = stack_subparsers.add_parser("sync", help="Sync") + stack_sync_parser.set_defaults(func=cmd_stack_sync) + + stack_checkout_parser = stack_subparsers.add_parser("checkout", aliases=["co"], help="Checkout a branch in this stack") + stack_checkout_parser.set_defaults(func=cmd_stack_checkout) + + +def _setup_upstack_subcommands(subparsers): + """Setup upstack subcommands.""" + upstack_parser = subparsers.add_parser("upstack", aliases=["us"], help="Operations on the current upstack") + upstack_subparsers = upstack_parser.add_subparsers(required=True, dest="upstack_command") + + upstack_info_parser = upstack_subparsers.add_parser("info", aliases=["i"], help="Info for current upstack") + upstack_info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") + upstack_info_parser.set_defaults(func=cmd_upstack_info) + + upstack_push_parser = upstack_subparsers.add_parser("push", help="Push") + upstack_push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + upstack_push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") + upstack_push_parser.set_defaults(func=cmd_upstack_push) + + upstack_sync_parser = upstack_subparsers.add_parser("sync", help="Sync") + upstack_sync_parser.set_defaults(func=cmd_upstack_sync) + + upstack_onto_parser = upstack_subparsers.add_parser("onto", aliases=["restack"], help="Restack") + upstack_onto_parser.add_argument("target", help="New parent") + upstack_onto_parser.set_defaults(func=cmd_upstack_onto) + + upstack_as_parser = upstack_subparsers.add_parser("as", help="Upstack branch this as a new stack bottom") + upstack_as_parser.add_argument("target", help="bottom, restack this branch as a new stack bottom").completer = branch_name_completer + upstack_as_parser.set_defaults(func=cmd_upstack_as) + + +def _setup_downstack_subcommands(subparsers): + """Setup downstack subcommands.""" + downstack_parser = subparsers.add_parser("downstack", aliases=["ds"], help="Operations on the current downstack") + downstack_subparsers = downstack_parser.add_subparsers(required=True, dest="downstack_command") + + downstack_info_parser = downstack_subparsers.add_parser("info", aliases=["i"], help="Info for current downstack") + downstack_info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") + downstack_info_parser.set_defaults(func=cmd_downstack_info) + + downstack_push_parser = downstack_subparsers.add_parser("push", help="Push") + downstack_push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + downstack_push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") + downstack_push_parser.set_defaults(func=cmd_downstack_push) + + downstack_sync_parser = downstack_subparsers.add_parser("sync", help="Sync") + downstack_sync_parser.set_defaults(func=cmd_downstack_sync) + + +def _setup_other_commands(subparsers): + """Setup other commands (update, import, adopt, land, shortcuts, etc.).""" + # update + update_parser = subparsers.add_parser("update", help="Update repo, all bottom branches must exist in remote") + update_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + update_parser.set_defaults(func=cmd_update) + + # import + import_parser = subparsers.add_parser("import", help="Import Graphite stack") + import_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + import_parser.add_argument("name", help="Foreign stack top").completer = branch_name_completer + import_parser.set_defaults(func=cmd_import) + + # adopt + adopt_parser = subparsers.add_parser("adopt", help="Adopt one branch") + adopt_parser.add_argument("name", help="Branch name").completer = branch_name_completer + adopt_parser.set_defaults(func=cmd_adopt) + + # land + land_parser = subparsers.add_parser("land", help="Land bottom-most PR on current stack") + land_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + land_parser.add_argument("--auto", "-a", action="store_true", help="Automatically merge after all checks pass") + land_parser.set_defaults(func=cmd_land) + + # shortcuts + push_parser = subparsers.add_parser("push", help="Alias for downstack push") + push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") + push_parser.set_defaults(func=cmd_downstack_push) + + sync_parser = subparsers.add_parser("sync", help="Alias for stack sync") + sync_parser.set_defaults(func=cmd_stack_sync) + + checkout_parser = subparsers.add_parser("checkout", aliases=["co"], help="Checkout a branch") + checkout_parser.add_argument("name", help="Branch name", nargs="?").completer = branch_name_completer + checkout_parser.set_defaults(func=cmd_branch_checkout) + + sco_parser = subparsers.add_parser("sco", help="Checkout a branch in this stack") + sco_parser.set_defaults(func=cmd_stack_checkout) + + # inbox + inbox_parser = subparsers.add_parser("inbox", help="List all active GitHub pull requests for the current user") + inbox_parser.add_argument("--compact", "-c", action="store_true", help="Show compact view") + inbox_parser.set_defaults(func=cmd_inbox) + + # prs + prs_parser = subparsers.add_parser("prs", help="Interactive PR management - select and edit PR descriptions") + prs_parser.set_defaults(func=cmd_prs) + + # fold + fold_parser = subparsers.add_parser("fold", help="Fold current branch into parent branch and delete current branch") + fold_parser.add_argument("--allow-empty", action="store_true", help="Allow empty commits during cherry-pick") + fold_parser.set_defaults(func=cmd_fold) diff --git a/src/stacky/pr/__init__.py b/src/stacky/pr/__init__.py new file mode 100644 index 0000000..c390929 --- /dev/null +++ b/src/stacky/pr/__init__.py @@ -0,0 +1 @@ +# PR module - GitHub PR operations diff --git a/src/stacky/pr/github.py b/src/stacky/pr/github.py new file mode 100644 index 0000000..d7ea68e --- /dev/null +++ b/src/stacky/pr/github.py @@ -0,0 +1,255 @@ +"""GitHub PR operations for stacky.""" + +import json +import logging +import os +import re +import subprocess +import tempfile +from typing import Dict, List, Optional, TYPE_CHECKING + +from stacky.stack.models import PRInfo, PRInfos +from stacky.stack.tree import get_pr_status_emoji +from stacky.utils.config import get_config +from stacky.utils.logging import COLOR_STDOUT, cout, fmt +from stacky.utils.shell import run, run_always_return, run_multiline +from stacky.utils.types import BranchesTreeForest, BranchName, CmdArgs, STACK_BOTTOMS + +if TYPE_CHECKING: + from stacky.stack.models import StackBranch + + +def get_pr_info(branch: BranchName, *, full: bool = False) -> PRInfos: + """Get PR information for a branch.""" + from stacky.utils.logging import die + + fields = [ + "id", "number", "state", "mergeable", "url", "title", + "baseRefName", "headRefName", "reviewDecision", "reviewRequests", "isDraft", + ] + if full: + fields += ["commits"] + data = json.loads( + run_always_return( + CmdArgs([ + "gh", "pr", "list", "--json", ",".join(fields), + "--state", "all", "--head", branch, + ]) + ) + ) + raw_infos: List[PRInfo] = data + + infos: Dict[str, PRInfo] = {info["id"]: info for info in raw_infos} + open_prs: List[PRInfo] = [info for info in infos.values() if info["state"] == "OPEN"] + if len(open_prs) > 1: + die( + "Branch {} has more than one open PR: {}", + branch, ", ".join([str(pr) for pr in open_prs]), + ) + return PRInfos(infos, open_prs[0] if open_prs else None) + + +def find_reviewers(b: "StackBranch") -> Optional[List[str]]: + """Find reviewers from commit message.""" + out = run_multiline( + CmdArgs(["git", "log", "--pretty=format:%b", "-1", f"{b.name}"]), + ) + assert out is not None + for l in out.split("\n"): + reviewer_match = re.match(r"^reviewers?\s*:\s*(.*)", l, re.I) + if reviewer_match: + reviewers = reviewer_match.group(1).split(",") + logging.debug(f"Found the following reviewers: {', '.join(reviewers)}") + return reviewers + return None + + +def find_issue_marker(name: str) -> Optional[str]: + """Find issue marker (e.g. SRE-123) in branch name.""" + match = re.search(r"(?:^|[_-])([A-Z]{3,}[_-]?\d{2,})($|[_-].*)", name) + if match: + res = match.group(1) + if "_" in res: + return res.replace("_", "-") + if "-" not in res: + newmatch = re.match(r"(...)(\d+)", res) + assert newmatch is not None + return f"{newmatch.group(1)}-{newmatch.group(2)}" + return res + return None + + +def create_gh_pr(b: "StackBranch", prefix: str): + """Create a GitHub PR for a branch.""" + from stacky.utils.ui import prompt + + cout("Creating PR for {}\n", b.name, fg="green") + parent_prefix = "" + if b.parent.name not in STACK_BOTTOMS: + prefix = "" + cmd = [ + "gh", "pr", "create", + "--head", f"{prefix}{b.name}", + "--base", f"{parent_prefix}{b.parent.name}", + ] + reviewers = find_reviewers(b) + issue_id = find_issue_marker(b.name) + if issue_id: + out = run_multiline( + CmdArgs(["git", "log", "--pretty=oneline", f"{b.parent.name}..{b.name}"]), + ) + title = f"[{issue_id}] " + if out is not None and len(out.split("\n")) == 2: + out = run( + CmdArgs(["git", "log", "--pretty=format:%s", "-1", f"{b.name}"]), + out=False, + ) + if out is None: + out = "" + if b.name not in out: + title += out + else: + title = out + + title = prompt( + (fmt("? ", color=COLOR_STDOUT, fg="green") + + fmt("Title ", color=COLOR_STDOUT, style="bold", fg="white")), + title, + ) + cmd.extend(["--title", title.strip()]) + if reviewers: + logging.debug(f"Adding {len(reviewers)} reviewer(s) to the review") + for r in reviewers: + r = r.strip() + r = r.replace("#", "rockset/") + if len(r) > 0: + cmd.extend(["--reviewer", r]) + + run(CmdArgs(cmd), out=True) + + +def generate_stack_string(forest: BranchesTreeForest, current_branch: "StackBranch") -> str: + """Generate a string representation of the PR stack.""" + from stacky.stack.tree import BranchesTree + + stack_lines = [] + + def add_branch_to_stack(b: "StackBranch", depth: int): + if b.name in STACK_BOTTOMS: + return + indent = " " * depth + pr_info = "" + if b.open_pr_info: + pr_number = b.open_pr_info['number'] + status_emoji = get_pr_status_emoji(b.open_pr_info) + pr_info = f" (#{pr_number}{status_emoji})" + current_indicator = " ← (CURRENT PR)" if b.name == current_branch.name else "" + stack_lines.append(f"{indent}- {b.name}{pr_info}{current_indicator}") + + def traverse_tree(tree: BranchesTree, depth: int): + for _, (branch, children) in tree.items(): + add_branch_to_stack(branch, depth) + traverse_tree(children, depth + 1) + + for tree in forest: + traverse_tree(tree, 0) + + if not stack_lines: + return "" + + return "\n".join([ + "", + "**Stack:**", + *stack_lines, + "" + ]) + + +def extract_stack_comment(body: str) -> str: + """Extract existing stack comment from PR body.""" + if not body: + return "" + pattern = r'.*?' + match = re.search(pattern, body, re.DOTALL) + if match: + return match.group(0).strip() + return "" + + +def add_or_update_stack_comment(branch: "StackBranch", complete_forest: BranchesTreeForest): + """Add or update stack comment in PR body.""" + if not branch.open_pr_info: + return + + pr_number = branch.open_pr_info["number"] + pr_data = json.loads( + run_always_return(CmdArgs(["gh", "pr", "view", str(pr_number), "--json", "body"])) + ) + + current_body = pr_data.get("body", "") + stack_string = generate_stack_string(complete_forest, branch) + + if not stack_string: + return + + existing_stack = extract_stack_comment(current_body) + + if not existing_stack: + if current_body: + new_body = f"{current_body}\n\n{stack_string}" + else: + new_body = stack_string + cout("Adding stack comment to PR #{}\n", pr_number, fg="green") + run(CmdArgs(["gh", "pr", "edit", str(pr_number), "--body", new_body]), out=True) + else: + if existing_stack != stack_string: + updated_body = current_body.replace(existing_stack, stack_string) + cout("Updating stack comment in PR #{}\n", pr_number, fg="yellow") + run(CmdArgs(["gh", "pr", "edit", str(pr_number), "--body", updated_body]), out=True) + else: + cout("✓ Stack comment in PR #{} is already correct\n", pr_number, fg="green") + + +def edit_pr_description(pr): + """Edit a PR's description using the user's default editor.""" + cout("Editing PR #{} - {}\n", pr["number"], pr["title"], fg="green") + cout("Current description:\n", fg="yellow") + current_body = pr.get("body", "") + if current_body: + cout("{}\n\n", current_body, fg="gray") + else: + cout("(No description)\n\n", fg="gray") + + with tempfile.NamedTemporaryFile(mode='w+', suffix='.md', delete=False) as temp_file: + temp_file.write(current_body or "") + temp_file_path = temp_file.name + + try: + editor = os.environ.get('EDITOR', 'vim') + result = subprocess.run([editor, temp_file_path]) + if result.returncode != 0: + cout("Editor exited with error, not updating PR description.\n", fg="red") + return + + with open(temp_file_path, 'r') as temp_file: + new_body = temp_file.read().strip() + + original_content = (current_body or "").strip() + new_content = new_body.strip() + + if new_content == original_content: + cout("No changes made to PR description.\n", fg="yellow") + return + + cout("Updating PR description...\n", fg="green") + run(CmdArgs(["gh", "pr", "edit", str(pr["number"]), "--body", new_body]), out=True) + cout("✓ Successfully updated PR #{} description\n", pr["number"], fg="green") + pr["body"] = new_body + + except Exception as e: + cout("Error editing PR description: {}\n", str(e), fg="red") + finally: + try: + os.unlink(temp_file_path) + except OSError: + pass diff --git a/src/stacky/stack/__init__.py b/src/stacky/stack/__init__.py new file mode 100644 index 0000000..2665fb6 --- /dev/null +++ b/src/stacky/stack/__init__.py @@ -0,0 +1 @@ +# Stack module - stack data structures and operations diff --git a/src/stacky/stack/models.py b/src/stacky/stack/models.py new file mode 100644 index 0000000..760b69a --- /dev/null +++ b/src/stacky/stack/models.py @@ -0,0 +1,140 @@ +"""Stack data models for stacky.""" + +import dataclasses +from typing import Dict, List, Optional, TypedDict + +from stacky.git.refs import get_commit +from stacky.git.remote import get_remote_info +from stacky.utils.logging import die +from stacky.utils.types import BranchName, Commit + + +class PRInfo(TypedDict): + """Type definition for PR information from GitHub.""" + id: str + number: int + state: str + mergeable: str + url: str + title: str + baseRefName: str + headRefName: str + commits: List[Dict[str, str]] + + +@dataclasses.dataclass +class PRInfos: + """Container for all PRs and the open PR for a branch.""" + all: Dict[str, PRInfo] + open: Optional[PRInfo] + + +@dataclasses.dataclass +class BranchNCommit: + """Branch name with its parent commit.""" + branch: BranchName + parent_commit: Optional[str] + + +class StackBranch: + """Represents a branch in a stack.""" + + def __init__( + self, + name: BranchName, + parent: "StackBranch", + parent_commit: Commit, + ): + self.name = name + self.parent = parent + self.parent_commit = parent_commit + self.children: set["StackBranch"] = set() + self.commit = get_commit(name) + self.remote, self.remote_branch, self.remote_commit = get_remote_info(name) + self.pr_info: Dict[str, PRInfo] = {} + self.open_pr_info: Optional[PRInfo] = None + self._pr_info_loaded = False + + def is_synced_with_parent(self): + """Check if branch is synced with its parent.""" + return self.parent is None or self.parent_commit == self.parent.commit + + def is_synced_with_remote(self): + """Check if branch is synced with remote.""" + return self.commit == self.remote_commit + + def __repr__(self): + return f"StackBranch: {self.name} {len(self.children)} {self.commit}" + + def load_pr_info(self): + """Load PR info from GitHub (lazy loading).""" + if not self._pr_info_loaded: + self._pr_info_loaded = True + from stacky.pr.github import get_pr_info + pr_infos = get_pr_info(self.name) + self.pr_info, self.open_pr_info = ( + pr_infos.all, + pr_infos.open, + ) + + +class StackBranchSet: + """Collection of stack branches.""" + + def __init__(self: "StackBranchSet"): + self.stack: Dict[BranchName, StackBranch] = {} + self.tops: set[StackBranch] = set() + self.bottoms: set[StackBranch] = set() + + def add(self, name: BranchName, **kwargs) -> StackBranch: + """Add a branch to the stack.""" + if name in self.stack: + s = self.stack[name] + assert s.name == name + for k, v in kwargs.items(): + if getattr(s, k) != v: + die( + "Mismatched stack: {}: {}={}, expected {}", + name, + k, + getattr(s, k), + v, + ) + else: + s = StackBranch(name, **kwargs) + self.stack[name] = s + if s.parent is None: + self.bottoms.add(s) + self.tops.add(s) + return s + + def addStackBranch(self, s: StackBranch): + """Add an existing StackBranch object to the set.""" + if s.name not in self.stack: + self.stack[s.name] = s + if s.parent is None: + self.bottoms.add(s) + if len(s.children) == 0: + self.tops.add(s) + return s + + def remove(self, name: BranchName) -> Optional[StackBranch]: + """Remove a branch from the stack.""" + if name in self.stack: + s = self.stack[name] + assert s.name == name + del self.stack[name] + if s in self.tops: + self.tops.remove(s) + if s in self.bottoms: + self.bottoms.remove(s) + return s + return None + + def __repr__(self) -> str: + return f"StackBranchSet: {self.stack}" + + def add_child(self, s: StackBranch, child: StackBranch): + """Add a child branch to a parent.""" + s.children.add(child) + self.tops.discard(s) diff --git a/src/stacky/stack/operations.py b/src/stacky/stack/operations.py new file mode 100644 index 0000000..ed24bc2 --- /dev/null +++ b/src/stacky/stack/operations.py @@ -0,0 +1,349 @@ +"""Stack operations for stacky - loading, syncing, pushing.""" + +import json +import os +from typing import List, Optional, Tuple, TYPE_CHECKING + +from stacky.git.branch import ( + get_all_branches, get_current_branch_name, get_stack_parent_branch, set_current_branch +) +from stacky.git.refs import ( + get_all_stack_bottoms, get_commit, get_commits_between, + get_stack_parent_commit, set_parent_commit +) +from stacky.git.remote import start_muxed_ssh, stop_muxed_ssh +from stacky.stack.models import BranchNCommit, StackBranch, StackBranchSet +from stacky.stack.tree import ( + forest_depth_first, get_complete_stack_forest_for_branch, + load_pr_info_for_forest, print_forest +) +from stacky.utils.config import get_config +from stacky.utils.logging import cout, die, info, warning +from stacky.utils.shell import run, run_always_return +from stacky.utils.types import ( + BranchesTreeForest, BranchName, CmdArgs, Commit, + STACK_BOTTOMS, STATE_FILE, TMP_STATE_FILE +) + +if TYPE_CHECKING: + pass + + +def load_all_stack_bottoms(): + """Load all custom stack bottoms into STACK_BOTTOMS.""" + STACK_BOTTOMS.update(get_all_stack_bottoms()) + + +def load_stack_for_given_branch( + stack: StackBranchSet, branch: BranchName, *, check: bool = True +) -> Tuple[Optional[StackBranch], List[BranchName]]: + """Load stack for a branch, returns (top_branch, list_of_branches).""" + branches: List[BranchNCommit] = [] + while branch not in STACK_BOTTOMS: + parent = get_stack_parent_branch(branch) + parent_commit = get_stack_parent_commit(branch) + branches.append(BranchNCommit(branch, parent_commit)) + if not parent or not parent_commit: + if check: + die("Branch is not in a stack: {}", branch) + return None, [b.branch for b in branches] + branch = parent + + branches.append(BranchNCommit(branch, None)) + top = None + for b in reversed(branches): + n = stack.add( + b.branch, + parent=top, + parent_commit=b.parent_commit, + ) + if top: + stack.add_child(top, n) + top = n + + return top, [b.branch for b in branches] + + +def load_all_stacks(stack: StackBranchSet) -> Optional[StackBranch]: + """Load all stacks, return top of current branch's stack.""" + load_all_stack_bottoms() + all_branches = set(get_all_branches()) + current_branch = get_current_branch_name() + current_branch_top = None + while all_branches: + b = all_branches.pop() + top, branches = load_stack_for_given_branch(stack, b, check=False) + all_branches -= set(branches) + if top is None: + if len(branches) > 1: + warning("Broken stack: {}", " -> ".join(branches)) + continue + if b == current_branch: + current_branch_top = top + return current_branch_top + + +def inner_do_sync(syncs: List[StackBranch], sync_names: List[BranchName]): + """Execute sync operations on branches.""" + print() + current_branch = get_current_branch_name() + sync_type = "merge" if get_config().use_merge else "rebase" + while syncs: + with open(TMP_STATE_FILE, "w") as f: + json.dump({"branch": current_branch, "sync": sync_names}, f) + os.replace(TMP_STATE_FILE, STATE_FILE) + + b = syncs.pop() + sync_names.pop() + if b.is_synced_with_parent(): + cout("{} is already synced on top of {}\n", b.name, b.parent.name) + continue + if b.parent.commit in get_commits_between(b.parent_commit, b.commit): + cout( + "Recording complete {} of {} on top of {}\n", + sync_type, b.name, b.parent.name, fg="green", + ) + else: + r = None + if get_config().use_merge: + cout("Merging {} into {}\n", b.parent.name, b.name, fg="green") + run(CmdArgs(["git", "checkout", str(b.name)])) + r = run(CmdArgs(["git", "merge", b.parent.name]), out=True, check=False) + else: + cout("Rebasing {} on top of {}\n", b.name, b.parent.name, fg="green") + r = run( + CmdArgs(["git", "rebase", "--onto", b.parent.name, b.parent_commit, b.name]), + out=True, check=False, + ) + + if r is None: + print() + die( + "Automatic {0} failed. Please complete the {0} (fix conflicts; " + "`git {0} --continue`), then run `stacky continue`".format(sync_type) + ) + b.commit = get_commit(b.name) + set_parent_commit(b.name, b.parent.commit, b.parent_commit) + b.parent_commit = b.parent.commit + run(CmdArgs(["git", "checkout", str(current_branch)])) + + +def do_sync(forest: BranchesTreeForest): + """Sync a forest of branches.""" + print_forest(forest) + + syncs: List[StackBranch] = [] + sync_names: List[BranchName] = [] + syncs_set: set[StackBranch] = set() + for b in forest_depth_first(forest): + if not b.parent: + cout("✓ Not syncing base branch {}\n", b.name, fg="green") + continue + if b.is_synced_with_parent() and b.parent not in syncs_set: + cout( + "✓ Not syncing branch {}, already synced with parent {}\n", + b.name, b.parent.name, fg="green", + ) + continue + syncs.append(b) + syncs_set.add(b) + sync_names.append(b.name) + cout("- Will sync branch {} on top of {}\n", b.name, b.parent.name) + + if not syncs: + return + + syncs.reverse() + sync_names.reverse() + inner_do_sync(syncs, sync_names) + + +def do_push( + forest: BranchesTreeForest, + *, + force: bool = False, + pr: bool = False, + remote_name: str = "origin", +): + """Push branches in a forest.""" + from stacky.pr.github import add_or_update_stack_comment, create_gh_pr + from stacky.utils.ui import confirm + + if pr: + load_pr_info_for_forest(forest) + print_forest(forest) + for b in forest_depth_first(forest): + if not b.is_synced_with_parent(): + die( + "Branch {} is not synced with parent {}, sync first", + b.name, b.parent.name, + ) + + PR_NONE = 0 + PR_FIX_BASE = 1 + PR_CREATE = 2 + actions = [] + for b in forest_depth_first(forest): + if not b.parent: + cout("✓ Not pushing base branch {}\n", b.name, fg="green") + continue + + push = False + if b.is_synced_with_remote(): + cout( + "✓ Not pushing branch {}, synced with remote {}/{}\n", + b.name, b.remote, b.remote_branch, fg="green", + ) + else: + cout("- Will push branch {} to {}/{}\n", b.name, b.remote, b.remote_branch) + push = True + + pr_action = PR_NONE + if pr: + if b.open_pr_info: + expected_base = b.parent.name + if b.open_pr_info["baseRefName"] != expected_base: + cout( + "- Branch {} already has open PR #{}; will change PR base from {} to {}\n", + b.name, b.open_pr_info["number"], + b.open_pr_info["baseRefName"], expected_base, + ) + pr_action = PR_FIX_BASE + else: + cout( + "✓ Branch {} already has open PR #{}\n", + b.name, b.open_pr_info["number"], fg="green", + ) + else: + cout("- Will create PR for branch {}\n", b.name) + pr_action = PR_CREATE + + if not push and pr_action == PR_NONE: + continue + actions.append((b, push, pr_action)) + + if actions and not force: + confirm() + + # Figure out prefix for branch (e.g. user:branch for forks) + val = run(CmdArgs(["git", "config", f"remote.{remote_name}.gh-resolved"]), check=False) + if val is not None and "/" in val: + val = run_always_return(CmdArgs(["git", "config", f"remote.{remote_name}.url"])) + prefix = f'{val.split(":")[1].split("/")[0]}:' + else: + prefix = "" + + muxed = False + for b, push, pr_action in actions: + if push: + if not muxed: + start_muxed_ssh(remote_name) + muxed = True + cout("Pushing {}\n", b.name, fg="green") + cmd_args = ["git", "push"] + if get_config().use_force_push: + cmd_args.append("-f") + cmd_args.extend([b.remote, "{}:{}".format(b.name, b.remote_branch)]) + run(CmdArgs(cmd_args), out=True) + if pr_action == PR_FIX_BASE: + cout("Fixing PR base for {}\n", b.name, fg="green") + assert b.open_pr_info is not None + run( + CmdArgs([ + "gh", "pr", "edit", str(b.open_pr_info["number"]), + "--base", b.parent.name, + ]), + out=True, + ) + elif pr_action == PR_CREATE: + create_gh_pr(b, prefix) + + # Handle stack comments for PRs + if pr and get_config().enable_stack_comment: + load_pr_info_for_forest(forest) + complete_forests_by_root = {} + branches_with_prs = [b for b in forest_depth_first(forest) if b.open_pr_info] + + for b in branches_with_prs: + root = b + while root.parent and root.parent.name not in STACK_BOTTOMS: + root = root.parent + root_name = root.name + if root_name not in complete_forests_by_root: + complete_forest = get_complete_stack_forest_for_branch(b) + load_pr_info_for_forest(complete_forest) + complete_forests_by_root[root_name] = complete_forest + + for b in branches_with_prs: + root = b + while root.parent and root.parent.name not in STACK_BOTTOMS: + root = root.parent + complete_forest = complete_forests_by_root[root.name] + add_or_update_stack_comment(b, complete_forest) + + stop_muxed_ssh(remote_name) + + +def get_branches_to_delete(forest: BranchesTreeForest) -> List[StackBranch]: + """Get branches that can be deleted (PRs merged).""" + deletes = [] + for b in forest_depth_first(forest): + if not b.parent or b.open_pr_info: + continue + for pr_info in b.pr_info.values(): + if pr_info["state"] != "MERGED": + continue + cout( + "- Will delete branch {}, PR #{} merged into {}\n", + b.name, pr_info["number"], b.parent.name, + ) + deletes.append(b) + for c in b.children: + cout("- Will reparent branch {} onto {}\n", c.name, b.parent.name) + break + return deletes + + +def delete_branches(stack: StackBranchSet, deletes: List[StackBranch]): + """Delete merged branches and reparent their children.""" + from stacky.git.refs import set_parent + + current_branch = get_current_branch_name() + for b in deletes: + for c in b.children: + info("Reparenting {} onto {}", c.name, b.parent.name) + c.parent = b.parent + set_parent(c.name, b.parent.name) + info("Deleting {}", b.name) + if b.name == current_branch: + new_branch = next(iter(stack.bottoms)) + info("About to delete current branch, switching to {}", new_branch.name) + run(CmdArgs(["git", "checkout", new_branch.name])) + set_current_branch(new_branch.name) + run(CmdArgs(["git", "branch", "-D", b.name])) + + +def cleanup_unused_refs(stack: StackBranchSet): + """Clean up refs for non-existent branches.""" + from stacky.git.refs import get_all_stack_parent_refs + + info("Cleaning up unused refs") + existing_branches = set(get_all_branches()) + + stack_bottoms = get_all_stack_bottoms() + for bottom in stack_bottoms: + if bottom not in stack.stack or bottom not in existing_branches: + ref = "refs/stacky-bottom-branch/{}".format(bottom) + info("Deleting ref {} (branch {} no longer exists)".format(ref, bottom)) + run(CmdArgs(["git", "update-ref", "-d", ref])) + + stack_parent_refs = get_all_stack_parent_refs() + for br in stack_parent_refs: + if br not in stack.stack or br not in existing_branches: + ref = "refs/stack-parent/{}".format(br) + old_value = run(CmdArgs(["git", "show-ref", ref]), check=False) + if old_value: + info("Deleting ref {} (branch {} no longer exists)".format(old_value, br)) + else: + info("Deleting ref refs/stack-parent/{} (branch {} no longer exists)".format(br, br)) + run(CmdArgs(["git", "update-ref", "-d", ref])) diff --git a/src/stacky/stack/tree.py b/src/stacky/stack/tree.py new file mode 100644 index 0000000..9417f39 --- /dev/null +++ b/src/stacky/stack/tree.py @@ -0,0 +1,193 @@ +"""Tree formatting and traversal for stacky stacks.""" + +from typing import Generator, List, TYPE_CHECKING + +from stacky.git.branch import get_current_branch_name +from stacky.utils.config import get_config +from stacky.utils.logging import COLOR_STDOUT, fmt +from stacky.utils.types import BranchesTree, BranchesTreeForest, BranchName, TreeNode + +if TYPE_CHECKING: + from stacky.stack.models import StackBranch, StackBranchSet + + +def get_pr_status_emoji(pr_info) -> str: + """Get the status emoji for a PR based on review state.""" + if not pr_info: + return "" + + review_decision = pr_info.get('reviewDecision') + review_requests = pr_info.get('reviewRequests', []) + is_draft = pr_info.get('isDraft', False) + + if is_draft: + # Draft PRs are waiting on author + return " 🚧" + elif review_decision == "APPROVED": + return " ✅" + elif review_requests and len(review_requests) > 0: + # Has pending review requests - waiting on review + return " 🔄" + else: + # No pending review requests, likely needs changes or author action + return " ❌" + + +def make_tree_node(b: "StackBranch") -> TreeNode: + """Create a tree node for a branch.""" + return (b.name, (b, make_subtree(b))) + + +def make_subtree(b: "StackBranch") -> BranchesTree: + """Create a subtree for a branch's children.""" + return BranchesTree(dict(make_tree_node(c) for c in sorted(b.children, key=lambda x: x.name))) + + +def make_tree(b: "StackBranch") -> BranchesTree: + """Create a tree rooted at a branch.""" + return BranchesTree(dict([make_tree_node(b)])) + + +def format_name(b: "StackBranch", *, colorize: bool) -> str: + """Format a branch name with status indicators.""" + current_branch = get_current_branch_name() + prefix = "" + severity = 0 + # TODO: Align things so that we have the same prefix length? + if not b.is_synced_with_parent(): + prefix += fmt("!", color=colorize, fg="yellow") + severity = max(severity, 2) + if not b.is_synced_with_remote(): + prefix += fmt("~", color=colorize, fg="yellow") + if b.name == current_branch: + prefix += fmt("*", color=colorize, fg="cyan") + else: + severity = max(severity, 1) + if prefix: + prefix += " " + fg = ["cyan", "green", "yellow", "red"][severity] + suffix = "" + if b.open_pr_info: + suffix += " " + # Make the PR info a clickable link + pr_url = b.open_pr_info["url"] + pr_number = b.open_pr_info["number"] + status_emoji = get_pr_status_emoji(b.open_pr_info) + + if get_config().compact_pr_display: + # Compact: just number and emoji + suffix += fmt("(\033]8;;{}\033\\#{}{}\033]8;;\033\\)", pr_url, pr_number, status_emoji, color=colorize, fg="blue") + else: + # Full: number, emoji, and title + pr_title = b.open_pr_info["title"] + suffix += fmt("(\033]8;;{}\033\\#{}{} {}\033]8;;\033\\)", pr_url, pr_number, status_emoji, pr_title, color=colorize, fg="blue") + return prefix + fmt("{}", b.name, color=colorize, fg=fg) + suffix + + +def format_tree(tree: BranchesTree, *, colorize: bool = False): + """Format a tree for display.""" + return { + format_name(branch, colorize=colorize): format_tree(children, colorize=colorize) + for branch, children in tree.values() + } + + +def print_tree(tree: BranchesTree): + """Print a tree (upside down to match upstack/downstack nomenclature).""" + from stacky.utils.ui import ASCII_TREE + s = ASCII_TREE(format_tree(tree, colorize=COLOR_STDOUT)) + lines = s.split("\n") + print("\n".join(reversed(lines))) + + +def print_forest(trees: List[BranchesTree]): + """Print multiple trees.""" + for i, t in enumerate(trees): + if i != 0: + print() + print_tree(t) + + +def forest_depth_first(forest: BranchesTreeForest) -> Generator["StackBranch", None, None]: + """Iterate over a forest in depth-first order.""" + for tree in forest: + for b in depth_first(tree): + yield b + + +def depth_first(tree: BranchesTree) -> Generator["StackBranch", None, None]: + """Iterate over a tree in depth-first order.""" + for _, (branch, children) in tree.items(): + yield branch + for b in depth_first(children): + yield b + + +def get_all_stacks_as_forest(stack: "StackBranchSet") -> BranchesTreeForest: + """Get all stacks as a forest.""" + return BranchesTreeForest([make_tree(b) for b in stack.bottoms]) + + +def get_current_stack_as_forest(stack: "StackBranchSet") -> BranchesTreeForest: + """Get the current stack as a forest.""" + current_branch = get_current_branch_name() + b = stack.stack[current_branch] + d: BranchesTree = make_tree(b) + b = b.parent + while b: + d = BranchesTree({b.name: (b, d)}) + b = b.parent + return [d] + + +def get_current_upstack_as_forest(stack: "StackBranchSet") -> BranchesTreeForest: + """Get the current upstack (current branch and above) as a forest.""" + current_branch = get_current_branch_name() + b = stack.stack[current_branch] + return BranchesTreeForest([make_tree(b)]) + + +def get_current_downstack_as_forest(stack: "StackBranchSet") -> BranchesTreeForest: + """Get the current downstack (current branch and below) as a forest.""" + current_branch = get_current_branch_name() + b = stack.stack[current_branch] + d: BranchesTree = BranchesTree({}) + while b: + d = BranchesTree({b.name: (b, d)}) + b = b.parent + return BranchesTreeForest([d]) + + +def get_bottom_level_branches_as_forest(stack: "StackBranchSet") -> BranchesTreeForest: + """Get bottom level branches (stack bottoms and their direct children) as a forest.""" + return BranchesTreeForest( + [ + BranchesTree( + { + bottom.name: ( + bottom, + BranchesTree({b.name: (b, BranchesTree({})) for b in bottom.children}), + ) + } + ) + for bottom in stack.bottoms + ] + ) + + +def load_pr_info_for_forest(forest: BranchesTreeForest): + """Load PR info for all branches in a forest.""" + for b in forest_depth_first(forest): + b.load_pr_info() + + +def get_complete_stack_forest_for_branch(branch: "StackBranch") -> BranchesTreeForest: + """Get the complete stack forest containing the given branch.""" + from stacky.utils.types import STACK_BOTTOMS + # Find the root of the stack + root = branch + while root.parent and root.parent.name not in STACK_BOTTOMS: + root = root.parent + + # Create a forest with just this root's complete tree + return BranchesTreeForest([make_tree(root)]) diff --git a/src/stacky/stacky.py b/src/stacky/stacky.py deleted file mode 100755 index 9da39aa..0000000 --- a/src/stacky/stacky.py +++ /dev/null @@ -1,1897 +0,0 @@ -#!/usr/bin/env python3 - -# GitHub helper for stacked diffs. -# -# Git maintains all metadata locally. Does everything by forking "git" and "gh" -# commands. -# -# Theory of operation: -# -# Each entry in a stack is a branch, set to track its parent (that is, `git -# config branch..remote` is ".", and `git config branch..merge` is -# "refs/heads/") -# -# For each branch, we maintain a ref (call it PC, for "parent commit") pointing -# to the commit at the tip of the parent branch, as `git update-ref -# refs/stack-parent/`. -# -# For all bottom branches we maintain a ref, labeling it a bottom_branch refs/stacky-bottom-branch/branch-name -# -# When rebasing or restacking, we proceed in depth-first order (from "master" -# onwards). After updating a parent branch P, given a child branch C, -# we rebase everything from C's PC until C's tip onto P. -# -# -# That's all there is to it. - -import configparser -import dataclasses -import json -import logging -import os -import re -import shlex -import subprocess -import sys -import time -from argparse import ArgumentParser -from typing import Dict, FrozenSet, Generator, List, NewType, Optional, Tuple, TypedDict, Union - -import asciitree # type: ignore -import colors # type: ignore -from simple_term_menu import TerminalMenu # type: ignore - -BranchName = NewType("BranchName", str) -PathName = NewType("PathName", str) -Commit = NewType("Commit", str) -CmdArgs = NewType("CmdArgs", List[str]) -StackSubTree = Tuple["StackBranch", "BranchesTree"] -TreeNode = Tuple[BranchName, StackSubTree] -BranchesTree = NewType("BranchesTree", Dict[BranchName, StackSubTree]) -BranchesTreeForest = NewType("BranchesTreeForest", List[BranchesTree]) - -JSON = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None] - - -class PRInfo(TypedDict): - id: str - number: int - state: str - mergeable: str - url: str - title: str - baseRefName: str - headRefName: str - commits: List[Dict[str, str]] - - -@dataclasses.dataclass -class PRInfos: - all: Dict[str, PRInfo] - open: Optional[PRInfo] - - -@dataclasses.dataclass -class BranchNCommit: - branch: BranchName - parent_commit: Optional[str] - - -_LOGGING_FORMAT = "%(asctime)s %(module)s %(levelname)s: %(message)s" - -# 2 minutes ought to be enough for anybody ;-) -MAX_SSH_MUX_LIFETIME = 120 -COLOR_STDOUT: bool = os.isatty(1) -COLOR_STDERR: bool = os.isatty(2) -IS_TERMINAL: bool = os.isatty(1) and os.isatty(2) -CURRENT_BRANCH: BranchName -STACK_BOTTOMS: set[BranchName] = set([BranchName("master"), BranchName("main")]) -FROZEN_STACK_BOTTOMS: FrozenSet[BranchName] = frozenset([BranchName("master"), BranchName("main")]) -STATE_FILE = os.path.expanduser("~/.stacky.state") -TMP_STATE_FILE = STATE_FILE + ".tmp" - -LOGLEVELS = { - "critical": logging.CRITICAL, - "error": logging.ERROR, - "warn": logging.WARNING, - "warning": logging.WARNING, - "info": logging.INFO, - "debug": logging.DEBUG, -} - - -@dataclasses.dataclass -class StackyConfig: - skip_confirm: bool = False - change_to_main: bool = False - change_to_adopted: bool = False - share_ssh_session: bool = False - use_merge: bool = False - use_force_push: bool = True - - def read_one_config(self, config_path: str): - rawconfig = configparser.ConfigParser() - rawconfig.read(config_path) - if rawconfig.has_section("UI"): - self.skip_confirm = bool(rawconfig.get("UI", "skip_confirm", fallback=self.skip_confirm)) - self.change_to_main = bool(rawconfig.get("UI", "change_to_main", fallback=self.change_to_main)) - self.change_to_adopted = bool(rawconfig.get("UI", "change_to_adopted", fallback=self.change_to_adopted)) - self.share_ssh_session = bool(rawconfig.get("UI", "share_ssh_session", fallback=self.share_ssh_session)) - - if rawconfig.has_section("GIT"): - self.use_merge = bool(rawconfig.get("GIT", "use_merge", fallback=self.use_merge)) - self.use_merge = bool(rawconfig.get("GIT", "use_force_push", fallback=self.use_force_push)) - - -CONFIG: Optional[StackyConfig] = None - - -def get_config() -> StackyConfig: - global CONFIG - if CONFIG is None: - CONFIG = read_config() - return CONFIG - - -def read_config() -> StackyConfig: - root_dir = get_top_level_dir() - config = StackyConfig() - config_paths = [f"{root_dir}/.stackyconfig", os.path.expanduser("~/.stackyconfig")] - - for p in config_paths: - if os.path.exists(p): - config.read_one_config(p) - - return config - - -def fmt(s: str, *args, color: bool = False, fg=None, bg=None, style=None, **kwargs) -> str: - s = colors.color(s, fg=fg, bg=bg, style=style) if color else s - return s.format(*args, **kwargs) - - -def cout(*args, **kwargs): - return sys.stdout.write(fmt(*args, color=COLOR_STDOUT, **kwargs)) - - -def _log(fn, *args, **kwargs): - return fn("%s", fmt(*args, color=COLOR_STDERR, **kwargs)) - - -def debug(*args, **kwargs): - return _log(logging.debug, *args, fg="green", **kwargs) - - -def info(*args, **kwargs): - return _log(logging.info, *args, fg="green", **kwargs) - - -def warning(*args, **kwargs): - return _log(logging.warning, *args, fg="yellow", **kwargs) - - -def error(*args, **kwargs): - return _log(logging.error, *args, fg="red", **kwargs) - - -class ExitException(BaseException): - def __init__(self, fmt, *args, **kwargs): - super().__init__(fmt.format(*args, **kwargs)) - - -def stop_muxed_ssh(remote: str = "origin"): - if get_config().share_ssh_session: - hostish = get_remote_type(remote) - if hostish is not None: - cmd = gen_ssh_mux_cmd() - cmd.append("-O") - cmd.append("exit") - cmd.append(hostish) - subprocess.Popen(cmd, stderr=subprocess.DEVNULL) - - -def die(*args, **kwargs): - # We are taking a wild guess at what is the remote ... - # TODO (mpatou) fix the assumption about the remote - stop_muxed_ssh() - raise ExitException(*args, **kwargs) - - -def _check_returncode(sp: subprocess.CompletedProcess, cmd: CmdArgs): - rc = sp.returncode - if rc == 0: - return - stderr = sp.stderr.decode("UTF-8") - if rc < 0: - die("Killed by signal {}: {}. Stderr was:\n{}", -rc, shlex.join(cmd), stderr) - else: - die("Exited with status {}: {}. Stderr was:\n{}", rc, shlex.join(cmd), stderr) - - -def run_multiline(cmd: CmdArgs, *, check: bool = True, null: bool = True, out: bool = False) -> Optional[str]: - debug("Running: {}", shlex.join(cmd)) - sys.stdout.flush() - sys.stderr.flush() - sp = subprocess.run( - cmd, - stdout=1 if out else subprocess.PIPE, - stderr=subprocess.PIPE, - ) - if check: - _check_returncode(sp, cmd) - rc = sp.returncode - if rc != 0: - return None - if sp.stdout is None: - return "" - return sp.stdout.decode("UTF-8") - - -def run_always_return(cmd: CmdArgs, **kwargs) -> str: - out = run(cmd, **kwargs) - assert out is not None - return out - - -def run(cmd: CmdArgs, **kwargs) -> Optional[str]: - out = run_multiline(cmd, **kwargs) - return None if out is None else out.strip() - - -def remove_prefix(s: str, prefix: str) -> str: - if not s.startswith(prefix): - die('Invalid string "{}": expected prefix "{}"', s, prefix) - return s[len(prefix) :] # noqa: E203 - - -def get_current_branch() -> Optional[BranchName]: - s = run(CmdArgs(["git", "symbolic-ref", "-q", "HEAD"])) - if s is not None: - return BranchName(remove_prefix(s, "refs/heads/")) - return None - - -def get_all_branches() -> List[BranchName]: - branches = run_multiline(CmdArgs(["git", "for-each-ref", "--format", "%(refname:short)", "refs/heads"])) - assert branches is not None - return [BranchName(b) for b in branches.split("\n") if b] - - -def get_real_stack_bottom() -> Optional[BranchName]: # type: ignore [return] - """ - return the actual stack bottom for this current repo - """ - branches = get_all_branches() - candiates = set() - for b in branches: - if b in STACK_BOTTOMS: - candiates.add(b) - - if len(candiates) == 1: - return candiates.pop() - - -def get_stack_parent_branch(branch: BranchName) -> Optional[BranchName]: # type: ignore [return] - if branch in STACK_BOTTOMS: - return None - p = run(CmdArgs(["git", "config", "branch.{}.merge".format(branch)]), check=False) - if p is not None: - p = remove_prefix(p, "refs/heads/") - if BranchName(p) == branch: - return None - return BranchName(p) - - -def get_top_level_dir() -> PathName: - p = run_always_return(CmdArgs(["git", "rev-parse", "--show-toplevel"])) - return PathName(p) - - -def get_stack_parent_commit(branch: BranchName) -> Optional[Commit]: # type: ignore [return] - c = run( - CmdArgs(["git", "rev-parse", "refs/stack-parent/{}".format(branch)]), - check=False, - ) - - if c is not None: - return Commit(c) - - -def get_commit(branch: BranchName) -> Commit: # type: ignore [return] - c = run_always_return(CmdArgs(["git", "rev-parse", "refs/heads/{}".format(branch)]), check=False) - return Commit(c) - - -def get_pr_info(branch: BranchName, *, full: bool = False) -> PRInfos: - fields = [ - "id", - "number", - "state", - "mergeable", - "url", - "title", - "baseRefName", - "headRefName", - ] - if full: - fields += ["commits"] - data = json.loads( - run_always_return( - CmdArgs( - [ - "gh", - "pr", - "list", - "--json", - ",".join(fields), - "--state", - "all", - "--head", - branch, - ] - ) - ) - ) - raw_infos: List[PRInfo] = data - - infos: Dict[str, PRInfo] = {info["id"]: info for info in raw_infos} - open_prs: List[PRInfo] = [info for info in infos.values() if info["state"] == "OPEN"] - if len(open_prs) > 1: - die( - "Branch {} has more than one open PR: {}", - branch, - ", ".join([str(pr) for pr in open_prs]), - ) # type: ignore[arg-type] - return PRInfos(infos, open_prs[0] if open_prs else None) - - -# (remote, remote_branch, remote_branch_commit) -def get_remote_info(branch: BranchName) -> Tuple[str, BranchName, Optional[Commit]]: - if branch not in STACK_BOTTOMS: - remote = run(CmdArgs(["git", "config", "branch.{}.remote".format(branch)]), check=False) - if remote != ".": - die("Misconfigured branch {}: remote {}", branch, remote) - - # TODO(tudor): Maybe add a way to change these. - remote = "origin" - remote_branch = branch - - remote_commit = run( - CmdArgs(["git", "rev-parse", "refs/remotes/{}/{}".format(remote, remote_branch)]), - check=False, - ) - - # TODO(mpatou): do something when remote_commit is none - commit = None - if remote_commit is not None: - commit = Commit(remote_commit) - - return (remote, BranchName(remote_branch), commit) - - -class StackBranch: - def __init__( - self, - name: BranchName, - parent: "StackBranch", - parent_commit: Commit, - ): - self.name = name - self.parent = parent - self.parent_commit = parent_commit - self.children: set["StackBranch"] = set() - self.commit = get_commit(name) - self.remote, self.remote_branch, self.remote_commit = get_remote_info(name) - self.pr_info: Dict[str, PRInfo] = {} - self.open_pr_info: Optional[PRInfo] = None - self._pr_info_loaded = False - - def is_synced_with_parent(self): - return self.parent is None or self.parent_commit == self.parent.commit - - def is_synced_with_remote(self): - return self.commit == self.remote_commit - - def __repr__(self): - return f"StackBranch: {self.name} {len(self.children)} {self.commit}" - - def load_pr_info(self): - if not self._pr_info_loaded: - self._pr_info_loaded = True - pr_infos = get_pr_info(self.name) - # FIXME maybe store the whole object and use it elsewhere - self.pr_info, self.open_pr_info = ( - pr_infos.all, - pr_infos.open, - ) - - -class StackBranchSet: - def __init__(self: "StackBranchSet"): - self.stack: Dict[BranchName, StackBranch] = {} - self.tops: set[StackBranch] = set() - self.bottoms: set[StackBranch] = set() - - def add(self, name: BranchName, **kwargs) -> StackBranch: - if name in self.stack: - s = self.stack[name] - assert s.name == name - for k, v in kwargs.items(): - if getattr(s, k) != v: - die( - "Mismatched stack: {}: {}={}, expected {}", - name, - k, - getattr(s, k), - v, - ) - else: - s = StackBranch(name, **kwargs) - self.stack[name] = s - if s.parent is None: - self.bottoms.add(s) - self.tops.add(s) - return s - - def addStackBranch(self, s: StackBranch): - if s.name not in self.stack: - self.stack[s.name] = s - if s.parent is None: - self.bottoms.add(s) - if len(s.children) == 0: - self.tops.add(s) - - return s - - def remove(self, name: BranchName) -> Optional[StackBranch]: - if name in self.stack: - s = self.stack[name] - assert s.name == name - del self.stack[name] - if s in self.tops: - self.tops.remove(s) - if s in self.bottoms: - self.bottoms.remove(s) - return s - - return None - - def __repr__(self) -> str: - out = f"StackBranchSet: {self.stack}" - return out - - def add_child(self, s: StackBranch, child: StackBranch): - s.children.add(child) - self.tops.discard(s) - - -def load_stack_for_given_branch( - stack: StackBranchSet, branch: BranchName, *, check: bool = True -) -> Tuple[Optional[StackBranch], List[BranchName]]: - """Given a stack of branch and a branch name, - update the stack with all the parents of the specified branch - if the branch is part of an existing stack. - Return also a list of BranchName of all the branch bellow the specified one - """ - branches: List[BranchNCommit] = [] - while branch not in STACK_BOTTOMS: - parent = get_stack_parent_branch(branch) - parent_commit = get_stack_parent_commit(branch) - branches.append(BranchNCommit(branch, parent_commit)) - if not parent or not parent_commit: - if check: - die("Branch is not in a stack: {}", branch) - return None, [b.branch for b in branches] - branch = parent - - branches.append(BranchNCommit(branch, None)) - top = None - for b in reversed(branches): - n = stack.add( - b.branch, - parent=top, - parent_commit=b.parent_commit, - ) - if top: - stack.add_child(top, n) - top = n - - return top, [b.branch for b in branches] - - -def get_branch_name_from_short_ref(ref: str) -> BranchName: - parts = ref.split("/", 1) - if len(parts) != 2: - die("invalid ref: {}".format(ref)) - - return BranchName(parts[1]) - - -def get_all_stack_bottoms() -> List[BranchName]: - branches = run_multiline( - CmdArgs(["git", "for-each-ref", "--format", "%(refname:short)", "refs/stacky-bottom-branch"]) - ) - if branches: - return [get_branch_name_from_short_ref(b) for b in branches.split("\n") if b] - return [] - - -def get_all_stack_parent_refs() -> List[BranchName]: - branches = run_multiline(CmdArgs(["git", "for-each-ref", "--format", "%(refname:short)", "refs/stack-parent"])) - if branches: - return [get_branch_name_from_short_ref(b) for b in branches.split("\n") if b] - return [] - - -def load_all_stack_bottoms(): - branches = run_multiline( - CmdArgs(["git", "for-each-ref", "--format", "%(refname:short)", "refs/stacky-bottom-branch"]) - ) - STACK_BOTTOMS.update(get_all_stack_bottoms()) - - -def load_all_stacks(stack: StackBranchSet) -> Optional[StackBranch]: - """Given a stack return the top of it, aka the bottom of the tree""" - load_all_stack_bottoms() - all_branches = set(get_all_branches()) - current_branch_top = None - while all_branches: - b = all_branches.pop() - top, branches = load_stack_for_given_branch(stack, b, check=False) - all_branches -= set(branches) - if top is None: - if len(branches) > 1: - # Incomplete (broken) stack - warning("Broken stack: {}", " -> ".join(branches)) - continue - if b == CURRENT_BRANCH: - current_branch_top = top - return current_branch_top - - -def make_tree_node(b: StackBranch) -> TreeNode: - return (b.name, (b, make_subtree(b))) - - -def make_subtree(b) -> BranchesTree: - return BranchesTree(dict(make_tree_node(c) for c in sorted(b.children, key=lambda x: x.name))) - - -def make_tree(b: StackBranch) -> BranchesTree: - return BranchesTree(dict([make_tree_node(b)])) - - -def format_name(b: StackBranch, *, colorize: bool) -> str: - prefix = "" - severity = 0 - # TODO: Align things so that we have the same prefix length ? - if not b.is_synced_with_parent(): - prefix += fmt("!", color=colorize, fg="yellow") - severity = max(severity, 2) - if not b.is_synced_with_remote(): - prefix += fmt("~", color=colorize, fg="yellow") - if b.name == CURRENT_BRANCH: - prefix += fmt("*", color=colorize, fg="cyan") - else: - severity = max(severity, 1) - if prefix: - prefix += " " - fg = ["cyan", "green", "yellow", "red"][severity] - suffix = "" - if b.open_pr_info: - suffix += " " - suffix += fmt("(#{})", b.open_pr_info["number"], color=colorize, fg="blue") - suffix += " " - suffix += fmt("{}", b.open_pr_info["title"], color=colorize, fg="blue") - return prefix + fmt("{}", b.name, color=colorize, fg=fg) + suffix - - -def format_tree(tree: BranchesTree, *, colorize: bool = False): - return { - format_name(branch, colorize=colorize): format_tree(children, colorize=colorize) - for branch, children in tree.values() - } - - -# Print upside down, to match our "upstack" / "downstack" nomenclature -_ASCII_TREE_BOX = { - "UP_AND_RIGHT": "\u250c", - "HORIZONTAL": "\u2500", - "VERTICAL": "\u2502", - "VERTICAL_AND_RIGHT": "\u251c", -} -_ASCII_TREE_STYLE = asciitree.drawing.BoxStyle(gfx=_ASCII_TREE_BOX) -ASCII_TREE = asciitree.LeftAligned(draw=_ASCII_TREE_STYLE) - - -def print_tree(tree: BranchesTree): - global ASCII_TREE - s = ASCII_TREE(format_tree(tree, colorize=COLOR_STDOUT)) - lines = s.split("\n") - print("\n".join(reversed(lines))) - - -def print_forest(trees: List[BranchesTree]): - for i, t in enumerate(trees): - if i != 0: - print() - print_tree(t) - - -def get_all_stacks_as_forest(stack: StackBranchSet) -> BranchesTreeForest: - return BranchesTreeForest([make_tree(b) for b in stack.bottoms]) - - -def get_current_stack_as_forest(stack: StackBranchSet): - b = stack.stack[CURRENT_BRANCH] - d: BranchesTree = make_tree(b) - b = b.parent - while b: - d = BranchesTree({b.name: (b, d)}) - b = b.parent - return [d] - - -def get_current_upstack_as_forest(stack: StackBranchSet) -> BranchesTreeForest: - b = stack.stack[CURRENT_BRANCH] - return BranchesTreeForest([make_tree(b)]) - - -def get_current_downstack_as_forest(stack: StackBranchSet) -> BranchesTreeForest: - b = stack.stack[CURRENT_BRANCH] - d: BranchesTree = BranchesTree({}) - while b: - d = BranchesTree({b.name: (b, d)}) - b = b.parent - return BranchesTreeForest([d]) - - -def init_git(): - push_default = run(["git", "config", "remote.pushDefault"], check=False) - if push_default is not None: - die("`git config remote.pushDefault` may not be set") - auth_status = run(["gh", "auth", "status"], check=False) - if auth_status is None: - die("`gh` authentication failed") - global CURRENT_BRANCH - CURRENT_BRANCH = get_current_branch() - - -def forest_depth_first( - forest: BranchesTreeForest, -) -> Generator[StackBranch, None, None]: - for tree in forest: - for b in depth_first(tree): - yield b - - -def depth_first(tree: BranchesTree) -> Generator[StackBranch, None, None]: - # This is for the regular forest - for _, (branch, children) in tree.items(): - yield branch - for b in depth_first(children): - yield b - - -def menu_choose_branch(forest: BranchesTreeForest): - if not IS_TERMINAL: - die("May only choose from menu when using a terminal") - - global ASCII_TREE - s = "" - lines = [] - for tree in forest: - s = ASCII_TREE(format_tree(tree)) - lines += [l.rstrip() for l in s.split("\n")] - lines.reverse() - - initial_index = 0 - for i, l in enumerate(lines): - if "*" in l: # lol - initial_index = i - break - - menu = TerminalMenu(lines, cursor_index=initial_index) - idx = menu.show() - if idx is None: - die("Aborted") - - branches = list(forest_depth_first(forest)) - branches.reverse() - return branches[idx] - - -def load_pr_info_for_forest(forest: BranchesTreeForest): - for b in forest_depth_first(forest): - b.load_pr_info() - - -def cmd_info(stack: StackBranchSet, args): - forest = get_all_stacks_as_forest(stack) - if args.pr: - load_pr_info_for_forest(forest) - print_forest(forest) - - -def checkout(branch): - info("Checking out branch {}", branch) - run(["git", "checkout", branch], out=True) - - -def cmd_branch_up(stack: StackBranchSet, args): - b = stack.stack[CURRENT_BRANCH] - if not b.children: - info("Branch {} is already at the top of the stack", CURRENT_BRANCH) - return - if len(b.children) > 1: - if not IS_TERMINAL: - die( - "Branch {} has multiple children: {}", - CURRENT_BRANCH, - ", ".join(c.name for c in b.children), - ) - cout( - "Branch {} has {} children, choose one\n", - CURRENT_BRANCH, - len(b.children), - fg="green", - ) - forest = BranchesTreeForest([BranchesTree({BranchName(c.name): (c, BranchesTree({}))}) for c in b.children]) - child = menu_choose_branch(forest).name - else: - child = next(iter(b.children)).name - checkout(child) - - -def cmd_branch_down(stack: StackBranchSet, args): - b = stack.stack[CURRENT_BRANCH] - if not b.parent: - info("Branch {} is already at the bottom of the stack", CURRENT_BRANCH) - return - checkout(b.parent.name) - - -def create_branch(branch): - run(["git", "checkout", "-b", branch, "--track"], out=True) - - -def cmd_branch_new(stack: StackBranchSet, args): - b = stack.stack[CURRENT_BRANCH] - assert b.commit - name = args.name - create_branch(name) - run(CmdArgs(["git", "update-ref", "refs/stack-parent/{}".format(name), b.commit, ""])) - - -def cmd_branch_checkout(stack: StackBranchSet, args): - branch_name = args.name - if branch_name is None: - forest = get_all_stacks_as_forest(stack) - branch_name = menu_choose_branch(forest).name - checkout(branch_name) - - -def cmd_stack_info(stack: StackBranchSet, args): - forest = get_current_stack_as_forest(stack) - if args.pr: - load_pr_info_for_forest(forest) - print_forest(forest) - - -def cmd_stack_checkout(stack: StackBranchSet, args): - forest = get_current_stack_as_forest(stack) - branch_name = menu_choose_branch(forest).name - checkout(branch_name) - - -def prompt(message: str, default_value: Optional[str]) -> str: - cout(message) - if default_value is not None: - cout("({})", default_value, fg="gray") - cout(" ") - while True: - sys.stderr.flush() - r = input().strip() - - if len(r) > 0: - return r - if default_value: - return default_value - - -def confirm(msg: str = "Proceed?"): - if get_config().skip_confirm: - return - if not os.isatty(0): - die("Standard input is not a terminal, use --force option to force action") - print() - while True: - cout("{} [yes/no] ", msg, fg="yellow") - sys.stderr.flush() - r = input().strip().lower() - if r == "yes": - break - if r == "no": - die("Not confirmed") - cout("Please answer yes or no\n", fg="red") - - -def find_reviewers(b: StackBranch) -> Optional[List[str]]: - out = run_multiline( - CmdArgs( - [ - "git", - "log", - "--pretty=format:%b", - "-1", - f"{b.name}", - ] - ), - ) - assert out is not None - for l in out.split("\n"): - reviewer_match = re.match(r"^reviewers?\s*:\s*(.*)", l, re.I) - if reviewer_match: - reviewers = reviewer_match.group(1).split(",") - logging.debug(f"Found the following reviewers: {', '.join(reviewers)}") - return reviewers - return None - - -def find_issue_marker(name: str) -> Optional[str]: - match = re.search(r"(?:^|[_-])([A-Z]{3,}[_-]?\d{2,})($|[_-].*)", name) - if match: - res = match.group(1) - if "_" in res: - return res.replace("_", "-") - if not "-" in res: - newmatch = re.match(r"(...)(\d+)", res) - assert newmatch is not None - return f"{newmatch.group(1)}-{newmatch.group(2)}" - return res - - return None - - -def create_gh_pr(b: StackBranch, prefix: str): - cout("Creating PR for {}\n", b.name, fg="green") - parent_prefix = "" - if b.parent.name not in STACK_BOTTOMS: - # you are pushing a sub stack, there is no way we can make it work - # accross repo so we will push within your own clone - prefix = "" - cmd = [ - "gh", - "pr", - "create", - "--head", - f"{prefix}{b.name}", - "--base", - f"{parent_prefix}{b.parent.name}", - ] - reviewers = find_reviewers(b) - issue_id = find_issue_marker(b.name) - if issue_id: - out = run_multiline( - CmdArgs(["git", "log", "--pretty=oneline", f"{b.parent.name}..{b.name}"]), - ) - title = f"[{issue_id}] " - # Just one line (hence 2 elements with the last one being an empty string when we - # split on "\"n ? - # Then use the title of the commit as the title of the PR - if out is not None and len(out.split("\n")) == 2: - out = run( - CmdArgs( - [ - "git", - "log", - "--pretty=format:%s", - "-1", - f"{b.name}", - ] - ), - out=False, - ) - if out is None: - out = "" - if b.name not in out: - title += out - else: - title = out - - title = prompt( - (fmt("? ", color=COLOR_STDOUT, fg="green") + fmt("Title ", color=COLOR_STDOUT, style="bold", fg="white")), - title, - ) - cmd.extend(["--title", title.strip()]) - if reviewers: - logging.debug(f"Adding {len(reviewers)} reviewer(s) to the review") - for r in reviewers: - r = r.strip() - r = r.replace("#", "rockset/") - if len(r) > 0: - cmd.extend(["--reviewer", r]) - - run( - CmdArgs(cmd), - out=True, - ) - - -def do_push( - forest: BranchesTreeForest, - *, - force: bool = False, - pr: bool = False, - remote_name: str = "origin", -): - if pr: - load_pr_info_for_forest(forest) - print_forest(forest) - for b in forest_depth_first(forest): - if not b.is_synced_with_parent(): - die( - "Branch {} is not synced with parent {}, sync first", - b.name, - b.parent.name, - ) - - # (branch, push, pr_action) - PR_NONE = 0 - PR_FIX_BASE = 1 - PR_CREATE = 2 - actions = [] - for b in forest_depth_first(forest): - if not b.parent: - cout("✓ Not pushing base branch {}\n", b.name, fg="green") - continue - - push = False - if b.is_synced_with_remote(): - cout( - "✓ Not pushing branch {}, synced with remote {}/{}\n", - b.name, - b.remote, - b.remote_branch, - fg="green", - ) - else: - cout("- Will push branch {} to {}/{}\n", b.name, b.remote, b.remote_branch) - push = True - - pr_action = PR_NONE - if pr: - if b.open_pr_info: - expected_base = b.parent.name - if b.open_pr_info["baseRefName"] != expected_base: - cout( - "- Branch {} already has open PR #{}; will change PR base from {} to {}\n", - b.name, - b.open_pr_info["number"], - b.open_pr_info["baseRefName"], - expected_base, - ) - pr_action = PR_FIX_BASE - else: - cout( - "✓ Branch {} already has open PR #{}\n", - b.name, - b.open_pr_info["number"], - fg="green", - ) - else: - cout("- Will create PR for branch {}\n", b.name) - pr_action = PR_CREATE - - if not push and pr_action == PR_NONE: - continue - - actions.append((b, push, pr_action)) - - if actions and not force: - confirm() - - # Figure out if we need to add a prefix to the branch - # ie. user:foo - # We should call gh repo set-default before doing that - val = run(CmdArgs(["git", "config", f"remote.{remote_name}.gh-resolved"]), check=False) - if val is not None and "/" in val: - # If there is a "/" in the gh-resolved it means that the repo where - # the should be created is not the same as the one where the push will - # be made, we need to add a prefix to the branch in the gh pr command - val = run_always_return(CmdArgs(["git", "config", f"remote.{remote_name}.url"])) - prefix = f'{val.split(":")[1].split("/")[0]}:' - else: - prefix = "" - muxed = False - for b, push, pr_action in actions: - if push: - if not muxed: - start_muxed_ssh(remote_name) - muxed = True - # Try to run pre-push before muxing ... - # To do so we need to pickup the current commit of the branch, the branch name, the - # parent branch and it's parent commit and call .git/hooks/pre-push - cout("Pushing {}\n", b.name, fg="green") - run( - CmdArgs( - [ - "git", - "push", - "-f" if get_config().use_force_push else "", - b.remote, - "{}:{}".format(b.name, b.remote_branch), - ] - ), - out=True, - ) - if pr_action == PR_FIX_BASE: - cout("Fixing PR base for {}\n", b.name, fg="green") - assert b.open_pr_info is not None - run( - CmdArgs( - [ - "gh", - "pr", - "edit", - str(b.open_pr_info["number"]), - "--base", - b.parent.name, - ] - ), - out=True, - ) - elif pr_action == PR_CREATE: - create_gh_pr(b, prefix) - - stop_muxed_ssh(remote_name) - - -def cmd_stack_push(stack: StackBranchSet, args): - do_push( - get_current_stack_as_forest(stack), - force=args.force, - pr=args.pr, - remote_name=args.remote_name, - ) - - -def do_sync(forest: BranchesTreeForest): - print_forest(forest) - - syncs: List[StackBranch] = [] - sync_names: List[BranchName] = [] - syncs_set: set[StackBranch] = set() - for b in forest_depth_first(forest): - if not b.parent: - cout("✓ Not syncing base branch {}\n", b.name, fg="green") - continue - if b.is_synced_with_parent() and not b.parent in syncs_set: - cout( - "✓ Not syncing branch {}, already synced with parent {}\n", - b.name, - b.parent.name, - fg="green", - ) - continue - syncs.append(b) - syncs_set.add(b) - sync_names.append(b.name) - cout("- Will sync branch {} on top of {}\n", b.name, b.parent.name) - - if not syncs: - return - - syncs.reverse() - sync_names.reverse() - # TODO: use list(syncs_set).reverse() ? - inner_do_sync(syncs, sync_names) - - -def set_parent_commit(branch: BranchName, new_commit: Commit, prev_commit: Optional[str] = None): - cmd = [ - "git", - "update-ref", - "refs/stack-parent/{}".format(branch), - new_commit, - ] - if prev_commit is not None: - cmd.append(prev_commit) - run(CmdArgs(cmd)) - - -def get_commits_between(a: Commit, b: Commit): - lines = run_multiline(CmdArgs(["git", "rev-list", "{}..{}".format(a, b)])) - assert lines is not None - return [x.strip() for x in lines.split("\n")] - - -def inner_do_sync(syncs: List[StackBranch], sync_names: List[BranchName]): - print() - sync_type = "merge" if get_config().use_merge else "rebase" - while syncs: - with open(TMP_STATE_FILE, "w") as f: - json.dump({"branch": CURRENT_BRANCH, "sync": sync_names}, f) - os.replace(TMP_STATE_FILE, STATE_FILE) # make the write atomic - - b = syncs.pop() - sync_names.pop() - if b.is_synced_with_parent(): - cout("{} is already synced on top of {}\n", b.name, b.parent.name) - continue - if b.parent.commit in get_commits_between(b.parent_commit, b.commit): - cout( - "Recording complete {} of {} on top of {}\n", - sync_type, - b.name, - b.parent.name, - fg="green", - ) - else: - r = None - if get_config().use_merge: - cout("Merging {} into {}\n", b.parent.name, b.name, fg="green") - run(CmdArgs(["git", "checkout", str(b.name)])) - r = run( - CmdArgs(["git", "merge", b.parent.name]), - out=True, - check=False, - ) - else: - cout("Rebasing {} on top of {}\n", b.name, b.parent.name, fg="green") - r = run( - CmdArgs(["git", "rebase", "--onto", b.parent.name, b.parent_commit, b.name]), - out=True, - check=False, - ) - - if r is None: - print() - die( - "Automatic {0} failed. Please complete the {0} (fix conflicts; `git {0} --continue`), then run `stacky continue`".format( - sync_type - ) - ) - b.commit = get_commit(b.name) - set_parent_commit(b.name, b.parent.commit, b.parent_commit) - b.parent_commit = b.parent.commit - run(CmdArgs(["git", "checkout", str(CURRENT_BRANCH)])) - - -def cmd_stack_sync(stack: StackBranchSet, args): - do_sync(get_current_stack_as_forest(stack)) - - -def do_commit(stack: StackBranchSet, *, message=None, amend=False, allow_empty=False, edit=True): - b = stack.stack[CURRENT_BRANCH] - if not b.parent: - die("Do not commit directly on {}", b.name) - if not b.is_synced_with_parent(): - die( - "Branch {} is not synced with parent {}, sync before committing", - b.name, - b.parent.name, - ) - - if amend and (get_config().use_merge or not get_config().use_force_push): - die("Amending is not allowed if using git merge or if force pushing is disallowed") - - if amend and b.commit == b.parent.commit: - die("Branch {} has no commits, may not amend", b.name) - - cmd = ["git", "commit"] - if allow_empty: - cmd += ["--allow-empty"] - if amend: - cmd += ["--amend"] - if not edit: - cmd += ["--no-edit"] - elif not edit: - die("--no-edit is only supported with --amend") - if message: - cmd += ["-m", message] - run(CmdArgs(cmd), out=True) - - # Sync everything upstack - b.commit = get_commit(b.name) - do_sync(get_current_upstack_as_forest(stack)) - - -def cmd_commit(stack: StackBranchSet, args): - do_commit( - stack, - message=args.message, - amend=args.amend, - allow_empty=args.allow_empty, - edit=not args.no_edit, - ) - - -def cmd_amend(stack: StackBranchSet, args): - do_commit(stack, amend=True, edit=False) - - -def cmd_upstack_info(stack: StackBranchSet, args): - forest = get_current_upstack_as_forest(stack) - if args.pr: - load_pr_info_for_forest(forest) - print_forest(forest) - - -def cmd_upstack_push(stack: StackBranchSet, args): - do_push( - get_current_upstack_as_forest(stack), - force=args.force, - pr=args.pr, - remote_name=args.remote_name, - ) - - -def cmd_upstack_sync(stack: StackBranchSet, args): - do_sync(get_current_upstack_as_forest(stack)) - - -def set_parent(branch: BranchName, target: Optional[BranchName], *, set_origin: bool = False): - if set_origin: - run(CmdArgs(["git", "config", "branch.{}.remote".format(branch), "."])) - - ## If target is none this becomes a new stack bottom - run( - CmdArgs( - [ - "git", - "config", - "branch.{}.merge".format(branch), - "refs/heads/{}".format(target if target is not None else branch), - ] - ) - ) - - if target is None: - run( - CmdArgs( - [ - "git", - "update-ref", - "-d", - "refs/stack-parent/{}".format(branch), - ] - ) - ) - - -def cmd_upstack_onto(stack: StackBranchSet, args): - b = stack.stack[CURRENT_BRANCH] - if not b.parent: - die("may not upstack a stack bottom, use stacky adopt") - target = stack.stack[args.target] - upstack = get_current_upstack_as_forest(stack) - for ub in forest_depth_first(upstack): - if ub == target: - die("Target branch {} is upstack of {}", target.name, b.name) - b.parent = target - set_parent(b.name, target.name) - - do_sync(upstack) - - -def cmd_upstack_as_base(stack: StackBranchSet): - b = stack.stack[CURRENT_BRANCH] - if not b.parent: - die("Branch {} is already a stack bottom", b.name) - - b.parent = None # type: ignore - stack.remove(b.name) - stack.addStackBranch(b) - set_parent(b.name, None) - - run(CmdArgs(["git", "update-ref", "refs/stacky-bottom-branch/{}".format(b.name), b.commit, ""])) - info("Set {} as new bottom branch".format(b.name)) - - -def cmd_upstack_as(stack: StackBranchSet, args): - if args.target == "bottom": - cmd_upstack_as_base(stack) - else: - die("Invalid target {}, acceptable targets are [base]", args.target) - - -def cmd_downstack_info(stack, args): - forest = get_current_downstack_as_forest(stack) - if args.pr: - load_pr_info_for_forest(forest) - print_forest(forest) - - -def cmd_downstack_push(stack: StackBranchSet, args): - do_push( - get_current_downstack_as_forest(stack), - force=args.force, - pr=args.pr, - remote_name=args.remote_name, - ) - - -def cmd_downstack_sync(stack: StackBranchSet, args): - do_sync(get_current_downstack_as_forest(stack)) - - -def get_bottom_level_branches_as_forest(stack: StackBranchSet) -> BranchesTreeForest: - return BranchesTreeForest( - [ - BranchesTree( - { - bottom.name: ( - bottom, - BranchesTree({b.name: (b, BranchesTree({})) for b in bottom.children}), - ) - } - ) - for bottom in stack.bottoms - ] - ) - - -def get_remote_type(remote: str = "origin") -> Optional[str]: - out = run_always_return(CmdArgs(["git", "remote", "-v"])) - for l in out.split("\n"): - match = re.match(r"^{}\s+(?:ssh://)?([^/]*):(?!//).*\s+\(push\)$".format(remote), l) - if match: - sshish_host = match.group(1) - return sshish_host - - return None - - -def gen_ssh_mux_cmd() -> List[str]: - args = [ - "ssh", - "-o", - "ControlMaster=auto", - "-o", - f"ControlPersist={MAX_SSH_MUX_LIFETIME}", - "-o", - "ControlPath=~/.ssh/stacky-%C", - ] - - return args - - -def start_muxed_ssh(remote: str = "origin"): - if not get_config().share_ssh_session: - return - hostish = get_remote_type(remote) - if hostish is not None: - info("Creating a muxed ssh connection") - cmd = gen_ssh_mux_cmd() - os.environ["GIT_SSH_COMMAND"] = " ".join(cmd) - cmd.append("-MNf") - cmd.append(hostish) - # We don't want to use the run() wrapper because - # we don't want to wait for the process to finish - - p = subprocess.Popen(cmd, stderr=subprocess.PIPE) - # Wait a little bit for the connection to establish - # before carrying on - while p.poll() is None: - time.sleep(1) - if p.returncode != 0: - if p.stderr is not None: - error = p.stderr.read() - else: - error = b"unknown" - die(f"Failed to start ssh muxed connection, error was: {error.decode('utf-8').strip()}") - - -def get_branches_to_delete(forest: BranchesTreeForest) -> List[StackBranch]: - deletes = [] - for b in forest_depth_first(forest): - if not b.parent or b.open_pr_info: - continue - for pr_info in b.pr_info.values(): - if pr_info["state"] != "MERGED": - continue - cout( - "- Will delete branch {}, PR #{} merged into {}\n", - b.name, - pr_info["number"], - b.parent.name, - ) - deletes.append(b) - for c in b.children: - cout( - "- Will reparent branch {} onto {}\n", - c.name, - b.parent.name, - ) - break - return deletes - - -def delete_branches(stack: StackBranchSet, deletes: List[StackBranch]): - global CURRENT_BRANCH - # Make sure we're not trying to delete the current branch - for b in deletes: - for c in b.children: - info("Reparenting {} onto {}", c.name, b.parent.name) - c.parent = b.parent - set_parent(c.name, b.parent.name) - info("Deleting {}", b.name) - if b.name == CURRENT_BRANCH: - new_branch = next(iter(stack.bottoms)) - info("About to delete current branch, switching to {}", new_branch.name) - run(CmdArgs(["git", "checkout", new_branch.name])) - CURRENT_BRANCH = new_branch.name - run(CmdArgs(["git", "branch", "-D", b.name])) - - -def cleanup_unused_refs(stack: StackBranchSet): - # Clean up stacky bottom branch refs - info("Cleaning up unused refs") - stack_bottoms = get_all_stack_bottoms() - for bottom in stack_bottoms: - if not bottom in stack.stack: - ref = "refs/stacky-bottom-branch/{}".format(bottom) - info("Deleting ref {}".format(ref)) - run(CmdArgs(["git", "update-ref", "-d", ref])) - - stack_parent_refs = get_all_stack_parent_refs() - for br in stack_parent_refs: - if br not in stack.stack: - ref = "refs/stack-parent/{}".format(br) - old_value = run(CmdArgs(["git", "show-ref", ref])) - info("Deleting ref {}".format(old_value)) - run(CmdArgs(["git", "update-ref", "-d", ref])) - - -def cmd_update(stack: StackBranchSet, args): - remote = "origin" - start_muxed_ssh(remote) - info("Fetching from {}", remote) - run(CmdArgs(["git", "fetch", remote])) - - # TODO(tudor): We should rebase instead of silently dropping - # everything you have on local master. Oh well. - global CURRENT_BRANCH - for b in stack.bottoms: - run( - CmdArgs( - [ - "git", - "update-ref", - "refs/heads/{}".format(b.name), - "refs/remotes/{}/{}".format(remote, b.remote_branch), - ] - ) - ) - if b.name == CURRENT_BRANCH: - run(CmdArgs(["git", "reset", "--hard", "HEAD"])) - - # We treat origin as the source of truth for bottom branches (master), and - # the local repo as the source of truth for everything else. So we can only - # track PR closure for branches that are direct descendants of master. - - info("Checking if any PRs have been merged and can be deleted") - forest = get_bottom_level_branches_as_forest(stack) - load_pr_info_for_forest(forest) - - deletes = get_branches_to_delete(forest) - if deletes and not args.force: - confirm() - - delete_branches(stack, deletes) - stop_muxed_ssh(remote) - - cleanup_unused_refs(stack) - - -def cmd_import(stack: StackBranchSet, args): - # Importing has to happen based on PR info, rather than local branch - # relationships, as that's the only place Graphite populates. - branch = args.name - branches = [] - bottoms = set(b.name for b in stack.bottoms) - while branch not in bottoms: - pr_info = get_pr_info(branch, full=True) - open_pr = pr_info.open - info("Getting PR information for {}", branch) - if open_pr is None: - die("Branch {} has no open PR", branch) - # Never reached because the die but makes mypy happy - assert open_pr is not None - if open_pr["headRefName"] != branch: - die( - "Branch {} is misconfigured: PR #{} head is {}", - branch, - open_pr["number"], - open_pr["headRefName"], - ) - if not open_pr["commits"]: - die("PR #{} has no commits", open_pr["number"]) - first_commit = open_pr["commits"][0]["oid"] - parent_commit = Commit(run_always_return(CmdArgs(["git", "rev-parse", "{}^".format(first_commit)]))) - next_branch = open_pr["baseRefName"] - info( - "Branch {}: PR #{}, parent is {} at commit {}", - branch, - open_pr["number"], - next_branch, - parent_commit, - ) - branches.append((branch, parent_commit)) - branch = next_branch - - if not branches: - return - - base_branch = branch - branches.reverse() - - for b, parent_commit in branches: - cout( - "- Will set parent of {} to {} at commit {}\n", - b, - branch, - parent_commit, - ) - branch = b - - if not args.force: - confirm() - - branch = base_branch - for b, parent_commit in branches: - set_parent(b, branch, set_origin=True) - set_parent_commit(b, parent_commit) - branch = b - - -def get_merge_base(b1: BranchName, b2: BranchName): - return run(CmdArgs(["git", "merge-base", str(b1), str(b2)])) - - -def cmd_adopt(stack: StackBranch, args): - """ - Adopt a branch that is based on the current branch (which must be a - valid stack bottom or the stack bottom (master or main) will be used - if change_to_main option is set in the config file - """ - branch = args.name - global CURRENT_BRANCH - - if branch == CURRENT_BRANCH: - die("A branch cannot adopt itself") - - if CURRENT_BRANCH not in STACK_BOTTOMS: - # TODO remove that, the initialisation code is already dealing with that in fact - main_branch = get_real_stack_bottom() - - if get_config().change_to_main and main_branch is not None: - run(CmdArgs(["git", "checkout", main_branch])) - CURRENT_BRANCH = main_branch - else: - die( - "The current branch {} must be a valid stack bottom: {}", - CURRENT_BRANCH, - ", ".join(sorted(STACK_BOTTOMS)), - ) - if branch in STACK_BOTTOMS: - if branch in FROZEN_STACK_BOTTOMS: - die("Cannot adopt frozen stack bottoms {}".format(FROZEN_STACK_BOTTOMS)) - # Remove the ref that this is a stack bottom - run(CmdArgs(["git", "update-ref", "-d", "refs/stacky-bottom-branch/{}".format(branch)])) - - parent_commit = get_merge_base(CURRENT_BRANCH, branch) - set_parent(branch, CURRENT_BRANCH, set_origin=True) - set_parent_commit(branch, parent_commit) - if get_config().change_to_adopted: - run(CmdArgs(["git", "checkout", branch])) - - -def cmd_land(stack: StackBranchSet, args): - forest = get_current_downstack_as_forest(stack) - assert len(forest) == 1 - branches = [] - p = forest[0] - while p: - assert len(p) == 1 - _, (b, p) = next(iter(p.items())) - branches.append(b) - assert branches - assert branches[0] in stack.bottoms - if len(branches) == 1: - die("May not land {}", branches[0].name) - - b = branches[1] - if not b.is_synced_with_parent(): - die( - "Branch {} is not synced with parent {}, sync before landing", - b.name, - b.parent.name, - ) - if not b.is_synced_with_remote(): - die( - "Branch {} is not synced with remote branch, push local changes before landing", - b.name, - ) - - b.load_pr_info() - pr = b.open_pr_info - if not pr: - die("Branch {} does not have an open PR", b.name) - assert pr is not None - - if pr["mergeable"] != "MERGEABLE": - die( - "PR #{} for branch {} is not mergeable: {}", - pr["number"], - b.name, - pr["mergeable"], - ) - - if len(branches) > 2: - cout( - "The `land` command only lands the bottom-most branch {}; the current stack has {} branches, ending with {}\n", - b.name, - len(branches) - 1, - CURRENT_BRANCH, - fg="yellow", - ) - - msg = fmt("- Will land PR #{} (", pr["number"], color=COLOR_STDOUT) - msg += fmt("{}", pr["url"], color=COLOR_STDOUT, fg="blue") - msg += fmt(") for branch {}", b.name, color=COLOR_STDOUT) - msg += fmt(" into branch {}\n", b.parent.name, color=COLOR_STDOUT) - sys.stdout.write(msg) - - if not args.force: - confirm() - - v = run(CmdArgs(["git", "rev-parse", b.name])) - assert v is not None - head_commit = Commit(v) - cmd = CmdArgs(["gh", "pr", "merge", b.name, "--squash", "--match-head-commit", head_commit]) - if args.auto: - cmd.append("--auto") - run(cmd, out=True) - cout("\n✓ Success! Run `stacky update` to update local state.\n", fg="green") - - -def main(): - logging.basicConfig(format=_LOGGING_FORMAT, level=logging.INFO) - try: - parser = ArgumentParser(description="Handle git stacks") - parser.add_argument( - "--log-level", - default="info", - choices=LOGLEVELS.keys(), - help="Set the log level", - ) - parser.add_argument( - "--color", - default="auto", - choices=["always", "auto", "never"], - help="Colorize output and error", - ) - parser.add_argument( - "--remote-name", - "-r", - default="origin", - help="name of the git remote where branches will be pushed", - ) - - subparsers = parser.add_subparsers(required=True, dest="command") - - # continue - continue_parser = subparsers.add_parser("continue", help="Continue previously interrupted command") - continue_parser.set_defaults(func=None) - - # down - down_parser = subparsers.add_parser("down", help="Go down in the current stack (towards master/main)") - down_parser.set_defaults(func=cmd_branch_down) - # up - up_parser = subparsers.add_parser("up", help="Go up in the current stack (away master/main)") - up_parser.set_defaults(func=cmd_branch_up) - # info - info_parser = subparsers.add_parser("info", help="Stack info") - info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") - info_parser.set_defaults(func=cmd_info) - - # commit - commit_parser = subparsers.add_parser("commit", help="Commit") - commit_parser.add_argument("-m", help="Commit message", dest="message") - commit_parser.add_argument("--amend", action="store_true", help="Amend last commit") - commit_parser.add_argument("--allow-empty", action="store_true", help="Allow empty commit") - commit_parser.add_argument("--no-edit", action="store_true", help="Skip editor") - commit_parser.set_defaults(func=cmd_commit) - - # amend - amend_parser = subparsers.add_parser("amend", help="Shortcut for amending last commit") - amend_parser.set_defaults(func=cmd_amend) - - # branch - branch_parser = subparsers.add_parser("branch", aliases=["b"], help="Operations on branches") - branch_subparsers = branch_parser.add_subparsers(required=True, dest="branch_command") - branch_up_parser = branch_subparsers.add_parser("up", aliases=["u"], help="Move upstack") - branch_up_parser.set_defaults(func=cmd_branch_up) - - branch_down_parser = branch_subparsers.add_parser("down", aliases=["d"], help="Move downstack") - branch_down_parser.set_defaults(func=cmd_branch_down) - - branch_new_parser = branch_subparsers.add_parser("new", aliases=["create"], help="Create a new branch") - branch_new_parser.add_argument("name", help="Branch name") - branch_new_parser.set_defaults(func=cmd_branch_new) - - branch_checkout_parser = branch_subparsers.add_parser("checkout", aliases=["co"], help="Checkout a branch") - branch_checkout_parser.add_argument("name", help="Branch name", nargs="?") - branch_checkout_parser.set_defaults(func=cmd_branch_checkout) - - # stack - stack_parser = subparsers.add_parser("stack", aliases=["s"], help="Operations on the full current stack") - stack_subparsers = stack_parser.add_subparsers(required=True, dest="stack_command") - - stack_info_parser = stack_subparsers.add_parser("info", aliases=["i"], help="Info for current stack") - stack_info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") - stack_info_parser.set_defaults(func=cmd_stack_info) - - stack_push_parser = stack_subparsers.add_parser("push", help="Push") - stack_push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") - stack_push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") - stack_push_parser.set_defaults(func=cmd_stack_push) - - stack_sync_parser = stack_subparsers.add_parser("sync", help="Sync") - stack_sync_parser.set_defaults(func=cmd_stack_sync) - - stack_checkout_parser = stack_subparsers.add_parser( - "checkout", aliases=["co"], help="Checkout a branch in this stack" - ) - stack_checkout_parser.set_defaults(func=cmd_stack_checkout) - - # upstack - upstack_parser = subparsers.add_parser("upstack", aliases=["us"], help="Operations on the current upstack") - upstack_subparsers = upstack_parser.add_subparsers(required=True, dest="upstack_command") - - upstack_info_parser = upstack_subparsers.add_parser("info", aliases=["i"], help="Info for current upstack") - upstack_info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") - upstack_info_parser.set_defaults(func=cmd_upstack_info) - - upstack_push_parser = upstack_subparsers.add_parser("push", help="Push") - upstack_push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") - upstack_push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") - upstack_push_parser.set_defaults(func=cmd_upstack_push) - - upstack_sync_parser = upstack_subparsers.add_parser("sync", help="Sync") - upstack_sync_parser.set_defaults(func=cmd_upstack_sync) - - upstack_onto_parser = upstack_subparsers.add_parser("onto", aliases=["restack"], help="Restack") - upstack_onto_parser.add_argument("target", help="New parent") - upstack_onto_parser.set_defaults(func=cmd_upstack_onto) - - upstack_as_parser = upstack_subparsers.add_parser("as", help="Upstack branch this as a new stack bottom") - upstack_as_parser.add_argument("target", help="bottom, restack this branch as a new stack bottom") - upstack_as_parser.set_defaults(func=cmd_upstack_as) - - # downstack - downstack_parser = subparsers.add_parser( - "downstack", aliases=["ds"], help="Operations on the current downstack" - ) - downstack_subparsers = downstack_parser.add_subparsers(required=True, dest="downstack_command") - - downstack_info_parser = downstack_subparsers.add_parser( - "info", aliases=["i"], help="Info for current downstack" - ) - downstack_info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") - downstack_info_parser.set_defaults(func=cmd_downstack_info) - - downstack_push_parser = downstack_subparsers.add_parser("push", help="Push") - downstack_push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") - downstack_push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") - downstack_push_parser.set_defaults(func=cmd_downstack_push) - - downstack_sync_parser = downstack_subparsers.add_parser("sync", help="Sync") - downstack_sync_parser.set_defaults(func=cmd_downstack_sync) - - # update - update_parser = subparsers.add_parser("update", help="Update repo, all bottom branches must exist in remote") - update_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") - update_parser.set_defaults(func=cmd_update) - - # import - import_parser = subparsers.add_parser("import", help="Import Graphite stack") - import_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") - import_parser.add_argument("name", help="Foreign stack top") - import_parser.set_defaults(func=cmd_import) - - # adopt - adopt_parser = subparsers.add_parser("adopt", help="Adopt one branch") - adopt_parser.add_argument("name", help="Branch name") - adopt_parser.set_defaults(func=cmd_adopt) - - # land - land_parser = subparsers.add_parser("land", help="Land bottom-most PR on current stack") - land_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") - land_parser.add_argument( - "--auto", - "-a", - action="store_true", - help="Automatically merge after all checks pass", - ) - land_parser.set_defaults(func=cmd_land) - - # shortcuts - push_parser = subparsers.add_parser("push", help="Alias for downstack push") - push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") - push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") - push_parser.set_defaults(func=cmd_downstack_push) - - sync_parser = subparsers.add_parser("sync", help="Alias for stack sync") - sync_parser.set_defaults(func=cmd_stack_sync) - - checkout_parser = subparsers.add_parser("checkout", aliases=["co"], help="Checkout a branch") - checkout_parser.add_argument("name", help="Branch name", nargs="?") - checkout_parser.set_defaults(func=cmd_branch_checkout) - - checkout_parser = subparsers.add_parser("sco", help="Checkout a branch in this stack") - checkout_parser.set_defaults(func=cmd_stack_checkout) - - args = parser.parse_args() - logging.basicConfig(format=_LOGGING_FORMAT, level=LOGLEVELS[args.log_level], force=True) - - global COLOR_STDERR - global COLOR_STDOUT - if args.color == "always": - COLOR_STDERR = True - COLOR_STDOUT = True - elif args.color == "never": - COLOR_STDERR = False - COLOR_STDOUT = False - - init_git() - - stack = StackBranchSet() - load_all_stacks(stack) - - global CURRENT_BRANCH - if args.command == "continue": - try: - with open(STATE_FILE) as f: - state = json.load(f) - except FileNotFoundError as e: # noqa: F841 - die("No previous command in progress") - branch = state["branch"] - run(["git", "checkout", branch]) - CURRENT_BRANCH = branch - if CURRENT_BRANCH not in stack.stack: - die("Current branch {} is not in a stack", CURRENT_BRANCH) - - sync_names = state["sync"] - syncs = [stack.stack[n] for n in sync_names] - - inner_do_sync(syncs, sync_names) - else: - # TODO restore the current branch after changing the branch on some commands for - # instance `info` - if CURRENT_BRANCH not in stack.stack: - main_branch = get_real_stack_bottom() - - if get_config().change_to_main and main_branch is not None: - run(["git", "checkout", main_branch]) - CURRENT_BRANCH = main_branch - else: - die("Current branch {} is not in a stack", CURRENT_BRANCH) - - get_current_stack_as_forest(stack) - args.func(stack, args) - - # Success, delete the state file - try: - os.remove(STATE_FILE) - except FileNotFoundError: - pass - except ExitException as e: - error("{}", e.args[0]) - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/src/stacky/stacky_test.py b/src/stacky/stacky_test.py deleted file mode 100755 index 24e1b98..0000000 --- a/src/stacky/stacky_test.py +++ /dev/null @@ -1,180 +0,0 @@ -#!/usr/bin/env python3 -import os -import shlex -import subprocess -import unittest -from unittest import mock -from unittest.mock import MagicMock, patch - -from stacky import ( - PRInfos, - _check_returncode, - cmd_land, - find_issue_marker, - get_top_level_dir, - read_config, - stop_muxed_ssh, -) - - -class TestCheckReturnCode(unittest.TestCase): - @patch("stacky.die") - def test_check_returncode_zero(self, mock_die): - sp = subprocess.CompletedProcess(args=["ls"], returncode=0) - _check_returncode(sp, ["ls"]) - mock_die.assert_not_called() - - @patch("stacky.die") - def test_check_returncode_negative(self, mock_die): - sp = subprocess.CompletedProcess(args=["ls"], returncode=-1, stderr=b"error") - _check_returncode(sp, ["ls"]) - mock_die.assert_called_once_with("Killed by signal {}: {}. Stderr was:\n{}", 1, shlex.join(["ls"]), "error") - - @patch("stacky.die") - def test_check_returncode_positive(self, mock_die): - sp = subprocess.CompletedProcess(args=["ls"], returncode=1, stderr=b"error") - _check_returncode(sp, ["ls"]) - mock_die.assert_called_once_with("Exited with status {}: {}. Stderr was:\n{}", 1, shlex.join(["ls"]), "error") - - -class TestStringMethods(unittest.TestCase): - def test_find_issue_marker(self): - out = find_issue_marker("SRE-12") - self.assertTrue(out is not None) - self.assertEqual("SRE-12", out) - - out = find_issue_marker("SRE-12-find-things") - self.assertTrue(out is not None) - self.assertEqual("SRE-12", out) - - out = find_issue_marker("SRE_12") - self.assertTrue(out is not None) - self.assertEqual("SRE-12", out) - - out = find_issue_marker("SRE_12-find-things") - self.assertTrue(out is not None) - self.assertEqual("SRE-12", out) - - out = find_issue_marker("john_SRE_12") - self.assertTrue(out is not None) - self.assertEqual("SRE-12", out) - - out = find_issue_marker("john_SRE_12-find-things") - self.assertTrue(out is not None) - self.assertEqual("SRE-12", out) - - out = find_issue_marker("john_SRE12-find-things") - self.assertTrue(out is not None) - self.assertEqual("SRE-12", out) - - out = find_issue_marker("anna_01_01_SRE-12") - self.assertTrue(out is not None) - self.assertEqual("SRE-12", out) - - out = find_issue_marker("anna_01_01_SRE12") - self.assertTrue(out is not None) - self.assertEqual("SRE-12", out) - - out = find_issue_marker("john_test_12") - self.assertTrue(out is None) - - out = find_issue_marker("john_test12") - self.assertTrue(out is None) - - -class TestCmdLand(unittest.TestCase): - @patch("stacky.COLOR_STDOUT", True) - @patch("sys.stdout.write") - @patch("stacky.get_current_downstack_as_forest") - @patch("stacky.die") - @patch("stacky.cout") - @patch("stacky.confirm") - @patch("stacky.run") - @patch("stacky.CmdArgs") - @patch("stacky.Commit") - def test_cmd_land( - self, - mock_Commit, - mock_CmdArgs, - mock_run, - mock_confirm, - mock_cout, - mock_die, - mock_get_current_downstack_as_forest, - mock_write, - ): - # Mock the args - args = MagicMock() - args.force = False - args.auto = False - - bottom_branch = MagicMock() - bottom_branch.name = "bottom_branch" - - # Mock the stack - stack = MagicMock() - stack.bottoms = [bottom_branch] - - # Mock the branch - branch = MagicMock() - branch.is_synced_with_parent.return_value = True - branch.is_synced_with_remote.return_value = True - branch.load_pr_info.return_value = None - branch.open_pr_info = {"mergeable": "MERGEABLE", "number": 1, "url": "http://example.com"} - branch.name = "branch_name" - branch.parent.name = "parent_name" - - # Mock the forest and branches - mock_get_current_downstack_as_forest.return_value = [ - {"bottom_branch": (bottom_branch, {"branch": (branch, None)})} - ] - - # Mock the CmdArgs - mock_CmdArgs.return_value = ["cmd_args"] - - # Mock the Commit - mock_Commit.return_value = "commit" - - # Call the function - cmd_land(stack, args) - - # Assert the mocks were called correctly - mock_get_current_downstack_as_forest.assert_called_once_with(stack) - branch.is_synced_with_parent.assert_called_once() - branch.is_synced_with_remote.assert_called_once() - branch.load_pr_info.assert_called_once() - mock_write.assert_called_with( - "- Will land PR #1 (\x1b[34mhttp://example.com\x1b[0m) for branch branch_name into branch parent_name\n" - ) - mock_run.assert_called_with(["cmd_args"], out=True) - mock_cout.assert_called_with("\n✓ Success! Run `stacky update` to update local state.\n", fg="green") - - -class TestStopMuxedSsh(unittest.TestCase): - @patch("stacky.get_config", return_value=MagicMock(share_ssh_session=True)) - @patch("stacky.get_remote_type", return_value="host") - @patch("stacky.gen_ssh_mux_cmd", return_value=["ssh", "-S"]) - @patch("subprocess.Popen") - def test_stop_muxed_ssh(self, mock_popen, mock_gen_ssh_mux_cmd, mock_get_remote_type, mock_get_config): - stop_muxed_ssh() - mock_popen.assert_called_once_with(["ssh", "-S", "-O", "exit", "host"], stderr=subprocess.DEVNULL) - - @patch("stacky.get_config", return_value=MagicMock(share_ssh_session=False)) - @patch("stacky.get_remote_type", return_value="host") - @patch("stacky.gen_ssh_mux_cmd", return_value=["ssh", "-S"]) - @patch("subprocess.Popen") - def test_stop_muxed_ssh_no_share(self, mock_popen, mock_gen_ssh_mux_cmd, mock_get_remote_type, mock_get_config): - stop_muxed_ssh() - mock_popen.assert_not_called() - - @patch("stacky.get_config", return_value=MagicMock(share_ssh_session=True)) - @patch("stacky.get_remote_type", return_value=None) - @patch("stacky.gen_ssh_mux_cmd", return_value=["ssh", "-S"]) - @patch("subprocess.Popen") - def test_stop_muxed_ssh_no_host(self, mock_popen, mock_gen_ssh_mux_cmd, mock_get_remote_type, mock_get_config): - stop_muxed_ssh() - mock_popen.assert_not_called() - - -if __name__ == "__main__": - unittest.main() diff --git a/src/stacky/tests/__init__.py b/src/stacky/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/stacky/tests/__pycache__/test_integration.cpython-311.pyc b/src/stacky/tests/__pycache__/test_integration.cpython-311.pyc new file mode 100644 index 0000000..27dd1a5 Binary files /dev/null and b/src/stacky/tests/__pycache__/test_integration.cpython-311.pyc differ diff --git a/src/stacky/tests/test_commands/__init__.py b/src/stacky/tests/test_commands/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/stacky/tests/test_commands/__pycache__/__init__.cpython-311.pyc b/src/stacky/tests/test_commands/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..3793a9d Binary files /dev/null and b/src/stacky/tests/test_commands/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/stacky/tests/test_commands/__pycache__/test_land.cpython-311.pyc b/src/stacky/tests/test_commands/__pycache__/test_land.cpython-311.pyc new file mode 100644 index 0000000..5a4b0db Binary files /dev/null and b/src/stacky/tests/test_commands/__pycache__/test_land.cpython-311.pyc differ diff --git a/src/stacky/tests/test_commands/test_land.py b/src/stacky/tests/test_commands/test_land.py new file mode 100644 index 0000000..b2e0b45 --- /dev/null +++ b/src/stacky/tests/test_commands/test_land.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +"""Tests for stacky.commands.land module.""" + +import unittest +from unittest.mock import patch, MagicMock + +from stacky.commands.land import cmd_land + + +class TestCmdLand(unittest.TestCase): + """Tests for cmd_land function.""" + + @patch("stacky.commands.land.COLOR_STDOUT", True) + @patch("sys.stdout.write") + @patch("stacky.commands.land.get_current_branch_name") + @patch("stacky.commands.land.get_current_downstack_as_forest") + @patch("stacky.commands.land.die") + @patch("stacky.commands.land.cout") + @patch("stacky.commands.land.confirm") + @patch("stacky.commands.land.run") + def test_cmd_land_success( + self, + mock_run, + mock_confirm, + mock_cout, + mock_die, + mock_forest, + mock_current_branch, + mock_write, + ): + """Test successful land command.""" + args = MagicMock() + args.force = False + args.auto = False + + bottom_branch = MagicMock() + bottom_branch.name = "main" + + stack = MagicMock() + stack.bottoms = {bottom_branch} + + branch = MagicMock() + branch.name = "feature" + branch.is_synced_with_parent.return_value = True + branch.is_synced_with_remote.return_value = True + branch.load_pr_info.return_value = None + branch.open_pr_info = { + "mergeable": "MERGEABLE", + "number": 1, + "url": "http://example.com" + } + branch.parent = MagicMock() + branch.parent.name = "main" + + mock_current_branch.return_value = "feature" + mock_forest.return_value = [ + {"main": (bottom_branch, {"feature": (branch, None)})} + ] + mock_run.return_value = "abc123" + + cmd_land(stack, args) + + mock_forest.assert_called_once_with(stack) + branch.is_synced_with_parent.assert_called_once() + branch.is_synced_with_remote.assert_called_once() + branch.load_pr_info.assert_called_once() + mock_confirm.assert_called_once() + + @patch("stacky.commands.land.get_current_branch_name") + @patch("stacky.commands.land.get_current_downstack_as_forest") + @patch("stacky.commands.land.die") + def test_cmd_land_not_synced_parent( + self, + mock_die, + mock_forest, + mock_current_branch, + ): + """Test land fails when not synced with parent.""" + from stacky.utils.logging import ExitException + mock_die.side_effect = ExitException("Not synced with parent") + + args = MagicMock() + args.force = False + + bottom_branch = MagicMock() + bottom_branch.name = "main" + + stack = MagicMock() + stack.bottoms = {bottom_branch} + + branch = MagicMock() + branch.name = "feature" + branch.is_synced_with_parent.return_value = False + branch.parent = MagicMock() + branch.parent.name = "main" + + mock_current_branch.return_value = "feature" + mock_forest.return_value = [ + {"main": (bottom_branch, {"feature": (branch, None)})} + ] + + with self.assertRaises(ExitException): + cmd_land(stack, args) + mock_die.assert_called() + + @patch("stacky.commands.land.get_current_branch_name") + @patch("stacky.commands.land.get_current_downstack_as_forest") + @patch("stacky.commands.land.die") + def test_cmd_land_not_synced_remote( + self, + mock_die, + mock_forest, + mock_current_branch, + ): + """Test land fails when not synced with remote.""" + from stacky.utils.logging import ExitException + mock_die.side_effect = ExitException("Not synced with remote") + + args = MagicMock() + args.force = False + + bottom_branch = MagicMock() + bottom_branch.name = "main" + + stack = MagicMock() + stack.bottoms = {bottom_branch} + + branch = MagicMock() + branch.name = "feature" + branch.is_synced_with_parent.return_value = True + branch.is_synced_with_remote.return_value = False + branch.parent = MagicMock() + branch.parent.name = "main" + + mock_current_branch.return_value = "feature" + mock_forest.return_value = [ + {"main": (bottom_branch, {"feature": (branch, None)})} + ] + + with self.assertRaises(ExitException): + cmd_land(stack, args) + mock_die.assert_called() + + @patch("stacky.commands.land.get_current_branch_name") + @patch("stacky.commands.land.get_current_downstack_as_forest") + @patch("stacky.commands.land.die") + def test_cmd_land_no_open_pr( + self, + mock_die, + mock_forest, + mock_current_branch, + ): + """Test land fails when no open PR.""" + from stacky.utils.logging import ExitException + mock_die.side_effect = ExitException("No open PR") + + args = MagicMock() + args.force = False + + bottom_branch = MagicMock() + bottom_branch.name = "main" + + stack = MagicMock() + stack.bottoms = {bottom_branch} + + branch = MagicMock() + branch.name = "feature" + branch.is_synced_with_parent.return_value = True + branch.is_synced_with_remote.return_value = True + branch.load_pr_info.return_value = None + branch.open_pr_info = None + branch.parent = MagicMock() + branch.parent.name = "main" + + mock_current_branch.return_value = "feature" + mock_forest.return_value = [ + {"main": (bottom_branch, {"feature": (branch, None)})} + ] + + with self.assertRaises(ExitException): + cmd_land(stack, args) + mock_die.assert_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/src/stacky/tests/test_git/__init__.py b/src/stacky/tests/test_git/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/stacky/tests/test_git/__pycache__/__init__.cpython-311.pyc b/src/stacky/tests/test_git/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..6ebd72a Binary files /dev/null and b/src/stacky/tests/test_git/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/stacky/tests/test_git/__pycache__/test_branch.cpython-311.pyc b/src/stacky/tests/test_git/__pycache__/test_branch.cpython-311.pyc new file mode 100644 index 0000000..d9da36b Binary files /dev/null and b/src/stacky/tests/test_git/__pycache__/test_branch.cpython-311.pyc differ diff --git a/src/stacky/tests/test_git/__pycache__/test_refs.cpython-311.pyc b/src/stacky/tests/test_git/__pycache__/test_refs.cpython-311.pyc new file mode 100644 index 0000000..8848208 Binary files /dev/null and b/src/stacky/tests/test_git/__pycache__/test_refs.cpython-311.pyc differ diff --git a/src/stacky/tests/test_git/__pycache__/test_remote.cpython-311.pyc b/src/stacky/tests/test_git/__pycache__/test_remote.cpython-311.pyc new file mode 100644 index 0000000..b8a9914 Binary files /dev/null and b/src/stacky/tests/test_git/__pycache__/test_remote.cpython-311.pyc differ diff --git a/src/stacky/tests/test_git/test_branch.py b/src/stacky/tests/test_git/test_branch.py new file mode 100644 index 0000000..69cbdd4 --- /dev/null +++ b/src/stacky/tests/test_git/test_branch.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +"""Tests for stacky.git.branch module.""" + +import unittest +from unittest.mock import patch, MagicMock + +from stacky.git.branch import ( + get_current_branch, get_all_branches, get_stack_parent_branch, + get_real_stack_bottom, checkout, create_branch +) +from stacky.utils.types import BranchName + + +class TestGetCurrentBranch(unittest.TestCase): + """Tests for get_current_branch function.""" + + @patch("stacky.git.branch.run") + def test_get_current_branch_success(self, mock_run): + """Test get_current_branch returns branch name.""" + mock_run.return_value = "refs/heads/feature-branch" + result = get_current_branch() + self.assertEqual(result, "feature-branch") + + @patch("stacky.git.branch.run") + def test_get_current_branch_detached_head(self, mock_run): + """Test get_current_branch returns None on detached HEAD.""" + mock_run.return_value = None + result = get_current_branch() + self.assertIsNone(result) + + +class TestGetAllBranches(unittest.TestCase): + """Tests for get_all_branches function.""" + + @patch("stacky.git.branch.run_multiline") + def test_get_all_branches(self, mock_run_multiline): + """Test get_all_branches returns list of branch names.""" + mock_run_multiline.return_value = "main\nfeature-1\nfeature-2\n" + result = get_all_branches() + self.assertEqual(result, [BranchName("main"), BranchName("feature-1"), BranchName("feature-2")]) + + @patch("stacky.git.branch.run_multiline") + def test_get_all_branches_empty(self, mock_run_multiline): + """Test get_all_branches with no branches.""" + mock_run_multiline.return_value = "" + result = get_all_branches() + self.assertEqual(result, []) + + +class TestGetStackParentBranch(unittest.TestCase): + """Tests for get_stack_parent_branch function.""" + + @patch("stacky.git.branch.run") + def test_get_stack_parent_branch_success(self, mock_run): + """Test getting parent branch.""" + mock_run.return_value = "refs/heads/parent-branch" + result = get_stack_parent_branch(BranchName("child-branch")) + self.assertEqual(result, "parent-branch") + + @patch("stacky.git.branch.run") + def test_get_stack_parent_branch_no_parent(self, mock_run): + """Test getting parent when no parent configured.""" + mock_run.return_value = None + result = get_stack_parent_branch(BranchName("orphan-branch")) + self.assertIsNone(result) + + def test_get_stack_parent_branch_is_bottom(self): + """Test getting parent of stack bottom returns None.""" + result = get_stack_parent_branch(BranchName("master")) + self.assertIsNone(result) + + +class TestGetRealStackBottom(unittest.TestCase): + """Tests for get_real_stack_bottom function.""" + + @patch("stacky.git.branch.get_all_branches") + def test_get_real_stack_bottom_master(self, mock_get_all): + """Test finding master as stack bottom.""" + mock_get_all.return_value = [BranchName("master"), BranchName("feature")] + result = get_real_stack_bottom() + self.assertEqual(result, "master") + + @patch("stacky.git.branch.get_all_branches") + def test_get_real_stack_bottom_main(self, mock_get_all): + """Test finding main as stack bottom.""" + mock_get_all.return_value = [BranchName("main"), BranchName("feature")] + result = get_real_stack_bottom() + self.assertEqual(result, "main") + + @patch("stacky.git.branch.get_all_branches") + def test_get_real_stack_bottom_none(self, mock_get_all): + """Test no stack bottom found.""" + mock_get_all.return_value = [BranchName("feature")] + result = get_real_stack_bottom() + self.assertIsNone(result) + + +class TestCheckout(unittest.TestCase): + """Tests for checkout function.""" + + @patch("stacky.git.branch.run") + @patch("stacky.git.branch.info") + def test_checkout(self, mock_info, mock_run): + """Test checkout calls git checkout.""" + checkout(BranchName("feature")) + mock_run.assert_called_once_with(["git", "checkout", "feature"], out=True) + + +class TestCreateBranch(unittest.TestCase): + """Tests for create_branch function.""" + + @patch("stacky.git.branch.run") + def test_create_branch(self, mock_run): + """Test create_branch calls git checkout -b with track.""" + create_branch(BranchName("new-feature")) + mock_run.assert_called_once_with( + ["git", "checkout", "-b", "new-feature", "--track"], out=True + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/stacky/tests/test_git/test_refs.py b/src/stacky/tests/test_git/test_refs.py new file mode 100644 index 0000000..b2a8e3d --- /dev/null +++ b/src/stacky/tests/test_git/test_refs.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +"""Tests for stacky.git.refs module.""" + +import unittest +from unittest.mock import patch + +from stacky.git.refs import ( + get_stack_parent_commit, get_commit, set_parent_commit, + get_branch_name_from_short_ref, get_all_stack_bottoms, + get_commits_between, get_merge_base +) +from stacky.utils.types import BranchName, Commit + + +class TestGetStackParentCommit(unittest.TestCase): + """Tests for get_stack_parent_commit function.""" + + @patch("stacky.git.refs.run") + def test_get_stack_parent_commit_success(self, mock_run): + """Test getting parent commit.""" + mock_run.return_value = "abc123" + result = get_stack_parent_commit(BranchName("feature")) + self.assertEqual(result, Commit("abc123")) + + @patch("stacky.git.refs.run") + def test_get_stack_parent_commit_none(self, mock_run): + """Test getting parent commit when not set.""" + mock_run.return_value = None + result = get_stack_parent_commit(BranchName("feature")) + self.assertIsNone(result) + + +class TestGetCommit(unittest.TestCase): + """Tests for get_commit function.""" + + @patch("stacky.git.refs.run") + def test_get_commit(self, mock_run): + """Test getting branch commit.""" + mock_run.return_value = "def456" + result = get_commit(BranchName("main")) + self.assertEqual(result, Commit("def456")) + + +class TestSetParentCommit(unittest.TestCase): + """Tests for set_parent_commit function.""" + + @patch("stacky.git.refs.run") + def test_set_parent_commit(self, mock_run): + """Test setting parent commit.""" + set_parent_commit(BranchName("feature"), Commit("abc123")) + mock_run.assert_called_once() + call_args = mock_run.call_args[0][0] + self.assertIn("update-ref", call_args) + self.assertIn("refs/stack-parent/feature", call_args) + self.assertIn("abc123", call_args) + + @patch("stacky.git.refs.run") + def test_set_parent_commit_with_prev(self, mock_run): + """Test setting parent commit with previous value.""" + set_parent_commit(BranchName("feature"), Commit("abc123"), "old123") + mock_run.assert_called_once() + call_args = mock_run.call_args[0][0] + self.assertIn("old123", call_args) + + +class TestGetBranchNameFromShortRef(unittest.TestCase): + """Tests for get_branch_name_from_short_ref function.""" + + def test_get_branch_name_from_short_ref(self): + """Test extracting branch name from short ref.""" + result = get_branch_name_from_short_ref("stack-parent/feature") + self.assertEqual(result, BranchName("feature")) + + def test_get_branch_name_from_short_ref_invalid(self): + """Test invalid ref format raises error.""" + from stacky.utils.logging import ExitException + # The function will raise ExitException via die() for invalid refs + with self.assertRaises(ExitException): + get_branch_name_from_short_ref("invalid") + + +class TestGetAllStackBottoms(unittest.TestCase): + """Tests for get_all_stack_bottoms function.""" + + @patch("stacky.git.refs.run_multiline") + def test_get_all_stack_bottoms(self, mock_run): + """Test getting all stack bottom branches.""" + mock_run.return_value = "stacky-bottom-branch/feature-1\nstacky-bottom-branch/feature-2\n" + result = get_all_stack_bottoms() + self.assertEqual(result, [BranchName("feature-1"), BranchName("feature-2")]) + + @patch("stacky.git.refs.run_multiline") + def test_get_all_stack_bottoms_empty(self, mock_run): + """Test no stack bottoms.""" + mock_run.return_value = "" + result = get_all_stack_bottoms() + self.assertEqual(result, []) + + +class TestGetCommitsBetween(unittest.TestCase): + """Tests for get_commits_between function.""" + + @patch("stacky.git.refs.run_multiline") + def test_get_commits_between(self, mock_run): + """Test getting commits between two refs.""" + mock_run.return_value = "abc123\ndef456\n" + result = get_commits_between(Commit("start"), Commit("end")) + self.assertEqual(result, ["abc123", "def456"]) + + @patch("stacky.git.refs.run_multiline") + def test_get_commits_between_empty(self, mock_run): + """Test no commits between refs.""" + mock_run.return_value = "" + result = get_commits_between(Commit("same"), Commit("same")) + self.assertEqual(result, []) + + +class TestGetMergeBase(unittest.TestCase): + """Tests for get_merge_base function.""" + + @patch("stacky.git.refs.run") + def test_get_merge_base(self, mock_run): + """Test getting merge base.""" + mock_run.return_value = "abc123" + result = get_merge_base(BranchName("main"), BranchName("feature")) + self.assertEqual(result, "abc123") + + +if __name__ == "__main__": + unittest.main() diff --git a/src/stacky/tests/test_git/test_remote.py b/src/stacky/tests/test_git/test_remote.py new file mode 100644 index 0000000..bc24c3a --- /dev/null +++ b/src/stacky/tests/test_git/test_remote.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +"""Tests for stacky.git.remote module.""" + +import subprocess +import unittest +from unittest.mock import patch, MagicMock + +from stacky.git.remote import ( + get_remote_type, gen_ssh_mux_cmd, stop_muxed_ssh, start_muxed_ssh +) + + +class TestGetRemoteType(unittest.TestCase): + """Tests for get_remote_type function.""" + + @patch("stacky.git.remote.run_always_return") + def test_get_remote_type_ssh(self, mock_run): + """Test getting SSH remote type.""" + mock_run.return_value = "origin\tgit@github.com:user/repo.git (push)" + result = get_remote_type("origin") + self.assertEqual(result, "git@github.com") + + @patch("stacky.git.remote.run_always_return") + def test_get_remote_type_https(self, mock_run): + """Test getting HTTPS remote type returns None.""" + mock_run.return_value = "origin\thttps://github.com/user/repo.git (push)" + result = get_remote_type("origin") + self.assertIsNone(result) + + +class TestGenSshMuxCmd(unittest.TestCase): + """Tests for gen_ssh_mux_cmd function.""" + + def test_gen_ssh_mux_cmd(self): + """Test SSH mux command generation.""" + cmd = gen_ssh_mux_cmd() + self.assertEqual(cmd[0], "ssh") + self.assertIn("-o", cmd) + self.assertIn("ControlMaster=auto", cmd) + self.assertIn("ControlPath=~/.ssh/stacky-%C", cmd) + + +class TestStopMuxedSsh(unittest.TestCase): + """Tests for stop_muxed_ssh function.""" + + @patch("stacky.git.remote.get_config") + @patch("stacky.git.remote.get_remote_type") + @patch("stacky.git.remote.gen_ssh_mux_cmd") + @patch("subprocess.Popen") + def test_stop_muxed_ssh(self, mock_popen, mock_gen_cmd, mock_get_remote, mock_get_config): + """Test stopping muxed SSH connection.""" + mock_get_config.return_value = MagicMock(share_ssh_session=True) + mock_get_remote.return_value = "git@github.com" + mock_gen_cmd.return_value = ["ssh", "-S"] + + stop_muxed_ssh() + + mock_popen.assert_called_once_with( + ["ssh", "-S", "-O", "exit", "git@github.com"], + stderr=subprocess.DEVNULL + ) + + @patch("stacky.git.remote.get_config") + @patch("subprocess.Popen") + def test_stop_muxed_ssh_disabled(self, mock_popen, mock_get_config): + """Test stop_muxed_ssh does nothing when disabled.""" + mock_get_config.return_value = MagicMock(share_ssh_session=False) + stop_muxed_ssh() + mock_popen.assert_not_called() + + @patch("stacky.git.remote.get_config") + @patch("stacky.git.remote.get_remote_type") + @patch("subprocess.Popen") + def test_stop_muxed_ssh_no_host(self, mock_popen, mock_get_remote, mock_get_config): + """Test stop_muxed_ssh does nothing when no SSH host.""" + mock_get_config.return_value = MagicMock(share_ssh_session=True) + mock_get_remote.return_value = None + stop_muxed_ssh() + mock_popen.assert_not_called() + + +class TestStartMuxedSsh(unittest.TestCase): + """Tests for start_muxed_ssh function.""" + + @patch("stacky.git.remote.get_config") + def test_start_muxed_ssh_disabled(self, mock_get_config): + """Test start_muxed_ssh does nothing when disabled.""" + mock_get_config.return_value = MagicMock(share_ssh_session=False) + # Should not raise any errors + start_muxed_ssh() + + +if __name__ == "__main__": + unittest.main() diff --git a/src/stacky/tests/test_integration.py b/src/stacky/tests/test_integration.py new file mode 100644 index 0000000..bbec1f3 --- /dev/null +++ b/src/stacky/tests/test_integration.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +"""Integration tests for stacky workflow.""" + +import unittest +from unittest.mock import patch, MagicMock + +from stacky.stack.models import StackBranch, StackBranchSet +from stacky.utils.types import BranchName, Commit + + +class TestStackWorkflow(unittest.TestCase): + """Integration tests for stack workflow.""" + + @patch("stacky.stack.models.get_remote_info") + @patch("stacky.stack.models.get_commit") + def test_build_simple_stack(self, mock_get_commit, mock_get_remote): + """Test building a simple two-branch stack.""" + # Mock git operations + mock_get_commit.return_value = Commit("abc123") + mock_get_remote.return_value = ("origin", BranchName("main"), Commit("abc123")) + + # Create stack + stack = StackBranchSet() + + # Add main branch (bottom) + main = stack.add( + BranchName("main"), + parent=None, + parent_commit=None + ) + + # Add feature branch + mock_get_commit.return_value = Commit("def456") + mock_get_remote.return_value = ("origin", BranchName("feature"), Commit("def456")) + feature = stack.add( + BranchName("feature"), + parent=main, + parent_commit=Commit("abc123") + ) + stack.add_child(main, feature) + + # Verify stack structure + self.assertEqual(len(stack.stack), 2) + self.assertIn(main, stack.bottoms) + self.assertIn(feature, stack.tops) + self.assertIn(feature, main.children) + self.assertEqual(feature.parent, main) + + @patch("stacky.stack.models.get_remote_info") + @patch("stacky.stack.models.get_commit") + def test_build_branching_stack(self, mock_get_commit, mock_get_remote): + """Test building a stack with multiple branches from one parent.""" + mock_get_commit.return_value = Commit("abc123") + mock_get_remote.return_value = ("origin", BranchName("main"), Commit("abc123")) + + stack = StackBranchSet() + + # Add main + main = stack.add(BranchName("main"), parent=None, parent_commit=None) + + # Add feature-1 + mock_get_commit.return_value = Commit("def456") + mock_get_remote.return_value = ("origin", BranchName("feature-1"), Commit("def456")) + feature1 = stack.add(BranchName("feature-1"), parent=main, parent_commit=Commit("abc123")) + stack.add_child(main, feature1) + + # Add feature-2 (also from main) + mock_get_commit.return_value = Commit("ghi789") + mock_get_remote.return_value = ("origin", BranchName("feature-2"), Commit("ghi789")) + feature2 = stack.add(BranchName("feature-2"), parent=main, parent_commit=Commit("abc123")) + stack.add_child(main, feature2) + + # Verify structure + self.assertEqual(len(main.children), 2) + self.assertIn(feature1, main.children) + self.assertIn(feature2, main.children) + self.assertEqual(len(stack.tops), 2) + + @patch("stacky.stack.models.get_remote_info") + @patch("stacky.stack.models.get_commit") + def test_sync_status_detection(self, mock_get_commit, mock_get_remote): + """Test that sync status is correctly detected.""" + mock_get_commit.return_value = Commit("abc123") + mock_get_remote.return_value = ("origin", BranchName("main"), Commit("abc123")) + + stack = StackBranchSet() + main = stack.add(BranchName("main"), parent=None, parent_commit=None) + + # Add feature synced with parent + mock_get_commit.return_value = Commit("def456") + mock_get_remote.return_value = ("origin", BranchName("feature"), Commit("def456")) + feature = stack.add(BranchName("feature"), parent=main, parent_commit=Commit("abc123")) + + # Initially synced + self.assertTrue(feature.is_synced_with_parent()) + self.assertTrue(feature.is_synced_with_remote()) + + # Simulate parent moving ahead + main.commit = Commit("new123") + self.assertFalse(feature.is_synced_with_parent()) + + # Simulate remote moving ahead + feature.remote_commit = Commit("remote456") + self.assertFalse(feature.is_synced_with_remote()) + + +class TestTreeTraversal(unittest.TestCase): + """Integration tests for tree traversal.""" + + @patch("stacky.stack.models.get_remote_info") + @patch("stacky.stack.models.get_commit") + def test_forest_traversal_order(self, mock_get_commit, mock_get_remote): + """Test that forest traversal visits branches in correct order.""" + from stacky.stack.tree import depth_first, forest_depth_first, make_tree + from stacky.utils.types import BranchesTreeForest + + mock_get_commit.return_value = Commit("abc123") + mock_get_remote.return_value = ("origin", BranchName("main"), Commit("abc123")) + + stack = StackBranchSet() + main = stack.add(BranchName("main"), parent=None, parent_commit=None) + + mock_get_commit.return_value = Commit("def456") + mock_get_remote.return_value = ("origin", BranchName("feature"), Commit("def456")) + feature = stack.add(BranchName("feature"), parent=main, parent_commit=Commit("abc123")) + stack.add_child(main, feature) + + tree = make_tree(main) + branches = list(depth_first(tree)) + + self.assertEqual(len(branches), 2) + self.assertEqual(branches[0].name, "main") + self.assertEqual(branches[1].name, "feature") + + +class TestPRInfoLoading(unittest.TestCase): + """Integration tests for PR info loading.""" + + @patch("stacky.stack.models.get_remote_info") + @patch("stacky.stack.models.get_commit") + @patch("stacky.pr.github.get_pr_info") + def test_lazy_pr_loading(self, mock_get_pr_info, mock_get_commit, mock_get_remote): + """Test that PR info is loaded lazily.""" + from stacky.stack.models import PRInfos + + mock_get_commit.return_value = Commit("abc123") + mock_get_remote.return_value = ("origin", BranchName("feature"), Commit("abc123")) + mock_get_pr_info.return_value = PRInfos( + all={"pr1": {"id": "pr1", "state": "OPEN", "number": 1}}, + open={"id": "pr1", "state": "OPEN", "number": 1} + ) + + stack = StackBranchSet() + feature = stack.add(BranchName("feature"), parent=None, parent_commit=None) + + # PR info not loaded yet + self.assertFalse(feature._pr_info_loaded) + self.assertEqual(feature.pr_info, {}) + + # Load PR info + feature.load_pr_info() + + # Now loaded + self.assertTrue(feature._pr_info_loaded) + mock_get_pr_info.assert_called_once_with(BranchName("feature")) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/stacky/tests/test_pr/__init__.py b/src/stacky/tests/test_pr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/stacky/tests/test_pr/__pycache__/__init__.cpython-311.pyc b/src/stacky/tests/test_pr/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..94b5cb3 Binary files /dev/null and b/src/stacky/tests/test_pr/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/stacky/tests/test_pr/__pycache__/test_github.cpython-311.pyc b/src/stacky/tests/test_pr/__pycache__/test_github.cpython-311.pyc new file mode 100644 index 0000000..0c8aba0 Binary files /dev/null and b/src/stacky/tests/test_pr/__pycache__/test_github.cpython-311.pyc differ diff --git a/src/stacky/tests/test_pr/test_github.py b/src/stacky/tests/test_pr/test_github.py new file mode 100644 index 0000000..6c7b76b --- /dev/null +++ b/src/stacky/tests/test_pr/test_github.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +"""Tests for stacky.pr.github module.""" + +import unittest +from unittest.mock import patch, MagicMock + +from stacky.pr.github import ( + find_issue_marker, find_reviewers, extract_stack_comment, + generate_stack_string +) +from stacky.utils.types import BranchName, BranchesTreeForest, BranchesTree + + +class TestFindIssueMarker(unittest.TestCase): + """Tests for find_issue_marker function.""" + + def test_simple_issue_marker(self): + """Test finding simple issue marker.""" + result = find_issue_marker("SRE-12") + self.assertEqual(result, "SRE-12") + + def test_issue_marker_with_suffix(self): + """Test finding issue marker with suffix.""" + result = find_issue_marker("SRE-12-find-things") + self.assertEqual(result, "SRE-12") + + def test_issue_marker_underscore(self): + """Test finding issue marker with underscore.""" + result = find_issue_marker("SRE_12") + self.assertEqual(result, "SRE-12") + + def test_issue_marker_underscore_with_suffix(self): + """Test finding issue marker with underscore and suffix.""" + result = find_issue_marker("SRE_12-find-things") + self.assertEqual(result, "SRE-12") + + def test_issue_marker_with_prefix(self): + """Test finding issue marker with prefix.""" + result = find_issue_marker("john_SRE_12") + self.assertEqual(result, "SRE-12") + + def test_issue_marker_with_prefix_and_suffix(self): + """Test finding issue marker with prefix and suffix.""" + result = find_issue_marker("john_SRE_12-find-things") + self.assertEqual(result, "SRE-12") + + def test_issue_marker_no_separator(self): + """Test finding issue marker without separator.""" + result = find_issue_marker("john_SRE12-find-things") + self.assertEqual(result, "SRE-12") + + def test_issue_marker_date_prefix(self): + """Test finding issue marker with date prefix.""" + result = find_issue_marker("anna_01_01_SRE-12") + self.assertEqual(result, "SRE-12") + + def test_issue_marker_date_prefix_no_separator(self): + """Test finding issue marker with date prefix no separator.""" + result = find_issue_marker("anna_01_01_SRE12") + self.assertEqual(result, "SRE-12") + + def test_no_issue_marker(self): + """Test no issue marker found.""" + result = find_issue_marker("john_test_12") + self.assertIsNone(result) + + def test_no_issue_marker_no_separator(self): + """Test no issue marker found without separator.""" + result = find_issue_marker("john_test12") + self.assertIsNone(result) + + +class TestFindReviewers(unittest.TestCase): + """Tests for find_reviewers function.""" + + @patch("stacky.pr.github.run_multiline") + def test_find_reviewers_single(self, mock_run): + """Test finding single reviewer.""" + mock_run.return_value = "Some commit message\n\nReviewer: alice\n" + branch = MagicMock() + branch.name = BranchName("feature") + result = find_reviewers(branch) + self.assertEqual(result, ["alice"]) + + @patch("stacky.pr.github.run_multiline") + def test_find_reviewers_multiple(self, mock_run): + """Test finding multiple reviewers.""" + mock_run.return_value = "Some commit message\n\nReviewers: alice, bob\n" + branch = MagicMock() + branch.name = BranchName("feature") + result = find_reviewers(branch) + self.assertEqual(result, ["alice", " bob"]) + + @patch("stacky.pr.github.run_multiline") + def test_find_reviewers_none(self, mock_run): + """Test no reviewers found.""" + mock_run.return_value = "Some commit message\n" + branch = MagicMock() + branch.name = BranchName("feature") + result = find_reviewers(branch) + self.assertIsNone(result) + + +class TestExtractStackComment(unittest.TestCase): + """Tests for extract_stack_comment function.""" + + def test_extract_existing_comment(self): + """Test extracting existing stack comment.""" + body = """Some PR description + + +**Stack:** +- branch1 (#1) +- branch2 (#2) + + +More description""" + result = extract_stack_comment(body) + self.assertIn("Stacky Stack Info", result) + self.assertIn("branch1", result) + + def test_extract_no_comment(self): + """Test extracting when no comment exists.""" + body = "Just a regular PR description" + result = extract_stack_comment(body) + self.assertEqual(result, "") + + def test_extract_empty_body(self): + """Test extracting from empty body.""" + result = extract_stack_comment("") + self.assertEqual(result, "") + + def test_extract_none_body(self): + """Test extracting from None body.""" + result = extract_stack_comment(None) + self.assertEqual(result, "") + + +class TestGenerateStackString(unittest.TestCase): + """Tests for generate_stack_string function.""" + + def test_generate_empty_forest(self): + """Test generating stack string for empty forest.""" + forest = BranchesTreeForest([]) + branch = MagicMock() + branch.name = BranchName("feature") + result = generate_stack_string(forest, branch) + self.assertEqual(result, "") + + def test_generate_with_branches(self): + """Test generating stack string with branches.""" + branch1 = MagicMock() + branch1.name = BranchName("feature-1") + branch1.open_pr_info = {"number": 1} + + branch2 = MagicMock() + branch2.name = BranchName("feature-2") + branch2.open_pr_info = {"number": 2} + + tree = BranchesTree({ + "feature-1": (branch1, BranchesTree({ + "feature-2": (branch2, BranchesTree({})) + })) + }) + forest = BranchesTreeForest([tree]) + + result = generate_stack_string(forest, branch2) + self.assertIn("Stacky Stack Info", result) + self.assertIn("feature-1", result) + self.assertIn("feature-2", result) + self.assertIn("CURRENT PR", result) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/stacky/tests/test_stack/__init__.py b/src/stacky/tests/test_stack/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/stacky/tests/test_stack/__pycache__/__init__.cpython-311.pyc b/src/stacky/tests/test_stack/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..48099e2 Binary files /dev/null and b/src/stacky/tests/test_stack/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/stacky/tests/test_stack/__pycache__/test_models.cpython-311.pyc b/src/stacky/tests/test_stack/__pycache__/test_models.cpython-311.pyc new file mode 100644 index 0000000..d46c866 Binary files /dev/null and b/src/stacky/tests/test_stack/__pycache__/test_models.cpython-311.pyc differ diff --git a/src/stacky/tests/test_stack/__pycache__/test_tree.cpython-311.pyc b/src/stacky/tests/test_stack/__pycache__/test_tree.cpython-311.pyc new file mode 100644 index 0000000..1b3e1c3 Binary files /dev/null and b/src/stacky/tests/test_stack/__pycache__/test_tree.cpython-311.pyc differ diff --git a/src/stacky/tests/test_stack/test_models.py b/src/stacky/tests/test_stack/test_models.py new file mode 100644 index 0000000..36d254e --- /dev/null +++ b/src/stacky/tests/test_stack/test_models.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +"""Tests for stacky.stack.models module.""" + +import unittest +from unittest.mock import patch, MagicMock + +from stacky.stack.models import PRInfo, PRInfos, StackBranch, StackBranchSet +from stacky.utils.types import BranchName, Commit + + +class TestPRInfos(unittest.TestCase): + """Tests for PRInfos dataclass.""" + + def test_prinfos_creation(self): + """Test PRInfos creation.""" + all_prs = {"pr1": {"id": "pr1", "state": "OPEN"}} + open_pr = {"id": "pr1", "state": "OPEN"} + pr_infos = PRInfos(all=all_prs, open=open_pr) + self.assertEqual(pr_infos.all, all_prs) + self.assertEqual(pr_infos.open, open_pr) + + def test_prinfos_no_open_pr(self): + """Test PRInfos with no open PR.""" + pr_infos = PRInfos(all={}, open=None) + self.assertIsNone(pr_infos.open) + + +class TestStackBranch(unittest.TestCase): + """Tests for StackBranch class.""" + + @patch("stacky.stack.models.get_remote_info") + @patch("stacky.stack.models.get_commit") + def test_stack_branch_creation(self, mock_get_commit, mock_get_remote): + """Test StackBranch creation.""" + mock_get_commit.return_value = Commit("abc123") + mock_get_remote.return_value = ("origin", BranchName("feature"), Commit("abc123")) + + parent = MagicMock() + parent.commit = Commit("parent123") + + branch = StackBranch( + name=BranchName("feature"), + parent=parent, + parent_commit=Commit("parent123") + ) + + self.assertEqual(branch.name, "feature") + self.assertEqual(branch.parent, parent) + self.assertEqual(branch.parent_commit, Commit("parent123")) + self.assertEqual(branch.commit, Commit("abc123")) + + @patch("stacky.stack.models.get_remote_info") + @patch("stacky.stack.models.get_commit") + def test_is_synced_with_parent(self, mock_get_commit, mock_get_remote): + """Test is_synced_with_parent method.""" + mock_get_commit.return_value = Commit("abc123") + mock_get_remote.return_value = ("origin", BranchName("feature"), Commit("abc123")) + + parent = MagicMock() + parent.commit = Commit("parent123") + + branch = StackBranch( + name=BranchName("feature"), + parent=parent, + parent_commit=Commit("parent123") + ) + + self.assertTrue(branch.is_synced_with_parent()) + + # Unsynced case + parent.commit = Commit("different") + self.assertFalse(branch.is_synced_with_parent()) + + @patch("stacky.stack.models.get_remote_info") + @patch("stacky.stack.models.get_commit") + def test_is_synced_with_remote(self, mock_get_commit, mock_get_remote): + """Test is_synced_with_remote method.""" + mock_get_commit.return_value = Commit("abc123") + mock_get_remote.return_value = ("origin", BranchName("feature"), Commit("abc123")) + + branch = StackBranch( + name=BranchName("feature"), + parent=None, + parent_commit=None + ) + + self.assertTrue(branch.is_synced_with_remote()) + + # Unsynced case + branch.remote_commit = Commit("different") + self.assertFalse(branch.is_synced_with_remote()) + + +class TestStackBranchSet(unittest.TestCase): + """Tests for StackBranchSet class.""" + + def test_stack_branch_set_creation(self): + """Test StackBranchSet creation.""" + stack_set = StackBranchSet() + self.assertEqual(stack_set.stack, {}) + self.assertEqual(stack_set.tops, set()) + self.assertEqual(stack_set.bottoms, set()) + + @patch("stacky.stack.models.get_remote_info") + @patch("stacky.stack.models.get_commit") + def test_add_branch(self, mock_get_commit, mock_get_remote): + """Test adding a branch to the set.""" + mock_get_commit.return_value = Commit("abc123") + mock_get_remote.return_value = ("origin", BranchName("main"), Commit("abc123")) + + stack_set = StackBranchSet() + branch = stack_set.add( + BranchName("main"), + parent=None, + parent_commit=None + ) + + self.assertIn(BranchName("main"), stack_set.stack) + self.assertIn(branch, stack_set.bottoms) + self.assertIn(branch, stack_set.tops) + + @patch("stacky.stack.models.get_remote_info") + @patch("stacky.stack.models.get_commit") + def test_remove_branch(self, mock_get_commit, mock_get_remote): + """Test removing a branch from the set.""" + mock_get_commit.return_value = Commit("abc123") + mock_get_remote.return_value = ("origin", BranchName("main"), Commit("abc123")) + + stack_set = StackBranchSet() + branch = stack_set.add( + BranchName("main"), + parent=None, + parent_commit=None + ) + + removed = stack_set.remove(BranchName("main")) + self.assertEqual(removed, branch) + self.assertNotIn(BranchName("main"), stack_set.stack) + + @patch("stacky.stack.models.get_remote_info") + @patch("stacky.stack.models.get_commit") + def test_add_child(self, mock_get_commit, mock_get_remote): + """Test adding child relationship.""" + mock_get_commit.return_value = Commit("abc123") + mock_get_remote.return_value = ("origin", BranchName("main"), Commit("abc123")) + + stack_set = StackBranchSet() + parent = stack_set.add( + BranchName("main"), + parent=None, + parent_commit=None + ) + + child = stack_set.add( + BranchName("feature"), + parent=parent, + parent_commit=Commit("abc123") + ) + + stack_set.add_child(parent, child) + self.assertIn(child, parent.children) + self.assertNotIn(parent, stack_set.tops) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/stacky/tests/test_stack/test_tree.py b/src/stacky/tests/test_stack/test_tree.py new file mode 100644 index 0000000..4bcca27 --- /dev/null +++ b/src/stacky/tests/test_stack/test_tree.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 +"""Tests for stacky.stack.tree module.""" + +import unittest +from unittest.mock import patch, MagicMock + +from stacky.stack.tree import ( + get_pr_status_emoji, make_tree_node, make_subtree, make_tree, + format_name, depth_first, forest_depth_first +) +from stacky.utils.types import BranchName, BranchesTree, BranchesTreeForest + + +class TestGetPrStatusEmoji(unittest.TestCase): + """Tests for get_pr_status_emoji function.""" + + def test_no_pr_info(self): + """Test emoji for no PR info.""" + result = get_pr_status_emoji(None) + self.assertEqual(result, "") + + def test_draft_pr(self): + """Test emoji for draft PR.""" + pr_info = {"isDraft": True} + result = get_pr_status_emoji(pr_info) + self.assertEqual(result, " 🚧") + + def test_approved_pr(self): + """Test emoji for approved PR.""" + pr_info = {"isDraft": False, "reviewDecision": "APPROVED", "reviewRequests": []} + result = get_pr_status_emoji(pr_info) + self.assertEqual(result, " ✅") + + def test_pending_review_pr(self): + """Test emoji for PR waiting on review.""" + pr_info = {"isDraft": False, "reviewDecision": None, "reviewRequests": [{"login": "reviewer"}]} + result = get_pr_status_emoji(pr_info) + self.assertEqual(result, " 🔄") + + def test_needs_changes_pr(self): + """Test emoji for PR needing changes.""" + pr_info = {"isDraft": False, "reviewDecision": "CHANGES_REQUESTED", "reviewRequests": []} + result = get_pr_status_emoji(pr_info) + self.assertEqual(result, " ❌") + + +class TestMakeTree(unittest.TestCase): + """Tests for tree building functions.""" + + def test_make_subtree_no_children(self): + """Test make_subtree with no children.""" + branch = MagicMock() + branch.children = set() + result = make_subtree(branch) + self.assertEqual(result, {}) + + def test_make_tree(self): + """Test make_tree creates correct structure.""" + branch = MagicMock() + branch.name = BranchName("feature") + branch.children = set() + result = make_tree(branch) + self.assertIn("feature", result) + + +class TestFormatName(unittest.TestCase): + """Tests for format_name function.""" + + @patch("stacky.stack.tree.get_current_branch_name") + @patch("stacky.stack.tree.get_config") + def test_format_name_current_branch(self, mock_config, mock_current): + """Test format_name marks current branch.""" + mock_current.return_value = BranchName("feature") + mock_config.return_value = MagicMock(compact_pr_display=False) + + branch = MagicMock() + branch.name = BranchName("feature") + branch.is_synced_with_parent.return_value = True + branch.is_synced_with_remote.return_value = True + branch.open_pr_info = None + + result = format_name(branch, colorize=False) + self.assertIn("*", result) + self.assertIn("feature", result) + + @patch("stacky.stack.tree.get_current_branch_name") + @patch("stacky.stack.tree.get_config") + def test_format_name_not_synced_parent(self, mock_config, mock_current): + """Test format_name shows ! when not synced with parent.""" + mock_current.return_value = BranchName("other") + mock_config.return_value = MagicMock(compact_pr_display=False) + + branch = MagicMock() + branch.name = BranchName("feature") + branch.is_synced_with_parent.return_value = False + branch.is_synced_with_remote.return_value = True + branch.open_pr_info = None + + result = format_name(branch, colorize=False) + self.assertIn("!", result) + + +class TestDepthFirst(unittest.TestCase): + """Tests for depth-first traversal functions.""" + + def test_depth_first_empty(self): + """Test depth_first with empty tree.""" + tree = BranchesTree({}) + result = list(depth_first(tree)) + self.assertEqual(result, []) + + def test_depth_first_single_branch(self): + """Test depth_first with single branch.""" + branch = MagicMock() + branch.name = BranchName("feature") + tree = BranchesTree({"feature": (branch, BranchesTree({}))}) + result = list(depth_first(tree)) + self.assertEqual(result, [branch]) + + def test_forest_depth_first_empty(self): + """Test forest_depth_first with empty forest.""" + forest = BranchesTreeForest([]) + result = list(forest_depth_first(forest)) + self.assertEqual(result, []) + + def test_forest_depth_first_single_tree(self): + """Test forest_depth_first with single tree.""" + branch = MagicMock() + branch.name = BranchName("feature") + tree = BranchesTree({"feature": (branch, BranchesTree({}))}) + forest = BranchesTreeForest([tree]) + result = list(forest_depth_first(forest)) + self.assertEqual(result, [branch]) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/stacky/tests/test_utils/__init__.py b/src/stacky/tests/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/stacky/tests/test_utils/__pycache__/__init__.cpython-311.pyc b/src/stacky/tests/test_utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..7d635cb Binary files /dev/null and b/src/stacky/tests/test_utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/stacky/tests/test_utils/__pycache__/test_config.cpython-311.pyc b/src/stacky/tests/test_utils/__pycache__/test_config.cpython-311.pyc new file mode 100644 index 0000000..1f80239 Binary files /dev/null and b/src/stacky/tests/test_utils/__pycache__/test_config.cpython-311.pyc differ diff --git a/src/stacky/tests/test_utils/__pycache__/test_shell.cpython-311.pyc b/src/stacky/tests/test_utils/__pycache__/test_shell.cpython-311.pyc new file mode 100644 index 0000000..c86e792 Binary files /dev/null and b/src/stacky/tests/test_utils/__pycache__/test_shell.cpython-311.pyc differ diff --git a/src/stacky/tests/test_utils/test_config.py b/src/stacky/tests/test_utils/test_config.py new file mode 100644 index 0000000..4357c2c --- /dev/null +++ b/src/stacky/tests/test_utils/test_config.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +"""Tests for stacky.utils.config module.""" + +import os +import tempfile +import unittest +from unittest.mock import patch + +from stacky.utils.config import StackyConfig, read_config, get_config + + +class TestStackyConfig(unittest.TestCase): + """Tests for StackyConfig dataclass.""" + + def test_default_values(self): + """Test StackyConfig has correct default values.""" + config = StackyConfig() + self.assertFalse(config.skip_confirm) + self.assertFalse(config.change_to_main) + self.assertFalse(config.change_to_adopted) + self.assertFalse(config.share_ssh_session) + self.assertFalse(config.use_merge) + self.assertTrue(config.use_force_push) + self.assertFalse(config.compact_pr_display) + self.assertTrue(config.enable_stack_comment) + + def test_read_one_config_ui_section(self): + """Test reading UI section from config file.""" + config = StackyConfig() + with tempfile.NamedTemporaryFile(mode='w', suffix='.ini', delete=False) as f: + f.write("[UI]\n") + f.write("skip_confirm = true\n") + f.write("change_to_main = true\n") + f.write("compact_pr_display = true\n") + f.name + try: + config.read_one_config(f.name) + self.assertTrue(config.skip_confirm) + self.assertTrue(config.change_to_main) + self.assertTrue(config.compact_pr_display) + finally: + os.unlink(f.name) + + def test_read_one_config_git_section(self): + """Test reading GIT section from config file.""" + config = StackyConfig() + with tempfile.NamedTemporaryFile(mode='w', suffix='.ini', delete=False) as f: + f.write("[GIT]\n") + f.write("use_merge = true\n") + f.write("use_force_push = false\n") + f.name + try: + config.read_one_config(f.name) + self.assertTrue(config.use_merge) + self.assertFalse(config.use_force_push) + finally: + os.unlink(f.name) + + +class TestReadConfig(unittest.TestCase): + """Tests for read_config function.""" + + @patch("os.path.exists", return_value=False) + @patch("stacky.utils.config.debug") + def test_read_config_no_files(self, mock_debug, mock_exists): + """Test read_config returns defaults when no config files exist.""" + config = read_config() + self.assertIsInstance(config, StackyConfig) + self.assertFalse(config.skip_confirm) + + @patch("os.path.exists", return_value=False) + def test_read_config_with_no_files(self, mock_exists): + """Test read_config returns defaults when no config files exist.""" + # Mock the get_top_level_dir to raise an exception (not in git repo) + with patch("stacky.git.branch.get_top_level_dir", side_effect=Exception("Not in git repo")): + config = read_config() + self.assertIsInstance(config, StackyConfig) + # Should have default values + self.assertFalse(config.skip_confirm) + + +class TestGetConfig(unittest.TestCase): + """Tests for get_config singleton function.""" + + def setUp(self): + """Reset global CONFIG before each test.""" + import stacky.utils.config as config_module + config_module.CONFIG = None + + @patch("stacky.utils.config.read_config") + def test_get_config_caches_result(self, mock_read_config): + """Test get_config caches the config and only reads once.""" + mock_config = StackyConfig(skip_confirm=True) + mock_read_config.return_value = mock_config + + result1 = get_config() + result2 = get_config() + + self.assertEqual(result1, result2) + mock_read_config.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/src/stacky/tests/test_utils/test_shell.py b/src/stacky/tests/test_utils/test_shell.py new file mode 100644 index 0000000..3605a60 --- /dev/null +++ b/src/stacky/tests/test_utils/test_shell.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +"""Tests for stacky.utils.shell module.""" + +import shlex +import subprocess +import unittest +from unittest.mock import patch, MagicMock + +from stacky.utils.shell import ( + _check_returncode, run, run_multiline, run_always_return, remove_prefix +) + + +class TestCheckReturnCode(unittest.TestCase): + """Tests for _check_returncode function.""" + + @patch("stacky.utils.shell.die") + def test_check_returncode_zero(self, mock_die): + """Test that zero return code does not call die.""" + sp = subprocess.CompletedProcess(args=["ls"], returncode=0) + _check_returncode(sp, ["ls"]) + mock_die.assert_not_called() + + @patch("stacky.utils.shell.die") + def test_check_returncode_negative(self, mock_die): + """Test that negative return code (signal) calls die with signal info.""" + sp = subprocess.CompletedProcess(args=["ls"], returncode=-1, stderr=b"error") + _check_returncode(sp, ["ls"]) + mock_die.assert_called_once_with( + "Killed by signal {}: {}. Stderr was:\n{}", + 1, shlex.join(["ls"]), "error" + ) + + @patch("stacky.utils.shell.die") + def test_check_returncode_positive(self, mock_die): + """Test that positive return code calls die with exit status.""" + sp = subprocess.CompletedProcess(args=["ls"], returncode=1, stderr=b"error") + _check_returncode(sp, ["ls"]) + mock_die.assert_called_once_with( + "Exited with status {}: {}. Stderr was:\n{}", + 1, shlex.join(["ls"]), "error" + ) + + +class TestRun(unittest.TestCase): + """Tests for run functions.""" + + @patch("subprocess.run") + @patch("stacky.utils.shell.debug") + def test_run_success(self, mock_debug, mock_subprocess_run): + """Test run returns stripped output on success.""" + mock_subprocess_run.return_value = subprocess.CompletedProcess( + args=["echo", "hello"], + returncode=0, + stdout=b" hello world \n", + stderr=b"" + ) + result = run(["echo", "hello"]) + self.assertEqual(result, "hello world") + + @patch("subprocess.run") + @patch("stacky.utils.shell.debug") + def test_run_failure_check_false(self, mock_debug, mock_subprocess_run): + """Test run returns None on failure when check=False.""" + mock_subprocess_run.return_value = subprocess.CompletedProcess( + args=["false"], + returncode=1, + stdout=b"", + stderr=b"error" + ) + result = run(["false"], check=False) + self.assertIsNone(result) + + @patch("subprocess.run") + @patch("stacky.utils.shell.debug") + def test_run_multiline_preserves_newlines(self, mock_debug, mock_subprocess_run): + """Test run_multiline preserves newlines in output.""" + mock_subprocess_run.return_value = subprocess.CompletedProcess( + args=["echo", "-e", "line1\\nline2"], + returncode=0, + stdout=b"line1\nline2\n", + stderr=b"" + ) + result = run_multiline(["echo", "-e", "line1\\nline2"]) + self.assertEqual(result, "line1\nline2\n") + + @patch("subprocess.run") + @patch("stacky.utils.shell.debug") + def test_run_always_return_asserts_not_none(self, mock_debug, mock_subprocess_run): + """Test run_always_return returns output (asserts not None).""" + mock_subprocess_run.return_value = subprocess.CompletedProcess( + args=["echo", "test"], + returncode=0, + stdout=b"test", + stderr=b"" + ) + result = run_always_return(["echo", "test"]) + self.assertEqual(result, "test") + + +class TestRemovePrefix(unittest.TestCase): + """Tests for remove_prefix function.""" + + def test_remove_prefix_success(self): + """Test remove_prefix removes prefix correctly.""" + result = remove_prefix("refs/heads/main", "refs/heads/") + self.assertEqual(result, "main") + + def test_remove_prefix_full_match(self): + """Test remove_prefix with exact match returns empty string.""" + result = remove_prefix("prefix", "prefix") + self.assertEqual(result, "") + + @patch("stacky.utils.shell.die") + def test_remove_prefix_no_match(self, mock_die): + """Test remove_prefix dies when prefix not found.""" + remove_prefix("other/path", "refs/heads/") + mock_die.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/src/stacky/utils/__init__.py b/src/stacky/utils/__init__.py new file mode 100644 index 0000000..50ab185 --- /dev/null +++ b/src/stacky/utils/__init__.py @@ -0,0 +1 @@ +# Utils module - shared utilities for stacky diff --git a/src/stacky/utils/config.py b/src/stacky/utils/config.py new file mode 100644 index 0000000..1764d89 --- /dev/null +++ b/src/stacky/utils/config.py @@ -0,0 +1,71 @@ +"""Configuration management for stacky.""" + +import configparser +import dataclasses +import os +from typing import Optional + +from stacky.utils.logging import debug + + +@dataclasses.dataclass +class StackyConfig: + """Configuration options for stacky.""" + skip_confirm: bool = False + change_to_main: bool = False + change_to_adopted: bool = False + share_ssh_session: bool = False + use_merge: bool = False + use_force_push: bool = True + compact_pr_display: bool = False + enable_stack_comment: bool = True + + def read_one_config(self, config_path: str): + """Read configuration from a single file.""" + rawconfig = configparser.ConfigParser() + rawconfig.read(config_path) + if rawconfig.has_section("UI"): + self.skip_confirm = rawconfig.getboolean("UI", "skip_confirm", fallback=self.skip_confirm) + self.change_to_main = rawconfig.getboolean("UI", "change_to_main", fallback=self.change_to_main) + self.change_to_adopted = rawconfig.getboolean("UI", "change_to_adopted", fallback=self.change_to_adopted) + self.share_ssh_session = rawconfig.getboolean("UI", "share_ssh_session", fallback=self.share_ssh_session) + self.compact_pr_display = rawconfig.getboolean("UI", "compact_pr_display", fallback=self.compact_pr_display) + self.enable_stack_comment = rawconfig.getboolean("UI", "enable_stack_comment", fallback=self.enable_stack_comment) + + if rawconfig.has_section("GIT"): + self.use_merge = rawconfig.getboolean("GIT", "use_merge", fallback=self.use_merge) + self.use_force_push = rawconfig.getboolean("GIT", "use_force_push", fallback=self.use_force_push) + + +# Global config singleton +CONFIG: Optional[StackyConfig] = None + + +def get_config() -> StackyConfig: + """Get the global configuration, loading it if necessary.""" + global CONFIG + if CONFIG is None: + CONFIG = read_config() + return CONFIG + + +def read_config() -> StackyConfig: + """Read configuration from config files.""" + config = StackyConfig() + config_paths = [os.path.expanduser("~/.stackyconfig")] + + try: + from stacky.git.branch import get_top_level_dir + root_dir = get_top_level_dir() + config_paths.append(f"{root_dir}/.stackyconfig") + except Exception: + # Not in a git repository, skip the repo-level config + debug("Not in a git repository, skipping repo-level config") + pass + + for p in config_paths: + # Root dir config overwrites home directory config + if os.path.exists(p): + config.read_one_config(p) + + return config diff --git a/src/stacky/utils/logging.py b/src/stacky/utils/logging.py new file mode 100644 index 0000000..cbce0b0 --- /dev/null +++ b/src/stacky/utils/logging.py @@ -0,0 +1,77 @@ +"""Logging and output utilities for stacky.""" + +import logging +import os +import sys + +import colors # type: ignore + +_LOGGING_FORMAT = "%(asctime)s %(module)s %(levelname)s: %(message)s" + +# Terminal state - can be modified by main() +COLOR_STDOUT: bool = os.isatty(1) +COLOR_STDERR: bool = os.isatty(2) +IS_TERMINAL: bool = os.isatty(1) and os.isatty(2) + + +def set_color_mode(mode: str): + """Set color mode: 'always', 'auto', or 'never'.""" + global COLOR_STDOUT, COLOR_STDERR + if mode == "always": + COLOR_STDOUT = True + COLOR_STDERR = True + elif mode == "never": + COLOR_STDOUT = False + COLOR_STDERR = False + # 'auto' keeps the default based on isatty + + +def fmt(s: str, *args, color: bool = False, fg=None, bg=None, style=None, **kwargs) -> str: + """Format a string with optional color.""" + s = colors.color(s, fg=fg, bg=bg, style=style) if color else s + return s.format(*args, **kwargs) + + +def cout(*args, **kwargs): + """Write colored output to stdout.""" + return sys.stdout.write(fmt(*args, color=COLOR_STDOUT, **kwargs)) + + +def _log(fn, *args, **kwargs): + """Internal log helper.""" + return fn("%s", fmt(*args, color=COLOR_STDERR, **kwargs)) + + +def debug(*args, **kwargs): + """Log debug message.""" + return _log(logging.debug, *args, fg="green", **kwargs) + + +def info(*args, **kwargs): + """Log info message.""" + return _log(logging.info, *args, fg="green", **kwargs) + + +def warning(*args, **kwargs): + """Log warning message.""" + return _log(logging.warning, *args, fg="yellow", **kwargs) + + +def error(*args, **kwargs): + """Log error message.""" + return _log(logging.error, *args, fg="red", **kwargs) + + +class ExitException(BaseException): + """Exception raised when the program should exit.""" + def __init__(self, fmt, *args, **kwargs): + super().__init__(fmt.format(*args, **kwargs)) + + +def die(*args, **kwargs): + """Exit with an error message. Stops SSH mux if active.""" + # Import here to avoid circular dependency + from stacky.git.remote import stop_muxed_ssh + # We are taking a wild guess at what is the remote ... + stop_muxed_ssh() + raise ExitException(*args, **kwargs) diff --git a/src/stacky/utils/shell.py b/src/stacky/utils/shell.py new file mode 100644 index 0000000..6025e81 --- /dev/null +++ b/src/stacky/utils/shell.py @@ -0,0 +1,61 @@ +"""Shell execution utilities for stacky.""" + +import shlex +import subprocess +import sys +from typing import Optional + +from stacky.utils.logging import debug, die +from stacky.utils.types import CmdArgs + + +def _check_returncode(sp: subprocess.CompletedProcess, cmd: CmdArgs): + """Check the return code of a subprocess and die if non-zero.""" + rc = sp.returncode + if rc == 0: + return + stderr = sp.stderr.decode("UTF-8") + if rc < 0: + die("Killed by signal {}: {}. Stderr was:\n{}", -rc, shlex.join(cmd), stderr) + else: + die("Exited with status {}: {}. Stderr was:\n{}", rc, shlex.join(cmd), stderr) + + +def run_multiline(cmd: CmdArgs, *, check: bool = True, null: bool = True, out: bool = False) -> Optional[str]: + """Run a command and return its output (with newlines preserved).""" + debug("Running: {}", shlex.join(cmd)) + sys.stdout.flush() + sys.stderr.flush() + sp = subprocess.run( + cmd, + stdout=1 if out else subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if check: + _check_returncode(sp, cmd) + rc = sp.returncode + if rc != 0: + return None + if sp.stdout is None: + return "" + return sp.stdout.decode("UTF-8") + + +def run_always_return(cmd: CmdArgs, **kwargs) -> str: + """Run a command and always return output (asserts it's not None).""" + out = run(cmd, **kwargs) + assert out is not None + return out + + +def run(cmd: CmdArgs, **kwargs) -> Optional[str]: + """Run a command and return stripped output.""" + out = run_multiline(cmd, **kwargs) + return None if out is None else out.strip() + + +def remove_prefix(s: str, prefix: str) -> str: + """Remove a prefix from a string, dying if not present.""" + if not s.startswith(prefix): + die('Invalid string "{}": expected prefix "{}"', s, prefix) + return s[len(prefix):] diff --git a/src/stacky/utils/types.py b/src/stacky/utils/types.py new file mode 100644 index 0000000..224038c --- /dev/null +++ b/src/stacky/utils/types.py @@ -0,0 +1,39 @@ +"""Type aliases and constants for stacky.""" + +import logging +import os +from typing import Dict, FrozenSet, List, NewType, Tuple, Union + +# Type aliases +BranchName = NewType("BranchName", str) +PathName = NewType("PathName", str) +Commit = NewType("Commit", str) +CmdArgs = NewType("CmdArgs", List[str]) + +# Forward reference types (actual types defined in stack/models.py) +# These are used for type hints only +StackSubTree = Tuple["StackBranch", "BranchesTree"] # type: ignore +TreeNode = Tuple[BranchName, StackSubTree] +BranchesTree = NewType("BranchesTree", Dict[BranchName, StackSubTree]) +BranchesTreeForest = NewType("BranchesTreeForest", List[BranchesTree]) + +JSON = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None] + +# Constants +MAX_SSH_MUX_LIFETIME = 120 # 2 minutes ought to be enough for anybody ;-) +STATE_FILE = os.path.expanduser("~/.stacky.state") +TMP_STATE_FILE = STATE_FILE + ".tmp" + +# Stack bottoms - mutable set that can be extended +STACK_BOTTOMS: set[BranchName] = set([BranchName("master"), BranchName("main")]) +FROZEN_STACK_BOTTOMS: FrozenSet[BranchName] = frozenset([BranchName("master"), BranchName("main")]) + +# Log levels +LOGLEVELS = { + "critical": logging.CRITICAL, + "error": logging.ERROR, + "warn": logging.WARNING, + "warning": logging.WARNING, + "info": logging.INFO, + "debug": logging.DEBUG, +} diff --git a/src/stacky/utils/ui.py b/src/stacky/utils/ui.py new file mode 100644 index 0000000..c24e039 --- /dev/null +++ b/src/stacky/utils/ui.py @@ -0,0 +1,94 @@ +"""User interface utilities for stacky.""" + +import os +import sys +from typing import TYPE_CHECKING + +import asciitree # type: ignore +from simple_term_menu import TerminalMenu # type: ignore + +from stacky.utils.config import get_config +from stacky.utils.logging import IS_TERMINAL, cout, die + +if TYPE_CHECKING: + from stacky.stack.models import StackBranch + from stacky.utils.types import BranchesTreeForest + + +def prompt(message: str, default_value: str | None) -> str: + """Prompt the user for input.""" + cout(message) + if default_value is not None: + cout("({})", default_value, fg="gray") + cout(" ") + while True: + sys.stderr.flush() + r = input().strip() + + if len(r) > 0: + return r + if default_value: + return default_value + + +def confirm(msg: str = "Proceed?"): + """Ask for confirmation. Skips if skip_confirm is set.""" + if get_config().skip_confirm: + return + if not os.isatty(0): + die("Standard input is not a terminal, use --force option to force action") + print() + while True: + cout("{} [yes/no] ", msg, fg="yellow") + sys.stderr.flush() + r = input().strip().lower() + if r == "yes" or r == "y": + break + if r == "no": + die("Not confirmed") + cout("Please answer yes or no\n", fg="red") + + +# Print upside down, to match our "upstack" / "downstack" nomenclature +_ASCII_TREE_BOX = { + "UP_AND_RIGHT": "\u250c", + "HORIZONTAL": "\u2500", + "VERTICAL": "\u2502", + "VERTICAL_AND_RIGHT": "\u251c", +} +_ASCII_TREE_STYLE = asciitree.drawing.BoxStyle(gfx=_ASCII_TREE_BOX) +ASCII_TREE = asciitree.LeftAligned(draw=_ASCII_TREE_STYLE) + + +def menu_choose_branch(forest: "BranchesTreeForest") -> "StackBranch": + """Display a menu for choosing a branch from the forest.""" + # Import here to avoid circular dependency + from stacky.stack.tree import forest_depth_first, format_tree + + if not IS_TERMINAL: + die("May only choose from menu when using a terminal") + + s = "" + lines = [] + for tree in forest: + s = ASCII_TREE(format_tree(tree)) + lines += [l.rstrip() for l in s.split("\n")] + lines.reverse() + + # Find current branch marker + from stacky.git.branch import get_current_branch_name + current = get_current_branch_name() + initial_index = 0 + for i, l in enumerate(lines): + if "*" in l: # lol + initial_index = i + break + + menu = TerminalMenu(lines, cursor_index=initial_index) + idx = menu.show() + if idx is None: + die("Aborted") + + branches = list(forest_depth_first(forest)) + branches.reverse() + return branches[idx]