diff --git a/.gemini/commands/find-docs.toml b/.gemini/commands/find-docs.toml new file mode 100644 index 00000000000..8282b2a589a --- /dev/null +++ b/.gemini/commands/find-docs.toml @@ -0,0 +1,30 @@ +description = "Find relevant documentation and output GitHub URLs." + +prompt = """ +## Mission: Find Relevant Documentation + +Your task is to find documentation files relevant to the user's question within the current git repository and provide a list of GitHub URLs to view them. + +### Workflow: + +1. **Identify Repository Details**: + * You may use shell commands like `git` or `gh` to get the remote URL of the repository. + * From the remote URL, parse and construct the base GitHub URL (e.g., `https://github.com/user/repo`). You must handle both HTTPS (`https://github.com/user/repo.git`) and SSH (`git@github.com:user/repo.git`) formats. + * Determine the default branch name. You can assume `main` for this purpose, as it is the most common. + +2. **Search for Documentation**: + * First, perform a targeted search across the repository for documentation files (e.g., `.md`, `.mdx`) that seem directly related to the user's question. + * If this initial search yields no relevant results, and a `docs/` directory exists, read the content of all files within the `docs/` directory to find relevant information. + * If you still can't find a direct match, broaden your search to include related concepts and synonyms of the keywords in the user's question. + * For each file you identify as potentially relevant, read its content to confirm it addresses the user's query. + +3. **Construct and Output URLs**: + * For each file you identify as relevant, construct the full GitHub URL by combining the base URL, branch, and file path. **Do not use shell commands for this step.** + * The URL format should be: `{BASE_GITHUB_URL}/blob/{BRANCH_NAME}/{PATH_TO_FILE_FROM_REPO_ROOT}`. + * Present the final list to the user as a markdown list. Each item in the list should be the URL to the document, followed by a short summary of its content. + * If, after all search attempts, you cannot find any relevant documentation, ask the user clarifying questions to better understand their needs. Do not return any URLs in this case. + +### QUESTION: + +{{args}} +""" diff --git a/.gemini/commands/github/cleanup-back-to-main.toml b/.gemini/commands/github/cleanup-back-to-main.toml new file mode 100644 index 00000000000..957eed06f20 --- /dev/null +++ b/.gemini/commands/github/cleanup-back-to-main.toml @@ -0,0 +1,13 @@ +description = "Go back to main and clean up the branch." + +prompt = """ +I'm done with the work on this branch, and I'm ready to go back to main and clean up. + +Here is the workflow I'd like you to follow: + +1. **Get Current Branch:** First, I need you to get the name of the current branch and save it. +2. **Branch Check:** Check if the current branch is `main`. If it is, I need you to stop and let me know. +3. **Go to Main:** Next, I need you to checkout the main branch. +4. **Pull Latest:** Once you are on the main branch, I need you to pull down the latest changes to make sure I'm up to date. +5. **Branch Cleanup:** Finally, I need you to delete the branch that you noted in the first step. +""" diff --git a/.gemini/commands/oncall/pr-review.toml b/.gemini/commands/oncall/pr-review.toml new file mode 100644 index 00000000000..88df74c5861 --- /dev/null +++ b/.gemini/commands/oncall/pr-review.toml @@ -0,0 +1,47 @@ +description = "Review a specific pull request" + +prompt = """ +## Mission: Comprehensive Pull Request Review + +Today, our mission is to meticulously review community pull requests (PRs) for this project. We will proceed systematically, evaluating each candidate PR for its quality, adherence to standards, and readiness for merging. + +### Workflow: + +1. **PR Preparation & Initial Assessment**: + * **You will check out the designated PR {{args}}** into a temporary branch. + * **Execute the preflight checks (`npm run preflight`)**. This includes building, linting, and running all unit tests. + * Analyze the output of these preflight checks, noting any failures, warnings, or linting issues. + +2. **In-Depth Code Review**: + * **Your primary role is to conduct a thorough and in-depth code review** of the changes introduced in the PR. Focus your analysis on the following criteria: + * **Correctness**: Does the code achieve its stated purpose without bugs or logical errors? + * **Maintainability**: Is the code clean, well-structured, and easy to understand and modify in the future? Consider factors like code clarity, modularity, and adherence to established design patterns. + * **Readability**: Is the code well-commented (where necessary) and consistently formatted according to our project's coding style guidelines? + * **Efficiency**: Are there any obvious performance bottlenecks or resource inefficiencies introduced by the changes? + * **Security**: Are there any potential security vulnerabilities or insecure coding practices? + * **Edge Cases and Error Handling**: Does the code appropriately handle edge cases and potential errors? + * **Testability**: Is the new or modified code adequately covered by tests (even if preflight checks pass)? Suggest additional test cases that would improve coverage or robustness. + * Based on your analysis, you will determine if the PR is **safe to merge**. + +3. **Reviewing Previous Feedback**: + * **Access and examine the PR's history** to identify any **outstanding requests or unresolved comments from previous reviews**. Incorporate these into your current review and explicitly highlight if they have been adequately addressed in the current state of the PR. + +4. **Decision and Output Generation**: + * **If the PR is deemed safe to merge** (after your comprehensive review and considering previous feedback): + * Draft a **friendly, concise, and professional approval message**. + * **The approval message should:** + * Clearly state that the PR is approved. + * Briefly acknowledge the quality or value of the contribution (e.g., "Great work on X feature!" or "Appreciate the fix for Y issue!"). + * **Do NOT mention the preflight checks or unit testing**, as these are internal processes. + * Be suitable for public display on GitHub. + * **If the PR is NOT safe to merge**: + * Provide a **clear, constructive, and detailed summary of the issues found**. + * Suggest **specific actionable changes** required for the PR to become merge-ready. + * Ensure the feedback is professional and encourages the contributor. + +### Post-PR Action: + +* After providing your review and decision for the current PR, I will wait for you to perform any manual testing you wish to do. Please let me know when you are finished. +* Once you have confirmed that you are done, I will switch to the `main` branch, clean up the local branch, and perform a pull to ensure we are synchronized with the latest upstream changes for the next review. + +""" diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index bc16c551def..89e9a03b7b2 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,7 +1,15 @@ -# By default, require reviews from the release approvers for all files. -* @google-gemini/gemini-cli-askmode-approvers +# By default, require reviews from the maintainers for all files. +* @google-gemini/gemini-cli-maintainers -# The following files don't need reviews from the release approvers. +# Require reviews from the release approvers for critical files. # These patterns override the rule above. -**/*.md -/docs/ \ No newline at end of file +/package.json @google-gemini/gemini-cli-askmode-approvers +/package-lock.json @google-gemini/gemini-cli-askmode-approvers +/GEMINI.md @google-gemini/gemini-cli-askmode-approvers +/SECURITY.md @google-gemini/gemini-cli-askmode-approvers +/LICENSE @google-gemini/gemini-cli-askmode-approvers +/.github/workflows/ @google-gemini/gemini-cli-askmode-approvers +/packages/cli/package.json @google-gemini/gemini-cli-askmode-approvers +/packages/cli/package-lock.json @google-gemini/gemini-cli-askmode-approvers +/packages/core/package.json @google-gemini/gemini-cli-askmode-approvers +/packages/core/package-lock.json @google-gemini/gemini-cli-askmode-approvers \ No newline at end of file diff --git a/.github/actions/publish-release/action.yml b/.github/actions/publish-release/action.yml new file mode 100644 index 00000000000..21abfa82c52 --- /dev/null +++ b/.github/actions/publish-release/action.yml @@ -0,0 +1,99 @@ +name: 'Publish Release' +description: 'Builds, prepares, and publishes the gemini-cli packages to npm and creates a GitHub release.' + +inputs: + release-version: + description: 'The version to release (e.g., 0.1.11).' + required: true + npm-tag: + description: 'The npm tag to publish with (e.g., latest, preview, nightly).' + required: true + wombat-token-core: + description: 'The npm token for the @blocksuser/gemini-cli-core package.' + required: true + wombat-token-cli: + description: 'The npm token for the @blocksuser/gemini-cli package.' + required: true + github-token: + description: 'The GitHub token for creating the release.' + required: true + dry-run: + description: 'Whether to run in dry-run mode.' + required: true + release-branch: + description: 'The branch to target for the release.' + required: true + previous-tag: + description: 'The previous tag to use for generating release notes.' + required: true + working-directory: + description: 'The working directory to run the steps in.' + required: false + default: '.' + +runs: + using: 'composite' + steps: + - name: 'Build and Prepare Packages' + working-directory: '${{ inputs.working-directory }}' + run: |- + npm run build:packages + npm run prepare:package + shell: 'bash' + + - name: 'Configure npm for publishing' + uses: 'actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020' + with: + node-version-file: '${{ inputs.working-directory }}/.nvmrc' + registry-url: 'https://wombat-dressing-room.appspot.com' + scope: '@google' + + - name: 'Publish @blocksuser/gemini-cli-core' + working-directory: '${{ inputs.working-directory }}' + env: + NODE_AUTH_TOKEN: '${{ inputs.wombat-token-core }}' + run: |- + npm publish \ + --dry-run="${{ inputs.dry-run }}" \ + --workspace="@blocksuser/gemini-cli-core" \ + --tag="${{ inputs.npm-tag }}" + shell: 'bash' + + - name: 'Install latest core package' + working-directory: '${{ inputs.working-directory }}' + if: '${{ inputs.dry-run == "false" }}' + run: |- + npm install "@blocksuser/gemini-cli-core@${{ inputs.release-version }}" \ + --workspace="@blocksuser/gemini-cli" \ + --save-exact + shell: 'bash' + + - name: 'Publish @blocksuser/gemini-cli' + working-directory: '${{ inputs.working-directory }}' + env: + NODE_AUTH_TOKEN: '${{ inputs.wombat-token-cli }}' + run: |- + npm publish \ + --dry-run="${{ inputs.dry-run }}" \ + --workspace="@blocksuser/gemini-cli" \ + --tag="${{ inputs.npm-tag }}" + shell: 'bash' + + - name: 'Bundle' + working-directory: '${{ inputs.working-directory }}' + run: 'npm run bundle' + shell: 'bash' + + - name: 'Create GitHub Release' + working-directory: '${{ inputs.working-directory }}' + if: '${{ inputs.dry-run == "false" }}' + env: + GITHUB_TOKEN: '${{ inputs.github-token }}' + run: |- + gh release create "v${{ inputs.release-version }}" \ + bundle/gemini.js \ + --target "${{ inputs.release-branch }}" \ + --title "Release v${{ inputs.release-version }}" \ + --notes-start-tag "${{ inputs.previous-tag }}" \ + --generate-notes + shell: 'bash' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8d4ebdb77f6..9d2f21737a9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -375,6 +375,8 @@ jobs: # Check for changes in bundle size. bundle_size: name: 'Check Bundle Size' + if: |- + ${{ github.event_name != 'merge_group' }} runs-on: 'ubuntu-latest' permissions: contents: 'read' # For checkout @@ -392,3 +394,4 @@ jobs: pattern: './bundle/**/*.{js,sb}' minimum-change-threshold: '1000' compression: 'none' + clean-script: 'clean' diff --git a/.github/workflows/create-patch-pr.yml b/.github/workflows/create-patch-pr.yml new file mode 100644 index 00000000000..2ec6aed3eb1 --- /dev/null +++ b/.github/workflows/create-patch-pr.yml @@ -0,0 +1,58 @@ +name: 'Create Patch PR' + +on: + workflow_dispatch: + inputs: + commit: + description: 'The commit SHA to cherry-pick for the patch.' + required: true + type: 'string' + channel: + description: 'The release channel to patch.' + required: true + type: 'choice' + options: + - 'stable' + - 'preview' + dry_run: + description: 'Whether to run in dry-run mode.' + required: false + type: 'boolean' + default: false + +jobs: + create-patch: + runs-on: 'ubuntu-latest' + permissions: + contents: 'write' + pull-requests: 'write' + steps: + - name: 'Checkout' + uses: 'actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8' # ratchet:actions/checkout@v5 + with: + fetch-depth: 0 + + - name: 'Setup Node.js' + uses: 'actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020' # ratchet:actions/setup-node@v4 + with: + node-version-file: '.nvmrc' + cache: 'npm' + + - name: 'Install Dependencies' + run: 'npm ci' + + - name: 'Configure Git User' + run: |- + git config user.name "gemini-cli-robot" + git config user.email "gemini-cli-robot@google.com" + + - name: 'Create Patch for Stable' + if: "github.event.inputs.channel == 'stable'" + env: + GH_TOKEN: '${{ secrets.GITHUB_TOKEN }}' + run: 'node scripts/create-patch-pr.js --commit=${{ github.event.inputs.commit }} --channel=stable --dry-run=${{ github.event.inputs.dry_run }}' + + - name: 'Create Patch for Preview' + env: + GH_TOKEN: '${{ secrets.GITHUB_TOKEN }}' + run: 'node scripts/create-patch-pr.js --commit=${{ github.event.inputs.commit }} --channel=${{ github.event.inputs.channel }} --dry-run=${{ github.event.inputs.dry_run }}' diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml new file mode 100644 index 00000000000..37fb6c3a709 --- /dev/null +++ b/.github/workflows/nightly-release.yml @@ -0,0 +1,53 @@ +name: 'Nightly Release' + +on: + schedule: + - cron: '0 0 * * *' + workflow_dispatch: + inputs: + dry_run: + description: 'Run a dry-run of the release process; no branches, npm packages or GitHub releases will be created.' + required: true + type: 'boolean' + default: true + +jobs: + release: + runs-on: 'ubuntu-latest' + steps: + - name: 'Checkout' + uses: 'actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8' + with: + fetch-depth: 0 + + - name: 'Setup Node.js' + uses: 'actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020' # ratchet:actions/setup-node@v4 + with: + node-version-file: '.nvmrc' + cache: 'npm' + + - name: 'Install Dependencies' + run: 'npm ci' + + - name: 'Get Nightly Version' + id: 'nightly_version' + env: + GH_TOKEN: '${{ secrets.GITHUB_TOKEN }}' + run: | + VERSION_JSON=$(node scripts/get-release-version.js --type=nightly) + echo "RELEASE_TAG=$(echo "${VERSION_JSON}" | jq -r .releaseTag)" >> "${GITHUB_OUTPUT}" + echo "RELEASE_VERSION=$(echo "${VERSION_JSON}" | jq -r .releaseVersion)" >> "${GITHUB_OUTPUT}" + echo "NPM_TAG=$(echo "${VERSION_JSON}" | jq -r .npmTag)" >> "${GITHUB_OUTPUT}" + echo "PREVIOUS_TAG=$(echo "${VERSION_JSON}" | jq -r .previousReleaseTag)" >> "${GITHUB_OUTPUT}" + + - name: 'Publish Release' + uses: './.github/actions/publish-release' + with: + release-version: '${{ steps.nightly_version.outputs.RELEASE_VERSION }}' + npm-tag: '${{ steps.nightly_version.outputs.NPM_TAG }}' + wombat-token-core: '${{ secrets.WOMBAT_TOKEN_CORE }}' + wombat-token-cli: '${{ secrets.WOMBAT_TOKEN_CLI }}' + github-token: '${{ secrets.GITHUB_TOKEN }}' + dry-run: '${{ github.event.inputs.dry_run }}' + release-branch: 'main' + previous-tag: '${{ steps.nightly_version.outputs.PREVIOUS_TAG }}' diff --git a/.github/workflows/patch-from-comment.yml b/.github/workflows/patch-from-comment.yml new file mode 100644 index 00000000000..55065b5b1cd --- /dev/null +++ b/.github/workflows/patch-from-comment.yml @@ -0,0 +1,56 @@ +name: 'Patch from Comment' + +on: + issue_comment: + types: ['created'] + +jobs: + slash-command: + runs-on: 'ubuntu-latest' + steps: + - name: 'Slash Command Dispatch' + id: 'slash_command' + uses: 'peter-evans/slash-command-dispatch@40877f718dce0101edfc7aea2b3800cc192f9ed5' + with: + token: '${{ secrets.GITHUB_TOKEN }}' + commands: 'patch' + permission: 'write' + issue-type: 'pull-request' + static-args: | + dry_run=false + + - name: 'Get PR Status' + id: 'pr_status' + if: "steps.slash_command.outputs.dispatched == 'true'" + env: + GH_TOKEN: '${{ secrets.GITHUB_TOKEN }}' + run: | + gh pr view "${{ github.event.issue.number }}" --json mergeCommit,state > pr_status.json + echo "MERGE_COMMIT_SHA=$(jq -r .mergeCommit.oid pr_status.json)" >> "$GITHUB_OUTPUT" + echo "STATE=$(jq -r .state pr_status.json)" >> "$GITHUB_OUTPUT" + + - name: 'Dispatch if Merged' + if: "steps.pr_status.outputs.STATE == 'MERGED'" + uses: 'actions/github-script@00f12e3e20659f42342b1c0226afda7f7c042325' + with: + script: | + const args = JSON.parse('${{ steps.slash_command.outputs.command-arguments }}'); + github.rest.actions.createWorkflowDispatch({ + owner: context.repo.owner, + repo: context.repo.repo, + workflow_id: 'create-patch-pr.yml', + ref: 'main', + inputs: { + commit: '${{ steps.pr_status.outputs.MERGE_COMMIT_SHA }}', + channel: args.channel, + dry_run: args.dry_run + } + }) + + - name: 'Comment on Failure' + if: "steps.pr_status.outputs.STATE != 'MERGED'" + uses: 'peter-evans/create-or-update-comment@67dcc547d311b736a8e6c5c236542148a47adc3d' + with: + issue-number: '${{ github.event.issue.number }}' + body: | + :x: The `/patch` command failed. This pull request must be merged before a patch can be created. diff --git a/.github/workflows/promote-release.yml b/.github/workflows/promote-release.yml new file mode 100644 index 00000000000..0ae7446614f --- /dev/null +++ b/.github/workflows/promote-release.yml @@ -0,0 +1,213 @@ +name: 'Promote Release' + +on: + workflow_dispatch: + inputs: + dry_run: + description: 'Run a dry-run of the release process; no branches, npm packages or GitHub releases will be created.' + required: true + type: 'boolean' + default: true + +jobs: + calculate-versions: + name: 'Calculate Versions and Plan' + runs-on: 'ubuntu-latest' + outputs: + STABLE_VERSION: '${{ steps.versions.outputs.STABLE_VERSION }}' + STABLE_SHA: '${{ steps.versions.outputs.STABLE_SHA }}' + PREVIOUS_STABLE_TAG: '${{ steps.versions.outputs.PREVIOUS_STABLE_TAG }}' + PREVIEW_VERSION: '${{ steps.versions.outputs.PREVIEW_VERSION }}' + PREVIEW_SHA: '${{ steps.versions.outputs.PREVIEW_SHA }}' + PREVIOUS_PREVIEW_TAG: '${{ steps.versions.outputs.PREVIOUS_PREVIEW_TAG }}' + NEXT_NIGHTLY_VERSION: '${{ steps.versions.outputs.NEXT_NIGHTLY_VERSION }}' + + steps: + - name: 'Checkout' + uses: 'actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8' + with: + fetch-depth: 0 + + - name: 'Setup Node.js' + uses: 'actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020' + with: + node-version-file: '.nvmrc' + cache: 'npm' + + - name: 'Install Dependencies' + run: 'npm ci' + + - name: 'Calculate Versions and SHAs' + id: 'versions' + env: + GH_TOKEN: '${{ secrets.GITHUB_TOKEN }}' + run: | + set -e + STABLE_JSON=$(node scripts/get-release-version.js --type=stable) + PREVIEW_JSON=$(node scripts/get-release-version.js --type=preview) + NIGHTLY_JSON=$(node scripts/get-release-version.js --type=nightly) + echo "STABLE_VERSION=$(echo "${STABLE_JSON}" | jq -r .releaseVersion)" >> "${GITHUB_OUTPUT}" + # shellcheck disable=SC1083 + echo "STABLE_SHA=$(git rev-parse "$(echo "${STABLE_JSON}" | jq -r .previousReleaseTag)"^{commit})" >> "${GITHUB_OUTPUT}" + echo "PREVIOUS_STABLE_TAG=$(echo "${STABLE_JSON}" | jq -r .previousReleaseTag)" >> "${GITHUB_OUTPUT}" + echo "PREVIEW_VERSION=$(echo "${PREVIEW_JSON}" | jq -r .releaseVersion)" >> "${GITHUB_OUTPUT}" + # shellcheck disable=SC1083 + echo "PREVIEW_SHA=$(git rev-parse "$(echo "${PREVIEW_JSON}" | jq -r .previousReleaseTag)"^{commit})" >> "${GITHUB_OUTPUT}" + echo "PREVIOUS_PREVIEW_TAG=$(echo "${PREVIEW_JSON}" | jq -r .previousReleaseTag)" >> "${GITHUB_OUTPUT}" + echo "NEXT_NIGHTLY_VERSION=$(echo "${NIGHTLY_JSON}" | jq -r .releaseVersion)" >> "${GITHUB_OUTPUT}" + + promote: + name: 'Promote to ${{ matrix.channel }}' + needs: 'calculate-versions' + runs-on: 'ubuntu-latest' + permissions: + contents: 'write' + packages: 'write' + strategy: + matrix: + include: + - channel: 'stable' + version: '${{ needs.calculate-versions.outputs.STABLE_VERSION }}' + sha: '${{ needs.calculate-versions.outputs.STABLE_SHA }}' + npm-tag: 'latest' + previous-tag: '${{ needs.calculate-versions.outputs.PREVIOUS_STABLE_TAG }}' + - channel: 'preview' + version: '${{ needs.calculate-versions.outputs.PREVIEW_VERSION }}' + sha: '${{ needs.calculate-versions.outputs.PREVIEW_SHA }}' + npm-tag: 'preview' + previous-tag: '${{ needs.calculate-versions.outputs.PREVIOUS_PREVIEW_TAG }}' + + steps: + - name: 'Checkout main' + uses: 'actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8' + with: + ref: 'main' + + - name: 'Checkout correct SHA' + uses: 'actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8' + with: + ref: '${{ matrix.sha }}' + path: 'release' + + - name: 'Setup Node.js' + uses: 'actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020' + with: + node-version-file: '.nvmrc' + cache: 'npm' + + - name: 'Install Dependencies' + working-directory: './release' + run: 'npm ci' + + - name: 'Configure Git User' + working-directory: './release' + run: |- + git config user.name "gemini-cli-robot" + git config user.email "gemini-cli-robot@google.com" + + - name: 'Create and switch to a release branch' + working-directory: './release' + id: 'release_branch' + run: | + BRANCH_NAME="release/v${{ matrix.version }}" + git switch -c "${BRANCH_NAME}" + echo "BRANCH_NAME=${BRANCH_NAME}" >> "${GITHUB_OUTPUT}" + + - name: 'Update package versions' + working-directory: './release' + run: 'npm run release:version "${{ matrix.version }}"' + + - name: 'Commit and Conditionally Push package versions' + working-directory: './release' + env: + BRANCH_NAME: '${{ steps.release_branch.outputs.BRANCH_NAME }}' + DRY_RUN: '${{ github.event.inputs.dry_run }}' + RELEASE_TAG: 'v${{ matrix.version }}' + run: |- + git add package.json package-lock.json packages/*/package.json + git commit -m "chore(release): ${RELEASE_TAG}" + if [[ "${DRY_RUN}" == "false" ]]; then + echo "Pushing release branch to remote..." + git push --set-upstream origin "${BRANCH_NAME}" --follow-tags + else + echo "Dry run enabled. Skipping push." + fi + + - name: 'Publish Release' + uses: './.github/actions/publish-release' + with: + release-version: '${{ matrix.version }}' + npm-tag: '${{ matrix.npm-tag }}' + wombat-token-core: '${{ secrets.WOMBAT_TOKEN_CORE }}' + wombat-token-cli: '${{ secrets.WOMBAT_TOKEN_CLI }}' + github-token: '${{ secrets.GITHUB_TOKEN }}' + dry-run: '${{ github.event.inputs.dry_run }}' + release-branch: '${{ steps.release_branch.outputs.BRANCH_NAME }}' + previous-tag: '${{ matrix.previous-tag }}' + working-directory: './release' + + nightly-pr: + name: 'Create Nightly PR' + needs: 'calculate-versions' + runs-on: 'ubuntu-latest' + permissions: + contents: 'write' + pull-requests: 'write' + steps: + - name: 'Checkout main' + uses: 'actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8' + with: + ref: 'main' + + - name: 'Setup Node.js' + uses: 'actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020' + with: + node-version-file: '.nvmrc' + cache: 'npm' + + - name: 'Install Dependencies' + run: 'npm ci' + + - name: 'Configure Git User' + run: |- + git config user.name "gemini-cli-robot" + git config user.email "gemini-cli-robot@google.com" + + - name: 'Create and switch to a new branch' + id: 'release_branch' + run: | + BRANCH_NAME="chore/nightly-version-bump-${{ needs.calculate-versions.outputs.NEXT_NIGHTLY_VERSION }}" + git switch -c "${BRANCH_NAME}" + echo "BRANCH_NAME=${BRANCH_NAME}" >> "${GITHUB_OUTPUT}" + + - name: 'Update package versions' + run: 'npm run release:version "${{ needs.calculate-versions.outputs.NEXT_NIGHTLY_VERSION }}"' + + - name: 'Commit and Push package versions' + env: + BRANCH_NAME: '${{ steps.release_branch.outputs.BRANCH_NAME }}' + DRY_RUN: '${{ github.event.inputs.dry_run }}' + run: |- + git add package.json package-lock.json packages/*/package.json + git commit -m "chore(release): bump version to ${{ needs.calculate-versions.outputs.NEXT_NIGHTLY_VERSION }}" + if [[ "${DRY_RUN}" == "false" ]]; then + echo "Pushing release branch to remote..." + git push --set-upstream origin "${BRANCH_NAME}" + else + echo "Dry run enabled. Skipping push." + fi + + - name: 'Create and Approve Pull Request' + if: |- + ${{ github.event.inputs.dry_run == 'false' }} + env: + GH_TOKEN: '${{ secrets.GITHUB_TOKEN }}' + BRANCH_NAME: '${{ steps.release_branch.outputs.BRANCH_NAME }}' + run: | + gh pr create \ + --title "chore(release): bump version to ${{ needs.calculate-versions.outputs.NEXT_NIGHTLY_VERSION }}" \ + --body "Automated version bump to prepare for the next nightly release." \ + --base "main" \ + --head "${BRANCH_NAME}" \ + --fill + gh pr merge --auto --squash diff --git a/.github/workflows/trigger-patch-release.yml b/.github/workflows/trigger-patch-release.yml new file mode 100644 index 00000000000..4270111bbd9 --- /dev/null +++ b/.github/workflows/trigger-patch-release.yml @@ -0,0 +1,30 @@ +name: 'Trigger Patch Release' + +on: + pull_request: + types: + - 'closed' + +jobs: + trigger-patch-release: + if: "github.event.pull_request.merged == true && startsWith(github.head_ref, 'hotfix/')" + runs-on: 'ubuntu-latest' + steps: + - name: 'Trigger Patch Release' + uses: 'actions/github-script@00f12e3e20659f42342b1c0226afda7f7c042325' + with: + script: | + const body = context.payload.pull_request.body; + const isDryRun = body.includes('[DRY RUN]'); + const ref = context.payload.pull_request.base.ref; + const channel = ref.includes('preview') ? 'preview' : 'stable'; + github.rest.actions.createWorkflowDispatch({ + owner: context.repo.owner, + repo: context.repo.repo, + workflow_id: 'patch-release.yml', + ref: ref, + inputs: { + type: channel, + dry_run: isDryRun.toString() + } + }) diff --git a/.gitignore b/.gitignore index bcead317941..47d9f9bd33c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,8 +3,14 @@ .env~ # gemini-cli settings -.gemini/ -!gemini/config.yaml +# We want to keep the .gemini in the root of the repo and ignore any .gemini +# in subdirectories. In our root .gemini we want to allow for version control +# for subcommands. +**/.gemini/ +!/.gemini/ +.gemini/* +!.gemini/config.yaml +!.gemini/commands/ # Note: .gemini-clipboard/ is NOT in gitignore so Gemini can access pasted images diff --git a/.yamllint.yml b/.yamllint.yml index b4612e07dbe..b1f5eb7d003 100644 --- a/.yamllint.yml +++ b/.yamllint.yml @@ -86,3 +86,4 @@ ignore: - 'thirdparty/' - 'third_party/' - 'vendor/' + - 'node_modules/' diff --git a/LICENSE b/LICENSE index 346b3f959b5..7a4a3ea2424 100644 --- a/LICENSE +++ b/LICENSE @@ -187,7 +187,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2025 Google LLC + Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -199,4 +199,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index 0015453f5f4..204c032aed8 100644 --- a/README.md +++ b/README.md @@ -106,7 +106,7 @@ Integrate Gemini CLI directly into your GitHub workflows with [**Gemini CLI GitH Choose the authentication method that best fits your needs: -### Option 1: OAuth login (Using your Google Account) +### Option 1: Login with Google (OAuth login using your Google Account) **✨ Best for:** Individual developers as well as anyone who has a Gemini Code Assist License. (see [quota limits and terms of service](https://cloud.google.com/gemini/docs/quotas) for details) @@ -117,7 +117,7 @@ Choose the authentication method that best fits your needs: - **No API key management** - just sign in with your Google account - **Automatic updates** to latest models -#### Start Gemini CLI, then choose OAuth and follow the browser authentication flow when prompted +#### Start Gemini CLI, then choose _Login with Google_ and follow the browser authentication flow when prompted ```bash gemini @@ -190,10 +190,19 @@ gemini -m gemini-2.5-flash #### Non-interactive mode for scripts +Get a simple text response: + ```bash gemini -p "Explain the architecture of this codebase" ``` +For more advanced scripting, including how to parse JSON and handle errors, use +the `--output-format json` flag to get structured output: + +```bash +gemini -p "Explain the architecture of this codebase" --output-format json +``` + ### Quick Examples #### Start a new project diff --git a/docs/cli/configuration.md b/docs/cli/configuration.md index f512ec7cc0d..039dbdab7fd 100644 --- a/docs/cli/configuration.md +++ b/docs/cli/configuration.md @@ -76,6 +76,13 @@ Settings are organized into categories. All settings should be placed within the - **Description:** Enable session checkpointing for recovery. - **Default:** `false` +#### `output` + +- **`output.format`** (string): + - **Description:** The format of the CLI output. + - **Default:** `"text"` + - **Values:** `"text"`, `"json"` + #### `ui` - **`ui.theme`** (string): @@ -442,11 +449,18 @@ Arguments passed directly when running the CLI can override other configurations - Example: `npm start -- --model gemini-1.5-pro-latest` - **`--prompt `** (**`-p `**): - Used to pass a prompt directly to the command. This invokes Gemini CLI in a non-interactive mode. + - For scripting examples, use the `--output-format json` flag to get structured output. - **`--prompt-interactive `** (**`-i `**): - Starts an interactive session with the provided prompt as the initial input. - The prompt is processed within the interactive session, not before it. - Cannot be used when piping input from stdin. - Example: `gemini -i "explain this code"` +- **`--output-format `**: + - **Description:** Specifies the format of the CLI output for non-interactive mode. + - **Values:** + - `text`: (Default) The standard human-readable output. + - `json`: A machine-readable JSON output. + - **Note:** For structured output and scripting, use the `--output-format json` flag. - **`--sandbox`** (**`-s`**): - Enables sandbox mode for this session. - **`--sandbox-image`**: diff --git a/docs/cli/index.md b/docs/cli/index.md index 1b5e1796d64..d9fcd0a6e5b 100644 --- a/docs/cli/index.md +++ b/docs/cli/index.md @@ -27,3 +27,19 @@ Gemini CLI executes the command and prints the output to your terminal. Note tha ```bash gemini -p "What is fine tuning?" ``` + +For non-interactive usage with structured output, use the `--output-format json` flag for scripting and automation. + +Get structured JSON output for scripting: + +```bash +gemini -p "What is fine tuning?" --output-format json +# Output: +# { +# "response": "Fine tuning is...", +# "stats": { +# "models": { "gemini-2.5-flash": { "tokens": {"total": 45} } } +# }, +# "error": null +# } +``` diff --git a/docs/extension.md b/docs/extension.md index b807524651b..8816133c293 100644 --- a/docs/extension.md +++ b/docs/extension.md @@ -1,19 +1,88 @@ # Gemini CLI Extensions -Gemini CLI supports extensions that can be used to configure and extend its functionality. +_This documentation is up-to-date with the v0.4.0 release._ -## How it works +Gemini CLI extensions package prompts, MCP servers, and custom commands into a familiar and user-friendly format. With extensions, you can expand the capabilities of Gemini CLI and share those capabilities with others. They are designed to be easily installable and shareable. + +## Extension management + +We offer a suite of extension management tools using `gemini extensions` commands. + +Note that these commands are not supported from within the CLI, although you can list installed extensions using the `/extensions list` subcommand. + +Note that all of these commands will only be reflected in active CLI sessions on restart. + +### Installing an extension + +You can install an extension using `gemini extensions install` with either a GitHub URL source or `--path=some/local/path`. + +Note that we create a copy of the installed extension, so you will need to run `gemini extensions update` to pull in changes from both locally-defined extensions and those on GitHub. + +``` +gemini extensions install https://github.com/google-gemini/gemini-cli-security +``` + +This will install the Gemini CLI Security extension, which offers support for a `/security:analyze` command. + +### Uninstalling an extension + +To uninstall, run `gemini extensions uninstall extension-name`, so, in the case of the install example: + +``` +gemini extensions uninstall gemini-cli-security +``` + +### Disabling an extension + +Extensions are, by default, enabled across all workspaces. You can disable an extension entirely or for specific workspace. + +For example, `gemini extensions disable extension-name` will disable the extension at the user level, so it will be disabled everywhere. `gemini extensions disable extension-name --scope=Workspace` will only disable the extension in the current workspace. + +### Enabling an extension -On startup, Gemini CLI looks for extensions in two locations: +You can re-enable extensions using `gemini extensions enable extension-name`. Note that if an extension is disabled at the user-level, enabling it at the workspace level will not do anything. -1. `/.gemini/extensions` -2. `/.gemini/extensions` +### Updating an extension + +For extensions installed from a local path or a git repository, you can explicitly update to the latest version (as reflected in the `gemini-extension.json` `version` field) with `gemini extensions update extension-name`. + +You can update all extensions with: + +``` +gemini extensions update --all +``` + +## Extension creation + +We offer commands to make extension development easier. + +### Create a boilerplate extension + +We offer several example extensions `context`, `custom-commands`, `exclude-tools` and `mcp-server`. You can view these examples [here](https://github.com/google-gemini/gemini-cli/tree/main/packages/cli/src/commands/extensions/examples). + +To copy one of these examples into a development directory using the type of your choosing, run: + +``` +gemini extensions new --path=path/to/directory --type=custom-commands +``` + +### Link a local extension + +The `gemini extensions link` command will create a symbolic link from the extension installation directory to the development path. + +This is useful so you don't have to run `gemini extensions update` every time you make changes you'd like to test. + +``` +gemini extensions link path/to/directory +``` + +## How it works -Gemini CLI loads all extensions from both locations. If an extension with the same name exists in both locations, the extension in the workspace directory takes precedence. +On startup, Gemini CLI looks for extensions in `/.gemini/extensions` -Within each location, individual extensions exist as a directory that contains a `gemini-extension.json` file. For example: +Extensions exist as a directory that contains a `gemini-extension.json` file. For example: -`/.gemini/extensions/my-extension/gemini-extension.json` +`/.gemini/extensions/my-extension/gemini-extension.json` ### `gemini-extension.json` @@ -33,19 +102,19 @@ The `gemini-extension.json` file contains the configuration for the extension. T } ``` -- `name`: The name of the extension. This is used to uniquely identify the extension and for conflict resolution when extension commands have the same name as user or project commands. +- `name`: The name of the extension. This is used to uniquely identify the extension and for conflict resolution when extension commands have the same name as user or project commands. The name should be lowercase and use dashes instead of underscores or spaces. This is how users will refer to your extension in the CLI. Note that we expect this name to match the extension directory name. - `version`: The version of the extension. - `mcpServers`: A map of MCP servers to configure. The key is the name of the server, and the value is the server configuration. These servers will be loaded on startup just like MCP servers configured in a [`settings.json` file](./cli/configuration.md). If both an extension and a `settings.json` file configure an MCP server with the same name, the server defined in the `settings.json` file takes precedence. -- `contextFileName`: The name of the file that contains the context for the extension. This will be used to load the context from the workspace. If this property is not used but a `GEMINI.md` file is present in your extension directory, then that file will be loaded. +- `contextFileName`: The name of the file that contains the context for the extension. This will be used to load the context from the extension directory. If this property is not used but a `GEMINI.md` file is present in your extension directory, then that file will be loaded. - `excludeTools`: An array of tool names to exclude from the model. You can also specify command-specific restrictions for tools that support it, like the `run_shell_command` tool. For example, `"excludeTools": ["run_shell_command(rm -rf)"]` will block the `rm -rf` command. When Gemini CLI starts, it loads all the extensions and merges their configurations. If there are any conflicts, the workspace configuration takes precedence. -## Extension Commands +### Custom commands Extensions can provide [custom commands](./cli/commands.md#custom-commands) by placing TOML files in a `commands/` subdirectory within the extension directory. These commands follow the same format as user and project custom commands and use standard naming conventions. -### Example +**Example** An extension named `gcp` with the following structure: @@ -63,7 +132,7 @@ Would provide these commands: - `/deploy` - Shows as `[gcp] Custom command from deploy.toml` in help - `/gcs:sync` - Shows as `[gcp] Custom command from sync.toml` in help -### Conflict Resolution +### Conflict resolution Extension commands have the lowest precedence. When a conflict occurs with user or project commands: @@ -75,7 +144,7 @@ For example, if both a user and the `gcp` extension define a `deploy` command: - `/deploy` - Executes the user's deploy command - `/gcp.deploy` - Executes the extension's deploy command (marked with `[gcp]` tag) -# Variables +## Variables Gemini CLI extensions allow variable substitution in `gemini-extension.json`. This can be useful if e.g., you need the current directory to run an MCP server using `"cwd": "${extensionPath}${/}run.ts"`. diff --git a/docs/releases.md b/docs/releases.md index 9e93680de08..65fce46be62 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -58,50 +58,58 @@ After one week (On the following Tuesday) with all signals a go, we will manuall ## Patching Releases -If a critical bug needs to be fixed before the next scheduled release, follow this process to create a patch. +If a critical bug that is already fixed on `main` needs to be patched on a `stable` or `preview` release, the process is now highly automated. -### 1. Create a Hotfix Branch +### 1. Create the Patch Pull Request -First, create a new branch for your fix. The source for this branch depends on whether you are patching a stable or a preview release. +There are two ways to create a patch pull request: -- **For a stable release patch:** - Create a branch from the Git tag of the version you need to patch. Tag names are formatted as `vx.y.z`. +**Option A: From a GitHub Comment (Recommended)** - ```bash - # Example: Create a hotfix branch for v0.2.0 - git checkout v0.2.0 -b hotfix/issue-123-fix-for-v0.2.0 - ``` +After a pull request has been merged, a maintainer can add a comment on that same PR with the following format: -- **For a preview release patch:** - Create a branch from the existing preview release branch, which is formatted as `release/vx.y.z-preview.n`. +`/patch [--dry-run]` - ```bash - # Example: Create a hotfix branch for a preview release - git checkout release/v0.2.0-preview.0 && git checkout -b hotfix/issue-456-fix-for-preview - ``` +- **channel**: `stable` or `preview` +- **--dry-run** (optional): If included, the workflow will run in dry-run mode. This will create the PR with "[DRY RUN]" in the title, and merging it will trigger a dry run of the final release, so nothing is actually published. -### 2. Implement the Fix +Example: `/patch stable --dry-run` -In your new hotfix branch, either create a new commit with the fix or cherry-pick an existing commit from the `main` branch. Merge your changes into the source of the hotfix branch (ex. https://github.com/google-gemini/gemini-cli/pull/6850). +The workflow will automatically find the merge commit SHA and begin the patch process. If the PR is not yet merged, it will post a comment indicating the failure. -### 3. Perform the Release +**Option B: Manually Triggering the Workflow** -Follow the manual release process using the "Release" GitHub Actions workflow. +Navigate to the **Actions** tab and run the **Create Patch PR** workflow. -- **Version**: For stable patches, increment the patch version (e.g., `v0.2.0` -> `v0.2.1`). For preview patches, increment the preview number (e.g., `v0.2.0-preview.0` -> `v0.2.0-preview.1`). -- **Ref**: Use your source branch as the reference (ex. `release/v0.2.0-preview.0`) +- **Commit**: The full SHA of the commit on `main` that you want to cherry-pick. +- **Channel**: The channel you want to patch (`stable` or `preview`). -![How to run a release](assets/release_patch.png) +This workflow will automatically: -### 4. Update Versions +1. Find the latest release tag for the channel. +2. Create a release branch from that tag if one doesn't exist (e.g., `release/v0.5.1`). +3. Create a new hotfix branch from the release branch. +4. Cherry-pick your specified commit into the hotfix branch. +5. Create a pull request from the hotfix branch back to the release branch. -After the hotfix is released, merge the changes back to the appropriate branch. +**Important:** If you select `stable`, the workflow will run twice, creating one PR for the `stable` channel and a second PR for the `preview` channel. -- **For a stable release hotfix:** - Open a pull request to merge the release branch (e.g., `release/0.2.1`) back into `main`. This keeps the version number in `main` up to date. +### 2. Review and Merge -- **For a preview release hotfix:** - Open a pull request to merge the new preview release branch (e.g., `release/v0.2.0-preview.1`) back into the existing preview release branch (`release/v0.2.0-preview.0`) (ex. https://github.com/google-gemini/gemini-cli/pull/6868) +Review the automatically created pull request(s) to ensure the cherry-pick was successful and the changes are correct. Once approved, merge the pull request. + +**Security Note:** The `release/*` branches are protected by branch protection rules. A pull request to one of these branches requires at least one review from a code owner before it can be merged. This ensures that no unauthorized code is released. + +### 3. Automatic Release + +Upon merging the pull request, a final workflow is automatically triggered. It will: + +1. Run the `patch-release` workflow. +2. Build and test the patched code. +3. Publish the new patch version to npm. +4. Create a new GitHub release with the patch notes. + +This fully automated process ensures that patches are created and released consistently and reliably. ## Release Schedule diff --git a/docs/telemetry.md b/docs/telemetry.md index 71038f5e57a..562e0657697 100644 --- a/docs/telemetry.md +++ b/docs/telemetry.md @@ -176,6 +176,7 @@ Logs are timestamped records of specific events. The following events are logged - `file_filtering_respect_git_ignore` (boolean) - `debug_mode` (boolean) - `mcp_servers` (string) + - `output_format` (string: "text" or "json") - `gemini_cli.user_prompt`: This event occurs when a user submits a prompt. - **Attributes**: @@ -193,6 +194,7 @@ Logs are timestamped records of specific events. The following events are logged - `decision` (string: "accept", "reject", "auto_accept", or "modify", if applicable) - `error` (if applicable) - `error_type` (if applicable) + - `content_length` (int, if applicable) - `metadata` (if applicable, dictionary of string -> any) - `gemini_cli.file_operation`: This event occurs for each file operation. @@ -237,6 +239,15 @@ Logs are timestamped records of specific events. The following events are logged - `response_text` (if applicable) - `auth_type` +- `gemini_cli.tool_output_truncated`: This event occurs when the output of a tool call is too large and gets truncated. + - **Attributes**: + - `tool_name` (string) + - `original_content_length` (int) + - `truncated_content_length` (int) + - `threshold` (int) + - `lines` (int) + - `prompt_id` (string) + - `gemini_cli.malformed_json_response`: This event occurs when a `generateJson` response from Gemini API cannot be parsed as a json. - **Attributes**: - `model` diff --git a/eslint.config.js b/eslint.config.js index 50efb79cded..4012f3d1e9a 100644 --- a/eslint.config.js +++ b/eslint.config.js @@ -35,6 +35,7 @@ export default tseslint.config( 'bundle/**', 'package/bundle/**', '.integration-tests/**', + 'dist/**', ], }, eslint.configs.recommended, diff --git a/integration-tests/json-output.test.ts b/integration-tests/json-output.test.ts new file mode 100644 index 00000000000..27caee40038 --- /dev/null +++ b/integration-tests/json-output.test.ts @@ -0,0 +1,37 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { expect, describe, it, beforeEach, afterEach } from 'vitest'; +import { TestRig } from './test-helper.js'; + +describe('JSON output', () => { + let rig: TestRig; + + beforeEach(async () => { + rig = new TestRig(); + await rig.setup('json-output-test'); + }); + + afterEach(async () => { + await rig.cleanup(); + }); + + it('should return a valid JSON with response and stats', async () => { + const result = await rig.run( + 'What is the capital of France?', + '--output-format', + 'json', + ); + const parsed = JSON.parse(result); + + expect(parsed).toHaveProperty('response'); + expect(typeof parsed.response).toBe('string'); + expect(parsed.response.toLowerCase()).toContain('paris'); + + expect(parsed).toHaveProperty('stats'); + expect(typeof parsed.stats).toBe('object'); + }); +}); diff --git a/integration-tests/test-helper.ts b/integration-tests/test-helper.ts index a02b7a28c31..f86b72d7872 100644 --- a/integration-tests/test-helper.ts +++ b/integration-tests/test-helper.ts @@ -284,8 +284,15 @@ export class TestRig { result = filteredLines.join('\n'); } - // If we have stderr output, include that also - if (stderr) { + + // Check if this is a JSON output test - if so, don't include stderr + // as it would corrupt the JSON + const isJsonOutput = + commandArgs.includes('--output-format') && + commandArgs.includes('json'); + + // If we have stderr output and it's not a JSON test, include that also + if (stderr && !isJsonOutput) { result += `\n\nStdErr:\n${stderr}`; } diff --git a/package-lock.json b/package-lock.json index 541297e2fa6..b93efbd5a87 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,17 +1,16 @@ { - "name": "@google/gemini-cli", - "version": "0.3.3", + "name": "@blocksuser/gemini-cli", + "version": "0.7.0-nightly.20250912.68035591", "lockfileVersion": 3, "requires": true, "packages": { "": { - "name": "@google/gemini-cli", - "version": "0.3.3", + "name": "@blocksuser/gemini-cli", + "version": "0.7.0-nightly.20250912.68035591", "workspaces": [ "packages/*" ], "dependencies": { - "@lvce-editor/ripgrep": "^1.6.0", "simple-git": "^3.28.0" }, "bin": { @@ -445,6 +444,22 @@ "node": ">=18" } }, + "node_modules/@blocksuser/gemini-cli": { + "resolved": "packages/cli", + "link": true + }, + "node_modules/@blocksuser/gemini-cli-a2a-server": { + "resolved": "packages/a2a-server", + "link": true + }, + "node_modules/@blocksuser/gemini-cli-core": { + "resolved": "packages/core", + "link": true + }, + "node_modules/@blocksuser/gemini-cli-test-utils": { + "resolved": "packages/test-utils", + "link": true + }, "node_modules/@bundled-es-modules/cookie": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/@bundled-es-modules/cookie/-/cookie-2.0.1.tgz", @@ -1302,22 +1317,6 @@ "uuid": "dist/bin/uuid" } }, - "node_modules/@google/gemini-cli": { - "resolved": "packages/cli", - "link": true - }, - "node_modules/@google/gemini-cli-a2a-server": { - "resolved": "packages/a2a-server", - "link": true - }, - "node_modules/@google/gemini-cli-core": { - "resolved": "packages/core", - "link": true - }, - "node_modules/@google/gemini-cli-test-utils": { - "resolved": "packages/test-utils", - "link": true - }, "node_modules/@google/genai": { "version": "1.16.0", "resolved": "https://registry.npmjs.org/@google/genai/-/genai-1.16.0.tgz", @@ -1684,6 +1683,28 @@ "node": "^18.14.0 || ^20.0.0 || ^22.0.0 || >=24.0.0" } }, + "node_modules/@joshua.litt/get-ripgrep": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/@joshua.litt/get-ripgrep/-/get-ripgrep-0.0.2.tgz", + "integrity": "sha512-cSHA+H+HEkOXeiCxrNvGj/pgv2Y0bfp4GbH3R87zr7Vob2pDUZV3BkUL9ucHMoDFID4GteSy5z5niN/lF9QeuQ==", + "dependencies": { + "@lvce-editor/verror": "^1.6.0", + "execa": "^9.5.2", + "extract-zip": "^2.0.1", + "fs-extra": "^11.3.0", + "got": "^14.4.5", + "path-exists": "^5.0.0", + "xdg-basedir": "^5.1.0" + } + }, + "node_modules/@joshua.litt/get-ripgrep/node_modules/path-exists": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-5.0.0.tgz", + "integrity": "sha512-RjhtfwJOxzcFmNOi6ltcbcu4Iu+FL3zEj83dk4kAS+fVpTxXLO1b38RvJgT/0QwvV/L3aY9TAnyv0EOqW4GoMQ==", + "engines": { + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" + } + }, "node_modules/@jridgewell/gen-mapping": { "version": "0.3.8", "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.8.tgz", @@ -1819,32 +1840,6 @@ "integrity": "sha512-GaHYm+c0O9MjZRu0ongGBRbinu8gVAMd2UZjji6jVmqKtZluZnptXGWhz1E8j8D2HJ3f/yMxKAUC0b+57wncIw==", "license": "MIT" }, - "node_modules/@lvce-editor/ripgrep": { - "version": "1.6.0", - "resolved": "https://registry.npmjs.org/@lvce-editor/ripgrep/-/ripgrep-1.6.0.tgz", - "integrity": "sha512-880taWBVULNXmcPHXdxnFUI0FvLErBOjY9OigMXEsLZ2Q1rjcm6LixOkaccKWC8qFMpzm/ldkO7WOMK+ZRfk5Q==", - "hasInstallScript": true, - "license": "MIT", - "dependencies": { - "@lvce-editor/verror": "^1.6.0", - "execa": "^9.5.2", - "extract-zip": "^2.0.1", - "fs-extra": "^11.3.0", - "got": "^14.4.5", - "path-exists": "^5.0.0", - "tempy": "^3.1.0", - "xdg-basedir": "^5.1.0" - } - }, - "node_modules/@lvce-editor/ripgrep/node_modules/path-exists": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-5.0.0.tgz", - "integrity": "sha512-RjhtfwJOxzcFmNOi6ltcbcu4Iu+FL3zEj83dk4kAS+fVpTxXLO1b38RvJgT/0QwvV/L3aY9TAnyv0EOqW4GoMQ==", - "license": "MIT", - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - } - }, "node_modules/@lvce-editor/verror": { "version": "1.7.0", "resolved": "https://registry.npmjs.org/@lvce-editor/verror/-/verror-1.7.0.tgz", @@ -4360,19 +4355,6 @@ "url": "https://github.com/sponsors/isaacs" } }, - "node_modules/@typescript-eslint/typescript-estree/node_modules/semver": { - "version": "7.7.2", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", - "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", - "dev": true, - "license": "ISC", - "bin": { - "semver": "bin/semver.js" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/@typescript-eslint/utils": { "version": "8.35.0", "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-8.35.0.tgz", @@ -4916,19 +4898,6 @@ "node": "20 || >=22" } }, - "node_modules/@vscode/vsce/node_modules/semver": { - "version": "7.7.2", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", - "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", - "dev": true, - "license": "ISC", - "bin": { - "semver": "bin/semver.js" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/@vscode/vsce/node_modules/yallist": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", @@ -6475,33 +6444,6 @@ "node": ">= 8" } }, - "node_modules/crypto-random-string": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/crypto-random-string/-/crypto-random-string-4.0.0.tgz", - "integrity": "sha512-x8dy3RnvYdlUcPOjkEHqozhiwzKNSq7GcPuXFbnyMOCHxX8V3OgIg/pYuabl2sbUPfIJaeAQB7PMOK8DFIdoRA==", - "license": "MIT", - "dependencies": { - "type-fest": "^1.0.1" - }, - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/crypto-random-string/node_modules/type-fest": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-1.4.0.tgz", - "integrity": "sha512-yGSza74xk0UG8k+pLh5oeoYirvIiWo5t0/o3zHHAO2tRDiZcxWP7fywNlXhqb6/r6sWvwi+RsyQMWhVLe4BVuA==", - "license": "(MIT OR CC0-1.0)", - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/css-select": { "version": "5.2.2", "resolved": "https://registry.npmjs.org/css-select/-/css-select-5.2.2.tgz", @@ -7564,6 +7506,16 @@ "ms": "^2.1.1" } }, + "node_modules/eslint-plugin-import/node_modules/semver": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + } + }, "node_modules/eslint-plugin-license-header": { "version": "0.8.0", "resolved": "https://registry.npmjs.org/eslint-plugin-license-header/-/eslint-plugin-license-header-0.8.0.tgz", @@ -7638,6 +7590,16 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/eslint-plugin-react/node_modules/semver": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + } + }, "node_modules/eslint-scope": { "version": "8.4.0", "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-8.4.0.tgz", @@ -8837,9 +8799,9 @@ } }, "node_modules/got": { - "version": "14.4.7", - "resolved": "https://registry.npmjs.org/got/-/got-14.4.7.tgz", - "integrity": "sha512-DI8zV1231tqiGzOiOzQWDhsBmncFW7oQDH6Zgy6pDPrqJuVZMtoSgPLLsBZQj8Jg4JFfwoOsDA8NGtLQLnIx2g==", + "version": "14.4.8", + "resolved": "https://registry.npmjs.org/got/-/got-14.4.8.tgz", + "integrity": "sha512-vxwU4HuR0BIl+zcT1LYrgBjM+IJjNElOjCzs0aPgHorQyr/V6H6Y73Sn3r3FOlUffvWD+Q5jtRuGWaXkU8Jbhg==", "license": "MIT", "dependencies": { "@sindresorhus/is": "^7.0.1", @@ -10447,19 +10409,6 @@ "safe-buffer": "^5.0.1" } }, - "node_modules/jsonwebtoken/node_modules/semver": { - "version": "7.7.2", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", - "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", - "dev": true, - "license": "ISC", - "bin": { - "semver": "bin/semver.js" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/jsx-ast-utils": { "version": "3.3.5", "resolved": "https://registry.npmjs.org/jsx-ast-utils/-/jsx-ast-utils-3.3.5.tgz", @@ -10817,19 +10766,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/make-dir/node_modules/semver": { - "version": "7.7.2", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", - "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", - "dev": true, - "license": "ISC", - "bin": { - "semver": "bin/semver.js" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/markdown-it": { "version": "14.1.0", "resolved": "https://registry.npmjs.org/markdown-it/-/markdown-it-14.1.0.tgz", @@ -11249,20 +11185,6 @@ "node": ">=10" } }, - "node_modules/node-abi/node_modules/semver": { - "version": "7.7.2", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", - "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", - "dev": true, - "license": "ISC", - "optional": true, - "bin": { - "semver": "bin/semver.js" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/node-addon-api": { "version": "4.3.0", "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-4.3.0.tgz", @@ -11352,18 +11274,6 @@ "node": "^16.14.0 || >=18.0.0" } }, - "node_modules/normalize-package-data/node_modules/semver": { - "version": "7.7.2", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", - "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", - "license": "ISC", - "bin": { - "semver": "bin/semver.js" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/normalize-url": { "version": "8.0.2", "resolved": "https://registry.npmjs.org/normalize-url/-/normalize-url-8.0.2.tgz", @@ -11848,18 +11758,6 @@ "integrity": "sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==", "license": "BlueOak-1.0.0" }, - "node_modules/package-json/node_modules/semver": { - "version": "7.7.2", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", - "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", - "license": "ISC", - "bin": { - "semver": "bin/semver.js" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/parent-module": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", @@ -13205,13 +13103,15 @@ } }, "node_modules/semver": { - "version": "6.3.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", - "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", - "dev": true, + "version": "7.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", + "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", "license": "ISC", "bin": { "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" } }, "node_modules/send": { @@ -14353,45 +14253,6 @@ "node": ">= 6" } }, - "node_modules/temp-dir": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/temp-dir/-/temp-dir-3.0.0.tgz", - "integrity": "sha512-nHc6S/bwIilKHNRgK/3jlhDoIHcp45YgyiwcAk46Tr0LfEqGBVpmiAyuiuxeVE44m3mXnEeVhaipLOEWmH+Njw==", - "license": "MIT", - "engines": { - "node": ">=14.16" - } - }, - "node_modules/tempy": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/tempy/-/tempy-3.1.0.tgz", - "integrity": "sha512-7jDLIdD2Zp0bDe5r3D2qtkd1QOCacylBuL7oa4udvN6v2pqr4+LcCr67C8DR1zkpaZ8XosF5m1yQSabKAW6f2g==", - "license": "MIT", - "dependencies": { - "is-stream": "^3.0.0", - "temp-dir": "^3.0.0", - "type-fest": "^2.12.2", - "unique-string": "^3.0.0" - }, - "engines": { - "node": ">=14.16" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/tempy/node_modules/is-stream": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-3.0.0.tgz", - "integrity": "sha512-LnQR4bZ9IADDRSkvpqMGvt/tEJWclzklNgSw48V5EAaAeDd6qGvN8ei6k5p0tvxSR171VmGyHuTiAOfxAbr8kA==", - "license": "MIT", - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/terminal-link": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/terminal-link/-/terminal-link-4.0.0.tgz", @@ -15027,21 +14888,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/unique-string": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/unique-string/-/unique-string-3.0.0.tgz", - "integrity": "sha512-VGXBUVwxKMBUznyffQweQABPRRW1vHZAbadFZud4pLFAqRGvv/96vafgjWFqzourzr8YonlQiPgH0YCJfawoGQ==", - "license": "MIT", - "dependencies": { - "crypto-random-string": "^4.0.0" - }, - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/universalify": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/universalify/-/universalify-0.2.0.tgz", @@ -15149,18 +14995,6 @@ "integrity": "sha512-EC+0oUMY1Rqm4O6LLrgjtYDvcVYTy7chDnM4Q7030tP4Kwj3u/pR6gP9ygnp2CJMK5Gq+9Q2oqmrFJAz01DXjw==", "license": "MIT" }, - "node_modules/update-notifier/node_modules/semver": { - "version": "7.7.2", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", - "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", - "license": "ISC", - "bin": { - "semver": "bin/semver.js" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/update-notifier/node_modules/string-width": { "version": "7.2.0", "resolved": "https://registry.npmjs.org/string-width/-/string-width-7.2.0.tgz", @@ -16106,12 +15940,12 @@ } }, "packages/a2a-server": { - "name": "@google/gemini-cli-a2a-server", - "version": "0.3.3", + "name": "@blocksuser/gemini-cli-a2a-server", + "version": "0.7.0-nightly.20250912.68035591", "dependencies": { "@a2a-js/sdk": "^0.3.2", + "@blocksuser/gemini-cli-core": "file:../core", "@google-cloud/storage": "^7.16.0", - "@google/gemini-cli-core": "file:../core", "express": "^5.1.0", "fs-extra": "^11.3.0", "tar": "^7.4.3", @@ -16377,10 +16211,10 @@ } }, "packages/cli": { - "name": "@google/gemini-cli", - "version": "0.3.3", + "name": "@blocksuser/gemini-cli", + "version": "0.7.0-nightly.20250912.68035591", "dependencies": { - "@google/gemini-cli-core": "file:../core", + "@blocksuser/gemini-cli-core": "file:../core", "@google/genai": "1.16.0", "@iarna/toml": "^2.2.5", "@modelcontextprotocol/sdk": "^1.15.1", @@ -16413,7 +16247,7 @@ }, "devDependencies": { "@babel/runtime": "^7.27.6", - "@google/gemini-cli-test-utils": "file:../test-utils", + "@blocksuser/gemini-cli-test-utils": "file:../test-utils", "@testing-library/react": "^16.3.0", "@types/command-exists": "^1.2.3", "@types/diff": "^7.0.2", @@ -16568,11 +16402,11 @@ } }, "packages/core": { - "name": "@google/gemini-cli-core", - "version": "0.3.3", + "name": "@blocksuser/gemini-cli-core", + "version": "0.7.0-nightly.20250912.68035591", "dependencies": { "@google/genai": "1.16.0", - "@lvce-editor/ripgrep": "^1.6.0", + "@joshua.litt/get-ripgrep": "^0.0.2", "@modelcontextprotocol/sdk": "^1.11.0", "@opentelemetry/api": "^1.9.0", "@opentelemetry/exporter-logs-otlp-grpc": "^0.203.0", @@ -16612,7 +16446,7 @@ "ws": "^8.18.0" }, "devDependencies": { - "@google/gemini-cli-test-utils": "file:../test-utils", + "@blocksuser/gemini-cli-test-utils": "file:../test-utils", "@types/diff": "^7.0.2", "@types/dotenv": "^6.1.1", "@types/fast-levenshtein": "^0.0.4", @@ -16709,8 +16543,8 @@ } }, "packages/test-utils": { - "name": "@google/gemini-cli-test-utils", - "version": "0.3.3", + "name": "@blocksuser/gemini-cli-test-utils", + "version": "0.7.0-nightly.20250912.68035591", "license": "Apache-2.0", "devDependencies": { "typescript": "^5.3.3" @@ -16721,7 +16555,7 @@ }, "packages/vscode-ide-companion": { "name": "gemini-cli-vscode-ide-companion", - "version": "0.3.3", + "version": "0.7.0-nightly.20250912.68035591", "license": "LICENSE", "dependencies": { "@modelcontextprotocol/sdk": "^1.15.1", diff --git a/package.json b/package.json index e32094bb01a..ee0a7e9e810 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@blocksuser/gemini-cli", - "version": "0.3.3", + "version": "0.7.0-nightly.20250912.68035591", "engines": { "node": ">=20.0.0" }, @@ -13,7 +13,7 @@ "url": "git+https://github.com/google-gemini/gemini-cli.git" }, "config": { - "sandboxImageUri": "us-docker.pkg.dev/gemini-code-dev/gemini-cli/sandbox:0.3.3" + "sandboxImageUri": "us-docker.pkg.dev/gemini-code-dev/gemini-cli/sandbox:0.7.0-nightly.20250912.68035591" }, "scripts": { "start": "node scripts/start.js", @@ -90,7 +90,6 @@ "yargs": "^17.7.2" }, "dependencies": { - "@lvce-editor/ripgrep": "^1.6.0", "simple-git": "^3.28.0" }, "optionalDependencies": { diff --git a/packages/a2a-server/package.json b/packages/a2a-server/package.json index cabc3464dfa..a786f340f77 100644 --- a/packages/a2a-server/package.json +++ b/packages/a2a-server/package.json @@ -1,6 +1,6 @@ { "name": "@blocksuser/gemini-cli-a2a-server", - "version": "0.3.3", + "version": "0.7.0-nightly.20250912.68035591", "private": true, "description": "Gemini CLI A2A Server", "repository": { diff --git a/packages/a2a-server/src/agent/executor.ts b/packages/a2a-server/src/agent/executor.ts index df8762a4e41..5f8ce2cc5d8 100644 --- a/packages/a2a-server/src/agent/executor.ts +++ b/packages/a2a-server/src/agent/executor.ts @@ -36,6 +36,7 @@ import { loadSettings } from '../config/settings.js'; import { loadExtensions } from '../config/extension.js'; import { Task } from './task.js'; import { requestStorage } from '../http/requestStorage.js'; +import { pushTaskStateFailed } from '../utils/executor_utils.js'; /** * Provides a wrapper for Task. Passes data from Task to SDKTask. @@ -116,8 +117,8 @@ export class CoderAgentExecutor implements AgentExecutor { const agentSettings = persistedState._agentSettings; const config = await this.getConfig(agentSettings, sdkTask.id); - const contextId = - (metadata['_contextId'] as string) || (sdkTask.contextId as string); + const contextId: string = + (metadata['_contextId'] as string) || sdkTask.contextId; const runtimeTask = await Task.create( sdkTask.id, contextId, @@ -125,9 +126,7 @@ export class CoderAgentExecutor implements AgentExecutor { eventBus, ); runtimeTask.taskState = persistedState._taskState; - await runtimeTask.geminiClient.initialize( - runtimeTask.config.getContentGeneratorConfig(), - ); + await runtimeTask.geminiClient.initialize(); const wrapper = new TaskWrapper(runtimeTask, agentSettings); this.tasks.set(sdkTask.id, wrapper); @@ -144,9 +143,7 @@ export class CoderAgentExecutor implements AgentExecutor { const agentSettings = agentSettingsInput || ({} as AgentSettings); const config = await this.getConfig(agentSettings, taskId); const runtimeTask = await Task.create(taskId, contextId, config, eventBus); - await runtimeTask.geminiClient.initialize( - runtimeTask.config.getContentGeneratorConfig(), - ); + await runtimeTask.geminiClient.initialize(); const wrapper = new TaskWrapper(runtimeTask, agentSettings); this.tasks.set(taskId, wrapper); @@ -284,10 +281,10 @@ export class CoderAgentExecutor implements AgentExecutor { const sdkTask = requestContext.task as SDKTask | undefined; const taskId = sdkTask?.id || userMessage.taskId || uuidv4(); - const contextId = + const contextId: string = userMessage.contextId || sdkTask?.contextId || - sdkTask?.metadata?.['_contextId'] || + (sdkTask?.metadata?.['_contextId'] as string) || uuidv4(); logger.info( @@ -385,12 +382,21 @@ export class CoderAgentExecutor implements AgentExecutor { const agentSettings = userMessage.metadata?.[ 'coderAgent' ] as AgentSettings; - wrapper = await this.createTask( - taskId, - contextId as string, - agentSettings, - eventBus, - ); + try { + wrapper = await this.createTask( + taskId, + contextId, + agentSettings, + eventBus, + ); + } catch (error) { + logger.error( + `[CoderAgentExecutor] Error creating task ${taskId}:`, + error, + ); + pushTaskStateFailed(error, eventBus, taskId, contextId); + return; + } const newTaskSDK = wrapper.toSDKTask(); eventBus.publish({ ...newTaskSDK, diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index 9ecf5e85d70..cddd5ec3257 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -6,7 +6,7 @@ import { CoreToolScheduler, - GeminiClient, + type GeminiClient, GeminiEventType, ToolConfirmationOutcome, ApprovalMode, @@ -25,6 +25,7 @@ import type { ToolCallConfirmationDetails, Config, UserTierId, + AnsiOutput, } from '@blocksuser/gemini-cli-core'; import type { RequestContext } from '@a2a-js/sdk/server'; import { type ExecutionEventBus } from '@a2a-js/sdk/server'; @@ -82,18 +83,17 @@ export class Task { this.contextId = contextId; this.config = config; this.scheduler = this.createScheduler(); - this.geminiClient = new GeminiClient(this.config); + this.geminiClient = this.config.getGeminiClient(); this.pendingToolConfirmationDetails = new Map(); this.taskState = 'submitted'; this.eventBus = eventBus; this.completedToolCalls = []; this._resetToolCompletionPromise(); - this.config.setFlashFallbackHandler( - async (currentModel: string, fallbackModel: string): Promise => { - config.setModel(fallbackModel); // gemini-cli-core sets to DEFAULT_GEMINI_FLASH_MODEL - // Switch model for future use but return false to stop current retry - return false; - }, + this.config.setFallbackModelHandler( + // For a2a-server, we want to automatically switch to the fallback model + // for future requests without retrying the current one. The 'stop' + // intent achieves this. + async () => 'stop', ); } @@ -133,7 +133,7 @@ export class Task { id: this.id, contextId: this.contextId, taskState: this.taskState, - model: this.config.getContentGeneratorConfig().model, + model: this.config.getModel(), mcpServers: servers, availableTools, }; @@ -227,7 +227,7 @@ export class Task { } = { coderAgent: coderAgentMessage, model: this.config.getModel(), - userTier: this.geminiClient.getUserTier(), + userTier: this.config.getUserTier(), }; if (metadataError) { @@ -285,20 +285,29 @@ export class Task { private _schedulerOutputUpdate( toolCallId: string, - outputChunk: string, + outputChunk: string | AnsiOutput, ): void { + let outputAsText: string; + if (typeof outputChunk === 'string') { + outputAsText = outputChunk; + } else { + outputAsText = outputChunk + .map((line) => line.map((token) => token.text).join('')) + .join('\n'); + } + logger.info( '[Task] Scheduler output update for tool call ' + toolCallId + ': ' + - outputChunk, + outputAsText, ); const artifact: Artifact = { artifactId: `tool-${toolCallId}-output`, parts: [ { kind: 'text', - text: outputChunk, + text: outputAsText, } as Part, ], }; diff --git a/packages/a2a-server/src/config/config.ts b/packages/a2a-server/src/config/config.ts index 17bb9e91053..637423550a3 100644 --- a/packages/a2a-server/src/config/config.ts +++ b/packages/a2a-server/src/config/config.ts @@ -69,6 +69,7 @@ export async function loadConfig( settings.fileFiltering?.enableRecursiveFileSearch, }, ideMode: false, + folderTrust: settings.folderTrust === true, }; const fileService = new FileDiscoveryService(workspaceDir); @@ -79,7 +80,7 @@ export async function loadConfig( false, fileService, extensionContextFilePaths, - true, /// TODO: Wire up folder trust logic here. + settings.folderTrust === true, ); configParams.userMemory = memoryContent; configParams.geminiMdFileCount = fileCount; @@ -108,9 +109,10 @@ export async function loadConfig( logger.info('[Config] Using Gemini API Key'); await config.refreshAuth(AuthType.USE_GEMINI); } else { - logger.error( - `[Config] Unable to set GeneratorConfig. Please provide a GEMINI_API_KEY or set USE_CCPA.`, - ); + const errorMessage = + '[Config] Unable to set GeneratorConfig. Please provide a GEMINI_API_KEY or set USE_CCPA.'; + logger.error(errorMessage); + throw new Error(errorMessage); } return config; diff --git a/packages/a2a-server/src/config/settings.ts b/packages/a2a-server/src/config/settings.ts index 1b9921eff42..323835751ec 100644 --- a/packages/a2a-server/src/config/settings.ts +++ b/packages/a2a-server/src/config/settings.ts @@ -27,6 +27,7 @@ export interface Settings { telemetry?: TelemetrySettings; showMemoryUsage?: boolean; checkpointing?: CheckpointingSettings; + folderTrust?: boolean; // Git-aware file filtering settings fileFiltering?: { diff --git a/packages/a2a-server/src/http/app.test.ts b/packages/a2a-server/src/http/app.test.ts index f43f89bd9de..574b2daca90 100644 --- a/packages/a2a-server/src/http/app.test.ts +++ b/packages/a2a-server/src/http/app.test.ts @@ -64,6 +64,7 @@ vi.mock('../utils/logger.js', () => ({ let config: Config; const getToolRegistrySpy = vi.fn().mockReturnValue(ApprovalMode.DEFAULT); const getApprovalModeSpy = vi.fn(); +const getShellExecutionConfigSpy = vi.fn(); vi.mock('../config/config.js', async () => { const actual = await vi.importActual('../config/config.js'); return { @@ -72,6 +73,7 @@ vi.mock('../config/config.js', async () => { const mockConfig = createMockConfig({ getToolRegistry: getToolRegistrySpy, getApprovalMode: getApprovalModeSpy, + getShellExecutionConfig: getShellExecutionConfigSpy, }); config = mockConfig as Config; return config; diff --git a/packages/a2a-server/src/http/endpoints.test.ts b/packages/a2a-server/src/http/endpoints.test.ts index 5a5116ef902..ff589443a89 100644 --- a/packages/a2a-server/src/http/endpoints.test.ts +++ b/packages/a2a-server/src/http/endpoints.test.ts @@ -7,14 +7,17 @@ import { describe, it, expect, beforeAll, afterAll, vi } from 'vitest'; import request from 'supertest'; import type express from 'express'; -import { createApp, updateCoderAgentCardUrl } from './app.js'; import * as fs from 'node:fs'; import * as path from 'node:path'; import * as os from 'node:os'; import type { Server } from 'node:http'; -import type { TaskMetadata } from '../types.js'; import type { AddressInfo } from 'node:net'; +import { createApp, updateCoderAgentCardUrl } from './app.js'; +import type { TaskMetadata } from '../types.js'; +import { createMockConfig } from '../utils/testing_utils.js'; +import type { Config } from '@blocksuser/gemini-cli-core'; + // Mock the logger to avoid polluting test output // Comment out to help debug vi.mock('../utils/logger.js', () => ({ @@ -56,6 +59,16 @@ vi.mock('../agent/task.js', () => { return { Task: MockTask }; }); +vi.mock('../config/config.js', async () => { + const actual = await vi.importActual('../config/config.js'); + return { + ...actual, + loadConfig: vi + .fn() + .mockImplementation(async () => createMockConfig({}) as Config), + }; +}); + describe('Agent Server Endpoints', () => { let app: express.Express; let server: Server; diff --git a/packages/a2a-server/src/utils/executor_utils.ts b/packages/a2a-server/src/utils/executor_utils.ts new file mode 100644 index 00000000000..b595a6905b2 --- /dev/null +++ b/packages/a2a-server/src/utils/executor_utils.ts @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Message } from '@a2a-js/sdk'; +import type { ExecutionEventBus } from '@a2a-js/sdk/server'; +import { v4 as uuidv4 } from 'uuid'; + +import { CoderAgentEvent } from '../types.js'; +import type { StateChange } from '../types.js'; + +export async function pushTaskStateFailed( + error: unknown, + eventBus: ExecutionEventBus, + taskId: string, + contextId: string, +) { + const errorMessage = + error instanceof Error ? error.message : 'Agent execution error'; + const stateChange: StateChange = { + kind: CoderAgentEvent.StateChangeEvent, + }; + eventBus.publish({ + kind: 'status-update', + taskId, + contextId, + status: { + state: 'failed', + message: { + kind: 'message', + role: 'agent', + parts: [ + { + kind: 'text', + text: errorMessage, + }, + ], + messageId: uuidv4(), + taskId, + contextId, + } as Message, + }, + final: true, + metadata: { + coderAgent: stateChange, + model: 'unknown', + error: errorMessage, + }, + }); +} diff --git a/packages/a2a-server/src/utils/testing_utils.ts b/packages/a2a-server/src/utils/testing_utils.ts index 9e1114f377a..6141c3330a7 100644 --- a/packages/a2a-server/src/utils/testing_utils.ts +++ b/packages/a2a-server/src/utils/testing_utils.ts @@ -13,6 +13,7 @@ import { ApprovalMode, DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, + GeminiClient, } from '@blocksuser/gemini-cli-core'; import type { Config, Storage } from '@blocksuser/gemini-cli-core'; import { expect, vi } from 'vitest'; @@ -38,19 +39,23 @@ export function createMockConfig( getTruncateToolOutputThreshold: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, - getGeminiClient: vi.fn(), getDebugMode: vi.fn().mockReturnValue(false), getContentGeneratorConfig: vi.fn().mockReturnValue({ model: 'gemini-pro' }), getModel: vi.fn().mockReturnValue('gemini-pro'), getUsageStatisticsEnabled: vi.fn().mockReturnValue(false), - setFlashFallbackHandler: vi.fn(), + setFallbackModelHandler: vi.fn(), initialize: vi.fn().mockResolvedValue(undefined), getProxy: vi.fn().mockReturnValue(undefined), getHistory: vi.fn().mockReturnValue([]), getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'), getSessionId: vi.fn().mockReturnValue('test-session-id'), + getUserTier: vi.fn(), ...overrides, - }; + } as unknown as Config; + + mockConfig.getGeminiClient = vi + .fn() + .mockReturnValue(new GeminiClient(mockConfig)); return mockConfig; } diff --git a/packages/cli/package.json b/packages/cli/package.json index 9dc86fa54a9..cc4a8042c6c 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -1,6 +1,6 @@ { "name": "@blocksuser/gemini-cli", - "version": "0.3.3", + "version": "0.7.0-nightly.20250912.68035591", "description": "Gemini CLI", "repository": { "type": "git", @@ -25,7 +25,7 @@ "dist" ], "config": { - "sandboxImageUri": "us-docker.pkg.dev/gemini-code-dev/gemini-cli/sandbox:0.3.3" + "sandboxImageUri": "us-docker.pkg.dev/gemini-code-dev/gemini-cli/sandbox:0.7.0-nightly.20250912.68035591" }, "dependencies": { "@blocksuser/gemini-cli-core": "file:../core", diff --git a/packages/cli/src/commands/extensions/install.test.ts b/packages/cli/src/commands/extensions/install.test.ts index 2f66e432a12..441855da2c1 100644 --- a/packages/cli/src/commands/extensions/install.test.ts +++ b/packages/cli/src/commands/extensions/install.test.ts @@ -5,9 +5,8 @@ */ import { describe, it, expect } from 'vitest'; -import { installCommand, handleInstall } from './install.js'; +import { installCommand } from './install.js'; import yargs from 'yargs'; -import * as extension from '../../config/extension.js'; vi.mock('../../config/extension.js', () => ({ installExtension: vi.fn(), @@ -28,22 +27,3 @@ describe('extensions install command', () => { ).toThrow('Arguments source and path are mutually exclusive'); }); }); - -describe('extensions install with org/repo', () => { - it('should call installExtension with the correct git URL', async () => { - const consoleLogSpy = vi.spyOn(console, 'log').mockImplementation(() => {}); - const installExtensionSpy = vi - .spyOn(extension, 'installExtension') - .mockResolvedValue('test-extension'); - - await handleInstall({ source: 'test-org/test-repo' }); - - expect(installExtensionSpy).toHaveBeenCalledWith({ - source: 'https://github.com/test-org/test-repo.git', - type: 'git', - }); - expect(consoleLogSpy).toHaveBeenCalledWith( - 'Extension "test-extension" installed successfully and enabled.', - ); - }); -}); diff --git a/packages/cli/src/commands/extensions/install.ts b/packages/cli/src/commands/extensions/install.ts index 56eb2a18674..fe43fdbf0ad 100644 --- a/packages/cli/src/commands/extensions/install.ts +++ b/packages/cli/src/commands/extensions/install.ts @@ -15,10 +15,9 @@ import { getErrorMessage } from '../../utils/errors.js'; interface InstallArgs { source?: string; path?: string; + ref?: string; } -const ORG_REPO_REGEX = /^[a-zA-Z0-9-]+\/[\w.-]+$/; - export async function handleInstall(args: InstallArgs) { try { let installMetadata: ExtensionInstallMetadata; @@ -33,16 +32,10 @@ export async function handleInstall(args: InstallArgs) { installMetadata = { source, type: 'git', - }; - } else if (ORG_REPO_REGEX.test(source)) { - installMetadata = { - source: `https://github.com/${source}.git`, - type: 'git', + ref: args.ref, }; } else { - throw new Error( - `The source "${source}" is not a valid URL or "org/repo" format.`, - ); + throw new Error(`The source "${source}" is not a valid URL format.`); } } else if (args.path) { installMetadata = { @@ -54,10 +47,8 @@ export async function handleInstall(args: InstallArgs) { throw new Error('Either --source or --path must be provided.'); } - const extensionName = await installExtension(installMetadata); - console.log( - `Extension "${extensionName}" installed successfully and enabled.`, - ); + const name = await installExtension(installMetadata); + console.log(`Extension "${name}" installed successfully and enabled.`); } catch (error) { console.error(getErrorMessage(error)); process.exit(1); @@ -66,19 +57,23 @@ export async function handleInstall(args: InstallArgs) { export const installCommand: CommandModule = { command: 'install [source]', - describe: - 'Installs an extension from a git repository (URL or "org/repo") or a local path.', + describe: 'Installs an extension from a git repository URL or a local path.', builder: (yargs) => yargs .positional('source', { - describe: 'The git URL or "org/repo" of the extension to install.', + describe: 'The github URL of the extension to install.', type: 'string', }) .option('path', { describe: 'Path to a local extension directory.', type: 'string', }) + .option('ref', { + describe: 'The git ref to install from.', + type: 'string', + }) .conflicts('source', 'path') + .conflicts('path', 'ref') .check((argv) => { if (!argv.source && !argv.path) { throw new Error('Either source or --path must be provided.'); @@ -89,6 +84,7 @@ export const installCommand: CommandModule = { await handleInstall({ source: argv['source'] as string | undefined, path: argv['path'] as string | undefined, + ref: argv['ref'] as string | undefined, }); }, }; diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index 3cd5c798554..80ff0e8f425 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -1645,28 +1645,28 @@ describe('loadCliConfig useRipgrep', () => { vi.restoreAllMocks(); }); - it('should be false by default when useRipgrep is not set in settings', async () => { + it('should be true by default when useRipgrep is not set in settings', async () => { process.argv = ['node', 'script.js']; const argv = await parseArguments({} as Settings); const settings: Settings = {}; const config = await loadCliConfig(settings, [], 'test-session', argv); - expect(config.getUseRipgrep()).toBe(false); + expect(config.getUseRipgrep()).toBe(true); }); - it('should be true when useRipgrep is set to true in settings', async () => { + it('should be false when useRipgrep is set to false in settings', async () => { process.argv = ['node', 'script.js']; const argv = await parseArguments({} as Settings); - const settings: Settings = { tools: { useRipgrep: true } }; + const settings: Settings = { tools: { useRipgrep: false } }; const config = await loadCliConfig(settings, [], 'test-session', argv); - expect(config.getUseRipgrep()).toBe(true); + expect(config.getUseRipgrep()).toBe(false); }); - it('should be false when useRipgrep is explicitly set to false in settings', async () => { + it('should be true when useRipgrep is explicitly set to true in settings', async () => { process.argv = ['node', 'script.js']; const argv = await parseArguments({} as Settings); - const settings: Settings = { tools: { useRipgrep: false } }; + const settings: Settings = { tools: { useRipgrep: true } }; const config = await loadCliConfig(settings, [], 'test-session', argv); - expect(config.getUseRipgrep()).toBe(false); + expect(config.getUseRipgrep()).toBe(true); }); }); @@ -1972,6 +1972,55 @@ describe('loadCliConfig fileFiltering', () => { ); }); +describe('Output Format Configuration', () => { + const originalArgv = process.argv; + + afterEach(() => { + process.argv = originalArgv; + vi.restoreAllMocks(); + }); + + it('should default to text format when no setting or flag is provided', async () => { + process.argv = ['node', 'script.js']; + const argv = await parseArguments({} as Settings); + const config = await loadCliConfig( + {} as Settings, + [], + 'test-session', + argv, + ); + expect(config.getOutputFormat()).toBe(ServerConfig.OutputFormat.TEXT); + }); + + it('should use the format from settings when no flag is provided', async () => { + process.argv = ['node', 'script.js']; + const settings: Settings = { output: { format: 'json' } }; + const argv = await parseArguments(settings); + const config = await loadCliConfig(settings, [], 'test-session', argv); + expect(config.getOutputFormat()).toBe(ServerConfig.OutputFormat.JSON); + }); + + it('should use the format from the flag when provided', async () => { + process.argv = ['node', 'script.js', '--output-format', 'json']; + const argv = await parseArguments({} as Settings); + const config = await loadCliConfig( + {} as Settings, + [], + 'test-session', + argv, + ); + expect(config.getOutputFormat()).toBe(ServerConfig.OutputFormat.JSON); + }); + + it('should prioritize the flag over the setting', async () => { + process.argv = ['node', 'script.js', '--output-format', 'text']; + const settings: Settings = { output: { format: 'json' } }; + const argv = await parseArguments(settings); + const config = await loadCliConfig(settings, [], 'test-session', argv); + expect(config.getOutputFormat()).toBe(ServerConfig.OutputFormat.TEXT); + }); +}); + describe('parseArguments with positional prompt', () => { const originalArgv = process.argv; diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 618a8641124..3825a884a15 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -15,6 +15,7 @@ import type { TelemetryTarget, FileFilteringOptions, MCPServerConfig, + OutputFormat, } from '@blocksuser/gemini-cli-core'; import { extensionsCommand } from '../commands/extensions.js'; import { @@ -81,6 +82,7 @@ export interface CliArgs { useSmartEdit: boolean | undefined; sessionSummary: string | undefined; promptWords: string[] | undefined; + outputFormat: string | undefined; resume: string | undefined; listSessions: boolean | undefined; deleteSession: string | undefined; @@ -237,6 +239,11 @@ export async function parseArguments(settings: Settings): Promise { type: 'string', description: 'File to write session summary to.', }) + .option('output-format', { + type: 'string', + description: 'The format of the CLI output.', + choices: ['text', 'json', 'stream-json'], + }) .option('resume', { alias: 'r', type: 'string', @@ -252,43 +259,43 @@ export async function parseArguments(settings: Settings): Promise { }) .deprecateOption( 'telemetry', - 'Use settings.json instead. This flag will be removed in a future version.', + 'Use the "telemetry.enabled" setting in settings.json instead. This flag will be removed in a future version.', ) .deprecateOption( 'telemetry-target', - 'Use settings.json instead. This flag will be removed in a future version.', + 'Use the "telemetry.target" setting in settings.json instead. This flag will be removed in a future version.', ) .deprecateOption( 'telemetry-otlp-endpoint', - 'Use settings.json instead. This flag will be removed in a future version.', + 'Use the "telemetry.otlpEndpoint" setting in settings.json instead. This flag will be removed in a future version.', ) .deprecateOption( 'telemetry-otlp-protocol', - 'Use settings.json instead. This flag will be removed in a future version.', + 'Use the "telemetry.otlpProtocol" setting in settings.json instead. This flag will be removed in a future version.', ) .deprecateOption( 'telemetry-log-prompts', - 'Use settings.json instead. This flag will be removed in a future version.', + 'Use the "telemetry.logPrompts" setting in settings.json instead. This flag will be removed in a future version.', ) .deprecateOption( 'telemetry-outfile', - 'Use settings.json instead. This flag will be removed in a future version.', + 'Use the "telemetry.outfile" setting in settings.json instead. This flag will be removed in a future version.', ) .deprecateOption( 'show-memory-usage', - 'Use settings.json instead. This flag will be removed in a future version.', + 'Use the "ui.showMemoryUsage" setting in settings.json instead. This flag will be removed in a future version.', ) .deprecateOption( 'sandbox-image', - 'Use settings.json instead. This flag will be removed in a future version.', + 'Use the "tools.sandbox" setting in settings.json instead. This flag will be removed in a future version.', ) .deprecateOption( 'proxy', - 'Use settings.json instead. This flag will be removed in a future version.', + 'Use the "proxy" setting in settings.json instead. This flag will be removed in a future version.', ) .deprecateOption( 'checkpointing', - 'Use settings.json instead. This flag will be removed in a future version.', + 'Use the "general.checkpointing.enabled" setting in settings.json instead. This flag will be removed in a future version.', ) .deprecateOption( 'all-files', @@ -650,8 +657,12 @@ export async function loadCliConfig( enablePromptCompletion: settings.general?.enablePromptCompletion ?? false, truncateToolOutputThreshold: settings.tools?.truncateToolOutputThreshold, truncateToolOutputLines: settings.tools?.truncateToolOutputLines, + enableToolOutputTruncation: settings.tools?.enableToolOutputTruncation, eventEmitter: appEvents, useSmartEdit: argv.useSmartEdit ?? settings.useSmartEdit, + output: { + format: (argv.outputFormat ?? settings.output?.format) as OutputFormat, + }, }); } diff --git a/packages/cli/src/config/extension.test.ts b/packages/cli/src/config/extension.test.ts index 5999d8d3876..93eda1534d3 100644 --- a/packages/cli/src/config/extension.test.ts +++ b/packages/cli/src/config/extension.test.ts @@ -12,8 +12,10 @@ import { EXTENSIONS_CONFIG_FILENAME, INSTALL_METADATA_FILENAME, annotateActiveExtensions, + checkForExtensionUpdates, disableExtension, enableExtension, + ExtensionUpdateStatus, installExtension, loadExtension, loadExtensions, @@ -26,14 +28,30 @@ import { GEMINI_DIR, type GeminiCLIExtension, type MCPServerConfig, + ClearcutLogger, + type Config, } from '@blocksuser/gemini-cli-core'; import { execSync } from 'node:child_process'; import { SettingScope, loadSettings } from './settings.js'; -import { type SimpleGit, simpleGit } from 'simple-git'; import { isWorkspaceTrusted } from './trustedFolders.js'; +const mockGit = { + clone: vi.fn(), + getRemotes: vi.fn(), + fetch: vi.fn(), + checkout: vi.fn(), + listRemote: vi.fn(), + revparse: vi.fn(), + // Not a part of the actual API, but we need to use this to do the correct + // file system interactions. + path: vi.fn(), +}; + vi.mock('simple-git', () => ({ - simpleGit: vi.fn(), + simpleGit: vi.fn((path: string) => { + mockGit.path.mockReturnValue(path); + return mockGit; + }), })); vi.mock('os', async (importOriginal) => { @@ -52,6 +70,22 @@ vi.mock('./trustedFolders.js', async (importOriginal) => { }; }); +vi.mock('@blocksuser/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + const mockLogExtensionInstallEvent = vi.fn(); + return { + ...actual, + ClearcutLogger: { + getInstance: vi.fn(() => ({ + logExtensionInstallEvent: mockLogExtensionInstallEvent, + })), + }, + Config: vi.fn(), + ExtensionInstallEvent: vi.fn(), + }; +}); + vi.mock('child_process', async (importOriginal) => { const actual = await importOriginal(); return { @@ -410,6 +444,7 @@ describe('installExtension', () => { fs.mkdirSync(userExtensionsDir, { recursive: true }); vi.mocked(isWorkspaceTrusted).mockReturnValue(true); vi.mocked(execSync).mockClear(); + Object.values(mockGit).forEach((fn) => fn.mockReset()); }); afterEach(() => { @@ -472,16 +507,14 @@ describe('installExtension', () => { const targetExtDir = path.join(userExtensionsDir, extensionName); const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME); - const clone = vi.fn().mockImplementation(async (_, destination) => { - fs.mkdirSync(destination, { recursive: true }); + mockGit.clone.mockImplementation(async (_, destination) => { + fs.mkdirSync(path.join(mockGit.path(), destination), { recursive: true }); fs.writeFileSync( - path.join(destination, EXTENSIONS_CONFIG_FILENAME), + path.join(mockGit.path(), destination, EXTENSIONS_CONFIG_FILENAME), JSON.stringify({ name: extensionName, version: '1.0.0' }), ); }); - - const mockedSimpleGit = simpleGit as vi.MockedFunction; - mockedSimpleGit.mockReturnValue({ clone } as unknown as SimpleGit); + mockGit.getRemotes.mockResolvedValue([{ name: 'origin' }]); await installExtension({ source: gitUrl, type: 'git' }); @@ -519,6 +552,19 @@ describe('installExtension', () => { }); fs.rmSync(targetExtDir, { recursive: true, force: true }); }); + + it('should log to clearcut on successful install', async () => { + const sourceExtDir = createExtension({ + extensionsDir: tempHomeDir, + name: 'my-local-extension', + version: '1.0.0', + }); + + await installExtension({ source: sourceExtDir, type: 'local' }); + + const logger = ClearcutLogger.getInstance({} as Config); + expect(logger?.logExtensionInstallEvent).toHaveBeenCalled(); + }); }); describe('uninstallExtension', () => { @@ -756,6 +802,7 @@ describe('updateExtension', () => { vi.mocked(isWorkspaceTrusted).mockReturnValue(true); vi.mocked(execSync).mockClear(); + Object.values(mockGit).forEach((fn) => fn.mockReset()); }); afterEach(() => { @@ -778,20 +825,16 @@ describe('updateExtension', () => { JSON.stringify({ source: gitUrl, type: 'git' }), ); - const clone = vi.fn().mockImplementation(async (_, destination) => { - fs.mkdirSync(destination, { recursive: true }); + mockGit.clone.mockImplementation(async (_, destination) => { + fs.mkdirSync(path.join(mockGit.path(), destination), { recursive: true }); fs.writeFileSync( - path.join(destination, EXTENSIONS_CONFIG_FILENAME), + path.join(mockGit.path(), destination, EXTENSIONS_CONFIG_FILENAME), JSON.stringify({ name: extensionName, version: '1.1.0' }), ); }); + mockGit.getRemotes.mockResolvedValue([{ name: 'origin' }]); - const mockedSimpleGit = simpleGit as vi.MockedFunction; - mockedSimpleGit.mockReturnValue({ - clone, - } as unknown as SimpleGit); - - const updateInfo = await updateExtension(loadExtension(targetExtDir)); + const updateInfo = await updateExtension(loadExtension(targetExtDir)!); expect(updateInfo).toEqual({ name: 'gemini-extensions', @@ -809,6 +852,105 @@ describe('updateExtension', () => { }); }); +describe('checkForExtensionUpdates', () => { + let tempHomeDir: string; + let userExtensionsDir: string; + + beforeEach(() => { + tempHomeDir = fs.mkdtempSync( + path.join(os.tmpdir(), 'gemini-cli-test-home-'), + ); + vi.mocked(os.homedir).mockReturnValue(tempHomeDir); + userExtensionsDir = path.join(tempHomeDir, GEMINI_DIR, 'extensions'); + fs.mkdirSync(userExtensionsDir, { recursive: true }); + Object.values(mockGit).forEach((fn) => fn.mockReset()); + }); + + afterEach(() => { + fs.rmSync(tempHomeDir, { recursive: true, force: true }); + }); + + it('should return UpdateAvailable for a git extension with updates', async () => { + const extensionDir = createExtension({ + extensionsDir: userExtensionsDir, + name: 'test-extension', + version: '1.0.0', + }); + const extension = loadExtension(extensionDir)!; + extension.installMetadata = { + source: 'https://some.git/repo', + type: 'git', + }; + + mockGit.getRemotes.mockResolvedValue([ + { name: 'origin', refs: { fetch: 'https://some.git/repo' } }, + ]); + mockGit.listRemote.mockResolvedValue('remoteHash HEAD'); + mockGit.revparse.mockResolvedValue('localHash'); + + const results = await checkForExtensionUpdates([extension]); + const result = results.get('test-extension'); + expect(result?.status).toBe(ExtensionUpdateStatus.UpdateAvailable); + }); + + it('should return UpToDate for a git extension with no updates', async () => { + const extensionDir = createExtension({ + extensionsDir: userExtensionsDir, + name: 'test-extension', + version: '1.0.0', + }); + const extension = loadExtension(extensionDir)!; + extension.installMetadata = { + source: 'https://some.git/repo', + type: 'git', + }; + + mockGit.getRemotes.mockResolvedValue([ + { name: 'origin', refs: { fetch: 'https://some.git/repo' } }, + ]); + mockGit.listRemote.mockResolvedValue('sameHash HEAD'); + mockGit.revparse.mockResolvedValue('sameHash'); + + const results = await checkForExtensionUpdates([extension]); + const result = results.get('test-extension'); + expect(result?.status).toBe(ExtensionUpdateStatus.UpToDate); + }); + + it('should return NotUpdatable for a non-git extension', async () => { + const extensionDir = createExtension({ + extensionsDir: userExtensionsDir, + name: 'local-extension', + version: '1.0.0', + }); + const extension = loadExtension(extensionDir)!; + extension.installMetadata = { source: '/local/path', type: 'local' }; + + const results = await checkForExtensionUpdates([extension]); + const result = results.get('local-extension'); + expect(result?.status).toBe(ExtensionUpdateStatus.NotUpdatable); + }); + + it('should return Error when git check fails', async () => { + const extensionDir = createExtension({ + extensionsDir: userExtensionsDir, + name: 'error-extension', + version: '1.0.0', + }); + const extension = loadExtension(extensionDir)!; + extension.installMetadata = { + source: 'https://some.git/repo', + type: 'git', + }; + + mockGit.getRemotes.mockRejectedValue(new Error('Git error')); + + const results = await checkForExtensionUpdates([extension]); + const result = results.get('error-extension'); + expect(result?.status).toBe(ExtensionUpdateStatus.Error); + expect(result?.error).toContain('Failed to check for updates'); + }); +}); + describe('disableExtension', () => { let tempWorkspaceDir: string; let tempHomeDir: string; diff --git a/packages/cli/src/config/extension.ts b/packages/cli/src/config/extension.ts index 90a699c2591..273553df25e 100644 --- a/packages/cli/src/config/extension.ts +++ b/packages/cli/src/config/extension.ts @@ -8,7 +8,13 @@ import type { MCPServerConfig, GeminiCLIExtension, } from '@blocksuser/gemini-cli-core'; -import { GEMINI_DIR, Storage } from '@blocksuser/gemini-cli-core'; +import { + GEMINI_DIR, + Storage, + ClearcutLogger, + Config, + ExtensionInstallEvent, +} from '@blocksuser/gemini-cli-core'; import * as fs from 'node:fs'; import * as path from 'node:path'; import * as os from 'node:os'; @@ -18,6 +24,7 @@ import { getErrorMessage } from '../utils/errors.js'; import { recursivelyHydrateStrings } from './extensions/variables.js'; import { isWorkspaceTrusted } from './trustedFolders.js'; import { resolveEnvVarsInObject } from '../utils/envVarResolver.js'; +import { randomUUID } from 'node:crypto'; export const EXTENSIONS_DIRECTORY_NAME = path.join(GEMINI_DIR, 'extensions'); @@ -42,6 +49,7 @@ export interface ExtensionConfig { export interface ExtensionInstallMetadata { source: string; type: 'git' | 'local' | 'link'; + ref?: string; } export interface ExtensionUpdateInfo { @@ -81,6 +89,10 @@ export class ExtensionStorage { } export function getWorkspaceExtensions(workspaceDir: string): Extension[] { + // If the workspace dir is the user extensions dir, there are no workspace extensions. + if (path.resolve(workspaceDir) === path.resolve(os.homedir())) { + return []; + } return loadExtensionsFromDir(workspaceDir); } @@ -325,20 +337,38 @@ export function annotateActiveExtensions( /** * Clones a Git repository to a specified local path. - * @param gitUrl The Git URL to clone. + * @param installMetadata The metadata for the extension to install. * @param destination The destination path to clone the repository to. */ async function cloneFromGit( - gitUrl: string, + installMetadata: ExtensionInstallMetadata, destination: string, ): Promise { try { - // TODO(chrstnb): Download the archive instead to avoid unnecessary .git info. - await simpleGit().clone(gitUrl, destination, ['--depth', '1']); + const git = simpleGit(destination); + await git.clone(installMetadata.source, './', ['--depth', '1']); + + const remotes = await git.getRemotes(true); + if (remotes.length === 0) { + throw new Error( + `Unable to find any remotes for repo ${installMetadata.source}`, + ); + } + + const refToFetch = installMetadata.ref || 'HEAD'; + + await git.fetch(remotes[0].name, refToFetch); + + // After fetching, checkout FETCH_HEAD to get the content of the fetched ref. + // This results in a detached HEAD state, which is fine for this purpose. + await git.checkout('FETCH_HEAD'); } catch (error) { - throw new Error(`Failed to clone Git repository from ${gitUrl}`, { - cause: error, - }); + throw new Error( + `Failed to clone Git repository from ${installMetadata.source}`, + { + cause: error, + }, + ); } } @@ -346,83 +376,120 @@ export async function installExtension( installMetadata: ExtensionInstallMetadata, cwd: string = process.cwd(), ): Promise { - const settings = loadSettings(cwd).merged; - if (!isWorkspaceTrusted(settings)) { - throw new Error( - `Could not install extension from untrusted folder at ${installMetadata.source}`, - ); - } - - const extensionsDir = ExtensionStorage.getUserExtensionsDir(); - await fs.promises.mkdir(extensionsDir, { recursive: true }); - - // Convert relative paths to absolute paths for the metadata file. - if ( - !path.isAbsolute(installMetadata.source) && - (installMetadata.type === 'local' || installMetadata.type === 'link') - ) { - installMetadata.source = path.resolve(cwd, installMetadata.source); - } - - let localSourcePath: string; - let tempDir: string | undefined; - let newExtensionName: string | undefined; - - if (installMetadata.type === 'git') { - tempDir = await ExtensionStorage.createTmpDir(); - await cloneFromGit(installMetadata.source, tempDir); - localSourcePath = tempDir; - } else if ( - installMetadata.type === 'local' || - installMetadata.type === 'link' - ) { - localSourcePath = installMetadata.source; - } else { - throw new Error(`Unsupported install type: ${installMetadata.type}`); - } + const config = new Config({ + sessionId: randomUUID(), + targetDir: process.cwd(), + cwd: process.cwd(), + model: '', + debugMode: false, + }); + const logger = ClearcutLogger.getInstance(config); + let newExtensionConfig: ExtensionConfig | null = null; + let localSourcePath: string | undefined; try { - const newExtensionConfig = await loadExtensionConfig(localSourcePath); - if (!newExtensionConfig) { + const settings = loadSettings(cwd).merged; + if (!isWorkspaceTrusted(settings)) { throw new Error( - `Invalid extension at ${installMetadata.source}. Please make sure it has a valid gemini-extension.json file.`, + `Could not install extension from untrusted folder at ${installMetadata.source}`, ); } - newExtensionName = newExtensionConfig.name; - const extensionStorage = new ExtensionStorage(newExtensionName); - const destinationPath = extensionStorage.getExtensionDir(); + const extensionsDir = ExtensionStorage.getUserExtensionsDir(); + await fs.promises.mkdir(extensionsDir, { recursive: true }); - const installedExtensions = loadUserExtensions(); if ( - installedExtensions.some( - (installed) => installed.config.name === newExtensionName, - ) + !path.isAbsolute(installMetadata.source) && + (installMetadata.type === 'local' || installMetadata.type === 'link') ) { - throw new Error( - `Extension "${newExtensionName}" is already installed. Please uninstall it first.`, - ); + installMetadata.source = path.resolve(cwd, installMetadata.source); } - await fs.promises.mkdir(destinationPath, { recursive: true }); + let tempDir: string | undefined; - if (installMetadata.type === 'local' || installMetadata.type === 'git') { - await copyExtension(localSourcePath, destinationPath); + if (installMetadata.type === 'git') { + tempDir = await ExtensionStorage.createTmpDir(); + await cloneFromGit(installMetadata, tempDir); + localSourcePath = tempDir; + } else if ( + installMetadata.type === 'local' || + installMetadata.type === 'link' + ) { + localSourcePath = installMetadata.source; + } else { + throw new Error(`Unsupported install type: ${installMetadata.type}`); } - const metadataString = JSON.stringify(installMetadata, null, 2); - const metadataPath = path.join(destinationPath, INSTALL_METADATA_FILENAME); - await fs.promises.writeFile(metadataPath, metadataString); - } finally { - if (tempDir) { - await fs.promises.rm(tempDir, { recursive: true, force: true }); + try { + newExtensionConfig = await loadExtensionConfig(localSourcePath); + if (!newExtensionConfig) { + throw new Error( + `Invalid extension at ${installMetadata.source}. Please make sure it has a valid gemini-extension.json file.`, + ); + } + + const newExtensionName = newExtensionConfig.name; + const extensionStorage = new ExtensionStorage(newExtensionName); + const destinationPath = extensionStorage.getExtensionDir(); + + const installedExtensions = loadUserExtensions(); + if ( + installedExtensions.some( + (installed) => installed.config.name === newExtensionName, + ) + ) { + throw new Error( + `Extension "${newExtensionName}" is already installed. Please uninstall it first.`, + ); + } + + await fs.promises.mkdir(destinationPath, { recursive: true }); + + if (installMetadata.type === 'local' || installMetadata.type === 'git') { + await copyExtension(localSourcePath, destinationPath); + } + + const metadataString = JSON.stringify(installMetadata, null, 2); + const metadataPath = path.join( + destinationPath, + INSTALL_METADATA_FILENAME, + ); + await fs.promises.writeFile(metadataPath, metadataString); + } finally { + if (tempDir) { + await fs.promises.rm(tempDir, { recursive: true, force: true }); + } } - } - return newExtensionName; + logger?.logExtensionInstallEvent( + new ExtensionInstallEvent( + newExtensionConfig!.name, + newExtensionConfig!.version, + installMetadata.source, + 'success', + ), + ); + + return newExtensionConfig!.name; + } catch (error) { + // Attempt to load config from the source path even if installation fails + // to get the name and version for logging. + if (!newExtensionConfig && localSourcePath) { + newExtensionConfig = await loadExtensionConfig(localSourcePath); + } + logger?.logExtensionInstallEvent( + new ExtensionInstallEvent( + newExtensionConfig?.name ?? '', + newExtensionConfig?.version ?? '', + installMetadata.source, + 'error', + ), + ); + throw error; + } } -async function loadExtensionConfig( +export async function loadExtensionConfig( extensionDir: string, ): Promise { const configFilePath = path.join(extensionDir, EXTENSIONS_CONFIG_FILENAME); @@ -474,6 +541,9 @@ export function toOutputString(extension: Extension): string { output += `\n Path: ${extension.path}`; if (extension.installMetadata) { output += `\n Source: ${extension.installMetadata.source} (Type: ${extension.installMetadata.type})`; + if (extension.installMetadata.ref) { + output += `\n Ref: ${extension.installMetadata.ref}`; + } } if (extension.contextFiles.length > 0) { output += `\n Context files:`; @@ -613,3 +683,90 @@ export async function updateAllUpdatableExtensions( extensions.map((extension) => updateExtension(extension, cwd)), ); } + +export enum ExtensionUpdateStatus { + UpdateAvailable, + UpToDate, + Error, + NotUpdatable, +} + +export interface ExtensionUpdateCheckResult { + status: ExtensionUpdateStatus; + error?: string; +} + +export async function checkForExtensionUpdates( + extensions: Extension[], +): Promise> { + const results = new Map(); + + for (const extension of extensions) { + if (extension.installMetadata?.type !== 'git') { + results.set(extension.config.name, { + status: ExtensionUpdateStatus.NotUpdatable, + }); + continue; + } + + try { + const git = simpleGit(extension.path); + const remotes = await git.getRemotes(true); + if (remotes.length === 0) { + results.set(extension.config.name, { + status: ExtensionUpdateStatus.Error, + error: 'No git remotes found.', + }); + continue; + } + const remoteUrl = remotes[0].refs.fetch; + if (!remoteUrl) { + results.set(extension.config.name, { + status: ExtensionUpdateStatus.Error, + error: `No fetch URL found for git remote ${remotes[0].name}.`, + }); + continue; + } + + // Determine the ref to check on the remote. + const refToCheck = extension.installMetadata.ref || 'HEAD'; + + const lsRemoteOutput = await git.listRemote([remoteUrl, refToCheck]); + + if (typeof lsRemoteOutput !== 'string' || lsRemoteOutput.trim() === '') { + results.set(extension.config.name, { + status: ExtensionUpdateStatus.Error, + error: `Git ref ${refToCheck} not found.`, + }); + continue; + } + + const remoteHash = lsRemoteOutput.split('\t')[0]; + const localHash = await git.revparse(['HEAD']); + + if (!remoteHash) { + results.set(extension.config.name, { + status: ExtensionUpdateStatus.Error, + error: `Unable to parse hash from git ls-remote output "${lsRemoteOutput}"`, + }); + } else if (remoteHash === localHash) { + results.set(extension.config.name, { + status: ExtensionUpdateStatus.UpToDate, + }); + } else { + results.set(extension.config.name, { + status: ExtensionUpdateStatus.UpdateAvailable, + }); + } + } catch (error) { + results.set(extension.config.name, { + status: ExtensionUpdateStatus.Error, + error: `Failed to check for updates for extension "${ + extension.config.name + }": ${getErrorMessage(error)}`, + }); + } + } + + return results; +} diff --git a/packages/cli/src/config/keyBindings.ts b/packages/cli/src/config/keyBindings.ts index 4b139d1a304..ba5ddac5874 100644 --- a/packages/cli/src/config/keyBindings.ts +++ b/packages/cli/src/config/keyBindings.ts @@ -56,6 +56,7 @@ export enum Command { REVERSE_SEARCH = 'reverseSearch', SUBMIT_REVERSE_SEARCH = 'submitReverseSearch', ACCEPT_SUGGESTION_REVERSE_SEARCH = 'acceptSuggestionReverseSearch', + TOGGLE_SHELL_INPUT_FOCUS = 'toggleShellInputFocus', } /** @@ -162,4 +163,5 @@ export const defaultKeyBindings: KeyBindingConfig = { // Note: original logic ONLY checked ctrl=false, ignored meta/shift/paste [Command.SUBMIT_REVERSE_SEARCH]: [{ key: 'return', ctrl: false }], [Command.ACCEPT_SUGGESTION_REVERSE_SEARCH]: [{ key: 'tab' }], + [Command.TOGGLE_SHELL_INPUT_FOCUS]: [{ key: 'f', ctrl: true }], }; diff --git a/packages/cli/src/config/settings.test.ts b/packages/cli/src/config/settings.test.ts index 558fadf778f..64d11be92d4 100644 --- a/packages/cli/src/config/settings.test.ts +++ b/packages/cli/src/config/settings.test.ts @@ -428,8 +428,11 @@ describe('Settings Loading and Merging', () => { '/workspace/dir', ]); - // Verify excludeTools are overwritten by workspace - expect(settings.merged.tools?.exclude).toEqual(['workspace-tool']); + // Verify excludeTools are concatenated and de-duped + expect(settings.merged.tools?.exclude).toEqual([ + 'user-tool', + 'workspace-tool', + ]); // Verify excludedProjectEnvVars are concatenated and de-duped expect(settings.merged.advanced?.excludedEnvVars).toEqual( @@ -1075,6 +1078,30 @@ describe('Settings Loading and Merging', () => { }); }); + it('should merge output format settings, with workspace taking precedence', () => { + (mockFsExistsSync as Mock).mockReturnValue(true); + const userSettingsContent = { + output: { format: 'text' }, + }; + const workspaceSettingsContent = { + output: { format: 'json' }, + }; + + (fs.readFileSync as Mock).mockImplementation( + (p: fs.PathOrFileDescriptor) => { + if (p === USER_SETTINGS_PATH) + return JSON.stringify(userSettingsContent); + if (p === MOCK_WORKSPACE_SETTINGS_PATH) + return JSON.stringify(workspaceSettingsContent); + return '{}'; + }, + ); + + const settings = loadSettings(MOCK_WORKSPACE_DIR); + + expect(settings.merged.output?.format).toBe('json'); + }); + it('should handle chatCompression when only in user settings', () => { (mockFsExistsSync as Mock).mockImplementation( (p: fs.PathLike) => p === USER_SETTINGS_PATH, diff --git a/packages/cli/src/config/settings.ts b/packages/cli/src/config/settings.ts index d4f9643c516..9037707a71a 100644 --- a/packages/cli/src/config/settings.ts +++ b/packages/cli/src/config/settings.ts @@ -22,17 +22,17 @@ import { isWorkspaceTrusted } from './trustedFolders.js'; import { type Settings, type MemoryImportFormat, - SETTINGS_SCHEMA, type MergeStrategy, type SettingsSchema, type SettingDefinition, + getSettingsSchema, } from './settingsSchema.js'; import { resolveEnvVarsInObject } from '../utils/envVarResolver.js'; import { customDeepMerge } from '../utils/deepMerge.js'; function getMergeStrategyForPath(path: string[]): MergeStrategy | undefined { let current: SettingDefinition | undefined = undefined; - let currentSchema: SettingsSchema | undefined = SETTINGS_SCHEMA; + let currentSchema: SettingsSchema | undefined = getSettingsSchema(); for (const key of path) { if (!currentSchema || !currentSchema[key]) { @@ -107,6 +107,8 @@ const MIGRATION_MAP: Record = { sandbox: 'tools.sandbox', selectedAuthType: 'security.auth.selectedType', shouldUseNodePtyShell: 'tools.usePty', + shellPager: 'tools.shell.pager', + shellShowColor: 'tools.shell.showColor', skipNextSpeakerCheck: 'model.skipNextSpeakerCheck', summarizeToolOutput: 'model.summarizeToolOutput', telemetry: 'telemetry', diff --git a/packages/cli/src/config/settingsSchema.test.ts b/packages/cli/src/config/settingsSchema.test.ts index e182e49ef89..47fc91c108d 100644 --- a/packages/cli/src/config/settingsSchema.test.ts +++ b/packages/cli/src/config/settingsSchema.test.ts @@ -5,13 +5,17 @@ */ import { describe, it, expect } from 'vitest'; -import type { Settings } from './settingsSchema.js'; -import { SETTINGS_SCHEMA } from './settingsSchema.js'; +import { + getSettingsSchema, + type SettingDefinition, + type Settings, + type SettingsSchema, +} from './settingsSchema.js'; describe('SettingsSchema', () => { - describe('SETTINGS_SCHEMA', () => { + describe('getSettingsSchema', () => { it('should contain all expected top-level settings', () => { - const expectedSettings = [ + const expectedSettings: Array = [ 'mcpServers', 'general', 'ui', @@ -27,14 +31,12 @@ describe('SettingsSchema', () => { ]; expectedSettings.forEach((setting) => { - expect( - SETTINGS_SCHEMA[setting as keyof typeof SETTINGS_SCHEMA], - ).toBeDefined(); + expect(getSettingsSchema()[setting as keyof Settings]).toBeDefined(); }); }); it('should have correct structure for each setting', () => { - Object.entries(SETTINGS_SCHEMA).forEach(([_key, definition]) => { + Object.entries(getSettingsSchema()).forEach(([_key, definition]) => { expect(definition).toHaveProperty('type'); expect(definition).toHaveProperty('label'); expect(definition).toHaveProperty('category'); @@ -48,7 +50,7 @@ describe('SettingsSchema', () => { }); it('should have correct nested setting structure', () => { - const nestedSettings = [ + const nestedSettings: Array = [ 'general', 'ui', 'ide', @@ -62,11 +64,9 @@ describe('SettingsSchema', () => { ]; nestedSettings.forEach((setting) => { - const definition = SETTINGS_SCHEMA[ - setting as keyof typeof SETTINGS_SCHEMA - ] as (typeof SETTINGS_SCHEMA)[keyof typeof SETTINGS_SCHEMA] & { - properties: unknown; - }; + const definition = getSettingsSchema()[ + setting as keyof Settings + ] as SettingDefinition; expect(definition.type).toBe('object'); expect(definition.properties).toBeDefined(); expect(typeof definition.properties).toBe('object'); @@ -75,35 +75,36 @@ describe('SettingsSchema', () => { it('should have accessibility nested properties', () => { expect( - SETTINGS_SCHEMA.ui?.properties?.accessibility?.properties, + getSettingsSchema().ui?.properties?.accessibility?.properties, ).toBeDefined(); expect( - SETTINGS_SCHEMA.ui?.properties?.accessibility.properties + getSettingsSchema().ui?.properties?.accessibility.properties ?.disableLoadingPhrases.type, ).toBe('boolean'); }); it('should have checkpointing nested properties', () => { expect( - SETTINGS_SCHEMA.general?.properties?.checkpointing.properties?.enabled, + getSettingsSchema().general?.properties?.checkpointing.properties + ?.enabled, ).toBeDefined(); expect( - SETTINGS_SCHEMA.general?.properties?.checkpointing.properties?.enabled - .type, + getSettingsSchema().general?.properties?.checkpointing.properties + ?.enabled.type, ).toBe('boolean'); }); it('should have fileFiltering nested properties', () => { expect( - SETTINGS_SCHEMA.context.properties.fileFiltering.properties + getSettingsSchema().context.properties.fileFiltering.properties ?.respectGitIgnore, ).toBeDefined(); expect( - SETTINGS_SCHEMA.context.properties.fileFiltering.properties + getSettingsSchema().context.properties.fileFiltering.properties ?.respectGeminiIgnore, ).toBeDefined(); expect( - SETTINGS_SCHEMA.context.properties.fileFiltering.properties + getSettingsSchema().context.properties.fileFiltering.properties ?.enableRecursiveFileSearch, ).toBeDefined(); }); @@ -112,7 +113,7 @@ describe('SettingsSchema', () => { const categories = new Set(); // Collect categories from top-level settings - Object.values(SETTINGS_SCHEMA).forEach((definition) => { + Object.values(getSettingsSchema()).forEach((definition) => { categories.add(definition.category); // Also collect from nested properties const defWithProps = definition as typeof definition & { @@ -137,74 +138,80 @@ describe('SettingsSchema', () => { }); it('should have consistent default values for boolean settings', () => { - const checkBooleanDefaults = (schema: Record) => { - Object.entries(schema).forEach( - ([_key, definition]: [string, unknown]) => { - const def = definition as { - type?: string; - default?: unknown; - properties?: Record; - }; - if (def.type === 'boolean') { - // Boolean settings can have boolean or undefined defaults (for optional settings) - expect(['boolean', 'undefined']).toContain(typeof def.default); - } - if (def.properties) { - checkBooleanDefaults(def.properties); - } - }, - ); + const checkBooleanDefaults = (schema: SettingsSchema) => { + Object.entries(schema).forEach(([, definition]) => { + const def = definition as SettingDefinition; + if (def.type === 'boolean') { + // Boolean settings can have boolean or undefined defaults (for optional settings) + expect(['boolean', 'undefined']).toContain(typeof def.default); + } + if (def.properties) { + checkBooleanDefaults(def.properties); + } + }); }; - checkBooleanDefaults(SETTINGS_SCHEMA as Record); + checkBooleanDefaults(getSettingsSchema() as SettingsSchema); }); it('should have showInDialog property configured', () => { // Check that user-facing settings are marked for dialog display - expect(SETTINGS_SCHEMA.ui.properties.showMemoryUsage.showInDialog).toBe( + expect( + getSettingsSchema().ui.properties.showMemoryUsage.showInDialog, + ).toBe(true); + expect(getSettingsSchema().general.properties.vimMode.showInDialog).toBe( true, ); - expect(SETTINGS_SCHEMA.general.properties.vimMode.showInDialog).toBe( + expect(getSettingsSchema().ide.properties.enabled.showInDialog).toBe( true, ); - expect(SETTINGS_SCHEMA.ide.properties.enabled.showInDialog).toBe(true); expect( - SETTINGS_SCHEMA.general.properties.disableAutoUpdate.showInDialog, + getSettingsSchema().general.properties.disableAutoUpdate.showInDialog, ).toBe(true); - expect(SETTINGS_SCHEMA.ui.properties.hideWindowTitle.showInDialog).toBe( + expect( + getSettingsSchema().ui.properties.hideWindowTitle.showInDialog, + ).toBe(true); + expect(getSettingsSchema().ui.properties.hideTips.showInDialog).toBe( + true, + ); + expect(getSettingsSchema().ui.properties.hideBanner.showInDialog).toBe( true, ); - expect(SETTINGS_SCHEMA.ui.properties.hideTips.showInDialog).toBe(true); - expect(SETTINGS_SCHEMA.ui.properties.hideBanner.showInDialog).toBe(true); expect( - SETTINGS_SCHEMA.privacy.properties.usageStatisticsEnabled.showInDialog, + getSettingsSchema().privacy.properties.usageStatisticsEnabled + .showInDialog, ).toBe(false); // Check that advanced settings are hidden from dialog - expect(SETTINGS_SCHEMA.security.properties.auth.showInDialog).toBe(false); - expect(SETTINGS_SCHEMA.tools.properties.core.showInDialog).toBe(false); - expect(SETTINGS_SCHEMA.mcpServers.showInDialog).toBe(false); - expect(SETTINGS_SCHEMA.telemetry.showInDialog).toBe(false); + expect(getSettingsSchema().security.properties.auth.showInDialog).toBe( + false, + ); + expect(getSettingsSchema().tools.properties.core.showInDialog).toBe( + false, + ); + expect(getSettingsSchema().mcpServers.showInDialog).toBe(false); + expect(getSettingsSchema().telemetry.showInDialog).toBe(false); // Check that some settings are appropriately hidden - expect(SETTINGS_SCHEMA.ui.properties.theme.showInDialog).toBe(false); // Changed to false - expect(SETTINGS_SCHEMA.ui.properties.customThemes.showInDialog).toBe( + expect(getSettingsSchema().ui.properties.theme.showInDialog).toBe(false); // Changed to false + expect(getSettingsSchema().ui.properties.customThemes.showInDialog).toBe( false, ); // Managed via theme editor expect( - SETTINGS_SCHEMA.general.properties.checkpointing.showInDialog, + getSettingsSchema().general.properties.checkpointing.showInDialog, ).toBe(false); // Experimental feature - expect(SETTINGS_SCHEMA.ui.properties.accessibility.showInDialog).toBe( + expect(getSettingsSchema().ui.properties.accessibility.showInDialog).toBe( false, ); // Changed to false expect( - SETTINGS_SCHEMA.context.properties.fileFiltering.showInDialog, + getSettingsSchema().context.properties.fileFiltering.showInDialog, ).toBe(false); // Changed to false expect( - SETTINGS_SCHEMA.general.properties.preferredEditor.showInDialog, + getSettingsSchema().general.properties.preferredEditor.showInDialog, ).toBe(false); // Changed to false expect( - SETTINGS_SCHEMA.advanced.properties.autoConfigureMemory.showInDialog, + getSettingsSchema().advanced.properties.autoConfigureMemory + .showInDialog, ).toBe(false); }); @@ -228,80 +235,84 @@ describe('SettingsSchema', () => { it('should have includeDirectories setting in schema', () => { expect( - SETTINGS_SCHEMA.context?.properties.includeDirectories, + getSettingsSchema().context?.properties.includeDirectories, ).toBeDefined(); - expect(SETTINGS_SCHEMA.context?.properties.includeDirectories.type).toBe( - 'array', - ); expect( - SETTINGS_SCHEMA.context?.properties.includeDirectories.category, + getSettingsSchema().context?.properties.includeDirectories.type, + ).toBe('array'); + expect( + getSettingsSchema().context?.properties.includeDirectories.category, ).toBe('Context'); expect( - SETTINGS_SCHEMA.context?.properties.includeDirectories.default, + getSettingsSchema().context?.properties.includeDirectories.default, ).toEqual([]); }); it('should have loadMemoryFromIncludeDirectories setting in schema', () => { expect( - SETTINGS_SCHEMA.context?.properties.loadMemoryFromIncludeDirectories, + getSettingsSchema().context?.properties + .loadMemoryFromIncludeDirectories, ).toBeDefined(); expect( - SETTINGS_SCHEMA.context?.properties.loadMemoryFromIncludeDirectories + getSettingsSchema().context?.properties.loadMemoryFromIncludeDirectories .type, ).toBe('boolean'); expect( - SETTINGS_SCHEMA.context?.properties.loadMemoryFromIncludeDirectories + getSettingsSchema().context?.properties.loadMemoryFromIncludeDirectories .category, ).toBe('Context'); expect( - SETTINGS_SCHEMA.context?.properties.loadMemoryFromIncludeDirectories + getSettingsSchema().context?.properties.loadMemoryFromIncludeDirectories .default, ).toBe(false); }); it('should have folderTrustFeature setting in schema', () => { expect( - SETTINGS_SCHEMA.security.properties.folderTrust.properties.enabled, + getSettingsSchema().security.properties.folderTrust.properties.enabled, ).toBeDefined(); expect( - SETTINGS_SCHEMA.security.properties.folderTrust.properties.enabled.type, + getSettingsSchema().security.properties.folderTrust.properties.enabled + .type, ).toBe('boolean'); expect( - SETTINGS_SCHEMA.security.properties.folderTrust.properties.enabled + getSettingsSchema().security.properties.folderTrust.properties.enabled .category, ).toBe('Security'); expect( - SETTINGS_SCHEMA.security.properties.folderTrust.properties.enabled + getSettingsSchema().security.properties.folderTrust.properties.enabled .default, ).toBe(false); expect( - SETTINGS_SCHEMA.security.properties.folderTrust.properties.enabled + getSettingsSchema().security.properties.folderTrust.properties.enabled .showInDialog, ).toBe(true); }); it('should have debugKeystrokeLogging setting in schema', () => { expect( - SETTINGS_SCHEMA.general.properties.debugKeystrokeLogging, + getSettingsSchema().general.properties.debugKeystrokeLogging, ).toBeDefined(); expect( - SETTINGS_SCHEMA.general.properties.debugKeystrokeLogging.type, + getSettingsSchema().general.properties.debugKeystrokeLogging.type, ).toBe('boolean'); expect( - SETTINGS_SCHEMA.general.properties.debugKeystrokeLogging.category, + getSettingsSchema().general.properties.debugKeystrokeLogging.category, ).toBe('General'); expect( - SETTINGS_SCHEMA.general.properties.debugKeystrokeLogging.default, + getSettingsSchema().general.properties.debugKeystrokeLogging.default, ).toBe(false); expect( - SETTINGS_SCHEMA.general.properties.debugKeystrokeLogging + getSettingsSchema().general.properties.debugKeystrokeLogging .requiresRestart, ).toBe(false); expect( - SETTINGS_SCHEMA.general.properties.debugKeystrokeLogging.showInDialog, + getSettingsSchema().general.properties.debugKeystrokeLogging + .showInDialog, ).toBe(true); expect( - SETTINGS_SCHEMA.general.properties.debugKeystrokeLogging.description, + getSettingsSchema().general.properties.debugKeystrokeLogging + .description, ).toBe('Enable debug logging of keystrokes to the console.'); }); }); diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index 6ad4c142f2e..45b11a59be6 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -17,6 +17,37 @@ import { } from '@blocksuser/gemini-cli-core'; import type { CustomTheme } from '../ui/themes/theme.js'; +export type SettingsType = + | 'boolean' + | 'string' + | 'number' + | 'array' + | 'object' + | 'enum'; + +export type SettingsValue = + | boolean + | string + | number + | string[] + | object + | undefined; + +/** + * Setting datatypes that "toggle" through a fixed list of options + * (e.g. an enum or true/false) rather than allowing for free form input + * (like a number or string). + */ +export const TOGGLE_TYPES: ReadonlySet = new Set([ + 'boolean', + 'enum', +]); + +export interface SettingEnumOption { + value: string | number; + label: string; +} + export enum MergeStrategy { // Replace the old value with the new value. This is the default. REPLACE = 'replace', @@ -29,11 +60,11 @@ export enum MergeStrategy { } export interface SettingDefinition { - type: 'boolean' | 'string' | 'number' | 'array' | 'object'; + type: SettingsType; label: string; category: string; requiresRestart: boolean; - default: boolean | string | number | string[] | object | undefined; + default: SettingsValue; description?: string; parentKey?: string; childKey?: string; @@ -41,6 +72,8 @@ export interface SettingDefinition { properties?: SettingsSchema; showInDialog?: boolean; mergeStrategy?: MergeStrategy; + /** Enum type options */ + options?: readonly SettingEnumOption[]; } export interface SettingsSchema { @@ -55,7 +88,7 @@ export type DnsResolutionOrder = 'ipv4first' | 'verbatim'; * The structure of this object defines the structure of the `Settings` type. * `as const` is crucial for TypeScript to infer the most specific types possible. */ -export const SETTINGS_SCHEMA = { +const SETTINGS_SCHEMA = { // Maintained for compatibility/criticality mcpServers: { type: 'object', @@ -196,6 +229,30 @@ export const SETTINGS_SCHEMA = { }, }, }, + output: { + type: 'object', + label: 'Output', + category: 'General', + requiresRestart: false, + default: {}, + description: 'Settings for the CLI output.', + showInDialog: false, + properties: { + format: { + type: 'enum', + label: 'Output Format', + category: 'General', + requiresRestart: false, + default: 'text', + description: 'The format of the CLI output.', + showInDialog: true, + options: [ + { value: 'text', label: 'Text' }, + { value: 'json', label: 'JSON' }, + ], + }, + }, + }, ui: { type: 'object', @@ -634,6 +691,36 @@ export const SETTINGS_SCHEMA = { 'Use node-pty for shell command execution. Fallback to child_process still applies.', showInDialog: true, }, + shell: { + type: 'object', + label: 'Shell', + category: 'Tools', + requiresRestart: false, + default: {}, + description: 'Settings for shell execution.', + showInDialog: false, + properties: { + pager: { + type: 'string', + label: 'Pager', + category: 'Tools', + requiresRestart: false, + default: 'cat' as string | undefined, + description: + 'The pager command to use for shell output. Defaults to `cat`.', + showInDialog: false, + }, + showColor: { + type: 'boolean', + label: 'Show Color', + category: 'Tools', + requiresRestart: false, + default: false, + description: 'Show color in shell output.', + showInDialog: true, + }, + }, + }, autoAccept: { type: 'boolean', label: 'Auto Accept', @@ -671,6 +758,7 @@ export const SETTINGS_SCHEMA = { default: undefined as string[] | undefined, description: 'Tool names to exclude from discovery.', showInDialog: false, + mergeStrategy: MergeStrategy.UNION, }, discoveryCommand: { type: 'string', @@ -695,16 +783,25 @@ export const SETTINGS_SCHEMA = { label: 'Use Ripgrep', category: 'Tools', requiresRestart: false, - default: false, + default: true, description: 'Use ripgrep for file content search instead of the fallback implementation. Provides faster search performance.', showInDialog: true, }, + enableToolOutputTruncation: { + type: 'boolean', + label: 'Enable Tool Output Truncation', + category: 'General', + requiresRestart: true, + default: false, + description: 'Enable truncation of large tool outputs.', + showInDialog: true, + }, truncateToolOutputThreshold: { type: 'number', label: 'Tool Output Truncation Threshold', category: 'General', - requiresRestart: false, + requiresRestart: true, default: DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, description: 'Truncate tool output if it is larger than this many characters. Set to -1 to disable.', @@ -714,7 +811,7 @@ export const SETTINGS_SCHEMA = { type: 'number', label: 'Tool Output Truncation Lines', category: 'General', - requiresRestart: false, + requiresRestart: true, default: DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, description: 'The number of lines to keep when truncating tool output.', showInDialog: true, @@ -942,7 +1039,13 @@ export const SETTINGS_SCHEMA = { }, }, }, -} as const; +} as const satisfies SettingsSchema; + +export type SettingsSchemaType = typeof SETTINGS_SCHEMA; + +export function getSettingsSchema(): SettingsSchemaType { + return SETTINGS_SCHEMA; +} type InferSettings = { -readonly [K in keyof T]?: T[K] extends { properties: SettingsSchema } @@ -952,7 +1055,7 @@ type InferSettings = { : T[K]['default']; }; -export type Settings = InferSettings; +export type Settings = InferSettings; export interface FooterSettings { hideCWD?: boolean; diff --git a/packages/cli/src/config/trustedFolders.test.ts b/packages/cli/src/config/trustedFolders.test.ts index 7bb418c3dd9..b8dd47f6e11 100644 --- a/packages/cli/src/config/trustedFolders.test.ts +++ b/packages/cli/src/config/trustedFolders.test.ts @@ -4,17 +4,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -// Mock 'os' first. import * as osActual from 'node:os'; -vi.mock('os', async (importOriginal) => { - const actualOs = await importOriginal(); - return { - ...actualOs, - homedir: vi.fn(() => '/mock/home/user'), - platform: vi.fn(() => 'linux'), - }; -}); - +import { ideContextStore } from '@blocksuser/gemini-cli-core'; import { describe, it, @@ -28,7 +19,6 @@ import { import * as fs from 'node:fs'; import stripJsonComments from 'strip-json-comments'; import * as path from 'node:path'; - import { loadTrustedFolders, USER_TRUSTED_FOLDERS_PATH, @@ -37,6 +27,14 @@ import { } from './trustedFolders.js'; import type { Settings } from './settings.js'; +vi.mock('os', async (importOriginal) => { + const actualOs = await importOriginal(); + return { + ...actualOs, + homedir: vi.fn(() => '/mock/home/user'), + platform: vi.fn(() => 'linux'), + }; +}); vi.mock('fs', async (importOriginal) => { const actualFs = await importOriginal(); return { @@ -47,7 +45,6 @@ vi.mock('fs', async (importOriginal) => { mkdirSync: vi.fn(), }; }); - vi.mock('strip-json-comments', () => ({ default: vi.fn((content) => content), })); @@ -256,8 +253,6 @@ describe('isWorkspaceTrusted', () => { }); }); -import { getIdeTrust } from '@blocksuser/gemini-cli-core'; - vi.mock('@blocksuser/gemini-cli-core', async (importOriginal) => { const actual = await importOriginal>(); return { @@ -267,6 +262,10 @@ vi.mock('@blocksuser/gemini-cli-core', async (importOriginal) => { }); describe('isWorkspaceTrusted with IDE override', () => { + afterEach(() => { + ideContextStore.clear(); + }); + const mockSettings: Settings = { security: { folderTrust: { @@ -276,7 +275,7 @@ describe('isWorkspaceTrusted with IDE override', () => { }; it('should return true when ideTrust is true, ignoring config', () => { - vi.mocked(getIdeTrust).mockReturnValue(true); + ideContextStore.set({ workspaceState: { isTrusted: true } }); // Even if config says don't trust, ideTrust should win. vi.spyOn(fs, 'readFileSync').mockReturnValue( JSON.stringify({ [process.cwd()]: TrustLevel.DO_NOT_TRUST }), @@ -285,7 +284,7 @@ describe('isWorkspaceTrusted with IDE override', () => { }); it('should return false when ideTrust is false, ignoring config', () => { - vi.mocked(getIdeTrust).mockReturnValue(false); + ideContextStore.set({ workspaceState: { isTrusted: false } }); // Even if config says trust, ideTrust should win. vi.spyOn(fs, 'readFileSync').mockReturnValue( JSON.stringify({ [process.cwd()]: TrustLevel.TRUST_FOLDER }), @@ -294,7 +293,6 @@ describe('isWorkspaceTrusted with IDE override', () => { }); it('should fall back to config when ideTrust is undefined', () => { - vi.mocked(getIdeTrust).mockReturnValue(undefined); vi.spyOn(fs, 'existsSync').mockReturnValue(true); vi.spyOn(fs, 'readFileSync').mockReturnValue( JSON.stringify({ [process.cwd()]: TrustLevel.TRUST_FOLDER }), @@ -310,7 +308,7 @@ describe('isWorkspaceTrusted with IDE override', () => { }, }, }; - vi.mocked(getIdeTrust).mockReturnValue(false); + ideContextStore.set({ workspaceState: { isTrusted: false } }); expect(isWorkspaceTrusted(settings)).toBe(true); }); }); diff --git a/packages/cli/src/config/trustedFolders.ts b/packages/cli/src/config/trustedFolders.ts index 870b0b06aee..9554f9788cd 100644 --- a/packages/cli/src/config/trustedFolders.ts +++ b/packages/cli/src/config/trustedFolders.ts @@ -10,7 +10,7 @@ import { homedir } from 'node:os'; import { getErrorMessage, isWithinRoot, - getIdeTrust, + ideContextStore, } from '@blocksuser/gemini-cli-core'; import type { Settings } from './settings.js'; import stripJsonComments from 'strip-json-comments'; @@ -182,7 +182,7 @@ export function isWorkspaceTrusted(settings: Settings): boolean | undefined { return true; } - const ideTrust = getIdeTrust(); + const ideTrust = ideContextStore.get()?.workspaceState?.isTrusted; if (ideTrust !== undefined) { return ideTrust; } diff --git a/packages/cli/src/core/initializer.ts b/packages/cli/src/core/initializer.ts index 42911396108..f35487f2cc0 100644 --- a/packages/cli/src/core/initializer.ts +++ b/packages/cli/src/core/initializer.ts @@ -4,7 +4,13 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { type Config } from '@blocksuser/gemini-cli-core'; +import { + IdeClient, + IdeConnectionEvent, + IdeConnectionType, + logIdeConnection, + type Config, +} from '@blocksuser/gemini-cli-core'; import { type LoadedSettings } from '../config/settings.js'; import { performInitialAuth } from './auth.js'; import { validateTheme } from './theme.js'; @@ -36,6 +42,12 @@ export async function initializeApp( const shouldOpenAuthDialog = settings.merged.security?.auth?.selectedType === undefined || !!authError; + if (config.getIdeMode()) { + const ideClient = await IdeClient.getInstance(); + await ideClient.connect(); + logIdeConnection(config, new IdeConnectionEvent(IdeConnectionType.START)); + } + return { authError, themeError, diff --git a/packages/cli/src/gemini.test.tsx b/packages/cli/src/gemini.test.tsx index 592adb2521f..3865f11d3fe 100644 --- a/packages/cli/src/gemini.test.tsx +++ b/packages/cli/src/gemini.test.tsx @@ -235,6 +235,7 @@ describe('gemini.tsx main function kitty protocol', () => { useSmartEdit: undefined, sessionSummary: undefined, promptWords: undefined, + outputFormat: undefined, resume: undefined, listSessions: undefined, deleteSession: undefined, diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 6c53ab53ec8..c39bc5c3ac1 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -4,9 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useState, useEffect } from 'react'; -import { render, Box, Text } from 'ink'; -import Spinner from 'ink-spinner'; +import { render } from 'ink'; import { AppContainer } from './ui/AppContainer.js'; import { loadCliConfig, parseArguments } from './config/config.js'; import { readStdin } from './utils/readStdin.js'; @@ -121,38 +119,6 @@ async function relaunchWithAdditionalArgs(additionalArgs: string[]) { process.exit(0); } -const InitializingComponent = ({ initialTotal }: { initialTotal: number }) => { - const [total, setTotal] = useState(initialTotal); - const [connected, setConnected] = useState(0); - - useEffect(() => { - const onStart = ({ count }: { count: number }) => setTotal(count); - const onChange = () => { - setConnected((val) => val + 1); - }; - - appEvents.on('mcp-servers-discovery-start', onStart); - appEvents.on('mcp-server-connected', onChange); - appEvents.on('mcp-server-error', onChange); - - return () => { - appEvents.off('mcp-servers-discovery-start', onStart); - appEvents.off('mcp-server-connected', onChange); - appEvents.off('mcp-server-error', onChange); - }; - }, []); - - const message = `Connecting to MCP servers... (${connected}/${total})`; - - return ( - - - {message} - - - ); -}; - import { runZedIntegration } from './zed-integration/zedIntegration.js'; export function setupUnhandledRejectionHandler() { @@ -213,15 +179,10 @@ export async function startInteractiveUI( ); }; - const instance = render( - - - , - { - exitOnCtrlC: false, - isScreenReaderEnabled: config.getScreenReader(), - }, - ); + const instance = render(, { + exitOnCtrlC: false, + isScreenReaderEnabled: config.getScreenReader(), + }); checkForUpdates() .then((info) => { @@ -380,25 +341,6 @@ export async function main() { setMaxSizedBoxDebugging(config.getDebugMode()); - const mcpServers = config.getMcpServers(); - const mcpServersCount = mcpServers ? Object.keys(mcpServers).length : 0; - - let spinnerInstance; - if (config.isInteractive() && mcpServersCount > 0) { - spinnerInstance = render( - , - ); - } - - await config.initialize(); - - if (spinnerInstance) { - // Small UX detail to show the completion message for a bit before unmounting. - await new Promise((f) => setTimeout(f, 100)); - spinnerInstance.clear(); - spinnerInstance.unmount(); - } - // Load custom themes from settings themeManager.loadCustomThemes(settings.merged.ui?.customThemes); @@ -511,6 +453,9 @@ export async function main() { ); return; } + + await config.initialize(); + // If not a TTY, read from stdin // This is for cases where the user pipes input directly into the command if (!process.stdin.isTTY) { diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index cb8dd48a60c..1e91f532a40 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -8,12 +8,16 @@ import type { Config, ToolRegistry, ServerGeminiStreamEvent, + SessionMetrics, } from '@blocksuser/gemini-cli-core'; import { executeToolCall, ToolErrorType, shutdownTelemetry, GeminiEventType, + OutputFormat, + uiTelemetryService, + FatalInputError, } from '@blocksuser/gemini-cli-core'; import type { Part } from '@google/genai'; import { runNonInteractive } from './nonInteractiveCli.js'; @@ -38,6 +42,9 @@ vi.mock('@blocksuser/gemini-cli-core', async (importOriginal) => { shutdownTelemetry: vi.fn(), isTelemetrySdkInitialized: vi.fn().mockReturnValue(true), ChatRecordingService: MockChatRecordingService, + uiTelemetryService: { + getMetrics: vi.fn(), + }, }; }); @@ -61,6 +68,9 @@ describe('runNonInteractive', () => { processStdoutSpy = vi .spyOn(process.stdout, 'write') .mockImplementation(() => true); + vi.spyOn(process, 'exit').mockImplementation((code) => { + throw new Error(`process.exit(${code}) called`); + }); mockToolRegistry = { getTool: vi.fn(), @@ -91,6 +101,7 @@ describe('runNonInteractive', () => { getFullContext: vi.fn().mockReturnValue(false), getContentGeneratorConfig: vi.fn().mockReturnValue({}), getDebugMode: vi.fn().mockReturnValue(false), + getOutputFormat: vi.fn().mockReturnValue('text'), } as unknown as Config; const { handleAtCommand } = await import( @@ -312,9 +323,7 @@ describe('runNonInteractive', () => { vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(0); await expect( runNonInteractive(mockConfig, 'Trigger loop', 'prompt-id-6'), - ).rejects.toThrow( - 'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', - ); + ).rejects.toThrow('process.exit(53) called'); }); it('should preprocess @include commands before sending to the model', async () => { @@ -364,4 +373,274 @@ describe('runNonInteractive', () => { // 6. Assert the final output is correct expect(processStdoutSpy).toHaveBeenCalledWith('Summary complete.'); }); + + it('should process input and write JSON output with stats', async () => { + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Hello World' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + const mockMetrics: SessionMetrics = { + models: {}, + tools: { + totalCalls: 0, + totalSuccess: 0, + totalFail: 0, + totalDurationMs: 0, + totalDecisions: { + accept: 0, + reject: 0, + modify: 0, + auto_accept: 0, + }, + byName: {}, + }, + files: { + totalLinesAdded: 0, + totalLinesRemoved: 0, + }, + }; + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue(mockMetrics); + + await runNonInteractive(mockConfig, 'Test input', 'prompt-id-1'); + + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + [{ text: 'Test input' }], + expect.any(AbortSignal), + 'prompt-id-1', + ); + expect(processStdoutSpy).toHaveBeenCalledWith( + JSON.stringify({ response: 'Hello World', stats: mockMetrics }, null, 2), + ); + }); + + it('should write JSON output with stats for tool-only commands (no text response)', async () => { + // Test the scenario where a command completes successfully with only tool calls + // but no text response - this would have caught the original bug + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', + name: 'testTool', + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-tool-only', + }, + }; + const toolResponse: Part[] = [{ text: 'Tool executed successfully' }]; + mockCoreExecuteToolCall.mockResolvedValue({ responseParts: toolResponse }); + + // First call returns only tool call, no content + const firstCallEvents: ServerGeminiStreamEvent[] = [ + toolCallEvent, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 5 } }, + }, + ]; + + // Second call returns no content (tool-only completion) + const secondCallEvents: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 3 } }, + }, + ]; + + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents(firstCallEvents)) + .mockReturnValueOnce(createStreamFromEvents(secondCallEvents)); + + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + const mockMetrics: SessionMetrics = { + models: {}, + tools: { + totalCalls: 1, + totalSuccess: 1, + totalFail: 0, + totalDurationMs: 100, + totalDecisions: { + accept: 1, + reject: 0, + modify: 0, + auto_accept: 0, + }, + byName: { + testTool: { + count: 1, + success: 1, + fail: 0, + durationMs: 100, + decisions: { + accept: 1, + reject: 0, + modify: 0, + auto_accept: 0, + }, + }, + }, + }, + files: { + totalLinesAdded: 0, + totalLinesRemoved: 0, + }, + }; + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue(mockMetrics); + + await runNonInteractive( + mockConfig, + 'Execute tool only', + 'prompt-id-tool-only', + ); + + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2); + expect(mockCoreExecuteToolCall).toHaveBeenCalledWith( + mockConfig, + expect.objectContaining({ name: 'testTool' }), + expect.any(AbortSignal), + ); + + // This should output JSON with empty response but include stats + expect(processStdoutSpy).toHaveBeenCalledWith( + JSON.stringify({ response: '', stats: mockMetrics }, null, 2), + ); + }); + + it('should write JSON output with stats for empty response commands', async () => { + // Test the scenario where a command completes but produces no content at all + const events: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 1 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + const mockMetrics: SessionMetrics = { + models: {}, + tools: { + totalCalls: 0, + totalSuccess: 0, + totalFail: 0, + totalDurationMs: 0, + totalDecisions: { + accept: 0, + reject: 0, + modify: 0, + auto_accept: 0, + }, + byName: {}, + }, + files: { + totalLinesAdded: 0, + totalLinesRemoved: 0, + }, + }; + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue(mockMetrics); + + await runNonInteractive( + mockConfig, + 'Empty response test', + 'prompt-id-empty', + ); + + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + [{ text: 'Empty response test' }], + expect.any(AbortSignal), + 'prompt-id-empty', + ); + + // This should output JSON with empty response but include stats + expect(processStdoutSpy).toHaveBeenCalledWith( + JSON.stringify({ response: '', stats: mockMetrics }, null, 2), + ); + }); + + it('should handle errors in JSON format', async () => { + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + const testError = new Error('Invalid input provided'); + + mockGeminiClient.sendMessageStream.mockImplementation(() => { + throw testError; + }); + + // Mock console.error to capture JSON error output + const consoleErrorJsonSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + + let thrownError: Error | null = null; + try { + await runNonInteractive(mockConfig, 'Test input', 'prompt-id-error'); + // Should not reach here + expect.fail('Expected process.exit to be called'); + } catch (error) { + thrownError = error as Error; + } + + // Should throw because of mocked process.exit + expect(thrownError?.message).toBe('process.exit(1) called'); + + expect(consoleErrorJsonSpy).toHaveBeenCalledWith( + JSON.stringify( + { + error: { + type: 'Error', + message: 'Invalid input provided', + code: 1, + }, + }, + null, + 2, + ), + ); + }); + + it('should handle FatalInputError with custom exit code in JSON format', async () => { + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + const fatalError = new FatalInputError('Invalid command syntax provided'); + + mockGeminiClient.sendMessageStream.mockImplementation(() => { + throw fatalError; + }); + + // Mock console.error to capture JSON error output + const consoleErrorJsonSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + + let thrownError: Error | null = null; + try { + await runNonInteractive(mockConfig, 'Invalid syntax', 'prompt-id-fatal'); + // Should not reach here + expect.fail('Expected process.exit to be called'); + } catch (error) { + thrownError = error as Error; + } + + // Should throw because of mocked process.exit with custom exit code + expect(thrownError?.message).toBe('process.exit(42) called'); + + expect(consoleErrorJsonSpy).toHaveBeenCalledWith( + JSON.stringify( + { + error: { + type: 'FatalInputError', + message: 'Invalid command syntax provided', + code: 42, + }, + }, + null, + 2, + ), + ); + }); }); diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index a483b3bdebb..0edcbe7c1f5 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -10,14 +10,24 @@ import { shutdownTelemetry, isTelemetrySdkInitialized, GeminiEventType, - parseAndFormatApiError, FatalInputError, - FatalTurnLimitedError, + promptIdContext, + OutputFormat, + JsonFormatter, + StreamJsonFormatter, + uiTelemetryService, + streamingTelemetryService, } from '@blocksuser/gemini-cli-core'; import type { Content, Part } from '@google/genai'; import { ConsolePatcher } from './ui/utils/ConsolePatcher.js'; import { handleAtCommand } from './ui/hooks/atCommandProcessor.js'; +import { + handleError, + handleToolError, + handleCancellationError, + handleMaxTurnsExceededError, +} from './utils/errors.js'; export async function runNonInteractive( config: Config, @@ -25,138 +35,173 @@ export async function runNonInteractive( prompt_id: string, resumedSessionData?: ResumedSessionData, ): Promise { - const consolePatcher = new ConsolePatcher({ - stderr: true, - debugMode: config.getDebugMode(), - }); - - try { - consolePatcher.patch(); - // Handle EPIPE errors when the output is piped to a command that closes early. - process.stdout.on('error', (err: NodeJS.ErrnoException) => { - if (err.code === 'EPIPE') { - // Exit gracefully if the pipe is closed. - process.exit(0); - } + return promptIdContext.run(prompt_id, async () => { + const consolePatcher = new ConsolePatcher({ + stderr: true, + debugMode: config.getDebugMode(), }); - const geminiClient = config.getGeminiClient(); - - // Initialize chat recording service and handle resumed session - if (resumedSessionData) { - const chatRecordingService = geminiClient.getChatRecordingService(); - if (chatRecordingService) { - chatRecordingService.initialize(resumedSessionData); - - // Convert resumed session messages to chat history - const geminiChat = await geminiClient.getChat(); - if (geminiChat && resumedSessionData.conversation.messages.length > 0) { - // Load the conversation history into the chat - const historyContent: Content[] = resumedSessionData.conversation.messages.map(msg => ({ - role: msg.type === 'user' ? 'user' : 'model' as const, - parts: Array.isArray(msg.content) - ? msg.content.map(part => typeof part === 'string' ? { text: part } : part) - : [{ text: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content) }] - })); + try { + consolePatcher.patch(); + + // Set up streaming telemetry for stream-json format + const isStreamJsonFormat = config.getOutputFormat() === OutputFormat.STREAM_JSON; + let streamJsonFormatter: StreamJsonFormatter | undefined; + + if (isStreamJsonFormat) { + streamJsonFormatter = new StreamJsonFormatter(); + streamingTelemetryService.enable(); + streamingTelemetryService.addTelemetryListener((event) => { + process.stdout.write(streamJsonFormatter!.formatTelemetryBlock(event) + '\n'); + }); + } + // Handle EPIPE errors when the output is piped to a command that closes early. + process.stdout.on('error', (err: NodeJS.ErrnoException) => { + if (err.code === 'EPIPE') { + // Exit gracefully if the pipe is closed. + process.exit(0); + } + }); + + const geminiClient = config.getGeminiClient(); + + // Initialize chat recording service and handle resumed session + if (resumedSessionData) { + const chatRecordingService = geminiClient.getChatRecordingService(); + if (chatRecordingService) { + chatRecordingService.initialize(resumedSessionData); - // Set the chat history - geminiChat.setHistory(historyContent); + // Convert resumed session messages to chat history + const geminiChat = await geminiClient.getChat(); + if (geminiChat && resumedSessionData.conversation.messages.length > 0) { + // Load the conversation history into the chat + const historyContent: Content[] = resumedSessionData.conversation.messages.map(msg => ({ + role: msg.type === 'user' ? 'user' : 'model' as const, + parts: Array.isArray(msg.content) + ? msg.content.map(part => typeof part === 'string' ? { text: part } : part) + : [{ text: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content) }] + })); + + // Set the chat history + geminiChat.setHistory(historyContent); + } } } - } - const abortController = new AbortController(); - - const { processedQuery, shouldProceed } = await handleAtCommand({ - query: input, - config, - addItem: (_item, _timestamp) => 0, - onDebugMessage: () => {}, - messageId: Date.now(), - signal: abortController.signal, - }); - - if (!shouldProceed || !processedQuery) { - // An error occurred during @include processing (e.g., file not found). - // The error message is already logged by handleAtCommand. - throw new FatalInputError( - 'Exiting due to an error processing the @ command.', - ); - } - - let currentMessages: Content[] = [ - { role: 'user', parts: processedQuery as Part[] }, - ]; - - let turnCount = 0; - while (true) { - turnCount++; - if ( - config.getMaxSessionTurns() >= 0 && - turnCount > config.getMaxSessionTurns() - ) { - throw new FatalTurnLimitedError( - 'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', + const abortController = new AbortController(); + + const { processedQuery, shouldProceed } = await handleAtCommand({ + query: input, + config, + addItem: (_item, _timestamp) => 0, + onDebugMessage: () => {}, + messageId: Date.now(), + signal: abortController.signal, + }); + + if (!shouldProceed || !processedQuery) { + // An error occurred during @include processing (e.g., file not found). + // The error message is already logged by handleAtCommand. + throw new FatalInputError( + 'Exiting due to an error processing the @ command.', ); } - const toolCallRequests: ToolCallRequestInfo[] = []; - - const responseStream = geminiClient.sendMessageStream( - currentMessages[0]?.parts || [], - abortController.signal, - prompt_id, - ); - for await (const event of responseStream) { - if (abortController.signal.aborted) { - console.error('Operation cancelled.'); - return; + let currentMessages: Content[] = [ + { role: 'user', parts: processedQuery as Part[] }, + ]; + + let turnCount = 0; + while (true) { + turnCount++; + if ( + config.getMaxSessionTurns() >= 0 && + turnCount > config.getMaxSessionTurns() + ) { + handleMaxTurnsExceededError(config); } + const toolCallRequests: ToolCallRequestInfo[] = []; - if (event.type === GeminiEventType.Content) { - process.stdout.write(event.value); - } else if (event.type === GeminiEventType.ToolCallRequest) { - toolCallRequests.push(event.value); - } - } + const responseStream = geminiClient.sendMessageStream( + currentMessages[0]?.parts || [], + abortController.signal, + prompt_id, + ); - if (toolCallRequests.length > 0) { - const toolResponseParts: Part[] = []; - for (const requestInfo of toolCallRequests) { - const toolResponse = await executeToolCall( - config, - requestInfo, - abortController.signal, - ); + let responseText = ''; + for await (const event of responseStream) { + if (abortController.signal.aborted) { + handleCancellationError(config); + } - if (toolResponse.error) { - console.error( - `Error executing tool ${requestInfo.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`, - ); + if (event.type === GeminiEventType.Content) { + if (config.getOutputFormat() === OutputFormat.JSON) { + responseText += event.value; + } else if (config.getOutputFormat() === OutputFormat.STREAM_JSON) { + responseText += event.value; + if (streamJsonFormatter) { + process.stdout.write(streamJsonFormatter.formatContentBlock(event.value) + '\n'); + } + } else { + process.stdout.write(event.value); + } + } else if (event.type === GeminiEventType.ToolCallRequest) { + toolCallRequests.push(event.value); } + } - if (toolResponse.responseParts) { - toolResponseParts.push(...toolResponse.responseParts); + if (toolCallRequests.length > 0) { + const toolResponseParts: Part[] = []; + for (const requestInfo of toolCallRequests) { + const toolResponse = await executeToolCall( + config, + requestInfo, + abortController.signal, + ); + + if (toolResponse.error) { + handleToolError( + requestInfo.name, + toolResponse.error, + config, + toolResponse.errorType || 'TOOL_EXECUTION_ERROR', + typeof toolResponse.resultDisplay === 'string' + ? toolResponse.resultDisplay + : undefined, + ); + } + + if (toolResponse.responseParts) { + toolResponseParts.push(...toolResponse.responseParts); + } + } + currentMessages = [{ role: 'user', parts: toolResponseParts }]; + } else { + if (config.getOutputFormat() === OutputFormat.JSON) { + const formatter = new JsonFormatter(); + const stats = uiTelemetryService.getMetrics(); + process.stdout.write(formatter.format(responseText, stats)); + } else if (config.getOutputFormat() === OutputFormat.STREAM_JSON) { + if (streamJsonFormatter) { + const stats = uiTelemetryService.getMetrics(); + process.stdout.write(streamJsonFormatter.formatFinalBlock(responseText, stats) + '\n'); + } + } else { + process.stdout.write('\n'); // Ensure a final newline } + return; } - currentMessages = [{ role: 'user', parts: toolResponseParts }]; - } else { - process.stdout.write('\n'); // Ensure a final newline - return; + } + } catch (error) { + handleError(error, config); + } finally { + consolePatcher.cleanup(); + if (config.getOutputFormat() === OutputFormat.STREAM_JSON) { + streamingTelemetryService.disable(); + } + if (isTelemetrySdkInitialized()) { + await shutdownTelemetry(config); } } - } catch (error) { - console.error( - parseAndFormatApiError( - error, - config.getContentGeneratorConfig()?.authType, - ), - ); - throw error; - } finally { - consolePatcher.cleanup(); - if (isTelemetrySdkInitialized()) { - await shutdownTelemetry(config); - } - } + }); } diff --git a/packages/cli/src/services/CommandService.test.ts b/packages/cli/src/services/CommandService.test.ts index 362e4b8b624..e2d5b9f585d 100644 --- a/packages/cli/src/services/CommandService.test.ts +++ b/packages/cli/src/services/CommandService.test.ts @@ -349,36 +349,4 @@ describe('CommandService', () => { expect(deployExtension).toBeDefined(); expect(deployExtension?.description).toBe('[gcp] Deploy to Google Cloud'); }); - - it('should filter out hidden commands', async () => { - const visibleCommand = createMockCommand('visible', CommandKind.BUILT_IN); - const hiddenCommand = { - ...createMockCommand('hidden', CommandKind.BUILT_IN), - hidden: true, - }; - const initiallyVisibleCommand = createMockCommand( - 'initially-visible', - CommandKind.BUILT_IN, - ); - const hiddenOverrideCommand = { - ...createMockCommand('initially-visible', CommandKind.FILE), - hidden: true, - }; - - const mockLoader = new MockCommandLoader([ - visibleCommand, - hiddenCommand, - initiallyVisibleCommand, - hiddenOverrideCommand, - ]); - - const service = await CommandService.create( - [mockLoader], - new AbortController().signal, - ); - - const commands = service.getCommands(); - expect(commands).toHaveLength(1); - expect(commands[0].name).toBe('visible'); - }); }); diff --git a/packages/cli/src/services/CommandService.ts b/packages/cli/src/services/CommandService.ts index 808ab61bf58..5f1e09d50db 100644 --- a/packages/cli/src/services/CommandService.ts +++ b/packages/cli/src/services/CommandService.ts @@ -85,9 +85,7 @@ export class CommandService { }); } - const finalCommands = Object.freeze( - Array.from(commandMap.values()).filter((cmd) => !cmd.hidden), - ); + const finalCommands = Object.freeze(Array.from(commandMap.values())); return new CommandService(finalCommands); } diff --git a/packages/cli/src/services/prompt-processors/shellProcessor.test.ts b/packages/cli/src/services/prompt-processors/shellProcessor.test.ts index 8006444eb08..e01f2aa6312 100644 --- a/packages/cli/src/services/prompt-processors/shellProcessor.test.ts +++ b/packages/cli/src/services/prompt-processors/shellProcessor.test.ts @@ -71,6 +71,7 @@ describe('ShellProcessor', () => { getTargetDir: vi.fn().mockReturnValue('/test/dir'), getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), getShouldUseNodePtyShell: vi.fn().mockReturnValue(false), + getShellExecutionConfig: vi.fn().mockReturnValue({}), }; context = createMockCommandContext({ @@ -147,6 +148,7 @@ describe('ShellProcessor', () => { expect.any(Function), expect.any(Object), false, + expect.any(Object), ); expect(result).toEqual([{ text: 'The current status is: On branch main' }]); }); @@ -218,6 +220,7 @@ describe('ShellProcessor', () => { expect.any(Function), expect.any(Object), false, + expect.any(Object), ); expect(result).toEqual([{ text: 'Do something dangerous: deleted' }]); }); @@ -410,6 +413,7 @@ describe('ShellProcessor', () => { expect.any(Function), expect.any(Object), false, + expect.any(Object), ); }); @@ -574,6 +578,7 @@ describe('ShellProcessor', () => { expect.any(Function), expect.any(Object), false, + expect.any(Object), ); expect(result).toEqual([{ text: 'Command: match found' }]); @@ -598,6 +603,7 @@ describe('ShellProcessor', () => { expect.any(Function), expect.any(Object), false, + expect.any(Object), ); expect(result).toEqual([ @@ -668,6 +674,7 @@ describe('ShellProcessor', () => { expect.any(Function), expect.any(Object), false, + expect.any(Object), ); }); @@ -697,6 +704,7 @@ describe('ShellProcessor', () => { expect.any(Function), expect.any(Object), false, + expect.any(Object), ); }); }); diff --git a/packages/cli/src/services/prompt-processors/shellProcessor.ts b/packages/cli/src/services/prompt-processors/shellProcessor.ts index 2a239272b32..f3143d1231c 100644 --- a/packages/cli/src/services/prompt-processors/shellProcessor.ts +++ b/packages/cli/src/services/prompt-processors/shellProcessor.ts @@ -20,6 +20,7 @@ import { SHORTHAND_ARGS_PLACEHOLDER, } from './types.js'; import { extractInjections, type Injection } from './injectionParser.js'; +import { themeManager } from '../../ui/themes/theme-manager.js'; export class ConfirmationRequiredError extends Error { constructor( @@ -159,12 +160,19 @@ export class ShellProcessor implements IPromptProcessor { // Execute the resolved command (which already has ESCAPED input). if (injection.resolvedCommand) { + const activeTheme = themeManager.getActiveTheme(); + const shellExecutionConfig = { + ...config.getShellExecutionConfig(), + defaultFg: activeTheme.colors.Foreground, + defaultBg: activeTheme.colors.Background, + }; const { result } = await ShellExecutionService.execute( injection.resolvedCommand, config.getTargetDir(), () => {}, new AbortController().signal, config.getShouldUseNodePtyShell(), + shellExecutionConfig, ); const executionResult = await result; diff --git a/packages/cli/src/ui/AppContainer.test.tsx b/packages/cli/src/ui/AppContainer.test.tsx index facbb500d3b..0ea7f2af84f 100644 --- a/packages/cli/src/ui/AppContainer.test.tsx +++ b/packages/cli/src/ui/AppContainer.test.tsx @@ -4,31 +4,55 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; import { render, cleanup } from 'ink-testing-library'; import { AppContainer } from './AppContainer.js'; import { type Config, makeFakeConfig } from '@blocksuser/gemini-cli-core'; import type { LoadedSettings } from '../config/settings.js'; import type { InitializationResult } from '../core/initializer.js'; +import { useQuotaAndFallback } from './hooks/useQuotaAndFallback.js'; +import { UIStateContext, type UIState } from './contexts/UIStateContext.js'; +import { + UIActionsContext, + type UIActions, +} from './contexts/UIActionsContext.js'; +import { useContext } from 'react'; + +// Helper component will read the context values provided by AppContainer +// so we can assert against them in our tests. +let capturedUIState: UIState; +let capturedUIActions: UIActions; +function TestContextConsumer() { + capturedUIState = useContext(UIStateContext)!; + capturedUIActions = useContext(UIActionsContext)!; + return null; +} -// Mock App component to isolate AppContainer testing vi.mock('./App.js', () => ({ - App: () => 'App Component', + App: TestContextConsumer, })); -// Mock all the hooks and utilities -vi.mock('./hooks/useHistory.js'); +vi.mock('./hooks/useQuotaAndFallback.js'); +vi.mock('./hooks/useHistoryManager.js'); vi.mock('./hooks/useThemeCommand.js'); -vi.mock('./hooks/useAuthCommand.js'); +vi.mock('./auth/useAuth.js'); vi.mock('./hooks/useEditorSettings.js'); vi.mock('./hooks/useSettingsCommand.js'); -vi.mock('./hooks/useSlashCommandProcessor.js'); +vi.mock('./hooks/slashCommandProcessor.js'); vi.mock('./hooks/useConsoleMessages.js'); vi.mock('./hooks/useTerminalSize.js', () => ({ useTerminalSize: vi.fn(() => ({ columns: 80, rows: 24 })), })); vi.mock('./hooks/useGeminiStream.js'); -vi.mock('./hooks/useVim.js'); +vi.mock('./hooks/vim.js'); vi.mock('./hooks/useFocus.js'); vi.mock('./hooks/useBracketedPaste.js'); vi.mock('./hooks/useKeypress.js'); @@ -40,7 +64,7 @@ vi.mock('./hooks/useWorkspaceMigration.js'); vi.mock('./hooks/useGitBranchName.js'); vi.mock('./contexts/VimModeContext.js'); vi.mock('./contexts/SessionContext.js'); -vi.mock('./hooks/useTextBuffer.js'); +vi.mock('./components/shared/text-buffer.js'); vi.mock('./hooks/useLogger.js'); // Mock external utilities @@ -49,14 +73,153 @@ vi.mock('../utils/handleAutoUpdate.js'); vi.mock('./utils/ConsolePatcher.js'); vi.mock('../utils/cleanup.js'); +import { useHistory } from './hooks/useHistoryManager.js'; +import { useThemeCommand } from './hooks/useThemeCommand.js'; +import { useAuthCommand } from './auth/useAuth.js'; +import { useEditorSettings } from './hooks/useEditorSettings.js'; +import { useSettingsCommand } from './hooks/useSettingsCommand.js'; +import { useSlashCommandProcessor } from './hooks/slashCommandProcessor.js'; +import { useConsoleMessages } from './hooks/useConsoleMessages.js'; +import { useGeminiStream } from './hooks/useGeminiStream.js'; +import { useVim } from './hooks/vim.js'; +import { useFolderTrust } from './hooks/useFolderTrust.js'; +import { useMessageQueue } from './hooks/useMessageQueue.js'; +import { useAutoAcceptIndicator } from './hooks/useAutoAcceptIndicator.js'; +import { useWorkspaceMigration } from './hooks/useWorkspaceMigration.js'; +import { useGitBranchName } from './hooks/useGitBranchName.js'; +import { useVimMode } from './contexts/VimModeContext.js'; +import { useSessionStats } from './contexts/SessionContext.js'; +import { useTextBuffer } from './components/shared/text-buffer.js'; +import { useLogger } from './hooks/useLogger.js'; +import { useLoadingIndicator } from './hooks/useLoadingIndicator.js'; + describe('AppContainer State Management', () => { let mockConfig: Config; let mockSettings: LoadedSettings; let mockInitResult: InitializationResult; + // Create typed mocks for all hooks + const mockedUseQuotaAndFallback = useQuotaAndFallback as Mock; + const mockedUseHistory = useHistory as Mock; + const mockedUseThemeCommand = useThemeCommand as Mock; + const mockedUseAuthCommand = useAuthCommand as Mock; + const mockedUseEditorSettings = useEditorSettings as Mock; + const mockedUseSettingsCommand = useSettingsCommand as Mock; + const mockedUseSlashCommandProcessor = useSlashCommandProcessor as Mock; + const mockedUseConsoleMessages = useConsoleMessages as Mock; + const mockedUseGeminiStream = useGeminiStream as Mock; + const mockedUseVim = useVim as Mock; + const mockedUseFolderTrust = useFolderTrust as Mock; + const mockedUseMessageQueue = useMessageQueue as Mock; + const mockedUseAutoAcceptIndicator = useAutoAcceptIndicator as Mock; + const mockedUseWorkspaceMigration = useWorkspaceMigration as Mock; + const mockedUseGitBranchName = useGitBranchName as Mock; + const mockedUseVimMode = useVimMode as Mock; + const mockedUseSessionStats = useSessionStats as Mock; + const mockedUseTextBuffer = useTextBuffer as Mock; + const mockedUseLogger = useLogger as Mock; + const mockedUseLoadingIndicator = useLoadingIndicator as Mock; + beforeEach(() => { vi.clearAllMocks(); + capturedUIState = null!; + capturedUIActions = null!; + + // **Provide a default return value for EVERY mocked hook.** + mockedUseQuotaAndFallback.mockReturnValue({ + proQuotaRequest: null, + handleProQuotaChoice: vi.fn(), + }); + mockedUseHistory.mockReturnValue({ + history: [], + addItem: vi.fn(), + updateItem: vi.fn(), + clearItems: vi.fn(), + loadHistory: vi.fn(), + }); + mockedUseThemeCommand.mockReturnValue({ + isThemeDialogOpen: false, + openThemeDialog: vi.fn(), + handleThemeSelect: vi.fn(), + handleThemeHighlight: vi.fn(), + }); + mockedUseAuthCommand.mockReturnValue({ + authState: 'authenticated', + setAuthState: vi.fn(), + authError: null, + onAuthError: vi.fn(), + }); + mockedUseEditorSettings.mockReturnValue({ + isEditorDialogOpen: false, + openEditorDialog: vi.fn(), + handleEditorSelect: vi.fn(), + exitEditorDialog: vi.fn(), + }); + mockedUseSettingsCommand.mockReturnValue({ + isSettingsDialogOpen: false, + openSettingsDialog: vi.fn(), + closeSettingsDialog: vi.fn(), + }); + mockedUseSlashCommandProcessor.mockReturnValue({ + handleSlashCommand: vi.fn(), + slashCommands: [], + pendingHistoryItems: [], + commandContext: {}, + shellConfirmationRequest: null, + confirmationRequest: null, + }); + mockedUseConsoleMessages.mockReturnValue({ + consoleMessages: [], + handleNewMessage: vi.fn(), + clearConsoleMessages: vi.fn(), + }); + mockedUseGeminiStream.mockReturnValue({ + streamingState: 'idle', + submitQuery: vi.fn(), + initError: null, + pendingHistoryItems: [], + thought: null, + cancelOngoingRequest: vi.fn(), + }); + mockedUseVim.mockReturnValue({ handleInput: vi.fn() }); + mockedUseFolderTrust.mockReturnValue({ + isFolderTrustDialogOpen: false, + handleFolderTrustSelect: vi.fn(), + isRestarting: false, + }); + mockedUseMessageQueue.mockReturnValue({ + messageQueue: [], + addMessage: vi.fn(), + clearQueue: vi.fn(), + getQueuedMessagesText: vi.fn().mockReturnValue(''), + }); + mockedUseAutoAcceptIndicator.mockReturnValue(false); + mockedUseWorkspaceMigration.mockReturnValue({ + showWorkspaceMigrationDialog: false, + workspaceExtensions: [], + onWorkspaceMigrationDialogOpen: vi.fn(), + onWorkspaceMigrationDialogClose: vi.fn(), + }); + mockedUseGitBranchName.mockReturnValue('main'); + mockedUseVimMode.mockReturnValue({ + isVimEnabled: false, + toggleVimEnabled: vi.fn(), + }); + mockedUseSessionStats.mockReturnValue({ stats: {} }); + mockedUseTextBuffer.mockReturnValue({ + text: '', + setText: vi.fn(), + // Add other properties if AppContainer uses them + }); + mockedUseLogger.mockReturnValue({ + getPreviousUserMessages: vi.fn().mockResolvedValue([]), + }); + mockedUseLoadingIndicator.mockReturnValue({ + elapsedTime: '0.0s', + currentLoadingPhrase: '', + }); + // Mock Config mockConfig = makeFakeConfig(); @@ -253,10 +416,9 @@ describe('AppContainer State Management', () => { }); describe('Version Handling', () => { - it('handles different version formats', () => { - const versions = ['1.0.0', '2.1.3-beta', '3.0.0-nightly']; - - versions.forEach((version) => { + it.each(['1.0.0', '2.1.3-beta', '3.0.0-nightly'])( + 'handles version format: %s', + (version) => { expect(() => { render( { />, ); }).not.toThrow(); - }); - }); + }, + ); }); describe('Error Handling', () => { @@ -325,7 +487,73 @@ describe('AppContainer State Management', () => { expect(() => unmount()).not.toThrow(); }); }); -}); -// TODO: Add comprehensive integration test once all hook mocks are complete -// For now, the 14 passing unit tests provide good coverage of AppContainer functionality + describe('Quota and Fallback Integration', () => { + it('passes a null proQuotaRequest to UIStateContext by default', () => { + // The default mock from beforeEach already sets proQuotaRequest to null + render( + , + ); + + // Assert that the context value is as expected + expect(capturedUIState.proQuotaRequest).toBeNull(); + }); + + it('passes a valid proQuotaRequest to UIStateContext when provided by the hook', () => { + // Arrange: Create a mock request object that a UI dialog would receive + const mockRequest = { + failedModel: 'gemini-pro', + fallbackModel: 'gemini-flash', + resolve: vi.fn(), + }; + mockedUseQuotaAndFallback.mockReturnValue({ + proQuotaRequest: mockRequest, + handleProQuotaChoice: vi.fn(), + }); + + // Act: Render the container + render( + , + ); + + // Assert: The mock request is correctly passed through the context + expect(capturedUIState.proQuotaRequest).toEqual(mockRequest); + }); + + it('passes the handleProQuotaChoice function to UIActionsContext', () => { + // Arrange: Create a mock handler function + const mockHandler = vi.fn(); + mockedUseQuotaAndFallback.mockReturnValue({ + proQuotaRequest: null, + handleProQuotaChoice: mockHandler, + }); + + // Act: Render the container + render( + , + ); + + // Assert: The action in the context is the mock handler we provided + expect(capturedUIActions.handleProQuotaChoice).toBe(mockHandler); + + // You can even verify that the plumbed function is callable + capturedUIActions.handleProQuotaChoice('auth'); + expect(mockHandler).toHaveBeenCalledWith('auth'); + }); + }); +}); diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index c7f2c3ef368..26022337982 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -20,23 +20,21 @@ import { type HistoryItemWithoutId, AuthState, } from './types.js'; -import { MessageType } from './types.js'; +import { MessageType, StreamingState } from './types.js'; import { type EditorType, type Config, - IdeClient, type DetectedIde, - ideContext, type IdeContext, + type UserTierId, + DEFAULT_GEMINI_FLASH_MODEL, + IdeClient, + ideContextStore, getErrorMessage, getAllGeminiMdFilenames, - UserTierId, AuthType, - isProQuotaExceededError, - isGenericQuotaExceededError, - logFlashFallback, - FlashFallbackEvent, clearCachedCredentialFile, + ShellExecutionService, } from '@blocksuser/gemini-cli-core'; import { validateAuthMethod } from '../config/auth.js'; import { loadHierarchicalGeminiMemory } from '../config/config.js'; @@ -44,6 +42,7 @@ import process from 'node:process'; import { useHistory } from './hooks/useHistoryManager.js'; import { useThemeCommand } from './hooks/useThemeCommand.js'; import { useAuthCommand } from './auth/useAuth.js'; +import { useQuotaAndFallback } from './hooks/useQuotaAndFallback.js'; import { useEditorSettings } from './hooks/useEditorSettings.js'; import { useSettingsCommand } from './hooks/useSettingsCommand.js'; import { useSlashCommandProcessor } from './hooks/slashCommandProcessor.js'; @@ -99,6 +98,18 @@ interface AppContainerProps { initializationResult: InitializationResult; } +/** + * The fraction of the terminal width to allocate to the shell. + * This provides horizontal padding. + */ +const SHELL_WIDTH_FRACTION = 0.89; + +/** + * The number of lines to subtract from the available terminal height + * for the shell. This provides vertical padding and space for other UI elements. + */ +const SHELL_HEIGHT_PADDING = 10; + export const AppContainer = (props: AppContainerProps) => { const { settings, config, initializationResult } = props; const historyManager = useHistory(); @@ -112,6 +123,8 @@ export const AppContainer = (props: AppContainerProps) => { initializationResult.themeError, ); const [isProcessing, setIsProcessing] = useState(false); + const [shellFocused, setShellFocused] = useState(false); + const [geminiMdFileCount, setGeminiMdFileCount] = useState( initializationResult.geminiMdFileCount, ); @@ -123,12 +136,20 @@ export const AppContainer = (props: AppContainerProps) => { const [isTrustedFolder, setIsTrustedFolder] = useState( config.isTrustedFolder(), ); - const [currentModel, setCurrentModel] = useState(config.getModel()); + + // Helper to determine the effective model, considering the fallback state. + const getEffectiveModel = useCallback(() => { + if (config.isInFallbackMode()) { + return DEFAULT_GEMINI_FLASH_MODEL; + } + return config.getModel(); + }, [config]); + + const [currentModel, setCurrentModel] = useState(getEffectiveModel()); + const [userTier, setUserTier] = useState(undefined); - const [isProQuotaDialogOpen, setIsProQuotaDialogOpen] = useState(false); - const [proQuotaDialogResolver, setProQuotaDialogResolver] = useState< - ((value: boolean) => void) | null - >(null); + + const [isConfigInitialized, setConfigInitialized] = useState(false); // Auto-accept indicator const showAutoAcceptIndicator = useAutoAcceptIndicator({ @@ -153,32 +174,37 @@ export const AppContainer = (props: AppContainerProps) => { const staticExtraHeight = 3; useEffect(() => { + (async () => { + // Note: the program will not work if this fails so let errors be + // handled by the global catch. + await config.initialize(); + setConfigInitialized(true); + })(); registerCleanup(async () => { const ideClient = await IdeClient.getInstance(); await ideClient.disconnect(); }); }, [config]); - useEffect(() => { - const cleanup = setUpdateHandler(historyManager.addItem, setUpdateInfo); - return cleanup; - }, [historyManager.addItem]); + useEffect( + () => setUpdateHandler(historyManager.addItem, setUpdateInfo), + [historyManager.addItem], + ); // Watch for model changes (e.g., from Flash fallback) useEffect(() => { const checkModelChange = () => { - const configModel = config.getModel(); - if (configModel !== currentModel) { - setCurrentModel(configModel); + const effectiveModel = getEffectiveModel(); + if (effectiveModel !== currentModel) { + setCurrentModel(effectiveModel); } }; - // Check immediately and then periodically checkModelChange(); const interval = setInterval(checkModelChange, 1000); // Check every second return () => clearInterval(interval); - }, [config, currentModel]); + }, [config, currentModel, getEffectiveModel]); const { consoleMessages, @@ -200,7 +226,7 @@ export const AppContainer = (props: AppContainerProps) => { 20, Math.floor(terminalWidth * widthFraction) - 3, ); - const suggestionsWidth = Math.max(20, Math.floor(terminalWidth * 0.8)); + const suggestionsWidth = Math.max(20, Math.floor(terminalWidth * 1.0)); const mainAreaWidth = Math.floor(terminalWidth * 0.9); const staticAreaMaxItemHeight = Math.max(terminalHeight * 4, 100); @@ -273,6 +299,14 @@ export const AppContainer = (props: AppContainerProps) => { config, ); + const { proQuotaRequest, handleProQuotaChoice } = useQuotaAndFallback({ + config, + historyManager, + userTier, + setAuthState, + setModelSwitchedFromQuotaError, + }); + // Derive auth state variables for backward compatibility with UIStateContext const isAuthDialogOpen = authState === AuthState.Updating; const isAuthenticating = authState === AuthState.Unauthenticated; @@ -316,7 +350,7 @@ Logging in with Google... Please restart Gemini CLI to continue. useEffect(() => { // Only sync when not currently authenticating if (authState === AuthState.Authenticated) { - setUserTier(config.getGeminiClient()?.getUserTier()); + setUserTier(config.getUserTier()); } }, [config, authState]); @@ -416,6 +450,7 @@ Logging in with Google... Please restart Gemini CLI to continue. setIsProcessing, setGeminiMdFileCount, slashCommandActions, + isConfigInitialized, ); const performMemoryRefresh = useCallback(async () => { @@ -477,132 +512,6 @@ Logging in with Google... Please restart Gemini CLI to continue. } }, [config, historyManager, settings.merged]); - // Set up Flash fallback handler - useEffect(() => { - const flashFallbackHandler = async ( - currentModel: string, - fallbackModel: string, - error?: unknown, - ): Promise => { - // Check if we've already switched to the fallback model - if (config.isInFallbackMode()) { - // If we're already in fallback mode, don't show the dialog again - return false; - } - - let message: string; - - if ( - config.getContentGeneratorConfig().authType === - AuthType.LOGIN_WITH_GOOGLE - ) { - // Use actual user tier if available; otherwise, default to FREE tier behavior (safe default) - const isPaidTier = - userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD; - - // Check if this is a Pro quota exceeded error - if (error && isProQuotaExceededError(error)) { - if (isPaidTier) { - message = `⚡ You have reached your daily ${currentModel} quota limit. -⚡ You can choose to authenticate with a paid API key or continue with the fallback model. -⚡ To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; - } else { - message = `⚡ You have reached your daily ${currentModel} quota limit. -⚡ You can choose to authenticate with a paid API key or continue with the fallback model. -⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist -⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key -⚡ You can switch authentication methods by typing /auth`; - } - } else if (error && isGenericQuotaExceededError(error)) { - if (isPaidTier) { - message = `⚡ You have reached your daily quota limit. -⚡ Automatically switching from ${currentModel} to ${fallbackModel} for the remainder of this session. -⚡ To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; - } else { - message = `⚡ You have reached your daily quota limit. -⚡ Automatically switching from ${currentModel} to ${fallbackModel} for the remainder of this session. -⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist -⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key -⚡ You can switch authentication methods by typing /auth`; - } - } else { - if (isPaidTier) { - // Default fallback message for other cases (like consecutive 429s) - message = `⚡ Automatically switching from ${currentModel} to ${fallbackModel} for faster responses for the remainder of this session. -⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${currentModel} quota limit -⚡ To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; - } else { - // Default fallback message for other cases (like consecutive 429s) - message = `⚡ Automatically switching from ${currentModel} to ${fallbackModel} for faster responses for the remainder of this session. -⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${currentModel} quota limit -⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist -⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key -⚡ You can switch authentication methods by typing /auth`; - } - } - - // Add message to UI history - historyManager.addItem( - { - type: MessageType.INFO, - text: message, - }, - Date.now(), - ); - - // For Pro quota errors, show the dialog and wait for user's choice - if (error && isProQuotaExceededError(error)) { - // Set the flag to prevent tool continuation - setModelSwitchedFromQuotaError(true); - // Set global quota error flag to prevent Flash model calls - config.setQuotaErrorOccurred(true); - - // Show the ProQuotaDialog and wait for user's choice - const shouldContinueWithFallback = await new Promise( - (resolve) => { - setIsProQuotaDialogOpen(true); - setProQuotaDialogResolver(() => resolve); - }, - ); - - // If user chose to continue with fallback, we don't need to stop the current prompt - if (shouldContinueWithFallback) { - // Switch to fallback model for future use - config.setModel(fallbackModel); - config.setFallbackMode(true); - logFlashFallback( - config, - new FlashFallbackEvent( - config.getContentGeneratorConfig().authType!, - ), - ); - return true; // Continue with current prompt using fallback model - } - - // If user chose to authenticate, stop current prompt - return false; - } - - // For other quota errors, automatically switch to fallback model - // Set the flag to prevent tool continuation - setModelSwitchedFromQuotaError(true); - // Set global quota error flag to prevent Flash model calls - config.setQuotaErrorOccurred(true); - } - - // Switch model for future use but return false to stop current retry - config.setModel(fallbackModel); - config.setFallbackMode(true); - logFlashFallback( - config, - new FlashFallbackEvent(config.getContentGeneratorConfig().authType!), - ); - return false; // Don't continue with current prompt - }; - - config.setFlashFallbackHandler(flashFallbackHandler); - }, [config, historyManager, userTier]); - const cancelHandlerRef = useRef<() => void>(() => {}); const { @@ -612,6 +521,8 @@ Logging in with Google... Please restart Gemini CLI to continue. pendingHistoryItems: pendingGeminiHistoryItems, thought, cancelOngoingRequest, + activePtyId, + loopDetectionConfirmationRequest, } = useGeminiStream( config.getGeminiClient(), historyManager.history, @@ -628,10 +539,15 @@ Logging in with Google... Please restart Gemini CLI to continue. setModelSwitchedFromQuotaError, refreshStatic, () => cancelHandlerRef.current(), + setShellFocused, + terminalWidth, + terminalHeight, + shellFocused, ); const { messageQueue, addMessage, clearQueue, getQueuedMessagesText } = useMessageQueue({ + isConfigInitialized, streamingState, submitQuery, }); @@ -681,25 +597,22 @@ Logging in with Google... Please restart Gemini CLI to continue. refreshStatic(); }, [historyManager, clearConsoleMessagesState, refreshStatic]); - const handleProQuotaChoice = useCallback( - (choice: 'auth' | 'continue') => { - setIsProQuotaDialogOpen(false); - if (proQuotaDialogResolver) { - if (choice === 'auth') { - proQuotaDialogResolver(false); // Don't continue with fallback, show auth dialog - setAuthState(AuthState.Updating); - } else { - proQuotaDialogResolver(true); // Continue with fallback model - } - setProQuotaDialogResolver(null); - } - }, - [proQuotaDialogResolver, setAuthState], - ); - const { handleInput: vimHandleInput } = useVim(buffer, handleFinalSubmit); - const isInputActive = !initError && !isProcessing; + /** + * Determines if the input prompt should be active and accept user input. + * Input is disabled during: + * - Initialization errors + * - Slash command processing + * - Tool confirmations (WaitingForConfirmation state) + * - Any future streaming states not explicitly allowed + */ + const isInputActive = + !initError && + !isProcessing && + (streamingState === StreamingState.Idle || + streamingState === StreamingState.Responding) && + !proQuotaRequest; // Compute available terminal height based on controls measurement const availableTerminalHeight = useMemo(() => { @@ -710,6 +623,13 @@ Logging in with Google... Please restart Gemini CLI to continue. return terminalHeight - staticExtraHeight; }, [terminalHeight]); + config.setShellExecutionConfig({ + terminalWidth: Math.floor(terminalWidth * SHELL_WIDTH_FRACTION), + terminalHeight: Math.floor(availableTerminalHeight - SHELL_HEIGHT_PADDING), + pager: settings.merged.tools?.shell?.pager, + showColor: settings.merged.tools?.shell?.showColor, + }); + const isFocused = useFocus(); useBracketedPaste(); @@ -727,9 +647,26 @@ Logging in with Google... Please restart Gemini CLI to continue. const initialPromptSubmitted = useRef(false); const geminiClient = config.getGeminiClient(); + useEffect(() => { + if (activePtyId) { + ShellExecutionService.resizePty( + activePtyId, + Math.floor(terminalWidth * SHELL_WIDTH_FRACTION), + Math.floor(availableTerminalHeight - SHELL_HEIGHT_PADDING), + ); + } + }, [ + terminalHeight, + terminalWidth, + availableTerminalHeight, + activePtyId, + geminiClient, + ]); + useEffect(() => { if ( initialPrompt && + isConfigInitialized && !initialPromptSubmitted.current && !isAuthenticating && !isAuthDialogOpen && @@ -743,6 +680,7 @@ Logging in with Google... Please restart Gemini CLI to continue. } }, [ initialPrompt, + isConfigInitialized, handleFinalSubmit, isAuthenticating, isAuthDialogOpen, @@ -786,7 +724,7 @@ Logging in with Google... Please restart Gemini CLI to continue. const [showIdeRestartPrompt, setShowIdeRestartPrompt] = useState(false); const { isFolderTrustDialogOpen, handleFolderTrustSelect, isRestarting } = - useFolderTrust(settings, config, setIsTrustedFolder); + useFolderTrust(settings, setIsTrustedFolder, refreshStatic); const { needsRestart: ideNeedsRestart } = useIdeTrustListener(); const isInitialMount = useRef(true); @@ -813,8 +751,8 @@ Logging in with Google... Please restart Gemini CLI to continue. }, [terminalWidth, refreshStatic]); useEffect(() => { - const unsubscribe = ideContext.subscribeToIdeContext(setIdeContextState); - setIdeContextState(ideContext.getIdeContext()); + const unsubscribe = ideContextStore.subscribe(setIdeContextState); + setIdeContextState(ideContextStore.get()); return unsubscribe; }, []); @@ -903,7 +841,6 @@ Logging in with Google... Please restart Gemini CLI to continue. isEditorDialogOpen || isSettingsDialogOpen || isFolderTrustDialogOpen || - isAuthenticating || showPrivacyNotice; if (anyDialogOpen) { return; @@ -932,9 +869,6 @@ Logging in with Google... Please restart Gemini CLI to continue. ) { handleSlashCommand('/ide status'); } else if (keyMatchers[Command.QUIT](key)) { - if (isAuthenticating) { - return; - } if (!ctrlCPressedOnce) { cancelOngoingRequest?.(); } @@ -949,6 +883,10 @@ Logging in with Google... Please restart Gemini CLI to continue. !enteringConstrainHeightMode ) { setConstrainHeight(false); + } else if (keyMatchers[Command.TOGGLE_SHELL_INPUT_FOCUS](key)) { + if (activePtyId || shellFocused) { + setShellFocused((prev) => !prev); + } } }, [ @@ -968,7 +906,6 @@ Logging in with Google... Please restart Gemini CLI to continue. setCtrlDPressedOnce, ctrlDTimerRef, handleSlashCommand, - isAuthenticating, cancelOngoingRequest, isThemeDialogOpen, isAuthDialogOpen, @@ -976,6 +913,8 @@ Logging in with Google... Please restart Gemini CLI to continue. isSettingsDialogOpen, isFolderTrustDialogOpen, showPrivacyNotice, + activePtyId, + shellFocused, settings.merged.general?.debugKeystrokeLogging, ], ); @@ -1008,35 +947,20 @@ Logging in with Google... Please restart Gemini CLI to continue. const nightly = props.version.includes('nightly'); - const dialogsVisible = useMemo( - () => - showWorkspaceMigrationDialog || - shouldShowIdePrompt || - isFolderTrustDialogOpen || - !!shellConfirmationRequest || - !!confirmationRequest || - isThemeDialogOpen || - isSettingsDialogOpen || - isAuthenticating || - isAuthDialogOpen || - isEditorDialogOpen || - showPrivacyNotice || - isProQuotaDialogOpen, - [ - showWorkspaceMigrationDialog, - shouldShowIdePrompt, - isFolderTrustDialogOpen, - shellConfirmationRequest, - confirmationRequest, - isThemeDialogOpen, - isSettingsDialogOpen, - isAuthenticating, - isAuthDialogOpen, - isEditorDialogOpen, - showPrivacyNotice, - isProQuotaDialogOpen, - ], - ); + const dialogsVisible = + showWorkspaceMigrationDialog || + shouldShowIdePrompt || + isFolderTrustDialogOpen || + !!shellConfirmationRequest || + !!confirmationRequest || + !!loopDetectionConfirmationRequest || + isThemeDialogOpen || + isSettingsDialogOpen || + isAuthenticating || + isAuthDialogOpen || + isEditorDialogOpen || + showPrivacyNotice || + !!proQuotaRequest; const pendingHistoryItems = useMemo( () => [...pendingSlashCommandHistoryItems, ...pendingGeminiHistoryItems], @@ -1049,6 +973,7 @@ Logging in with Google... Please restart Gemini CLI to continue. isThemeDialogOpen, themeError, isAuthenticating, + isConfigInitialized, authError, isAuthDialogOpen, editorError, @@ -1063,6 +988,7 @@ Logging in with Google... Please restart Gemini CLI to continue. commandContext, shellConfirmationRequest, confirmationRequest, + loopDetectionConfirmationRequest, geminiMdFileCount, streamingState, initError, @@ -1093,11 +1019,9 @@ Logging in with Google... Please restart Gemini CLI to continue. showAutoAcceptIndicator, showWorkspaceMigrationDialog, workspaceExtensions, - // Use current state values instead of config.getModel() currentModel, userTier, - isProQuotaDialogOpen, - // New fields + proQuotaRequest, contextFileNames, errorCount, availableTerminalHeight, @@ -1116,12 +1040,15 @@ Logging in with Google... Please restart Gemini CLI to continue. updateInfo, showIdeRestartPrompt, isRestarting, + activePtyId, + shellFocused, }), [ historyManager.history, isThemeDialogOpen, themeError, isAuthenticating, + isConfigInitialized, authError, isAuthDialogOpen, editorError, @@ -1136,6 +1063,7 @@ Logging in with Google... Please restart Gemini CLI to continue. commandContext, shellConfirmationRequest, confirmationRequest, + loopDetectionConfirmationRequest, geminiMdFileCount, streamingState, initError, @@ -1166,10 +1094,8 @@ Logging in with Google... Please restart Gemini CLI to continue. showAutoAcceptIndicator, showWorkspaceMigrationDialog, workspaceExtensions, - // Quota-related state dependencies userTier, - isProQuotaDialogOpen, - // New fields dependencies + proQuotaRequest, contextFileNames, errorCount, availableTerminalHeight, @@ -1188,8 +1114,9 @@ Logging in with Google... Please restart Gemini CLI to continue. updateInfo, showIdeRestartPrompt, isRestarting, - // Quota-related dependencies currentModel, + activePtyId, + shellFocused, ], ); diff --git a/packages/cli/src/ui/IdeIntegrationNudge.tsx b/packages/cli/src/ui/IdeIntegrationNudge.tsx index 08389910f8e..dc9510b3618 100644 --- a/packages/cli/src/ui/IdeIntegrationNudge.tsx +++ b/packages/cli/src/ui/IdeIntegrationNudge.tsx @@ -10,6 +10,7 @@ import { Box, Text } from 'ink'; import type { RadioSelectItem } from './components/shared/RadioButtonSelect.js'; import { RadioButtonSelect } from './components/shared/RadioButtonSelect.js'; import { useKeypress } from './hooks/useKeypress.js'; +import { theme } from './semantic-colors.js'; export type IdeIntegrationNudgeResult = { userSelection: 'yes' | 'no' | 'dismiss'; @@ -79,17 +80,17 @@ export function IdeIntegrationNudge({ - {'> '} + {'> '} {`Do you want to connect ${ideName ?? 'your editor'} to Gemini CLI?`} - {installText} + {installText} diff --git a/packages/cli/src/ui/auth/AuthDialog.tsx b/packages/cli/src/ui/auth/AuthDialog.tsx index 2437e81b8c9..b3c22a7d62b 100644 --- a/packages/cli/src/ui/auth/AuthDialog.tsx +++ b/packages/cli/src/ui/auth/AuthDialog.tsx @@ -7,7 +7,7 @@ import type React from 'react'; import { useCallback } from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { RadioButtonSelect } from '../components/shared/RadioButtonSelect.js'; import type { LoadedSettings } from '../../config/settings.js'; import { SettingScope } from '../../config/settings.js'; @@ -149,14 +149,18 @@ Logging in with Google... Please restart Gemini CLI to continue. return ( - Get started + + Get started + - How would you like to authenticate for this project? + + How would you like to authenticate for this project? + {authError && ( - {authError} + {authError} )} - (Use Enter to select) + (Use Enter to select) - Terms of Services and Privacy Notice for Gemini CLI + + Terms of Services and Privacy Notice for Gemini CLI + - + { 'https://github.com/google-gemini/gemini-cli/blob/main/docs/tos-privacy.md' } diff --git a/packages/cli/src/ui/auth/AuthInProgress.tsx b/packages/cli/src/ui/auth/AuthInProgress.tsx index ce8b54435d4..6270ecf1fb6 100644 --- a/packages/cli/src/ui/auth/AuthInProgress.tsx +++ b/packages/cli/src/ui/auth/AuthInProgress.tsx @@ -8,7 +8,7 @@ import type React from 'react'; import { useState, useEffect } from 'react'; import { Box, Text } from 'ink'; import Spinner from 'ink-spinner'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { useKeypress } from '../hooks/useKeypress.js'; interface AuthInProgressProps { @@ -41,13 +41,13 @@ export function AuthInProgress({ return ( {timedOut ? ( - + Authentication timed out. Please try again. ) : ( diff --git a/packages/cli/src/ui/commands/chatCommand.ts b/packages/cli/src/ui/commands/chatCommand.ts index f62e67abe44..9a7936ce8e8 100644 --- a/packages/cli/src/ui/commands/chatCommand.ts +++ b/packages/cli/src/ui/commands/chatCommand.ts @@ -7,7 +7,7 @@ import * as fsPromises from 'node:fs/promises'; import React from 'react'; import { Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import type { CommandContext, SlashCommand, @@ -126,7 +126,7 @@ const saveCommand: SlashCommand = { Text, null, 'A checkpoint with the tag ', - React.createElement(Text, { color: Colors.AccentPurple }, tag), + React.createElement(Text, { color: theme.text.accent }, tag), ' already exists. Do you want to overwrite it?', ), originalInvocation: { diff --git a/packages/cli/src/ui/commands/corgiCommand.ts b/packages/cli/src/ui/commands/corgiCommand.ts index bdbcc05f568..2da6ad3ed1d 100644 --- a/packages/cli/src/ui/commands/corgiCommand.ts +++ b/packages/cli/src/ui/commands/corgiCommand.ts @@ -9,8 +9,8 @@ import { CommandKind, type SlashCommand } from './types.js'; export const corgiCommand: SlashCommand = { name: 'corgi', description: 'Toggles corgi mode.', - kind: CommandKind.BUILT_IN, hidden: true, + kind: CommandKind.BUILT_IN, action: (context, _args) => { context.ui.toggleCorgiMode(); }, diff --git a/packages/cli/src/ui/commands/extensionsCommand.test.ts b/packages/cli/src/ui/commands/extensionsCommand.test.ts index 0a69e01c66c..191c2a343d6 100644 --- a/packages/cli/src/ui/commands/extensionsCommand.test.ts +++ b/packages/cli/src/ui/commands/extensionsCommand.test.ts @@ -4,16 +4,42 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect } from 'vitest'; -import { extensionsCommand } from './extensionsCommand.js'; -import { type CommandContext } from './types.js'; +import { + updateAllUpdatableExtensions, + updateExtensionByName, +} from '../../config/extension.js'; import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; import { MessageType } from '../types.js'; +import { extensionsCommand } from './extensionsCommand.js'; +import { type CommandContext } from './types.js'; +import { + describe, + it, + expect, + vi, + beforeEach, + type MockedFunction, +} from 'vitest'; + +vi.mock('../../config/extension.js', () => ({ + updateExtensionByName: vi.fn(), + updateAllUpdatableExtensions: vi.fn(), +})); + +const mockUpdateExtensionByName = updateExtensionByName as MockedFunction< + typeof updateExtensionByName +>; + +const mockUpdateAllUpdatableExtensions = + updateAllUpdatableExtensions as MockedFunction< + typeof updateAllUpdatableExtensions + >; describe('extensionsCommand', () => { let mockContext: CommandContext; - it('should display "No active extensions." when none are found', async () => { + beforeEach(() => { + vi.resetAllMocks(); mockContext = createMockCommandContext({ services: { config: { @@ -21,47 +47,182 @@ describe('extensionsCommand', () => { }, }, }); + }); - if (!extensionsCommand.action) throw new Error('Action not defined'); - await extensionsCommand.action(mockContext, ''); + describe('list', () => { + it('should display "No active extensions." when none are found', async () => { + if (!extensionsCommand.action) throw new Error('Action not defined'); + await extensionsCommand.action(mockContext, ''); - expect(mockContext.ui.addItem).toHaveBeenCalledWith( - { - type: MessageType.INFO, - text: 'No active extensions.', - }, - expect.any(Number), - ); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.INFO, + text: 'No active extensions.', + }, + expect.any(Number), + ); + }); + + it('should list active extensions when they are found', async () => { + const mockExtensions = [ + { name: 'ext-one', version: '1.0.0', isActive: true }, + { name: 'ext-two', version: '2.1.0', isActive: true }, + { name: 'ext-three', version: '3.0.0', isActive: false }, + ]; + mockContext = createMockCommandContext({ + services: { + config: { + getExtensions: () => mockExtensions, + }, + }, + }); + + if (!extensionsCommand.action) throw new Error('Action not defined'); + await extensionsCommand.action(mockContext, ''); + + const expectedMessage = + 'Active extensions:\n\n' + + ` - \u001b[36mext-one (v1.0.0)\u001b[0m\n` + + ` - \u001b[36mext-two (v2.1.0)\u001b[0m\n`; + + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.INFO, + text: expectedMessage, + }, + expect.any(Number), + ); + }); }); - it('should list active extensions when they are found', async () => { - const mockExtensions = [ - { name: 'ext-one', version: '1.0.0', isActive: true }, - { name: 'ext-two', version: '2.1.0', isActive: true }, - { name: 'ext-three', version: '3.0.0', isActive: false }, - ]; - mockContext = createMockCommandContext({ - services: { - config: { - getExtensions: () => mockExtensions, + describe('update', () => { + const updateAction = extensionsCommand.subCommands?.find( + (cmd) => cmd.name === 'update', + )?.action; + + if (!updateAction) { + throw new Error('Update action not found'); + } + + it('should show usage if no args are provided', async () => { + await updateAction(mockContext, ''); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.ERROR, + text: 'Usage: /extensions update |--all', }, - }, + expect.any(Number), + ); }); - if (!extensionsCommand.action) throw new Error('Action not defined'); - await extensionsCommand.action(mockContext, ''); + it('should inform user if there are no extensions to update with --all', async () => { + mockUpdateAllUpdatableExtensions.mockResolvedValue([]); + await updateAction(mockContext, '--all'); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.INFO, + text: 'No extensions to update.', + }, + expect.any(Number), + ); + }); - const expectedMessage = - 'Active extensions:\n\n' + - ` - \u001b[36mext-one (v1.0.0)\u001b[0m\n` + - ` - \u001b[36mext-two (v2.1.0)\u001b[0m\n`; + it('should update all extensions with --all', async () => { + mockUpdateAllUpdatableExtensions.mockResolvedValue([ + { + name: 'ext-one', + originalVersion: '1.0.0', + updatedVersion: '1.0.1', + }, + { + name: 'ext-two', + originalVersion: '2.0.0', + updatedVersion: '2.0.1', + }, + ]); + await updateAction(mockContext, '--all'); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.INFO, + text: + 'Extension "ext-one" successfully updated: 1.0.0 → 1.0.1.\n' + + 'Extension "ext-two" successfully updated: 2.0.0 → 2.0.1.\n' + + 'Restart gemini-cli to see the changes.', + }, + expect.any(Number), + ); + }); - expect(mockContext.ui.addItem).toHaveBeenCalledWith( - { - type: MessageType.INFO, - text: expectedMessage, - }, - expect.any(Number), - ); + it('should handle errors when updating all extensions', async () => { + mockUpdateAllUpdatableExtensions.mockRejectedValue( + new Error('Something went wrong'), + ); + await updateAction(mockContext, '--all'); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.ERROR, + text: 'Something went wrong', + }, + expect.any(Number), + ); + }); + + it('should update a single extension by name', async () => { + mockUpdateExtensionByName.mockResolvedValue({ + name: 'ext-one', + originalVersion: '1.0.0', + updatedVersion: '1.0.1', + }); + await updateAction(mockContext, 'ext-one'); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.INFO, + text: + 'Extension "ext-one" successfully updated: 1.0.0 → 1.0.1.\n' + + 'Restart gemini-cli to see the changes.', + }, + expect.any(Number), + ); + }); + + it('should handle errors when updating a single extension', async () => { + mockUpdateExtensionByName.mockRejectedValue( + new Error('Extension not found'), + ); + await updateAction(mockContext, 'ext-one'); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.ERROR, + text: 'Extension not found', + }, + expect.any(Number), + ); + }); + + it('should update multiple extensions by name', async () => { + mockUpdateExtensionByName + .mockResolvedValueOnce({ + name: 'ext-one', + originalVersion: '1.0.0', + updatedVersion: '1.0.1', + }) + .mockResolvedValueOnce({ + name: 'ext-two', + originalVersion: '2.0.0', + updatedVersion: '2.0.1', + }); + await updateAction(mockContext, 'ext-one ext-two'); + expect(mockUpdateExtensionByName).toHaveBeenCalledTimes(2); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.INFO, + text: + 'Extension "ext-one" successfully updated: 1.0.0 → 1.0.1.\n' + + 'Extension "ext-two" successfully updated: 2.0.0 → 2.0.1.\n' + + 'Restart gemini-cli to see the changes.', + }, + expect.any(Number), + ); + }); }); }); diff --git a/packages/cli/src/ui/commands/extensionsCommand.ts b/packages/cli/src/ui/commands/extensionsCommand.ts index ea9f9a4f404..d0b81ad2a13 100644 --- a/packages/cli/src/ui/commands/extensionsCommand.ts +++ b/packages/cli/src/ui/commands/extensionsCommand.ts @@ -4,43 +4,131 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { + updateExtensionByName, + updateAllUpdatableExtensions, + type ExtensionUpdateInfo, +} from '../../config/extension.js'; +import { getErrorMessage } from '../../utils/errors.js'; +import { MessageType } from '../types.js'; import { type CommandContext, type SlashCommand, CommandKind, } from './types.js'; -import { MessageType } from '../types.js'; -export const extensionsCommand: SlashCommand = { - name: 'extensions', - description: 'list active extensions', - kind: CommandKind.BUILT_IN, - action: async (context: CommandContext): Promise => { - const activeExtensions = context.services.config - ?.getExtensions() - .filter((ext) => ext.isActive); - if (!activeExtensions || activeExtensions.length === 0) { +async function listAction(context: CommandContext) { + const activeExtensions = context.services.config + ?.getExtensions() + .filter((ext) => ext.isActive); + if (!activeExtensions || activeExtensions.length === 0) { + context.ui.addItem( + { + type: MessageType.INFO, + text: 'No active extensions.', + }, + Date.now(), + ); + return; + } + + const extensionLines = activeExtensions.map( + (ext) => ` - \u001b[36m${ext.name} (v${ext.version})\u001b[0m`, + ); + const message = `Active extensions:\n\n${extensionLines.join('\n')}\n`; + + context.ui.addItem( + { + type: MessageType.INFO, + text: message, + }, + Date.now(), + ); +} + +const updateOutput = (info: ExtensionUpdateInfo) => + `Extension "${info.name}" successfully updated: ${info.originalVersion} → ${info.updatedVersion}.`; + +async function updateAction(context: CommandContext, args: string) { + const updateArgs = args.split(' ').filter((value) => value.length > 0); + const all = updateArgs.length === 1 && updateArgs[0] === '--all'; + const names = all ? undefined : updateArgs; + let updateInfos: ExtensionUpdateInfo[] = []; + try { + if (all) { + updateInfos = await updateAllUpdatableExtensions(); + } else if (names?.length) { + for (const name of names) { + updateInfos.push(await updateExtensionByName(name)); + } + } else { context.ui.addItem( { - type: MessageType.INFO, - text: 'No active extensions.', + type: MessageType.ERROR, + text: 'Usage: /extensions update |--all', }, Date.now(), ); return; } - const extensionLines = activeExtensions.map( - (ext) => ` - \u001b[36m${ext.name} (v${ext.version})\u001b[0m`, + // Filter to the actually updated ones. + updateInfos = updateInfos.filter( + (info) => info.originalVersion !== info.updatedVersion, ); - const message = `Active extensions:\n\n${extensionLines.join('\n')}\n`; + + if (updateInfos.length === 0) { + context.ui.addItem( + { + type: MessageType.INFO, + text: 'No extensions to update.', + }, + Date.now(), + ); + return; + } context.ui.addItem( { type: MessageType.INFO, - text: message, + text: [ + ...updateInfos.map((info) => updateOutput(info)), + 'Restart gemini-cli to see the changes.', + ].join('\n'), + }, + Date.now(), + ); + } catch (error) { + context.ui.addItem( + { + type: MessageType.ERROR, + text: getErrorMessage(error), }, Date.now(), ); - }, + } +} + +const listExtensionsCommand: SlashCommand = { + name: 'list', + description: 'List active extensions', + kind: CommandKind.BUILT_IN, + action: listAction, +}; + +const updateExtensionsCommand: SlashCommand = { + name: 'update', + description: 'Update extensions. Usage: update |--all', + kind: CommandKind.BUILT_IN, + action: updateAction, +}; + +export const extensionsCommand: SlashCommand = { + name: 'extensions', + description: 'Manage extensions', + kind: CommandKind.BUILT_IN, + subCommands: [listExtensionsCommand, updateExtensionsCommand], + action: (context, args) => + // Default to list if no subcommand is provided + listExtensionsCommand.action!(context, args), }; diff --git a/packages/cli/src/ui/commands/ideCommand.ts b/packages/cli/src/ui/commands/ideCommand.ts index 681b3a3b644..a9d6d9b2d6e 100644 --- a/packages/cli/src/ui/commands/ideCommand.ts +++ b/packages/cli/src/ui/commands/ideCommand.ts @@ -15,7 +15,7 @@ import { import { getIdeInstaller, IDEConnectionStatus, - ideContext, + ideContextStore, GEMINI_CLI_COMPANION_EXTENSION_NAME, } from '@blocksuser/gemini-cli-core'; import path from 'node:path'; @@ -90,7 +90,7 @@ async function getIdeStatusMessageWithFiles(ideClient: IdeClient): Promise<{ switch (connection.status) { case IDEConnectionStatus.Connected: { let content = `🟢 Connected to ${ideClient.getDetectedIdeDisplayName()}`; - const context = ideContext.getIdeContext(); + const context = ideContextStore.get(); const openFiles = context?.workspaceState?.openFiles; if (openFiles && openFiles.length > 0) { content += formatFileList(openFiles); diff --git a/packages/cli/src/ui/components/AboutBox.tsx b/packages/cli/src/ui/components/AboutBox.tsx index 1d58495ded5..129ed30351e 100644 --- a/packages/cli/src/ui/components/AboutBox.tsx +++ b/packages/cli/src/ui/components/AboutBox.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { GIT_COMMIT_INFO } from '../../generated/git-commit.js'; interface AboutBoxProps { @@ -30,77 +30,77 @@ export const AboutBox: React.FC = ({ }) => ( - + About Gemini CLI - + CLI Version - {cliVersion} + {cliVersion} {GIT_COMMIT_INFO && !['N/A'].includes(GIT_COMMIT_INFO) && ( - + Git Commit - {GIT_COMMIT_INFO} + {GIT_COMMIT_INFO} )} - + Model - {modelVersion} + {modelVersion} - + Sandbox - {sandboxEnv} + {sandboxEnv} - + OS - {osVersion} + {osVersion} - + Auth Method - + {selectedAuthType.startsWith('oauth') ? 'OAuth' : selectedAuthType} @@ -108,24 +108,24 @@ export const AboutBox: React.FC = ({ {gcpProject && ( - + GCP Project - {gcpProject} + {gcpProject} )} {ideClient && ( - + IDE Client - {ideClient} + {ideClient} )} diff --git a/packages/cli/src/ui/components/AnsiOutput.test.tsx b/packages/cli/src/ui/components/AnsiOutput.test.tsx new file mode 100644 index 00000000000..b0cdd7373d7 --- /dev/null +++ b/packages/cli/src/ui/components/AnsiOutput.test.tsx @@ -0,0 +1,106 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { render } from 'ink-testing-library'; +import { AnsiOutputText } from './AnsiOutput.js'; +import type { AnsiOutput, AnsiToken } from '@blocksuser/gemini-cli-core'; + +// Helper to create a valid AnsiToken with default values +const createAnsiToken = (overrides: Partial): AnsiToken => ({ + text: '', + bold: false, + italic: false, + underline: false, + dim: false, + inverse: false, + fg: '#ffffff', + bg: '#000000', + ...overrides, +}); + +describe('', () => { + it('renders a simple AnsiOutput object correctly', () => { + const data: AnsiOutput = [ + [ + createAnsiToken({ text: 'Hello, ' }), + createAnsiToken({ text: 'world!' }), + ], + ]; + const { lastFrame } = render(); + expect(lastFrame()).toBe('Hello, world!'); + }); + + it('correctly applies all the styles', () => { + const data: AnsiOutput = [ + [ + createAnsiToken({ text: 'Bold', bold: true }), + createAnsiToken({ text: 'Italic', italic: true }), + createAnsiToken({ text: 'Underline', underline: true }), + createAnsiToken({ text: 'Dim', dim: true }), + createAnsiToken({ text: 'Inverse', inverse: true }), + ], + ]; + // Note: ink-testing-library doesn't render styles, so we can only check the text. + // We are testing that it renders without crashing. + const { lastFrame } = render(); + expect(lastFrame()).toBe('BoldItalicUnderlineDimInverse'); + }); + + it('correctly applies foreground and background colors', () => { + const data: AnsiOutput = [ + [ + createAnsiToken({ text: 'Red FG', fg: '#ff0000' }), + createAnsiToken({ text: 'Blue BG', bg: '#0000ff' }), + ], + ]; + // Note: ink-testing-library doesn't render colors, so we can only check the text. + // We are testing that it renders without crashing. + const { lastFrame } = render(); + expect(lastFrame()).toBe('Red FGBlue BG'); + }); + + it('handles empty lines and empty tokens', () => { + const data: AnsiOutput = [ + [createAnsiToken({ text: 'First line' })], + [], + [createAnsiToken({ text: 'Third line' })], + [createAnsiToken({ text: '' })], + ]; + const { lastFrame } = render(); + const output = lastFrame(); + expect(output).toBeDefined(); + const lines = output!.split('\n'); + expect(lines[0]).toBe('First line'); + expect(lines[1]).toBe('Third line'); + }); + + it('respects the availableTerminalHeight prop and slices the lines correctly', () => { + const data: AnsiOutput = [ + [createAnsiToken({ text: 'Line 1' })], + [createAnsiToken({ text: 'Line 2' })], + [createAnsiToken({ text: 'Line 3' })], + [createAnsiToken({ text: 'Line 4' })], + ]; + const { lastFrame } = render( + , + ); + const output = lastFrame(); + expect(output).not.toContain('Line 1'); + expect(output).not.toContain('Line 2'); + expect(output).toContain('Line 3'); + expect(output).toContain('Line 4'); + }); + + it('renders a large AnsiOutput object without crashing', () => { + const largeData: AnsiOutput = []; + for (let i = 0; i < 1000; i++) { + largeData.push([createAnsiToken({ text: `Line ${i}` })]); + } + const { lastFrame } = render(); + // We are just checking that it renders something without crashing. + expect(lastFrame()).toBeDefined(); + }); +}); diff --git a/packages/cli/src/ui/components/AnsiOutput.tsx b/packages/cli/src/ui/components/AnsiOutput.tsx new file mode 100644 index 00000000000..7ede94fb1be --- /dev/null +++ b/packages/cli/src/ui/components/AnsiOutput.tsx @@ -0,0 +1,46 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type React from 'react'; +import { Text } from 'ink'; +import type { AnsiLine, AnsiOutput, AnsiToken } from '@blocksuser/gemini-cli-core'; + +const DEFAULT_HEIGHT = 24; + +interface AnsiOutputProps { + data: AnsiOutput; + availableTerminalHeight?: number; +} + +export const AnsiOutputText: React.FC = ({ + data, + availableTerminalHeight, +}) => { + const lastLines = data.slice( + -(availableTerminalHeight && availableTerminalHeight > 0 + ? availableTerminalHeight + : DEFAULT_HEIGHT), + ); + return lastLines.map((line: AnsiLine, lineIndex: number) => ( + + {line.length > 0 + ? line.map((token: AnsiToken, tokenIndex: number) => ( + + {token.text} + + )) + : null} + + )); +}; diff --git a/packages/cli/src/ui/components/AppHeader.tsx b/packages/cli/src/ui/components/AppHeader.tsx index 685799111b4..d13af344ad8 100644 --- a/packages/cli/src/ui/components/AppHeader.tsx +++ b/packages/cli/src/ui/components/AppHeader.tsx @@ -18,15 +18,18 @@ interface AppHeaderProps { export const AppHeader = ({ version }: AppHeaderProps) => { const settings = useSettings(); const config = useConfig(); - const { nightly } = useUIState(); + const { nightly, isFolderTrustDialogOpen } = useUIState(); + const showTips = + !isFolderTrustDialogOpen && + !settings.merged.ui?.hideTips && + !config.getScreenReader(); + return ( {!(settings.merged.ui?.hideBanner || config.getScreenReader()) && (
)} - {!(settings.merged.ui?.hideTips || config.getScreenReader()) && ( - - )} + {showTips && } ); }; diff --git a/packages/cli/src/ui/components/AutoAcceptIndicator.tsx b/packages/cli/src/ui/components/AutoAcceptIndicator.tsx index a16a902db75..c60b9fb412a 100644 --- a/packages/cli/src/ui/components/AutoAcceptIndicator.tsx +++ b/packages/cli/src/ui/components/AutoAcceptIndicator.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { ApprovalMode } from '@blocksuser/gemini-cli-core'; interface AutoAcceptIndicatorProps { @@ -22,12 +22,12 @@ export const AutoAcceptIndicator: React.FC = ({ switch (approvalMode) { case ApprovalMode.AUTO_EDIT: - textColor = Colors.AccentGreen; + textColor = theme.status.warning; textContent = 'accepting edits'; subText = ' (shift + tab to toggle)'; break; case ApprovalMode.YOLO: - textColor = Colors.AccentRed; + textColor = theme.status.error; textContent = 'YOLO mode'; subText = ' (ctrl + y to toggle)'; break; @@ -40,7 +40,7 @@ export const AutoAcceptIndicator: React.FC = ({ {textContent} - {subText && {subText}} + {subText && {subText}} ); diff --git a/packages/cli/src/ui/components/Composer.tsx b/packages/cli/src/ui/components/Composer.tsx index 64fa25e9235..e35795a9133 100644 --- a/packages/cli/src/ui/components/Composer.tsx +++ b/packages/cli/src/ui/components/Composer.tsx @@ -14,7 +14,7 @@ import { InputPrompt } from './InputPrompt.js'; import { Footer, type FooterProps } from './Footer.js'; import { ShowMoreLines } from './ShowMoreLines.js'; import { OverflowProvider } from '../contexts/OverflowContext.js'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { isNarrowWidth } from '../utils/isNarrowWidth.js'; import { useUIState } from '../contexts/UIStateContext.js'; import { useUIActions } from '../contexts/UIActionsContext.js'; @@ -23,6 +23,7 @@ import { useConfig } from '../contexts/ConfigContext.js'; import { useSettings } from '../contexts/SettingsContext.js'; import { ApprovalMode } from '@blocksuser/gemini-cli-core'; import { StreamingState } from '../types.js'; +import { ConfigInitDisplay } from '../components/ConfigInitDisplay.js'; const MAX_DISPLAYED_QUEUED_MESSAGES = 3; @@ -57,20 +58,24 @@ export const Composer = () => { return ( - + {!uiState.shellFocused && ( + + )} + + {!uiState.isConfigInitialized && } {uiState.messageQueue.length > 0 && ( @@ -109,14 +114,18 @@ export const Composer = () => { > {process.env['GEMINI_SYSTEM_MD'] && ( - |⌐■_■| + |⌐■_■| )} {uiState.ctrlCPressedOnce ? ( - Press Ctrl+C again to exit. + + Press Ctrl+C again to exit. + ) : uiState.ctrlDPressedOnce ? ( - Press Ctrl+D again to exit. + + Press Ctrl+D again to exit. + ) : uiState.showEscapePrompt ? ( - Press Esc again to clear. + Press Esc again to clear. ) : ( !settings.merged.ui?.hideContextSummary && ( { commandContext={uiState.commandContext} shellModeActive={uiState.shellModeActive} setShellModeActive={uiActions.setShellModeActive} + approvalMode={showAutoAcceptIndicator} onEscapePromptChange={uiActions.onEscapePromptChange} focus={uiState.isFocused} vimHandleInput={uiActions.vimHandleInput} + isShellFocused={uiState.shellFocused} placeholder={ vimEnabled ? " Press 'i' for INSERT mode and 'Esc' for NORMAL mode." diff --git a/packages/cli/src/ui/components/ConfigInitDisplay.tsx b/packages/cli/src/ui/components/ConfigInitDisplay.tsx new file mode 100644 index 00000000000..63971ccd076 --- /dev/null +++ b/packages/cli/src/ui/components/ConfigInitDisplay.tsx @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useEffect, useState } from 'react'; +import { appEvents } from './../../utils/events.js'; +import { Box, Text } from 'ink'; +import { useConfig } from '../contexts/ConfigContext.js'; +import { type McpClient, MCPServerStatus } from '@blocksuser/gemini-cli-core'; +import { GeminiSpinner } from './GeminiRespondingSpinner.js'; +import { theme } from '../semantic-colors.js'; + +export const ConfigInitDisplay = () => { + const config = useConfig(); + const [message, setMessage] = useState('Initializing...'); + + useEffect(() => { + const onChange = (clients?: Map) => { + if (!clients || clients.size === 0) { + setMessage(`Initializing...`); + return; + } + let connected = 0; + for (const client of clients.values()) { + if (client.getStatus() === MCPServerStatus.CONNECTED) { + connected++; + } + } + setMessage(`Connecting to MCP servers... (${connected}/${clients.size})`); + }; + + appEvents.on('mcp-client-update', onChange); + return () => { + appEvents.off('mcp-client-update', onChange); + }; + }, [config]); + + return ( + + + {message} + + + ); +}; diff --git a/packages/cli/src/ui/components/ConsoleSummaryDisplay.tsx b/packages/cli/src/ui/components/ConsoleSummaryDisplay.tsx index 68700f27040..2f2f8a2a757 100644 --- a/packages/cli/src/ui/components/ConsoleSummaryDisplay.tsx +++ b/packages/cli/src/ui/components/ConsoleSummaryDisplay.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; interface ConsoleSummaryDisplayProps { errorCount: number; @@ -25,9 +25,9 @@ export const ConsoleSummaryDisplay: React.FC = ({ return ( {errorCount > 0 && ( - + {errorIcon} {errorCount} error{errorCount > 1 ? 's' : ''}{' '} - (ctrl+o for details) + (ctrl+o for details) )} diff --git a/packages/cli/src/ui/components/ContextSummaryDisplay.tsx b/packages/cli/src/ui/components/ContextSummaryDisplay.tsx index 9cb9d97c96f..dec3626039b 100644 --- a/packages/cli/src/ui/components/ContextSummaryDisplay.tsx +++ b/packages/cli/src/ui/components/ContextSummaryDisplay.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { type IdeContext, type MCPServerConfig } from '@blocksuser/gemini-cli-core'; import { useTerminalSize } from '../hooks/useTerminalSize.js'; import { isNarrowWidth } from '../utils/isNarrowWidth.js'; @@ -99,9 +99,9 @@ export const ContextSummaryDisplay: React.FC = ({ if (isNarrow) { return ( - Using: + Using: {summaryParts.map((part, index) => ( - + {' '}- {part} ))} @@ -111,7 +111,9 @@ export const ContextSummaryDisplay: React.FC = ({ return ( - Using: {summaryParts.join(' | ')} + + Using: {summaryParts.join(' | ')} + ); }; diff --git a/packages/cli/src/ui/components/ContextUsageDisplay.tsx b/packages/cli/src/ui/components/ContextUsageDisplay.tsx index 82be8aaa18d..73cedeb0a90 100644 --- a/packages/cli/src/ui/components/ContextUsageDisplay.tsx +++ b/packages/cli/src/ui/components/ContextUsageDisplay.tsx @@ -5,7 +5,7 @@ */ import { Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { tokenLimit } from '@blocksuser/gemini-cli-core'; export const ContextUsageDisplay = ({ @@ -18,7 +18,7 @@ export const ContextUsageDisplay = ({ const percentage = promptTokenCount / tokenLimit(model); return ( - + ({((1 - percentage) * 100).toFixed(0)}% context left) ); diff --git a/packages/cli/src/ui/components/DebugProfiler.tsx b/packages/cli/src/ui/components/DebugProfiler.tsx index 22c16cfb22d..4a4d6b4c17b 100644 --- a/packages/cli/src/ui/components/DebugProfiler.tsx +++ b/packages/cli/src/ui/components/DebugProfiler.tsx @@ -6,7 +6,7 @@ import { Text } from 'ink'; import { useEffect, useRef, useState } from 'react'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { useKeypress } from '../hooks/useKeypress.js'; export const DebugProfiler = () => { @@ -31,6 +31,6 @@ export const DebugProfiler = () => { } return ( - Renders: {numRenders.current} + Renders: {numRenders.current} ); }; diff --git a/packages/cli/src/ui/components/DetailedMessagesDisplay.tsx b/packages/cli/src/ui/components/DetailedMessagesDisplay.tsx index 454977220a0..b31d088005a 100644 --- a/packages/cli/src/ui/components/DetailedMessagesDisplay.tsx +++ b/packages/cli/src/ui/components/DetailedMessagesDisplay.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import type { ConsoleMessageItem } from '../types.js'; import { MaxSizedBox } from './shared/MaxSizedBox.js'; @@ -31,31 +31,32 @@ export const DetailedMessagesDisplay: React.FC< flexDirection="column" marginTop={1} borderStyle="round" - borderColor={Colors.Gray} + borderColor={theme.border.default} paddingX={1} width={width} > - - Debug Console (ctrl+o to close) + + Debug Console{' '} + (ctrl+o to close) {messages.map((msg, index) => { - let textColor = Colors.Foreground; + let textColor = theme.text.primary; let icon = '\u2139'; // Information source (ℹ) switch (msg.type) { case 'warn': - textColor = Colors.AccentYellow; + textColor = theme.status.warning; icon = '\u26A0'; // Warning sign (⚠) break; case 'error': - textColor = Colors.AccentRed; + textColor = theme.status.error; icon = '\u2716'; // Heavy multiplication x (✖) break; case 'debug': - textColor = Colors.Gray; // Or Colors.Gray + textColor = theme.text.secondary; // Or theme.text.secondary icon = '\u{1F50D}'; // Left-pointing magnifying glass (🔍) break; case 'log': @@ -70,7 +71,7 @@ export const DetailedMessagesDisplay: React.FC< {msg.content} {msg.count && msg.count > 1 && ( - (x{msg.count}) + (x{msg.count}) )} diff --git a/packages/cli/src/ui/components/DialogManager.tsx b/packages/cli/src/ui/components/DialogManager.tsx index 81b9f404256..aee6191d5a0 100644 --- a/packages/cli/src/ui/components/DialogManager.tsx +++ b/packages/cli/src/ui/components/DialogManager.tsx @@ -6,6 +6,7 @@ import { Box, Text } from 'ink'; import { IdeIntegrationNudge } from '../IdeIntegrationNudge.js'; +import { LoopDetectionConfirmation } from './LoopDetectionConfirmation.js'; import { FolderTrustDialog } from './FolderTrustDialog.js'; import { ShellConfirmationDialog } from './ShellConfirmationDialog.js'; import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; @@ -17,12 +18,11 @@ import { EditorSettingsDialog } from './EditorSettingsDialog.js'; import { PrivacyNotice } from '../privacy/PrivacyNotice.js'; import { WorkspaceMigrationDialog } from './WorkspaceMigrationDialog.js'; import { ProQuotaDialog } from './ProQuotaDialog.js'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { useUIState } from '../contexts/UIStateContext.js'; import { useUIActions } from '../contexts/UIActionsContext.js'; import { useConfig } from '../contexts/ConfigContext.js'; import { useSettings } from '../contexts/SettingsContext.js'; -import { DEFAULT_GEMINI_FLASH_MODEL } from '@blocksuser/gemini-cli-core'; import process from 'node:process'; // Props for DialogManager @@ -37,8 +37,8 @@ export const DialogManager = () => { if (uiState.showIdeRestartPrompt) { return ( - - + + Workspace trust has changed. Press 'r' to restart Gemini to apply the changes. @@ -54,11 +54,11 @@ export const DialogManager = () => { /> ); } - if (uiState.isProQuotaDialogOpen) { + if (uiState.proQuotaRequest) { return ( ); @@ -84,6 +84,13 @@ export const DialogManager = () => { ); } + if (uiState.loopDetectionConfirmationRequest) { + return ( + + ); + } if (uiState.confirmationRequest) { return ( @@ -107,7 +114,7 @@ export const DialogManager = () => { {uiState.themeError && ( - {uiState.themeError} + {uiState.themeError} )} { return ( { - /* This is now handled in AppContainer */ + uiActions.onAuthError('Authentication cancelled.'); }} /> ); @@ -160,7 +167,7 @@ export const DialogManager = () => { {uiState.editorError && ( - {uiState.editorError} + {uiState.editorError} )} {focusedSection === 'editor' ? '> ' : ' '}Select Editor{' '} - {otherScopeModifiedMessage} + {otherScopeModifiedMessage} ({ @@ -147,26 +147,28 @@ export function EditorSettingsDialog({ - + (Use Enter to select, Tab to change focus) - Editor Preference + + Editor Preference + - + These editors are currently supported. Please note that some editors cannot be used in sandbox mode. - + Your preferred editor is:{' '} diff --git a/packages/cli/src/ui/components/FolderTrustDialog.test.tsx b/packages/cli/src/ui/components/FolderTrustDialog.test.tsx index c583617aae5..649ae8967ce 100644 --- a/packages/cli/src/ui/components/FolderTrustDialog.test.tsx +++ b/packages/cli/src/ui/components/FolderTrustDialog.test.tsx @@ -98,7 +98,15 @@ describe('FolderTrustDialog', () => { }); }); - describe('parentFolder display', () => { + describe('directory display', () => { + it('should correctly display the folder name for a nested directory', () => { + mockedCwd.mockReturnValue('/home/user/project'); + const { lastFrame } = renderWithProviders( + , + ); + expect(lastFrame()).toContain('Trust folder (project)'); + }); + it('should correctly display the parent folder name for a nested directory', () => { mockedCwd.mockReturnValue('/home/user/project'); const { lastFrame } = renderWithProviders( diff --git a/packages/cli/src/ui/components/FolderTrustDialog.tsx b/packages/cli/src/ui/components/FolderTrustDialog.tsx index 0060187925a..83799a35068 100644 --- a/packages/cli/src/ui/components/FolderTrustDialog.tsx +++ b/packages/cli/src/ui/components/FolderTrustDialog.tsx @@ -6,7 +6,7 @@ import { Box, Text } from 'ink'; import type React from 'react'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import type { RadioSelectItem } from './shared/RadioButtonSelect.js'; import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; import { useKeypress } from '../hooks/useKeypress.js'; @@ -46,11 +46,12 @@ export const FolderTrustDialog: React.FC = ({ { isActive: !!isRestarting }, ); + const dirName = path.basename(process.cwd()); const parentFolder = path.basename(path.dirname(process.cwd())); const options: Array> = [ { - label: 'Trust folder', + label: `Trust folder (${dirName})`, value: FolderTrustChoice.TRUST_FOLDER, }, { @@ -68,14 +69,16 @@ export const FolderTrustDialog: React.FC = ({ - Do you trust this folder? - + + Do you trust this folder? + + Trusting a folder allows Gemini to execute commands it suggests. This is a security feature to prevent accidental execution in untrusted directories. @@ -90,7 +93,7 @@ export const FolderTrustDialog: React.FC = ({ {isRestarting && ( - + To see changes, Gemini CLI must be restarted. Press r to exit and apply changes now. diff --git a/packages/cli/src/ui/components/Footer.tsx b/packages/cli/src/ui/components/Footer.tsx index 6df6ab369be..47345681617 100644 --- a/packages/cli/src/ui/components/Footer.tsx +++ b/packages/cli/src/ui/components/Footer.tsx @@ -138,7 +138,7 @@ export const Footer: React.FC = ({ {corgiMode && ( - | + | @@ -148,7 +148,7 @@ export const Footer: React.FC = ({ )} {!showErrorDetails && errorCount > 0 && ( - | + | )} diff --git a/packages/cli/src/ui/components/GeminiRespondingSpinner.tsx b/packages/cli/src/ui/components/GeminiRespondingSpinner.tsx index caf774e2a8d..cde45a3ec39 100644 --- a/packages/cli/src/ui/components/GeminiRespondingSpinner.tsx +++ b/packages/cli/src/ui/components/GeminiRespondingSpinner.tsx @@ -14,6 +14,7 @@ import { SCREEN_READER_LOADING, SCREEN_READER_RESPONDING, } from '../textConstants.js'; +import { theme } from '../semantic-colors.js'; interface GeminiRespondingSpinnerProps { /** @@ -30,17 +31,37 @@ export const GeminiRespondingSpinner: React.FC< const streamingState = useStreamingContext(); const isScreenReaderEnabled = useIsScreenReaderEnabled(); if (streamingState === StreamingState.Responding) { - return isScreenReaderEnabled ? ( - {SCREEN_READER_RESPONDING} - ) : ( - + return ( + ); } else if (nonRespondingDisplay) { return isScreenReaderEnabled ? ( {SCREEN_READER_LOADING} ) : ( - {nonRespondingDisplay} + {nonRespondingDisplay} ); } return null; }; + +interface GeminiSpinnerProps { + spinnerType?: SpinnerName; + altText?: string; +} + +export const GeminiSpinner: React.FC = ({ + spinnerType = 'dots', + altText, +}) => { + const isScreenReaderEnabled = useIsScreenReaderEnabled(); + return isScreenReaderEnabled ? ( + {altText} + ) : ( + + + + ); +}; diff --git a/packages/cli/src/ui/components/Header.tsx b/packages/cli/src/ui/components/Header.tsx index 5d09ec3b105..694752312fc 100644 --- a/packages/cli/src/ui/components/Header.tsx +++ b/packages/cli/src/ui/components/Header.tsx @@ -7,7 +7,7 @@ import type React from 'react'; import { Box, Text } from 'ink'; import Gradient from 'ink-gradient'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { shortAsciiLogo, longAsciiLogo, tinyAsciiLogo } from './AsciiArt.js'; import { getAsciiArtWidth } from '../utils/textUtils.js'; import { useTerminalSize } from '../hooks/useTerminalSize.js'; @@ -47,8 +47,8 @@ export const Header: React.FC = ({ flexShrink={0} flexDirection="column" > - {Colors.GradientColors ? ( - + {theme.ui.gradient ? ( + {displayTitle} ) : ( @@ -56,8 +56,8 @@ export const Header: React.FC = ({ )} {nightly && ( - {Colors.GradientColors ? ( - + {theme.ui.gradient ? ( + v{version} ) : ( diff --git a/packages/cli/src/ui/components/Help.test.tsx b/packages/cli/src/ui/components/Help.test.tsx new file mode 100644 index 00000000000..ff749643ba6 --- /dev/null +++ b/packages/cli/src/ui/components/Help.test.tsx @@ -0,0 +1,63 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** @vitest-environment jsdom */ + +import { render } from 'ink-testing-library'; +import { describe, it, expect } from 'vitest'; +import { Help } from './Help.js'; +import type { SlashCommand } from '../commands/types.js'; +import { CommandKind } from '../commands/types.js'; + +const mockCommands: readonly SlashCommand[] = [ + { + name: 'test', + description: 'A test command', + kind: CommandKind.BUILT_IN, + }, + { + name: 'hidden', + description: 'A hidden command', + hidden: true, + kind: CommandKind.BUILT_IN, + }, + { + name: 'parent', + description: 'A parent command', + kind: CommandKind.BUILT_IN, + subCommands: [ + { + name: 'visible-child', + description: 'A visible child command', + kind: CommandKind.BUILT_IN, + }, + { + name: 'hidden-child', + description: 'A hidden child command', + hidden: true, + kind: CommandKind.BUILT_IN, + }, + ], + }, +]; + +describe('Help Component', () => { + it('should not render hidden commands', () => { + const { lastFrame } = render(); + const output = lastFrame(); + + expect(output).toContain('/test'); + expect(output).not.toContain('/hidden'); + }); + + it('should not render hidden subcommands', () => { + const { lastFrame } = render(); + const output = lastFrame(); + + expect(output).toContain('visible-child'); + expect(output).not.toContain('hidden-child'); + }); +}); diff --git a/packages/cli/src/ui/components/Help.tsx b/packages/cli/src/ui/components/Help.tsx index 124ab26081c..90a6b26a9c3 100644 --- a/packages/cli/src/ui/components/Help.tsx +++ b/packages/cli/src/ui/components/Help.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { type SlashCommand, CommandKind } from '../commands/types.js'; interface Help { @@ -17,42 +17,42 @@ export const Help: React.FC = ({ commands }) => ( {/* Basics */} - + Basics: - - + + Add context : Use{' '} - + @ {' '} to specify files for context (e.g.,{' '} - + @src/myFile.ts ) to target specific files or folders. - - + + Shell mode : Execute shell commands via{' '} - + ! {' '} (e.g.,{' '} - + !npm run start ) or use natural language (e.g.{' '} - + start server ). @@ -61,119 +61,121 @@ export const Help: React.FC = ({ commands }) => ( {/* Commands */} - + Commands: {commands - .filter((command) => command.description) + .filter((command) => command.description && !command.hidden) .map((command: SlashCommand) => ( - - + + {' '} /{command.name} {command.kind === CommandKind.MCP_PROMPT && ( - [MCP] + [MCP] )} {command.description && ' - ' + command.description} {command.subCommands && - command.subCommands.map((subCommand) => ( - - - {' '} - {subCommand.name} + command.subCommands + .filter((subCommand) => !subCommand.hidden) + .map((subCommand) => ( + + + {' '} + {subCommand.name} + + {subCommand.description && ' - ' + subCommand.description} - {subCommand.description && ' - ' + subCommand.description} - - ))} + ))} ))} - - + + {' '} !{' '} - shell command - - [MCP] - Model Context Protocol command - (from external servers) + + [MCP] - Model Context Protocol + command (from external servers) {/* Shortcuts */} - + Keyboard Shortcuts: - - + + Alt+Left/Right {' '} - Jump through words in the input - - + + Ctrl+C {' '} - Quit application - - + + {process.platform === 'win32' ? 'Ctrl+Enter' : 'Ctrl+J'} {' '} {process.platform === 'linux' ? '- New line (Alt+Enter works for certain linux distros)' : '- New line'} - - + + Ctrl+L {' '} - Clear the screen - - + + {process.platform === 'darwin' ? 'Ctrl+X / Meta+Enter' : 'Ctrl+X'} {' '} - Open input in external editor - - + + Ctrl+Y {' '} - Toggle YOLO mode - - + + Enter {' '} - Send message - - + + Esc {' '} - Cancel operation / Clear input (double press) - - + + Shift+Tab {' '} - Toggle auto-accepting edits - - + + Up/Down {' '} - Cycle through your prompt history - + For a full list of shortcuts, see{' '} - + docs/keyboard-shortcuts.md diff --git a/packages/cli/src/ui/components/HistoryItemDisplay.tsx b/packages/cli/src/ui/components/HistoryItemDisplay.tsx index 9c08a5828fb..bb68b495506 100644 --- a/packages/cli/src/ui/components/HistoryItemDisplay.tsx +++ b/packages/cli/src/ui/components/HistoryItemDisplay.tsx @@ -30,6 +30,8 @@ interface HistoryItemDisplayProps { isPending: boolean; isFocused?: boolean; commands?: readonly SlashCommand[]; + activeShellPtyId?: number | null; + shellFocused?: boolean; } export const HistoryItemDisplay: React.FC = ({ @@ -39,6 +41,8 @@ export const HistoryItemDisplay: React.FC = ({ isPending, commands, isFocused = true, + activeShellPtyId, + shellFocused, }) => ( {/* Render standard message types */} @@ -85,6 +89,8 @@ export const HistoryItemDisplay: React.FC = ({ availableTerminalHeight={availableTerminalHeight} terminalWidth={terminalWidth} isFocused={isFocused} + activeShellPtyId={activeShellPtyId} + shellFocused={shellFocused} /> )} {item.type === 'compression' && ( diff --git a/packages/cli/src/ui/components/InputPrompt.test.tsx b/packages/cli/src/ui/components/InputPrompt.test.tsx index 2b5ac391d6f..754cae402f5 100644 --- a/packages/cli/src/ui/components/InputPrompt.test.tsx +++ b/packages/cli/src/ui/components/InputPrompt.test.tsx @@ -10,6 +10,7 @@ import type { InputPromptProps } from './InputPrompt.js'; import { InputPrompt } from './InputPrompt.js'; import type { TextBuffer } from './shared/text-buffer.js'; import type { Config } from '@blocksuser/gemini-cli-core'; +import { ApprovalMode } from '@blocksuser/gemini-cli-core'; import * as path from 'node:path'; import type { CommandContext, SlashCommand } from '../commands/types.js'; import { CommandKind } from '../commands/types.js'; @@ -207,6 +208,7 @@ describe('InputPrompt', () => { commandContext: mockCommandContext, shellModeActive: false, setShellModeActive: vi.fn(), + approvalMode: ApprovalMode.DEFAULT, inputWidth: 80, suggestionsWidth: 80, focus: true, @@ -1786,4 +1788,36 @@ describe('InputPrompt', () => { unmount(); }); }); + + describe('snapshots', () => { + it('should render correctly in shell mode', async () => { + props.shellModeActive = true; + const { stdout, unmount } = renderWithProviders( + , + ); + await wait(); + expect(stdout.lastFrame()).toMatchSnapshot(); + unmount(); + }); + + it('should render correctly when accepting edits', async () => { + props.approvalMode = ApprovalMode.AUTO_EDIT; + const { stdout, unmount } = renderWithProviders( + , + ); + await wait(); + expect(stdout.lastFrame()).toMatchSnapshot(); + unmount(); + }); + + it('should render correctly in yolo mode', async () => { + props.approvalMode = ApprovalMode.YOLO; + const { stdout, unmount } = renderWithProviders( + , + ); + await wait(); + expect(stdout.lastFrame()).toMatchSnapshot(); + unmount(); + }); + }); }); diff --git a/packages/cli/src/ui/components/InputPrompt.tsx b/packages/cli/src/ui/components/InputPrompt.tsx index 887ee3dbe48..b922d68bcac 100644 --- a/packages/cli/src/ui/components/InputPrompt.tsx +++ b/packages/cli/src/ui/components/InputPrompt.tsx @@ -23,6 +23,7 @@ import { useKeypress } from '../hooks/useKeypress.js'; import { keyMatchers, Command } from '../keyMatchers.js'; import type { CommandContext, SlashCommand } from '../commands/types.js'; import type { Config } from '@blocksuser/gemini-cli-core'; +import { ApprovalMode } from '@blocksuser/gemini-cli-core'; import { parseInputForHighlighting } from '../utils/highlight.js'; import { clipboardHasImage, @@ -46,8 +47,10 @@ export interface InputPromptProps { suggestionsWidth: number; shellModeActive: boolean; setShellModeActive: (value: boolean) => void; + approvalMode: ApprovalMode; onEscapePromptChange?: (showPrompt: boolean) => void; vimHandleInput?: (key: Key) => boolean; + isShellFocused?: boolean; } export const InputPrompt: React.FC = ({ @@ -64,8 +67,10 @@ export const InputPrompt: React.FC = ({ suggestionsWidth, shellModeActive, setShellModeActive, + approvalMode, onEscapePromptChange, vimHandleInput, + isShellFocused, }) => { const [justNavigatedHistory, setJustNavigatedHistory] = useState(false); const [escPressCount, setEscPressCount] = useState(0); @@ -588,7 +593,7 @@ export const InputPrompt: React.FC = ({ ); useKeypress(handleInput, { - isActive: true, + isActive: !isShellFocused, }); const linesToRender = buffer.viewportVisualLines; @@ -709,17 +714,36 @@ export const InputPrompt: React.FC = ({ const { inlineGhost, additionalLines } = getGhostTextLines(); + const showAutoAcceptStyling = + !shellModeActive && approvalMode === ApprovalMode.AUTO_EDIT; + const showYoloStyling = + !shellModeActive && approvalMode === ApprovalMode.YOLO; + + let statusColor: string | undefined; + let statusText = ''; + if (shellModeActive) { + statusColor = theme.ui.symbol; + statusText = 'Shell mode'; + } else if (showYoloStyling) { + statusColor = theme.status.error; + statusText = 'YOLO mode'; + } else if (showAutoAcceptStyling) { + statusColor = theme.status.warning; + statusText = 'Accepting edits'; + } + return ( <> {shellModeActive ? ( reverseSearchActive ? ( @@ -730,11 +754,13 @@ export const InputPrompt: React.FC = ({ (r:){' '} ) : ( - '! ' + '!' ) + ) : showYoloStyling ? ( + '*' ) : ( - '> ' - )} + '>' + )}{' '} {buffer.text.length === 0 && placeholder ? ( @@ -797,7 +823,7 @@ export const InputPrompt: React.FC = ({ const color = token.type === 'command' || token.type === 'file' ? theme.text.accent - : undefined; + : theme.text.primary; renderedLine.push( diff --git a/packages/cli/src/ui/components/LoadingIndicator.tsx b/packages/cli/src/ui/components/LoadingIndicator.tsx index 00f54e0dc62..dc19958a974 100644 --- a/packages/cli/src/ui/components/LoadingIndicator.tsx +++ b/packages/cli/src/ui/components/LoadingIndicator.tsx @@ -7,7 +7,7 @@ import type { ThoughtSummary } from '@blocksuser/gemini-cli-core'; import type React from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { useStreamingContext } from '../contexts/StreamingContext.js'; import { StreamingState } from '../types.js'; import { GeminiRespondingSpinner } from './GeminiRespondingSpinner.js'; @@ -61,11 +61,9 @@ export const LoadingIndicator: React.FC = ({ } /> - {primaryText && ( - {primaryText} - )} + {primaryText && {primaryText}} {!isNarrow && cancelAndTimerContent && ( - {cancelAndTimerContent} + {cancelAndTimerContent} )} {!isNarrow && {/* Spacer */}} @@ -73,7 +71,7 @@ export const LoadingIndicator: React.FC = ({ {isNarrow && cancelAndTimerContent && ( - {cancelAndTimerContent} + {cancelAndTimerContent} )} {isNarrow && rightContent && {rightContent}} diff --git a/packages/cli/src/ui/components/LoopDetectionConfirmation.test.tsx b/packages/cli/src/ui/components/LoopDetectionConfirmation.test.tsx new file mode 100644 index 00000000000..87b57033ee2 --- /dev/null +++ b/packages/cli/src/ui/components/LoopDetectionConfirmation.test.tsx @@ -0,0 +1,34 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { renderWithProviders } from '../../test-utils/render.js'; +import { describe, it, expect, vi } from 'vitest'; +import { LoopDetectionConfirmation } from './LoopDetectionConfirmation.js'; + +describe('LoopDetectionConfirmation', () => { + const onComplete = vi.fn(); + + it('renders correctly', () => { + const { lastFrame } = renderWithProviders( + , + ); + expect(lastFrame()).toMatchSnapshot(); + }); + + it('contains the expected options', () => { + const { lastFrame } = renderWithProviders( + , + ); + const output = lastFrame()!.toString(); + + expect(output).toContain('A potential loop was detected'); + expect(output).toContain('Keep loop detection enabled (esc)'); + expect(output).toContain('Disable loop detection for this session'); + expect(output).toContain( + 'This can happen due to repetitive tool calls or other model behavior', + ); + }); +}); diff --git a/packages/cli/src/ui/components/LoopDetectionConfirmation.tsx b/packages/cli/src/ui/components/LoopDetectionConfirmation.tsx new file mode 100644 index 00000000000..c644c8866af --- /dev/null +++ b/packages/cli/src/ui/components/LoopDetectionConfirmation.tsx @@ -0,0 +1,88 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Box, Text } from 'ink'; +import type { RadioSelectItem } from './shared/RadioButtonSelect.js'; +import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; +import { useKeypress } from '../hooks/useKeypress.js'; +import { theme } from '../semantic-colors.js'; + +export type LoopDetectionConfirmationResult = { + userSelection: 'disable' | 'keep'; +}; + +interface LoopDetectionConfirmationProps { + onComplete: (result: LoopDetectionConfirmationResult) => void; +} + +export function LoopDetectionConfirmation({ + onComplete, +}: LoopDetectionConfirmationProps) { + useKeypress( + (key) => { + if (key.name === 'escape') { + onComplete({ + userSelection: 'keep', + }); + } + }, + { isActive: true }, + ); + + const OPTIONS: Array> = [ + { + label: 'Keep loop detection enabled (esc)', + value: { + userSelection: 'keep', + }, + }, + { + label: 'Disable loop detection for this session', + value: { + userSelection: 'disable', + }, + }, + ]; + + return ( + + + + + + ? + + + + + + A potential loop was detected + {' '} + + + + + + + This can happen due to repetitive tool calls or other model + behavior. Do you want to keep loop detection enabled or disable it + for this session? + + + + + + + + + ); +} diff --git a/packages/cli/src/ui/components/MainContent.tsx b/packages/cli/src/ui/components/MainContent.tsx index ff63d9f7670..ea6ccda6122 100644 --- a/packages/cli/src/ui/components/MainContent.tsx +++ b/packages/cli/src/ui/components/MainContent.tsx @@ -54,6 +54,8 @@ export const MainContent = () => { item={{ ...item, id: 0 }} isPending={true} isFocused={!uiState.isEditorDialogOpen} + activeShellPtyId={uiState.activePtyId} + shellFocused={uiState.shellFocused} /> ))} diff --git a/packages/cli/src/ui/components/MemoryUsageDisplay.tsx b/packages/cli/src/ui/components/MemoryUsageDisplay.tsx index a8b449e5247..1ea3c10693d 100644 --- a/packages/cli/src/ui/components/MemoryUsageDisplay.tsx +++ b/packages/cli/src/ui/components/MemoryUsageDisplay.tsx @@ -7,20 +7,24 @@ import type React from 'react'; import { useEffect, useState } from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import process from 'node:process'; import { formatMemoryUsage } from '../utils/formatters.js'; export const MemoryUsageDisplay: React.FC = () => { const [memoryUsage, setMemoryUsage] = useState(''); - const [memoryUsageColor, setMemoryUsageColor] = useState(Colors.Gray); + const [memoryUsageColor, setMemoryUsageColor] = useState( + theme.text.secondary, + ); useEffect(() => { const updateMemory = () => { const usage = process.memoryUsage().rss; setMemoryUsage(formatMemoryUsage(usage)); setMemoryUsageColor( - usage >= 2 * 1024 * 1024 * 1024 ? Colors.AccentRed : Colors.Gray, + usage >= 2 * 1024 * 1024 * 1024 + ? theme.status.error + : theme.text.secondary, ); }; const intervalId = setInterval(updateMemory, 2000); @@ -30,7 +34,7 @@ export const MemoryUsageDisplay: React.FC = () => { return ( - | + | {memoryUsage} ); diff --git a/packages/cli/src/ui/components/ModelStatsDisplay.tsx b/packages/cli/src/ui/components/ModelStatsDisplay.tsx index 2316cf8f58c..95a8fe46055 100644 --- a/packages/cli/src/ui/components/ModelStatsDisplay.tsx +++ b/packages/cli/src/ui/components/ModelStatsDisplay.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { formatDuration } from '../utils/formatters.js'; import { calculateAverageLatency, @@ -34,13 +34,16 @@ const StatRow: React.FC = ({ }) => ( - + {isSubtle ? ` ↳ ${title}` : title} {values.map((value, index) => ( - {value} + {value} ))} @@ -57,11 +60,13 @@ export const ModelStatsDisplay: React.FC = () => { return ( - No API calls have been made in this session. + + No API calls have been made in this session. + ); } @@ -83,12 +88,12 @@ export const ModelStatsDisplay: React.FC = () => { return ( - + Model Stats For Nerds @@ -96,11 +101,15 @@ export const ModelStatsDisplay: React.FC = () => { {/* Header */} - Metric + + Metric + {modelNames.map((name) => ( - {name} + + {name} + ))} @@ -112,6 +121,7 @@ export const ModelStatsDisplay: React.FC = () => { borderTop={false} borderLeft={false} borderRight={false} + borderColor={theme.border.default} /> {/* API Section */} @@ -127,7 +137,7 @@ export const ModelStatsDisplay: React.FC = () => { return ( 0 ? Colors.AccentRed : Colors.Foreground + m.api.totalErrors > 0 ? theme.status.error : theme.text.primary } > {m.api.totalErrors.toLocaleString()} ({errorRate.toFixed(1)}%) @@ -150,7 +160,7 @@ export const ModelStatsDisplay: React.FC = () => { ( - + {m.tokens.total.toLocaleString()} ))} @@ -167,7 +177,7 @@ export const ModelStatsDisplay: React.FC = () => { values={getModelValues((m) => { const cacheHitRate = calculateCacheHitRate(m); return ( - + {m.tokens.cached.toLocaleString()} ({cacheHitRate.toFixed(1)}%) ); diff --git a/packages/cli/src/ui/components/Notifications.tsx b/packages/cli/src/ui/components/Notifications.tsx index 954945d3ada..a287b105246 100644 --- a/packages/cli/src/ui/components/Notifications.tsx +++ b/packages/cli/src/ui/components/Notifications.tsx @@ -7,7 +7,7 @@ import { Box, Text } from 'ink'; import { useAppContext } from '../contexts/AppContext.js'; import { useUIState } from '../contexts/UIStateContext.js'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { StreamingState } from '../types.js'; import { UpdateNotification } from './UpdateNotification.js'; @@ -29,13 +29,13 @@ export const Notifications = () => { {showStartupWarnings && ( {startupWarnings.map((warning, index) => ( - + {warning} ))} @@ -44,14 +44,14 @@ export const Notifications = () => { {showInitError && ( - + Initialization Error: {initError} - + {' '} Please check API key and configuration. diff --git a/packages/cli/src/ui/components/PrepareLabel.tsx b/packages/cli/src/ui/components/PrepareLabel.tsx index d89c1fe4803..37ad5a33136 100644 --- a/packages/cli/src/ui/components/PrepareLabel.tsx +++ b/packages/cli/src/ui/components/PrepareLabel.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; interface PrepareLabelProps { label: string; @@ -21,7 +21,7 @@ export const PrepareLabel: React.FC = ({ matchedIndex, userInput, textColor, - highlightColor = Colors.AccentYellow, + highlightColor = theme.status.warning, }) => { if ( matchedIndex === undefined || diff --git a/packages/cli/src/ui/components/ProQuotaDialog.test.tsx b/packages/cli/src/ui/components/ProQuotaDialog.test.tsx index 31bb4f03f67..c3a1afda6a0 100644 --- a/packages/cli/src/ui/components/ProQuotaDialog.test.tsx +++ b/packages/cli/src/ui/components/ProQuotaDialog.test.tsx @@ -22,7 +22,7 @@ describe('ProQuotaDialog', () => { it('should render with correct title and options', () => { const { lastFrame } = render( {}} />, @@ -53,7 +53,7 @@ describe('ProQuotaDialog', () => { const mockOnChoice = vi.fn(); render( , @@ -72,7 +72,7 @@ describe('ProQuotaDialog', () => { const mockOnChoice = vi.fn(); render( , diff --git a/packages/cli/src/ui/components/ProQuotaDialog.tsx b/packages/cli/src/ui/components/ProQuotaDialog.tsx index d94d0698571..17e5965e0bc 100644 --- a/packages/cli/src/ui/components/ProQuotaDialog.tsx +++ b/packages/cli/src/ui/components/ProQuotaDialog.tsx @@ -7,16 +7,16 @@ import type React from 'react'; import { Box, Text } from 'ink'; import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; interface ProQuotaDialogProps { - currentModel: string; + failedModel: string; fallbackModel: string; onChoice: (choice: 'auth' | 'continue') => void; } export function ProQuotaDialog({ - currentModel, + failedModel, fallbackModel, onChoice, }: ProQuotaDialogProps): React.JSX.Element { @@ -37,8 +37,8 @@ export function ProQuotaDialog({ return ( - - Pro quota limit reached for {currentModel}. + + Pro quota limit reached for {failedModel}. { - const actual = await vi.importActual('../contexts/SettingsContext.js'); - let settings = createMockSettings({ 'a.string.setting': 'initial' }); +vi.mock('../../config/settingsSchema.js', async (importOriginal) => { + const original = + await importOriginal(); return { - ...actual, - useSettings: () => ({ - settings, - setSetting: (key: string, value: string) => { - settings = createMockSettings({ [key]: value }); - }, - getSettingDefinition: (key: string) => { - if (key === 'a.string.setting') { - return { - type: 'string', - description: 'A string setting', - }; - } - return undefined; - }, - }), + ...original, + getSettingsSchema: vi.fn(original.getSettingsSchema), }; }); @@ -136,7 +139,6 @@ describe('SettingsDialog', () => { const wait = (ms = 50) => new Promise((resolve) => setTimeout(resolve, ms)); beforeEach(() => { - vi.clearAllMocks(); // Reset keypress mock state (variables are commented out) // currentKeypressHandler = null; // isKeypressActive = false; @@ -146,6 +148,9 @@ describe('SettingsDialog', () => { }); afterEach(() => { + TEST_ONLY.clearFlattenedSchema(); + vi.clearAllMocks(); + vi.resetAllMocks(); // Reset keypress mock state (variables are commented out) // currentKeypressHandler = null; // isKeypressActive = false; @@ -153,44 +158,6 @@ describe('SettingsDialog', () => { // console.error = originalConsoleError; }); - const createMockSettings = ( - userSettings = {}, - systemSettings = {}, - workspaceSettings = {}, - ) => - new LoadedSettings( - { - settings: { - ui: { customThemes: {} }, - mcpServers: {}, - ...systemSettings, - }, - path: '/system/settings.json', - }, - { - settings: {}, - path: '/system/system-defaults.json', - }, - { - settings: { - ui: { customThemes: {} }, - mcpServers: {}, - ...userSettings, - }, - path: '/user/settings.json', - }, - { - settings: { - ui: { customThemes: {} }, - mcpServers: {}, - ...workspaceSettings, - }, - path: '/workspace/settings.json', - }, - true, - new Set(), - ); - describe('Initial Rendering', () => { it('should render the settings dialog with default state', () => { const settings = createMockSettings(); @@ -244,15 +211,18 @@ describe('SettingsDialog', () => { const settings = createMockSettings(); const onSelect = vi.fn(); - const { stdin, unmount } = render( + const { stdin, unmount, lastFrame } = render( , ); // Press down arrow - stdin.write('\u001B[B'); // Down arrow - await wait(); + act(() => { + stdin.write(TerminalKeys.DOWN_ARROW as string); // Down arrow + }); + + expect(lastFrame()).toContain('● Disable Auto Update'); // The active index should have changed (tested indirectly through behavior) unmount(); @@ -269,9 +239,9 @@ describe('SettingsDialog', () => { ); // First go down, then up - stdin.write('\u001B[B'); // Down arrow + stdin.write(TerminalKeys.DOWN_ARROW as string); // Down arrow await wait(); - stdin.write('\u001B[A'); // Up arrow + stdin.write(TerminalKeys.UP_ARROW as string); await wait(); unmount(); @@ -296,21 +266,25 @@ describe('SettingsDialog', () => { unmount(); }); - it('should not navigate beyond bounds', async () => { + it('wraps around when at the top of the list', async () => { const settings = createMockSettings(); const onSelect = vi.fn(); - const { stdin, unmount } = render( + const { stdin, unmount, lastFrame } = render( , ); // Try to go up from first item - stdin.write('\u001B[A'); // Up arrow + act(() => { + stdin.write(TerminalKeys.UP_ARROW); + }); + await wait(); - // Should still be on first item + expect(lastFrame()).toContain('● Folder Trust'); + unmount(); }); }); @@ -319,20 +293,142 @@ describe('SettingsDialog', () => { it('should toggle setting with Enter key', async () => { const settings = createMockSettings(); const onSelect = vi.fn(); - - const { stdin, unmount } = render( + const component = ( - , + ); + const { stdin, unmount } = render(component); + // Press Enter to toggle current setting - stdin.write('\u000D'); // Enter key + stdin.write(TerminalKeys.DOWN_ARROW as string); + await wait(); + stdin.write(TerminalKeys.ENTER as string); await wait(); + expect(vi.mocked(saveModifiedSettings)).toHaveBeenCalledWith( + new Set(['general.disableAutoUpdate']), + { + general: { + disableAutoUpdate: true, + }, + }, + expect.any(LoadedSettings), + SettingScope.User, + ); + unmount(); }); + describe('enum values', () => { + enum StringEnum { + FOO = 'foo', + BAR = 'bar', + BAZ = 'baz', + } + + const SETTING: SettingDefinition = { + type: 'enum', + label: 'Theme', + options: [ + { + label: 'Foo', + value: StringEnum.FOO, + }, + { + label: 'Bar', + value: StringEnum.BAR, + }, + { + label: 'Baz', + value: StringEnum.BAZ, + }, + ], + category: 'UI', + requiresRestart: false, + default: StringEnum.BAR, + description: 'The color theme for the UI.', + showInDialog: true, + }; + + const FAKE_SCHEMA: SettingsSchemaType = { + ui: { + showInDialog: false, + properties: { + theme: { + ...SETTING, + }, + }, + }, + } as unknown as SettingsSchemaType; + + it('toggles enum values with the enter key', async () => { + vi.mocked(getSettingsSchema).mockReturnValue(FAKE_SCHEMA); + const settings = createMockSettings(); + const onSelect = vi.fn(); + const component = ( + + + + ); + + const { stdin, unmount } = render(component); + + // Press Enter to toggle current setting + stdin.write(TerminalKeys.DOWN_ARROW as string); + await wait(); + stdin.write(TerminalKeys.ENTER as string); + await wait(); + + expect(vi.mocked(saveModifiedSettings)).toHaveBeenCalledWith( + new Set(['ui.theme']), + { + ui: { + theme: StringEnum.BAZ, + }, + }, + expect.any(LoadedSettings), + SettingScope.User, + ); + + unmount(); + }); + + it('loops back when reaching the end of an enum', async () => { + vi.mocked(getSettingsSchema).mockReturnValue(FAKE_SCHEMA); + const settings = createMockSettings(); + settings.setValue(SettingScope.User, 'ui.theme', StringEnum.BAZ); + const onSelect = vi.fn(); + const component = ( + + + + ); + + const { stdin, unmount } = render(component); + + // Press Enter to toggle current setting + stdin.write(TerminalKeys.DOWN_ARROW as string); + await wait(); + stdin.write(TerminalKeys.ENTER as string); + await wait(); + + expect(vi.mocked(saveModifiedSettings)).toHaveBeenCalledWith( + new Set(['ui.theme']), + { + ui: { + theme: StringEnum.FOO, + }, + }, + expect.any(LoadedSettings), + SettingScope.User, + ); + + unmount(); + }); + }); + it('should toggle setting with Space key', async () => { const settings = createMockSettings(); const onSelect = vi.fn(); @@ -362,7 +458,7 @@ describe('SettingsDialog', () => { // Navigate to vim mode setting and toggle it // This would require knowing the exact position, so we'll just test that the mock is called - stdin.write('\u000D'); // Enter key + stdin.write(TerminalKeys.ENTER as string); // Enter key await wait(); // The mock should potentially be called if vim mode was toggled @@ -382,7 +478,7 @@ describe('SettingsDialog', () => { ); // Switch to scope focus - stdin.write('\t'); // Tab key + stdin.write(TerminalKeys.TAB); // Tab key await wait(); // Select different scope (numbers 1-3 typically available) @@ -502,7 +598,7 @@ describe('SettingsDialog', () => { ); // Switch to scope selector - stdin.write('\t'); // Tab + stdin.write(TerminalKeys.TAB as string); // Tab await wait(); // Change scope @@ -547,7 +643,7 @@ describe('SettingsDialog', () => { ); // Try to toggle a setting (this might trigger vim mode toggle) - stdin.write('\u000D'); // Enter + stdin.write(TerminalKeys.ENTER as string); // Enter await wait(); // Should not crash @@ -567,13 +663,13 @@ describe('SettingsDialog', () => { ); // Toggle a setting - stdin.write('\u000D'); // Enter + stdin.write(TerminalKeys.ENTER as string); // Enter await wait(); // Toggle another setting - stdin.write('\u001B[B'); // Down + stdin.write(TerminalKeys.DOWN_ARROW as string); // Down await wait(); - stdin.write('\u000D'); // Enter + stdin.write(TerminalKeys.ENTER as string); // Enter await wait(); // Should track multiple modified settings @@ -592,7 +688,7 @@ describe('SettingsDialog', () => { // Navigate down many times to test scrolling for (let i = 0; i < 10; i++) { - stdin.write('\u001B[B'); // Down arrow + stdin.write(TerminalKeys.DOWN_ARROW as string); // Down arrow await wait(10); } @@ -615,7 +711,7 @@ describe('SettingsDialog', () => { // Navigate to and toggle vim mode setting // This would require knowing the exact position of vim mode setting - stdin.write('\u000D'); // Enter + stdin.write(TerminalKeys.ENTER as string); // Enter await wait(); unmount(); @@ -653,7 +749,7 @@ describe('SettingsDialog', () => { ); // Toggle a non-restart-required setting (like hideTips) - stdin.write('\u000D'); // Enter - toggle current setting + stdin.write(TerminalKeys.ENTER as string); // Enter - toggle current setting await wait(); // Should save immediately without showing restart prompt @@ -750,8 +846,8 @@ describe('SettingsDialog', () => { // Rapid navigation for (let i = 0; i < 5; i++) { - stdin.write('\u001B[B'); // Down arrow - stdin.write('\u001B[A'); // Up arrow + stdin.write(TerminalKeys.DOWN_ARROW as string); + stdin.write(TerminalKeys.UP_ARROW as string); } await wait(100); @@ -806,9 +902,9 @@ describe('SettingsDialog', () => { ); // Try to navigate when potentially at bounds - stdin.write('\u001B[B'); // Down + stdin.write(TerminalKeys.DOWN_ARROW as string); await wait(); - stdin.write('\u001B[A'); // Up + stdin.write(TerminalKeys.UP_ARROW as string); await wait(); unmount(); @@ -917,19 +1013,19 @@ describe('SettingsDialog', () => { ); // Toggle first setting (should require restart) - stdin.write('\u000D'); // Enter + stdin.write(TerminalKeys.ENTER as string); // Enter await wait(); // Navigate to next setting and toggle it (should not require restart - e.g., vimMode) - stdin.write('\u001B[B'); // Down + stdin.write(TerminalKeys.DOWN_ARROW as string); // Down await wait(); - stdin.write('\u000D'); // Enter + stdin.write(TerminalKeys.ENTER as string); // Enter await wait(); // Navigate to another setting and toggle it (should also require restart) - stdin.write('\u001B[B'); // Down + stdin.write(TerminalKeys.DOWN_ARROW as string); // Down await wait(); - stdin.write('\u000D'); // Enter + stdin.write(TerminalKeys.ENTER as string); // Enter await wait(); // The test verifies that all changes are preserved and the dialog still works @@ -948,13 +1044,13 @@ describe('SettingsDialog', () => { ); // Multiple scope changes - stdin.write('\t'); // Tab to scope + stdin.write(TerminalKeys.TAB as string); // Tab to scope await wait(); stdin.write('2'); // Workspace await wait(); - stdin.write('\t'); // Tab to settings + stdin.write(TerminalKeys.TAB as string); // Tab to settings await wait(); - stdin.write('\t'); // Tab to scope + stdin.write(TerminalKeys.TAB as string); // Tab to scope await wait(); stdin.write('1'); // User await wait(); diff --git a/packages/cli/src/ui/components/SettingsDialog.tsx b/packages/cli/src/ui/components/SettingsDialog.tsx index 8cf6e83a64b..17a56e2de06 100644 --- a/packages/cli/src/ui/components/SettingsDialog.tsx +++ b/packages/cli/src/ui/components/SettingsDialog.tsx @@ -6,7 +6,7 @@ import React, { useState, useEffect } from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import type { LoadedSettings, Settings } from '../../config/settings.js'; import { SettingScope } from '../../config/settings.js'; import { @@ -16,7 +16,6 @@ import { import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; import { getDialogSettingKeys, - getSettingValue, setPendingSettingValue, getDisplayValue, hasRestartRequiredSettings, @@ -28,11 +27,16 @@ import { getDefaultValue, setPendingSettingValueAny, getNestedValue, + getEffectiveValue, } from '../../utils/settingsUtils.js'; import { useVimMode } from '../contexts/VimModeContext.js'; import { useKeypress } from '../hooks/useKeypress.js'; import chalk from 'chalk'; import { cpSlice, cpLen, stripUnsafeCharacters } from '../utils/textUtils.js'; +import { + type SettingsValue, + TOGGLE_TYPES, +} from '../../config/settingsSchema.js'; interface SettingsDialogProps { settings: LoadedSettings; @@ -122,15 +126,33 @@ export function SettingsDialog({ value: key, type: definition?.type, toggle: () => { - if (definition?.type !== 'boolean') { - // For non-boolean items, toggle will be handled via edit mode. + if (!TOGGLE_TYPES.has(definition?.type)) { return; } - const currentValue = getSettingValue(key, pendingSettings, {}); - const newValue = !currentValue; + const currentValue = getEffectiveValue(key, pendingSettings, {}); + let newValue: SettingsValue; + if (definition?.type === 'boolean') { + newValue = !(currentValue as boolean); + setPendingSettings((prev) => + setPendingSettingValue(key, newValue as boolean, prev), + ); + } else if (definition?.type === 'enum' && definition.options) { + const options = definition.options; + const currentIndex = options?.findIndex( + (opt) => opt.value === currentValue, + ); + if (currentIndex !== -1 && currentIndex < options.length - 1) { + newValue = options[currentIndex + 1].value; + } else { + newValue = options[0].value; // loop back to start. + } + setPendingSettings((prev) => + setPendingSettingValueAny(key, newValue, prev), + ); + } setPendingSettings((prev) => - setPendingSettingValue(key, newValue, prev), + setPendingSettingValue(key, newValue as boolean, prev), ); if (!requiresRestart(key)) { @@ -634,18 +656,18 @@ export function SettingsDialog({ return ( - + Settings - {showScrollUp && } + {showScrollUp && } {visibleItems.map((item, idx) => { const isActive = focusSection === 'settings' && @@ -726,17 +748,21 @@ export function SettingsDialog({ - + {isActive ? '●' : ''} {item.label} {scopeMessage && ( - {scopeMessage} + {scopeMessage} )} @@ -744,10 +770,10 @@ export function SettingsDialog({ {displayValue} @@ -757,7 +783,7 @@ export function SettingsDialog({ ); })} - {showScrollDown && } + {showScrollDown && } @@ -776,11 +802,11 @@ export function SettingsDialog({ - + (Use Enter to select, Tab to change focus) {showRestartPrompt && ( - + To see changes, Gemini CLI must be restarted. Press r to exit and apply changes now. diff --git a/packages/cli/src/ui/components/ShellConfirmationDialog.tsx b/packages/cli/src/ui/components/ShellConfirmationDialog.tsx index 7baeaca9f4a..fa8e566c36c 100644 --- a/packages/cli/src/ui/components/ShellConfirmationDialog.tsx +++ b/packages/cli/src/ui/components/ShellConfirmationDialog.tsx @@ -7,7 +7,7 @@ import { ToolConfirmationOutcome } from '@blocksuser/gemini-cli-core'; import { Box, Text } from 'ink'; import type React from 'react'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { RenderInline } from '../utils/InlineMarkdownRenderer.js'; import type { RadioSelectItem } from './shared/RadioButtonSelect.js'; import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; @@ -68,23 +68,27 @@ export const ShellConfirmationDialog: React.FC< - Shell Command Execution - A custom command wants to run the following shell commands: + + Shell Command Execution + + + A custom command wants to run the following shell commands: + {commands.map((cmd) => ( - + ))} @@ -92,7 +96,7 @@ export const ShellConfirmationDialog: React.FC< - Do you want to proceed? + Do you want to proceed? diff --git a/packages/cli/src/ui/components/ShellInputPrompt.tsx b/packages/cli/src/ui/components/ShellInputPrompt.tsx new file mode 100644 index 00000000000..4c9858eb0ea --- /dev/null +++ b/packages/cli/src/ui/components/ShellInputPrompt.tsx @@ -0,0 +1,57 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useCallback } from 'react'; +import type React from 'react'; +import { useKeypress } from '../hooks/useKeypress.js'; +import { ShellExecutionService } from '@blocksuser/gemini-cli-core'; +import { keyToAnsi, type Key } from '../hooks/keyToAnsi.js'; + +export interface ShellInputPromptProps { + activeShellPtyId: number | null; + focus?: boolean; +} + +export const ShellInputPrompt: React.FC = ({ + activeShellPtyId, + focus = true, +}) => { + const handleShellInputSubmit = useCallback( + (input: string) => { + if (activeShellPtyId) { + ShellExecutionService.writeToPty(activeShellPtyId, input); + } + }, + [activeShellPtyId], + ); + + const handleInput = useCallback( + (key: Key) => { + if (!focus || !activeShellPtyId) { + return; + } + if (key.ctrl && key.shift && key.name === 'up') { + ShellExecutionService.scrollPty(activeShellPtyId, -1); + return; + } + + if (key.ctrl && key.shift && key.name === 'down') { + ShellExecutionService.scrollPty(activeShellPtyId, 1); + return; + } + + const ansiSequence = keyToAnsi(key); + if (ansiSequence) { + handleShellInputSubmit(ansiSequence); + } + }, + [focus, handleShellInputSubmit, activeShellPtyId], + ); + + useKeypress(handleInput, { isActive: focus }); + + return null; +}; diff --git a/packages/cli/src/ui/components/ShellModeIndicator.tsx b/packages/cli/src/ui/components/ShellModeIndicator.tsx index 23d7174017b..10370d2e553 100644 --- a/packages/cli/src/ui/components/ShellModeIndicator.tsx +++ b/packages/cli/src/ui/components/ShellModeIndicator.tsx @@ -6,13 +6,13 @@ import type React from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; export const ShellModeIndicator: React.FC = () => ( - + shell mode enabled - (esc to disable) + (esc to disable) ); diff --git a/packages/cli/src/ui/components/ShowMoreLines.tsx b/packages/cli/src/ui/components/ShowMoreLines.tsx index 41232d94834..8823eee6200 100644 --- a/packages/cli/src/ui/components/ShowMoreLines.tsx +++ b/packages/cli/src/ui/components/ShowMoreLines.tsx @@ -8,7 +8,7 @@ import { Box, Text } from 'ink'; import { useOverflowState } from '../contexts/OverflowContext.js'; import { useStreamingContext } from '../contexts/StreamingContext.js'; import { StreamingState } from '../types.js'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; interface ShowMoreLinesProps { constrainHeight: boolean; @@ -32,7 +32,7 @@ export const ShowMoreLines = ({ constrainHeight }: ShowMoreLinesProps) => { return ( - + Press ctrl-s to show more lines diff --git a/packages/cli/src/ui/components/StatsDisplay.tsx b/packages/cli/src/ui/components/StatsDisplay.tsx index dd9879d77af..8c7bacd7abb 100644 --- a/packages/cli/src/ui/components/StatsDisplay.tsx +++ b/packages/cli/src/ui/components/StatsDisplay.tsx @@ -47,7 +47,7 @@ const SubStatRow: React.FC = ({ title, children }) => ( {/* Adjust width for the "» " prefix */} - » {title} + » {title} {/* FIX: Apply the same flexGrow fix here */} {children} @@ -62,7 +62,9 @@ interface SectionProps { const Section: React.FC = ({ title, children }) => ( - {title} + + {title} + {children} ); @@ -82,16 +84,24 @@ const ModelUsageTable: React.FC<{ {/* Header */} - Model Usage + + Model Usage + - Reqs + + Reqs + - Input Tokens + + Input Tokens + - Output Tokens + + Output Tokens + {/* Divider */} @@ -101,6 +111,7 @@ const ModelUsageTable: React.FC<{ borderTop={false} borderLeft={false} borderRight={false} + borderColor={theme.border.default} width={nameWidth + requestsWidth + inputTokensWidth + outputTokensWidth} > @@ -108,10 +119,12 @@ const ModelUsageTable: React.FC<{ {Object.entries(models).map(([name, modelMetrics]) => ( - {name.replace('-001', '')} + {name.replace('-001', '')} - {modelMetrics.api.totalRequests} + + {modelMetrics.api.totalRequests} + @@ -127,7 +140,7 @@ const ModelUsageTable: React.FC<{ ))} {cacheEfficiency > 0 && ( - + Savings Highlight:{' '} {totalCachedTokens.toLocaleString()} ({cacheEfficiency.toFixed(1)} %) of input tokens were served from the cache, reducing costs. @@ -174,7 +187,9 @@ export const StatsDisplay: React.FC = ({ if (title) { return theme.ui.gradient && theme.ui.gradient.length > 0 ? ( - {title} + + {title} + ) : ( @@ -202,10 +217,10 @@ export const StatsDisplay: React.FC = ({
- {stats.sessionId} + {stats.sessionId} - + {tools.totalCalls} ({' '} ✓ {tools.totalSuccess}{' '} x {tools.totalFail} ) @@ -227,7 +242,7 @@ export const StatsDisplay: React.FC = ({ {files && (files.totalLinesAdded > 0 || files.totalLinesRemoved > 0) && ( - + +{files.totalLinesAdded} {' '} @@ -241,13 +256,15 @@ export const StatsDisplay: React.FC = ({
- {duration} + {duration} - {formatDuration(computed.agentActiveTime)} + + {formatDuration(computed.agentActiveTime)} + - + {formatDuration(computed.totalApiTime)}{' '} ({computed.apiTimePercent.toFixed(1)}%) @@ -255,7 +272,7 @@ export const StatsDisplay: React.FC = ({ - + {formatDuration(computed.totalToolTime)}{' '} ({computed.toolTimePercent.toFixed(1)}%) diff --git a/packages/cli/src/ui/components/SuggestionsDisplay.tsx b/packages/cli/src/ui/components/SuggestionsDisplay.tsx index b1be0e255f3..3cbb689e241 100644 --- a/packages/cli/src/ui/components/SuggestionsDisplay.tsx +++ b/packages/cli/src/ui/components/SuggestionsDisplay.tsx @@ -5,7 +5,7 @@ */ import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { PrepareLabel } from './PrepareLabel.js'; import { CommandKind } from '../commands/types.js'; export interface Suggestion { @@ -54,14 +54,22 @@ export function SuggestionsDisplay({ ); const visibleSuggestions = suggestions.slice(startIndex, endIndex); + const getFullLabel = (s: Suggestion) => + s.label + (s.commandKind === CommandKind.MCP_PROMPT ? ' [MCP]' : ''); + + const maxLabelLength = Math.max( + ...suggestions.map((s) => getFullLabel(s).length), + ); + const commandColumnWidth = Math.min(maxLabelLength, Math.floor(width * 0.5)); + return ( - {scrollOffset > 0 && } + {scrollOffset > 0 && } {visibleSuggestions.map((suggestion, index) => { const originalIndex = startIndex + index; const isActive = originalIndex === activeIndex; - const textColor = isActive ? Colors.AccentPurple : Colors.Gray; + const textColor = isActive ? theme.text.accent : theme.text.secondary; const labelElement = ( - - {(() => { - const isSlashCommand = userInput.startsWith('/'); - return ( - <> - {isSlashCommand ? ( - - {labelElement} - {suggestion.commandKind === CommandKind.MCP_PROMPT && ( - [MCP] - )} - - ) : ( - labelElement - )} - {suggestion.description && ( - - - {suggestion.description} - - - )} - - ); - })()} + + + + {labelElement} + {suggestion.commandKind === CommandKind.MCP_PROMPT && ( + [MCP] + )} + + + {suggestion.description && ( + + + {suggestion.description} + + + )} ); })} diff --git a/packages/cli/src/ui/components/ThemeDialog.test.tsx b/packages/cli/src/ui/components/ThemeDialog.test.tsx new file mode 100644 index 00000000000..f2899e94a26 --- /dev/null +++ b/packages/cli/src/ui/components/ThemeDialog.test.tsx @@ -0,0 +1,98 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { render } from 'ink-testing-library'; +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { ThemeDialog } from './ThemeDialog.js'; +import { LoadedSettings } from '../../config/settings.js'; +import { KeypressProvider } from '../contexts/KeypressContext.js'; +import { SettingsContext } from '../contexts/SettingsContext.js'; +import { DEFAULT_THEME, themeManager } from '../themes/theme-manager.js'; +import { act } from 'react'; + +const createMockSettings = ( + userSettings = {}, + workspaceSettings = {}, + systemSettings = {}, +): LoadedSettings => + new LoadedSettings( + { + settings: { ui: { customThemes: {} }, ...systemSettings }, + path: '/system/settings.json', + }, + { + settings: {}, + path: '/system/system-defaults.json', + }, + { + settings: { + ui: { customThemes: {} }, + ...userSettings, + }, + path: '/user/settings.json', + }, + { + settings: { + ui: { customThemes: {} }, + ...workspaceSettings, + }, + path: '/workspace/settings.json', + }, + true, + new Set(), + ); + +describe('ThemeDialog Snapshots', () => { + const baseProps = { + onSelect: vi.fn(), + onHighlight: vi.fn(), + availableTerminalHeight: 40, + terminalWidth: 120, + }; + + beforeEach(() => { + // Reset theme manager to a known state + themeManager.setActiveTheme(DEFAULT_THEME.name); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should render correctly in theme selection mode', () => { + const settings = createMockSettings(); + const { lastFrame } = render( + + + + + , + ); + + expect(lastFrame()).toMatchSnapshot(); + }); + + it('should render correctly in scope selector mode', async () => { + const settings = createMockSettings(); + const { lastFrame, stdin } = render( + + + + + , + ); + + // Press Tab to switch to scope selector mode + act(() => { + stdin.write('\t'); + }); + + // Need to wait for the state update to propagate + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(lastFrame()).toMatchSnapshot(); + }); +}); diff --git a/packages/cli/src/ui/components/ThemeDialog.tsx b/packages/cli/src/ui/components/ThemeDialog.tsx index f4729f624f4..497ba1c6541 100644 --- a/packages/cli/src/ui/components/ThemeDialog.tsx +++ b/packages/cli/src/ui/components/ThemeDialog.tsx @@ -7,18 +7,16 @@ import type React from 'react'; import { useCallback, useState } from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { themeManager, DEFAULT_THEME } from '../themes/theme-manager.js'; import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; import { DiffRenderer } from './messages/DiffRenderer.js'; import { colorizeCode } from '../utils/CodeColorizer.js'; import type { LoadedSettings } from '../../config/settings.js'; import { SettingScope } from '../../config/settings.js'; -import { - getScopeItems, - getScopeMessageForSetting, -} from '../../utils/dialogScopeUtils.js'; +import { getScopeMessageForSetting } from '../../utils/dialogScopeUtils.js'; import { useKeypress } from '../hooks/useKeypress.js'; +import { ScopeSelector } from './shared/ScopeSelector.js'; interface ThemeDialogProps { /** Callback function when a theme is selected */ @@ -73,18 +71,14 @@ export function ThemeDialog({ themeTypeDisplay: 'Custom', })), ]; - const [selectInputKey, setSelectInputKey] = useState(Date.now()); // Find the index of the selected theme, but only if it exists in the list - const selectedThemeName = settings.merged.ui?.theme || DEFAULT_THEME.name; const initialThemeIndex = themeItems.findIndex( - (item) => item.value === selectedThemeName, + (item) => item.value === highlightedThemeName, ); // If not found, fall back to the first theme const safeInitialThemeIndex = initialThemeIndex >= 0 ? initialThemeIndex : 0; - const scopeItems = getScopeItems(); - const handleThemeSelect = useCallback( (themeName: string) => { onSelect(themeName, selectedScope); @@ -99,25 +93,21 @@ export function ThemeDialog({ const handleScopeHighlight = useCallback((scope: SettingScope) => { setSelectedScope(scope); - setSelectInputKey(Date.now()); }, []); const handleScopeSelect = useCallback( (scope: SettingScope) => { - handleScopeHighlight(scope); - setFocusedSection('theme'); // Reset focus to theme section + onSelect(highlightedThemeName, scope); }, - [handleScopeHighlight], + [onSelect, highlightedThemeName], ); - const [focusedSection, setFocusedSection] = useState<'theme' | 'scope'>( - 'theme', - ); + const [mode, setMode] = useState<'theme' | 'scope'>('theme'); useKeypress( (key) => { if (key.name === 'tab') { - setFocusedSection((prev) => (prev === 'theme' ? 'scope' : 'theme')); + setMode((prev) => (prev === 'theme' ? 'scope' : 'theme')); } if (key.name === 'escape') { onSelect(undefined, selectedScope); @@ -152,20 +142,13 @@ export function ThemeDialog({ const DIALOG_PADDING = 2; const selectThemeHeight = themeItems.length + 1; - const SCOPE_SELECTION_HEIGHT = 4; // Height for the scope selection section + margin. - const SPACE_BETWEEN_THEME_SELECTION_AND_APPLY_TO = 1; const TAB_TO_SELECT_HEIGHT = 2; availableTerminalHeight = availableTerminalHeight ?? Number.MAX_SAFE_INTEGER; availableTerminalHeight -= 2; // Top and bottom borders. availableTerminalHeight -= TAB_TO_SELECT_HEIGHT; - let totalLeftHandSideHeight = - DIALOG_PADDING + - selectThemeHeight + - SCOPE_SELECTION_HEIGHT + - SPACE_BETWEEN_THEME_SELECTION_AND_APPLY_TO; + let totalLeftHandSideHeight = DIALOG_PADDING + selectThemeHeight; - let showScopeSelection = true; let includePadding = true; // Remove content from the LHS that can be omitted if it exceeds the available height. @@ -174,15 +157,6 @@ export function ThemeDialog({ totalLeftHandSideHeight -= DIALOG_PADDING; } - if (totalLeftHandSideHeight > availableTerminalHeight) { - // First, try hiding the scope selection - totalLeftHandSideHeight -= SCOPE_SELECTION_HEIGHT; - showScopeSelection = false; - } - - // Don't focus the scope selection if it is hidden due to height constraints. - const currentFocusedSection = !showScopeSelection ? 'theme' : focusedSection; - // Vertical space taken by elements other than the two code blocks in the preview pane. // Includes "Preview" title, borders, and margin between blocks. const PREVIEW_PANE_FIXED_VERTICAL_SPACE = 8; @@ -209,7 +183,7 @@ export function ThemeDialog({ return ( - - {/* Left Column: Selection */} - - - {currentFocusedSection === 'theme' ? '> ' : ' '}Select Theme{' '} - {otherScopeModifiedMessage} - - - - {/* Scope Selection */} - {showScopeSelection && ( - - - {currentFocusedSection === 'scope' ? '> ' : ' '}Apply To + {mode === 'theme' ? ( + + {/* Left Column: Selection */} + + + {mode === 'theme' ? '> ' : ' '}Select Theme{' '} + + {otherScopeModifiedMessage} - - - )} - + + + - {/* Right Column: Preview */} - - Preview - {/* Get the Theme object for the highlighted theme, fall back to default if not found */} - {(() => { - const previewTheme = - themeManager.getTheme( - highlightedThemeName || DEFAULT_THEME.name, - ) || DEFAULT_THEME; - return ( - - {colorizeCode( - `# function + {/* Right Column: Preview */} + + + Preview + + {/* Get the Theme object for the highlighted theme, fall back to default if not found */} + {(() => { + const previewTheme = + themeManager.getTheme( + highlightedThemeName || DEFAULT_THEME.name, + ) || DEFAULT_THEME; + return ( + + {colorizeCode( + `# function def fibonacci(n): a, b = 0, 1 for _ in range(n): a, b = b, a + b return a`, - 'python', - codeBlockHeight, - colorizeCodeWidth, - )} - - + - - ); - })()} + availableTerminalHeight={diffHeight} + terminalWidth={colorizeCodeWidth} + theme={previewTheme} + /> + + ); + })()} + - + ) : ( + + )} - - (Use Enter to select - {showScopeSelection ? ', Tab to change focus' : ''}) + + (Use Enter to {mode === 'theme' ? 'select' : 'apply scope'}, Tab to{' '} + {mode === 'theme' ? 'configure scope' : 'select theme'}) diff --git a/packages/cli/src/ui/components/Tips.tsx b/packages/cli/src/ui/components/Tips.tsx index 5d8865259f3..d4f07c8544c 100644 --- a/packages/cli/src/ui/components/Tips.tsx +++ b/packages/cli/src/ui/components/Tips.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { type Config } from '@blocksuser/gemini-cli-core'; interface TipsProps { @@ -17,25 +17,25 @@ export const Tips: React.FC = ({ config }) => { const geminiMdFileCount = config.getGeminiMdFileCount(); return ( - Tips for getting started: - + Tips for getting started: + 1. Ask questions, edit files, or run commands. - + 2. Be specific for the best results. {geminiMdFileCount === 0 && ( - + 3. Create{' '} - + GEMINI.md {' '} files to customize your interactions with Gemini. )} - + {geminiMdFileCount === 0 ? '4.' : '3.'}{' '} - + /help {' '} for more information. diff --git a/packages/cli/src/ui/components/ToolStatsDisplay.tsx b/packages/cli/src/ui/components/ToolStatsDisplay.tsx index fe016ecb484..0e5df14e816 100644 --- a/packages/cli/src/ui/components/ToolStatsDisplay.tsx +++ b/packages/cli/src/ui/components/ToolStatsDisplay.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { formatDuration } from '../utils/formatters.js'; import { getStatusColor, @@ -37,16 +37,16 @@ const StatRow: React.FC<{ return ( - {name} + {name} - {stats.count} + {stats.count} {successRate.toFixed(1)}% - {formatDuration(avgDuration)} + {formatDuration(avgDuration)} ); @@ -63,11 +63,13 @@ export const ToolStatsDisplay: React.FC = () => { return ( - No tool calls have been made in this session. + + No tool calls have been made in this session. + ); } @@ -94,13 +96,13 @@ export const ToolStatsDisplay: React.FC = () => { return ( - + Tool Stats For Nerds @@ -108,16 +110,24 @@ export const ToolStatsDisplay: React.FC = () => { {/* Header */} - Tool Name + + Tool Name + - Calls + + Calls + - Success Rate + + Success Rate + - Avg Duration + + Avg Duration + @@ -128,6 +138,7 @@ export const ToolStatsDisplay: React.FC = () => { borderTop={false} borderLeft={false} borderRight={false} + borderColor={theme.border.default} width="100%" /> @@ -139,45 +150,47 @@ export const ToolStatsDisplay: React.FC = () => { {/* User Decision Summary */} - User Decision Summary + + User Decision Summary + - Total Reviewed Suggestions: + Total Reviewed Suggestions: - {totalReviewed} + {totalReviewed} - » Accepted: + » Accepted: - {totalDecisions.accept} + {totalDecisions.accept} - » Rejected: + » Rejected: - {totalDecisions.reject} + {totalDecisions.reject} - » Modified: + » Modified: - {totalDecisions.modify} + {totalDecisions.modify} @@ -188,6 +201,7 @@ export const ToolStatsDisplay: React.FC = () => { borderTop={false} borderLeft={false} borderRight={false} + borderColor={theme.border.default} width="100%" /> @@ -195,7 +209,7 @@ export const ToolStatsDisplay: React.FC = () => { - Overall Agreement Rate: + Overall Agreement Rate: 0 ? agreementColor : undefined}> diff --git a/packages/cli/src/ui/components/UpdateNotification.tsx b/packages/cli/src/ui/components/UpdateNotification.tsx index b88c9bd5d02..8142a2018ed 100644 --- a/packages/cli/src/ui/components/UpdateNotification.tsx +++ b/packages/cli/src/ui/components/UpdateNotification.tsx @@ -5,7 +5,7 @@ */ import { Box, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; interface UpdateNotificationProps { message: string; @@ -14,10 +14,10 @@ interface UpdateNotificationProps { export const UpdateNotification = ({ message }: UpdateNotificationProps) => ( - {message} + {message} ); diff --git a/packages/cli/src/ui/components/WorkspaceMigrationDialog.tsx b/packages/cli/src/ui/components/WorkspaceMigrationDialog.tsx index c53de1cfb40..e642407db59 100644 --- a/packages/cli/src/ui/components/WorkspaceMigrationDialog.tsx +++ b/packages/cli/src/ui/components/WorkspaceMigrationDialog.tsx @@ -10,7 +10,7 @@ import { performWorkspaceExtensionMigration, } from '../../config/extension.js'; import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { useState } from 'react'; export function WorkspaceMigrationDialog(props: { @@ -40,15 +40,15 @@ export function WorkspaceMigrationDialog(props: { {failedExtensions.length > 0 ? ( <> - + The following extensions failed to migrate. Please try installing them manually. To see other changes, Gemini CLI must be restarted. - Press {"'q'"} to quit. + Press 'q' to quit. {failedExtensions.map((failed) => ( @@ -57,9 +57,9 @@ export function WorkspaceMigrationDialog(props: { ) : ( - + Migration complete. To see changes, Gemini CLI must be restarted. - Press {"'q'"} to quit. + Press 'q' to quit. )} @@ -70,15 +70,19 @@ export function WorkspaceMigrationDialog(props: { - Workspace-level extensions are deprecated{'\n'} - Would you like to install them at the user level? - + + Workspace-level extensions are deprecated{'\n'} + + + Would you like to install them at the user level? + + The extension definition will remain in your workspace directory. - + If you opt to skip, you can install them manually using the extensions install command. diff --git a/packages/cli/src/ui/components/__snapshots__/InputPrompt.test.tsx.snap b/packages/cli/src/ui/components/__snapshots__/InputPrompt.test.tsx.snap new file mode 100644 index 00000000000..b955931d417 --- /dev/null +++ b/packages/cli/src/ui/components/__snapshots__/InputPrompt.test.tsx.snap @@ -0,0 +1,19 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`InputPrompt > snapshots > should render correctly in shell mode 1`] = ` +"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ ! Type your message or @path/to/file │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────╯" +`; + +exports[`InputPrompt > snapshots > should render correctly in yolo mode 1`] = ` +"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ * Type your message or @path/to/file │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────╯" +`; + +exports[`InputPrompt > snapshots > should render correctly when accepting edits 1`] = ` +"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ > Type your message or @path/to/file │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────╯" +`; diff --git a/packages/cli/src/ui/components/__snapshots__/LoopDetectionConfirmation.test.tsx.snap b/packages/cli/src/ui/components/__snapshots__/LoopDetectionConfirmation.test.tsx.snap new file mode 100644 index 00000000000..1686c4bc138 --- /dev/null +++ b/packages/cli/src/ui/components/__snapshots__/LoopDetectionConfirmation.test.tsx.snap @@ -0,0 +1,13 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`LoopDetectionConfirmation > renders correctly 1`] = ` +" ╭──────────────────────────────────────────────────────────────────────────────────────────────────╮ + │ ? A potential loop was detected │ + │ │ + │ This can happen due to repetitive tool calls or other model behavior. Do you want to keep loop │ + │ detection enabled or disable it for this session? │ + │ │ + │ ● 1. Keep loop detection enabled (esc) │ + │ 2. Disable loop detection for this session │ + ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯" +`; diff --git a/packages/cli/src/ui/components/__snapshots__/ThemeDialog.test.tsx.snap b/packages/cli/src/ui/components/__snapshots__/ThemeDialog.test.tsx.snap new file mode 100644 index 00000000000..b205bba4a5b --- /dev/null +++ b/packages/cli/src/ui/components/__snapshots__/ThemeDialog.test.tsx.snap @@ -0,0 +1,38 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`ThemeDialog Snapshots > should render correctly in scope selector mode 1`] = ` +"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ │ +│ > Apply To │ +│ ● 1. User Settings │ +│ 2. Workspace Settings │ +│ 3. System Settings │ +│ │ +│ (Use Enter to apply scope, Tab to select theme) │ +│ │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────╯" +`; + +exports[`ThemeDialog Snapshots > should render correctly in theme selection mode 1`] = ` +"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ │ +│ > Select Theme Preview │ +│ ▲ ┌─────────────────────────────────────────────────┐ │ +│ 1. ANSI Dark │ │ │ +│ 2. Atom One Dark │ 1 # function │ │ +│ 3. Ayu Dark │ 2 def fibonacci(n): │ │ +│ ● 4. Default Dark │ 3 a, b = 0, 1 │ │ +│ 5. Dracula Dark │ 4 for _ in range(n): │ │ +│ 6. GitHub Dark │ 5 a, b = b, a + b │ │ +│ 7. Shades Of Purple Dark │ 6 return a │ │ +│ 8. ANSI Light Light │ │ │ +│ 9. Ayu Light Light │ 1 - print("Hello, " + name) │ │ +│ 10. Default Light Light │ 1 + print(f"Hello, {name}!") │ │ +│ 11. GitHub Light Light │ │ │ +│ 12. Google Code Light └─────────────────────────────────────────────────┘ │ +│ ▼ │ +│ │ +│ (Use Enter to select, Tab to configure scope) │ +│ │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────╯" +`; diff --git a/packages/cli/src/ui/components/messages/CompressionMessage.test.tsx b/packages/cli/src/ui/components/messages/CompressionMessage.test.tsx new file mode 100644 index 00000000000..9c5e33f3b24 --- /dev/null +++ b/packages/cli/src/ui/components/messages/CompressionMessage.test.tsx @@ -0,0 +1,198 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { render } from 'ink-testing-library'; +import type { CompressionDisplayProps } from './CompressionMessage.js'; +import { CompressionMessage } from './CompressionMessage.js'; +import { CompressionStatus } from '@blocksuser/gemini-cli-core'; +import type { CompressionProps } from '../../types.js'; +import { describe, it, expect } from 'vitest'; + +describe('', () => { + const createCompressionProps = ( + overrides: Partial = {}, + ): CompressionDisplayProps => ({ + compression: { + isPending: false, + originalTokenCount: null, + newTokenCount: null, + compressionStatus: CompressionStatus.COMPRESSED, + ...overrides, + }, + }); + + describe('pending state', () => { + it('renders pending message when compression is in progress', () => { + const props = createCompressionProps({ isPending: true }); + const { lastFrame } = render(); + const output = lastFrame(); + + expect(output).toContain('Compressing chat history'); + }); + }); + + describe('normal compression (successful token reduction)', () => { + it('renders success message when tokens are reduced', () => { + const props = createCompressionProps({ + isPending: false, + originalTokenCount: 100, + newTokenCount: 50, + compressionStatus: CompressionStatus.COMPRESSED, + }); + const { lastFrame } = render(); + const output = lastFrame(); + + expect(output).toContain('✦'); + expect(output).toContain( + 'Chat history compressed from 100 to 50 tokens.', + ); + }); + + it('renders success message for large successful compressions', () => { + const testCases = [ + { original: 50000, new: 25000 }, // Large compression + { original: 700000, new: 350000 }, // Very large compression + ]; + + testCases.forEach(({ original, new: newTokens }) => { + const props = createCompressionProps({ + isPending: false, + originalTokenCount: original, + newTokenCount: newTokens, + compressionStatus: CompressionStatus.COMPRESSED, + }); + const { lastFrame } = render(); + const output = lastFrame(); + + expect(output).toContain('✦'); + expect(output).toContain( + `compressed from ${original} to ${newTokens} tokens`, + ); + expect(output).not.toContain('Skipping compression'); + expect(output).not.toContain('did not reduce size'); + }); + }); + }); + + describe('skipped compression (tokens increased or same)', () => { + it('renders skip message when compression would increase token count', () => { + const props = createCompressionProps({ + isPending: false, + originalTokenCount: 50, + newTokenCount: 75, + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }); + const { lastFrame } = render(); + const output = lastFrame(); + + expect(output).toContain('✦'); + expect(output).toContain( + 'Compression was not beneficial for this history size.', + ); + }); + + it('renders skip message when token counts are equal', () => { + const props = createCompressionProps({ + isPending: false, + originalTokenCount: 50, + newTokenCount: 50, + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }); + const { lastFrame } = render(); + const output = lastFrame(); + + expect(output).toContain( + 'Compression was not beneficial for this history size.', + ); + }); + }); + + describe('message content validation', () => { + it('displays correct compression statistics', () => { + const testCases = [ + { + original: 200, + new: 80, + expected: 'compressed from 200 to 80 tokens', + }, + { + original: 500, + new: 150, + expected: 'compressed from 500 to 150 tokens', + }, + { + original: 1500, + new: 400, + expected: 'compressed from 1500 to 400 tokens', + }, + ]; + + testCases.forEach(({ original, new: newTokens, expected }) => { + const props = createCompressionProps({ + isPending: false, + originalTokenCount: original, + newTokenCount: newTokens, + compressionStatus: CompressionStatus.COMPRESSED, + }); + const { lastFrame } = render(); + const output = lastFrame(); + + expect(output).toContain(expected); + }); + }); + + it('shows skip message for small histories when new tokens >= original tokens', () => { + const testCases = [ + { original: 50, new: 60 }, // Increased + { original: 100, new: 100 }, // Same + { original: 49999, new: 50000 }, // Just under 50k threshold + ]; + + testCases.forEach(({ original, new: newTokens }) => { + const props = createCompressionProps({ + isPending: false, + originalTokenCount: original, + newTokenCount: newTokens, + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }); + const { lastFrame } = render(); + const output = lastFrame(); + + expect(output).toContain( + 'Compression was not beneficial for this history size.', + ); + expect(output).not.toContain('compressed from'); + }); + }); + + it('shows compression failure message for large histories when new tokens >= original tokens', () => { + const testCases = [ + { original: 50000, new: 50100 }, // At 50k threshold + { original: 700000, new: 710000 }, // Large history case + { original: 100000, new: 100000 }, // Large history, same count + ]; + + testCases.forEach(({ original, new: newTokens }) => { + const props = createCompressionProps({ + isPending: false, + originalTokenCount: original, + newTokenCount: newTokens, + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }); + const { lastFrame } = render(); + const output = lastFrame(); + + expect(output).toContain('compression did not reduce size'); + expect(output).not.toContain('compressed from'); + expect(output).not.toContain('Compression was not beneficial'); + }); + }); + }); +}); diff --git a/packages/cli/src/ui/components/messages/CompressionMessage.tsx b/packages/cli/src/ui/components/messages/CompressionMessage.tsx index 7663172e23c..8a4d7db0c97 100644 --- a/packages/cli/src/ui/components/messages/CompressionMessage.tsx +++ b/packages/cli/src/ui/components/messages/CompressionMessage.tsx @@ -4,12 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type React from 'react'; import { Box, Text } from 'ink'; import type { CompressionProps } from '../../types.js'; import Spinner from 'ink-spinner'; -import { Colors } from '../../colors.js'; +import { theme } from '../../semantic-colors.js'; import { SCREEN_READER_MODEL_PREFIX } from '../../textConstants.js'; +import { CompressionStatus } from '@blocksuser/gemini-cli-core'; export interface CompressionDisplayProps { compression: CompressionProps; @@ -19,27 +19,55 @@ export interface CompressionDisplayProps { * Compression messages appear when the /compress command is run, and show a loading spinner * while compression is in progress, followed up by some compression stats. */ -export const CompressionMessage: React.FC = ({ +export function CompressionMessage({ compression, -}) => { - const text = compression.isPending - ? 'Compressing chat history' - : `Chat history compressed from ${compression.originalTokenCount ?? 'unknown'}` + - ` to ${compression.newTokenCount ?? 'unknown'} tokens.`; +}: CompressionDisplayProps): React.JSX.Element { + const { isPending, originalTokenCount, newTokenCount, compressionStatus } = + compression; + + const originalTokens = originalTokenCount ?? 0; + const newTokens = newTokenCount ?? 0; + + const getCompressionText = () => { + if (isPending) { + return 'Compressing chat history'; + } + + switch (compressionStatus) { + case CompressionStatus.COMPRESSED: + return `Chat history compressed from ${originalTokens} to ${newTokens} tokens.`; + case CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT: + // For smaller histories (< 50k tokens), compression overhead likely exceeds benefits + if (originalTokens < 50000) { + return 'Compression was not beneficial for this history size.'; + } + // For larger histories where compression should work but didn't, + // this suggests an issue with the compression process itself + return 'Chat history compression did not reduce size. This may indicate issues with the compression prompt.'; + case CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR: + return 'Could not compress chat history due to a token counting error.'; + case CompressionStatus.NOOP: + return 'Chat history is already compressed.'; + default: + return ''; + } + }; + + const text = getCompressionText(); return ( - {compression.isPending ? ( + {isPending ? ( ) : ( - + )} @@ -48,4 +76,4 @@ export const CompressionMessage: React.FC = ({ ); -}; +} diff --git a/packages/cli/src/ui/components/messages/DiffRenderer.tsx b/packages/cli/src/ui/components/messages/DiffRenderer.tsx index f855c97f565..d962d683b88 100644 --- a/packages/cli/src/ui/components/messages/DiffRenderer.tsx +++ b/packages/cli/src/ui/components/messages/DiffRenderer.tsx @@ -6,11 +6,11 @@ import type React from 'react'; import { Box, Text, useIsScreenReaderEnabled } from 'ink'; -import { Colors } from '../../colors.js'; import crypto from 'node:crypto'; import { colorizeCode, colorizeLine } from '../../utils/CodeColorizer.js'; import { MaxSizedBox } from '../shared/MaxSizedBox.js'; -import { theme } from '../../semantic-colors.js'; +import { theme as semanticTheme } from '../../semantic-colors.js'; +import type { Theme } from '../../themes/theme.js'; interface DiffLine { type: 'add' | 'del' | 'context' | 'hunk' | 'other'; @@ -42,18 +42,9 @@ function parseDiffWithLineNumbers(diffContent: string): DiffLine[] { } if (!inHunk) { // Skip standard Git header lines more robustly - if ( - line.startsWith('--- ') || - line.startsWith('+++ ') || - line.startsWith('diff --git') || - line.startsWith('index ') || - line.startsWith('similarity index') || - line.startsWith('rename from') || - line.startsWith('rename to') || - line.startsWith('new file mode') || - line.startsWith('deleted file mode') - ) + if (line.startsWith('--- ')) { continue; + } // If it's not a hunk or header, skip (or handle as 'other' if needed) continue; } @@ -94,7 +85,7 @@ interface DiffRendererProps { tabWidth?: number; availableTerminalHeight?: number; terminalWidth: number; - theme?: import('../../themes/theme.js').Theme; + theme?: Theme; } const DEFAULT_TAB_WIDTH = 4; // Spaces per tab for normalization @@ -109,14 +100,18 @@ export const DiffRenderer: React.FC = ({ }) => { const screenReaderEnabled = useIsScreenReaderEnabled(); if (!diffContent || typeof diffContent !== 'string') { - return No diff content.; + return No diff content.; } const parsedLines = parseDiffWithLineNumbers(diffContent); if (parsedLines.length === 0) { return ( - + No changes detected. ); @@ -196,7 +191,11 @@ const renderDiffContent = ( if (displayableLines.length === 0) { return ( - + No changes detected. ); @@ -260,7 +259,7 @@ const renderDiffContent = ( ) { acc.push( - + {'═'.repeat(terminalWidth)} , @@ -301,12 +300,12 @@ const renderDiffContent = ( acc.push( @@ -323,16 +322,16 @@ const renderDiffContent = ( {prefixSymbol} diff --git a/packages/cli/src/ui/components/messages/ErrorMessage.tsx b/packages/cli/src/ui/components/messages/ErrorMessage.tsx index 52a03a89deb..71794ee41c9 100644 --- a/packages/cli/src/ui/components/messages/ErrorMessage.tsx +++ b/packages/cli/src/ui/components/messages/ErrorMessage.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Text, Box } from 'ink'; -import { Colors } from '../../colors.js'; +import { theme } from '../../semantic-colors.js'; interface ErrorMessageProps { text: string; @@ -19,10 +19,10 @@ export const ErrorMessage: React.FC = ({ text }) => { return ( - {prefix} + {prefix} - + {text} diff --git a/packages/cli/src/ui/components/messages/GeminiMessage.tsx b/packages/cli/src/ui/components/messages/GeminiMessage.tsx index 9473c12885b..389b5ac1516 100644 --- a/packages/cli/src/ui/components/messages/GeminiMessage.tsx +++ b/packages/cli/src/ui/components/messages/GeminiMessage.tsx @@ -7,7 +7,7 @@ import type React from 'react'; import { Text, Box } from 'ink'; import { MarkdownDisplay } from '../../utils/MarkdownDisplay.js'; -import { Colors } from '../../colors.js'; +import { theme } from '../../semantic-colors.js'; import { SCREEN_READER_MODEL_PREFIX } from '../../textConstants.js'; interface GeminiMessageProps { @@ -29,10 +29,7 @@ export const GeminiMessage: React.FC = ({ return ( - + {prefix} diff --git a/packages/cli/src/ui/components/messages/InfoMessage.tsx b/packages/cli/src/ui/components/messages/InfoMessage.tsx index 3d7866bec08..e8d09d637ec 100644 --- a/packages/cli/src/ui/components/messages/InfoMessage.tsx +++ b/packages/cli/src/ui/components/messages/InfoMessage.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Text, Box } from 'ink'; -import { Colors } from '../../colors.js'; +import { theme } from '../../semantic-colors.js'; import { RenderInline } from '../../utils/InlineMarkdownRenderer.js'; interface InfoMessageProps { @@ -20,10 +20,10 @@ export const InfoMessage: React.FC = ({ text }) => { return ( - {prefix} + {prefix} - + diff --git a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx index 7701ae1f62e..11baabb1823 100644 --- a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx @@ -5,9 +5,9 @@ */ import type React from 'react'; +import { useEffect, useState } from 'react'; import { Box, Text } from 'ink'; import { DiffRenderer } from './DiffRenderer.js'; -import { Colors } from '../../colors.js'; import { RenderInline } from '../../utils/InlineMarkdownRenderer.js'; import type { ToolCallConfirmationDetails, @@ -20,6 +20,7 @@ import type { RadioSelectItem } from '../shared/RadioButtonSelect.js'; import { RadioButtonSelect } from '../shared/RadioButtonSelect.js'; import { MaxSizedBox } from '../shared/MaxSizedBox.js'; import { useKeypress } from '../../hooks/useKeypress.js'; +import { theme } from '../../semantic-colors.js'; export interface ToolConfirmationMessageProps { confirmationDetails: ToolCallConfirmationDetails; @@ -41,12 +42,31 @@ export const ToolConfirmationMessage: React.FC< const { onConfirm } = confirmationDetails; const childWidth = terminalWidth - 2; // 2 for padding + const [ideClient, setIdeClient] = useState(null); + const [isDiffingEnabled, setIsDiffingEnabled] = useState(false); + + useEffect(() => { + let isMounted = true; + if (config.getIdeMode()) { + const getIdeClient = async () => { + const client = await IdeClient.getInstance(); + if (isMounted) { + setIdeClient(client); + setIsDiffingEnabled(client?.isDiffingEnabled() ?? false); + } + }; + getIdeClient(); + } + return () => { + isMounted = false; + }; + }, [config]); + const handleConfirm = async (outcome: ToolConfirmationOutcome) => { if (confirmationDetails.type === 'edit') { - if (config.getIdeMode()) { + if (config.getIdeMode() && isDiffingEnabled) { const cliOutcome = outcome === ToolConfirmationOutcome.Cancel ? 'rejected' : 'accepted'; - const ideClient = await IdeClient.getInstance(); await ideClient?.resolveDiffFromCli( confirmationDetails.filePath, cliOutcome, @@ -113,13 +133,13 @@ export const ToolConfirmationMessage: React.FC< - Modify in progress: - + Modify in progress: + Save and close external editor to continue @@ -137,22 +157,18 @@ export const ToolConfirmationMessage: React.FC< value: ToolConfirmationOutcome.ProceedAlways, }); } - if (config.getIdeMode()) { - options.push({ - label: 'No (esc)', - value: ToolConfirmationOutcome.Cancel, - }); - } else { + if (!config.getIdeMode() || !isDiffingEnabled) { options.push({ label: 'Modify with external editor', value: ToolConfirmationOutcome.ModifyWithEditor, }); - options.push({ - label: 'No, suggest changes (esc)', - value: ToolConfirmationOutcome.Cancel, - }); } + options.push({ + label: 'No, suggest changes (esc)', + value: ToolConfirmationOutcome.Cancel, + }); + bodyContent = ( - {executionProps.command} + {executionProps.command} @@ -223,12 +239,12 @@ export const ToolConfirmationMessage: React.FC< bodyContent = ( - + {displayUrls && infoProps.urls && infoProps.urls.length > 0 && ( - URLs to fetch: + URLs to fetch: {infoProps.urls.map((url) => ( {' '} @@ -245,8 +261,8 @@ export const ToolConfirmationMessage: React.FC< bodyContent = ( - MCP Server: {mcpProps.serverName} - Tool: {mcpProps.toolName} + MCP Server: {mcpProps.serverName} + Tool: {mcpProps.toolName} ); @@ -281,7 +297,9 @@ export const ToolConfirmationMessage: React.FC< {/* Confirmation Question */} - {question} + + {question} + {/* Select Input for Options */} diff --git a/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx b/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx index bfa2db33430..b2b25cc8a14 100644 --- a/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx @@ -21,6 +21,9 @@ interface ToolGroupMessageProps { availableTerminalHeight?: number; terminalWidth: number; isFocused?: boolean; + activeShellPtyId?: number | null; + shellFocused?: boolean; + onShellInputSubmit?: (input: string) => void; } // Main component renders the border and maps the tools using ToolMessage @@ -29,14 +32,26 @@ export const ToolGroupMessage: React.FC = ({ availableTerminalHeight, terminalWidth, isFocused = true, + activeShellPtyId, + shellFocused, }) => { - const config = useConfig(); + const isShellFocused = + shellFocused && + toolCalls.some( + (t) => + t.ptyId === activeShellPtyId && t.status === ToolCallStatus.Executing, + ); + const hasPending = !toolCalls.every( (t) => t.status === ToolCallStatus.Success, ); + + const config = useConfig(); const isShellCommand = toolCalls.some((t) => t.name === SHELL_COMMAND_NAME); const borderColor = - hasPending || isShellCommand ? theme.status.warning : theme.border.default; + hasPending || isShellCommand || isShellFocused + ? theme.status.warning + : theme.border.default; const staticHeight = /* border */ 2 + /* marginBottom */ 1; // This is a bit of a magic number, but it accounts for the border and @@ -89,12 +104,7 @@ export const ToolGroupMessage: React.FC = ({ = ({ ? 'low' : 'medium' } - renderOutputAsMarkdown={tool.renderOutputAsMarkdown} + activeShellPtyId={activeShellPtyId} + shellFocused={shellFocused} + config={config} /> {tool.status === ToolCallStatus.Confirming && @@ -122,7 +134,9 @@ export const ToolGroupMessage: React.FC = ({ )} {tool.outputFile && ( - Output too long and was saved to: {tool.outputFile} + + Output too long and was saved to: {tool.outputFile} + )} diff --git a/packages/cli/src/ui/components/messages/ToolMessage.test.tsx b/packages/cli/src/ui/components/messages/ToolMessage.test.tsx index d6872dbad3e..72cc07f7924 100644 --- a/packages/cli/src/ui/components/messages/ToolMessage.test.tsx +++ b/packages/cli/src/ui/components/messages/ToolMessage.test.tsx @@ -11,6 +11,31 @@ import { ToolMessage } from './ToolMessage.js'; import { StreamingState, ToolCallStatus } from '../../types.js'; import { Text } from 'ink'; import { StreamingContext } from '../../contexts/StreamingContext.js'; +import type { AnsiOutput } from '@blocksuser/gemini-cli-core'; + +vi.mock('../TerminalOutput.js', () => ({ + TerminalOutput: function MockTerminalOutput({ + cursor, + }: { + cursor: { x: number; y: number } | null; + }) { + return ( + + MockCursor:({cursor?.x},{cursor?.y}) + + ); + }, +})); + +vi.mock('../AnsiOutput.js', () => ({ + AnsiOutputText: function MockAnsiOutputText({ data }: { data: AnsiOutput }) { + // Simple serialization for snapshot stability + const serialized = data + .map((line) => line.map((token) => token.text || '').join('')) + .join('\n'); + return MockAnsiOutput:{serialized}; + }, +})); // Mock child components or utilities if they are complex or have side effects vi.mock('../GeminiRespondingSpinner.js', () => ({ @@ -181,4 +206,26 @@ describe('', () => { // We can at least ensure it doesn't have the high emphasis indicator. expect(lowEmphasisFrame()).not.toContain('←'); }); + + it('renders AnsiOutputText for AnsiOutput results', () => { + const ansiResult: AnsiOutput = [ + [ + { + text: 'hello', + fg: '#ffffff', + bg: '#000000', + bold: false, + italic: false, + underline: false, + dim: false, + inverse: false, + }, + ], + ]; + const { lastFrame } = renderWithContext( + , + StreamingState.Idle, + ); + expect(lastFrame()).toContain('MockAnsiOutput:hello'); + }); }); diff --git a/packages/cli/src/ui/components/messages/ToolMessage.tsx b/packages/cli/src/ui/components/messages/ToolMessage.tsx index c4e5b6baf43..86e31f74bd0 100644 --- a/packages/cli/src/ui/components/messages/ToolMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolMessage.tsx @@ -9,11 +9,14 @@ import { Box, Text } from 'ink'; import type { IndividualToolCallDisplay } from '../../types.js'; import { ToolCallStatus } from '../../types.js'; import { DiffRenderer } from './DiffRenderer.js'; -import { Colors } from '../../colors.js'; import { MarkdownDisplay } from '../../utils/MarkdownDisplay.js'; +import { AnsiOutputText } from '../AnsiOutput.js'; import { GeminiRespondingSpinner } from '../GeminiRespondingSpinner.js'; import { MaxSizedBox } from '../shared/MaxSizedBox.js'; -import { TOOL_STATUS } from '../../constants.js'; +import { ShellInputPrompt } from '../ShellInputPrompt.js'; +import { SHELL_COMMAND_NAME, TOOL_STATUS } from '../../constants.js'; +import { theme } from '../../semantic-colors.js'; +import type { AnsiOutput, Config } from '@blocksuser/gemini-cli-core'; const STATIC_HEIGHT = 1; const RESERVED_LINE_COUNT = 5; // for tool name, status, padding etc. @@ -30,6 +33,9 @@ export interface ToolMessageProps extends IndividualToolCallDisplay { terminalWidth: number; emphasis?: TextEmphasis; renderOutputAsMarkdown?: boolean; + activeShellPtyId?: number | null; + shellFocused?: boolean; + config?: Config; } export const ToolMessage: React.FC = ({ @@ -41,7 +47,17 @@ export const ToolMessage: React.FC = ({ terminalWidth, emphasis = 'medium', renderOutputAsMarkdown = true, + activeShellPtyId, + shellFocused, + ptyId, + config, }) => { + const isThisShellFocused = + (name === SHELL_COMMAND_NAME || name === 'Shell') && + status === ToolCallStatus.Executing && + ptyId === activeShellPtyId && + shellFocused; + const availableHeight = availableTerminalHeight ? Math.max( availableTerminalHeight - STATIC_HEIGHT - RESERVED_LINE_COUNT, @@ -74,12 +90,17 @@ export const ToolMessage: React.FC = ({ description={description} emphasis={emphasis} /> + {isThisShellFocused && ( + + [Focused] + + )} {emphasis === 'high' && } {resultDisplay && ( - {typeof resultDisplay === 'string' && renderOutputAsMarkdown && ( + {typeof resultDisplay === 'string' && renderOutputAsMarkdown ? ( = ({ terminalWidth={childWidth} /> - )} - {typeof resultDisplay === 'string' && !renderOutputAsMarkdown && ( + ) : typeof resultDisplay === 'string' && !renderOutputAsMarkdown ? ( - {resultDisplay} + + {resultDisplay} + - )} - {typeof resultDisplay !== 'string' && ( + ) : typeof resultDisplay === 'object' && + !Array.isArray(resultDisplay) ? ( + ) : ( + )} )} + {isThisShellFocused && config && ( + + + + )} ); }; @@ -120,7 +155,7 @@ const ToolStatusIndicator: React.FC = ({ }) => ( {status === ToolCallStatus.Pending && ( - {TOOL_STATUS.PENDING} + {TOOL_STATUS.PENDING} )} {status === ToolCallStatus.Executing && ( = ({ /> )} {status === ToolCallStatus.Success && ( - + {TOOL_STATUS.SUCCESS} )} {status === ToolCallStatus.Confirming && ( - + {TOOL_STATUS.CONFIRMING} )} {status === ToolCallStatus.Canceled && ( - + {TOOL_STATUS.CANCELED} )} {status === ToolCallStatus.Error && ( - + {TOOL_STATUS.ERROR} )} @@ -166,11 +201,11 @@ const ToolInfo: React.FC = ({ const nameColor = React.useMemo(() => { switch (emphasis) { case 'high': - return Colors.Foreground; + return theme.text.primary; case 'medium': - return Colors.Foreground; + return theme.text.primary; case 'low': - return Colors.Gray; + return theme.text.secondary; default: { const exhaustiveCheck: never = emphasis; return exhaustiveCheck; @@ -186,14 +221,14 @@ const ToolInfo: React.FC = ({ {name} {' '} - {description} + {description} ); }; const TrailingIndicator: React.FC = () => ( - + {' '} ← diff --git a/packages/cli/src/ui/components/messages/UserMessage.tsx b/packages/cli/src/ui/components/messages/UserMessage.tsx index 4f279a747f1..2de3ed0df84 100644 --- a/packages/cli/src/ui/components/messages/UserMessage.tsx +++ b/packages/cli/src/ui/components/messages/UserMessage.tsx @@ -6,7 +6,7 @@ import type React from 'react'; import { Text, Box } from 'ink'; -import { Colors } from '../../colors.js'; +import { theme } from '../../semantic-colors.js'; import { SCREEN_READER_USER_PREFIX } from '../../textConstants.js'; import { isSlashCommand as checkIsSlashCommand } from '../../utils/commandUtils.js'; @@ -19,8 +19,8 @@ export const UserMessage: React.FC = ({ text }) => { const prefixWidth = prefix.length; const isSlashCommand = checkIsSlashCommand(text); - const textColor = isSlashCommand ? Colors.AccentPurple : Colors.Gray; - const borderColor = isSlashCommand ? Colors.AccentPurple : Colors.Gray; + const textColor = isSlashCommand ? theme.text.accent : theme.text.secondary; + const borderColor = isSlashCommand ? theme.text.accent : theme.text.secondary; return ( = ({ text }) => { return ( - $ - {commandToDisplay} + $ + {commandToDisplay} ); }; diff --git a/packages/cli/src/ui/components/shared/EnumSelector.test.tsx b/packages/cli/src/ui/components/shared/EnumSelector.test.tsx new file mode 100644 index 00000000000..be2df513f97 --- /dev/null +++ b/packages/cli/src/ui/components/shared/EnumSelector.test.tsx @@ -0,0 +1,152 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { renderWithProviders } from '../../../test-utils/render.js'; +import { EnumSelector } from './EnumSelector.js'; +import type { SettingEnumOption } from '../../../config/settingsSchema.js'; +import { describe, it, expect } from 'vitest'; + +const LANGUAGE_OPTIONS: readonly SettingEnumOption[] = [ + { label: 'English', value: 'en' }, + { label: '中文 (简体)', value: 'zh' }, + { label: 'Español', value: 'es' }, + { label: 'Français', value: 'fr' }, +]; + +const NUMERIC_OPTIONS: readonly SettingEnumOption[] = [ + { label: 'Low', value: 1 }, + { label: 'Medium', value: 2 }, + { label: 'High', value: 3 }, +]; + +describe('', () => { + it('renders with string options and matches snapshot', () => { + const { lastFrame } = renderWithProviders( + {}} + />, + ); + expect(lastFrame()).toMatchSnapshot(); + }); + + it('renders with numeric options and matches snapshot', () => { + const { lastFrame } = renderWithProviders( + {}} + />, + ); + expect(lastFrame()).toMatchSnapshot(); + }); + + it('renders inactive state and matches snapshot', () => { + const { lastFrame } = renderWithProviders( + {}} + />, + ); + expect(lastFrame()).toMatchSnapshot(); + }); + + it('renders with single option and matches snapshot', () => { + const singleOption: readonly SettingEnumOption[] = [ + { label: 'Only Option', value: 'only' }, + ]; + const { lastFrame } = renderWithProviders( + {}} + />, + ); + expect(lastFrame()).toMatchSnapshot(); + }); + + it('renders nothing when no options are provided', () => { + const { lastFrame } = renderWithProviders( + {}} + />, + ); + expect(lastFrame()).toBe(''); + }); + + it('handles currentValue not found in options', () => { + const { lastFrame } = renderWithProviders( + {}} + />, + ); + // Should default to first option + expect(lastFrame()).toContain('English'); + }); + + it('updates when currentValue changes externally', () => { + const { rerender, lastFrame } = renderWithProviders( + {}} + />, + ); + expect(lastFrame()).toContain('English'); + + rerender( + {}} + />, + ); + expect(lastFrame()).toContain('中文 (简体)'); + }); + + it('shows navigation arrows when multiple options available', () => { + const { lastFrame } = renderWithProviders( + {}} + />, + ); + expect(lastFrame()).toContain('←'); + expect(lastFrame()).toContain('→'); + }); + + it('hides navigation arrows when single option available', () => { + const singleOption: readonly SettingEnumOption[] = [ + { label: 'Only Option', value: 'only' }, + ]; + const { lastFrame } = renderWithProviders( + {}} + />, + ); + expect(lastFrame()).not.toContain('←'); + expect(lastFrame()).not.toContain('→'); + }); +}); diff --git a/packages/cli/src/ui/components/shared/EnumSelector.tsx b/packages/cli/src/ui/components/shared/EnumSelector.tsx new file mode 100644 index 00000000000..a86efd8ff13 --- /dev/null +++ b/packages/cli/src/ui/components/shared/EnumSelector.tsx @@ -0,0 +1,87 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useState, useEffect } from 'react'; +import type React from 'react'; +import { Box, Text } from 'ink'; +import { Colors } from '../../colors.js'; +import type { SettingEnumOption } from '../../../config/settingsSchema.js'; + +interface EnumSelectorProps { + options: readonly SettingEnumOption[]; + currentValue: string | number; + isActive: boolean; + onValueChange: (value: string | number) => void; +} + +/** + * A left-right scrolling selector for enum values + */ +export function EnumSelector({ + options, + currentValue, + isActive, + onValueChange: _onValueChange, +}: EnumSelectorProps): React.JSX.Element { + const [currentIndex, setCurrentIndex] = useState(() => { + // Guard against empty options array + if (!options || options.length === 0) { + return 0; + } + const index = options.findIndex((option) => option.value === currentValue); + return index >= 0 ? index : 0; + }); + + // Update index when currentValue changes externally + useEffect(() => { + // Guard against empty options array + if (!options || options.length === 0) { + return; + } + const index = options.findIndex((option) => option.value === currentValue); + // Always update index, defaulting to 0 if value not found + setCurrentIndex(index >= 0 ? index : 0); + }, [currentValue, options]); + + // Guard against empty options array + if (!options || options.length === 0) { + return ; + } + + // Left/right navigation is handled by parent component + // This component is purely for display + // onValueChange is kept for interface compatibility but not used internally + + const currentOption = options[currentIndex] || options[0]; + const canScrollLeft = options.length > 1; + const canScrollRight = options.length > 1; + + return ( + + + {canScrollLeft ? '←' : ' '} + + + + {currentOption.label} + + + + {canScrollRight ? '→' : ' '} + + + ); +} + +// Export the interface for external use +export type { EnumSelectorProps }; diff --git a/packages/cli/src/ui/components/shared/MaxSizedBox.tsx b/packages/cli/src/ui/components/shared/MaxSizedBox.tsx index 74092324a78..86b36cd28d4 100644 --- a/packages/cli/src/ui/components/shared/MaxSizedBox.tsx +++ b/packages/cli/src/ui/components/shared/MaxSizedBox.tsx @@ -7,7 +7,7 @@ import React, { Fragment, useEffect, useId } from 'react'; import { Box, Text } from 'ink'; import stringWidth from 'string-width'; -import { Colors } from '../../colors.js'; +import { theme } from '../../semantic-colors.js'; import { toCodePoints } from '../../utils/textUtils.js'; import { useOverflowActions } from '../../contexts/OverflowContext.js'; @@ -186,14 +186,14 @@ export const MaxSizedBox: React.FC = ({ return ( {totalHiddenLines > 0 && overflowDirection === 'top' && ( - + ... first {totalHiddenLines} line{totalHiddenLines === 1 ? '' : 's'}{' '} hidden ... )} {visibleLines} {totalHiddenLines > 0 && overflowDirection === 'bottom' && ( - + ... last {totalHiddenLines} line{totalHiddenLines === 1 ? '' : 's'}{' '} hidden ... diff --git a/packages/cli/src/ui/components/shared/RadioButtonSelect.tsx b/packages/cli/src/ui/components/shared/RadioButtonSelect.tsx index 719d263b96b..ab62e5d1b11 100644 --- a/packages/cli/src/ui/components/shared/RadioButtonSelect.tsx +++ b/packages/cli/src/ui/components/shared/RadioButtonSelect.tsx @@ -7,7 +7,7 @@ import type React from 'react'; import { useEffect, useState, useRef } from 'react'; import { Text, Box } from 'ink'; -import { Colors } from '../../colors.js'; +import { theme } from '../../semantic-colors.js'; import { useKeypress } from '../../hooks/useKeypress.js'; /** @@ -164,7 +164,9 @@ export function RadioButtonSelect({ return ( {showScrollArrows && ( - 0 ? Colors.Foreground : Colors.Gray}> + 0 ? theme.text.primary : theme.text.secondary} + > ▲ )} @@ -172,18 +174,18 @@ export function RadioButtonSelect({ const itemIndex = scrollOffset + index; const isSelected = activeIndex === itemIndex; - let textColor = Colors.Foreground; - let numberColor = Colors.Foreground; + let textColor = theme.text.primary; + let numberColor = theme.text.primary; if (isSelected) { - textColor = Colors.AccentGreen; - numberColor = Colors.AccentGreen; + textColor = theme.status.success; + numberColor = theme.status.success; } else if (item.disabled) { - textColor = Colors.Gray; - numberColor = Colors.Gray; + textColor = theme.text.secondary; + numberColor = theme.text.secondary; } if (!showNumbers) { - numberColor = Colors.Gray; + numberColor = theme.text.secondary; } const numberColumnWidth = String(items.length).length; @@ -195,7 +197,7 @@ export function RadioButtonSelect({ {isSelected ? '●' : ' '} @@ -212,7 +214,9 @@ export function RadioButtonSelect({ {item.themeNameDisplay && item.themeTypeDisplay ? ( {item.themeNameDisplay}{' '} - {item.themeTypeDisplay} + + {item.themeTypeDisplay} + ) : ( @@ -226,8 +230,8 @@ export function RadioButtonSelect({ ▼ diff --git a/packages/cli/src/ui/components/shared/ScopeSelector.tsx b/packages/cli/src/ui/components/shared/ScopeSelector.tsx new file mode 100644 index 00000000000..8066d8c9ee0 --- /dev/null +++ b/packages/cli/src/ui/components/shared/ScopeSelector.tsx @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type React from 'react'; +import { Box, Text } from 'ink'; +import type { SettingScope } from '../../../config/settings.js'; +import { getScopeItems } from '../../../utils/dialogScopeUtils.js'; +import { RadioButtonSelect } from './RadioButtonSelect.js'; + +interface ScopeSelectorProps { + /** Callback function when a scope is selected */ + onSelect: (scope: SettingScope) => void; + /** Callback function when a scope is highlighted */ + onHighlight: (scope: SettingScope) => void; + /** Whether the component is focused */ + isFocused: boolean; + /** The initial scope to select */ + initialScope: SettingScope; +} + +export function ScopeSelector({ + onSelect, + onHighlight, + isFocused, + initialScope, +}: ScopeSelectorProps): React.JSX.Element { + const scopeItems = getScopeItems(); + + const initialIndex = scopeItems.findIndex( + (item) => item.value === initialScope, + ); + const safeInitialIndex = initialIndex >= 0 ? initialIndex : 0; + + return ( + + + {isFocused ? '> ' : ' '}Apply To + + + + ); +} diff --git a/packages/cli/src/ui/components/shared/__snapshots__/EnumSelector.test.tsx.snap b/packages/cli/src/ui/components/shared/__snapshots__/EnumSelector.test.tsx.snap new file mode 100644 index 00000000000..9949aba5e47 --- /dev/null +++ b/packages/cli/src/ui/components/shared/__snapshots__/EnumSelector.test.tsx.snap @@ -0,0 +1,9 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[` > renders inactive state and matches snapshot 1`] = `"← 中文 (简体) →"`; + +exports[` > renders with numeric options and matches snapshot 1`] = `"← Medium →"`; + +exports[` > renders with single option and matches snapshot 1`] = `" Only Option"`; + +exports[` > renders with string options and matches snapshot 1`] = `"← English →"`; diff --git a/packages/cli/src/ui/components/shared/text-buffer.test.ts b/packages/cli/src/ui/components/shared/text-buffer.test.ts index 94d4ffb55b9..2bd02967811 100644 --- a/packages/cli/src/ui/components/shared/text-buffer.test.ts +++ b/packages/cli/src/ui/components/shared/text-buffer.test.ts @@ -1393,6 +1393,69 @@ Contrary to popular belief, Lorem Ipsum is not simply random text. It has roots expect(getBufferState(result).text).toBe('Pasted Text'); }); + it('should sanitize large text (>5000 chars) and strip unsafe characters', () => { + const { result } = renderHook(() => + useTextBuffer({ viewport, isValidPath: () => false }), + ); + const unsafeChars = '\x07\x08\x0B\x0C'; + const largeTextWithUnsafe = + 'safe text'.repeat(600) + unsafeChars + 'more safe text'; + + expect(largeTextWithUnsafe.length).toBeGreaterThan(5000); + + act(() => + result.current.handleInput({ + name: '', + ctrl: false, + meta: false, + shift: false, + paste: false, + sequence: largeTextWithUnsafe, + }), + ); + + const resultText = getBufferState(result).text; + expect(resultText).not.toContain('\x07'); + expect(resultText).not.toContain('\x08'); + expect(resultText).not.toContain('\x0B'); + expect(resultText).not.toContain('\x0C'); + expect(resultText).toContain('safe text'); + expect(resultText).toContain('more safe text'); + }); + + it('should sanitize large ANSI text (>5000 chars) and strip escape codes', () => { + const { result } = renderHook(() => + useTextBuffer({ viewport, isValidPath: () => false }), + ); + const largeTextWithAnsi = + '\x1B[31m' + + 'red text'.repeat(800) + + '\x1B[0m' + + '\x1B[32m' + + 'green text'.repeat(200) + + '\x1B[0m'; + + expect(largeTextWithAnsi.length).toBeGreaterThan(5000); + + act(() => + result.current.handleInput({ + name: '', + ctrl: false, + meta: false, + shift: false, + paste: false, + sequence: largeTextWithAnsi, + }), + ); + + const resultText = getBufferState(result).text; + expect(resultText).not.toContain('\x1B[31m'); + expect(resultText).not.toContain('\x1B[32m'); + expect(resultText).not.toContain('\x1B[0m'); + expect(resultText).toContain('red text'); + expect(resultText).toContain('green text'); + }); + it('should not strip popular emojis', () => { const { result } = renderHook(() => useTextBuffer({ viewport, isValidPath: () => false }), diff --git a/packages/cli/src/ui/components/shared/text-buffer.ts b/packages/cli/src/ui/components/shared/text-buffer.ts index 17d2a3adaeb..751b07a4553 100644 --- a/packages/cli/src/ui/components/shared/text-buffer.ts +++ b/packages/cli/src/ui/components/shared/text-buffer.ts @@ -9,13 +9,13 @@ import fs from 'node:fs'; import os from 'node:os'; import pathMod from 'node:path'; import { useState, useCallback, useEffect, useMemo, useReducer } from 'react'; -import stringWidth from 'string-width'; import { unescapePath } from '@blocksuser/gemini-cli-core'; import { toCodePoints, cpLen, cpSlice, stripUnsafeCharacters, + getCachedStringWidth, } from '../../utils/textUtils.js'; import type { VimAction } from './vim-buffer-actions.js'; import { handleVimAction } from './vim-buffer-actions.js'; @@ -629,21 +629,23 @@ export function logicalPosToOffset( return offset; } -// Helper to calculate visual lines and map cursor positions -function calculateVisualLayout( +export interface VisualLayout { + visualLines: string[]; + // For each logical line, an array of [visualLineIndex, startColInLogical] + logicalToVisualMap: Array>; + // For each visual line, its [logicalLineIndex, startColInLogical] + visualToLogicalMap: Array<[number, number]>; +} + +// Calculates the visual wrapping of lines and the mapping between logical and visual coordinates. +// This is an expensive operation and should be memoized. +function calculateLayout( logicalLines: string[], - logicalCursor: [number, number], viewportWidth: number, -): { - visualLines: string[]; - visualCursor: [number, number]; - logicalToVisualMap: Array>; // For each logical line, an array of [visualLineIndex, startColInLogical] - visualToLogicalMap: Array<[number, number]>; // For each visual line, its [logicalLineIndex, startColInLogical] -} { +): VisualLayout { const visualLines: string[] = []; const logicalToVisualMap: Array> = []; const visualToLogicalMap: Array<[number, number]> = []; - let currentVisualCursor: [number, number] = [0, 0]; logicalLines.forEach((logLine, logIndex) => { logicalToVisualMap[logIndex] = []; @@ -652,9 +654,6 @@ function calculateVisualLayout( logicalToVisualMap[logIndex].push([visualLines.length, 0]); visualToLogicalMap.push([logIndex, 0]); visualLines.push(''); - if (logIndex === logicalCursor[0] && logicalCursor[1] === 0) { - currentVisualCursor = [visualLines.length - 1, 0]; - } } else { // Non-empty logical line let currentPosInLogLine = 0; // Tracks position within the current logical line (code point index) @@ -670,7 +669,7 @@ function calculateVisualLayout( // Iterate through code points to build the current visual line (chunk) for (let i = currentPosInLogLine; i < codePointsInLogLine.length; i++) { const char = codePointsInLogLine[i]; - const charVisualWidth = stringWidth(char); + const charVisualWidth = getCachedStringWidth(char); if (currentChunkVisualWidth + charVisualWidth > viewportWidth) { // Character would exceed viewport width @@ -754,30 +753,6 @@ function calculateVisualLayout( visualToLogicalMap.push([logIndex, currentPosInLogLine]); visualLines.push(currentChunk); - // Cursor mapping logic - // Note: currentPosInLogLine here is the start of the currentChunk within the logical line. - if (logIndex === logicalCursor[0]) { - const cursorLogCol = logicalCursor[1]; // This is a code point index - if ( - cursorLogCol >= currentPosInLogLine && - cursorLogCol < currentPosInLogLine + numCodePointsInChunk // Cursor is within this chunk - ) { - currentVisualCursor = [ - visualLines.length - 1, - cursorLogCol - currentPosInLogLine, // Visual col is also code point index within visual line - ]; - } else if ( - cursorLogCol === currentPosInLogLine + numCodePointsInChunk && - numCodePointsInChunk > 0 - ) { - // Cursor is exactly at the end of this non-empty chunk - currentVisualCursor = [ - visualLines.length - 1, - numCodePointsInChunk, - ]; - } - } - const logicalStartOfThisChunk = currentPosInLogLine; currentPosInLogLine += numCodePointsInChunk; @@ -793,23 +768,6 @@ function calculateVisualLayout( currentPosInLogLine++; } } - // After all chunks of a non-empty logical line are processed, - // if the cursor is at the very end of this logical line, update visual cursor. - if ( - logIndex === logicalCursor[0] && - logicalCursor[1] === codePointsInLogLine.length // Cursor at end of logical line - ) { - const lastVisualLineIdx = visualLines.length - 1; - if ( - lastVisualLineIdx >= 0 && - visualLines[lastVisualLineIdx] !== undefined - ) { - currentVisualCursor = [ - lastVisualLineIdx, - cpLen(visualLines[lastVisualLineIdx]), // Cursor at end of last visual line for this logical line - ]; - } - } } }); @@ -824,27 +782,67 @@ function calculateVisualLayout( logicalToVisualMap[0].push([0, 0]); visualToLogicalMap.push([0, 0]); } - currentVisualCursor = [0, 0]; - } - // Handle cursor at the very end of the text (after all processing) - // This case might be covered by the loop end condition now, but kept for safety. - else if ( - logicalCursor[0] === logicalLines.length - 1 && - logicalCursor[1] === cpLen(logicalLines[logicalLines.length - 1]) && - visualLines.length > 0 - ) { - const lastVisLineIdx = visualLines.length - 1; - currentVisualCursor = [lastVisLineIdx, cpLen(visualLines[lastVisLineIdx])]; } return { visualLines, - visualCursor: currentVisualCursor, logicalToVisualMap, visualToLogicalMap, }; } +// Calculates the visual cursor position based on a pre-calculated layout. +// This is a lightweight operation. +function calculateVisualCursorFromLayout( + layout: VisualLayout, + logicalCursor: [number, number], +): [number, number] { + const { logicalToVisualMap, visualLines } = layout; + const [logicalRow, logicalCol] = logicalCursor; + + const segmentsForLogicalLine = logicalToVisualMap[logicalRow]; + + if (!segmentsForLogicalLine || segmentsForLogicalLine.length === 0) { + // This can happen for an empty document. + return [0, 0]; + } + + // Find the segment where the logical column fits. + // The segments are sorted by startColInLogical. + let targetSegmentIndex = segmentsForLogicalLine.findIndex( + ([, startColInLogical], index) => { + const nextStartColInLogical = + index + 1 < segmentsForLogicalLine.length + ? segmentsForLogicalLine[index + 1][1] + : Infinity; + return ( + logicalCol >= startColInLogical && logicalCol < nextStartColInLogical + ); + }, + ); + + // If not found, it means the cursor is at the end of the logical line. + if (targetSegmentIndex === -1) { + if (logicalCol === 0) { + targetSegmentIndex = 0; + } else { + targetSegmentIndex = segmentsForLogicalLine.length - 1; + } + } + + const [visualRow, startColInLogical] = + segmentsForLogicalLine[targetSegmentIndex]; + const visualCol = logicalCol - startColInLogical; + + // The visual column should not exceed the length of the visual line. + const clampedVisualCol = Math.min( + visualCol, + cpLen(visualLines[visualRow] ?? ''), + ); + + return [visualRow, clampedVisualCol]; +} + // --- Start of reducer logic --- export interface TextBufferState { @@ -857,6 +855,8 @@ export interface TextBufferState { clipboard: string | null; selectionAnchor: [number, number] | null; viewportWidth: number; + viewportHeight: number; + visualLayout: VisualLayout; } const historyLimit = 100; @@ -884,6 +884,14 @@ export type TextBufferAction = dir: Direction; }; } + | { + type: 'set_cursor'; + payload: { + cursorRow: number; + cursorCol: number; + preferredCol: number | null; + }; + } | { type: 'delete' } | { type: 'delete_word_left' } | { type: 'delete_word_right' } @@ -903,7 +911,7 @@ export type TextBufferAction = } | { type: 'move_to_offset'; payload: { offset: number } } | { type: 'create_undo_snapshot' } - | { type: 'set_viewport_width'; payload: number } + | { type: 'set_viewport'; payload: { width: number; height: number } } | { type: 'vim_delete_word_forward'; payload: { count: number } } | { type: 'vim_delete_word_backward'; payload: { count: number } } | { type: 'vim_delete_word_end'; payload: { count: number } } @@ -941,7 +949,7 @@ export type TextBufferAction = | { type: 'vim_move_to_line'; payload: { lineNumber: number } } | { type: 'vim_escape_insert_mode' }; -export function textBufferReducer( +function textBufferReducerLogic( state: TextBufferState, action: TextBufferAction, ): TextBufferState { @@ -1047,80 +1055,120 @@ export function textBufferReducer( }; } - case 'set_viewport_width': { - if (action.payload === state.viewportWidth) { + case 'set_viewport': { + const { width, height } = action.payload; + if (width === state.viewportWidth && height === state.viewportHeight) { return state; } - return { ...state, viewportWidth: action.payload }; + return { + ...state, + viewportWidth: width, + viewportHeight: height, + }; } case 'move': { const { dir } = action.payload; - const { lines, cursorRow, cursorCol, viewportWidth } = state; - const visualLayout = calculateVisualLayout( - lines, - [cursorRow, cursorCol], - viewportWidth, - ); - const { visualLines, visualCursor, visualToLogicalMap } = visualLayout; + const { cursorRow, cursorCol, lines, visualLayout, preferredCol } = state; - let newVisualRow = visualCursor[0]; - let newVisualCol = visualCursor[1]; - let newPreferredCol = state.preferredCol; - - const currentVisLineLen = cpLen(visualLines[newVisualRow] ?? ''); - - switch (dir) { - case 'left': - newPreferredCol = null; - if (newVisualCol > 0) { - newVisualCol--; - } else if (newVisualRow > 0) { - newVisualRow--; - newVisualCol = cpLen(visualLines[newVisualRow] ?? ''); - } - break; - case 'right': - newPreferredCol = null; - if (newVisualCol < currentVisLineLen) { - newVisualCol++; - } else if (newVisualRow < visualLines.length - 1) { - newVisualRow++; + // Visual movements + if ( + dir === 'left' || + dir === 'right' || + dir === 'up' || + dir === 'down' || + dir === 'home' || + dir === 'end' + ) { + const visualCursor = calculateVisualCursorFromLayout(visualLayout, [ + cursorRow, + cursorCol, + ]); + const { visualLines, visualToLogicalMap } = visualLayout; + + let newVisualRow = visualCursor[0]; + let newVisualCol = visualCursor[1]; + let newPreferredCol = preferredCol; + + const currentVisLineLen = cpLen(visualLines[newVisualRow] ?? ''); + + switch (dir) { + case 'left': + newPreferredCol = null; + if (newVisualCol > 0) { + newVisualCol--; + } else if (newVisualRow > 0) { + newVisualRow--; + newVisualCol = cpLen(visualLines[newVisualRow] ?? ''); + } + break; + case 'right': + newPreferredCol = null; + if (newVisualCol < currentVisLineLen) { + newVisualCol++; + } else if (newVisualRow < visualLines.length - 1) { + newVisualRow++; + newVisualCol = 0; + } + break; + case 'up': + if (newVisualRow > 0) { + if (newPreferredCol === null) newPreferredCol = newVisualCol; + newVisualRow--; + newVisualCol = clamp( + newPreferredCol, + 0, + cpLen(visualLines[newVisualRow] ?? ''), + ); + } + break; + case 'down': + if (newVisualRow < visualLines.length - 1) { + if (newPreferredCol === null) newPreferredCol = newVisualCol; + newVisualRow++; + newVisualCol = clamp( + newPreferredCol, + 0, + cpLen(visualLines[newVisualRow] ?? ''), + ); + } + break; + case 'home': + newPreferredCol = null; newVisualCol = 0; - } - break; - case 'up': - if (newVisualRow > 0) { - if (newPreferredCol === null) newPreferredCol = newVisualCol; - newVisualRow--; - newVisualCol = clamp( - newPreferredCol, - 0, - cpLen(visualLines[newVisualRow] ?? ''), + break; + case 'end': + newPreferredCol = null; + newVisualCol = currentVisLineLen; + break; + default: { + const exhaustiveCheck: never = dir; + console.error( + `Unknown visual movement direction: ${exhaustiveCheck}`, ); + return state; } - break; - case 'down': - if (newVisualRow < visualLines.length - 1) { - if (newPreferredCol === null) newPreferredCol = newVisualCol; - newVisualRow++; - newVisualCol = clamp( - newPreferredCol, + } + + if (visualToLogicalMap[newVisualRow]) { + const [logRow, logStartCol] = visualToLogicalMap[newVisualRow]; + return { + ...state, + cursorRow: logRow, + cursorCol: clamp( + logStartCol + newVisualCol, 0, - cpLen(visualLines[newVisualRow] ?? ''), - ); - } - break; - case 'home': - newPreferredCol = null; - newVisualCol = 0; - break; - case 'end': - newPreferredCol = null; - newVisualCol = currentVisLineLen; - break; + cpLen(lines[logRow] ?? ''), + ), + preferredCol: newPreferredCol, + }; + } + return state; + } + + // Logical movements + switch (dir) { case 'wordLeft': { - const { cursorRow, cursorCol, lines } = state; if (cursorCol === 0 && cursorRow === 0) return state; let newCursorRow = cursorRow; @@ -1156,7 +1204,6 @@ export function textBufferReducer( }; } case 'wordRight': { - const { cursorRow, cursorCol, lines } = state; if ( cursorRow === lines.length - 1 && cursorCol === cpLen(lines[cursorRow] ?? '') @@ -1186,23 +1233,15 @@ export function textBufferReducer( }; } default: - break; + return state; } + } - if (visualToLogicalMap[newVisualRow]) { - const [logRow, logStartCol] = visualToLogicalMap[newVisualRow]; - return { - ...state, - cursorRow: logRow, - cursorCol: clamp( - logStartCol + newVisualCol, - 0, - cpLen(state.lines[logRow] ?? ''), - ), - preferredCol: newPreferredCol, - }; - } - return state; + case 'set_cursor': { + return { + ...state, + ...action.payload, + }; } case 'delete': { @@ -1214,14 +1253,22 @@ export function textBufferReducer( newLines[cursorRow] = cpSlice(lineContent, 0, cursorCol) + cpSlice(lineContent, cursorCol + 1); - return { ...nextState, lines: newLines, preferredCol: null }; + return { + ...nextState, + lines: newLines, + preferredCol: null, + }; } else if (cursorRow < lines.length - 1) { const nextState = pushUndoLocal(state); const nextLineContent = currentLine(cursorRow + 1); const newLines = [...nextState.lines]; newLines[cursorRow] = lineContent + nextLineContent; newLines.splice(cursorRow + 1, 1); - return { ...nextState, lines: newLines, preferredCol: null }; + return { + ...nextState, + lines: newLines, + preferredCol: null, + }; } return state; } @@ -1303,7 +1350,10 @@ export function textBufferReducer( const nextState = pushUndoLocal(state); const newLines = [...nextState.lines]; newLines[cursorRow] = cpSlice(lineContent, 0, cursorCol); - return { ...nextState, lines: newLines }; + return { + ...nextState, + lines: newLines, + }; } else if (cursorRow < lines.length - 1) { // Act as a delete const nextState = pushUndoLocal(state); @@ -1311,7 +1361,11 @@ export function textBufferReducer( const newLines = [...nextState.lines]; newLines[cursorRow] = lineContent + nextLineContent; newLines.splice(cursorRow + 1, 1); - return { ...nextState, lines: newLines, preferredCol: null }; + return { + ...nextState, + lines: newLines, + preferredCol: null, + }; } return state; } @@ -1441,6 +1495,25 @@ export function textBufferReducer( } } +export function textBufferReducer( + state: TextBufferState, + action: TextBufferAction, +): TextBufferState { + const newState = textBufferReducerLogic(state, action); + + if ( + newState.lines !== state.lines || + newState.viewportWidth !== state.viewportWidth + ) { + return { + ...newState, + visualLayout: calculateLayout(newState.lines, newState.viewportWidth), + }; + } + + return newState; +} + // --- End of reducer logic --- export function useTextBuffer({ @@ -1459,6 +1532,10 @@ export function useTextBuffer({ lines.length === 0 ? [''] : lines, initialCursorOffset, ); + const visualLayout = calculateLayout( + lines.length === 0 ? [''] : lines, + viewport.width, + ); return { lines: lines.length === 0 ? [''] : lines, cursorRow: initialCursorRow, @@ -1469,21 +1546,29 @@ export function useTextBuffer({ clipboard: null, selectionAnchor: null, viewportWidth: viewport.width, + viewportHeight: viewport.height, + visualLayout, }; - }, [initialText, initialCursorOffset, viewport.width]); + }, [initialText, initialCursorOffset, viewport.width, viewport.height]); const [state, dispatch] = useReducer(textBufferReducer, initialState); - const { lines, cursorRow, cursorCol, preferredCol, selectionAnchor } = state; + const { + lines, + cursorRow, + cursorCol, + preferredCol, + selectionAnchor, + visualLayout, + } = state; const text = useMemo(() => lines.join('\n'), [lines]); - const visualLayout = useMemo( - () => - calculateVisualLayout(lines, [cursorRow, cursorCol], state.viewportWidth), - [lines, cursorRow, cursorCol, state.viewportWidth], + const visualCursor = useMemo( + () => calculateVisualCursorFromLayout(visualLayout, [cursorRow, cursorCol]), + [visualLayout, cursorRow, cursorCol], ); - const { visualLines, visualCursor } = visualLayout; + const { visualLines } = visualLayout; const [visualScrollRow, setVisualScrollRow] = useState(0); @@ -1494,8 +1579,11 @@ export function useTextBuffer({ }, [text, onChange]); useEffect(() => { - dispatch({ type: 'set_viewport_width', payload: viewport.width }); - }, [viewport.width]); + dispatch({ + type: 'set_viewport', + payload: { width: viewport.width, height: viewport.height }, + }); + }, [viewport.width, viewport.height]); // Update visual scroll (vertical) useEffect(() => { @@ -1568,9 +1656,12 @@ export function useTextBuffer({ dispatch({ type: 'delete' }); }, []); - const move = useCallback((dir: Direction): void => { - dispatch({ type: 'move', payload: { dir } }); - }, []); + const move = useCallback( + (dir: Direction): void => { + dispatch({ type: 'move', payload: { dir } }); + }, + [dispatch], + ); const undo = useCallback((): void => { dispatch({ type: 'undo' }); diff --git a/packages/cli/src/ui/contexts/UIStateContext.tsx b/packages/cli/src/ui/contexts/UIStateContext.tsx index 13152092a66..0e8a2ca13de 100644 --- a/packages/cli/src/ui/contexts/UIStateContext.tsx +++ b/packages/cli/src/ui/contexts/UIStateContext.tsx @@ -11,6 +11,7 @@ import type { ConsoleMessageItem, ShellConfirmationRequest, ConfirmationRequest, + LoopDetectionConfirmationRequest, HistoryItemWithoutId, StreamingState, } from '../types.js'; @@ -21,16 +22,24 @@ import type { ApprovalMode, UserTierId, DetectedIde, + FallbackIntent, } from '@blocksuser/gemini-cli-core'; import type { DOMElement } from 'ink'; import type { SessionStatsState } from '../contexts/SessionContext.js'; import type { UpdateObject } from '../utils/updateCheck.js'; +export interface ProQuotaDialogRequest { + failedModel: string; + fallbackModel: string; + resolve: (intent: FallbackIntent) => void; +} + export interface UIState { history: HistoryItem[]; isThemeDialogOpen: boolean; themeError: string | null; isAuthenticating: boolean; + isConfigInitialized: boolean; authError: string | null; isAuthDialogOpen: boolean; editorError: string | null; @@ -45,6 +54,7 @@ export interface UIState { commandContext: CommandContext; shellConfirmationRequest: ShellConfirmationRequest | null; confirmationRequest: ConfirmationRequest | null; + loopDetectionConfirmationRequest: LoopDetectionConfirmationRequest | null; geminiMdFileCount: number; streamingState: StreamingState; initError: string | null; @@ -78,9 +88,8 @@ export interface UIState { workspaceExtensions: any[]; // Extension[] // Quota-related state userTier: UserTierId | undefined; - isProQuotaDialogOpen: boolean; + proQuotaRequest: ProQuotaDialogRequest | null; currentModel: string; - // New fields for complete state management contextFileNames: string[]; errorCount: number; availableTerminalHeight: number | undefined; @@ -99,6 +108,8 @@ export interface UIState { updateInfo: UpdateObject | null; showIdeRestartPrompt: boolean; isRestarting: boolean; + activePtyId: number | undefined; + shellFocused: boolean; } export const UIStateContext = createContext(null); diff --git a/packages/cli/src/ui/hooks/keyToAnsi.ts b/packages/cli/src/ui/hooks/keyToAnsi.ts new file mode 100644 index 00000000000..1d5549ab0f7 --- /dev/null +++ b/packages/cli/src/ui/hooks/keyToAnsi.ts @@ -0,0 +1,77 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Key } from '../contexts/KeypressContext.js'; + +export type { Key }; + +/** + * Translates a Key object into its corresponding ANSI escape sequence. + * This is useful for sending control characters to a pseudo-terminal. + * + * @param key The Key object to translate. + * @returns The ANSI escape sequence as a string, or null if no mapping exists. + */ +export function keyToAnsi(key: Key): string | null { + if (key.ctrl) { + // Ctrl + letter + if (key.name >= 'a' && key.name <= 'z') { + return String.fromCharCode( + key.name.charCodeAt(0) - 'a'.charCodeAt(0) + 1, + ); + } + // Other Ctrl combinations might need specific handling + switch (key.name) { + case 'c': + return '\x03'; // ETX (End of Text), commonly used for interrupt + // Add other special ctrl cases if needed + default: + break; + } + } + + // Arrow keys and other special keys + switch (key.name) { + case 'up': + return '\x1b[A'; + case 'down': + return '\x1b[B'; + case 'right': + return '\x1b[C'; + case 'left': + return '\x1b[D'; + case 'escape': + return '\x1b'; + case 'tab': + return '\t'; + case 'backspace': + return '\x7f'; + case 'delete': + return '\x1b[3~'; + case 'home': + return '\x1b[H'; + case 'end': + return '\x1b[F'; + case 'pageup': + return '\x1b[5~'; + case 'pagedown': + return '\x1b[6~'; + default: + break; + } + + // Enter/Return + if (key.name === 'return') { + return '\r'; + } + + // If it's a simple character, return it. + if (!key.ctrl && !key.meta && key.sequence) { + return key.sequence; + } + + return null; +} diff --git a/packages/cli/src/ui/hooks/shellCommandProcessor.test.ts b/packages/cli/src/ui/hooks/shellCommandProcessor.test.ts index 6f16a0469f8..18498567af0 100644 --- a/packages/cli/src/ui/hooks/shellCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/shellCommandProcessor.test.ts @@ -53,6 +53,8 @@ describe('useShellCommandProcessor', () => { let mockShellOutputCallback: (event: ShellOutputEvent) => void; let resolveExecutionPromise: (result: ShellExecutionResult) => void; + let setShellInputFocusedMock: Mock; + beforeEach(() => { vi.clearAllMocks(); @@ -60,9 +62,14 @@ describe('useShellCommandProcessor', () => { setPendingHistoryItemMock = vi.fn(); onExecMock = vi.fn(); onDebugMessageMock = vi.fn(); + setShellInputFocusedMock = vi.fn(); mockConfig = { getTargetDir: () => '/test/dir', getShouldUseNodePtyShell: () => false, + getShellExecutionConfig: () => ({ + terminalHeight: 20, + terminalWidth: 80, + }), } as Config; mockGeminiClient = { addHistory: vi.fn() } as unknown as GeminiClient; @@ -76,12 +83,12 @@ describe('useShellCommandProcessor', () => { mockShellExecutionService.mockImplementation((_cmd, _cwd, callback) => { mockShellOutputCallback = callback; - return { + return Promise.resolve({ pid: 12345, result: new Promise((resolve) => { resolveExecutionPromise = resolve; }), - }; + }); }); }); @@ -94,6 +101,7 @@ describe('useShellCommandProcessor', () => { onDebugMessageMock, mockConfig, mockGeminiClient, + setShellInputFocusedMock, ), ); @@ -139,6 +147,7 @@ describe('useShellCommandProcessor', () => { expect.any(Function), expect.any(Object), false, + expect.any(Object), ); expect(onExecMock).toHaveBeenCalledWith(expect.any(Promise)); }); @@ -172,6 +181,7 @@ describe('useShellCommandProcessor', () => { }), ); expect(mockGeminiClient.addHistory).toHaveBeenCalled(); + expect(setShellInputFocusedMock).toHaveBeenCalledWith(false); }); it('should handle command failure and display error status', async () => { @@ -198,6 +208,7 @@ describe('useShellCommandProcessor', () => { 'Command exited with code 127', ); expect(finalHistoryItem.tools[0].resultDisplay).toContain('not found'); + expect(setShellInputFocusedMock).toHaveBeenCalledWith(false); }); describe('UI Streaming and Throttling', () => { @@ -208,7 +219,7 @@ describe('useShellCommandProcessor', () => { vi.useRealTimers(); }); - it('should throttle pending UI updates for text streams', async () => { + it('should throttle pending UI updates for text streams (non-interactive)', async () => { const { result } = renderProcessorHook(); act(() => { result.current.handleShellCommand( @@ -217,6 +228,26 @@ describe('useShellCommandProcessor', () => { ); }); + // Verify it's using the non-pty shell + const wrappedCommand = `{ stream; }; __code=$?; pwd > "${path.join( + os.tmpdir(), + 'shell_pwd_abcdef.tmp', + )}"; exit $__code`; + expect(mockShellExecutionService).toHaveBeenCalledWith( + wrappedCommand, + '/test/dir', + expect.any(Function), + expect.any(Object), + false, // usePty + expect.any(Object), + ); + + // Wait for the async PID update to happen. + await vi.waitFor(() => { + // It's called once for initial, and once for the PID update. + expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(2); + }); + // Simulate rapid output act(() => { mockShellOutputCallback({ @@ -224,28 +255,49 @@ describe('useShellCommandProcessor', () => { chunk: 'hello', }); }); + // The count should still be 2, as throttling is in effect. + expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(2); - // Should not have updated the UI yet - expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(1); // Only the initial call + // Simulate more rapid output + act(() => { + mockShellOutputCallback({ + type: 'data', + chunk: ' world', + }); + }); + expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(2); - // Advance time and send another event to trigger the throttled update + // Advance time, but the update won't happen until the next event await act(async () => { await vi.advanceTimersByTimeAsync(OUTPUT_UPDATE_INTERVAL_MS + 1); }); + + // Trigger one more event to cause the throttled update to fire. act(() => { mockShellOutputCallback({ type: 'data', - chunk: ' world', + chunk: '', }); }); - // Should now have been called with the cumulative output - expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(2); - expect(setPendingHistoryItemMock).toHaveBeenLastCalledWith( - expect.objectContaining({ - tools: [expect.objectContaining({ resultDisplay: 'hello world' })], - }), - ); + // Now the cumulative update should have occurred. + // Call 1: Initial, Call 2: PID update, Call 3: Throttled stream update + expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(3); + + const streamUpdateFn = setPendingHistoryItemMock.mock.calls[2][0]; + if (!streamUpdateFn || typeof streamUpdateFn !== 'function') { + throw new Error( + 'setPendingHistoryItem was not called with a stream updater function', + ); + } + + // Get the state after the PID update to feed into the stream updater + const pidUpdateFn = setPendingHistoryItemMock.mock.calls[1][0]; + const initialState = setPendingHistoryItemMock.mock.calls[0][0]; + const stateAfterPid = pidUpdateFn(initialState); + + const stateAfterStream = streamUpdateFn(stateAfterPid); + expect(stateAfterStream.tools[0].resultDisplay).toBe('hello world'); }); it('should show binary progress messages correctly', async () => { @@ -269,7 +321,15 @@ describe('useShellCommandProcessor', () => { mockShellOutputCallback({ type: 'binary_progress', bytesReceived: 0 }); }); - expect(setPendingHistoryItemMock).toHaveBeenLastCalledWith( + // The state update is functional, so we test it by executing it. + const updaterFn1 = setPendingHistoryItemMock.mock.lastCall?.[0]; + if (!updaterFn1) { + throw new Error('setPendingHistoryItem was not called'); + } + const initialState = setPendingHistoryItemMock.mock.calls[0][0]; + const stateAfterBinaryDetected = updaterFn1(initialState); + + expect(stateAfterBinaryDetected).toEqual( expect.objectContaining({ tools: [ expect.objectContaining({ @@ -290,7 +350,12 @@ describe('useShellCommandProcessor', () => { }); }); - expect(setPendingHistoryItemMock).toHaveBeenLastCalledWith( + const updaterFn2 = setPendingHistoryItemMock.mock.lastCall?.[0]; + if (!updaterFn2) { + throw new Error('setPendingHistoryItem was not called'); + } + const stateAfterProgress = updaterFn2(stateAfterBinaryDetected); + expect(stateAfterProgress).toEqual( expect.objectContaining({ tools: [ expect.objectContaining({ @@ -316,6 +381,7 @@ describe('useShellCommandProcessor', () => { expect.any(Function), expect.any(Object), false, + expect.any(Object), ); }); @@ -341,6 +407,7 @@ describe('useShellCommandProcessor', () => { expect(finalHistoryItem.tools[0].resultDisplay).toContain( 'Command was cancelled.', ); + expect(setShellInputFocusedMock).toHaveBeenCalledWith(false); }); it('should handle binary output result correctly', async () => { @@ -394,6 +461,7 @@ describe('useShellCommandProcessor', () => { type: 'error', text: 'An unexpected error occurred: Unexpected failure', }); + expect(setShellInputFocusedMock).toHaveBeenCalledWith(false); }); it('should handle synchronous errors during execution and clean up resources', async () => { @@ -425,6 +493,7 @@ describe('useShellCommandProcessor', () => { const tmpFile = path.join(os.tmpdir(), 'shell_pwd_abcdef.tmp'); // Verify that the temporary file was cleaned up expect(vi.mocked(fs.unlinkSync)).toHaveBeenCalledWith(tmpFile); + expect(setShellInputFocusedMock).toHaveBeenCalledWith(false); }); describe('Directory Change Warning', () => { @@ -473,4 +542,177 @@ describe('useShellCommandProcessor', () => { expect(finalHistoryItem.tools[0].resultDisplay).not.toContain('WARNING'); }); }); + + describe('ActiveShellPtyId management', () => { + beforeEach(() => { + // The real service returns a promise that resolves with the pid and result promise + mockShellExecutionService.mockImplementation((_cmd, _cwd, callback) => { + mockShellOutputCallback = callback; + return Promise.resolve({ + pid: 12345, + result: new Promise((resolve) => { + resolveExecutionPromise = resolve; + }), + }); + }); + }); + + it('should have activeShellPtyId as null initially', () => { + const { result } = renderProcessorHook(); + expect(result.current.activeShellPtyId).toBeNull(); + }); + + it('should set activeShellPtyId when a command with a PID starts', async () => { + const { result } = renderProcessorHook(); + + act(() => { + result.current.handleShellCommand('ls', new AbortController().signal); + }); + + await vi.waitFor(() => { + expect(result.current.activeShellPtyId).toBe(12345); + }); + }); + + it('should update the pending history item with the ptyId', async () => { + const { result } = renderProcessorHook(); + + act(() => { + result.current.handleShellCommand('ls', new AbortController().signal); + }); + + await vi.waitFor(() => { + // Wait for the second call which is the functional update + expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(2); + }); + + // The state update is functional, so we test it by executing it. + const updaterFn = setPendingHistoryItemMock.mock.lastCall?.[0]; + expect(typeof updaterFn).toBe('function'); + + // The initial state is the first call to setPendingHistoryItem + const initialState = setPendingHistoryItemMock.mock.calls[0][0]; + const stateAfterPid = updaterFn(initialState); + + expect(stateAfterPid.tools[0].ptyId).toBe(12345); + }); + + it('should reset activeShellPtyId to null after successful execution', async () => { + const { result } = renderProcessorHook(); + + act(() => { + result.current.handleShellCommand('ls', new AbortController().signal); + }); + const execPromise = onExecMock.mock.calls[0][0]; + + await vi.waitFor(() => { + expect(result.current.activeShellPtyId).toBe(12345); + }); + + act(() => { + resolveExecutionPromise(createMockServiceResult()); + }); + await act(async () => await execPromise); + + expect(result.current.activeShellPtyId).toBeNull(); + }); + + it('should reset activeShellPtyId to null after failed execution', async () => { + const { result } = renderProcessorHook(); + + act(() => { + result.current.handleShellCommand( + 'bad-cmd', + new AbortController().signal, + ); + }); + const execPromise = onExecMock.mock.calls[0][0]; + + await vi.waitFor(() => { + expect(result.current.activeShellPtyId).toBe(12345); + }); + + act(() => { + resolveExecutionPromise(createMockServiceResult({ exitCode: 1 })); + }); + await act(async () => await execPromise); + + expect(result.current.activeShellPtyId).toBeNull(); + }); + + it('should reset activeShellPtyId to null if execution promise rejects', async () => { + let rejectResultPromise: (reason?: unknown) => void; + mockShellExecutionService.mockImplementation(() => + Promise.resolve({ + pid: 1234_5, + result: new Promise((_, reject) => { + rejectResultPromise = reject; + }), + }), + ); + const { result } = renderProcessorHook(); + + act(() => { + result.current.handleShellCommand('cmd', new AbortController().signal); + }); + const execPromise = onExecMock.mock.calls[0][0]; + + await vi.waitFor(() => { + expect(result.current.activeShellPtyId).toBe(12345); + }); + + act(() => { + rejectResultPromise(new Error('Failure')); + }); + + await act(async () => await execPromise); + + expect(result.current.activeShellPtyId).toBeNull(); + }); + + it('should not set activeShellPtyId on synchronous execution error and should remain null', async () => { + mockShellExecutionService.mockImplementation(() => { + throw new Error('Sync Error'); + }); + const { result } = renderProcessorHook(); + + expect(result.current.activeShellPtyId).toBeNull(); // Pre-condition + + act(() => { + result.current.handleShellCommand('cmd', new AbortController().signal); + }); + const execPromise = onExecMock.mock.calls[0][0]; + + // The hook's state should not have changed to a PID + expect(result.current.activeShellPtyId).toBeNull(); + + await act(async () => await execPromise); // Let the promise resolve + + // And it should still be null after everything is done + expect(result.current.activeShellPtyId).toBeNull(); + }); + + it('should not set activeShellPtyId if service does not return a PID', async () => { + mockShellExecutionService.mockImplementation((_cmd, _cwd, callback) => { + mockShellOutputCallback = callback; + return Promise.resolve({ + pid: undefined, // No PID + result: new Promise((resolve) => { + resolveExecutionPromise = resolve; + }), + }); + }); + + const { result } = renderProcessorHook(); + + act(() => { + result.current.handleShellCommand('ls', new AbortController().signal); + }); + + // Let microtasks run + await act(async () => {}); + + expect(result.current.activeShellPtyId).toBeNull(); + }); + }); }); diff --git a/packages/cli/src/ui/hooks/shellCommandProcessor.ts b/packages/cli/src/ui/hooks/shellCommandProcessor.ts index ba5b23f0a95..9e0d21aefab 100644 --- a/packages/cli/src/ui/hooks/shellCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/shellCommandProcessor.ts @@ -9,8 +9,9 @@ import type { IndividualToolCallDisplay, } from '../types.js'; import { ToolCallStatus } from '../types.js'; -import { useCallback } from 'react'; +import { useCallback, useState } from 'react'; import type { + AnsiOutput, Config, GeminiClient, ShellExecutionResult, @@ -24,6 +25,7 @@ import crypto from 'node:crypto'; import path from 'node:path'; import os from 'node:os'; import fs from 'node:fs'; +import { themeManager } from '../../ui/themes/theme-manager.js'; export const OUTPUT_UPDATE_INTERVAL_MS = 1000; const MAX_OUTPUT_LENGTH = 10000; @@ -69,7 +71,11 @@ export const useShellCommandProcessor = ( onDebugMessage: (message: string) => void, config: Config, geminiClient: GeminiClient, + setShellInputFocused: (value: boolean) => void, + terminalWidth?: number, + terminalHeight?: number, ) => { + const [activeShellPtyId, setActiveShellPtyId] = useState(null); const handleShellCommand = useCallback( (rawQuery: PartListUnion, abortSignal: AbortSignal): boolean => { if (typeof rawQuery !== 'string' || rawQuery.trim() === '') { @@ -104,7 +110,7 @@ export const useShellCommandProcessor = ( resolve: (value: void | PromiseLike) => void, ) => { let lastUpdateTime = Date.now(); - let cumulativeStdout = ''; + let cumulativeStdout: string | AnsiOutput = ''; let isBinaryStream = false; let binaryBytesReceived = 0; @@ -134,18 +140,38 @@ export const useShellCommandProcessor = ( onDebugMessage(`Executing in ${targetDir}: ${commandToExecute}`); try { + const activeTheme = themeManager.getActiveTheme(); + const shellExecutionConfig = { + ...config.getShellExecutionConfig(), + defaultFg: activeTheme.colors.Foreground, + defaultBg: activeTheme.colors.Background, + }; + const { pid, result } = await ShellExecutionService.execute( commandToExecute, targetDir, (event) => { + let shouldUpdate = false; switch (event.type) { case 'data': // Do not process text data if we've already switched to binary mode. if (isBinaryStream) break; - cumulativeStdout += event.chunk; + // PTY provides the full screen state, so we just replace. + // Child process provides chunks, so we append. + if ( + typeof event.chunk === 'string' && + typeof cumulativeStdout === 'string' + ) { + cumulativeStdout += event.chunk; + } else { + cumulativeStdout = event.chunk; + shouldUpdate = true; + } break; case 'binary_detected': isBinaryStream = true; + // Force an immediate UI update to show the binary detection message. + shouldUpdate = true; break; case 'binary_progress': isBinaryStream = true; @@ -157,7 +183,7 @@ export const useShellCommandProcessor = ( } // Compute the display string based on the *current* state. - let currentDisplayOutput: string; + let currentDisplayOutput: string | AnsiOutput; if (isBinaryStream) { if (binaryBytesReceived > 0) { currentDisplayOutput = `[Receiving binary output... ${formatMemoryUsage( @@ -171,25 +197,49 @@ export const useShellCommandProcessor = ( currentDisplayOutput = cumulativeStdout; } - // Throttle pending UI updates to avoid excessive re-renders. - if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) { - setPendingHistoryItem({ - type: 'tool_group', - tools: [ - { - ...initialToolDisplay, - resultDisplay: currentDisplayOutput, - }, - ], + // Throttle pending UI updates, but allow forced updates. + if ( + shouldUpdate || + Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS + ) { + setPendingHistoryItem((prevItem) => { + if (prevItem?.type === 'tool_group') { + return { + ...prevItem, + tools: prevItem.tools.map((tool) => + tool.callId === callId + ? { ...tool, resultDisplay: currentDisplayOutput } + : tool, + ), + }; + } + return prevItem; }); lastUpdateTime = Date.now(); } }, abortSignal, config.getShouldUseNodePtyShell(), + shellExecutionConfig, ); + console.log(terminalHeight, terminalWidth); + executionPid = pid; + if (pid) { + setActiveShellPtyId(pid); + setPendingHistoryItem((prevItem) => { + if (prevItem?.type === 'tool_group') { + return { + ...prevItem, + tools: prevItem.tools.map((tool) => + tool.callId === callId ? { ...tool, ptyId: pid } : tool, + ), + }; + } + return prevItem; + }); + } result .then((result: ShellExecutionResult) => { @@ -269,6 +319,8 @@ export const useShellCommandProcessor = ( if (pwdFilePath && fs.existsSync(pwdFilePath)) { fs.unlinkSync(pwdFilePath); } + setActiveShellPtyId(null); + setShellInputFocused(false); resolve(); }); } catch (err) { @@ -287,7 +339,8 @@ export const useShellCommandProcessor = ( if (pwdFilePath && fs.existsSync(pwdFilePath)) { fs.unlinkSync(pwdFilePath); } - + setActiveShellPtyId(null); + setShellInputFocused(false); resolve(); // Resolve the promise to unblock `onExec` } }; @@ -306,8 +359,11 @@ export const useShellCommandProcessor = ( setPendingHistoryItem, onExec, geminiClient, + setShellInputFocused, + terminalHeight, + terminalWidth, ], ); - return { handleShellCommand }; + return { handleShellCommand, activeShellPtyId }; }; diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index 7a3e3c8f30e..4718cbb3ae9 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -4,6 +4,27 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { act, renderHook, waitFor } from '@testing-library/react'; +import { vi, describe, it, expect, beforeEach } from 'vitest'; +import { useSlashCommandProcessor } from './slashCommandProcessor.js'; +import type { + CommandContext, + ConfirmShellCommandsActionReturn, + SlashCommand, +} from '../commands/types.js'; +import { CommandKind } from '../commands/types.js'; +import type { LoadedSettings } from '../../config/settings.js'; +import { MessageType } from '../types.js'; +import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js'; +import { FileCommandLoader } from '../../services/FileCommandLoader.js'; +import { McpPromptLoader } from '../../services/McpPromptLoader.js'; +import { + type GeminiClient, + SlashCommandStatus, + ToolConfirmationOutcome, + makeFakeConfig, +} from '@blocksuser/gemini-cli-core'; + const { logSlashCommand } = vi.hoisted(() => ({ logSlashCommand: vi.fn(), })); @@ -113,7 +134,7 @@ describe('useSlashCommandProcessor', () => { beforeEach(() => { vi.clearAllMocks(); - (vi.mocked(BuiltinCommandLoader) as Mock).mockClear(); + vi.mocked(BuiltinCommandLoader).mockClear(); mockBuiltinLoadCommands.mockResolvedValue([]); mockFileLoadCommands.mockResolvedValue([]); mockMcpLoadCommands.mockResolvedValue([]); @@ -223,18 +244,6 @@ describe('useSlashCommandProcessor', () => { expect(fileAction).toHaveBeenCalledTimes(1); expect(builtinAction).not.toHaveBeenCalled(); }); - - it('should not include hidden commands in the command list', async () => { - const visibleCommand = createTestCommand({ name: 'visible' }); - const hiddenCommand = createTestCommand({ name: 'hidden', hidden: true }); - const result = setupProcessorHook([visibleCommand, hiddenCommand]); - - await waitFor(() => { - expect(result.current.slashCommands).toHaveLength(1); - }); - - expect(result.current.slashCommands[0].name).toBe('visible'); - }); }); describe('Command Execution Logic', () => { @@ -403,6 +412,12 @@ describe('useSlashCommandProcessor', () => { }); it('should handle "load_history" action', async () => { + const mockClient = { + setHistory: vi.fn(), + stripThoughtsFromHistory: vi.fn(), + } as unknown as GeminiClient; + vi.spyOn(mockConfig, 'getGeminiClient').mockReturnValue(mockClient); + const command = createTestCommand({ name: 'load', action: vi.fn().mockResolvedValue({ @@ -426,14 +441,11 @@ describe('useSlashCommandProcessor', () => { }); it('should strip thoughts when handling "load_history" action', async () => { - const mockSetHistory = vi.fn(); - const mockGeminiClient = { - setHistory: mockSetHistory, - }; - vi.spyOn(mockConfig, 'getGeminiClient').mockReturnValue( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - mockGeminiClient as any, - ); + const mockClient = { + setHistory: vi.fn(), + stripThoughtsFromHistory: vi.fn(), + } as unknown as GeminiClient; + vi.spyOn(mockConfig, 'getGeminiClient').mockReturnValue(mockClient); const historyWithThoughts = [ { @@ -457,10 +469,8 @@ describe('useSlashCommandProcessor', () => { await result.current.handleSlashCommand('/loadwiththoughts'); }); - expect(mockSetHistory).toHaveBeenCalledTimes(1); - expect(mockSetHistory).toHaveBeenCalledWith(historyWithThoughts, { - stripThoughts: true, - }); + expect(mockClient.setHistory).toHaveBeenCalledTimes(1); + expect(mockClient.stripThoughtsFromHistory).toHaveBeenCalledWith(); }); it('should handle a "quit" action', async () => { diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index 282b3418e65..da5a17dbc63 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -59,6 +59,7 @@ export const useSlashCommandProcessor = ( setIsProcessing: (isProcessing: boolean) => void, setGeminiMdFileCount: (count: number) => void, actions: SlashCommandProcessorActions, + isConfigInitialized: boolean, ) => { const session = useSessionStats(); const [commands, setCommands] = useState([]); @@ -255,7 +256,7 @@ export const useSlashCommandProcessor = ( return () => { controller.abort(); }; - }, [config, reloadTrigger]); + }, [config, reloadTrigger, isConfigInitialized]); const handleSlashCommand = useCallback( async ( @@ -401,9 +402,8 @@ export const useSlashCommandProcessor = ( } } case 'load_history': { - config - ?.getGeminiClient() - ?.setHistory(result.clientHistory, { stripThoughts: true }); + config?.getGeminiClient()?.setHistory(result.clientHistory); + config?.getGeminiClient()?.stripThoughtsFromHistory(); fullCommandContext.ui.clear(); result.history.forEach((item, index) => { fullCommandContext.ui.addItem(item, index); diff --git a/packages/cli/src/ui/hooks/useFolderTrust.test.ts b/packages/cli/src/ui/hooks/useFolderTrust.test.ts index 99a7c3b05eb..821e65ec2f3 100644 --- a/packages/cli/src/ui/hooks/useFolderTrust.test.ts +++ b/packages/cli/src/ui/hooks/useFolderTrust.test.ts @@ -22,23 +22,24 @@ vi.mock('process', () => ({ describe('useFolderTrust', () => { let mockSettings: LoadedSettings; - let mockConfig: unknown; let mockTrustedFolders: LoadedTrustedFolders; let loadTrustedFoldersSpy: vi.SpyInstance; let isWorkspaceTrustedSpy: vi.SpyInstance; let onTrustChange: (isTrusted: boolean | undefined) => void; + let refreshStatic: () => void; beforeEach(() => { mockSettings = { merged: { - folderTrustFeature: true, - folderTrust: undefined, + security: { + folderTrust: { + enabled: true, + }, + }, }, setValue: vi.fn(), } as unknown as LoadedSettings; - mockConfig = {} as unknown; - mockTrustedFolders = { setValue: vi.fn(), } as unknown as LoadedTrustedFolders; @@ -49,6 +50,7 @@ describe('useFolderTrust', () => { isWorkspaceTrustedSpy = vi.spyOn(trustedFolders, 'isWorkspaceTrusted'); (process.cwd as vi.Mock).mockReturnValue('/test/path'); onTrustChange = vi.fn(); + refreshStatic = vi.fn(); }); afterEach(() => { @@ -58,7 +60,7 @@ describe('useFolderTrust', () => { it('should not open dialog when folder is already trusted', () => { isWorkspaceTrustedSpy.mockReturnValue(true); const { result } = renderHook(() => - useFolderTrust(mockSettings, mockConfig, onTrustChange), + useFolderTrust(mockSettings, onTrustChange, refreshStatic), ); expect(result.current.isFolderTrustDialogOpen).toBe(false); expect(onTrustChange).toHaveBeenCalledWith(true); @@ -67,7 +69,7 @@ describe('useFolderTrust', () => { it('should not open dialog when folder is already untrusted', () => { isWorkspaceTrustedSpy.mockReturnValue(false); const { result } = renderHook(() => - useFolderTrust(mockSettings, mockConfig, onTrustChange), + useFolderTrust(mockSettings, onTrustChange, refreshStatic), ); expect(result.current.isFolderTrustDialogOpen).toBe(false); expect(onTrustChange).toHaveBeenCalledWith(false); @@ -76,7 +78,7 @@ describe('useFolderTrust', () => { it('should open dialog when folder trust is undefined', () => { isWorkspaceTrustedSpy.mockReturnValue(undefined); const { result } = renderHook(() => - useFolderTrust(mockSettings, mockConfig, onTrustChange), + useFolderTrust(mockSettings, onTrustChange, refreshStatic), ); expect(result.current.isFolderTrustDialogOpen).toBe(true); expect(onTrustChange).toHaveBeenCalledWith(undefined); @@ -87,7 +89,7 @@ describe('useFolderTrust', () => { .mockReturnValueOnce(undefined) .mockReturnValueOnce(true); const { result } = renderHook(() => - useFolderTrust(mockSettings, mockConfig, onTrustChange), + useFolderTrust(mockSettings, onTrustChange, refreshStatic), ); isWorkspaceTrustedSpy.mockReturnValue(true); @@ -109,7 +111,7 @@ describe('useFolderTrust', () => { .mockReturnValueOnce(undefined) .mockReturnValueOnce(true); const { result } = renderHook(() => - useFolderTrust(mockSettings, mockConfig, onTrustChange), + useFolderTrust(mockSettings, onTrustChange, refreshStatic), ); act(() => { @@ -129,7 +131,7 @@ describe('useFolderTrust', () => { .mockReturnValueOnce(undefined) .mockReturnValueOnce(false); const { result } = renderHook(() => - useFolderTrust(mockSettings, mockConfig, onTrustChange), + useFolderTrust(mockSettings, onTrustChange, refreshStatic), ); act(() => { @@ -141,14 +143,14 @@ describe('useFolderTrust', () => { TrustLevel.DO_NOT_TRUST, ); expect(onTrustChange).toHaveBeenLastCalledWith(false); - expect(result.current.isRestarting).toBe(false); - expect(result.current.isFolderTrustDialogOpen).toBe(false); + expect(result.current.isRestarting).toBe(true); + expect(result.current.isFolderTrustDialogOpen).toBe(true); }); it('should do nothing for default choice', () => { isWorkspaceTrustedSpy.mockReturnValue(undefined); const { result } = renderHook(() => - useFolderTrust(mockSettings, mockConfig, onTrustChange), + useFolderTrust(mockSettings, onTrustChange, refreshStatic), ); act(() => { @@ -166,15 +168,15 @@ describe('useFolderTrust', () => { it('should set isRestarting to true when trust status changes from false to true', () => { isWorkspaceTrustedSpy.mockReturnValueOnce(false).mockReturnValueOnce(true); // Initially untrusted, then trusted const { result } = renderHook(() => - useFolderTrust(mockSettings, mockConfig, onTrustChange), + useFolderTrust(mockSettings, onTrustChange, refreshStatic), ); act(() => { result.current.handleFolderTrustSelect(FolderTrustChoice.TRUST_FOLDER); }); - expect(result.current.isRestarting).toBe(false); - expect(result.current.isFolderTrustDialogOpen).toBe(false); // Dialog should close after selection + expect(result.current.isRestarting).toBe(true); + expect(result.current.isFolderTrustDialogOpen).toBe(true); // Dialog should stay open }); it('should not set isRestarting to true when trust status does not change', () => { @@ -182,7 +184,7 @@ describe('useFolderTrust', () => { .mockReturnValueOnce(undefined) .mockReturnValueOnce(true); // Initially undefined, then trust const { result } = renderHook(() => - useFolderTrust(mockSettings, mockConfig, onTrustChange), + useFolderTrust(mockSettings, onTrustChange, refreshStatic), ); act(() => { @@ -192,4 +194,26 @@ describe('useFolderTrust', () => { expect(result.current.isRestarting).toBe(false); expect(result.current.isFolderTrustDialogOpen).toBe(false); // Dialog should close }); + + it('should call refreshStatic when dialog opens and closes', () => { + isWorkspaceTrustedSpy.mockReturnValue(undefined); + const { result } = renderHook(() => + useFolderTrust(mockSettings, onTrustChange, refreshStatic), + ); + + // The hook runs, isFolderTrustDialogOpen becomes true, useEffect triggers. + // It's called once on mount, and once when the dialog state changes. + expect(refreshStatic).toHaveBeenCalledTimes(2); + expect(result.current.isFolderTrustDialogOpen).toBe(true); + + // Now, simulate closing the dialog + isWorkspaceTrustedSpy.mockReturnValue(true); // So the state update works + act(() => { + result.current.handleFolderTrustSelect(FolderTrustChoice.TRUST_FOLDER); + }); + + // The state isFolderTrustDialogOpen becomes false, useEffect triggers again + expect(refreshStatic).toHaveBeenCalledTimes(3); + expect(result.current.isFolderTrustDialogOpen).toBe(false); + }); }); diff --git a/packages/cli/src/ui/hooks/useFolderTrust.ts b/packages/cli/src/ui/hooks/useFolderTrust.ts index 8d39d52301e..ddefc8827f1 100644 --- a/packages/cli/src/ui/hooks/useFolderTrust.ts +++ b/packages/cli/src/ui/hooks/useFolderTrust.ts @@ -5,7 +5,6 @@ */ import { useState, useCallback, useEffect } from 'react'; -import { type Config } from '@blocksuser/gemini-cli-core'; import type { LoadedSettings } from '../../config/settings.js'; import { FolderTrustChoice } from '../components/FolderTrustDialog.js'; import { @@ -17,12 +16,12 @@ import * as process from 'node:process'; export const useFolderTrust = ( settings: LoadedSettings, - config: Config, onTrustChange: (isTrusted: boolean | undefined) => void, + refreshStatic: () => void, ) => { const [isTrusted, setIsTrusted] = useState(undefined); const [isFolderTrustDialogOpen, setIsFolderTrustDialogOpen] = useState(false); - const [isRestarting] = useState(false); + const [isRestarting, setIsRestarting] = useState(false); const folderTrust = settings.merged.security?.folderTrust?.enabled; @@ -33,12 +32,20 @@ export const useFolderTrust = ( onTrustChange(trusted); }, [folderTrust, onTrustChange, settings.merged]); + useEffect(() => { + // When the folder trust dialog is about to open/close, we need to force a refresh + // of the static content to ensure the Tips are hidden/shown correctly. + refreshStatic(); + }, [isFolderTrustDialogOpen, refreshStatic]); + const handleFolderTrustSelect = useCallback( (choice: FolderTrustChoice) => { const trustedFolders = loadTrustedFolders(); const cwd = process.cwd(); let trustLevel: TrustLevel; + const wasTrusted = isTrusted ?? true; + switch (choice) { case FolderTrustChoice.TRUST_FOLDER: trustLevel = TrustLevel.TRUST_FOLDER; @@ -54,12 +61,21 @@ export const useFolderTrust = ( } trustedFolders.setValue(cwd, trustLevel); - const trusted = isWorkspaceTrusted(settings.merged); - setIsTrusted(trusted); - setIsFolderTrustDialogOpen(false); - onTrustChange(trusted); + const currentIsTrusted = + trustLevel === TrustLevel.TRUST_FOLDER || + trustLevel === TrustLevel.TRUST_PARENT; + setIsTrusted(currentIsTrusted); + onTrustChange(currentIsTrusted); + + const needsRestart = wasTrusted !== currentIsTrusted; + if (needsRestart) { + setIsRestarting(true); + setIsFolderTrustDialogOpen(true); + } else { + setIsFolderTrustDialogOpen(false); + } }, - [settings.merged, onTrustChange], + [onTrustChange, isTrusted], ); return { diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 4ab600c8056..f3626ef7605 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -297,6 +297,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ); }, { @@ -459,6 +462,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ), ); @@ -539,6 +545,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ), ); @@ -648,6 +657,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ), ); @@ -758,6 +770,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ), ); @@ -888,6 +903,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, cancelSubmitSpy, + () => {}, + 80, + 24, ), ); @@ -901,6 +919,47 @@ describe('useGeminiStream', () => { expect(cancelSubmitSpy).toHaveBeenCalled(); }); + it('should call setShellInputFocused(false) when escape is pressed', async () => { + const setShellInputFocusedSpy = vi.fn(); + const mockStream = (async function* () { + yield { type: 'content', value: 'Part 1' }; + await new Promise(() => {}); // Keep stream open + })(); + mockSendMessageStream.mockReturnValue(mockStream); + + const { result } = renderHook(() => + useGeminiStream( + mockConfig.getGeminiClient(), + [], + mockAddItem, + mockConfig, + mockLoadedSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => 'vscode' as EditorType, + () => {}, + () => Promise.resolve(), + false, + () => {}, + () => {}, + vi.fn(), + setShellInputFocusedSpy, // Pass the spy here + 80, + 24, + ), + ); + + // Start a query + await act(async () => { + result.current.submitQuery('test query'); + }); + + simulateEscapeKeyPress(); + + expect(setShellInputFocusedSpy).toHaveBeenCalledWith(false); + }); + it('should not do anything if escape is pressed when not responding', () => { const { result } = renderTestHook(); @@ -1200,6 +1259,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ), ); @@ -1254,6 +1316,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ), ); @@ -1308,6 +1373,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ), ); @@ -1360,6 +1428,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ), ); @@ -1413,6 +1484,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ), ); @@ -1495,6 +1569,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -1505,6 +1580,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + vi.fn(), + 80, + 24, ), ); @@ -1556,6 +1634,9 @@ describe('useGeminiStream', () => { vi.fn(), // setModelSwitched vi.fn(), // onEditorClose vi.fn(), // onCancelSubmit + vi.fn(), // setShellInputFocused + 80, // terminalWidth + 24, // terminalHeight ), ); @@ -1624,6 +1705,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ), ); @@ -1706,6 +1790,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ), ); @@ -1761,6 +1848,9 @@ describe('useGeminiStream', () => { () => {}, () => {}, () => {}, + () => {}, + 80, + 24, ), ); @@ -1789,4 +1879,262 @@ describe('useGeminiStream', () => { ); }); }); + + describe('Loop Detection Confirmation', () => { + beforeEach(() => { + // Add mock for getLoopDetectionService to the config + const mockLoopDetectionService = { + disableForSession: vi.fn(), + }; + mockConfig.getGeminiClient = vi.fn().mockReturnValue({ + ...new MockedGeminiClientClass(mockConfig), + getLoopDetectionService: () => mockLoopDetectionService, + }); + }); + + it('should set loopDetectionConfirmationRequest when LoopDetected event is received', async () => { + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { + type: ServerGeminiEventType.Content, + value: 'Some content', + }; + yield { + type: ServerGeminiEventType.LoopDetected, + }; + })(), + ); + + const { result } = renderTestHook(); + + await act(async () => { + await result.current.submitQuery('test query'); + }); + + await waitFor(() => { + expect(result.current.loopDetectionConfirmationRequest).not.toBeNull(); + expect( + typeof result.current.loopDetectionConfirmationRequest?.onComplete, + ).toBe('function'); + }); + }); + + it('should disable loop detection and show message when user selects "disable"', async () => { + const mockLoopDetectionService = { + disableForSession: vi.fn(), + }; + const mockClient = { + ...new MockedGeminiClientClass(mockConfig), + getLoopDetectionService: () => mockLoopDetectionService, + }; + mockConfig.getGeminiClient = vi.fn().mockReturnValue(mockClient); + + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { + type: ServerGeminiEventType.LoopDetected, + }; + })(), + ); + + const { result } = renderTestHook(); + + await act(async () => { + await result.current.submitQuery('test query'); + }); + + // Wait for confirmation request to be set + await waitFor(() => { + expect(result.current.loopDetectionConfirmationRequest).not.toBeNull(); + }); + + // Simulate user selecting "disable" + await act(async () => { + result.current.loopDetectionConfirmationRequest?.onComplete({ + userSelection: 'disable', + }); + }); + + // Verify loop detection was disabled + expect(mockLoopDetectionService.disableForSession).toHaveBeenCalledTimes( + 1, + ); + + // Verify confirmation request was cleared + expect(result.current.loopDetectionConfirmationRequest).toBeNull(); + + // Verify appropriate message was added + expect(mockAddItem).toHaveBeenCalledWith( + { + type: 'info', + text: 'Loop detection has been disabled for this session. Please try your request again.', + }, + expect.any(Number), + ); + }); + + it('should keep loop detection enabled and show message when user selects "keep"', async () => { + const mockLoopDetectionService = { + disableForSession: vi.fn(), + }; + const mockClient = { + ...new MockedGeminiClientClass(mockConfig), + getLoopDetectionService: () => mockLoopDetectionService, + }; + mockConfig.getGeminiClient = vi.fn().mockReturnValue(mockClient); + + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { + type: ServerGeminiEventType.LoopDetected, + }; + })(), + ); + + const { result } = renderTestHook(); + + await act(async () => { + await result.current.submitQuery('test query'); + }); + + // Wait for confirmation request to be set + await waitFor(() => { + expect(result.current.loopDetectionConfirmationRequest).not.toBeNull(); + }); + + // Simulate user selecting "keep" + await act(async () => { + result.current.loopDetectionConfirmationRequest?.onComplete({ + userSelection: 'keep', + }); + }); + + // Verify loop detection was NOT disabled + expect(mockLoopDetectionService.disableForSession).not.toHaveBeenCalled(); + + // Verify confirmation request was cleared + expect(result.current.loopDetectionConfirmationRequest).toBeNull(); + + // Verify appropriate message was added + expect(mockAddItem).toHaveBeenCalledWith( + { + type: 'info', + text: 'A potential loop was detected. This can happen due to repetitive tool calls or other model behavior. The request has been halted.', + }, + expect.any(Number), + ); + }); + + it('should handle multiple loop detection events properly', async () => { + const { result } = renderTestHook(); + + // First loop detection - set up fresh mock for first call + mockSendMessageStream.mockReturnValueOnce( + (async function* () { + yield { + type: ServerGeminiEventType.LoopDetected, + }; + })(), + ); + + // First loop detection + await act(async () => { + await result.current.submitQuery('first query'); + }); + + await waitFor(() => { + expect(result.current.loopDetectionConfirmationRequest).not.toBeNull(); + }); + + // Simulate user selecting "keep" for first request + await act(async () => { + result.current.loopDetectionConfirmationRequest?.onComplete({ + userSelection: 'keep', + }); + }); + + expect(result.current.loopDetectionConfirmationRequest).toBeNull(); + + // Verify first message was added + expect(mockAddItem).toHaveBeenCalledWith( + { + type: 'info', + text: 'A potential loop was detected. This can happen due to repetitive tool calls or other model behavior. The request has been halted.', + }, + expect.any(Number), + ); + + // Second loop detection - set up fresh mock for second call + mockSendMessageStream.mockReturnValueOnce( + (async function* () { + yield { + type: ServerGeminiEventType.LoopDetected, + }; + })(), + ); + + // Second loop detection + await act(async () => { + await result.current.submitQuery('second query'); + }); + + await waitFor(() => { + expect(result.current.loopDetectionConfirmationRequest).not.toBeNull(); + }); + + // Simulate user selecting "disable" for second request + await act(async () => { + result.current.loopDetectionConfirmationRequest?.onComplete({ + userSelection: 'disable', + }); + }); + + expect(result.current.loopDetectionConfirmationRequest).toBeNull(); + + // Verify second message was added + expect(mockAddItem).toHaveBeenCalledWith( + { + type: 'info', + text: 'Loop detection has been disabled for this session. Please try your request again.', + }, + expect.any(Number), + ); + }); + + it('should process LoopDetected event after moving pending history to history', async () => { + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { + type: ServerGeminiEventType.Content, + value: 'Some response content', + }; + yield { + type: ServerGeminiEventType.LoopDetected, + }; + })(), + ); + + const { result } = renderTestHook(); + + await act(async () => { + await result.current.submitQuery('test query'); + }); + + // Verify that the content was added to history before the loop detection dialog + await waitFor(() => { + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'gemini', + text: 'Some response content', + }), + expect.any(Number), + ); + }); + + // Then verify loop detection confirmation request was set + await waitFor(() => { + expect(result.current.loopDetectionConfirmationRequest).not.toBeNull(); + }); + }); + }); }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 7d239525db9..ff124907166 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -33,6 +33,7 @@ import { parseAndFormatApiError, getCodeAssistServer, UserTierId, + promptIdContext, } from '@blocksuser/gemini-cli-core'; import { type Part, type PartListUnion, FinishReason } from '@google/genai'; import type { @@ -101,6 +102,10 @@ export const useGeminiStream = ( setModelSwitchedFromQuotaError: React.Dispatch>, onEditorClose: () => void, onCancelSubmit: () => void, + setShellInputFocused: (value: boolean) => void, + terminalWidth: number, + terminalHeight: number, + isShellFocused?: boolean, ) => { const [initError, setInitError] = useState(null); const abortControllerRef = useRef(null); @@ -140,7 +145,6 @@ export const useGeminiStream = ( } }, config, - setPendingHistoryItem, getPreferredEditor, onEditorClose, ); @@ -151,22 +155,50 @@ export const useGeminiStream = ( [toolCalls], ); + const activeToolPtyId = useMemo(() => { + const executingShellTool = toolCalls?.find( + (tc) => + tc.status === 'executing' && tc.request.name === 'run_shell_command', + ); + if (executingShellTool) { + return (executingShellTool as { pid?: number }).pid; + } + return undefined; + }, [toolCalls]); + const loopDetectedRef = useRef(false); + const [ + loopDetectionConfirmationRequest, + setLoopDetectionConfirmationRequest, + ] = useState<{ + onComplete: (result: { userSelection: 'disable' | 'keep' }) => void; + } | null>(null); const onExec = useCallback(async (done: Promise) => { setIsResponding(true); await done; setIsResponding(false); }, []); - const { handleShellCommand } = useShellCommandProcessor( + const { handleShellCommand, activeShellPtyId } = useShellCommandProcessor( addItem, setPendingHistoryItem, onExec, onDebugMessage, config, geminiClient, + setShellInputFocused, + terminalWidth, + terminalHeight, ); + const activePtyId = activeShellPtyId || activeToolPtyId; + + useEffect(() => { + if (!activePtyId) { + setShellInputFocused(false); + } + }, [activePtyId, setShellInputFocused]); + const streamingState = useMemo(() => { if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) { return StreamingState.WaitingForConfirmation; @@ -233,17 +265,19 @@ export const useGeminiStream = ( setPendingHistoryItem(null); onCancelSubmit(); setIsResponding(false); + setShellInputFocused(false); }, [ streamingState, addItem, setPendingHistoryItem, onCancelSubmit, pendingHistoryItemRef, + setShellInputFocused, ]); useKeypress( (key) => { - if (key.name === 'escape') { + if (key.name === 'escape' && !isShellFocused) { cancelOngoingRequest(); } }, @@ -587,15 +621,38 @@ export const useGeminiStream = ( [addItem, config], ); + const handleLoopDetectionConfirmation = useCallback( + (result: { userSelection: 'disable' | 'keep' }) => { + setLoopDetectionConfirmationRequest(null); + + if (result.userSelection === 'disable') { + config.getGeminiClient().getLoopDetectionService().disableForSession(); + addItem( + { + type: 'info', + text: `Loop detection has been disabled for this session. Please try your request again.`, + }, + Date.now(), + ); + } else { + addItem( + { + type: 'info', + text: `A potential loop was detected. This can happen due to repetitive tool calls or other model behavior. The request has been halted.`, + }, + Date.now(), + ); + } + }, + [config, addItem], + ); + const handleLoopDetectedEvent = useCallback(() => { - addItem( - { - type: 'info', - text: `A potential loop was detected. This can happen due to repetitive tool calls or other model behavior. The request has been halted.`, - }, - Date.now(), - ); - }, [addItem]); + // Show the confirmation dialog to choose whether to disable loop detection + setLoopDetectionConfirmationRequest({ + onComplete: handleLoopDetectionConfirmation, + }); + }, [handleLoopDetectionConfirmation]); const processGeminiStreamEvents = useCallback( async ( @@ -705,71 +762,72 @@ export const useGeminiStream = ( if (!prompt_id) { prompt_id = config.getSessionId() + '########' + getPromptCount(); } - - const { queryToSend, shouldProceed } = await prepareQueryForGemini( - query, - userMessageTimestamp, - abortSignal, - prompt_id!, - ); - - if (!shouldProceed || queryToSend === null) { - return; - } - - if (!options?.isContinuation) { - startNewPrompt(); - setThought(null); // Reset thought when starting a new prompt - } - - setIsResponding(true); - setInitError(null); - - try { - const stream = geminiClient.sendMessageStream( - queryToSend, - abortSignal, - prompt_id!, - ); - const processingStatus = await processGeminiStreamEvents( - stream, + return promptIdContext.run(prompt_id, async () => { + const { queryToSend, shouldProceed } = await prepareQueryForGemini( + query, userMessageTimestamp, abortSignal, + prompt_id, ); - if (processingStatus === StreamProcessingStatus.UserCancelled) { + if (!shouldProceed || queryToSend === null) { return; } - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - setPendingHistoryItem(null); - } - if (loopDetectedRef.current) { - loopDetectedRef.current = false; - handleLoopDetectedEvent(); + if (!options?.isContinuation) { + startNewPrompt(); + setThought(null); // Reset thought when starting a new prompt } - } catch (error: unknown) { - if (error instanceof UnauthorizedError) { - onAuthError('Session expired or is unauthorized.'); - } else if (!isNodeError(error) || error.name !== 'AbortError') { - addItem( - { - type: MessageType.ERROR, - text: parseAndFormatApiError( - getErrorMessage(error) || 'Unknown error', - config.getContentGeneratorConfig()?.authType, - undefined, - config.getModel(), - DEFAULT_GEMINI_FLASH_MODEL, - ), - }, + + setIsResponding(true); + setInitError(null); + + try { + const stream = geminiClient.sendMessageStream( + queryToSend, + abortSignal, + prompt_id, + ); + const processingStatus = await processGeminiStreamEvents( + stream, userMessageTimestamp, + abortSignal, ); + + if (processingStatus === StreamProcessingStatus.UserCancelled) { + return; + } + + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + setPendingHistoryItem(null); + } + if (loopDetectedRef.current) { + loopDetectedRef.current = false; + handleLoopDetectedEvent(); + } + } catch (error: unknown) { + if (error instanceof UnauthorizedError) { + onAuthError('Session expired or is unauthorized.'); + } else if (!isNodeError(error) || error.name !== 'AbortError') { + addItem( + { + type: MessageType.ERROR, + text: parseAndFormatApiError( + getErrorMessage(error) || 'Unknown error', + config.getContentGeneratorConfig()?.authType, + undefined, + config.getModel(), + DEFAULT_GEMINI_FLASH_MODEL, + ), + }, + userMessageTimestamp, + ); + } + } finally { + setIsResponding(false); } - } finally { - setIsResponding(false); - } + }); }, [ streamingState, @@ -1043,5 +1101,7 @@ export const useGeminiStream = ( pendingHistoryItems, thought, cancelOngoingRequest, + activePtyId, + loopDetectionConfirmationRequest, }; }; diff --git a/packages/cli/src/ui/hooks/useIdeTrustListener.ts b/packages/cli/src/ui/hooks/useIdeTrustListener.ts index bc6d98ba85d..11d9a205351 100644 --- a/packages/cli/src/ui/hooks/useIdeTrustListener.ts +++ b/packages/cli/src/ui/hooks/useIdeTrustListener.ts @@ -5,7 +5,7 @@ */ import { useCallback, useEffect, useState, useSyncExternalStore } from 'react'; -import { IdeClient, ideContext } from '@blocksuser/gemini-cli-core'; +import { IdeClient, ideContextStore } from '@blocksuser/gemini-cli-core'; /** * This hook listens for trust status updates from the IDE companion extension. @@ -26,8 +26,7 @@ export function useIdeTrustListener() { }; }, []); - const getSnapshot = () => - ideContext.getIdeContext()?.workspaceState?.isTrusted; + const getSnapshot = () => ideContextStore.get()?.workspaceState?.isTrusted; const isIdeTrusted = useSyncExternalStore(subscribe, getSnapshot); diff --git a/packages/cli/src/ui/hooks/useMessageQueue.test.ts b/packages/cli/src/ui/hooks/useMessageQueue.test.ts index 01e49afe5f5..33dbf3211c5 100644 --- a/packages/cli/src/ui/hooks/useMessageQueue.test.ts +++ b/packages/cli/src/ui/hooks/useMessageQueue.test.ts @@ -25,6 +25,7 @@ describe('useMessageQueue', () => { it('should initialize with empty queue', () => { const { result } = renderHook(() => useMessageQueue({ + isConfigInitialized: true, streamingState: StreamingState.Idle, submitQuery: mockSubmitQuery, }), @@ -37,6 +38,7 @@ describe('useMessageQueue', () => { it('should add messages to queue', () => { const { result } = renderHook(() => useMessageQueue({ + isConfigInitialized: true, streamingState: StreamingState.Responding, submitQuery: mockSubmitQuery, }), @@ -56,6 +58,7 @@ describe('useMessageQueue', () => { it('should filter out empty messages', () => { const { result } = renderHook(() => useMessageQueue({ + isConfigInitialized: true, streamingState: StreamingState.Responding, submitQuery: mockSubmitQuery, }), @@ -77,6 +80,7 @@ describe('useMessageQueue', () => { it('should clear queue', () => { const { result } = renderHook(() => useMessageQueue({ + isConfigInitialized: true, streamingState: StreamingState.Responding, submitQuery: mockSubmitQuery, }), @@ -98,6 +102,7 @@ describe('useMessageQueue', () => { it('should return queued messages as text with double newlines', () => { const { result } = renderHook(() => useMessageQueue({ + isConfigInitialized: true, streamingState: StreamingState.Responding, submitQuery: mockSubmitQuery, }), @@ -118,6 +123,7 @@ describe('useMessageQueue', () => { const { result, rerender } = renderHook( ({ streamingState }) => useMessageQueue({ + isConfigInitialized: true, streamingState, submitQuery: mockSubmitQuery, }), @@ -145,6 +151,7 @@ describe('useMessageQueue', () => { const { rerender } = renderHook( ({ streamingState }) => useMessageQueue({ + isConfigInitialized: true, streamingState, submitQuery: mockSubmitQuery, }), @@ -163,6 +170,7 @@ describe('useMessageQueue', () => { const { result, rerender } = renderHook( ({ streamingState }) => useMessageQueue({ + isConfigInitialized: true, streamingState, submitQuery: mockSubmitQuery, }), @@ -187,6 +195,7 @@ describe('useMessageQueue', () => { const { result, rerender } = renderHook( ({ streamingState }) => useMessageQueue({ + isConfigInitialized: true, streamingState, submitQuery: mockSubmitQuery, }), diff --git a/packages/cli/src/ui/hooks/useMessageQueue.ts b/packages/cli/src/ui/hooks/useMessageQueue.ts index f7bbe1ebe70..517040fecfa 100644 --- a/packages/cli/src/ui/hooks/useMessageQueue.ts +++ b/packages/cli/src/ui/hooks/useMessageQueue.ts @@ -8,6 +8,7 @@ import { useCallback, useEffect, useState } from 'react'; import { StreamingState } from '../types.js'; export interface UseMessageQueueOptions { + isConfigInitialized: boolean; streamingState: StreamingState; submitQuery: (query: string) => void; } @@ -25,6 +26,7 @@ export interface UseMessageQueueReturn { * sends them when streaming completes. */ export function useMessageQueue({ + isConfigInitialized, streamingState, submitQuery, }: UseMessageQueueOptions): UseMessageQueueReturn { @@ -51,14 +53,18 @@ export function useMessageQueue({ // Process queued messages when streaming becomes idle useEffect(() => { - if (streamingState === StreamingState.Idle && messageQueue.length > 0) { + if ( + isConfigInitialized && + streamingState === StreamingState.Idle && + messageQueue.length > 0 + ) { // Combine all messages with double newlines for clarity const combinedMessage = messageQueue.join('\n\n'); // Clear the queue and submit setMessageQueue([]); submitQuery(combinedMessage); } - }, [streamingState, messageQueue, submitQuery]); + }, [isConfigInitialized, streamingState, messageQueue, submitQuery]); return { messageQueue, diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts new file mode 100644 index 00000000000..d033aea4ea1 --- /dev/null +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts @@ -0,0 +1,389 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + vi, + describe, + it, + expect, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; +import { act, renderHook } from '@testing-library/react'; +import { + type Config, + type FallbackModelHandler, + UserTierId, + AuthType, + isGenericQuotaExceededError, + isProQuotaExceededError, + makeFakeConfig, +} from '@blocksuser/gemini-cli-core'; +import { useQuotaAndFallback } from './useQuotaAndFallback.js'; +import type { UseHistoryManagerReturn } from './useHistoryManager.js'; +import { AuthState, MessageType } from '../types.js'; + +// Mock the error checking functions from the core package to control test scenarios +vi.mock('@blocksuser/gemini-cli-core', async (importOriginal) => { + const original = + await importOriginal(); + return { + ...original, + isGenericQuotaExceededError: vi.fn(), + isProQuotaExceededError: vi.fn(), + }; +}); + +// Use a type alias for SpyInstance as it's not directly exported +type SpyInstance = ReturnType; + +describe('useQuotaAndFallback', () => { + let mockConfig: Config; + let mockHistoryManager: UseHistoryManagerReturn; + let mockSetAuthState: Mock; + let mockSetModelSwitchedFromQuotaError: Mock; + let setFallbackHandlerSpy: SpyInstance; + + const mockedIsGenericQuotaExceededError = isGenericQuotaExceededError as Mock; + const mockedIsProQuotaExceededError = isProQuotaExceededError as Mock; + + beforeEach(() => { + mockConfig = makeFakeConfig(); + + // Spy on the method that requires the private field and mock its return. + // This is cleaner than modifying the config class for tests. + vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({ + authType: AuthType.LOGIN_WITH_GOOGLE, + }); + + mockHistoryManager = { + addItem: vi.fn(), + history: [], + updateItem: vi.fn(), + clearItems: vi.fn(), + loadHistory: vi.fn(), + }; + mockSetAuthState = vi.fn(); + mockSetModelSwitchedFromQuotaError = vi.fn(); + + setFallbackHandlerSpy = vi.spyOn(mockConfig, 'setFallbackModelHandler'); + vi.spyOn(mockConfig, 'setQuotaErrorOccurred'); + + mockedIsGenericQuotaExceededError.mockReturnValue(false); + mockedIsProQuotaExceededError.mockReturnValue(false); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should register a fallback handler on initialization', () => { + renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + ); + + expect(setFallbackHandlerSpy).toHaveBeenCalledTimes(1); + expect(setFallbackHandlerSpy.mock.calls[0][0]).toBeInstanceOf(Function); + }); + + describe('Fallback Handler Logic', () => { + // Helper function to render the hook and extract the registered handler + const getRegisteredHandler = ( + userTier: UserTierId = UserTierId.FREE, + ): FallbackModelHandler => { + renderHook( + (props) => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: props.userTier, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + { initialProps: { userTier } }, + ); + return setFallbackHandlerSpy.mock.calls[0][0] as FallbackModelHandler; + }; + + it('should return null and take no action if already in fallback mode', async () => { + vi.spyOn(mockConfig, 'isInFallbackMode').mockReturnValue(true); + const handler = getRegisteredHandler(); + const result = await handler('gemini-pro', 'gemini-flash', new Error()); + + expect(result).toBeNull(); + expect(mockHistoryManager.addItem).not.toHaveBeenCalled(); + }); + + it('should return null and take no action if authType is not LOGIN_WITH_GOOGLE', async () => { + // Override the default mock from beforeEach for this specific test + vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({ + authType: AuthType.USE_GEMINI, + }); + + const handler = getRegisteredHandler(); + const result = await handler('gemini-pro', 'gemini-flash', new Error()); + + expect(result).toBeNull(); + expect(mockHistoryManager.addItem).not.toHaveBeenCalled(); + }); + + describe('Automatic Fallback Scenarios', () => { + const testCases = [ + { + errorType: 'generic', + tier: UserTierId.FREE, + expectedMessageSnippets: [ + 'Automatically switching from model-A to model-B', + 'upgrade to a Gemini Code Assist Standard or Enterprise plan', + ], + }, + { + errorType: 'generic', + tier: UserTierId.STANDARD, // Paid tier + expectedMessageSnippets: [ + 'Automatically switching from model-A to model-B', + 'switch to using a paid API key from AI Studio', + ], + }, + { + errorType: 'other', + tier: UserTierId.FREE, + expectedMessageSnippets: [ + 'Automatically switching from model-A to model-B for faster responses', + 'upgrade to a Gemini Code Assist Standard or Enterprise plan', + ], + }, + { + errorType: 'other', + tier: UserTierId.LEGACY, // Paid tier + expectedMessageSnippets: [ + 'Automatically switching from model-A to model-B for faster responses', + 'switch to using a paid API key from AI Studio', + ], + }, + ]; + + for (const { errorType, tier, expectedMessageSnippets } of testCases) { + it(`should handle ${errorType} error for ${tier} tier correctly`, async () => { + mockedIsGenericQuotaExceededError.mockReturnValue( + errorType === 'generic', + ); + + const handler = getRegisteredHandler(tier); + const result = await handler( + 'model-A', + 'model-B', + new Error('quota exceeded'), + ); + + // Automatic fallbacks should return 'stop' + expect(result).toBe('stop'); + + expect(mockHistoryManager.addItem).toHaveBeenCalledWith( + expect.objectContaining({ type: MessageType.INFO }), + expect.any(Number), + ); + + const message = (mockHistoryManager.addItem as Mock).mock.calls[0][0] + .text; + for (const snippet of expectedMessageSnippets) { + expect(message).toContain(snippet); + } + + expect(mockSetModelSwitchedFromQuotaError).toHaveBeenCalledWith(true); + expect(mockConfig.setQuotaErrorOccurred).toHaveBeenCalledWith(true); + }); + } + }); + + describe('Interactive Fallback (Pro Quota Error)', () => { + beforeEach(() => { + mockedIsProQuotaExceededError.mockReturnValue(true); + }); + + it('should set an interactive request and wait for user choice', async () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + ); + + const handler = setFallbackHandlerSpy.mock + .calls[0][0] as FallbackModelHandler; + + // Call the handler but do not await it, to check the intermediate state + const promise = handler( + 'gemini-pro', + 'gemini-flash', + new Error('pro quota'), + ); + + await act(async () => {}); + + // The hook should now have a pending request for the UI to handle + expect(result.current.proQuotaRequest).not.toBeNull(); + expect(result.current.proQuotaRequest?.failedModel).toBe('gemini-pro'); + + // Simulate the user choosing to continue with the fallback model + act(() => { + result.current.handleProQuotaChoice('continue'); + }); + + // The original promise from the handler should now resolve + const intent = await promise; + expect(intent).toBe('retry'); + + // The pending request should be cleared from the state + expect(result.current.proQuotaRequest).toBeNull(); + }); + + it('should handle race conditions by stopping subsequent requests', async () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + ); + + const handler = setFallbackHandlerSpy.mock + .calls[0][0] as FallbackModelHandler; + + const promise1 = handler( + 'gemini-pro', + 'gemini-flash', + new Error('pro quota 1'), + ); + await act(async () => {}); + + const firstRequest = result.current.proQuotaRequest; + expect(firstRequest).not.toBeNull(); + + const result2 = await handler( + 'gemini-pro', + 'gemini-flash', + new Error('pro quota 2'), + ); + + // The lock should have stopped the second request + expect(result2).toBe('stop'); + expect(result.current.proQuotaRequest).toBe(firstRequest); + + act(() => { + result.current.handleProQuotaChoice('continue'); + }); + + const intent1 = await promise1; + expect(intent1).toBe('retry'); + expect(result.current.proQuotaRequest).toBeNull(); + }); + }); + }); + + describe('handleProQuotaChoice', () => { + beforeEach(() => { + mockedIsProQuotaExceededError.mockReturnValue(true); + }); + + it('should do nothing if there is no pending pro quota request', () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + ); + + act(() => { + result.current.handleProQuotaChoice('auth'); + }); + + expect(mockSetAuthState).not.toHaveBeenCalled(); + expect(mockHistoryManager.addItem).not.toHaveBeenCalled(); + }); + + it('should resolve intent to "auth" and trigger auth state update', async () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + ); + + const handler = setFallbackHandlerSpy.mock + .calls[0][0] as FallbackModelHandler; + const promise = handler( + 'gemini-pro', + 'gemini-flash', + new Error('pro quota'), + ); + await act(async () => {}); // Allow state to update + + act(() => { + result.current.handleProQuotaChoice('auth'); + }); + + const intent = await promise; + expect(intent).toBe('auth'); + expect(mockSetAuthState).toHaveBeenCalledWith(AuthState.Updating); + expect(result.current.proQuotaRequest).toBeNull(); + }); + + it('should resolve intent to "retry" and add info message on continue', async () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + ); + + const handler = setFallbackHandlerSpy.mock + .calls[0][0] as FallbackModelHandler; + // The first `addItem` call is for the initial quota error message + const promise = handler( + 'gemini-pro', + 'gemini-flash', + new Error('pro quota'), + ); + await act(async () => {}); // Allow state to update + + act(() => { + result.current.handleProQuotaChoice('continue'); + }); + + const intent = await promise; + expect(intent).toBe('retry'); + expect(result.current.proQuotaRequest).toBeNull(); + + // Check for the second "Switched to fallback model" message + expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(2); + const lastCall = (mockHistoryManager.addItem as Mock).mock.calls[1][0]; + expect(lastCall.type).toBe(MessageType.INFO); + expect(lastCall.text).toContain('Switched to fallback model.'); + }); + }); +}); diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts new file mode 100644 index 00000000000..c8bc72f8074 --- /dev/null +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts @@ -0,0 +1,175 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + AuthType, + type Config, + type FallbackModelHandler, + type FallbackIntent, + isGenericQuotaExceededError, + isProQuotaExceededError, + UserTierId, +} from '@blocksuser/gemini-cli-core'; +import { useCallback, useEffect, useRef, useState } from 'react'; +import { type UseHistoryManagerReturn } from './useHistoryManager.js'; +import { AuthState, MessageType } from '../types.js'; +import { type ProQuotaDialogRequest } from '../contexts/UIStateContext.js'; + +interface UseQuotaAndFallbackArgs { + config: Config; + historyManager: UseHistoryManagerReturn; + userTier: UserTierId | undefined; + setAuthState: (state: AuthState) => void; + setModelSwitchedFromQuotaError: (value: boolean) => void; +} + +export function useQuotaAndFallback({ + config, + historyManager, + userTier, + setAuthState, + setModelSwitchedFromQuotaError, +}: UseQuotaAndFallbackArgs) { + const [proQuotaRequest, setProQuotaRequest] = + useState(null); + const isDialogPending = useRef(false); + + // Set up Flash fallback handler + useEffect(() => { + const fallbackHandler: FallbackModelHandler = async ( + failedModel, + fallbackModel, + error, + ): Promise => { + if (config.isInFallbackMode()) { + return null; + } + + // Fallbacks are currently only handled for OAuth users. + const contentGeneratorConfig = config.getContentGeneratorConfig(); + if ( + !contentGeneratorConfig || + contentGeneratorConfig.authType !== AuthType.LOGIN_WITH_GOOGLE + ) { + return null; + } + + // Use actual user tier if available; otherwise, default to FREE tier behavior (safe default) + const isPaidTier = + userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD; + + let message: string; + + if (error && isProQuotaExceededError(error)) { + // Pro Quota specific messages (Interactive) + if (isPaidTier) { + message = `⚡ You have reached your daily ${failedModel} quota limit. +⚡ You can choose to authenticate with a paid API key or continue with the fallback model. +⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; + } else { + message = `⚡ You have reached your daily ${failedModel} quota limit. +⚡ You can choose to authenticate with a paid API key or continue with the fallback model. +⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist +⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key +⚡ You can switch authentication methods by typing /auth`; + } + } else if (error && isGenericQuotaExceededError(error)) { + // Generic Quota (Automatic fallback) + const actionMessage = `⚡ You have reached your daily quota limit.\n⚡ Automatically switching from ${failedModel} to ${fallbackModel} for the remainder of this session.`; + + if (isPaidTier) { + message = `${actionMessage} +⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; + } else { + message = `${actionMessage} +⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist +⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key +⚡ You can switch authentication methods by typing /auth`; + } + } else { + // Consecutive 429s or other errors (Automatic fallback) + const actionMessage = `⚡ Automatically switching from ${failedModel} to ${fallbackModel} for faster responses for the remainder of this session.`; + + if (isPaidTier) { + message = `${actionMessage} +⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${failedModel} quota limit +⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; + } else { + message = `${actionMessage} +⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${failedModel} quota limit +⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist +⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key +⚡ You can switch authentication methods by typing /auth`; + } + } + + // Add message to UI history + historyManager.addItem( + { + type: MessageType.INFO, + text: message, + }, + Date.now(), + ); + + setModelSwitchedFromQuotaError(true); + config.setQuotaErrorOccurred(true); + + // Interactive Fallback for Pro quota + if (error && isProQuotaExceededError(error)) { + if (isDialogPending.current) { + return 'stop'; // A dialog is already active, so just stop this request. + } + isDialogPending.current = true; + + const intent: FallbackIntent = await new Promise( + (resolve) => { + setProQuotaRequest({ + failedModel, + fallbackModel, + resolve, + }); + }, + ); + + return intent; + } + + return 'stop'; + }; + + config.setFallbackModelHandler(fallbackHandler); + }, [config, historyManager, userTier, setModelSwitchedFromQuotaError]); + + const handleProQuotaChoice = useCallback( + (choice: 'auth' | 'continue') => { + if (!proQuotaRequest) return; + + const intent: FallbackIntent = choice === 'auth' ? 'auth' : 'retry'; + proQuotaRequest.resolve(intent); + setProQuotaRequest(null); + isDialogPending.current = false; // Reset the flag here + + if (choice === 'auth') { + setAuthState(AuthState.Updating); + } else { + historyManager.addItem( + { + type: MessageType.INFO, + text: 'Switched to fallback model. Tip: Press Ctrl+P (or Up Arrow) to recall your previous prompt and submit it again if you wish.', + }, + Date.now(), + ); + } + }, + [proQuotaRequest, setAuthState, historyManager], + ); + + return { + proQuotaRequest, + handleProQuotaChoice, + }; +} diff --git a/packages/cli/src/ui/hooks/useReactToolScheduler.ts b/packages/cli/src/ui/hooks/useReactToolScheduler.ts index 3684f6b8f18..b8055c6e033 100644 --- a/packages/cli/src/ui/hooks/useReactToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useReactToolScheduler.ts @@ -25,7 +25,6 @@ import { useCallback, useState, useMemo } from 'react'; import type { HistoryItemToolGroup, IndividualToolCallDisplay, - HistoryItemWithoutId, } from '../types.js'; import { ToolCallStatus } from '../types.js'; @@ -46,6 +45,7 @@ export type TrackedWaitingToolCall = WaitingToolCall & { }; export type TrackedExecutingToolCall = ExecutingToolCall & { responseSubmittedToGemini?: boolean; + pid?: number; }; export type TrackedCompletedToolCall = CompletedToolCall & { responseSubmittedToGemini?: boolean; @@ -65,9 +65,6 @@ export type TrackedToolCall = export function useReactToolScheduler( onComplete: (tools: CompletedToolCall[]) => Promise, config: Config, - setPendingHistoryItem: React.Dispatch< - React.SetStateAction - >, getPreferredEditor: () => EditorType | undefined, onEditorClose: () => void, ): [TrackedToolCall[], ScheduleFn, MarkToolsAsSubmittedFn] { @@ -77,21 +74,6 @@ export function useReactToolScheduler( const outputUpdateHandler: OutputUpdateHandler = useCallback( (toolCallId, outputChunk) => { - setPendingHistoryItem((prevItem) => { - if (prevItem?.type === 'tool_group') { - return { - ...prevItem, - tools: prevItem.tools.map((toolDisplay) => - toolDisplay.callId === toolCallId && - toolDisplay.status === ToolCallStatus.Executing - ? { ...toolDisplay, resultDisplay: outputChunk } - : toolDisplay, - ), - }; - } - return prevItem; - }); - setToolCallsForDisplay((prevCalls) => prevCalls.map((tc) => { if (tc.request.callId === toolCallId && tc.status === 'executing') { @@ -102,7 +84,7 @@ export function useReactToolScheduler( }), ); }, - [setPendingHistoryItem], + [], ); const allToolCallsCompleteHandler: AllToolCallsCompleteHandler = useCallback( @@ -119,12 +101,29 @@ export function useReactToolScheduler( const existingTrackedCall = prevTrackedCalls.find( (ptc) => ptc.request.callId === coreTc.request.callId, ); - const newTrackedCall: TrackedToolCall = { + // Start with the new core state, then layer on the existing UI state + // to ensure UI-only properties like pid are preserved. + const responseSubmittedToGemini = + existingTrackedCall?.responseSubmittedToGemini ?? false; + + if (coreTc.status === 'executing') { + return { + ...coreTc, + responseSubmittedToGemini, + liveOutput: (existingTrackedCall as TrackedExecutingToolCall) + ?.liveOutput, + pid: (coreTc as ExecutingToolCall).pid, + }; + } + + // For other statuses, explicitly set liveOutput and pid to undefined + // to ensure they are not carried over from a previous executing state. + return { ...coreTc, - responseSubmittedToGemini: - existingTrackedCall?.responseSubmittedToGemini ?? false, - } as TrackedToolCall; - return newTrackedCall; + responseSubmittedToGemini, + liveOutput: undefined, + pid: undefined, + }; }), ); }, @@ -278,6 +277,7 @@ export function mapToDisplay( resultDisplay: (trackedCall as TrackedExecutingToolCall).liveOutput ?? undefined, confirmationDetails: undefined, + ptyId: (trackedCall as TrackedExecutingToolCall).pid, }; case 'validating': // Fallthrough case 'scheduled': diff --git a/packages/cli/src/ui/hooks/useSlashCompletion.test.ts b/packages/cli/src/ui/hooks/useSlashCompletion.test.ts index 6b8ce9682ba..b5568ce9b43 100644 --- a/packages/cli/src/ui/hooks/useSlashCompletion.test.ts +++ b/packages/cli/src/ui/hooks/useSlashCompletion.test.ts @@ -347,6 +347,31 @@ describe('useSlashCompletion', () => { expect(result.current.suggestions).toHaveLength(0); }); + + it('should not suggest hidden commands', async () => { + const slashCommands = [ + createTestCommand({ + name: 'visible', + description: 'A visible command', + }), + createTestCommand({ + name: 'hidden', + description: 'A hidden command', + hidden: true, + }), + ]; + const { result } = renderHook(() => + useTestHarnessForSlashCompletion( + true, + '/', + slashCommands, + mockCommandContext, + ), + ); + + expect(result.current.suggestions.length).toBe(1); + expect(result.current.suggestions[0].label).toBe('visible'); + }); }); describe('Sub-Commands', () => { diff --git a/packages/cli/src/ui/hooks/useSlashCompletion.ts b/packages/cli/src/ui/hooks/useSlashCompletion.ts index 87288090fef..a284e0bc6e9 100644 --- a/packages/cli/src/ui/hooks/useSlashCompletion.ts +++ b/packages/cli/src/ui/hooks/useSlashCompletion.ts @@ -225,7 +225,7 @@ function useCommandSuggestions( if (partial === '') { // If no partial query, show all available commands potentialSuggestions = commandsToSearch.filter( - (cmd) => cmd.description, + (cmd) => cmd.description && !cmd.hidden, ); } else { // Use fuzzy search for non-empty partial queries with fallback @@ -400,7 +400,7 @@ export function useSlashCompletion(props: UseSlashCompletionProps): { const commandMap = new Map(); commands.forEach((cmd) => { - if (cmd.description) { + if (cmd.description && !cmd.hidden) { commandItems.push(cmd.name); commandMap.set(cmd.name, cmd); diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index 3f4e1e7ecd1..7e6826992b5 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -68,6 +68,7 @@ const mockConfig = { }), getUseSmartEdit: () => false, getGeminiClient: () => null, // No client needed for these tests + getShellExecutionConfig: () => ({ terminalWidth: 80, terminalHeight: 24 }), } as unknown as Config; const mockTool = new MockTool({ @@ -124,7 +125,6 @@ describe('useReactToolScheduler in YOLO Mode', () => { onComplete, mockConfig as unknown as Config, setPendingHistoryItem, - () => undefined, () => {}, ), ); @@ -163,7 +163,7 @@ describe('useReactToolScheduler in YOLO Mode', () => { expect(mockToolRequiresConfirmation.execute).toHaveBeenCalledWith( request.args, expect.any(AbortSignal), - undefined /*updateOutputFn*/, + undefined, ); // Check that onComplete was called with success @@ -272,7 +272,6 @@ describe('useReactToolScheduler', () => { onComplete, mockConfig as unknown as Config, setPendingHistoryItem, - () => undefined, () => {}, ), ); @@ -314,7 +313,7 @@ describe('useReactToolScheduler', () => { expect(mockTool.execute).toHaveBeenCalledWith( request.args, expect.any(AbortSignal), - undefined /*updateOutputFn*/, + undefined, ); expect(onComplete).toHaveBeenCalledWith([ expect.objectContaining({ diff --git a/packages/cli/src/ui/keyMatchers.test.ts b/packages/cli/src/ui/keyMatchers.test.ts index e08cc2e0396..eb7f2332b9e 100644 --- a/packages/cli/src/ui/keyMatchers.test.ts +++ b/packages/cli/src/ui/keyMatchers.test.ts @@ -63,6 +63,8 @@ describe('keyMatchers', () => { key.name === 'return' && !key.ctrl, [Command.ACCEPT_SUGGESTION_REVERSE_SEARCH]: (key: Key) => key.name === 'tab', + [Command.TOGGLE_SHELL_INPUT_FOCUS]: (key: Key) => + key.ctrl && key.name === 'f', }; // Test data for each command with positive and negative test cases @@ -253,6 +255,11 @@ describe('keyMatchers', () => { positive: [createKey('tab'), createKey('tab', { ctrl: true })], negative: [createKey('return'), createKey('space')], }, + { + command: Command.TOGGLE_SHELL_INPUT_FOCUS, + positive: [createKey('f', { ctrl: true })], + negative: [createKey('f')], + }, ]; describe('Data-driven key binding matches original logic', () => { diff --git a/packages/cli/src/ui/privacy/CloudFreePrivacyNotice.tsx b/packages/cli/src/ui/privacy/CloudFreePrivacyNotice.tsx index accda1a5624..b0329a131d6 100644 --- a/packages/cli/src/ui/privacy/CloudFreePrivacyNotice.tsx +++ b/packages/cli/src/ui/privacy/CloudFreePrivacyNotice.tsx @@ -9,7 +9,7 @@ import { RadioButtonSelect } from '../components/shared/RadioButtonSelect.js'; import { usePrivacySettings } from '../hooks/usePrivacySettings.js'; import { CloudPaidPrivacyNotice } from './CloudPaidPrivacyNotice.js'; import type { Config } from '@blocksuser/gemini-cli-core'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { useKeypress } from '../hooks/useKeypress.js'; interface CloudFreePrivacyNoticeProps { @@ -34,16 +34,16 @@ export const CloudFreePrivacyNotice = ({ ); if (privacyState.isLoading) { - return Loading...; + return Loading...; } if (privacyState.error) { return ( - + Error loading Opt-in settings: {privacyState.error} - Press Esc to exit. + Press Esc to exit. ); } @@ -59,17 +59,17 @@ export const CloudFreePrivacyNotice = ({ return ( - + Gemini Code Assist for Individuals Privacy Notice - + This notice and our Privacy Policy - [1] describe how Gemini Code - Assist handles your data. Please read them carefully. + [1] describe how Gemini Code Assist + handles your data. Please read them carefully. - + When you use Gemini Code Assist for individuals with Gemini CLI, Google collects your prompts, related code, generated output, code edits, related feature usage information, and your feedback to provide, @@ -77,7 +77,7 @@ export const CloudFreePrivacyNotice = ({ technologies. - + To help with quality and improve our products (such as generative machine-learning models), human reviewers may read, annotate, and process the data collected above. We take steps to protect your privacy @@ -90,7 +90,7 @@ export const CloudFreePrivacyNotice = ({ - + Allow Google to use this data to develop and improve our products? - [1]{' '} + [1]{' '} https://policies.google.com/privacy - Press Enter to choose an option and exit. + + Press Enter to choose an option and exit. + ); }; diff --git a/packages/cli/src/ui/privacy/CloudPaidPrivacyNotice.tsx b/packages/cli/src/ui/privacy/CloudPaidPrivacyNotice.tsx index f0adbb68e24..ce640308ece 100644 --- a/packages/cli/src/ui/privacy/CloudPaidPrivacyNotice.tsx +++ b/packages/cli/src/ui/privacy/CloudPaidPrivacyNotice.tsx @@ -5,7 +5,7 @@ */ import { Box, Newline, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { useKeypress } from '../hooks/useKeypress.js'; interface CloudPaidPrivacyNoticeProps { @@ -26,14 +26,14 @@ export const CloudPaidPrivacyNotice = ({ return ( - + Vertex AI Notice - - Service Specific Terms[1] are + + Service Specific Terms[1] are incorporated into the agreement under which Google has agreed to provide - Google Cloud Platform[2] to + Google Cloud Platform[2] to Customer (the “Agreement”). If the Agreement authorizes the resale or supply of Google Cloud Platform under a Google Cloud partner or reseller program, then except for in the section entitled “Partner-Specific @@ -44,16 +44,16 @@ export const CloudPaidPrivacyNotice = ({ them in the Agreement. - - [1]{' '} + + [1]{' '} https://cloud.google.com/terms/service-terms - - [2]{' '} + + [2]{' '} https://cloud.google.com/terms/services - Press Esc to exit. + Press Esc to exit. ); }; diff --git a/packages/cli/src/ui/privacy/GeminiPrivacyNotice.tsx b/packages/cli/src/ui/privacy/GeminiPrivacyNotice.tsx index c0eaa74f2d6..1f4015b5c25 100644 --- a/packages/cli/src/ui/privacy/GeminiPrivacyNotice.tsx +++ b/packages/cli/src/ui/privacy/GeminiPrivacyNotice.tsx @@ -5,7 +5,7 @@ */ import { Box, Newline, Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { useKeypress } from '../hooks/useKeypress.js'; interface GeminiPrivacyNoticeProps { @@ -24,39 +24,39 @@ export const GeminiPrivacyNotice = ({ onExit }: GeminiPrivacyNoticeProps) => { return ( - + Gemini API Key Notice - - By using the Gemini API[1], - Google AI Studio - [2], and the other Google + + By using the Gemini API[1], Google + AI Studio + [2], and the other Google developer services that reference these terms (collectively, the "APIs" or "Services"), you are agreeing to Google APIs Terms of Service (the "API Terms") - [3], and the Gemini API + [3], and the Gemini API Additional Terms of Service (the "Additional Terms") - [4]. + [4]. - - [1]{' '} + + [1]{' '} https://ai.google.dev/docs/gemini_api_overview - - [2] https://aistudio.google.com/ + + [2] https://aistudio.google.com/ - - [3]{' '} + + [3]{' '} https://developers.google.com/terms - - [4]{' '} + + [4]{' '} https://ai.google.dev/gemini-api/terms - Press Esc to exit. + Press Esc to exit. ); }; diff --git a/packages/cli/src/ui/themes/theme.ts b/packages/cli/src/ui/themes/theme.ts index ed6db897142..df33a59a108 100644 --- a/packages/cli/src/ui/themes/theme.ts +++ b/packages/cli/src/ui/themes/theme.ts @@ -174,8 +174,8 @@ export class Theme { focused: this.colors.AccentBlue, }, ui: { - comment: this.colors.Comment, - symbol: this.colors.Gray, + comment: this.colors.Gray, + symbol: this.colors.AccentCyan, gradient: this.colors.GradientColors, }, status: { @@ -410,31 +410,31 @@ export function createCustomTheme(customTheme: CustomTheme): Theme { const semanticColors: SemanticColors = { text: { - primary: colors.Foreground, - secondary: colors.Gray, - link: colors.AccentBlue, - accent: colors.AccentPurple, + primary: customTheme.text?.primary ?? colors.Foreground, + secondary: customTheme.text?.secondary ?? colors.Gray, + link: customTheme.text?.link ?? colors.AccentBlue, + accent: customTheme.text?.accent ?? colors.AccentPurple, }, background: { - primary: colors.Background, + primary: customTheme.background?.primary ?? colors.Background, diff: { - added: colors.DiffAdded, - removed: colors.DiffRemoved, + added: customTheme.background?.diff?.added ?? colors.DiffAdded, + removed: customTheme.background?.diff?.removed ?? colors.DiffRemoved, }, }, border: { - default: colors.Gray, - focused: colors.AccentBlue, + default: customTheme.border?.default ?? colors.Gray, + focused: customTheme.border?.focused ?? colors.AccentBlue, }, ui: { - comment: colors.Comment, - symbol: colors.Gray, - gradient: colors.GradientColors, + comment: customTheme.ui?.comment ?? colors.Comment, + symbol: customTheme.ui?.symbol ?? colors.Gray, + gradient: customTheme.ui?.gradient ?? colors.GradientColors, }, status: { - error: colors.AccentRed, - success: colors.AccentGreen, - warning: colors.AccentYellow, + error: customTheme.status?.error ?? colors.AccentRed, + success: customTheme.status?.success ?? colors.AccentGreen, + warning: customTheme.status?.warning ?? colors.AccentYellow, }, }; diff --git a/packages/cli/src/ui/types.ts b/packages/cli/src/ui/types.ts index 41c7676007b..7e8c875908e 100644 --- a/packages/cli/src/ui/types.ts +++ b/packages/cli/src/ui/types.ts @@ -66,6 +66,7 @@ export interface IndividualToolCallDisplay { status: ToolCallStatus; confirmationDetails: ToolCallConfirmationDetails | undefined; renderOutputAsMarkdown?: boolean; + ptyId?: number; outputFile?: string; } @@ -284,3 +285,7 @@ export interface ConfirmationRequest { prompt: ReactNode; onConfirm: (confirm: boolean) => void; } + +export interface LoopDetectionConfirmationRequest { + onComplete: (result: { userSelection: 'disable' | 'keep' }) => void; +} diff --git a/packages/cli/src/ui/utils/CodeColorizer.tsx b/packages/cli/src/ui/utils/CodeColorizer.tsx index e06e199a5e3..644248fd052 100644 --- a/packages/cli/src/ui/utils/CodeColorizer.tsx +++ b/packages/cli/src/ui/utils/CodeColorizer.tsx @@ -31,8 +31,9 @@ function renderHastNode( inheritedColor: string | undefined, ): React.ReactNode { if (node.type === 'text') { - // Use the color passed down from parent element, if any - return {node.value}; + // Use the color passed down from parent element, or the theme's default. + const color = inheritedColor || theme.defaultColor; + return {node.value}; } // Handle Element Nodes: Determine color and pass it down, don't wrap diff --git a/packages/cli/src/ui/utils/InlineMarkdownRenderer.tsx b/packages/cli/src/ui/utils/InlineMarkdownRenderer.tsx index 4c05a28f4eb..4320c51967d 100644 --- a/packages/cli/src/ui/utils/InlineMarkdownRenderer.tsx +++ b/packages/cli/src/ui/utils/InlineMarkdownRenderer.tsx @@ -6,13 +6,13 @@ import React from 'react'; import { Text } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import stringWidth from 'string-width'; // Constants for Markdown parsing const BOLD_MARKER_LENGTH = 2; // For "**" const ITALIC_MARKER_LENGTH = 1; // For "*" or "_" -const STRIKETHROUGH_MARKER_LENGTH = 2; // For "~~" +const STRIKETHROUGH_MARKER_LENGTH = 2; // For "~~") const INLINE_CODE_MARKER_LENGTH = 1; // For "`" const UNDERLINE_TAG_START_LENGTH = 3; // For "" const UNDERLINE_TAG_END_LENGTH = 4; // For "" @@ -24,7 +24,7 @@ interface RenderInlineProps { const RenderInlineInternal: React.FC = ({ text }) => { // Early return for plain text without markdown or URLs if (!/[*_~`<[https?:]/.test(text)) { - return {text}; + return {text}; } const nodes: React.ReactNode[] = []; @@ -96,7 +96,7 @@ const RenderInlineInternal: React.FC = ({ text }) => { const codeMatch = fullMatch.match(/^(`+)(.+?)\1$/s); if (codeMatch && codeMatch[2]) { renderedNode = ( - + {codeMatch[2]} ); @@ -113,7 +113,7 @@ const RenderInlineInternal: React.FC = ({ text }) => { renderedNode = ( {linkText} - ({url}) + ({url}) ); } @@ -133,7 +133,7 @@ const RenderInlineInternal: React.FC = ({ text }) => { ); } else if (fullMatch.match(/^https?:\/\//)) { renderedNode = ( - + {fullMatch} ); @@ -168,6 +168,6 @@ export const getPlainTextLength = (text: string): number => { .replace(/~~(.*?)~~/g, '$1') .replace(/`(.*?)`/g, '$1') .replace(/(.*?)<\/u>/g, '$1') - .replace(/\[(.*?)\]\(.*?\)/g, '$1'); + .replace(/.*\[(.*?)\]\(.*\)/g, '$1'); return stringWidth(cleanText); }; diff --git a/packages/cli/src/ui/utils/MarkdownDisplay.tsx b/packages/cli/src/ui/utils/MarkdownDisplay.tsx index f5cbd84b3fb..2baea998848 100644 --- a/packages/cli/src/ui/utils/MarkdownDisplay.tsx +++ b/packages/cli/src/ui/utils/MarkdownDisplay.tsx @@ -7,7 +7,7 @@ import React from 'react'; import { Text, Box } from 'ink'; import { EOL } from 'node:os'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { colorizeCode } from './CodeColorizer.js'; import { TableRenderer } from './TableRenderer.js'; import { RenderInline } from './InlineMarkdownRenderer.js'; @@ -174,35 +174,35 @@ const MarkdownDisplayInternal: React.FC = ({ switch (level) { case 1: headerNode = ( - + ); break; case 2: headerNode = ( - + ); break; case 3: headerNode = ( - + ); break; case 4: headerNode = ( - + ); break; default: headerNode = ( - + ); @@ -246,7 +246,7 @@ const MarkdownDisplayInternal: React.FC = ({ } else { addContentBlock( - + , @@ -315,7 +315,9 @@ const RenderCodeBlockInternal: React.FC = ({ // Not enough space to even show the message meaningfully return ( - ... code is being written ... + + ... code is being written ... + ); } @@ -331,7 +333,7 @@ const RenderCodeBlockInternal: React.FC = ({ return ( {colorizedTruncatedCode} - ... generating more ... + ... generating more ... ); } @@ -384,10 +386,10 @@ const RenderListItemInternal: React.FC = ({ flexDirection="row" > - {prefix} + {prefix} - + diff --git a/packages/cli/src/ui/utils/TableRenderer.tsx b/packages/cli/src/ui/utils/TableRenderer.tsx index 2ec195491d1..3c1af38170d 100644 --- a/packages/cli/src/ui/utils/TableRenderer.tsx +++ b/packages/cli/src/ui/utils/TableRenderer.tsx @@ -6,7 +6,7 @@ import React from 'react'; import { Text, Box } from 'ink'; -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; import { RenderInline, getPlainTextLength } from './InlineMarkdownRenderer.js'; interface TableRendererProps { @@ -89,7 +89,7 @@ export const TableRenderer: React.FC = ({ return ( {isHeader ? ( - + ) : ( @@ -112,7 +112,7 @@ export const TableRenderer: React.FC = ({ const borderParts = adjustedWidths.map((w) => char.horizontal.repeat(w)); const border = char.left + borderParts.join(char.middle) + char.right; - return {border}; + return {border}; }; // Helper function to render a table row @@ -123,7 +123,7 @@ export const TableRenderer: React.FC = ({ }); return ( - + │{' '} {renderedCells.map((cell, index) => ( diff --git a/packages/cli/src/ui/utils/displayUtils.ts b/packages/cli/src/ui/utils/displayUtils.ts index a52c6ff0452..6f6c9209dbe 100644 --- a/packages/cli/src/ui/utils/displayUtils.ts +++ b/packages/cli/src/ui/utils/displayUtils.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { Colors } from '../colors.js'; +import { theme } from '../semantic-colors.js'; // --- Thresholds --- export const TOOL_SUCCESS_RATE_HIGH = 95; @@ -23,10 +23,10 @@ export const getStatusColor = ( options: { defaultColor?: string } = {}, ) => { if (value >= thresholds.green) { - return Colors.AccentGreen; + return theme.status.success; } if (value >= thresholds.yellow) { - return Colors.AccentYellow; + return theme.status.warning; } - return options.defaultColor || Colors.AccentRed; + return options.defaultColor || theme.status.error; }; diff --git a/packages/cli/src/ui/utils/textUtils.ts b/packages/cli/src/ui/utils/textUtils.ts index ac3d3398fb4..98f690eae33 100644 --- a/packages/cli/src/ui/utils/textUtils.ts +++ b/packages/cli/src/ui/utils/textUtils.ts @@ -6,6 +6,7 @@ import stripAnsi from 'strip-ansi'; import { stripVTControlCharacters } from 'node:util'; +import stringWidth from 'string-width'; /** * Calculates the maximum width of a multi-line ASCII art string. @@ -26,10 +27,39 @@ export const getAsciiArtWidth = (asciiArt: string): number => { * code units so that surrogate‑pair emoji count as one "column".) * ---------------------------------------------------------------------- */ +// Cache for code points to reduce GC pressure +const codePointsCache = new Map(); +const MAX_STRING_LENGTH_TO_CACHE = 1000; + export function toCodePoints(str: string): string[] { - // [...str] or Array.from both iterate by UTF‑32 code point, handling - // surrogate pairs correctly. - return Array.from(str); + // ASCII fast path - check if all chars are ASCII (0-127) + let isAscii = true; + for (let i = 0; i < str.length; i++) { + if (str.charCodeAt(i) > 127) { + isAscii = false; + break; + } + } + if (isAscii) { + return str.split(''); + } + + // Cache short strings + if (str.length <= MAX_STRING_LENGTH_TO_CACHE) { + const cached = codePointsCache.get(str); + if (cached) { + return cached; + } + } + + const result = Array.from(str); + + // Cache result (unlimited like Ink) + if (str.length <= MAX_STRING_LENGTH_TO_CACHE) { + codePointsCache.set(str, result); + } + + return result; } export function cpLen(str: string): number { @@ -86,3 +116,33 @@ export function stripUnsafeCharacters(str: string): string { }) .join(''); } + +// String width caching for performance optimization +const stringWidthCache = new Map(); + +/** + * Cached version of stringWidth function for better performance + * Follows Ink's approach with unlimited cache (no eviction) + */ +export const getCachedStringWidth = (str: string): number => { + // ASCII printable chars have width 1 + if (/^[\x20-\x7E]*$/.test(str)) { + return str.length; + } + + if (stringWidthCache.has(str)) { + return stringWidthCache.get(str)!; + } + + const width = stringWidth(str); + stringWidthCache.set(str, width); + + return width; +}; + +/** + * Clear the string width cache + */ +export const clearStringWidthCache = (): void => { + stringWidthCache.clear(); +}; diff --git a/packages/cli/src/utils/errors.test.ts b/packages/cli/src/utils/errors.test.ts new file mode 100644 index 00000000000..a74bf853233 --- /dev/null +++ b/packages/cli/src/utils/errors.test.ts @@ -0,0 +1,476 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { vi, type MockInstance } from 'vitest'; +import type { Config } from '@blocksuser/gemini-cli-core'; +import { OutputFormat, FatalInputError } from '@blocksuser/gemini-cli-core'; +import { + getErrorMessage, + handleError, + handleToolError, + handleCancellationError, + handleMaxTurnsExceededError, +} from './errors.js'; + +// Mock the core modules +vi.mock('@blocksuser/gemini-cli-core', async (importOriginal) => { + const original = + await importOriginal(); + + return { + ...original, + parseAndFormatApiError: vi.fn((error: unknown) => { + if (error instanceof Error) { + return `API Error: ${error.message}`; + } + return `API Error: ${String(error)}`; + }), + JsonFormatter: vi.fn().mockImplementation(() => ({ + formatError: vi.fn((error: Error, code?: string | number) => + JSON.stringify( + { + error: { + type: error.constructor.name, + message: error.message, + ...(code && { code }), + }, + }, + null, + 2, + ), + ), + })), + FatalToolExecutionError: class extends Error { + constructor(message: string) { + super(message); + this.name = 'FatalToolExecutionError'; + this.exitCode = 54; + } + exitCode: number; + }, + FatalCancellationError: class extends Error { + constructor(message: string) { + super(message); + this.name = 'FatalCancellationError'; + this.exitCode = 130; + } + exitCode: number; + }, + }; +}); + +describe('errors', () => { + let mockConfig: Config; + let processExitSpy: MockInstance; + let consoleErrorSpy: MockInstance; + + beforeEach(() => { + // Reset mocks + vi.clearAllMocks(); + + // Mock console.error + consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + + // Mock process.exit to throw instead of actually exiting + processExitSpy = vi.spyOn(process, 'exit').mockImplementation((code) => { + throw new Error(`process.exit called with code: ${code}`); + }); + + // Create mock config + mockConfig = { + getOutputFormat: vi.fn().mockReturnValue(OutputFormat.TEXT), + getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: 'test' }), + } as unknown as Config; + }); + + afterEach(() => { + consoleErrorSpy.mockRestore(); + processExitSpy.mockRestore(); + }); + + describe('getErrorMessage', () => { + it('should return error message for Error instances', () => { + const error = new Error('Test error message'); + expect(getErrorMessage(error)).toBe('Test error message'); + }); + + it('should convert non-Error values to strings', () => { + expect(getErrorMessage('string error')).toBe('string error'); + expect(getErrorMessage(123)).toBe('123'); + expect(getErrorMessage(null)).toBe('null'); + expect(getErrorMessage(undefined)).toBe('undefined'); + }); + + it('should handle objects', () => { + const obj = { message: 'test' }; + expect(getErrorMessage(obj)).toBe('[object Object]'); + }); + }); + + describe('handleError', () => { + describe('in text mode', () => { + beforeEach(() => { + ( + mockConfig.getOutputFormat as ReturnType + ).mockReturnValue(OutputFormat.TEXT); + }); + + it('should log error message and re-throw', () => { + const testError = new Error('Test error'); + + expect(() => { + handleError(testError, mockConfig); + }).toThrow(testError); + + expect(consoleErrorSpy).toHaveBeenCalledWith('API Error: Test error'); + }); + + it('should handle non-Error objects', () => { + const testError = 'String error'; + + expect(() => { + handleError(testError, mockConfig); + }).toThrow(testError); + + expect(consoleErrorSpy).toHaveBeenCalledWith('API Error: String error'); + }); + }); + + describe('in JSON mode', () => { + beforeEach(() => { + ( + mockConfig.getOutputFormat as ReturnType + ).mockReturnValue(OutputFormat.JSON); + }); + + it('should format error as JSON and exit with default code', () => { + const testError = new Error('Test error'); + + expect(() => { + handleError(testError, mockConfig); + }).toThrow('process.exit called with code: 1'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + JSON.stringify( + { + error: { + type: 'Error', + message: 'Test error', + code: 1, + }, + }, + null, + 2, + ), + ); + }); + + it('should use custom error code when provided', () => { + const testError = new Error('Test error'); + + expect(() => { + handleError(testError, mockConfig, 42); + }).toThrow('process.exit called with code: 42'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + JSON.stringify( + { + error: { + type: 'Error', + message: 'Test error', + code: 42, + }, + }, + null, + 2, + ), + ); + }); + + it('should extract exitCode from FatalError instances', () => { + const fatalError = new FatalInputError('Fatal error'); + + expect(() => { + handleError(fatalError, mockConfig); + }).toThrow('process.exit called with code: 42'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + JSON.stringify( + { + error: { + type: 'FatalInputError', + message: 'Fatal error', + code: 42, + }, + }, + null, + 2, + ), + ); + }); + + it('should handle error with code property', () => { + const errorWithCode = new Error('Error with code') as Error & { + code: number; + }; + errorWithCode.code = 404; + + expect(() => { + handleError(errorWithCode, mockConfig); + }).toThrow('process.exit called with code: 404'); + }); + + it('should handle error with status property', () => { + const errorWithStatus = new Error('Error with status') as Error & { + status: string; + }; + errorWithStatus.status = 'TIMEOUT'; + + expect(() => { + handleError(errorWithStatus, mockConfig); + }).toThrow('process.exit called with code: 1'); // string codes become 1 + + expect(consoleErrorSpy).toHaveBeenCalledWith( + JSON.stringify( + { + error: { + type: 'Error', + message: 'Error with status', + code: 'TIMEOUT', + }, + }, + null, + 2, + ), + ); + }); + }); + }); + + describe('handleToolError', () => { + const toolName = 'test-tool'; + const toolError = new Error('Tool failed'); + + describe('in text mode', () => { + beforeEach(() => { + ( + mockConfig.getOutputFormat as ReturnType + ).mockReturnValue(OutputFormat.TEXT); + }); + + it('should log error message to stderr', () => { + handleToolError(toolName, toolError, mockConfig); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Error executing tool test-tool: Tool failed', + ); + }); + + it('should use resultDisplay when provided', () => { + handleToolError( + toolName, + toolError, + mockConfig, + 'CUSTOM_ERROR', + 'Custom display message', + ); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Error executing tool test-tool: Custom display message', + ); + }); + }); + + describe('in JSON mode', () => { + beforeEach(() => { + ( + mockConfig.getOutputFormat as ReturnType + ).mockReturnValue(OutputFormat.JSON); + }); + + it('should format error as JSON and exit with default code', () => { + expect(() => { + handleToolError(toolName, toolError, mockConfig); + }).toThrow('process.exit called with code: 54'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + JSON.stringify( + { + error: { + type: 'FatalToolExecutionError', + message: 'Error executing tool test-tool: Tool failed', + code: 54, + }, + }, + null, + 2, + ), + ); + }); + + it('should use custom error code', () => { + expect(() => { + handleToolError(toolName, toolError, mockConfig, 'CUSTOM_TOOL_ERROR'); + }).toThrow('process.exit called with code: 54'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + JSON.stringify( + { + error: { + type: 'FatalToolExecutionError', + message: 'Error executing tool test-tool: Tool failed', + code: 'CUSTOM_TOOL_ERROR', + }, + }, + null, + 2, + ), + ); + }); + + it('should use numeric error code and exit with that code', () => { + expect(() => { + handleToolError(toolName, toolError, mockConfig, 500); + }).toThrow('process.exit called with code: 500'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + JSON.stringify( + { + error: { + type: 'FatalToolExecutionError', + message: 'Error executing tool test-tool: Tool failed', + code: 500, + }, + }, + null, + 2, + ), + ); + }); + + it('should prefer resultDisplay over error message', () => { + expect(() => { + handleToolError( + toolName, + toolError, + mockConfig, + 'DISPLAY_ERROR', + 'Display message', + ); + }).toThrow('process.exit called with code: 54'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + JSON.stringify( + { + error: { + type: 'FatalToolExecutionError', + message: 'Error executing tool test-tool: Display message', + code: 'DISPLAY_ERROR', + }, + }, + null, + 2, + ), + ); + }); + }); + }); + + describe('handleCancellationError', () => { + describe('in text mode', () => { + beforeEach(() => { + ( + mockConfig.getOutputFormat as ReturnType + ).mockReturnValue(OutputFormat.TEXT); + }); + + it('should log cancellation message and exit with 130', () => { + expect(() => { + handleCancellationError(mockConfig); + }).toThrow('process.exit called with code: 130'); + + expect(consoleErrorSpy).toHaveBeenCalledWith('Operation cancelled.'); + }); + }); + + describe('in JSON mode', () => { + beforeEach(() => { + ( + mockConfig.getOutputFormat as ReturnType + ).mockReturnValue(OutputFormat.JSON); + }); + + it('should format cancellation as JSON and exit with 130', () => { + expect(() => { + handleCancellationError(mockConfig); + }).toThrow('process.exit called with code: 130'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + JSON.stringify( + { + error: { + type: 'FatalCancellationError', + message: 'Operation cancelled.', + code: 130, + }, + }, + null, + 2, + ), + ); + }); + }); + }); + + describe('handleMaxTurnsExceededError', () => { + describe('in text mode', () => { + beforeEach(() => { + ( + mockConfig.getOutputFormat as ReturnType + ).mockReturnValue(OutputFormat.TEXT); + }); + + it('should log max turns message and exit with 53', () => { + expect(() => { + handleMaxTurnsExceededError(mockConfig); + }).toThrow('process.exit called with code: 53'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', + ); + }); + }); + + describe('in JSON mode', () => { + beforeEach(() => { + ( + mockConfig.getOutputFormat as ReturnType + ).mockReturnValue(OutputFormat.JSON); + }); + + it('should format max turns error as JSON and exit with 53', () => { + expect(() => { + handleMaxTurnsExceededError(mockConfig); + }).toThrow('process.exit called with code: 53'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + JSON.stringify( + { + error: { + type: 'FatalTurnLimitedError', + message: + 'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', + code: 53, + }, + }, + null, + 2, + ), + ); + }); + }); + }); +}); diff --git a/packages/cli/src/utils/errors.ts b/packages/cli/src/utils/errors.ts index c1544dd9b42..67654eaf4e7 100644 --- a/packages/cli/src/utils/errors.ts +++ b/packages/cli/src/utils/errors.ts @@ -4,9 +4,159 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { Config } from '@blocksuser/gemini-cli-core'; +import { + OutputFormat, + JsonFormatter, + parseAndFormatApiError, + FatalTurnLimitedError, + FatalToolExecutionError, + FatalCancellationError, +} from '@blocksuser/gemini-cli-core'; + export function getErrorMessage(error: unknown): string { if (error instanceof Error) { return error.message; } return String(error); } + +interface ErrorWithCode extends Error { + exitCode?: number; + code?: string | number; + status?: string | number; +} + +/** + * Extracts the appropriate error code from an error object. + */ +function extractErrorCode(error: unknown): string | number { + const errorWithCode = error as ErrorWithCode; + + // Prioritize exitCode for FatalError types, fall back to other codes + if (typeof errorWithCode.exitCode === 'number') { + return errorWithCode.exitCode; + } + if (errorWithCode.code !== undefined) { + return errorWithCode.code; + } + if (errorWithCode.status !== undefined) { + return errorWithCode.status; + } + + return 1; // Default exit code +} + +/** + * Converts an error code to a numeric exit code. + */ +function getNumericExitCode(errorCode: string | number): number { + return typeof errorCode === 'number' ? errorCode : 1; +} + +/** + * Handles errors consistently for both JSON and text output formats. + * In JSON mode, outputs formatted JSON error and exits. + * In text mode, outputs error message and re-throws. + */ +export function handleError( + error: unknown, + config: Config, + customErrorCode?: string | number, +): never { + const errorMessage = parseAndFormatApiError( + error, + config.getContentGeneratorConfig()?.authType, + ); + + if (config.getOutputFormat() === OutputFormat.JSON) { + const formatter = new JsonFormatter(); + const errorCode = customErrorCode ?? extractErrorCode(error); + + const formattedError = formatter.formatError( + error instanceof Error ? error : new Error(getErrorMessage(error)), + errorCode, + ); + + console.error(formattedError); + process.exit(getNumericExitCode(errorCode)); + } else { + console.error(errorMessage); + throw error; + } +} + +/** + * Handles tool execution errors specifically. + * In JSON mode, outputs formatted JSON error and exits. + * In text mode, outputs error message to stderr only. + */ +export function handleToolError( + toolName: string, + toolError: Error, + config: Config, + errorCode?: string | number, + resultDisplay?: string, +): void { + const errorMessage = `Error executing tool ${toolName}: ${resultDisplay || toolError.message}`; + const toolExecutionError = new FatalToolExecutionError(errorMessage); + + if (config.getOutputFormat() === OutputFormat.JSON) { + const formatter = new JsonFormatter(); + const formattedError = formatter.formatError( + toolExecutionError, + errorCode ?? toolExecutionError.exitCode, + ); + + console.error(formattedError); + process.exit( + typeof errorCode === 'number' ? errorCode : toolExecutionError.exitCode, + ); + } else { + console.error(errorMessage); + } +} + +/** + * Handles cancellation/abort signals consistently. + */ +export function handleCancellationError(config: Config): never { + const cancellationError = new FatalCancellationError('Operation cancelled.'); + + if (config.getOutputFormat() === OutputFormat.JSON) { + const formatter = new JsonFormatter(); + const formattedError = formatter.formatError( + cancellationError, + cancellationError.exitCode, + ); + + console.error(formattedError); + process.exit(cancellationError.exitCode); + } else { + console.error(cancellationError.message); + process.exit(cancellationError.exitCode); + } +} + +/** + * Handles max session turns exceeded consistently. + */ +export function handleMaxTurnsExceededError(config: Config): never { + const maxTurnsError = new FatalTurnLimitedError( + 'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', + ); + + if (config.getOutputFormat() === OutputFormat.JSON) { + const formatter = new JsonFormatter(); + const formattedError = formatter.formatError( + maxTurnsError, + maxTurnsError.exitCode, + ); + + console.error(formattedError); + process.exit(maxTurnsError.exitCode); + } else { + console.error(maxTurnsError.message); + process.exit(maxTurnsError.exitCode); + } +} diff --git a/packages/cli/src/utils/sandbox.ts b/packages/cli/src/utils/sandbox.ts index 44820a67fe9..006cdf1cf0a 100644 --- a/packages/cli/src/utils/sandbox.ts +++ b/packages/cli/src/utils/sandbox.ts @@ -424,6 +424,9 @@ export async function start_sandbox( args.push('-t'); } + // allow access to host.docker.internal + args.push('--add-host', 'host.docker.internal:host-gateway'); + // mount current directory as working directory in sandbox (set via --workdir) args.push('--volume', `${workdir}:${containerWorkdir}`); diff --git a/packages/cli/src/utils/settingsUtils.test.ts b/packages/cli/src/utils/settingsUtils.test.ts index b6830abc4ce..ca1dc802c98 100644 --- a/packages/cli/src/utils/settingsUtils.test.ts +++ b/packages/cli/src/utils/settingsUtils.test.ts @@ -25,6 +25,7 @@ import { // Business logic utilities getSettingValue, isSettingModified, + TEST_ONLY, settingExistsInScope, setPendingSettingValue, hasRestartRequiredSettings, @@ -34,15 +35,153 @@ import { isValueInherited, getEffectiveDisplayValue, } from './settingsUtils.js'; +import { + getSettingsSchema, + type SettingDefinition, + type Settings, + type SettingsSchema, + type SettingsSchemaType, +} from '../config/settingsSchema.js'; + +vi.mock('../config/settingsSchema.js', async (importOriginal) => { + const original = + await importOriginal(); + return { + ...original, + getSettingsSchema: vi.fn(), + }; +}); + +function makeMockSettings(settings: unknown): Settings { + return settings as Settings; +} describe('SettingsUtils', () => { + beforeEach(() => { + const SETTINGS_SCHEMA = { + mcpServers: { + type: 'object', + label: 'MCP Servers', + category: 'Advanced', + requiresRestart: true, + default: {} as Record, + description: 'Configuration for MCP servers.', + showInDialog: false, + }, + test: { + type: 'string', + label: 'Test', + category: 'Basic', + requiresRestart: false, + default: 'hello', + description: 'A test field', + showInDialog: true, + }, + advanced: { + type: 'object', + label: 'Advanced', + category: 'Advanced', + requiresRestart: true, + default: {}, + description: 'Advanced settings for power users.', + showInDialog: false, + }, + ui: { + type: 'object', + label: 'UI', + category: 'UI', + requiresRestart: false, + default: {}, + description: 'User interface settings.', + showInDialog: false, + properties: { + theme: { + type: 'string', + label: 'Theme', + category: 'UI', + requiresRestart: false, + default: undefined as string | undefined, + description: 'The color theme for the UI.', + showInDialog: false, + }, + requiresRestart: { + type: 'boolean', + label: 'Requires Restart', + category: 'UI', + default: false, + requiresRestart: true, + }, + accessibility: { + type: 'object', + label: 'Accessibility', + category: 'UI', + requiresRestart: true, + default: {}, + description: 'Accessibility settings.', + showInDialog: false, + properties: { + disableLoadingPhrases: { + type: 'boolean', + label: 'Disable Loading Phrases', + category: 'UI', + requiresRestart: true, + default: false, + description: 'Disable loading phrases for accessibility', + showInDialog: true, + }, + }, + }, + }, + }, + tools: { + type: 'object', + label: 'Tools', + category: 'Tools', + requiresRestart: false, + default: {}, + description: 'Tool settings.', + showInDialog: false, + properties: { + shell: { + type: 'object', + label: 'Shell', + category: 'Tools', + requiresRestart: false, + default: {}, + description: 'Shell tool settings.', + showInDialog: false, + properties: { + pager: { + type: 'string', + label: 'Pager', + category: 'Tools', + requiresRestart: false, + default: 'less', + description: 'The pager to use for long output.', + showInDialog: true, + }, + }, + }, + }, + }, + } as const satisfies SettingsSchema; + + vi.mocked(getSettingsSchema).mockReturnValue( + SETTINGS_SCHEMA as unknown as SettingsSchemaType, + ); + }); + afterEach(() => { + TEST_ONLY.clearFlattenedSchema(); + vi.clearAllMocks(); + vi.resetAllMocks(); + }); + describe('Schema Utilities', () => { describe('getSettingsByCategory', () => { it('should group settings by category', () => { const categories = getSettingsByCategory(); - - expect(categories).toHaveProperty('General'); - expect(categories).toHaveProperty('UI'); + expect(categories).toHaveProperty('Advanced'); + expect(categories).toHaveProperty('Basic'); }); it('should include key property in grouped settings', () => { @@ -58,9 +197,9 @@ describe('SettingsUtils', () => { describe('getSettingDefinition', () => { it('should return definition for valid setting', () => { - const definition = getSettingDefinition('ui.showMemoryUsage'); + const definition = getSettingDefinition('ui.theme'); expect(definition).toBeDefined(); - expect(definition?.label).toBe('Show Memory Usage'); + expect(definition?.label).toBe('Theme'); }); it('should return undefined for invalid setting', () => { @@ -71,13 +210,11 @@ describe('SettingsUtils', () => { describe('requiresRestart', () => { it('should return true for settings that require restart', () => { - expect(requiresRestart('advanced.autoConfigureMemory')).toBe(true); - expect(requiresRestart('general.checkpointing.enabled')).toBe(true); + expect(requiresRestart('ui.requiresRestart')).toBe(true); }); it('should return false for settings that do not require restart', () => { - expect(requiresRestart('ui.showMemoryUsage')).toBe(false); - expect(requiresRestart('ui.hideTips')).toBe(false); + expect(requiresRestart('ui.theme')).toBe(false); }); it('should return false for invalid settings', () => { @@ -87,10 +224,8 @@ describe('SettingsUtils', () => { describe('getDefaultValue', () => { it('should return correct default values', () => { - expect(getDefaultValue('ui.showMemoryUsage')).toBe(false); - expect( - getDefaultValue('context.fileFiltering.enableRecursiveFileSearch'), - ).toBe(true); + expect(getDefaultValue('test')).toBe('hello'); + expect(getDefaultValue('ui.requiresRestart')).toBe(false); }); it('should return undefined for invalid settings', () => { @@ -101,19 +236,20 @@ describe('SettingsUtils', () => { describe('getRestartRequiredSettings', () => { it('should return all settings that require restart', () => { const restartSettings = getRestartRequiredSettings(); - expect(restartSettings).toContain('advanced.autoConfigureMemory'); - expect(restartSettings).toContain('general.checkpointing.enabled'); - expect(restartSettings).not.toContain('ui.showMemoryUsage'); + expect(restartSettings).toContain('mcpServers'); + expect(restartSettings).toContain('ui.requiresRestart'); }); }); describe('getEffectiveValue', () => { it('should return value from settings when set', () => { - const settings = { ui: { showMemoryUsage: true } }; - const mergedSettings = { ui: { showMemoryUsage: false } }; + const settings = makeMockSettings({ ui: { requiresRestart: true } }); + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: false }, + }); const value = getEffectiveValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, ); @@ -121,11 +257,13 @@ describe('SettingsUtils', () => { }); it('should return value from merged settings when not set in current scope', () => { - const settings = {}; - const mergedSettings = { ui: { showMemoryUsage: true } }; + const settings = makeMockSettings({}); + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: true }, + }); const value = getEffectiveValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, ); @@ -133,11 +271,11 @@ describe('SettingsUtils', () => { }); it('should return default value when not set anywhere', () => { - const settings = {}; - const mergedSettings = {}; + const settings = makeMockSettings({}); + const mergedSettings = makeMockSettings({}); const value = getEffectiveValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, ); @@ -145,12 +283,12 @@ describe('SettingsUtils', () => { }); it('should handle nested settings correctly', () => { - const settings = { + const settings = makeMockSettings({ ui: { accessibility: { disableLoadingPhrases: true } }, - }; - const mergedSettings = { + }); + const mergedSettings = makeMockSettings({ ui: { accessibility: { disableLoadingPhrases: false } }, - }; + }); const value = getEffectiveValue( 'ui.accessibility.disableLoadingPhrases', @@ -161,8 +299,8 @@ describe('SettingsUtils', () => { }); it('should return undefined for invalid settings', () => { - const settings = {}; - const mergedSettings = {}; + const settings = makeMockSettings({}); + const mergedSettings = makeMockSettings({}); const value = getEffectiveValue( 'invalidSetting', @@ -176,9 +314,8 @@ describe('SettingsUtils', () => { describe('getAllSettingKeys', () => { it('should return all setting keys', () => { const keys = getAllSettingKeys(); - expect(keys).toContain('ui.showMemoryUsage'); + expect(keys).toContain('test'); expect(keys).toContain('ui.accessibility.disableLoadingPhrases'); - expect(keys).toContain('general.checkpointing.enabled'); }); }); @@ -204,7 +341,7 @@ describe('SettingsUtils', () => { describe('isValidSettingKey', () => { it('should return true for valid setting keys', () => { - expect(isValidSettingKey('ui.showMemoryUsage')).toBe(true); + expect(isValidSettingKey('ui.requiresRestart')).toBe(true); expect( isValidSettingKey('ui.accessibility.disableLoadingPhrases'), ).toBe(true); @@ -218,7 +355,7 @@ describe('SettingsUtils', () => { describe('getSettingCategory', () => { it('should return correct category for valid settings', () => { - expect(getSettingCategory('ui.showMemoryUsage')).toBe('UI'); + expect(getSettingCategory('ui.requiresRestart')).toBe('UI'); expect( getSettingCategory('ui.accessibility.disableLoadingPhrases'), ).toBe('UI'); @@ -231,20 +368,13 @@ describe('SettingsUtils', () => { describe('shouldShowInDialog', () => { it('should return true for settings marked to show in dialog', () => { - expect(shouldShowInDialog('ui.showMemoryUsage')).toBe(true); + expect(shouldShowInDialog('ui.requiresRestart')).toBe(true); expect(shouldShowInDialog('general.vimMode')).toBe(true); expect(shouldShowInDialog('ui.hideWindowTitle')).toBe(true); - expect(shouldShowInDialog('privacy.usageStatisticsEnabled')).toBe( - false, - ); }); it('should return false for settings marked to hide from dialog', () => { - expect(shouldShowInDialog('security.auth.selectedType')).toBe(false); - expect(shouldShowInDialog('tools.core')).toBe(false); - expect(shouldShowInDialog('ui.customThemes')).toBe(false); - expect(shouldShowInDialog('ui.theme')).toBe(false); // Changed to false - expect(shouldShowInDialog('general.preferredEditor')).toBe(false); // Changed to false + expect(shouldShowInDialog('ui.theme')).toBe(false); }); it('should return true for invalid settings (default behavior)', () => { @@ -260,9 +390,8 @@ describe('SettingsUtils', () => { expect(categories['UI']).toBeDefined(); const uiSettings = categories['UI']; const uiKeys = uiSettings.map((s) => s.key); - expect(uiKeys).toContain('ui.showMemoryUsage'); - expect(uiKeys).toContain('ui.hideWindowTitle'); - expect(uiKeys).not.toContain('ui.customThemes'); // This is marked false + expect(uiKeys).toContain('ui.requiresRestart'); + expect(uiKeys).toContain('ui.accessibility.disableLoadingPhrases'); expect(uiKeys).not.toContain('ui.theme'); // This is now marked false }); @@ -279,13 +408,8 @@ describe('SettingsUtils', () => { const allSettings = Object.values(categories).flat(); const allKeys = allSettings.map((s) => s.key); - expect(allKeys).toContain('general.vimMode'); - expect(allKeys).toContain('ide.enabled'); - expect(allKeys).toContain('general.disableAutoUpdate'); - expect(allKeys).toContain('ui.showMemoryUsage'); - expect(allKeys).not.toContain('privacy.usageStatisticsEnabled'); - expect(allKeys).not.toContain('security.auth.selectedType'); - expect(allKeys).not.toContain('tools.core'); + expect(allKeys).toContain('test'); + expect(allKeys).toContain('ui.requiresRestart'); expect(allKeys).not.toContain('ui.theme'); // Now hidden expect(allKeys).not.toContain('general.preferredEditor'); // Now hidden }); @@ -296,9 +420,8 @@ describe('SettingsUtils', () => { const booleanSettings = getDialogSettingsByType('boolean'); const keys = booleanSettings.map((s) => s.key); - expect(keys).toContain('ui.showMemoryUsage'); - expect(keys).toContain('general.vimMode'); - expect(keys).toContain('ui.hideWindowTitle'); + expect(keys).toContain('ui.requiresRestart'); + expect(keys).toContain('ui.accessibility.disableLoadingPhrases'); expect(keys).not.toContain('privacy.usageStatisticsEnabled'); expect(keys).not.toContain('security.auth.selectedType'); // Advanced setting expect(keys).not.toContain('security.auth.useExternal'); // Advanced setting @@ -313,8 +436,13 @@ describe('SettingsUtils', () => { expect(keys).not.toContain('general.preferredEditor'); // Now marked false expect(keys).not.toContain('security.auth.selectedType'); // Advanced setting - // Most string settings are now hidden, so let's just check they exclude advanced ones - expect(keys.every((key) => !key.startsWith('tool'))).toBe(true); // No tool-related settings + // Check that user-facing tool settings are included + expect(keys).toContain('tools.shell.pager'); + + // Check that advanced/hidden tool settings are excluded + expect(keys).not.toContain('tools.discoveryCommand'); + expect(keys).not.toContain('tools.callCommand'); + expect(keys.every((key) => !key.startsWith('advanced.'))).toBe(true); }); }); @@ -323,30 +451,13 @@ describe('SettingsUtils', () => { const dialogKeys = getDialogSettingKeys(); // Should include settings marked for dialog - expect(dialogKeys).toContain('ui.showMemoryUsage'); - expect(dialogKeys).toContain('general.vimMode'); - expect(dialogKeys).toContain('ui.hideWindowTitle'); - expect(dialogKeys).not.toContain('privacy.usageStatisticsEnabled'); - expect(dialogKeys).toContain('ide.enabled'); - expect(dialogKeys).toContain('general.disableAutoUpdate'); + expect(dialogKeys).toContain('ui.requiresRestart'); // Should include nested settings marked for dialog - expect(dialogKeys).toContain('context.fileFiltering.respectGitIgnore'); - expect(dialogKeys).toContain( - 'context.fileFiltering.respectGeminiIgnore', - ); - expect(dialogKeys).toContain( - 'context.fileFiltering.enableRecursiveFileSearch', - ); + expect(dialogKeys).toContain('ui.accessibility.disableLoadingPhrases'); // Should NOT include settings marked as hidden expect(dialogKeys).not.toContain('ui.theme'); // Hidden - expect(dialogKeys).not.toContain('ui.customThemes'); // Hidden - expect(dialogKeys).not.toContain('general.preferredEditor'); // Hidden - expect(dialogKeys).not.toContain('security.auth.selectedType'); // Advanced - expect(dialogKeys).not.toContain('tools.core'); // Advanced - expect(dialogKeys).not.toContain('mcpServers'); // Advanced - expect(dialogKeys).not.toContain('telemetry'); // Advanced }); it('should return fewer keys than getAllSettingKeys', () => { @@ -358,10 +469,44 @@ describe('SettingsUtils', () => { }); it('should handle nested settings display correctly', () => { + vi.mocked(getSettingsSchema).mockReturnValue({ + context: { + type: 'object', + label: 'Context', + category: 'Context', + requiresRestart: false, + default: {}, + description: 'Settings for managing context provided to the model.', + showInDialog: false, + properties: { + fileFiltering: { + type: 'object', + label: 'File Filtering', + category: 'Context', + requiresRestart: true, + default: {}, + description: 'Settings for git-aware file filtering.', + showInDialog: false, + properties: { + respectGitIgnore: { + type: 'boolean', + label: 'Respect .gitignore', + category: 'Context', + requiresRestart: true, + default: true, + description: 'Respect .gitignore files when searching', + showInDialog: true, + }, + }, + }, + }, + }, + } as unknown as SettingsSchemaType); + // Test the specific issue with fileFiltering.respectGitIgnore const key = 'context.fileFiltering.respectGitIgnore'; - const initialSettings = {}; - const pendingSettings = {}; + const initialSettings = makeMockSettings({}); + const pendingSettings = makeMockSettings({}); // Set the nested setting to true const updatedPendingSettings = setPendingSettingValue( @@ -412,11 +557,13 @@ describe('SettingsUtils', () => { describe('Business Logic Utilities', () => { describe('getSettingValue', () => { it('should return value from settings when set', () => { - const settings = { ui: { showMemoryUsage: true } }; - const mergedSettings = { ui: { showMemoryUsage: false } }; + const settings = makeMockSettings({ ui: { requiresRestart: true } }); + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: false }, + }); const value = getSettingValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, ); @@ -424,11 +571,13 @@ describe('SettingsUtils', () => { }); it('should return value from merged settings when not set in current scope', () => { - const settings = {}; - const mergedSettings = { ui: { showMemoryUsage: true } }; + const settings = makeMockSettings({}); + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: true }, + }); const value = getSettingValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, ); @@ -436,8 +585,8 @@ describe('SettingsUtils', () => { }); it('should return default value for invalid setting', () => { - const settings = {}; - const mergedSettings = {}; + const settings = makeMockSettings({}); + const mergedSettings = makeMockSettings({}); const value = getSettingValue( 'invalidSetting', @@ -450,43 +599,37 @@ describe('SettingsUtils', () => { describe('isSettingModified', () => { it('should return true when value differs from default', () => { - expect(isSettingModified('ui.showMemoryUsage', true)).toBe(true); + expect(isSettingModified('ui.requiresRestart', true)).toBe(true); expect( - isSettingModified( - 'context.fileFiltering.enableRecursiveFileSearch', - false, - ), + isSettingModified('ui.accessibility.disableLoadingPhrases', true), ).toBe(true); }); it('should return false when value matches default', () => { - expect(isSettingModified('ui.showMemoryUsage', false)).toBe(false); + expect(isSettingModified('ui.requiresRestart', false)).toBe(false); expect( - isSettingModified( - 'context.fileFiltering.enableRecursiveFileSearch', - true, - ), + isSettingModified('ui.accessibility.disableLoadingPhrases', false), ).toBe(false); }); }); describe('settingExistsInScope', () => { it('should return true for top-level settings that exist', () => { - const settings = { ui: { showMemoryUsage: true } }; - expect(settingExistsInScope('ui.showMemoryUsage', settings)).toBe(true); + const settings = makeMockSettings({ ui: { requiresRestart: true } }); + expect(settingExistsInScope('ui.requiresRestart', settings)).toBe(true); }); it('should return false for top-level settings that do not exist', () => { - const settings = {}; - expect(settingExistsInScope('ui.showMemoryUsage', settings)).toBe( + const settings = makeMockSettings({}); + expect(settingExistsInScope('ui.requiresRestart', settings)).toBe( false, ); }); it('should return true for nested settings that exist', () => { - const settings = { + const settings = makeMockSettings({ ui: { accessibility: { disableLoadingPhrases: true } }, - }; + }); expect( settingExistsInScope( 'ui.accessibility.disableLoadingPhrases', @@ -496,7 +639,7 @@ describe('SettingsUtils', () => { }); it('should return false for nested settings that do not exist', () => { - const settings = {}; + const settings = makeMockSettings({}); expect( settingExistsInScope( 'ui.accessibility.disableLoadingPhrases', @@ -506,7 +649,7 @@ describe('SettingsUtils', () => { }); it('should return false when parent exists but child does not', () => { - const settings = { ui: { accessibility: {} } }; + const settings = makeMockSettings({ ui: { accessibility: {} } }); expect( settingExistsInScope( 'ui.accessibility.disableLoadingPhrases', @@ -518,18 +661,18 @@ describe('SettingsUtils', () => { describe('setPendingSettingValue', () => { it('should set top-level setting value', () => { - const pendingSettings = {}; + const pendingSettings = makeMockSettings({}); const result = setPendingSettingValue( - 'ui.showMemoryUsage', + 'ui.hideWindowTitle', true, pendingSettings, ); - expect(result.ui?.showMemoryUsage).toBe(true); + expect(result.ui?.hideWindowTitle).toBe(true); }); it('should set nested setting value', () => { - const pendingSettings = {}; + const pendingSettings = makeMockSettings({}); const result = setPendingSettingValue( 'ui.accessibility.disableLoadingPhrases', true, @@ -540,9 +683,9 @@ describe('SettingsUtils', () => { }); it('should preserve existing nested settings', () => { - const pendingSettings = { + const pendingSettings = makeMockSettings({ ui: { accessibility: { disableLoadingPhrases: false } }, - }; + }); const result = setPendingSettingValue( 'ui.accessibility.disableLoadingPhrases', true, @@ -553,8 +696,8 @@ describe('SettingsUtils', () => { }); it('should not mutate original settings', () => { - const pendingSettings = {}; - setPendingSettingValue('ui.showMemoryUsage', true, pendingSettings); + const pendingSettings = makeMockSettings({}); + setPendingSettingValue('ui.requiresRestart', true, pendingSettings); expect(pendingSettings).toEqual({}); }); @@ -564,16 +707,13 @@ describe('SettingsUtils', () => { it('should return true when modified settings require restart', () => { const modifiedSettings = new Set([ 'advanced.autoConfigureMemory', - 'ui.showMemoryUsage', + 'ui.requiresRestart', ]); expect(hasRestartRequiredSettings(modifiedSettings)).toBe(true); }); it('should return false when no modified settings require restart', () => { - const modifiedSettings = new Set([ - 'ui.showMemoryUsage', - 'ui.hideTips', - ]); + const modifiedSettings = new Set(['test']); expect(hasRestartRequiredSettings(modifiedSettings)).toBe(false); }); @@ -586,20 +726,18 @@ describe('SettingsUtils', () => { describe('getRestartRequiredFromModified', () => { it('should return only settings that require restart', () => { const modifiedSettings = new Set([ - 'advanced.autoConfigureMemory', - 'ui.showMemoryUsage', - 'general.checkpointing.enabled', + 'ui.requiresRestart', + 'test', ]); const result = getRestartRequiredFromModified(modifiedSettings); - expect(result).toContain('advanced.autoConfigureMemory'); - expect(result).toContain('general.checkpointing.enabled'); - expect(result).not.toContain('ui.showMemoryUsage'); + expect(result).toContain('ui.requiresRestart'); + expect(result).not.toContain('test'); }); it('should return empty array when no settings require restart', () => { const modifiedSettings = new Set([ - 'showMemoryUsage', + 'requiresRestart', 'hideTips', ]); const result = getRestartRequiredFromModified(modifiedSettings); @@ -609,13 +747,193 @@ describe('SettingsUtils', () => { }); describe('getDisplayValue', () => { + describe('enum behavior', () => { + enum StringEnum { + FOO = 'foo', + BAR = 'bar', + BAZ = 'baz', + } + + enum NumberEnum { + ONE = 1, + TWO = 2, + THREE = 3, + } + + const SETTING: SettingDefinition = { + type: 'enum', + label: 'Theme', + options: [ + { + value: StringEnum.FOO, + label: 'Foo', + }, + { + value: StringEnum.BAR, + label: 'Bar', + }, + { + value: StringEnum.BAZ, + label: 'Baz', + }, + ], + category: 'UI', + requiresRestart: false, + default: StringEnum.BAR, + description: 'The color theme for the UI.', + showInDialog: false, + }; + + it('handles display of number-based enums', () => { + vi.mocked(getSettingsSchema).mockReturnValue({ + ui: { + properties: { + theme: { + ...SETTING, + options: [ + { + value: NumberEnum.ONE, + label: 'One', + }, + { + value: NumberEnum.TWO, + label: 'Two', + }, + { + value: NumberEnum.THREE, + label: 'Three', + }, + ], + }, + }, + }, + } as unknown as SettingsSchemaType); + + const settings = makeMockSettings({ + ui: { theme: NumberEnum.THREE }, + }); + const mergedSettings = makeMockSettings({ + ui: { theme: NumberEnum.THREE }, + }); + const modifiedSettings = new Set(); + + const result = getDisplayValue( + 'ui.theme', + settings, + mergedSettings, + modifiedSettings, + ); + + expect(result).toBe('Three*'); + }); + + it('handles default values for number-based enums', () => { + vi.mocked(getSettingsSchema).mockReturnValue({ + ui: { + properties: { + theme: { + ...SETTING, + default: NumberEnum.THREE, + options: [ + { + value: NumberEnum.ONE, + label: 'One', + }, + { + value: NumberEnum.TWO, + label: 'Two', + }, + { + value: NumberEnum.THREE, + label: 'Three', + }, + ], + }, + }, + }, + } as unknown as SettingsSchemaType); + const modifiedSettings = new Set(); + + const result = getDisplayValue( + 'ui.theme', + makeMockSettings({}), + makeMockSettings({}), + modifiedSettings, + ); + expect(result).toBe('Three'); + }); + + it('shows the enum display value', () => { + vi.mocked(getSettingsSchema).mockReturnValue({ + ui: { properties: { theme: { ...SETTING } } }, + } as unknown as SettingsSchemaType); + const settings = makeMockSettings({ ui: { theme: StringEnum.BAR } }); + const mergedSettings = makeMockSettings({ + ui: { theme: StringEnum.BAR }, + }); + const modifiedSettings = new Set(); + + const result = getDisplayValue( + 'ui.theme', + settings, + mergedSettings, + modifiedSettings, + ); + expect(result).toBe('Bar*'); + }); + + it('passes through unknown values verbatim', () => { + vi.mocked(getSettingsSchema).mockReturnValue({ + ui: { + properties: { + theme: { ...SETTING }, + }, + }, + } as unknown as SettingsSchemaType); + const settings = makeMockSettings({ ui: { theme: 'xyz' } }); + const mergedSettings = makeMockSettings({ ui: { theme: 'xyz' } }); + const modifiedSettings = new Set(); + + const result = getDisplayValue( + 'ui.theme', + settings, + mergedSettings, + modifiedSettings, + ); + expect(result).toBe('xyz*'); + }); + + it('shows the default value for string enums', () => { + vi.mocked(getSettingsSchema).mockReturnValue({ + ui: { + properties: { + theme: { ...SETTING, default: StringEnum.BAR }, + }, + }, + } as unknown as SettingsSchemaType); + const modifiedSettings = new Set(); + + const result = getDisplayValue( + 'ui.theme', + makeMockSettings({}), + makeMockSettings({}), + modifiedSettings, + ); + expect(result).toBe('Bar'); + }); + }); + it('should show value without * when setting matches default', () => { - const settings = { ui: { showMemoryUsage: false } }; // false matches default, so no * - const mergedSettings = { ui: { showMemoryUsage: false } }; + const settings = makeMockSettings({ + ui: { requiresRestart: false }, + }); // false matches default, so no * + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: false }, + }); const modifiedSettings = new Set(); const result = getDisplayValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, modifiedSettings, @@ -624,12 +942,14 @@ describe('SettingsUtils', () => { }); it('should show default value when setting is not in scope', () => { - const settings = {}; // no setting in scope - const mergedSettings = { ui: { showMemoryUsage: false } }; + const settings = makeMockSettings({}); // no setting in scope + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: false }, + }); const modifiedSettings = new Set(); const result = getDisplayValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, modifiedSettings, @@ -638,12 +958,14 @@ describe('SettingsUtils', () => { }); it('should show value with * when changed from default', () => { - const settings = { ui: { showMemoryUsage: true } }; // true is different from default (false) - const mergedSettings = { ui: { showMemoryUsage: true } }; + const settings = makeMockSettings({ ui: { requiresRestart: true } }); // true is different from default (false) + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: true }, + }); const modifiedSettings = new Set(); const result = getDisplayValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, modifiedSettings, @@ -652,12 +974,14 @@ describe('SettingsUtils', () => { }); it('should show default value without * when setting does not exist in scope', () => { - const settings = {}; // setting doesn't exist in scope, show default - const mergedSettings = { ui: { showMemoryUsage: false } }; + const settings = makeMockSettings({}); // setting doesn't exist in scope, show default + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: false }, + }); const modifiedSettings = new Set(); const result = getDisplayValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, modifiedSettings, @@ -666,13 +990,17 @@ describe('SettingsUtils', () => { }); it('should show value with * when user changes from default', () => { - const settings = {}; // setting doesn't exist in scope originally - const mergedSettings = { ui: { showMemoryUsage: false } }; - const modifiedSettings = new Set(['ui.showMemoryUsage']); - const pendingSettings = { ui: { showMemoryUsage: true } }; // user changed to true + const settings = makeMockSettings({}); // setting doesn't exist in scope originally + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: false }, + }); + const modifiedSettings = new Set(['ui.requiresRestart']); + const pendingSettings = makeMockSettings({ + ui: { requiresRestart: true }, + }); // user changed to true const result = getDisplayValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, modifiedSettings, @@ -684,21 +1012,21 @@ describe('SettingsUtils', () => { describe('isDefaultValue', () => { it('should return true when setting does not exist in scope', () => { - const settings = {}; // setting doesn't exist + const settings = makeMockSettings({}); // setting doesn't exist - const result = isDefaultValue('ui.showMemoryUsage', settings); + const result = isDefaultValue('ui.requiresRestart', settings); expect(result).toBe(true); }); it('should return false when setting exists in scope', () => { - const settings = { ui: { showMemoryUsage: true } }; // setting exists + const settings = makeMockSettings({ ui: { requiresRestart: true } }); // setting exists - const result = isDefaultValue('ui.showMemoryUsage', settings); + const result = isDefaultValue('ui.requiresRestart', settings); expect(result).toBe(false); }); it('should return true when nested setting does not exist in scope', () => { - const settings = {}; // nested setting doesn't exist + const settings = makeMockSettings({}); // nested setting doesn't exist const result = isDefaultValue( 'ui.accessibility.disableLoadingPhrases', @@ -708,9 +1036,9 @@ describe('SettingsUtils', () => { }); it('should return false when nested setting exists in scope', () => { - const settings = { + const settings = makeMockSettings({ ui: { accessibility: { disableLoadingPhrases: true } }, - }; // nested setting exists + }); // nested setting exists const result = isDefaultValue( 'ui.accessibility.disableLoadingPhrases', @@ -722,11 +1050,13 @@ describe('SettingsUtils', () => { describe('isValueInherited', () => { it('should return false for top-level settings that exist in scope', () => { - const settings = { ui: { showMemoryUsage: true } }; - const mergedSettings = { ui: { showMemoryUsage: true } }; + const settings = makeMockSettings({ ui: { requiresRestart: true } }); + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: true }, + }); const result = isValueInherited( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, ); @@ -734,11 +1064,13 @@ describe('SettingsUtils', () => { }); it('should return true for top-level settings that do not exist in scope', () => { - const settings = {}; - const mergedSettings = { ui: { showMemoryUsage: true } }; + const settings = makeMockSettings({}); + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: true }, + }); const result = isValueInherited( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, ); @@ -746,12 +1078,12 @@ describe('SettingsUtils', () => { }); it('should return false for nested settings that exist in scope', () => { - const settings = { + const settings = makeMockSettings({ ui: { accessibility: { disableLoadingPhrases: true } }, - }; - const mergedSettings = { + }); + const mergedSettings = makeMockSettings({ ui: { accessibility: { disableLoadingPhrases: true } }, - }; + }); const result = isValueInherited( 'ui.accessibility.disableLoadingPhrases', @@ -762,10 +1094,10 @@ describe('SettingsUtils', () => { }); it('should return true for nested settings that do not exist in scope', () => { - const settings = {}; - const mergedSettings = { + const settings = makeMockSettings({}); + const mergedSettings = makeMockSettings({ ui: { accessibility: { disableLoadingPhrases: true } }, - }; + }); const result = isValueInherited( 'ui.accessibility.disableLoadingPhrases', @@ -778,11 +1110,13 @@ describe('SettingsUtils', () => { describe('getEffectiveDisplayValue', () => { it('should return value from settings when available', () => { - const settings = { ui: { showMemoryUsage: true } }; - const mergedSettings = { ui: { showMemoryUsage: false } }; + const settings = makeMockSettings({ ui: { requiresRestart: true } }); + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: false }, + }); const result = getEffectiveDisplayValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, ); @@ -790,11 +1124,13 @@ describe('SettingsUtils', () => { }); it('should return value from merged settings when not in scope', () => { - const settings = {}; - const mergedSettings = { ui: { showMemoryUsage: true } }; + const settings = makeMockSettings({}); + const mergedSettings = makeMockSettings({ + ui: { requiresRestart: true }, + }); const result = getEffectiveDisplayValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, ); @@ -802,11 +1138,11 @@ describe('SettingsUtils', () => { }); it('should return default value for undefined values', () => { - const settings = {}; - const mergedSettings = {}; + const settings = makeMockSettings({}); + const mergedSettings = makeMockSettings({}); const result = getEffectiveDisplayValue( - 'ui.showMemoryUsage', + 'ui.requiresRestart', settings, mergedSettings, ); diff --git a/packages/cli/src/utils/settingsUtils.ts b/packages/cli/src/utils/settingsUtils.ts index d6a114a674b..a9a429370ac 100644 --- a/packages/cli/src/utils/settingsUtils.ts +++ b/packages/cli/src/utils/settingsUtils.ts @@ -12,18 +12,19 @@ import type { import type { SettingDefinition, SettingsSchema, + SettingsType, + SettingsValue, } from '../config/settingsSchema.js'; -import { SETTINGS_SCHEMA } from '../config/settingsSchema.js'; +import { getSettingsSchema } from '../config/settingsSchema.js'; // The schema is now nested, but many parts of the UI and logic work better // with a flattened structure and dot-notation keys. This section flattens the // schema into a map for easier lookups. -function flattenSchema( - schema: SettingsSchema, - prefix = '', -): Record { - let result: Record = {}; +type FlattenedSchema = Record; + +function flattenSchema(schema: SettingsSchema, prefix = ''): FlattenedSchema { + let result: FlattenedSchema = {}; for (const key in schema) { const newKey = prefix ? `${prefix}.${key}` : key; const definition = schema[key]; @@ -35,7 +36,19 @@ function flattenSchema( return result; } -const FLATTENED_SCHEMA = flattenSchema(SETTINGS_SCHEMA); +let _FLATTENED_SCHEMA: FlattenedSchema | undefined; + +/** Returns a flattened schema, the first call is memoized for future requests. */ +export function getFlattenedSchema() { + return ( + _FLATTENED_SCHEMA ?? + (_FLATTENED_SCHEMA = flattenSchema(getSettingsSchema())) + ); +} + +function clearFlattenedSchema() { + _FLATTENED_SCHEMA = undefined; +} /** * Get all settings grouped by category @@ -49,7 +62,7 @@ export function getSettingsByCategory(): Record< Array > = {}; - Object.values(FLATTENED_SCHEMA).forEach((definition) => { + Object.values(getFlattenedSchema()).forEach((definition) => { const category = definition.category; if (!categories[category]) { categories[category] = []; @@ -66,28 +79,28 @@ export function getSettingsByCategory(): Record< export function getSettingDefinition( key: string, ): (SettingDefinition & { key: string }) | undefined { - return FLATTENED_SCHEMA[key]; + return getFlattenedSchema()[key]; } /** * Check if a setting requires restart */ export function requiresRestart(key: string): boolean { - return FLATTENED_SCHEMA[key]?.requiresRestart ?? false; + return getFlattenedSchema()[key]?.requiresRestart ?? false; } /** * Get the default value for a setting */ -export function getDefaultValue(key: string): SettingDefinition['default'] { - return FLATTENED_SCHEMA[key]?.default; +export function getDefaultValue(key: string): SettingsValue { + return getFlattenedSchema()[key]?.default; } /** * Get all setting keys that require restart */ export function getRestartRequiredSettings(): string[] { - return Object.values(FLATTENED_SCHEMA) + return Object.values(getFlattenedSchema()) .filter((definition) => definition.requiresRestart) .map((definition) => definition.key); } @@ -121,7 +134,7 @@ export function getEffectiveValue( key: string, settings: Settings, mergedSettings: Settings, -): SettingDefinition['default'] { +): SettingsValue { const definition = getSettingDefinition(key); if (!definition) { return undefined; @@ -132,13 +145,13 @@ export function getEffectiveValue( // Check the current scope's settings first let value = getNestedValue(settings as Record, path); if (value !== undefined) { - return value as SettingDefinition['default']; + return value as SettingsValue; } // Check the merged settings for an inherited value value = getNestedValue(mergedSettings as Record, path); if (value !== undefined) { - return value as SettingDefinition['default']; + return value as SettingsValue; } // Return default value if no value is set anywhere @@ -149,16 +162,16 @@ export function getEffectiveValue( * Get all setting keys from the schema */ export function getAllSettingKeys(): string[] { - return Object.keys(FLATTENED_SCHEMA); + return Object.keys(getFlattenedSchema()); } /** * Get settings by type */ export function getSettingsByType( - type: SettingDefinition['type'], + type: SettingsType, ): Array { - return Object.values(FLATTENED_SCHEMA).filter( + return Object.values(getFlattenedSchema()).filter( (definition) => definition.type === type, ); } @@ -171,7 +184,7 @@ export function getSettingsRequiringRestart(): Array< key: string; } > { - return Object.values(FLATTENED_SCHEMA).filter( + return Object.values(getFlattenedSchema()).filter( (definition) => definition.requiresRestart, ); } @@ -180,21 +193,21 @@ export function getSettingsRequiringRestart(): Array< * Validate if a setting key exists in the schema */ export function isValidSettingKey(key: string): boolean { - return key in FLATTENED_SCHEMA; + return key in getFlattenedSchema(); } /** * Get the category for a setting */ export function getSettingCategory(key: string): string | undefined { - return FLATTENED_SCHEMA[key]?.category; + return getFlattenedSchema()[key]?.category; } /** * Check if a setting should be shown in the settings dialog */ export function shouldShowInDialog(key: string): boolean { - return FLATTENED_SCHEMA[key]?.showInDialog ?? true; // Default to true for backward compatibility + return getFlattenedSchema()[key]?.showInDialog ?? true; // Default to true for backward compatibility } /** @@ -209,7 +222,7 @@ export function getDialogSettingsByCategory(): Record< Array > = {}; - Object.values(FLATTENED_SCHEMA) + Object.values(getFlattenedSchema()) .filter((definition) => definition.showInDialog !== false) .forEach((definition) => { const category = definition.category; @@ -226,9 +239,9 @@ export function getDialogSettingsByCategory(): Record< * Get settings by type that should be shown in the dialog */ export function getDialogSettingsByType( - type: SettingDefinition['type'], + type: SettingsType, ): Array { - return Object.values(FLATTENED_SCHEMA).filter( + return Object.values(getFlattenedSchema()).filter( (definition) => definition.type === type && definition.showInDialog !== false, ); @@ -238,7 +251,7 @@ export function getDialogSettingsByType( * Get all setting keys that should be shown in the dialog */ export function getDialogSettingKeys(): string[] { - return Object.values(FLATTENED_SCHEMA) + return Object.values(getFlattenedSchema()) .filter((definition) => definition.showInDialog !== false) .map((definition) => definition.key); } @@ -344,7 +357,7 @@ export function setPendingSettingValue( */ export function setPendingSettingValueAny( key: string, - value: unknown, + value: SettingsValue, pendingSettings: Settings, ): Settings { const path = key.split('.'); @@ -415,25 +428,30 @@ export function getDisplayValue( pendingSettings?: Settings, ): string { // Prioritize pending changes if user has modified this setting - let value: boolean; + const definition = getSettingDefinition(key); + + let value: SettingsValue; if (pendingSettings && settingExistsInScope(key, pendingSettings)) { // Show the value from the pending (unsaved) edits when it exists - value = getSettingValue(key, pendingSettings, {}); + value = getEffectiveValue(key, pendingSettings, {}); } else if (settingExistsInScope(key, settings)) { // Show the value defined at the current scope if present - value = getSettingValue(key, settings, {}); + value = getEffectiveValue(key, settings, {}); } else { // Fall back to the schema default when the key is unset in this scope - const defaultValue = getDefaultValue(key); - value = typeof defaultValue === 'boolean' ? defaultValue : false; + value = getDefaultValue(key); } - const valueString = String(value); + let valueString = String(value); + + if (definition?.type === 'enum' && definition.options) { + const option = definition.options?.find((option) => option.value === value); + valueString = option?.label ?? `${value}`; + } // Check if value is different from default OR if it's in modified settings OR if there are pending changes const defaultValue = getDefaultValue(key); - const isChangedFromDefault = - typeof defaultValue === 'boolean' ? value !== defaultValue : value === true; + const isChangedFromDefault = value !== defaultValue; const isInModifiedSettings = modifiedSettings.has(key); // Mark as modified if setting exists in current scope OR is in modified settings @@ -476,3 +494,5 @@ export function getEffectiveDisplayValue( ): boolean { return getSettingValue(key, settings, mergedSettings); } + +export const TEST_ONLY = { clearFlattenedSchema }; diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index 97f76e9fd8c..1b15c31dcf3 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -255,6 +255,7 @@ class Session { try { const responseStream = await chat.sendMessageStream( + this.config.getModel(), { message: nextMessage?.parts ?? [], config: { @@ -852,12 +853,15 @@ function toToolCallContent(toolResult: ToolResult): acp.ToolCallContent | null { content: { type: 'text', text: toolResult.returnDisplay }, }; } else { - return { - type: 'diff', - path: toolResult.returnDisplay.fileName, - oldText: toolResult.returnDisplay.originalContent, - newText: toolResult.returnDisplay.newContent, - }; + if ('fileName' in toolResult.returnDisplay) { + return { + type: 'diff', + path: toolResult.returnDisplay.fileName, + oldText: toolResult.returnDisplay.originalContent, + newText: toolResult.returnDisplay.newContent, + }; + } + return null; } } else { return null; diff --git a/packages/core/index.ts b/packages/core/index.ts index be8ca39dd76..81f09969f48 100644 --- a/packages/core/index.ts +++ b/packages/core/index.ts @@ -12,16 +12,24 @@ export { DEFAULT_GEMINI_FLASH_LITE_MODEL, DEFAULT_GEMINI_EMBEDDING_MODEL, } from './src/config/models.js'; +export { + serializeTerminalToObject, + type AnsiOutput, + type AnsiLine, + type AnsiToken, +} from './src/utils/terminalSerializer.js'; export { DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, } from './src/config/config.js'; -export { getIdeInfo } from './src/ide/detect-ide.js'; +export { detectIdeFromEnv, getIdeInfo } from './src/ide/detect-ide.js'; export { logIdeConnection } from './src/telemetry/loggers.js'; + export { IdeConnectionEvent, IdeConnectionType, + ExtensionInstallEvent, } from './src/telemetry/types.js'; -export { getIdeTrust } from './src/utils/ide-trust.js'; export { makeFakeConfig } from './src/test-utils/config.js'; export * from './src/utils/pathReader.js'; +export { ClearcutLogger } from './src/telemetry/clearcut-logger/clearcut-logger.js'; diff --git a/packages/core/package.json b/packages/core/package.json index f5a254f986e..e648594038a 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -1,6 +1,6 @@ { "name": "@blocksuser/gemini-cli-core", - "version": "0.3.3", + "version": "0.7.0-nightly.20250912.68035591", "description": "Gemini CLI Core", "repository": { "type": "git", @@ -21,7 +21,7 @@ ], "dependencies": { "@google/genai": "1.16.0", - "@lvce-editor/ripgrep": "^1.6.0", + "@joshua.litt/get-ripgrep": "^0.0.2", "@modelcontextprotocol/sdk": "^1.11.0", "@opentelemetry/api": "^1.9.0", "@opentelemetry/exporter-logs-otlp-grpc": "^0.203.0", diff --git a/packages/core/src/code_assist/codeAssist.ts b/packages/core/src/code_assist/codeAssist.ts index 32f314a51ae..c8ade92edda 100644 --- a/packages/core/src/code_assist/codeAssist.ts +++ b/packages/core/src/code_assist/codeAssist.ts @@ -40,7 +40,7 @@ export async function createCodeAssistContentGenerator( export function getCodeAssistServer( config: Config, ): CodeAssistServer | undefined { - let server = config.getGeminiClient().getContentGenerator(); + let server = config.getContentGenerator(); // Unwrap LoggingContentGenerator if present if (server instanceof LoggingContentGenerator) { diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 8f91d5fc487..bb4a7f5827a 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -25,6 +25,11 @@ import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js' import { ShellTool } from '../tools/shell.js'; import { ReadFileTool } from '../tools/read-file.js'; +import { GrepTool } from '../tools/grep.js'; +import { RipGrepTool, canUseRipgrep } from '../tools/ripGrep.js'; +import { logRipgrepFallback } from '../telemetry/loggers.js'; +import { RipgrepFallbackEvent } from '../telemetry/types.js'; +import { ToolRegistry } from '../tools/tool-registry.js'; vi.mock('fs', async (importOriginal) => { const actual = await importOriginal(); @@ -56,7 +61,11 @@ vi.mock('../utils/memoryDiscovery.js', () => ({ // Mock individual tools if their constructors are complex or have side effects vi.mock('../tools/ls'); vi.mock('../tools/read-file'); -vi.mock('../tools/grep'); +vi.mock('../tools/grep.js'); +vi.mock('../tools/ripGrep.js', () => ({ + canUseRipgrep: vi.fn(), + RipGrepTool: class MockRipGrepTool {}, +})); vi.mock('../tools/glob'); vi.mock('../tools/edit'); vi.mock('../tools/shell'); @@ -71,18 +80,12 @@ vi.mock('../tools/memoryTool', () => ({ GEMINI_CONFIG_DIR: '.gemini', })); -vi.mock('../core/contentGenerator.js', async (importOriginal) => { - const actual = - await importOriginal(); - return { - ...actual, - createContentGeneratorConfig: vi.fn(), - }; -}); +vi.mock('../core/contentGenerator.js'); vi.mock('../core/client.js', () => ({ GeminiClient: vi.fn().mockImplementation(() => ({ initialize: vi.fn().mockResolvedValue(undefined), + stripThoughtsFromHistory: vi.fn(), })), })); @@ -94,6 +97,15 @@ vi.mock('../telemetry/index.js', async (importOriginal) => { }; }); +vi.mock('../telemetry/loggers.js', async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + logRipgrepFallback: vi.fn(), + }; +}); + vi.mock('../services/gitService.js', () => { const GitServiceMock = vi.fn(); GitServiceMock.prototype.initialize = vi.fn(); @@ -110,6 +122,10 @@ vi.mock('../ide/ide-client.js', () => ({ }, })); +import { BaseLlmClient } from '../core/baseLlmClient.js'; + +vi.mock('../core/baseLlmClient.js'); + describe('Server Config (config.ts)', () => { const MODEL = 'gemini-pro'; const SANDBOX: SandboxConfig = { @@ -190,13 +206,13 @@ describe('Server Config (config.ts)', () => { it('should refresh auth and update config', async () => { const config = new Config(baseParams); const authType = AuthType.USE_GEMINI; - const newModel = 'gemini-flash'; const mockContentConfig = { - model: newModel, apiKey: 'test-key', }; - (createContentGeneratorConfig as Mock).mockReturnValue(mockContentConfig); + vi.mocked(createContentGeneratorConfig).mockReturnValue( + mockContentConfig, + ); // Set fallback mode to true to ensure it gets reset config.setFallbackMode(true); @@ -208,181 +224,45 @@ describe('Server Config (config.ts)', () => { config, authType, ); - // Verify that contentGeneratorConfig is updated with the new model + // Verify that contentGeneratorConfig is updated expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig); - expect(config.getContentGeneratorConfig().model).toBe(newModel); - expect(config.getModel()).toBe(newModel); // getModel() should return the updated model expect(GeminiClient).toHaveBeenCalledWith(config); // Verify that fallback mode is reset expect(config.isInFallbackMode()).toBe(false); }); - it('should preserve conversation history when refreshing auth', async () => { - const config = new Config(baseParams); - const authType = AuthType.USE_GEMINI; - const mockContentConfig = { - model: 'gemini-pro', - apiKey: 'test-key', - }; - - (createContentGeneratorConfig as Mock).mockReturnValue(mockContentConfig); - - // Mock the existing client with some history - const mockExistingHistory = [ - { role: 'user', parts: [{ text: 'Hello' }] }, - { role: 'model', parts: [{ text: 'Hi there!' }] }, - { role: 'user', parts: [{ text: 'How are you?' }] }, - ]; - - const mockExistingClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue(mockExistingHistory), - }; - - const mockNewClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue([]), - setHistory: vi.fn(), - initialize: vi.fn().mockResolvedValue(undefined), - }; - - // Set the existing client - ( - config as unknown as { geminiClient: typeof mockExistingClient } - ).geminiClient = mockExistingClient; - (GeminiClient as Mock).mockImplementation(() => mockNewClient); - - await config.refreshAuth(authType); - - // Verify that existing history was retrieved - expect(mockExistingClient.getHistory).toHaveBeenCalled(); - - // Verify that new client was created and initialized - expect(GeminiClient).toHaveBeenCalledWith(config); - expect(mockNewClient.initialize).toHaveBeenCalledWith(mockContentConfig); - - // Verify that history was restored to the new client - expect(mockNewClient.setHistory).toHaveBeenCalledWith( - mockExistingHistory, - { stripThoughts: false }, - ); - }); - - it('should handle case when no existing client is initialized', async () => { - const config = new Config(baseParams); - const authType = AuthType.USE_GEMINI; - const mockContentConfig = { - model: 'gemini-pro', - apiKey: 'test-key', - }; - - (createContentGeneratorConfig as Mock).mockReturnValue(mockContentConfig); - - const mockNewClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue([]), - setHistory: vi.fn(), - initialize: vi.fn().mockResolvedValue(undefined), - }; - - // No existing client - (config as unknown as { geminiClient: null }).geminiClient = null; - (GeminiClient as Mock).mockImplementation(() => mockNewClient); - - await config.refreshAuth(authType); - - // Verify that new client was created and initialized - expect(GeminiClient).toHaveBeenCalledWith(config); - expect(mockNewClient.initialize).toHaveBeenCalledWith(mockContentConfig); - - // Verify that setHistory was not called since there was no existing history - expect(mockNewClient.setHistory).not.toHaveBeenCalled(); - }); - it('should strip thoughts when switching from GenAI to Vertex', async () => { const config = new Config(baseParams); - const mockContentConfig = { - model: 'gemini-pro', - apiKey: 'test-key', - authType: AuthType.USE_GEMINI, - }; - ( - config as unknown as { contentGeneratorConfig: ContentGeneratorConfig } - ).contentGeneratorConfig = mockContentConfig; - (createContentGeneratorConfig as Mock).mockReturnValue({ - ...mockContentConfig, - authType: AuthType.LOGIN_WITH_GOOGLE, - }); - - const mockExistingHistory = [ - { role: 'user', parts: [{ text: 'Hello' }] }, - ]; - const mockExistingClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue(mockExistingHistory), - }; - const mockNewClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue([]), - setHistory: vi.fn(), - initialize: vi.fn().mockResolvedValue(undefined), - }; + vi.mocked(createContentGeneratorConfig).mockImplementation( + (_: Config, authType: AuthType | undefined) => + ({ authType }) as unknown as ContentGeneratorConfig, + ); - ( - config as unknown as { geminiClient: typeof mockExistingClient } - ).geminiClient = mockExistingClient; - (GeminiClient as Mock).mockImplementation(() => mockNewClient); + await config.refreshAuth(AuthType.USE_GEMINI); await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE); - expect(mockNewClient.setHistory).toHaveBeenCalledWith( - mockExistingHistory, - { stripThoughts: true }, - ); + expect( + config.getGeminiClient().stripThoughtsFromHistory, + ).toHaveBeenCalledWith(); }); it('should not strip thoughts when switching from Vertex to GenAI', async () => { const config = new Config(baseParams); - const mockContentConfig = { - model: 'gemini-pro', - apiKey: 'test-key', - authType: AuthType.LOGIN_WITH_GOOGLE, - }; - ( - config as unknown as { contentGeneratorConfig: ContentGeneratorConfig } - ).contentGeneratorConfig = mockContentConfig; - - (createContentGeneratorConfig as Mock).mockReturnValue({ - ...mockContentConfig, - authType: AuthType.USE_GEMINI, - }); - const mockExistingHistory = [ - { role: 'user', parts: [{ text: 'Hello' }] }, - ]; - const mockExistingClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue(mockExistingHistory), - }; - const mockNewClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue([]), - setHistory: vi.fn(), - initialize: vi.fn().mockResolvedValue(undefined), - }; + vi.mocked(createContentGeneratorConfig).mockImplementation( + (_: Config, authType: AuthType | undefined) => + ({ authType }) as unknown as ContentGeneratorConfig, + ); - ( - config as unknown as { geminiClient: typeof mockExistingClient } - ).geminiClient = mockExistingClient; - (GeminiClient as Mock).mockImplementation(() => mockNewClient); + await config.refreshAuth(AuthType.USE_VERTEX_AI); await config.refreshAuth(AuthType.USE_GEMINI); - expect(mockNewClient.setHistory).toHaveBeenCalledWith( - mockExistingHistory, - { stripThoughts: false }, - ); + expect( + config.getGeminiClient().stripThoughtsFromHistory, + ).not.toHaveBeenCalledWith(); }); }); @@ -610,36 +490,36 @@ describe('Server Config (config.ts)', () => { }); describe('UseRipgrep Configuration', () => { - it('should default useRipgrep to false when not provided', () => { + it('should default useRipgrep to true when not provided', () => { const config = new Config(baseParams); - expect(config.getUseRipgrep()).toBe(false); + expect(config.getUseRipgrep()).toBe(true); }); - it('should set useRipgrep to true when provided as true', () => { + it('should set useRipgrep to false when provided as false', () => { const paramsWithRipgrep: ConfigParameters = { ...baseParams, - useRipgrep: true, + useRipgrep: false, }; const config = new Config(paramsWithRipgrep); - expect(config.getUseRipgrep()).toBe(true); + expect(config.getUseRipgrep()).toBe(false); }); - it('should set useRipgrep to false when explicitly provided as false', () => { + it('should set useRipgrep to true when explicitly provided as true', () => { const paramsWithRipgrep: ConfigParameters = { ...baseParams, - useRipgrep: false, + useRipgrep: true, }; const config = new Config(paramsWithRipgrep); - expect(config.getUseRipgrep()).toBe(false); + expect(config.getUseRipgrep()).toBe(true); }); - it('should default useRipgrep to false when undefined', () => { + it('should default useRipgrep to true when undefined', () => { const paramsWithUndefinedRipgrep: ConfigParameters = { ...baseParams, useRipgrep: undefined, }; const config = new Config(paramsWithUndefinedRipgrep); - expect(config.getUseRipgrep()).toBe(false); + expect(config.getUseRipgrep()).toBe(true); }); }); @@ -804,4 +684,148 @@ describe('setApprovalMode with folder trust', () => { expect(() => config.setApprovalMode(ApprovalMode.AUTO_EDIT)).not.toThrow(); expect(() => config.setApprovalMode(ApprovalMode.DEFAULT)).not.toThrow(); }); + + describe('registerCoreTools', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('should register RipGrepTool when useRipgrep is true and it is available', async () => { + (canUseRipgrep as Mock).mockResolvedValue(true); + const config = new Config({ ...baseParams, useRipgrep: true }); + await config.initialize(); + + const calls = (ToolRegistry.prototype.registerTool as Mock).mock.calls; + const wasRipGrepRegistered = calls.some( + (call) => call[0] instanceof vi.mocked(RipGrepTool), + ); + const wasGrepRegistered = calls.some( + (call) => call[0] instanceof vi.mocked(GrepTool), + ); + + expect(wasRipGrepRegistered).toBe(true); + expect(wasGrepRegistered).toBe(false); + expect(logRipgrepFallback).not.toHaveBeenCalled(); + }); + + it('should register GrepTool as a fallback when useRipgrep is true but it is not available', async () => { + (canUseRipgrep as Mock).mockResolvedValue(false); + const config = new Config({ ...baseParams, useRipgrep: true }); + await config.initialize(); + + const calls = (ToolRegistry.prototype.registerTool as Mock).mock.calls; + const wasRipGrepRegistered = calls.some( + (call) => call[0] instanceof vi.mocked(RipGrepTool), + ); + const wasGrepRegistered = calls.some( + (call) => call[0] instanceof vi.mocked(GrepTool), + ); + + expect(wasRipGrepRegistered).toBe(false); + expect(wasGrepRegistered).toBe(true); + expect(logRipgrepFallback).toHaveBeenCalledWith( + config, + expect.any(RipgrepFallbackEvent), + ); + const event = (logRipgrepFallback as Mock).mock.calls[0][1]; + expect(event.error).toBeUndefined(); + }); + + it('should register GrepTool as a fallback when canUseRipgrep throws an error', async () => { + const error = new Error('ripGrep check failed'); + (canUseRipgrep as Mock).mockRejectedValue(error); + const config = new Config({ ...baseParams, useRipgrep: true }); + await config.initialize(); + + const calls = (ToolRegistry.prototype.registerTool as Mock).mock.calls; + const wasRipGrepRegistered = calls.some( + (call) => call[0] instanceof vi.mocked(RipGrepTool), + ); + const wasGrepRegistered = calls.some( + (call) => call[0] instanceof vi.mocked(GrepTool), + ); + + expect(wasRipGrepRegistered).toBe(false); + expect(wasGrepRegistered).toBe(true); + expect(logRipgrepFallback).toHaveBeenCalledWith( + config, + expect.any(RipgrepFallbackEvent), + ); + const event = (logRipgrepFallback as Mock).mock.calls[0][1]; + expect(event.error).toBe(String(error)); + }); + + it('should register GrepTool when useRipgrep is false', async () => { + const config = new Config({ ...baseParams, useRipgrep: false }); + await config.initialize(); + + const calls = (ToolRegistry.prototype.registerTool as Mock).mock.calls; + const wasRipGrepRegistered = calls.some( + (call) => call[0] instanceof vi.mocked(RipGrepTool), + ); + const wasGrepRegistered = calls.some( + (call) => call[0] instanceof vi.mocked(GrepTool), + ); + + expect(wasRipGrepRegistered).toBe(false); + expect(wasGrepRegistered).toBe(true); + expect(canUseRipgrep).not.toHaveBeenCalled(); + expect(logRipgrepFallback).not.toHaveBeenCalled(); + }); + }); +}); + +describe('BaseLlmClient Lifecycle', () => { + const MODEL = 'gemini-pro'; + const SANDBOX: SandboxConfig = { + command: 'docker', + image: 'gemini-cli-sandbox', + }; + const TARGET_DIR = '/path/to/target'; + const DEBUG_MODE = false; + const QUESTION = 'test question'; + const FULL_CONTEXT = false; + const USER_MEMORY = 'Test User Memory'; + const TELEMETRY_SETTINGS = { enabled: false }; + const EMBEDDING_MODEL = 'gemini-embedding'; + const SESSION_ID = 'test-session-id'; + const baseParams: ConfigParameters = { + cwd: '/tmp', + embeddingModel: EMBEDDING_MODEL, + sandbox: SANDBOX, + targetDir: TARGET_DIR, + debugMode: DEBUG_MODE, + question: QUESTION, + fullContext: FULL_CONTEXT, + userMemory: USER_MEMORY, + telemetry: TELEMETRY_SETTINGS, + sessionId: SESSION_ID, + model: MODEL, + usageStatisticsEnabled: false, + }; + + it('should throw an error if getBaseLlmClient is called before refreshAuth', () => { + const config = new Config(baseParams); + expect(() => config.getBaseLlmClient()).toThrow( + 'BaseLlmClient not initialized. Ensure authentication has occurred and ContentGenerator is ready.', + ); + }); + + it('should successfully initialize BaseLlmClient after refreshAuth is called', async () => { + const config = new Config(baseParams); + const authType = AuthType.USE_GEMINI; + const mockContentConfig = { model: 'gemini-flash', apiKey: 'test-key' }; + + vi.mocked(createContentGeneratorConfig).mockReturnValue(mockContentConfig); + + await config.refreshAuth(authType); + + // Should not throw + const llmService = config.getBaseLlmClient(); + expect(llmService).toBeDefined(); + expect(BaseLlmClient).toHaveBeenCalledWith( + config.getContentGenerator(), + config, + ); + }); }); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index f7f3e6a1773..af2264a89ab 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -6,9 +6,13 @@ import * as path from 'node:path'; import process from 'node:process'; -import type { ContentGeneratorConfig } from '../core/contentGenerator.js'; +import type { + ContentGenerator, + ContentGeneratorConfig, +} from '../core/contentGenerator.js'; import { AuthType, + createContentGenerator, createContentGeneratorConfig, } from '../core/contentGenerator.js'; import { PromptRegistry } from '../prompts/prompt-registry.js'; @@ -16,7 +20,7 @@ import { ToolRegistry } from '../tools/tool-registry.js'; import { LSTool } from '../tools/ls.js'; import { ReadFileTool } from '../tools/read-file.js'; import { GrepTool } from '../tools/grep.js'; -import { RipGrepTool } from '../tools/ripGrep.js'; +import { canUseRipgrep, RipGrepTool } from '../tools/ripGrep.js'; import { GlobTool } from '../tools/glob.js'; import { EditTool } from '../tools/edit.js'; import { SmartEditTool } from '../tools/smart-edit.js'; @@ -27,6 +31,7 @@ import { ReadManyFilesTool } from '../tools/read-many-files.js'; import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js'; import { WebSearchTool } from '../tools/web-search.js'; import { GeminiClient } from '../core/client.js'; +import { BaseLlmClient } from '../core/baseLlmClient.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { GitService } from '../services/gitService.js'; import type { TelemetryTarget } from '../telemetry/index.js'; @@ -39,24 +44,35 @@ import { StartSessionEvent } from '../telemetry/index.js'; import { DEFAULT_GEMINI_EMBEDDING_MODEL, DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_MODEL, } from './models.js'; import { shouldAttemptBrowserLaunch } from '../utils/browser.js'; import type { MCPOAuthConfig } from '../mcp/oauth-provider.js'; -import { IdeClient } from '../ide/ide-client.js'; -import { ideContext } from '../ide/ideContext.js'; -import type { Content } from '@google/genai'; +import { ideContextStore } from '../ide/ideContext.js'; import type { FileSystemService } from '../services/fileSystemService.js'; import { StandardFileSystemService } from '../services/fileSystemService.js'; -import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js'; -import { IdeConnectionEvent, IdeConnectionType } from '../telemetry/types.js'; +import { + logCliConfiguration, + logRipgrepFallback, +} from '../telemetry/loggers.js'; +import { RipgrepFallbackEvent } from '../telemetry/types.js'; +import type { FallbackModelHandler } from '../fallback/types.js'; +import { ModelRouterService } from '../routing/modelRouterService.js'; +import { OutputFormat } from '../output/types.js'; // Re-export OAuth config type export type { MCPOAuthConfig, AnyToolInvocation }; import type { AnyToolInvocation } from '../tools/tools.js'; import { WorkspaceContext } from '../utils/workspaceContext.js'; import { Storage } from './storage.js'; +import type { ShellExecutionConfig } from '../services/shellExecutionService.js'; import { FileExclusions } from '../utils/ignorePatterns.js'; import type { EventEmitter } from 'node:events'; +import { MessageBus } from '../confirmation-bus/message-bus.js'; +import { PolicyEngine } from '../policy/policy-engine.js'; +import type { PolicyEngineConfig } from '../policy/types.js'; +import type { UserTierId } from '../code_assist/types.js'; +import { ProxyAgent, setGlobalDispatcher } from 'undici'; export enum ApprovalMode { DEFAULT = 'default', @@ -90,6 +106,10 @@ export interface TelemetrySettings { outfile?: string; } +export interface OutputSettings { + format?: OutputFormat; +} + export interface GeminiCLIExtension { name: string; version: string; @@ -152,12 +172,6 @@ export interface SandboxConfig { image: string; } -export type FlashFallbackHandler = ( - currentModel: string, - fallbackModel: string, - error?: unknown, -) => Promise; - export interface ConfigParameters { sessionId: string; embeddingModel?: string; @@ -212,12 +226,16 @@ export interface ConfigParameters { useRipgrep?: boolean; shouldUseNodePtyShell?: boolean; skipNextSpeakerCheck?: boolean; + shellExecutionConfig?: ShellExecutionConfig; extensionManagement?: boolean; enablePromptCompletion?: boolean; truncateToolOutputThreshold?: number; truncateToolOutputLines?: number; + enableToolOutputTruncation?: boolean; eventEmitter?: EventEmitter; useSmartEdit?: boolean; + policyEngineConfig?: PolicyEngineConfig; + output?: OutputSettings; } export class Config { @@ -226,6 +244,7 @@ export class Config { private readonly sessionId: string; private fileSystemService: FileSystemService; private contentGeneratorConfig!: ContentGeneratorConfig; + private contentGenerator!: ContentGenerator; private readonly embeddingModel: string; private readonly sandbox: SandboxConfig | undefined; private readonly targetDir: string; @@ -248,6 +267,8 @@ export class Config { private readonly telemetrySettings: TelemetrySettings; private readonly usageStatisticsEnabled: boolean; private geminiClient!: GeminiClient; + private baseLlmClient!: BaseLlmClient; + private modelRouterService: ModelRouterService; private readonly fileFiltering: { respectGitIgnore: boolean; respectGeminiIgnore: boolean; @@ -260,7 +281,7 @@ export class Config { private readonly proxy: string | undefined; private readonly cwd: string; private readonly bugCommand: BugCommandSettings | undefined; - private readonly model: string; + private model: string; private readonly extensionContextFilePaths: string[]; private readonly noBrowser: boolean; private readonly folderTrustFeature: boolean; @@ -275,7 +296,7 @@ export class Config { name: string; extensionName: string; }>; - flashFallbackHandler?: FlashFallbackHandler; + fallbackModelHandler?: FallbackModelHandler; private quotaErrorOccurred: boolean = false; private readonly summarizeToolOutput: | Record @@ -288,15 +309,20 @@ export class Config { private readonly useRipgrep: boolean; private readonly shouldUseNodePtyShell: boolean; private readonly skipNextSpeakerCheck: boolean; + private shellExecutionConfig: ShellExecutionConfig; private readonly extensionManagement: boolean = true; private readonly enablePromptCompletion: boolean = false; private readonly truncateToolOutputThreshold: number; private readonly truncateToolOutputLines: number; + private readonly enableToolOutputTruncation: boolean; private initialized: boolean = false; readonly storage: Storage; private readonly fileExclusions: FileExclusions; private readonly eventEmitter?: EventEmitter; private readonly useSmartEdit: boolean; + private readonly messageBus: MessageBus; + private readonly policyEngine: PolicyEngine; + private readonly outputSettings: OutputSettings; constructor(params: ConfigParameters) { this.sessionId = params.sessionId; @@ -346,7 +372,7 @@ export class Config { this.cwd = params.cwd ?? process.cwd(); this.fileDiscoveryService = params.fileDiscoveryService ?? null; this.bugCommand = params.bugCommand; - this.model = params.model; + this.model = params.model || DEFAULT_GEMINI_MODEL; this.extensionContextFilePaths = params.extensionContextFilePaths ?? []; this.maxSessionTurns = params.maxSessionTurns ?? -1; this.experimentalZedIntegration = @@ -364,20 +390,33 @@ export class Config { this.chatCompression = params.chatCompression; this.interactive = params.interactive ?? false; this.trustedFolder = params.trustedFolder; - this.useRipgrep = params.useRipgrep ?? false; + this.useRipgrep = params.useRipgrep ?? true; this.shouldUseNodePtyShell = params.shouldUseNodePtyShell ?? false; - this.skipNextSpeakerCheck = params.skipNextSpeakerCheck ?? false; + this.skipNextSpeakerCheck = params.skipNextSpeakerCheck ?? true; + this.shellExecutionConfig = { + terminalWidth: params.shellExecutionConfig?.terminalWidth ?? 80, + terminalHeight: params.shellExecutionConfig?.terminalHeight ?? 24, + showColor: params.shellExecutionConfig?.showColor ?? false, + pager: params.shellExecutionConfig?.pager ?? 'cat', + }; this.truncateToolOutputThreshold = params.truncateToolOutputThreshold ?? DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD; this.truncateToolOutputLines = params.truncateToolOutputLines ?? DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES; + this.enableToolOutputTruncation = + params.enableToolOutputTruncation ?? false; this.useSmartEdit = params.useSmartEdit ?? true; this.extensionManagement = params.extensionManagement ?? true; this.storage = new Storage(this.targetDir); this.enablePromptCompletion = params.enablePromptCompletion ?? false; this.fileExclusions = new FileExclusions(this); this.eventEmitter = params.eventEmitter; + this.policyEngine = new PolicyEngine(params.policyEngineConfig); + this.messageBus = new MessageBus(this.policyEngine); + this.outputSettings = { + format: params.output?.format ?? OutputFormat.TEXT, + }; if (params.contextFileName) { setGeminiMdFilename(params.contextFileName); @@ -386,6 +425,12 @@ export class Config { if (this.telemetrySettings.enabled) { initializeTelemetry(this); } + + if (this.getProxy()) { + setGlobalDispatcher(new ProxyAgent(this.getProxy() as string)); + } + this.geminiClient = new GeminiClient(this); + this.modelRouterService = new ModelRouterService(this); } /** @@ -397,11 +442,6 @@ export class Config { } this.initialized = true; - if (this.getIdeMode()) { - await (await IdeClient.getInstance()).connect(); - logIdeConnection(this, new IdeConnectionEvent(IdeConnectionType.START)); - } - // Initialize centralized FileDiscoveryService this.getFileService(); if (this.getCheckpointingEnabled()) { @@ -410,46 +450,68 @@ export class Config { this.promptRegistry = new PromptRegistry(); this.toolRegistry = await this.createToolRegistry(); logCliConfiguration(this, new StartSessionEvent(this, this.toolRegistry)); + + await this.geminiClient.initialize(); + } + + getContentGenerator(): ContentGenerator { + return this.contentGenerator; } async refreshAuth(authMethod: AuthType) { - // Save the current conversation history before creating a new client - let existingHistory: Content[] = []; - if (this.geminiClient && this.geminiClient.isInitialized()) { - existingHistory = this.geminiClient.getHistory(); + // Vertex and Genai have incompatible encryption and sending history with + // throughtSignature from Genai to Vertex will fail, we need to strip them + if ( + this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI && + authMethod === AuthType.LOGIN_WITH_GOOGLE + ) { + // Restore the conversation history to the new client + this.geminiClient.stripThoughtsFromHistory(); } - // Create new content generator config const newContentGeneratorConfig = createContentGeneratorConfig( this, authMethod, ); - - // Create and initialize new client in local variable first - const newGeminiClient = new GeminiClient(this); - await newGeminiClient.initialize(newContentGeneratorConfig); - - // Vertex and Genai have incompatible encryption and sending history with - // throughtSignature from Genai to Vertex will fail, we need to strip them - const fromGenaiToVertex = - this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI && - authMethod === AuthType.LOGIN_WITH_GOOGLE; - + this.contentGenerator = await createContentGenerator( + newContentGeneratorConfig, + this, + this.getSessionId(), + ); // Only assign to instance properties after successful initialization this.contentGeneratorConfig = newContentGeneratorConfig; - this.geminiClient = newGeminiClient; - // Restore the conversation history to the new client - if (existingHistory.length > 0) { - this.geminiClient.setHistory(existingHistory, { - stripThoughts: fromGenaiToVertex, - }); - } + // Initialize BaseLlmClient now that the ContentGenerator is available + this.baseLlmClient = new BaseLlmClient(this.contentGenerator, this); // Reset the session flag since we're explicitly changing auth and using default model this.inFallbackMode = false; } + getUserTier(): UserTierId | undefined { + return this.contentGenerator?.userTier; + } + + /** + * Provides access to the BaseLlmClient for stateless LLM operations. + */ + getBaseLlmClient(): BaseLlmClient { + if (!this.baseLlmClient) { + // Handle cases where initialization might be deferred or authentication failed + if (this.contentGenerator) { + this.baseLlmClient = new BaseLlmClient( + this.getContentGenerator(), + this, + ); + } else { + throw new Error( + 'BaseLlmClient not initialized. Ensure authentication has occurred and ContentGenerator is ready.', + ); + } + } + return this.baseLlmClient; + } + getSessionId(): string { return this.sessionId; } @@ -463,13 +525,16 @@ export class Config { } getModel(): string { - return this.contentGeneratorConfig?.model || this.model; + return this.model; } setModel(newModel: string): void { - if (this.contentGeneratorConfig) { - this.contentGeneratorConfig.model = newModel; + // Do not allow Pro usage if the user is in fallback mode. + if (newModel.includes('pro') && this.isInFallbackMode()) { + return; } + + this.model = newModel; } isInFallbackMode(): boolean { @@ -480,8 +545,8 @@ export class Config { this.inFallbackMode = active; } - setFlashFallbackHandler(handler: FlashFallbackHandler): void { - this.flashFallbackHandler = handler; + setFallbackModelHandler(handler: FallbackModelHandler): void { + this.fallbackModelHandler = handler; } getMaxSessionTurns(): number { @@ -639,6 +704,10 @@ export class Config { return this.geminiClient; } + getModelRouterService(): ModelRouterService { + return this.modelRouterService; + } + getEnableRecursiveFileSearch(): boolean { return this.fileFiltering.enableRecursiveFileSearch; } @@ -768,7 +837,7 @@ export class Config { // restarts in the more common path. If the user chooses to mark the folder // as untrusted, the CLI will restart and we will have the trust value // reloaded. - const context = ideContext.getIdeContext(); + const context = ideContextStore.get(); if (context?.workspaceState?.isTrusted !== undefined) { return context.workspaceState.isTrusted; } @@ -814,6 +883,20 @@ export class Config { return this.skipNextSpeakerCheck; } + getShellExecutionConfig(): ShellExecutionConfig { + return this.shellExecutionConfig; + } + + setShellExecutionConfig(config: ShellExecutionConfig): void { + this.shellExecutionConfig = { + terminalWidth: + config.terminalWidth ?? this.shellExecutionConfig.terminalWidth, + terminalHeight: + config.terminalHeight ?? this.shellExecutionConfig.terminalHeight, + showColor: config.showColor ?? this.shellExecutionConfig.showColor, + pager: config.pager ?? this.shellExecutionConfig.pager, + }; + } getScreenReader(): boolean { return this.accessibility.screenReader ?? false; } @@ -822,6 +905,10 @@ export class Config { return this.enablePromptCompletion; } + getEnableToolOutputTruncation(): boolean { + return this.enableToolOutputTruncation; + } + getTruncateToolOutputThreshold(): number { return this.truncateToolOutputThreshold; } @@ -834,6 +921,12 @@ export class Config { return this.useSmartEdit; } + getOutputFormat(): OutputFormat { + return this.outputSettings?.format + ? this.outputSettings.format + : OutputFormat.TEXT; + } + async getGitService(): Promise { if (!this.gitService) { this.gitService = new GitService(this.targetDir, this.storage); @@ -846,6 +939,14 @@ export class Config { return this.fileExclusions; } + getMessageBus(): MessageBus { + return this.messageBus; + } + + getPolicyEngine(): PolicyEngine { + return this.policyEngine; + } + async createToolRegistry(): Promise { const registry = new ToolRegistry(this, this.eventEmitter); @@ -887,7 +988,19 @@ export class Config { registerCoreTool(ReadFileTool, this); if (this.getUseRipgrep()) { - registerCoreTool(RipGrepTool, this); + let useRipgrep = false; + let errorString: undefined | string = undefined; + try { + useRipgrep = await canUseRipgrep(); + } catch (error: unknown) { + errorString = String(error); + } + if (useRipgrep) { + registerCoreTool(RipGrepTool, this); + } else { + logRipgrepFallback(this, new RipgrepFallbackEvent(errorString)); + registerCoreTool(GrepTool, this); + } } else { registerCoreTool(GrepTool, this); } diff --git a/packages/core/src/config/models.test.ts b/packages/core/src/config/models.test.ts new file mode 100644 index 00000000000..8c790dd1aec --- /dev/null +++ b/packages/core/src/config/models.test.ts @@ -0,0 +1,83 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { + getEffectiveModel, + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_FLASH_LITE_MODEL, +} from './models.js'; + +describe('getEffectiveModel', () => { + describe('When NOT in fallback mode', () => { + const isInFallbackMode = false; + + it('should return the Pro model when Pro is requested', () => { + const model = getEffectiveModel(isInFallbackMode, DEFAULT_GEMINI_MODEL); + expect(model).toBe(DEFAULT_GEMINI_MODEL); + }); + + it('should return the Flash model when Flash is requested', () => { + const model = getEffectiveModel( + isInFallbackMode, + DEFAULT_GEMINI_FLASH_MODEL, + ); + expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + }); + + it('should return the Lite model when Lite is requested', () => { + const model = getEffectiveModel( + isInFallbackMode, + DEFAULT_GEMINI_FLASH_LITE_MODEL, + ); + expect(model).toBe(DEFAULT_GEMINI_FLASH_LITE_MODEL); + }); + + it('should return a custom model name when requested', () => { + const customModel = 'custom-model-v1'; + const model = getEffectiveModel(isInFallbackMode, customModel); + expect(model).toBe(customModel); + }); + }); + + describe('When IN fallback mode', () => { + const isInFallbackMode = true; + + it('should downgrade the Pro model to the Flash model', () => { + const model = getEffectiveModel(isInFallbackMode, DEFAULT_GEMINI_MODEL); + expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + }); + + it('should return the Flash model when Flash is requested', () => { + const model = getEffectiveModel( + isInFallbackMode, + DEFAULT_GEMINI_FLASH_MODEL, + ); + expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + }); + + it('should HONOR the Lite model when Lite is requested', () => { + const model = getEffectiveModel( + isInFallbackMode, + DEFAULT_GEMINI_FLASH_LITE_MODEL, + ); + expect(model).toBe(DEFAULT_GEMINI_FLASH_LITE_MODEL); + }); + + it('should HONOR any model with "lite" in its name', () => { + const customLiteModel = 'gemini-2.5-custom-lite-vNext'; + const model = getEffectiveModel(isInFallbackMode, customLiteModel); + expect(model).toBe(customLiteModel); + }); + + it('should downgrade any other custom model to the Flash model', () => { + const customModel = 'custom-model-v1-unlisted'; + const model = getEffectiveModel(isInFallbackMode, customModel); + expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + }); + }); +}); diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index 1d2c1310a77..a0aa73bfdd9 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -12,3 +12,35 @@ export const DEFAULT_GEMINI_EMBEDDING_MODEL = 'gemini-embedding-001'; // Some thinking models do not default to dynamic thinking which is done by a value of -1 export const DEFAULT_THINKING_MODE = -1; + +/** + * Determines the effective model to use, applying fallback logic if necessary. + * + * When fallback mode is active, this function enforces the use of the standard + * fallback model. However, it makes an exception for "lite" models (any model + * with "lite" in its name), allowing them to be used to preserve cost savings. + * This ensures that "pro" models are always downgraded, while "lite" model + * requests are honored. + * + * @param isInFallbackMode Whether the application is in fallback mode. + * @param requestedModel The model that was originally requested. + * @returns The effective model name. + */ +export function getEffectiveModel( + isInFallbackMode: boolean, + requestedModel: string, +): string { + // If we are not in fallback mode, simply use the requested model. + if (!isInFallbackMode) { + return requestedModel; + } + + // If a "lite" model is requested, honor it. This allows for variations of + // lite models without needing to list them all as constants. + if (requestedModel.includes('lite')) { + return requestedModel; + } + + // Default fallback for Gemini CLI. + return DEFAULT_GEMINI_FLASH_MODEL; +} diff --git a/packages/core/src/config/storage.test.ts b/packages/core/src/config/storage.test.ts index 3b2cbef9473..e93bb375c83 100644 --- a/packages/core/src/config/storage.test.ts +++ b/packages/core/src/config/storage.test.ts @@ -52,4 +52,9 @@ describe('Storage – additional helpers', () => { ); expect(Storage.getMcpOAuthTokensPath()).toBe(expected); }); + + it('getGlobalBinDir returns ~/.gemini/tmp/bin', () => { + const expected = path.join(os.homedir(), '.gemini', 'tmp', 'bin'); + expect(Storage.getGlobalBinDir()).toBe(expected); + }); }); diff --git a/packages/core/src/config/storage.ts b/packages/core/src/config/storage.ts index 6442b87c876..354d51f1c11 100644 --- a/packages/core/src/config/storage.ts +++ b/packages/core/src/config/storage.ts @@ -13,6 +13,7 @@ export const GEMINI_DIR = '.gemini'; export const GOOGLE_ACCOUNTS_FILENAME = 'google_accounts.json'; export const OAUTH_FILE = 'oauth_creds.json'; const TMP_DIR_NAME = 'tmp'; +const BIN_DIR_NAME = 'bin'; export class Storage { private readonly targetDir: string; @@ -57,6 +58,10 @@ export class Storage { return path.join(Storage.getGlobalGeminiDir(), TMP_DIR_NAME); } + static getGlobalBinDir(): string { + return path.join(Storage.getGlobalTempDir(), BIN_DIR_NAME); + } + getGeminiDir(): string { return path.join(this.targetDir, GEMINI_DIR); } diff --git a/packages/core/src/confirmation-bus/index.ts b/packages/core/src/confirmation-bus/index.ts new file mode 100644 index 00000000000..379d9aa4d84 --- /dev/null +++ b/packages/core/src/confirmation-bus/index.ts @@ -0,0 +1,8 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from './message-bus.js'; +export * from './types.js'; diff --git a/packages/core/src/confirmation-bus/message-bus.test.ts b/packages/core/src/confirmation-bus/message-bus.test.ts new file mode 100644 index 00000000000..8156671c9b4 --- /dev/null +++ b/packages/core/src/confirmation-bus/message-bus.test.ts @@ -0,0 +1,235 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { MessageBus } from './message-bus.js'; +import { PolicyEngine } from '../policy/policy-engine.js'; +import { PolicyDecision } from '../policy/types.js'; +import { + MessageBusType, + type ToolConfirmationRequest, + type ToolConfirmationResponse, + type ToolPolicyRejection, + type ToolExecutionSuccess, +} from './types.js'; + +describe('MessageBus', () => { + let messageBus: MessageBus; + let policyEngine: PolicyEngine; + + beforeEach(() => { + policyEngine = new PolicyEngine(); + messageBus = new MessageBus(policyEngine); + }); + + describe('publish', () => { + it('should emit error for invalid message', () => { + const errorHandler = vi.fn(); + messageBus.on('error', errorHandler); + + // @ts-expect-error - Testing invalid message + messageBus.publish({ invalid: 'message' }); + + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('Invalid message structure'), + }), + ); + }); + + it('should validate tool confirmation requests have correlationId', () => { + const errorHandler = vi.fn(); + messageBus.on('error', errorHandler); + + // @ts-expect-error - Testing missing correlationId + messageBus.publish({ + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall: { name: 'test' }, + }); + + expect(errorHandler).toHaveBeenCalled(); + }); + + it('should emit confirmation response when policy allows', () => { + vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.ALLOW); + + const responseHandler = vi.fn(); + messageBus.subscribe( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + responseHandler, + ); + + const request: ToolConfirmationRequest = { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall: { name: 'test-tool', args: {} }, + correlationId: '123', + }; + + messageBus.publish(request); + + const expectedResponse: ToolConfirmationResponse = { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: '123', + confirmed: true, + }; + expect(responseHandler).toHaveBeenCalledWith(expectedResponse); + }); + + it('should emit rejection and response when policy denies', () => { + vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.DENY); + + const responseHandler = vi.fn(); + const rejectionHandler = vi.fn(); + messageBus.subscribe( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + responseHandler, + ); + messageBus.subscribe( + MessageBusType.TOOL_POLICY_REJECTION, + rejectionHandler, + ); + + const request: ToolConfirmationRequest = { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall: { name: 'test-tool', args: {} }, + correlationId: '123', + }; + + messageBus.publish(request); + + const expectedRejection: ToolPolicyRejection = { + type: MessageBusType.TOOL_POLICY_REJECTION, + toolCall: { name: 'test-tool', args: {} }, + }; + expect(rejectionHandler).toHaveBeenCalledWith(expectedRejection); + + const expectedResponse: ToolConfirmationResponse = { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: '123', + confirmed: false, + }; + expect(responseHandler).toHaveBeenCalledWith(expectedResponse); + }); + + it('should pass through to UI when policy says ASK_USER', () => { + vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.ASK_USER); + + const requestHandler = vi.fn(); + messageBus.subscribe( + MessageBusType.TOOL_CONFIRMATION_REQUEST, + requestHandler, + ); + + const request: ToolConfirmationRequest = { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall: { name: 'test-tool', args: {} }, + correlationId: '123', + }; + + messageBus.publish(request); + + expect(requestHandler).toHaveBeenCalledWith(request); + }); + + it('should emit other message types directly', () => { + const successHandler = vi.fn(); + messageBus.subscribe( + MessageBusType.TOOL_EXECUTION_SUCCESS, + successHandler, + ); + + const message: ToolExecutionSuccess = { + type: MessageBusType.TOOL_EXECUTION_SUCCESS as const, + toolCall: { name: 'test-tool' }, + result: 'success', + }; + + messageBus.publish(message); + + expect(successHandler).toHaveBeenCalledWith(message); + }); + }); + + describe('subscribe/unsubscribe', () => { + it('should allow subscribing to specific message types', () => { + const handler = vi.fn(); + messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler); + + const message: ToolExecutionSuccess = { + type: MessageBusType.TOOL_EXECUTION_SUCCESS as const, + toolCall: { name: 'test' }, + result: 'test', + }; + + messageBus.publish(message); + + expect(handler).toHaveBeenCalledWith(message); + }); + + it('should allow unsubscribing from message types', () => { + const handler = vi.fn(); + messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler); + messageBus.unsubscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler); + + const message: ToolExecutionSuccess = { + type: MessageBusType.TOOL_EXECUTION_SUCCESS as const, + toolCall: { name: 'test' }, + result: 'test', + }; + + messageBus.publish(message); + + expect(handler).not.toHaveBeenCalled(); + }); + + it('should support multiple subscribers for the same message type', () => { + const handler1 = vi.fn(); + const handler2 = vi.fn(); + + messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler1); + messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler2); + + const message: ToolExecutionSuccess = { + type: MessageBusType.TOOL_EXECUTION_SUCCESS as const, + toolCall: { name: 'test' }, + result: 'test', + }; + + messageBus.publish(message); + + expect(handler1).toHaveBeenCalledWith(message); + expect(handler2).toHaveBeenCalledWith(message); + }); + }); + + describe('error handling', () => { + it('should not crash on errors during message processing', () => { + const errorHandler = vi.fn(); + messageBus.on('error', errorHandler); + + // Mock policyEngine to throw an error + vi.spyOn(policyEngine, 'check').mockImplementation(() => { + throw new Error('Policy check failed'); + }); + + const request: ToolConfirmationRequest = { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall: { name: 'test-tool' }, + correlationId: '123', + }; + + // Should not throw + expect(() => messageBus.publish(request)).not.toThrow(); + + // Should emit error + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: 'Policy check failed', + }), + ); + }); + }); +}); diff --git a/packages/core/src/confirmation-bus/message-bus.ts b/packages/core/src/confirmation-bus/message-bus.ts new file mode 100644 index 00000000000..b9d66eff6ab --- /dev/null +++ b/packages/core/src/confirmation-bus/message-bus.ts @@ -0,0 +1,98 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { EventEmitter } from 'node:events'; +import type { PolicyEngine } from '../policy/policy-engine.js'; +import { PolicyDecision } from '../policy/types.js'; +import { MessageBusType, type Message } from './types.js'; +import { safeJsonStringify } from '../utils/safeJsonStringify.js'; + +export class MessageBus extends EventEmitter { + constructor(private readonly policyEngine: PolicyEngine) { + super(); + } + + private isValidMessage(message: Message): boolean { + if (!message || !message.type) { + return false; + } + + if ( + message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST && + !('correlationId' in message) + ) { + return false; + } + + return true; + } + + private emitMessage(message: Message): void { + this.emit(message.type, message); + } + + publish(message: Message): void { + try { + if (!this.isValidMessage(message)) { + throw new Error( + `Invalid message structure: ${safeJsonStringify(message)}`, + ); + } + + if (message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { + const decision = this.policyEngine.check(message.toolCall); + + switch (decision) { + case PolicyDecision.ALLOW: + // Directly emit the response instead of recursive publish + this.emitMessage({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: message.correlationId, + confirmed: true, + }); + break; + case PolicyDecision.DENY: + // Emit both rejection and response messages + this.emitMessage({ + type: MessageBusType.TOOL_POLICY_REJECTION, + toolCall: message.toolCall, + }); + this.emitMessage({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: message.correlationId, + confirmed: false, + }); + break; + case PolicyDecision.ASK_USER: + // Pass through to UI for user confirmation + this.emitMessage(message); + break; + default: + throw new Error(`Unknown policy decision: ${decision}`); + } + } else { + // For all other message types, just emit them + this.emitMessage(message); + } + } catch (error) { + this.emit('error', error); + } + } + + subscribe( + type: T['type'], + listener: (message: T) => void, + ): void { + this.on(type, listener); + } + + unsubscribe( + type: T['type'], + listener: (message: T) => void, + ): void { + this.off(type, listener); + } +} diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts new file mode 100644 index 00000000000..cb86595be9a --- /dev/null +++ b/packages/core/src/confirmation-bus/types.ts @@ -0,0 +1,51 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type FunctionCall } from '@google/genai'; + +export enum MessageBusType { + TOOL_CONFIRMATION_REQUEST = 'tool-confirmation-request', + TOOL_CONFIRMATION_RESPONSE = 'tool-confirmation-response', + TOOL_POLICY_REJECTION = 'tool-policy-rejection', + TOOL_EXECUTION_SUCCESS = 'tool-execution-success', + TOOL_EXECUTION_FAILURE = 'tool-execution-failure', +} + +export interface ToolConfirmationRequest { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST; + toolCall: FunctionCall; + correlationId: string; +} + +export interface ToolConfirmationResponse { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE; + correlationId: string; + confirmed: boolean; +} + +export interface ToolPolicyRejection { + type: MessageBusType.TOOL_POLICY_REJECTION; + toolCall: FunctionCall; +} + +export interface ToolExecutionSuccess { + type: MessageBusType.TOOL_EXECUTION_SUCCESS; + toolCall: FunctionCall; + result: T; +} + +export interface ToolExecutionFailure { + type: MessageBusType.TOOL_EXECUTION_FAILURE; + toolCall: FunctionCall; + error: E; +} + +export type Message = + | ToolConfirmationRequest + | ToolConfirmationResponse + | ToolPolicyRejection + | ToolExecutionSuccess + | ToolExecutionFailure; diff --git a/packages/core/src/core/baseLlmClient.test.ts b/packages/core/src/core/baseLlmClient.test.ts new file mode 100644 index 00000000000..1b1787f5fd4 --- /dev/null +++ b/packages/core/src/core/baseLlmClient.test.ts @@ -0,0 +1,291 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mocked, +} from 'vitest'; + +import type { GenerateContentResponse } from '@google/genai'; +import { BaseLlmClient, type GenerateJsonOptions } from './baseLlmClient.js'; +import type { ContentGenerator } from './contentGenerator.js'; +import type { Config } from '../config/config.js'; +import { AuthType } from './contentGenerator.js'; +import { reportError } from '../utils/errorReporting.js'; +import { logMalformedJsonResponse } from '../telemetry/loggers.js'; +import { retryWithBackoff } from '../utils/retry.js'; +import { MalformedJsonResponseEvent } from '../telemetry/types.js'; +import { getErrorMessage } from '../utils/errors.js'; + +vi.mock('../utils/errorReporting.js'); +vi.mock('../telemetry/loggers.js'); +vi.mock('../utils/errors.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + getErrorMessage: vi.fn((e) => (e instanceof Error ? e.message : String(e))), + }; +}); + +vi.mock('../utils/retry.js', () => ({ + retryWithBackoff: vi.fn(async (fn) => await fn()), +})); + +const mockGenerateContent = vi.fn(); + +const mockContentGenerator = { + generateContent: mockGenerateContent, +} as unknown as Mocked; + +const mockConfig = { + getSessionId: vi.fn().mockReturnValue('test-session-id'), + getContentGeneratorConfig: vi + .fn() + .mockReturnValue({ authType: AuthType.USE_GEMINI }), +} as unknown as Mocked; + +// Helper to create a mock GenerateContentResponse +const createMockResponse = (text: string): GenerateContentResponse => + ({ + candidates: [{ content: { role: 'model', parts: [{ text }] }, index: 0 }], + }) as GenerateContentResponse; + +describe('BaseLlmClient', () => { + let client: BaseLlmClient; + let abortController: AbortController; + let defaultOptions: GenerateJsonOptions; + + beforeEach(() => { + vi.clearAllMocks(); + // Reset the mocked implementation for getErrorMessage for accurate error message assertions + vi.mocked(getErrorMessage).mockImplementation((e) => + e instanceof Error ? e.message : String(e), + ); + client = new BaseLlmClient(mockContentGenerator, mockConfig); + abortController = new AbortController(); + defaultOptions = { + contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }], + schema: { type: 'object', properties: { color: { type: 'string' } } }, + model: 'test-model', + abortSignal: abortController.signal, + promptId: 'test-prompt-id', + }; + }); + + afterEach(() => { + abortController.abort(); + }); + + describe('generateJson - Success Scenarios', () => { + it('should call generateContent with correct parameters, defaults, and utilize retry mechanism', async () => { + const mockResponse = createMockResponse('{"color": "blue"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + + const result = await client.generateJson(defaultOptions); + + expect(result).toEqual({ color: 'blue' }); + + // Ensure the retry mechanism was engaged + expect(retryWithBackoff).toHaveBeenCalledTimes(1); + + // Validate the parameters passed to the underlying generator + expect(mockGenerateContent).toHaveBeenCalledTimes(1); + expect(mockGenerateContent).toHaveBeenCalledWith( + { + model: 'test-model', + contents: defaultOptions.contents, + config: { + abortSignal: defaultOptions.abortSignal, + temperature: 0, + topP: 1, + responseJsonSchema: defaultOptions.schema, + responseMimeType: 'application/json', + // Crucial: systemInstruction should NOT be in the config object if not provided + }, + }, + 'test-prompt-id', + ); + }); + + it('should respect configuration overrides', async () => { + const mockResponse = createMockResponse('{"color": "red"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + + const options: GenerateJsonOptions = { + ...defaultOptions, + config: { temperature: 0.8, topK: 10 }, + }; + + await client.generateJson(options); + + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.objectContaining({ + config: expect.objectContaining({ + temperature: 0.8, + topP: 1, // Default should remain if not overridden + topK: 10, + }), + }), + expect.any(String), + ); + }); + + it('should include system instructions when provided', async () => { + const mockResponse = createMockResponse('{"color": "green"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + const systemInstruction = 'You are a helpful assistant.'; + + const options: GenerateJsonOptions = { + ...defaultOptions, + systemInstruction, + }; + + await client.generateJson(options); + + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.objectContaining({ + config: expect.objectContaining({ + systemInstruction, + }), + }), + expect.any(String), + ); + }); + + it('should use the provided promptId', async () => { + const mockResponse = createMockResponse('{"color": "yellow"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + const customPromptId = 'custom-id-123'; + + const options: GenerateJsonOptions = { + ...defaultOptions, + promptId: customPromptId, + }; + + await client.generateJson(options); + + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.any(Object), + customPromptId, + ); + }); + }); + + describe('generateJson - Response Cleaning', () => { + it('should clean JSON wrapped in markdown backticks and log telemetry', async () => { + const malformedResponse = '```json\n{"color": "purple"}\n```'; + mockGenerateContent.mockResolvedValue( + createMockResponse(malformedResponse), + ); + + const result = await client.generateJson(defaultOptions); + + expect(result).toEqual({ color: 'purple' }); + expect(logMalformedJsonResponse).toHaveBeenCalledTimes(1); + expect(logMalformedJsonResponse).toHaveBeenCalledWith( + mockConfig, + expect.any(MalformedJsonResponseEvent), + ); + // Validate the telemetry event content + const event = vi.mocked(logMalformedJsonResponse).mock + .calls[0][1] as MalformedJsonResponseEvent; + expect(event.model).toBe('test-model'); + }); + + it('should handle extra whitespace correctly without logging malformed telemetry', async () => { + const responseWithWhitespace = ' \n {"color": "orange"} \n'; + mockGenerateContent.mockResolvedValue( + createMockResponse(responseWithWhitespace), + ); + + const result = await client.generateJson(defaultOptions); + + expect(result).toEqual({ color: 'orange' }); + expect(logMalformedJsonResponse).not.toHaveBeenCalled(); + }); + }); + + describe('generateJson - Error Handling', () => { + it('should throw and report error for empty response', async () => { + mockGenerateContent.mockResolvedValue(createMockResponse('')); + + // The final error message includes the prefix added by the client's outer catch block. + await expect(client.generateJson(defaultOptions)).rejects.toThrow( + 'Failed to generate JSON content: API returned an empty response for generateJson.', + ); + + // Verify error reporting details + expect(reportError).toHaveBeenCalledTimes(1); + expect(reportError).toHaveBeenCalledWith( + expect.any(Error), + 'Error in generateJson: API returned an empty response.', + defaultOptions.contents, + 'generateJson-empty-response', + ); + }); + + it('should throw and report error for invalid JSON syntax', async () => { + const invalidJson = '{"color": "blue"'; // missing closing brace + mockGenerateContent.mockResolvedValue(createMockResponse(invalidJson)); + + await expect(client.generateJson(defaultOptions)).rejects.toThrow( + /^Failed to generate JSON content: Failed to parse API response as JSON:/, + ); + + expect(reportError).toHaveBeenCalledTimes(1); + expect(reportError).toHaveBeenCalledWith( + expect.any(Error), + 'Failed to parse JSON response from generateJson.', + expect.objectContaining({ responseTextFailedToParse: invalidJson }), + 'generateJson-parse', + ); + }); + + it('should throw and report generic API errors', async () => { + const apiError = new Error('Service Unavailable (503)'); + // Simulate the generator failing + mockGenerateContent.mockRejectedValue(apiError); + + await expect(client.generateJson(defaultOptions)).rejects.toThrow( + 'Failed to generate JSON content: Service Unavailable (503)', + ); + + // Verify generic error reporting + expect(reportError).toHaveBeenCalledTimes(1); + expect(reportError).toHaveBeenCalledWith( + apiError, + 'Error generating JSON content via API.', + defaultOptions.contents, + 'generateJson-api', + ); + }); + + it('should throw immediately without reporting if aborted', async () => { + const abortError = new DOMException('Aborted', 'AbortError'); + + // Simulate abortion happening during the API call + mockGenerateContent.mockImplementation(() => { + abortController.abort(); // Ensure the signal is aborted when the service checks + throw abortError; + }); + + const options = { + ...defaultOptions, + abortSignal: abortController.signal, + }; + + await expect(client.generateJson(options)).rejects.toThrow(abortError); + + // Crucially, it should not report a cancellation as an application error + expect(reportError).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/core/src/core/baseLlmClient.ts b/packages/core/src/core/baseLlmClient.ts new file mode 100644 index 00000000000..25a92dabdd7 --- /dev/null +++ b/packages/core/src/core/baseLlmClient.ts @@ -0,0 +1,171 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Content, GenerateContentConfig, Part } from '@google/genai'; +import type { Config } from '../config/config.js'; +import type { ContentGenerator } from './contentGenerator.js'; +import { getResponseText } from '../utils/partUtils.js'; +import { reportError } from '../utils/errorReporting.js'; +import { getErrorMessage } from '../utils/errors.js'; +import { logMalformedJsonResponse } from '../telemetry/loggers.js'; +import { MalformedJsonResponseEvent } from '../telemetry/types.js'; +import { retryWithBackoff } from '../utils/retry.js'; + +/** + * Options for the generateJson utility function. + */ +export interface GenerateJsonOptions { + /** The input prompt or history. */ + contents: Content[]; + /** The required JSON schema for the output. */ + schema: Record; + /** The specific model to use for this task. */ + model: string; + /** + * Task-specific system instructions. + * If omitted, no system instruction is sent. + */ + systemInstruction?: string | Part | Part[] | Content; + /** + * Overrides for generation configuration (e.g., temperature). + */ + config?: Omit< + GenerateContentConfig, + | 'systemInstruction' + | 'responseJsonSchema' + | 'responseMimeType' + | 'tools' + | 'abortSignal' + >; + /** Signal for cancellation. */ + abortSignal: AbortSignal; + /** + * A unique ID for the prompt, used for logging/telemetry correlation. + */ + promptId: string; +} + +/** + * A client dedicated to stateless, utility-focused LLM calls. + */ +export class BaseLlmClient { + // Default configuration for utility tasks + private readonly defaultUtilityConfig: GenerateContentConfig = { + temperature: 0, + topP: 1, + }; + + constructor( + private readonly contentGenerator: ContentGenerator, + private readonly config: Config, + ) {} + + async generateJson( + options: GenerateJsonOptions, + ): Promise> { + const { + contents, + schema, + model, + abortSignal, + systemInstruction, + promptId, + } = options; + + const requestConfig: GenerateContentConfig = { + abortSignal, + ...this.defaultUtilityConfig, + ...options.config, + ...(systemInstruction && { systemInstruction }), + responseJsonSchema: schema, + responseMimeType: 'application/json', + }; + + try { + const apiCall = () => + this.contentGenerator.generateContent( + { + model, + config: requestConfig, + contents, + }, + promptId, + ); + + const result = await retryWithBackoff(apiCall); + + let text = getResponseText(result)?.trim(); + if (!text) { + const error = new Error( + 'API returned an empty response for generateJson.', + ); + await reportError( + error, + 'Error in generateJson: API returned an empty response.', + contents, + 'generateJson-empty-response', + ); + throw error; + } + + text = this.cleanJsonResponse(text, model); + + try { + return JSON.parse(text); + } catch (parseError) { + const error = new Error( + `Failed to parse API response as JSON: ${getErrorMessage(parseError)}`, + ); + await reportError( + parseError, + 'Failed to parse JSON response from generateJson.', + { + responseTextFailedToParse: text, + originalRequestContents: contents, + }, + 'generateJson-parse', + ); + throw error; + } + } catch (error) { + if (abortSignal.aborted) { + throw error; + } + + if ( + error instanceof Error && + (error.message === 'API returned an empty response for generateJson.' || + error.message.startsWith('Failed to parse API response as JSON:')) + ) { + // We perform this check so that we don't report these again. + } else { + await reportError( + error, + 'Error generating JSON content via API.', + contents, + 'generateJson-api', + ); + } + + throw new Error( + `Failed to generate JSON content: ${getErrorMessage(error)}`, + ); + } + } + + private cleanJsonResponse(text: string, model: string): string { + const prefix = '```json'; + const suffix = '```'; + if (text.startsWith(prefix) && text.endsWith(suffix)) { + logMalformedJsonResponse( + this.config, + new MalformedJsonResponseEvent(model), + ); + return text.substring(prefix.length, text.length - suffix.length).trim(); + } + return text; + } +} diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 00d3b3e6d1d..3dd4b7b0530 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -11,17 +11,10 @@ import { vi, beforeEach, afterEach, - type Mocked, + type Mock, } from 'vitest'; -import type { - Chat, - Content, - EmbedContentResponse, - GenerateContentResponse, - Part, -} from '@google/genai'; -import { GoogleGenAI } from '@google/genai'; +import type { Content, GenerateContentResponse, Part } from '@google/genai'; import { findIndexAfterFraction, isThinkingDefault, @@ -34,7 +27,7 @@ import { type ContentGeneratorConfig, } from './contentGenerator.js'; import { type GeminiChat } from './geminiChat.js'; -import { Config } from '../config/config.js'; +import type { Config } from '../config/config.js'; import { CompressionStatus, GeminiEventType, @@ -46,8 +39,9 @@ import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { setSimulate429 } from '../utils/testUtils.js'; import { tokenLimit } from './tokenLimits.js'; -import { ideContext } from '../ide/ideContext.js'; +import { ideContextStore } from '../ide/ideContext.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; +import type { ModelRouterService } from '../routing/modelRouterService.js'; // Mock fs module to prevent actual file system operations during tests const mockFileSystem = new Map(); @@ -76,12 +70,8 @@ vi.mock('node:fs', () => { }); // --- Mocks --- -const mockChatCreateFn = vi.fn(); -const mockGenerateContentFn = vi.fn(); -const mockEmbedContentFn = vi.fn(); const mockTurnRunFn = vi.fn(); -vi.mock('@google/genai'); vi.mock('./turn', async (importOriginal) => { const actual = await importOriginal(); // Define a mock class that has the same shape as the real Turn @@ -162,8 +152,8 @@ describe('findIndexAfterFraction', () => { // 0: 66 // 1: 66 + 68 = 134 // 2: 134 + 66 = 200 - // 200 >= 166.5, so index is 2 - expect(findIndexAfterFraction(history, 0.5)).toBe(2); + // 200 >= 166.5, so index is 3 + expect(findIndexAfterFraction(history, 0.5)).toBe(3); }); it('should handle a fraction that results in the last index', () => { @@ -171,8 +161,8 @@ describe('findIndexAfterFraction', () => { // ... // 3: 200 + 68 = 268 // 4: 268 + 65 = 333 - // 333 >= 299.7, so index is 4 - expect(findIndexAfterFraction(history, 0.9)).toBe(4); + // 333 >= 299.7, so index is 5 + expect(findIndexAfterFraction(history, 0.9)).toBe(5); }); it('should handle an empty history', () => { @@ -180,7 +170,7 @@ describe('findIndexAfterFraction', () => { }); it('should handle a history with only one item', () => { - expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(0); + expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(1); }); it('should handle history with weird parts', () => { @@ -189,7 +179,7 @@ describe('findIndexAfterFraction', () => { { role: 'model', parts: [{ fileData: { fileUri: 'derp' } }] }, { role: 'user', parts: [{ text: 'Message 2' }] }, ]; - expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(1); + expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(2); }); }); @@ -228,36 +218,27 @@ describe('isThinkingDefault', () => { }); describe('Gemini Client (client.ts)', () => { + let mockContentGenerator: ContentGenerator; + let mockConfig: Config; let client: GeminiClient; + let mockGenerateContentFn: Mock; beforeEach(async () => { vi.resetAllMocks(); + mockGenerateContentFn = vi.fn().mockResolvedValue({ + candidates: [{ content: { parts: [{ text: '{"key": "value"}' }] } }], + }); + // Disable 429 simulation for tests setSimulate429(false); - // Set up the mock for GoogleGenAI constructor and its methods - const MockedGoogleGenAI = vi.mocked(GoogleGenAI); - MockedGoogleGenAI.mockImplementation(() => { - const mock = { - chats: { create: mockChatCreateFn }, - models: { - generateContent: mockGenerateContentFn, - embedContent: mockEmbedContentFn, - }, - }; - return mock as unknown as GoogleGenAI; - }); - - mockChatCreateFn.mockResolvedValue({} as Chat); - mockGenerateContentFn.mockResolvedValue({ - candidates: [ - { - content: { - parts: [{ text: '{"key": "value"}' }], - }, - }, - ], - } as unknown as GenerateContentResponse); + mockContentGenerator = { + generateContent: mockGenerateContentFn, + generateContentStream: vi.fn(), + countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }), + embedContent: vi.fn(), + batchEmbedContents: vi.fn(), + } as unknown as ContentGenerator; // Because the GeminiClient constructor kicks off an async process (startChat) // that depends on a fully-formed Config object, we need to mock the @@ -268,12 +249,11 @@ describe('Gemini Client (client.ts)', () => { }; const fileService = new FileDiscoveryService('/test/dir'); const contentGeneratorConfig: ContentGeneratorConfig = { - model: 'test-model', apiKey: 'test-key', vertexai: false, authType: AuthType.USE_GEMINI, }; - const mockConfigObject = { + mockConfig = { getContentGeneratorConfig: vi .fn() .mockReturnValue(contentGeneratorConfig), @@ -301,6 +281,10 @@ describe('Gemini Client (client.ts)', () => { getDirectories: vi.fn().mockReturnValue(['/test/dir']), }), getGeminiClient: vi.fn(), + getModelRouterService: vi.fn().mockReturnValue({ + route: vi.fn().mockResolvedValue({ model: 'default-routed-model' }), + }), + isInFallbackMode: vi.fn().mockReturnValue(false), setFallbackMode: vi.fn(), getChatCompression: vi.fn().mockReturnValue(undefined), getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false), @@ -309,46 +293,18 @@ describe('Gemini Client (client.ts)', () => { storage: { getProjectTempDir: vi.fn().mockReturnValue('/test/temp'), }, - }; - const MockedConfig = vi.mocked(Config, true); - MockedConfig.mockImplementation( - () => mockConfigObject as unknown as Config, - ); + getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator), + } as unknown as Config; - // We can instantiate the client here since Config is mocked - // and the constructor will use the mocked GoogleGenAI - client = new GeminiClient( - new Config({ sessionId: 'test-session-id' } as never), - ); - mockConfigObject.getGeminiClient.mockReturnValue(client); - - await client.initialize(contentGeneratorConfig); + client = new GeminiClient(mockConfig); + await client.initialize(); + vi.mocked(mockConfig.getGeminiClient).mockReturnValue(client); }); afterEach(() => { vi.restoreAllMocks(); }); - // NOTE: The following tests for startChat were removed due to persistent issues with - // the @google/genai mock. Specifically, the mockChatCreateFn (representing instance.chats.create) - // was not being detected as called by the GeminiClient instance. - // This likely points to a subtle issue in how the GoogleGenerativeAI class constructor - // and its instance methods are mocked and then used by the class under test. - // For future debugging, ensure that the `this.client` in `GeminiClient` (which is an - // instance of the mocked GoogleGenerativeAI) correctly has its `chats.create` method - // pointing to `mockChatCreateFn`. - // it('startChat should call getCoreSystemPrompt with userMemory and pass to chats.create', async () => { ... }); - // it('startChat should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... }); - - // NOTE: The following tests for generateJson were removed due to persistent issues with - // the @google/genai mock, similar to the startChat tests. The mockGenerateContentFn - // (representing instance.models.generateContent) was not being detected as called, or the mock - // was not preventing an actual API call (leading to API key errors). - // For future debugging, ensure `this.client.models.generateContent` in `GeminiClient` correctly - // uses the `mockGenerateContentFn`. - // it('generateJson should call getCoreSystemPrompt with userMemory and pass to generateContent', async () => { ... }); - // it('generateJson should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... }); - describe('generateEmbedding', () => { const texts = ['hello world', 'goodbye world']; const testEmbeddingModel = 'test-embedding-model'; @@ -358,18 +314,17 @@ describe('Gemini Client (client.ts)', () => { [0.1, 0.2, 0.3], [0.4, 0.5, 0.6], ]; - const mockResponse: EmbedContentResponse = { + vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ embeddings: [ { values: mockEmbeddings[0] }, { values: mockEmbeddings[1] }, ], - }; - mockEmbedContentFn.mockResolvedValue(mockResponse); + }); const result = await client.generateEmbedding(texts); - expect(mockEmbedContentFn).toHaveBeenCalledTimes(1); - expect(mockEmbedContentFn).toHaveBeenCalledWith({ + expect(mockContentGenerator.embedContent).toHaveBeenCalledTimes(1); + expect(mockContentGenerator.embedContent).toHaveBeenCalledWith({ model: testEmbeddingModel, contents: texts, }); @@ -379,11 +334,11 @@ describe('Gemini Client (client.ts)', () => { it('should return an empty array if an empty array is passed', async () => { const result = await client.generateEmbedding([]); expect(result).toEqual([]); - expect(mockEmbedContentFn).not.toHaveBeenCalled(); + expect(mockContentGenerator.embedContent).not.toHaveBeenCalled(); }); it('should throw an error if API response has no embeddings array', async () => { - mockEmbedContentFn.mockResolvedValue({} as EmbedContentResponse); // No `embeddings` key + vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({}); await expect(client.generateEmbedding(texts)).rejects.toThrow( 'No embeddings found in API response.', @@ -391,20 +346,19 @@ describe('Gemini Client (client.ts)', () => { }); it('should throw an error if API response has an empty embeddings array', async () => { - const mockResponse: EmbedContentResponse = { + vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ embeddings: [], - }; - mockEmbedContentFn.mockResolvedValue(mockResponse); + }); + await expect(client.generateEmbedding(texts)).rejects.toThrow( 'No embeddings found in API response.', ); }); it('should throw an error if API returns a mismatched number of embeddings', async () => { - const mockResponse: EmbedContentResponse = { + vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ embeddings: [{ values: [1, 2, 3] }], // Only one for two texts - }; - mockEmbedContentFn.mockResolvedValue(mockResponse); + }); await expect(client.generateEmbedding(texts)).rejects.toThrow( 'API returned a mismatched number of embeddings. Expected 2, got 1.', @@ -412,10 +366,9 @@ describe('Gemini Client (client.ts)', () => { }); it('should throw an error if any embedding has nullish values', async () => { - const mockResponse: EmbedContentResponse = { + vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad - }; - mockEmbedContentFn.mockResolvedValue(mockResponse); + }); await expect(client.generateEmbedding(texts)).rejects.toThrow( 'API returned an empty embedding for input text at index 1: "goodbye world"', @@ -423,10 +376,9 @@ describe('Gemini Client (client.ts)', () => { }); it('should throw an error if any embedding has an empty values array', async () => { - const mockResponse: EmbedContentResponse = { + vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad - }; - mockEmbedContentFn.mockResolvedValue(mockResponse); + }); await expect(client.generateEmbedding(texts)).rejects.toThrow( 'API returned an empty embedding for input text at index 0: "hello world"', @@ -434,8 +386,9 @@ describe('Gemini Client (client.ts)', () => { }); it('should propagate errors from the API call', async () => { - const apiError = new Error('API Failure'); - mockEmbedContentFn.mockRejectedValue(apiError); + vi.mocked(mockContentGenerator.embedContent).mockRejectedValue( + new Error('API Failure'), + ); await expect(client.generateEmbedding(texts)).rejects.toThrow( 'API Failure', @@ -449,12 +402,9 @@ describe('Gemini Client (client.ts)', () => { const schema = { type: 'string' }; const abortSignal = new AbortController().signal; - // Mock countTokens - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ + totalTokens: 1, + }); await client.generateJson( contents, @@ -463,7 +413,7 @@ describe('Gemini Client (client.ts)', () => { DEFAULT_GEMINI_FLASH_MODEL, ); - expect(mockGenerateContentFn).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( { model: DEFAULT_GEMINI_FLASH_MODEL, config: { @@ -489,11 +439,9 @@ describe('Gemini Client (client.ts)', () => { const customModel = 'custom-json-model'; const customConfig = { temperature: 0.9, topK: 20 }; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ + totalTokens: 1, + }); await client.generateJson( contents, @@ -503,7 +451,7 @@ describe('Gemini Client (client.ts)', () => { customConfig, ); - expect(mockGenerateContentFn).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( { model: customModel, config: { @@ -520,14 +468,35 @@ describe('Gemini Client (client.ts)', () => { 'test-session-id', ); }); + + it('should use the Flash model when fallback mode is active', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const schema = { type: 'string' }; + const abortSignal = new AbortController().signal; + const requestedModel = 'gemini-2.5-pro'; // A non-flash model + + // Mock config to be in fallback mode + // We access the mock via the client instance which holds the mocked config + vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true); + + await client.generateJson(contents, schema, abortSignal, requestedModel); + + // Assert that the Flash model was used, not the requested model + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: DEFAULT_GEMINI_FLASH_MODEL, + }), + 'test-session-id', + ); + }); }); describe('addHistory', () => { it('should call chat.addHistory with the provided content', async () => { - const mockChat: Partial = { + const mockChat = { addHistory: vi.fn(), - }; - client['chat'] = mockChat as GeminiChat; + } as unknown as GeminiChat; + client['chat'] = mockChat; const newContent = { role: 'user', @@ -568,8 +537,6 @@ describe('Gemini Client (client.ts)', () => { }); describe('tryCompressChat', () => { - const mockCountTokens = vi.fn(); - const mockSendMessage = vi.fn(); const mockGetHistory = vi.fn(); beforeEach(() => { @@ -577,15 +544,10 @@ describe('Gemini Client (client.ts)', () => { tokenLimit: vi.fn(), })); - client['contentGenerator'] = { - countTokens: mockCountTokens, - } as unknown as ContentGenerator; - client['chat'] = { getHistory: mockGetHistory, addHistory: vi.fn(), setHistory: vi.fn(), - sendMessage: mockSendMessage, } as unknown as GeminiChat; }); @@ -598,29 +560,22 @@ describe('Gemini Client (client.ts)', () => { const mockChat: Partial = { getHistory: vi.fn().mockReturnValue(chatHistory), setHistory: vi.fn(), - sendMessage: vi.fn().mockResolvedValue({ text: 'Summary' }), }; - const mockCountTokens = vi - .fn() + vi.mocked(mockContentGenerator.countTokens) .mockResolvedValueOnce({ totalTokens: 1000 }) .mockResolvedValueOnce({ totalTokens: 5000 }); - const mockGenerator: Partial> = { - countTokens: mockCountTokens, - }; - client['chat'] = mockChat as GeminiChat; - client['contentGenerator'] = mockGenerator as ContentGenerator; client['startChat'] = vi.fn().mockResolvedValue({ ...mockChat }); - return { client, mockChat, mockGenerator }; + return { client, mockChat }; } describe('when compression inflates the token count', () => { - it('uses the truncated history for compression'); it('allows compression to be forced/manual after a failure', async () => { - const { client, mockGenerator } = setup(); - mockGenerator.countTokens?.mockResolvedValue({ + const { client } = setup(); + + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ totalTokens: 1000, }); await client.tryCompressChat('prompt-id-4'); // Fails @@ -635,6 +590,9 @@ describe('Gemini Client (client.ts)', () => { it('yields the result even if the compression inflated the tokens', async () => { const { client } = setup(); + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ + totalTokens: 1000, + }); const result = await client.tryCompressChat('prompt-id-4', true); expect(result).toEqual({ @@ -654,7 +612,7 @@ describe('Gemini Client (client.ts)', () => { it('restores the history back to the original', async () => { vi.mocked(tokenLimit).mockReturnValue(1000); - mockCountTokens.mockResolvedValue({ + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ totalTokens: 999, }); @@ -679,13 +637,13 @@ describe('Gemini Client (client.ts)', () => { }); it('will not attempt to compress context after a failure', async () => { - const { client, mockGenerator } = setup(); + const { client } = setup(); await client.tryCompressChat('prompt-id-4'); const result = await client.tryCompressChat('prompt-id-5'); // it counts tokens for {original, compressed} and then never again - expect(mockGenerator.countTokens).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.countTokens).toHaveBeenCalledTimes(2); expect(result).toEqual({ compressionStatus: CompressionStatus.NOOP, newTokenCount: 0, @@ -694,42 +652,13 @@ describe('Gemini Client (client.ts)', () => { }); }); - it('attempts to compress with a maxOutputTokens set to the original token count', async () => { - vi.mocked(tokenLimit).mockReturnValue(1000); - mockCountTokens.mockResolvedValue({ - totalTokens: 999, - }); - - mockGetHistory.mockReturnValue([ - { role: 'user', parts: [{ text: '...history...' }] }, - ]); - - // Mock the summary response from the chat - mockSendMessage.mockResolvedValue({ - role: 'model', - parts: [{ text: 'This is a summary.' }], - }); - - await client.tryCompressChat('prompt-id-2', true); - - expect(mockSendMessage).toHaveBeenCalledWith( - expect.objectContaining({ - config: expect.objectContaining({ - maxOutputTokens: 999, - }), - }), - 'prompt-id-2', - ); - }); - it('should not trigger summarization if token count is below threshold', async () => { const MOCKED_TOKEN_LIMIT = 1000; vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT); mockGetHistory.mockReturnValue([ { role: 'user', parts: [{ text: '...history...' }] }, ]); - - mockCountTokens.mockResolvedValue({ + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ totalTokens: MOCKED_TOKEN_LIMIT * 0.699, // TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7 }); @@ -763,15 +692,21 @@ describe('Gemini Client (client.ts)', () => { MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD; const newTokenCount = 100; - mockCountTokens + vi.mocked(mockContentGenerator.countTokens) .mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check .mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history // Mock the summary response from the chat - mockSendMessage.mockResolvedValue({ - role: 'model', - parts: [{ text: 'This is a summary.' }], - }); + mockGenerateContentFn.mockResolvedValue({ + candidates: [ + { + content: { + role: 'model', + parts: [{ text: 'This is a summary.' }], + }, + }, + ], + } as unknown as GenerateContentResponse); await client.tryCompressChat('prompt-id-3'); @@ -800,22 +735,28 @@ describe('Gemini Client (client.ts)', () => { MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD; const newTokenCount = 100; - mockCountTokens + vi.mocked(mockContentGenerator.countTokens) .mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check .mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history // Mock the summary response from the chat - mockSendMessage.mockResolvedValue({ - role: 'model', - parts: [{ text: 'This is a summary.' }], - }); + mockGenerateContentFn.mockResolvedValue({ + candidates: [ + { + content: { + role: 'model', + parts: [{ text: 'This is a summary.' }], + }, + }, + ], + } as unknown as GenerateContentResponse); const initialChat = client.getChat(); const result = await client.tryCompressChat('prompt-id-3'); const newChat = client.getChat(); expect(tokenLimit).toHaveBeenCalled(); - expect(mockSendMessage).toHaveBeenCalled(); + expect(mockGenerateContentFn).toHaveBeenCalled(); // Assert that summarization happened and returned the correct stats expect(result).toEqual({ @@ -853,22 +794,28 @@ describe('Gemini Client (client.ts)', () => { const originalTokenCount = 1000 * 0.7; const newTokenCount = 100; - mockCountTokens + vi.mocked(mockContentGenerator.countTokens) .mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check .mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history // Mock the summary response from the chat - mockSendMessage.mockResolvedValue({ - role: 'model', - parts: [{ text: 'This is a summary.' }], - }); + mockGenerateContentFn.mockResolvedValue({ + candidates: [ + { + content: { + role: 'model', + parts: [{ text: 'This is a summary.' }], + }, + }, + ], + } as unknown as GenerateContentResponse); const initialChat = client.getChat(); const result = await client.tryCompressChat('prompt-id-3'); const newChat = client.getChat(); expect(tokenLimit).toHaveBeenCalled(); - expect(mockSendMessage).toHaveBeenCalled(); + expect(mockGenerateContentFn).toHaveBeenCalled(); // Assert that summarization happened and returned the correct stats expect(result).toEqual({ @@ -895,21 +842,27 @@ describe('Gemini Client (client.ts)', () => { const originalTokenCount = 10; // Well below threshold const newTokenCount = 5; - mockCountTokens + vi.mocked(mockContentGenerator.countTokens) .mockResolvedValueOnce({ totalTokens: originalTokenCount }) .mockResolvedValueOnce({ totalTokens: newTokenCount }); // Mock the summary response from the chat - mockSendMessage.mockResolvedValue({ - role: 'model', - parts: [{ text: 'This is a summary.' }], - }); + mockGenerateContentFn.mockResolvedValue({ + candidates: [ + { + content: { + role: 'model', + parts: [{ text: 'This is a summary.' }], + }, + }, + ], + } as unknown as GenerateContentResponse); const initialChat = client.getChat(); const result = await client.tryCompressChat('prompt-id-1', true); // force = true const newChat = client.getChat(); - expect(mockSendMessage).toHaveBeenCalled(); + expect(mockGenerateContentFn).toHaveBeenCalled(); expect(result).toEqual({ compressionStatus: CompressionStatus.COMPRESSED, @@ -922,10 +875,16 @@ describe('Gemini Client (client.ts)', () => { }); it('should use current model from config for token counting after sendMessage', async () => { - const initialModel = client['config'].getModel(); + const initialModel = mockConfig.getModel(); - const mockCountTokens = vi - .fn() + // mock the model has been changed between calls of `countTokens` + const firstCurrentModel = initialModel + '-changed-1'; + const secondCurrentModel = initialModel + '-changed-2'; + vi.mocked(mockConfig.getModel) + .mockReturnValueOnce(firstCurrentModel) + .mockReturnValueOnce(secondCurrentModel); + + vi.mocked(mockContentGenerator.countTokens) .mockResolvedValueOnce({ totalTokens: 100000 }) .mockResolvedValueOnce({ totalTokens: 5000 }); @@ -936,35 +895,23 @@ describe('Gemini Client (client.ts)', () => { { role: 'model', parts: [{ text: 'Long response' }] }, ]; - const mockChat: Partial = { + const mockChat = { getHistory: vi.fn().mockReturnValue(mockChatHistory), setHistory: vi.fn(), sendMessage: mockSendMessage, - }; - - const mockGenerator: Partial = { - countTokens: mockCountTokens, - }; - - // mock the model has been changed between calls of `countTokens` - const firstCurrentModel = initialModel + '-changed-1'; - const secondCurrentModel = initialModel + '-changed-2'; - vi.spyOn(client['config'], 'getModel') - .mockReturnValueOnce(firstCurrentModel) - .mockReturnValueOnce(secondCurrentModel); + } as unknown as GeminiChat; - client['chat'] = mockChat as GeminiChat; - client['contentGenerator'] = mockGenerator as ContentGenerator; + client['chat'] = mockChat; client['startChat'] = vi.fn().mockResolvedValue(mockChat); const result = await client.tryCompressChat('prompt-id-4', true); - expect(mockCountTokens).toHaveBeenCalledTimes(2); - expect(mockCountTokens).toHaveBeenNthCalledWith(1, { + expect(mockContentGenerator.countTokens).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.countTokens).toHaveBeenNthCalledWith(1, { model: firstCurrentModel, contents: mockChatHistory, }); - expect(mockCountTokens).toHaveBeenNthCalledWith(2, { + expect(mockContentGenerator.countTokens).toHaveBeenNthCalledWith(2, { model: secondCurrentModel, contents: expect.any(Array), }); @@ -980,22 +927,11 @@ describe('Gemini Client (client.ts)', () => { describe('sendMessageStream', () => { it('emits a compression event when the context was automatically compressed', async () => { // Arrange - const mockStream = (async function* () { - yield { type: 'content', value: 'Hello' }; - })(); - mockTurnRunFn.mockReturnValue(mockStream); - - const mockChat: Partial = { - addHistory: vi.fn(), - getHistory: vi.fn().mockReturnValue([]), - }; - client['chat'] = mockChat as GeminiChat; - - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; + mockTurnRunFn.mockReturnValue( + (async function* () { + yield { type: 'content', value: 'Hello' }; + })(), + ); const compressionInfo: ChatCompressionInfo = { compressionStatus: CompressionStatus.COMPRESSED, @@ -1042,18 +978,6 @@ describe('Gemini Client (client.ts)', () => { })(); mockTurnRunFn.mockReturnValue(mockStream); - const mockChat: Partial = { - addHistory: vi.fn(), - getHistory: vi.fn().mockReturnValue([]), - }; - client['chat'] = mockChat as GeminiChat; - - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - const compressionInfo: ChatCompressionInfo = { compressionStatus, originalTokenCount: 1000, @@ -1083,7 +1007,7 @@ describe('Gemini Client (client.ts)', () => { it('should include editor context when ideMode is enabled', async () => { // Arrange - vi.mocked(ideContext.getIdeContext).mockReturnValue({ + vi.mocked(ideContextStore.get).mockReturnValue({ workspaceState: { openFiles: [ { @@ -1105,24 +1029,19 @@ describe('Gemini Client (client.ts)', () => { }, }); - vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true); + vi.mocked(mockConfig.getIdeMode).mockReturnValue(true); - const mockStream = (async function* () { - yield { type: 'content', value: 'Hello' }; - })(); - mockTurnRunFn.mockReturnValue(mockStream); + mockTurnRunFn.mockReturnValue( + (async function* () { + yield { type: 'content', value: 'Hello' }; + })(), + ); - const mockChat: Partial = { + const mockChat = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), - }; - client['chat'] = mockChat as GeminiChat; - - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; + } as unknown as GeminiChat; + client['chat'] = mockChat; const initialRequest: Part[] = [{ text: 'Hi' }]; @@ -1137,7 +1056,7 @@ describe('Gemini Client (client.ts)', () => { } // Assert - expect(ideContext.getIdeContext).toHaveBeenCalled(); + expect(ideContextStore.get).toHaveBeenCalled(); const expectedContext = ` Here is the user's editor context as a JSON object. This is for your information only. \`\`\`json @@ -1167,7 +1086,7 @@ ${JSON.stringify( it('should not add context if ideMode is enabled but no open files', async () => { // Arrange - vi.mocked(ideContext.getIdeContext).mockReturnValue({ + vi.mocked(ideContextStore.get).mockReturnValue({ workspaceState: { openFiles: [], }, @@ -1186,12 +1105,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - const initialRequest = [{ text: 'Hi' }]; // Act @@ -1205,8 +1118,13 @@ ${JSON.stringify( } // Assert - expect(ideContext.getIdeContext).toHaveBeenCalled(); + expect(ideContextStore.get).toHaveBeenCalled(); + // The `turn.run` method is now called with the model name as the first + // argument. We use `expect.any(String)` because this test is + // concerned with the IDE context logic, not the model routing, + // which is tested in its own dedicated suite. expect(mockTurnRunFn).toHaveBeenCalledWith( + expect.any(String), initialRequest, expect.any(Object), ); @@ -1214,7 +1132,7 @@ ${JSON.stringify( it('should add context if ideMode is enabled and there is one active file', async () => { // Arrange - vi.mocked(ideContext.getIdeContext).mockReturnValue({ + vi.mocked(ideContextStore.get).mockReturnValue({ workspaceState: { openFiles: [ { @@ -1241,12 +1159,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - const initialRequest = [{ text: 'Hi' }]; // Act @@ -1260,7 +1172,7 @@ ${JSON.stringify( } // Assert - expect(ideContext.getIdeContext).toHaveBeenCalled(); + expect(ideContextStore.get).toHaveBeenCalled(); const expectedContext = ` Here is the user's editor context as a JSON object. This is for your information only. \`\`\`json @@ -1289,7 +1201,7 @@ ${JSON.stringify( it('should add context if ideMode is enabled and there are open files but no active file', async () => { // Arrange - vi.mocked(ideContext.getIdeContext).mockReturnValue({ + vi.mocked(ideContextStore.get).mockReturnValue({ workspaceState: { openFiles: [ { @@ -1317,12 +1229,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - const initialRequest = [{ text: 'Hi' }]; // Act @@ -1336,7 +1242,7 @@ ${JSON.stringify( } // Assert - expect(ideContext.getIdeContext).toHaveBeenCalled(); + expect(ideContextStore.get).toHaveBeenCalled(); const expectedContext = ` Here is the user's editor context as a JSON object. This is for your information only. \`\`\`json @@ -1369,12 +1275,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - // Act const stream = client.sendMessageStream( [{ text: 'Hi' }], @@ -1419,12 +1319,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - // Use a signal that never gets aborted const abortController = new AbortController(); const signal = abortController.signal; @@ -1512,12 +1406,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - // Act & Assert // Run up to the limit for (let i = 0; i < MAX_SESSION_TURNS; i++) { @@ -1574,12 +1462,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - // Use a signal that never gets aborted const abortController = new AbortController(); const signal = abortController.signal; @@ -1632,6 +1514,178 @@ ${JSON.stringify( ); }); + describe('Model Routing', () => { + let mockRouterService: { route: Mock }; + + beforeEach(() => { + mockRouterService = { + route: vi + .fn() + .mockResolvedValue({ model: 'routed-model', reason: 'test' }), + }; + vi.mocked(mockConfig.getModelRouterService).mockReturnValue( + mockRouterService as unknown as ModelRouterService, + ); + + mockTurnRunFn.mockReturnValue( + (async function* () { + yield { type: 'content', value: 'Hello' }; + })(), + ); + }); + + it('should use the model router service to select a model on the first turn', async () => { + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-1', + ); + await fromAsync(stream); // consume stream + + expect(mockConfig.getModelRouterService).toHaveBeenCalled(); + expect(mockRouterService.route).toHaveBeenCalled(); + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'routed-model', // The model from the router + [{ text: 'Hi' }], + expect.any(Object), + ); + }); + + it('should use the same model for subsequent turns in the same prompt (stickiness)', async () => { + // First turn + let stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-1', + ); + await fromAsync(stream); + + expect(mockRouterService.route).toHaveBeenCalledTimes(1); + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'routed-model', + [{ text: 'Hi' }], + expect.any(Object), + ); + + // Second turn + stream = client.sendMessageStream( + [{ text: 'Continue' }], + new AbortController().signal, + 'prompt-1', + ); + await fromAsync(stream); + + // Router should not be called again + expect(mockRouterService.route).toHaveBeenCalledTimes(1); + // Should stick to the first model + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'routed-model', + [{ text: 'Continue' }], + expect.any(Object), + ); + }); + + it('should reset the sticky model and re-route when the prompt_id changes', async () => { + // First prompt + let stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-1', + ); + await fromAsync(stream); + + expect(mockRouterService.route).toHaveBeenCalledTimes(1); + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'routed-model', + [{ text: 'Hi' }], + expect.any(Object), + ); + + // New prompt + mockRouterService.route.mockResolvedValue({ + model: 'new-routed-model', + reason: 'test', + }); + stream = client.sendMessageStream( + [{ text: 'A new topic' }], + new AbortController().signal, + 'prompt-2', + ); + await fromAsync(stream); + + // Router should be called again for the new prompt + expect(mockRouterService.route).toHaveBeenCalledTimes(2); + // Should use the newly routed model + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'new-routed-model', + [{ text: 'A new topic' }], + expect.any(Object), + ); + }); + + it('should use the fallback model and bypass routing when in fallback mode', async () => { + vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true); + mockRouterService.route.mockResolvedValue({ + model: DEFAULT_GEMINI_FLASH_MODEL, + reason: 'fallback', + }); + + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-1', + ); + await fromAsync(stream); + + expect(mockTurnRunFn).toHaveBeenCalledWith( + DEFAULT_GEMINI_FLASH_MODEL, + [{ text: 'Hi' }], + expect.any(Object), + ); + }); + + it('should stick to the fallback model for the entire sequence even if fallback mode ends', async () => { + // Start the sequence in fallback mode + vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true); + mockRouterService.route.mockResolvedValue({ + model: DEFAULT_GEMINI_FLASH_MODEL, + reason: 'fallback', + }); + let stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-fallback-stickiness', + ); + await fromAsync(stream); + + // First call should use fallback model + expect(mockTurnRunFn).toHaveBeenCalledWith( + DEFAULT_GEMINI_FLASH_MODEL, + [{ text: 'Hi' }], + expect.any(Object), + ); + + // End fallback mode + vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false); + + // Second call in the same sequence + stream = client.sendMessageStream( + [{ text: 'Continue' }], + new AbortController().signal, + 'prompt-fallback-stickiness', + ); + await fromAsync(stream); + + // Router should still not be called, and it should stick to the fallback model + expect(mockTurnRunFn).toHaveBeenCalledTimes(2); // Ensure it was called again + expect(mockTurnRunFn).toHaveBeenLastCalledWith( + DEFAULT_GEMINI_FLASH_MODEL, // Still the fallback model + [{ text: 'Continue' }], + expect.any(Object), + ); + }); + }); + describe('Editor context delta', () => { const mockStream = (async function* () { yield { type: 'content', value: 'Hello' }; @@ -1650,7 +1704,6 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), setHistory: vi.fn(), - sendMessage: vi.fn().mockResolvedValue({ text: 'summary' }), // Assume history is not empty for delta checks getHistory: vi .fn() @@ -1659,12 +1712,6 @@ ${JSON.stringify( ]), }; client['chat'] = mockChat as GeminiChat; - - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; }); const testCases = [ @@ -1789,7 +1836,7 @@ ${JSON.stringify( }; // Setup current context - vi.mocked(ideContext.getIdeContext).mockReturnValue({ + vi.mocked(ideContextStore.get).mockReturnValue({ workspaceState: { openFiles: [ { ...currentActiveFile, isActive: true, timestamp: Date.now() }, @@ -1851,7 +1898,7 @@ ${JSON.stringify( }; // Setup current context (same as previous) - vi.mocked(ideContext.getIdeContext).mockReturnValue({ + vi.mocked(ideContextStore.get).mockReturnValue({ workspaceState: { openFiles: [ { ...activeFile, isActive: true, timestamp: Date.now() }, @@ -1917,17 +1964,11 @@ ${JSON.stringify( addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), // Default empty history setHistory: vi.fn(), - sendMessage: vi.fn().mockResolvedValue({ text: 'summary' }), }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true); - vi.mocked(ideContext.getIdeContext).mockReturnValue({ + vi.mocked(ideContextStore.get).mockReturnValue({ workspaceState: { openFiles: [{ path: '/path/to/file.ts', timestamp: Date.now() }], }, @@ -2024,7 +2065,7 @@ ${JSON.stringify( openFiles: [{ path: '/path/to/fileA.ts', timestamp: Date.now() }], }, }; - vi.mocked(ideContext.getIdeContext).mockReturnValue(initialIdeContext); + vi.mocked(ideContextStore.get).mockReturnValue(initialIdeContext); // Act: Send the tool response let stream = client.sendMessageStream( @@ -2083,7 +2124,7 @@ ${JSON.stringify( openFiles: [{ path: '/path/to/fileB.ts', timestamp: Date.now() }], }, }; - vi.mocked(ideContext.getIdeContext).mockReturnValue(newIdeContext); + vi.mocked(ideContextStore.get).mockReturnValue(newIdeContext); // Act: Send a new, regular user message stream = client.sendMessageStream( @@ -2124,7 +2165,7 @@ ${JSON.stringify( ], }, }; - vi.mocked(ideContext.getIdeContext).mockReturnValue(contextA); + vi.mocked(ideContextStore.get).mockReturnValue(contextA); // Act: Send a regular message to establish the initial context let stream = client.sendMessageStream( @@ -2167,7 +2208,7 @@ ${JSON.stringify( ], }, }; - vi.mocked(ideContext.getIdeContext).mockReturnValue(contextB); + vi.mocked(ideContextStore.get).mockReturnValue(contextB); // Act: Send the tool response stream = client.sendMessageStream( @@ -2222,7 +2263,7 @@ ${JSON.stringify( ], }, }; - vi.mocked(ideContext.getIdeContext).mockReturnValue(contextC); + vi.mocked(ideContextStore.get).mockReturnValue(contextC); // Act: Send a new, regular user message stream = client.sendMessageStream( @@ -2266,12 +2307,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - // Act const stream = client.sendMessageStream( [{ text: 'Hi' }], @@ -2308,12 +2343,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - // Act const stream = client.sendMessageStream( [{ text: 'Hi' }], @@ -2335,13 +2364,6 @@ ${JSON.stringify( const generationConfig = { temperature: 0.5 }; const abortSignal = new AbortController().signal; - // Mock countTokens - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - await client.generateContent( contents, generationConfig, @@ -2349,7 +2371,7 @@ ${JSON.stringify( DEFAULT_GEMINI_FLASH_MODEL, ); - expect(mockGenerateContentFn).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( { model: DEFAULT_GEMINI_FLASH_MODEL, config: { @@ -2371,12 +2393,6 @@ ${JSON.stringify( vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel); - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - await client.generateContent( contents, {}, @@ -2384,12 +2400,12 @@ ${JSON.stringify( DEFAULT_GEMINI_FLASH_MODEL, ); - expect(mockGenerateContentFn).not.toHaveBeenCalledWith({ + expect(mockContentGenerator.generateContent).not.toHaveBeenCalledWith({ model: initialModel, config: expect.any(Object), contents, }); - expect(mockGenerateContentFn).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( { model: DEFAULT_GEMINI_FLASH_MODEL, config: expect.any(Object), @@ -2398,102 +2414,29 @@ ${JSON.stringify( 'test-session-id', ); }); - }); - describe('handleFlashFallback', () => { - it('should use current model from config when checking for fallback', async () => { - const initialModel = client['config'].getModel(); - const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; - - // mock config been changed - const currentModel = initialModel + '-changed'; - const getModelSpy = vi.spyOn(client['config'], 'getModel'); - getModelSpy.mockReturnValue(currentModel); + it('should use the Flash model when fallback mode is active', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const generationConfig = { temperature: 0.5 }; + const abortSignal = new AbortController().signal; + const requestedModel = 'gemini-2.5-pro'; // A non-flash model - const mockFallbackHandler = vi.fn().mockResolvedValue(true); - client['config'].flashFallbackHandler = mockFallbackHandler; - client['config'].setModel = vi.fn(); + // Mock config to be in fallback mode + vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true); - const result = await client['handleFlashFallback']( - AuthType.LOGIN_WITH_GOOGLE, + await client.generateContent( + contents, + generationConfig, + abortSignal, + requestedModel, ); - expect(result).toBe(fallbackModel); - - expect(mockFallbackHandler).toHaveBeenCalledWith( - currentModel, - fallbackModel, - undefined, + expect(mockGenerateContentFn).toHaveBeenCalledWith( + expect.objectContaining({ + model: DEFAULT_GEMINI_FLASH_MODEL, + }), + 'test-session-id', ); }); }); - - describe('setHistory', () => { - it('should strip thought signatures when stripThoughts is true', () => { - const mockChat = { - setHistory: vi.fn(), - }; - client['chat'] = mockChat as unknown as GeminiChat; - - const historyWithThoughts: Content[] = [ - { - role: 'user', - parts: [{ text: 'hello' }], - }, - { - role: 'model', - parts: [ - { text: 'thinking...', thoughtSignature: 'thought-123' }, - { - functionCall: { name: 'test', args: {} }, - thoughtSignature: 'thought-456', - }, - ], - }, - ]; - - client.setHistory(historyWithThoughts, { stripThoughts: true }); - - const expectedHistory: Content[] = [ - { - role: 'user', - parts: [{ text: 'hello' }], - }, - { - role: 'model', - parts: [ - { text: 'thinking...' }, - { functionCall: { name: 'test', args: {} } }, - ], - }, - ]; - - expect(mockChat.setHistory).toHaveBeenCalledWith(expectedHistory); - }); - - it('should not strip thought signatures when stripThoughts is false', () => { - const mockChat = { - setHistory: vi.fn(), - }; - client['chat'] = mockChat as unknown as GeminiChat; - - const historyWithThoughts: Content[] = [ - { - role: 'user', - parts: [{ text: 'hello' }], - }, - { - role: 'model', - parts: [ - { text: 'thinking...', thoughtSignature: 'thought-123' }, - { text: 'ok', thoughtSignature: 'thought-456' }, - ], - }, - ]; - - client.setHistory(historyWithThoughts, { stripThoughts: false }); - - expect(mockChat.setHistory).toHaveBeenCalledWith(historyWithThoughts); - }); - }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index d00504dc5b0..21392ed12be 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -20,7 +20,6 @@ import type { ServerGeminiStreamEvent, ChatCompressionInfo } from './turn.js'; import { CompressionStatus } from './turn.js'; import { Turn, GeminiEventType } from './turn.js'; import type { Config } from '../config/config.js'; -import type { UserTierId } from '../code_assist/types.js'; import { getCoreSystemPrompt, getCompressionPrompt } from './prompts.js'; import { getResponseText } from '../utils/partUtils.js'; import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js'; @@ -31,18 +30,13 @@ import { getErrorMessage } from '../utils/errors.js'; import { isFunctionResponse } from '../utils/messageInspectors.js'; import { tokenLimit } from './tokenLimits.js'; import type { ChatRecordingService } from '../services/chatRecordingService.js'; -import type { - ContentGenerator, - ContentGeneratorConfig, -} from './contentGenerator.js'; -import { AuthType, createContentGenerator } from './contentGenerator.js'; -import { ProxyAgent, setGlobalDispatcher } from 'undici'; +import type { ContentGenerator } from './contentGenerator.js'; import { DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_THINKING_MODE, } from '../config/models.js'; import { LoopDetectionService } from '../services/loopDetectionService.js'; -import { ideContext } from '../ide/ideContext.js'; +import { ideContextStore } from '../ide/ideContext.js'; import { logChatCompression, logNextSpeakerCheck, @@ -53,7 +47,9 @@ import { MalformedJsonResponseEvent, NextSpeakerCheckEvent, } from '../telemetry/types.js'; -import type { IdeContext, File } from '../ide/ideContext.js'; +import type { IdeContext, File } from '../ide/types.js'; +import { handleFallback } from '../fallback/handler.js'; +import type { RoutingContext } from '../routing/routingStrategy.js'; export function isThinkingSupported(model: string) { if (model.startsWith('gemini-2.5')) return true; @@ -91,10 +87,10 @@ export function findIndexAfterFraction( let charactersSoFar = 0; for (let i = 0; i < contentLengths.length; i++) { - charactersSoFar += contentLengths[i]; if (charactersSoFar >= targetCharacters) { return i; } + charactersSoFar += contentLengths[i]; } return contentLengths.length; } @@ -115,8 +111,6 @@ const COMPRESSION_PRESERVE_THRESHOLD = 0.3; export class GeminiClient { private chat?: GeminiChat; - private contentGenerator?: ContentGenerator; - private readonly embeddingModel: string; private readonly generateContentConfig: GenerateContentConfig = { temperature: 0, topP: 1, @@ -125,6 +119,7 @@ export class GeminiClient { private readonly loopDetector: LoopDetectionService; private lastPromptId: string; + private currentSequenceModel: string | null = null; private lastSentIdeContext: IdeContext | undefined; private forceFullIdeContext = true; @@ -135,33 +130,19 @@ export class GeminiClient { private hasFailedCompressionAttempt = false; constructor(private readonly config: Config) { - if (config.getProxy()) { - setGlobalDispatcher(new ProxyAgent(config.getProxy() as string)); - } - - this.embeddingModel = config.getEmbeddingModel(); this.loopDetector = new LoopDetectionService(config); this.lastPromptId = this.config.getSessionId(); } - async initialize(contentGeneratorConfig: ContentGeneratorConfig) { - this.contentGenerator = await createContentGenerator( - contentGeneratorConfig, - this.config, - this.config.getSessionId(), - ); + async initialize() { this.chat = await this.startChat(); } - getContentGenerator(): ContentGenerator { - if (!this.contentGenerator) { + private getContentGeneratorOrFail(): ContentGenerator { + if (!this.config.getContentGenerator()) { throw new Error('Content generator not initialized'); } - return this.contentGenerator; - } - - getUserTier(): UserTierId | undefined { - return this.contentGenerator?.userTier; + return this.config.getContentGenerator(); } async addHistory(content: Content) { @@ -176,39 +157,19 @@ export class GeminiClient { } isInitialized(): boolean { - return this.chat !== undefined && this.contentGenerator !== undefined; + return this.chat !== undefined; } getHistory(): Content[] { return this.getChat().getHistory(); } - setHistory( - history: Content[], - { stripThoughts = false }: { stripThoughts?: boolean } = {}, - ) { - const historyToSet = stripThoughts - ? history.map((content) => { - const newContent = { ...content }; - if (newContent.parts) { - newContent.parts = newContent.parts.map((part) => { - if ( - part && - typeof part === 'object' && - 'thoughtSignature' in part - ) { - const newPart = { ...part }; - delete (newPart as { thoughtSignature?: string }) - .thoughtSignature; - return newPart; - } - return part; - }); - } - return newContent; - }) - : history; - this.getChat().setHistory(historyToSet); + stripThoughtsFromHistory() { + this.getChat().stripThoughtsFromHistory(); + } + + setHistory(history: Content[]) { + this.getChat().setHistory(history); this.forceFullIdeContext = true; } @@ -227,6 +188,10 @@ export class GeminiClient { return this.chat?.getChatRecordingService(); } + getLoopDetectionService(): LoopDetectionService { + return this.loopDetector; + } + async addDirectoryContext(): Promise { if (!this.chat) { return; @@ -242,9 +207,11 @@ export class GeminiClient { this.forceFullIdeContext = true; this.hasFailedCompressionAttempt = false; const envParts = await getEnvironmentContext(this.config); + const toolRegistry = this.config.getToolRegistry(); const toolDeclarations = toolRegistry.getFunctionDeclarations(); const tools: Tool[] = [{ functionDeclarations: toolDeclarations }]; + const history: Content[] = [ { role: 'user', @@ -274,7 +241,6 @@ export class GeminiClient { : this.generateContentConfig; return new GeminiChat( this.config, - this.getContentGenerator(), { systemInstruction, ...generateContentConfigWithThinking, @@ -297,7 +263,7 @@ export class GeminiClient { contextParts: string[]; newIdeContext: IdeContext | undefined; } { - const currentIdeContext = ideContext.getIdeContext(); + const currentIdeContext = ideContextStore.get(); if (!currentIdeContext) { return { contextParts: [], newIdeContext: undefined }; } @@ -466,11 +432,11 @@ export class GeminiClient { signal: AbortSignal, prompt_id: string, turns: number = MAX_TURNS, - originalModel?: string, ): AsyncGenerator { if (this.lastPromptId !== prompt_id) { this.loopDetector.reset(prompt_id); this.lastPromptId = prompt_id; + this.currentSequenceModel = null; } this.sessionTurnCount++; if ( @@ -486,9 +452,6 @@ export class GeminiClient { return new Turn(this.getChat(), prompt_id); } - // Track the original model from the first call to detect model switching - const initialModel = originalModel || this.config.getModel(); - const compressed = await this.tryCompressChat(prompt_id); if (compressed.compressionStatus === CompressionStatus.COMPRESSED) { @@ -530,7 +493,26 @@ export class GeminiClient { return turn; } - const resultStream = turn.run(request, signal); + const routingContext: RoutingContext = { + history: this.getChat().getHistory(/*curated=*/ true), + request, + signal, + }; + + let modelToUse: string; + + // Determine Model (Stickiness vs. Routing) + if (this.currentSequenceModel) { + modelToUse = this.currentSequenceModel; + } else { + const router = await this.config.getModelRouterService(); + const decision = await router.route(routingContext); + modelToUse = decision.model; + // Lock the model for the rest of the sequence + this.currentSequenceModel = modelToUse; + } + + const resultStream = turn.run(modelToUse, request, signal); for await (const event of resultStream) { if (this.loopDetector.addAndCheck(event)) { yield { type: GeminiEventType.LoopDetected }; @@ -542,11 +524,8 @@ export class GeminiClient { } } if (!turn.pendingToolCalls.length && signal && !signal.aborted) { - // Check if model was switched during the call (likely due to quota error) - const currentModel = this.config.getModel(); - if (currentModel !== initialModel) { - // Model was switched (likely due to quota error fallback) - // Don't continue with recursive call to prevent unwanted Flash execution + // Check if next speaker check is needed + if (this.config.getQuotaErrorOccurred()) { return turn; } @@ -576,7 +555,6 @@ export class GeminiClient { signal, prompt_id, boundedTurns - 1, - initialModel, ); } } @@ -590,6 +568,8 @@ export class GeminiClient { model: string, config: GenerateContentConfig = {}, ): Promise> { + let currentAttemptModel: string = model; + try { const userMemory = this.config.getUserMemory(); const systemInstruction = getCoreSystemPrompt(userMemory); @@ -599,10 +579,15 @@ export class GeminiClient { ...config, }; - const apiCall = () => - this.getContentGenerator().generateContent( + const apiCall = () => { + const modelToUse = this.config.isInFallbackMode() + ? DEFAULT_GEMINI_FLASH_MODEL + : model; + currentAttemptModel = modelToUse; + + return this.getContentGeneratorOrFail().generateContent( { - model, + model: modelToUse, config: { ...requestConfig, systemInstruction, @@ -613,10 +598,17 @@ export class GeminiClient { }, this.lastPromptId, ); + }; + + const onPersistent429Callback = async ( + authType?: string, + error?: unknown, + ) => + // Pass the captured model to the centralized handler. + await handleFallback(this.config, currentAttemptModel, authType, error); const result = await retryWithBackoff(apiCall, { - onPersistent429: async (authType?: string, error?: unknown) => - await this.handleFlashFallback(authType, error), + onPersistent429: onPersistent429Callback, authType: this.config.getContentGeneratorConfig()?.authType, }); @@ -639,7 +631,7 @@ export class GeminiClient { if (text.startsWith(prefix) && text.endsWith(suffix)) { logMalformedJsonResponse( this.config, - new MalformedJsonResponseEvent(model), + new MalformedJsonResponseEvent(currentAttemptModel), ); text = text .substring(prefix.length, text.length - suffix.length) @@ -695,6 +687,8 @@ export class GeminiClient { abortSignal: AbortSignal, model: string, ): Promise { + let currentAttemptModel: string = model; + const configToUse: GenerateContentConfig = { ...this.generateContentConfig, ...generationConfig, @@ -710,19 +704,30 @@ export class GeminiClient { systemInstruction, }; - const apiCall = () => - this.getContentGenerator().generateContent( + const apiCall = () => { + const modelToUse = this.config.isInFallbackMode() + ? DEFAULT_GEMINI_FLASH_MODEL + : model; + currentAttemptModel = modelToUse; + + return this.getContentGeneratorOrFail().generateContent( { - model, + model: modelToUse, config: requestConfig, contents, }, this.lastPromptId, ); + }; + const onPersistent429Callback = async ( + authType?: string, + error?: unknown, + ) => + // Pass the captured model to the centralized handler. + await handleFallback(this.config, currentAttemptModel, authType, error); const result = await retryWithBackoff(apiCall, { - onPersistent429: async (authType?: string, error?: unknown) => - await this.handleFlashFallback(authType, error), + onPersistent429: onPersistent429Callback, authType: this.config.getContentGeneratorConfig()?.authType, }); return result; @@ -733,7 +738,7 @@ export class GeminiClient { await reportError( error, - `Error generating content via API with model ${model}.`, + `Error generating content via API with model ${currentAttemptModel}.`, { requestContents: contents, requestConfig: configToUse, @@ -741,7 +746,7 @@ export class GeminiClient { 'generateContent-api', ); throw new Error( - `Failed to generate content with model ${model}: ${getErrorMessage(error)}`, + `Failed to generate content with model ${currentAttemptModel}: ${getErrorMessage(error)}`, ); } } @@ -751,12 +756,12 @@ export class GeminiClient { return []; } const embedModelParams: EmbedContentParameters = { - model: this.embeddingModel, + model: this.config.getEmbeddingModel(), contents: texts, }; const embedContentResponse = - await this.getContentGenerator().embedContent(embedModelParams); + await this.getContentGeneratorOrFail().embedContent(embedModelParams); if ( !embedContentResponse.embeddings || embedContentResponse.embeddings.length === 0 @@ -802,7 +807,7 @@ export class GeminiClient { const model = this.config.getModel(); const { totalTokens: originalTokenCount } = - await this.getContentGenerator().countTokens({ + await this.getContentGeneratorOrFail().countTokens({ model, contents: curatedHistory, }); @@ -849,20 +854,30 @@ export class GeminiClient { const historyToCompress = curatedHistory.slice(0, compressBeforeIndex); const historyToKeep = curatedHistory.slice(compressBeforeIndex); - this.getChat().setHistory(historyToCompress); - - const { text: summary } = await this.getChat().sendMessage( - { - message: { - text: 'First, reason in your scratchpad. Then, generate the .', - }, - config: { - systemInstruction: { text: getCompressionPrompt() }, - maxOutputTokens: originalTokenCount, + const summaryResponse = await this.config + .getContentGenerator() + .generateContent( + { + model, + contents: [ + ...historyToCompress, + { + role: 'user', + parts: [ + { + text: 'First, reason in your scratchpad. Then, generate the .', + }, + ], + }, + ], + config: { + systemInstruction: { text: getCompressionPrompt() }, + }, }, - }, - prompt_id, - ); + prompt_id, + ); + const summary = getResponseText(summaryResponse) ?? ''; + const chat = await this.startChat([ { role: 'user', @@ -877,7 +892,7 @@ export class GeminiClient { this.forceFullIdeContext = true; const { totalTokens: newTokenCount } = - await this.getContentGenerator().countTokens({ + await this.getContentGeneratorOrFail().countTokens({ // model might change after calling `sendMessage`, so we get the newest value from config model: this.config.getModel(), contents: chat.getHistory(), @@ -920,53 +935,6 @@ export class GeminiClient { compressionStatus: CompressionStatus.COMPRESSED, }; } - - /** - * Handles falling back to Flash model when persistent 429 errors occur for OAuth users. - * Uses a fallback handler if provided by the config; otherwise, returns null. - */ - private async handleFlashFallback( - authType?: string, - error?: unknown, - ): Promise { - // Only handle fallback for OAuth users - if (authType !== AuthType.LOGIN_WITH_GOOGLE) { - return null; - } - - const currentModel = this.config.getModel(); - const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; - - // Don't fallback if already using Flash model - if (currentModel === fallbackModel) { - return null; - } - - // Check if config has a fallback handler (set by CLI package) - const fallbackHandler = this.config.flashFallbackHandler; - if (typeof fallbackHandler === 'function') { - try { - const accepted = await fallbackHandler( - currentModel, - fallbackModel, - error, - ); - if (accepted !== false && accepted !== null) { - this.config.setModel(fallbackModel); - this.config.setFallbackMode(true); - return fallbackModel; - } - // Check if the model was switched manually in the handler - if (this.config.getModel() === fallbackModel) { - return null; // Model was switched but don't continue with current prompt - } - } catch (error) { - console.warn('Flash fallback handler failed:', error); - } - } - - return null; - } } export const TEST_ONLY = { diff --git a/packages/core/src/core/contentGenerator.test.ts b/packages/core/src/core/contentGenerator.test.ts index eba9d353ec2..3084c84bd46 100644 --- a/packages/core/src/core/contentGenerator.test.ts +++ b/packages/core/src/core/contentGenerator.test.ts @@ -29,7 +29,6 @@ describe('createContentGenerator', () => { ); const generator = await createContentGenerator( { - model: 'test-model', authType: AuthType.LOGIN_WITH_GOOGLE, }, mockConfig, @@ -51,7 +50,6 @@ describe('createContentGenerator', () => { vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never); const generator = await createContentGenerator( { - model: 'test-model', apiKey: 'test-api-key', authType: AuthType.USE_GEMINI, }, @@ -85,7 +83,6 @@ describe('createContentGenerator', () => { vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never); const generator = await createContentGenerator( { - model: 'test-model', apiKey: 'test-api-key', authType: AuthType.USE_GEMINI, }, diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 4a794fd1f44..12f8ac7ae86 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -14,7 +14,6 @@ import type { } from '@google/genai'; import { GoogleGenAI } from '@google/genai'; import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js'; -import { DEFAULT_GEMINI_MODEL } from '../config/models.js'; import type { Config } from '../config/config.js'; import type { UserTierId } from '../code_assist/types.js'; @@ -50,7 +49,6 @@ export enum AuthType { } export type ContentGeneratorConfig = { - model: string; apiKey?: string; vertexai?: boolean; authType?: AuthType; @@ -66,11 +64,7 @@ export function createContentGeneratorConfig( const googleCloudProject = process.env['GOOGLE_CLOUD_PROJECT'] || undefined; const googleCloudLocation = process.env['GOOGLE_CLOUD_LOCATION'] || undefined; - // Use runtime model from config if available; otherwise, fall back to parameter or default - const effectiveModel = config.getModel() || DEFAULT_GEMINI_MODEL; - const contentGeneratorConfig: ContentGeneratorConfig = { - model: effectiveModel, authType, proxy: config?.getProxy(), }; diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 0f85cc5bd26..2c028240fb4 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -176,6 +176,10 @@ describe('CoreToolScheduler', () => { model: 'test-model', authType: 'oauth-personal', }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), storage: { getProjectTempDir: () => '/tmp', }, @@ -283,6 +287,10 @@ describe('CoreToolScheduler with payload', () => { model: 'test-model', authType: 'oauth-personal', }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), storage: { getProjectTempDir: () => '/tmp', }, @@ -601,6 +609,10 @@ describe('CoreToolScheduler edit cancellation', () => { model: 'test-model', authType: 'oauth-personal', }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), storage: { getProjectTempDir: () => '/tmp', }, @@ -697,6 +709,10 @@ describe('CoreToolScheduler YOLO mode', () => { model: 'test-model', authType: 'oauth-personal', }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), storage: { getProjectTempDir: () => '/tmp', }, @@ -799,6 +815,10 @@ describe('CoreToolScheduler request queueing', () => { model: 'test-model', authType: 'oauth-personal', }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), storage: { getProjectTempDir: () => '/tmp', }, @@ -924,6 +944,12 @@ describe('CoreToolScheduler request queueing', () => { model: 'test-model', authType: 'oauth-personal', }), + getShellExecutionConfig: () => ({ + terminalWidth: 80, + terminalHeight: 24, + }), + getTerminalWidth: vi.fn(() => 80), + getTerminalHeight: vi.fn(() => 24), storage: { getProjectTempDir: () => '/tmp', }, @@ -1016,6 +1042,10 @@ describe('CoreToolScheduler request queueing', () => { model: 'test-model', authType: 'oauth-personal', }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), storage: { getProjectTempDir: () => '/tmp', }, @@ -1084,6 +1114,10 @@ describe('CoreToolScheduler request queueing', () => { setApprovalMode: (mode: ApprovalMode) => { approvalMode = mode; }, + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), storage: { getProjectTempDir: () => '/tmp', }, diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 5cdb3d5e9cc..689e5abc157 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -16,6 +16,7 @@ import type { ToolConfirmationPayload, AnyDeclarativeTool, AnyToolInvocation, + AnsiOutput, } from '../index.js'; import { ToolConfirmationOutcome, @@ -24,6 +25,9 @@ import { ReadFileTool, ToolErrorType, ToolCallEvent, + ShellTool, + logToolOutputTruncated, + ToolOutputTruncatedEvent, } from '../index.js'; import type { Part, PartListUnion } from '@google/genai'; import { getResponseTextFromParts } from '../utils/generateContentResponseUtilities.js'; @@ -37,6 +41,7 @@ import * as fs from 'node:fs/promises'; import * as path from 'node:path'; import { doesToolInvocationMatch } from '../utils/tool-utils.js'; import levenshtein from 'fast-levenshtein'; +import { ShellToolInvocation } from '../tools/shell.js'; export type ValidatingToolCall = { status: 'validating'; @@ -80,9 +85,10 @@ export type ExecutingToolCall = { request: ToolCallRequestInfo; tool: AnyDeclarativeTool; invocation: AnyToolInvocation; - liveOutput?: string; + liveOutput?: string | AnsiOutput; startTime?: number; outcome?: ToolConfirmationOutcome; + pid?: number; }; export type CancelledToolCall = { @@ -127,7 +133,7 @@ export type ConfirmHandler = ( export type OutputUpdateHandler = ( toolCallId: string, - outputChunk: string, + outputChunk: string | AnsiOutput, ) => void; export type AllToolCallsCompleteHandler = ( @@ -244,6 +250,7 @@ const createErrorResponse = ( ], resultDisplay: error.message, errorType, + contentLength: error.message.length, }); export async function truncateAndSaveToFile( @@ -460,6 +467,7 @@ export class CoreToolScheduler { } } + const errorMessage = `[Operation Cancelled] Reason: ${auxiliaryData}`; return { request: currentCall.request, tool: toolInstance, @@ -473,7 +481,7 @@ export class CoreToolScheduler { id: currentCall.request.callId, name: currentCall.request.name, response: { - error: `[Operation Cancelled] Reason: ${auxiliaryData}`, + error: errorMessage, }, }, }, @@ -481,6 +489,7 @@ export class CoreToolScheduler { resultDisplay, error: undefined, errorType: undefined, + contentLength: errorMessage.length, }, durationMs, outcome, @@ -946,7 +955,7 @@ export class CoreToolScheduler { const liveOutputCallback = scheduledCall.tool.canUpdateOutput && this.outputUpdateHandler - ? (outputChunk: string) => { + ? (outputChunk: string | AnsiOutput) => { if (this.outputUpdateHandler) { this.outputUpdateHandler(callId, outputChunk); } @@ -959,8 +968,37 @@ export class CoreToolScheduler { } : undefined; - invocation - .execute(signal, liveOutputCallback) + const shellExecutionConfig = this.config.getShellExecutionConfig(); + + // TODO: Refactor to remove special casing for ShellToolInvocation. + // Introduce a generic callbacks object for the execute method to handle + // things like `onPid` and `onLiveOutput`. This will make the scheduler + // agnostic to the invocation type. + let promise: Promise; + if (invocation instanceof ShellToolInvocation) { + const setPidCallback = (pid: number) => { + this.toolCalls = this.toolCalls.map((tc) => + tc.request.callId === callId && tc.status === 'executing' + ? { ...tc, pid } + : tc, + ); + this.notifyToolCallsUpdate(); + }; + promise = invocation.execute( + signal, + liveOutputCallback, + shellExecutionConfig, + setPidCallback, + ); + } else { + promise = invocation.execute( + signal, + liveOutputCallback, + shellExecutionConfig, + ); + } + + promise .then(async (toolResult: ToolResult) => { if (signal.aborted) { this.setStatusInternal( @@ -974,17 +1012,43 @@ export class CoreToolScheduler { if (toolResult.error === undefined) { let content = toolResult.llmContent; let outputFile: string | undefined = undefined; + const contentLength = + typeof content === 'string' ? content.length : undefined; if ( typeof content === 'string' && - this.config.getTruncateToolOutputThreshold() > 0 + toolName === ShellTool.Name && + this.config.getEnableToolOutputTruncation() && + this.config.getTruncateToolOutputThreshold() > 0 && + this.config.getTruncateToolOutputLines() > 0 ) { - ({ content, outputFile } = await truncateAndSaveToFile( + const originalContentLength = content.length; + const threshold = this.config.getTruncateToolOutputThreshold(); + const lines = this.config.getTruncateToolOutputLines(); + const truncatedResult = await truncateAndSaveToFile( content, callId, this.config.storage.getProjectTempDir(), - this.config.getTruncateToolOutputThreshold(), - this.config.getTruncateToolOutputLines(), - )); + threshold, + lines, + ); + content = truncatedResult.content; + outputFile = truncatedResult.outputFile; + + if (outputFile) { + logToolOutputTruncated( + this.config, + new ToolOutputTruncatedEvent( + scheduledCall.request.prompt_id, + { + toolName, + originalContentLength, + truncatedContentLength: content.length, + threshold, + lines, + }, + ), + ); + } } const response = convertToFunctionResponse( @@ -999,6 +1063,7 @@ export class CoreToolScheduler { error: undefined, errorType: undefined, outputFile, + contentLength, }; this.setStatusInternal(callId, 'success', successResponse); } else { diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index c7660441fd8..4d5b6f4ab1b 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -7,11 +7,11 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import type { Content, - Models, GenerateContentConfig, Part, GenerateContentResponse, } from '@google/genai'; +import type { ContentGenerator } from '../core/contentGenerator.js'; import { GeminiChat, EmptyStreamError, @@ -20,6 +20,9 @@ import { } from './geminiChat.js'; import type { Config } from '../config/config.js'; import { setSimulate429 } from '../utils/testUtils.js'; +import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { AuthType } from './contentGenerator.js'; +import { type RetryOptions } from '../utils/retry.js'; // Mock fs module to prevent actual file system operations during tests const mockFileSystem = new Map(); @@ -47,14 +50,22 @@ vi.mock('node:fs', () => { }; }); -// Mocks -const mockModelsModule = { - generateContent: vi.fn(), - generateContentStream: vi.fn(), - countTokens: vi.fn(), - embedContent: vi.fn(), - batchEmbedContents: vi.fn(), -} as unknown as Models; +const { mockHandleFallback } = vi.hoisted(() => ({ + mockHandleFallback: vi.fn(), +})); + +// Add mock for the retry utility +const { mockRetryWithBackoff } = vi.hoisted(() => ({ + mockRetryWithBackoff: vi.fn(), +})); + +vi.mock('../utils/retry.js', () => ({ + retryWithBackoff: mockRetryWithBackoff, +})); + +vi.mock('../fallback/handler.js', () => ({ + handleFallback: mockHandleFallback, +})); const { mockLogInvalidChunk, mockLogContentRetry, mockLogContentRetryFailure } = vi.hoisted(() => ({ @@ -70,23 +81,36 @@ vi.mock('../telemetry/loggers.js', () => ({ })); describe('GeminiChat', () => { + let mockContentGenerator: ContentGenerator; let chat: GeminiChat; let mockConfig: Config; const config: GenerateContentConfig = {}; beforeEach(() => { vi.clearAllMocks(); + mockContentGenerator = { + generateContent: vi.fn(), + generateContentStream: vi.fn(), + countTokens: vi.fn(), + embedContent: vi.fn(), + batchEmbedContents: vi.fn(), + } as unknown as ContentGenerator; + + mockHandleFallback.mockClear(); + // Default mock implementation for tests that don't care about retry logic + mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall()); mockConfig = { getSessionId: () => 'test-session-id', getTelemetryLogPromptsEnabled: () => true, getUsageStatisticsEnabled: () => true, getDebugMode: () => false, - getContentGeneratorConfig: () => ({ - authType: 'oauth-personal', + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'oauth-personal', // Ensure this is set for fallback tests model: 'test-model', }), getModel: vi.fn().mockReturnValue('gemini-pro'), setModel: vi.fn(), + isInFallbackMode: vi.fn().mockReturnValue(false), getQuotaErrorOccurred: vi.fn().mockReturnValue(false), setQuotaErrorOccurred: vi.fn(), flashFallbackHandler: undefined, @@ -97,12 +121,13 @@ describe('GeminiChat', () => { getToolRegistry: vi.fn().mockReturnValue({ getTool: vi.fn(), }), + getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator), } as unknown as Config; // Disable 429 simulation for tests setSimulate429(false); // Reset history for each test by creating a new instance - chat = new GeminiChat(mockConfig, mockModelsModule, config, []); + chat = new GeminiChat(mockConfig, config, []); }); afterEach(() => { @@ -110,259 +135,6 @@ describe('GeminiChat', () => { vi.resetAllMocks(); }); - describe('sendMessage', () => { - it('should retain the initial user message when an automatic function call occurs', async () => { - // 1. Define the user's initial text message. This is the turn that gets dropped by the buggy logic. - const userInitialMessage: Content = { - role: 'user', - parts: [{ text: 'How is the weather in Boston?' }], - }; - - // 2. Mock the full API response, including the automaticFunctionCallingHistory. - // This history represents the full turn: user asks, model calls tool, tool responds, model answers. - const mockAfcResponse = { - candidates: [ - { - content: { - role: 'model', - parts: [ - { text: 'The weather in Boston is 72 degrees and sunny.' }, - ], - }, - }, - ], - automaticFunctionCallingHistory: [ - userInitialMessage, // The user's turn - { - // The model's first response: a tool call - role: 'model', - parts: [ - { - functionCall: { - name: 'get_weather', - args: { location: 'Boston' }, - }, - }, - ], - }, - { - // The tool's response, which has a 'user' role - role: 'user', - parts: [ - { - functionResponse: { - name: 'get_weather', - response: { temperature: 72, condition: 'sunny' }, - }, - }, - ], - }, - ], - } as unknown as GenerateContentResponse; - - vi.mocked(mockModelsModule.generateContent).mockResolvedValue( - mockAfcResponse, - ); - - // 3. Action: Send the initial message. - await chat.sendMessage( - { message: 'How is the weather in Boston?' }, - 'prompt-id-afc-bug', - ); - - // 4. Assert: Check the final state of the history. - const history = chat.getHistory(); - - // With the bug, history.length will be 3, because the first user message is dropped. - // The correct behavior is for the history to contain all 4 turns. - expect(history.length).toBe(4); - - // Crucially, assert that the very first turn in the history matches the user's initial message. - // This is the assertion that will fail. - const firstTurn = history[0]!; - expect(firstTurn.role).toBe('user'); - expect(firstTurn?.parts![0]!.text).toBe('How is the weather in Boston?'); - - // Verify the rest of the history is also correct. - const secondTurn = history[1]!; - expect(secondTurn.role).toBe('model'); - expect(secondTurn?.parts![0]!.functionCall).toBeDefined(); - - const thirdTurn = history[2]!; - expect(thirdTurn.role).toBe('user'); - expect(thirdTurn?.parts![0]!.functionResponse).toBeDefined(); - - const fourthTurn = history[3]!; - expect(fourthTurn.role).toBe('model'); - expect(fourthTurn?.parts![0]!.text).toContain('72 degrees and sunny'); - }); - - it('should throw an error when attempting to add a user turn after another user turn', async () => { - // 1. Setup: Create a history that already ends with a user turn (a functionResponse). - const initialHistory: Content[] = [ - { role: 'user', parts: [{ text: 'Initial prompt' }] }, - { - role: 'model', - parts: [{ functionCall: { name: 'test_tool', args: {} } }], - }, - { - role: 'user', - parts: [{ functionResponse: { name: 'test_tool', response: {} } }], - }, - ]; - chat.setHistory(initialHistory); - - // 2. Mock a valid model response so the call doesn't fail for other reasons. - const mockResponse = { - candidates: [ - { content: { role: 'model', parts: [{ text: 'some response' }] } }, - ], - } as unknown as GenerateContentResponse; - vi.mocked(mockModelsModule.generateContent).mockResolvedValue( - mockResponse, - ); - - // 3. Action & Assert: Expect that sending another user message immediately - // after a user-role turn throws the specific error. - await expect( - chat.sendMessage( - { message: 'This is an invalid consecutive user message' }, - 'prompt-id-1', - ), - ).rejects.toThrow('Cannot add a user turn after another user turn.'); - }); - it('should preserve text parts that are in the same response as a thought', async () => { - // 1. Mock the API to return a single response containing both a thought and visible text. - const mixedContentResponse = { - candidates: [ - { - content: { - role: 'model', - parts: [ - { thought: 'This is a thought.' }, - { text: 'This is the visible text that should not be lost.' }, - ], - }, - }, - ], - } as unknown as GenerateContentResponse; - - vi.mocked(mockModelsModule.generateContent).mockResolvedValue( - mixedContentResponse, - ); - - // 2. Action: Send a standard, non-streaming message. - await chat.sendMessage( - { message: 'test message' }, - 'prompt-id-mixed-response', - ); - - // 3. Assert: Check the final state of the history. - const history = chat.getHistory(); - - // The history should contain two turns: the user's message and the model's response. - expect(history.length).toBe(2); - - const modelTurn = history[1]!; - expect(modelTurn.role).toBe('model'); - - // CRUCIAL ASSERTION: - // Buggy code would discard the entire response because a "thought" was present, - // resulting in an empty placeholder turn with 0 parts. - // The corrected code will pass, preserving the single visible text part. - expect(modelTurn?.parts?.length).toBe(1); - expect(modelTurn?.parts![0]!.text).toBe( - 'This is the visible text that should not be lost.', - ); - }); - it('should add a placeholder model turn when a tool call is followed by an empty model response', async () => { - // 1. Setup: A history where the model has just made a function call. - const initialHistory: Content[] = [ - { - role: 'user', - parts: [{ text: 'Find a good Italian restaurant for me.' }], - }, - { - role: 'model', - parts: [ - { - functionCall: { - name: 'find_restaurant', - args: { cuisine: 'Italian' }, - }, - }, - ], - }, - ]; - chat.setHistory(initialHistory); - - // 2. Mock the API to return an empty/thought-only response. - const emptyModelResponse = { - candidates: [ - { content: { role: 'model', parts: [{ thought: true }] } }, - ], - } as unknown as GenerateContentResponse; - vi.mocked(mockModelsModule.generateContent).mockResolvedValue( - emptyModelResponse, - ); - - // 3. Action: Send the function response back to the model. - await chat.sendMessage( - { - message: { - functionResponse: { - name: 'find_restaurant', - response: { name: 'Vesuvio' }, - }, - }, - }, - 'prompt-id-1', - ); - - // 4. Assert: The history should now have four valid, alternating turns. - const history = chat.getHistory(); - expect(history.length).toBe(4); - - // The final turn must be the empty model placeholder. - const lastTurn = history[3]!; - expect(lastTurn.role).toBe('model'); - expect(lastTurn?.parts?.length).toBe(0); - - // The second-to-last turn must be the function response we sent. - const secondToLastTurn = history[2]!; - expect(secondToLastTurn.role).toBe('user'); - expect(secondToLastTurn?.parts![0]!.functionResponse).toBeDefined(); - }); - it('should call generateContent with the correct parameters', async () => { - const response = { - candidates: [ - { - content: { - parts: [{ text: 'response' }], - role: 'model', - }, - finishReason: 'STOP', - index: 0, - safetyRatings: [], - }, - ], - text: () => 'response', - } as unknown as GenerateContentResponse; - vi.mocked(mockModelsModule.generateContent).mockResolvedValue(response); - - await chat.sendMessage({ message: 'hello' }, 'prompt-id-1'); - - expect(mockModelsModule.generateContent).toHaveBeenCalledWith( - { - model: 'gemini-pro', - contents: [{ role: 'user', parts: [{ text: 'hello' }] }], - config: {}, - }, - 'prompt-id-1', - ); - }); - }); - describe('sendMessageStream', () => { it('should succeed if a tool call is followed by an empty part', async () => { // 1. Mock a stream that contains a tool call, then an invalid (empty) part. @@ -390,13 +162,14 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( streamWithToolCall, ); // 2. Action & Assert: The stream processing should complete without throwing an error // because the presence of a tool call makes the empty final chunk acceptable. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-tool-call-empty-end', ); @@ -442,12 +215,13 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( streamWithNoFinish, ); // 2. Action & Assert: The stream should fail because there's no finish reason. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-no-finish-empty-end', ); @@ -487,12 +261,13 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( streamWithInvalidEnd, ); // 2. Action & Assert: The stream should complete without throwing an error. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-valid-then-invalid-end', ); @@ -543,12 +318,13 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( multiChunkStream, ); // 2. Action: Send a message and consume the stream. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-malformed-chunk', ); @@ -593,12 +369,13 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( multiChunkStream, ); // 2. Action: Send a message and consume the stream. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-empty-chunk-consolidation', ); @@ -650,12 +427,13 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( multiChunkStream, ); // 2. Action: Send a message and consume the stream. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-multi-chunk', ); @@ -697,12 +475,13 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( mixedContentStream, ); // 2. Action: Send a message and fully consume the stream to trigger history recording. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-mixed-chunk', ); @@ -759,12 +538,13 @@ describe('GeminiChat', () => { ], } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( emptyStreamResponse, ); // 3. Action: Send the function response back to the model and consume the stream. const stream = await chat.sendMessageStream( + 'test-model', { message: { functionResponse: { @@ -811,22 +591,28 @@ describe('GeminiChat', () => { text: () => 'response', } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( response, ); const stream = await chat.sendMessageStream( + 'test-model', { message: 'hello' }, 'prompt-id-1', ); for await (const _ of stream) { - // consume stream to trigger internal logic + // consume stream } - expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith( { - model: 'gemini-pro', - contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + model: 'test-model', + contents: [ + { + role: 'user', + parts: [{ text: 'hello' }], + }, + ], config: {}, }, 'prompt-id-1', @@ -1012,7 +798,7 @@ describe('GeminiChat', () => { describe('sendMessageStream with retries', () => { it('should yield a RETRY event when an invalid stream is encountered', async () => { // ARRANGE: Mock the stream to fail once, then succeed. - vi.mocked(mockModelsModule.generateContentStream) + vi.mocked(mockContentGenerator.generateContentStream) .mockImplementationOnce(async () => // First attempt: An invalid stream with an empty text part. (async function* () { @@ -1037,6 +823,7 @@ describe('GeminiChat', () => { // ACT: Send a message and collect all events from the stream. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test' }, 'prompt-id-yield-retry', ); @@ -1053,7 +840,7 @@ describe('GeminiChat', () => { }); it('should retry on invalid content, succeed, and report metrics', async () => { // Use mockImplementationOnce to provide a fresh, promise-wrapped generator for each attempt. - vi.mocked(mockModelsModule.generateContentStream) + vi.mocked(mockContentGenerator.generateContentStream) .mockImplementationOnce(async () => // First call returns an invalid stream (async function* () { @@ -1077,6 +864,7 @@ describe('GeminiChat', () => { ); const stream = await chat.sendMessageStream( + 'test-model', { message: 'test' }, 'prompt-id-retry-success', ); @@ -1089,7 +877,9 @@ describe('GeminiChat', () => { expect(mockLogInvalidChunk).toHaveBeenCalledTimes(1); expect(mockLogContentRetry).toHaveBeenCalledTimes(1); expect(mockLogContentRetryFailure).not.toHaveBeenCalled(); - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes( + 2, + ); // Check for a retry event expect(chunks.some((c) => c.type === StreamEventType.RETRY)).toBe(true); @@ -1118,7 +908,7 @@ describe('GeminiChat', () => { }); it('should fail after all retries on persistent invalid content and report metrics', async () => { - vi.mocked(mockModelsModule.generateContentStream).mockImplementation( + vi.mocked(mockContentGenerator.generateContentStream).mockImplementation( async () => (async function* () { yield { @@ -1134,23 +924,21 @@ describe('GeminiChat', () => { })(), ); - // This helper function consumes the stream and allows us to test for rejection. - async function consumeStreamAndExpectError() { - const stream = await chat.sendMessageStream( - { message: 'test' }, - 'prompt-id-retry-fail', - ); + const stream = await chat.sendMessageStream( + 'test-model', + { message: 'test' }, + 'prompt-id-retry-fail', + ); + await expect(async () => { for await (const _ of stream) { // Must loop to trigger the internal logic that throws. } - } - - await expect(consumeStreamAndExpectError()).rejects.toThrow( - EmptyStreamError, - ); + }).rejects.toThrow(EmptyStreamError); // Should be called 3 times (initial + 2 retries) - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(3); + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes( + 3, + ); expect(mockLogInvalidChunk).toHaveBeenCalledTimes(3); expect(mockLogContentRetry).toHaveBeenCalledTimes(2); expect(mockLogContentRetryFailure).toHaveBeenCalledTimes(1); @@ -1169,7 +957,7 @@ describe('GeminiChat', () => { chat.setHistory(initialHistory); // 2. Mock the API to fail once with an empty stream, then succeed. - vi.mocked(mockModelsModule.generateContentStream) + vi.mocked(mockContentGenerator.generateContentStream) .mockImplementationOnce(async () => (async function* () { yield { @@ -1193,6 +981,7 @@ describe('GeminiChat', () => { // 3. Send a new message const stream = await chat.sendMessageStream( + 'test-model', { message: 'Second question' }, 'prompt-id-retry-existing', ); @@ -1237,90 +1026,9 @@ describe('GeminiChat', () => { expect(turn4.parts[0].text).toBe('Second answer'); }); - describe('concurrency control', () => { - it('should queue a subsequent sendMessage call until the first one completes', async () => { - // 1. Create promises to manually control when the API calls resolve - let firstCallResolver: (value: GenerateContentResponse) => void; - const firstCallPromise = new Promise( - (resolve) => { - firstCallResolver = resolve; - }, - ); - - let secondCallResolver: (value: GenerateContentResponse) => void; - const secondCallPromise = new Promise( - (resolve) => { - secondCallResolver = resolve; - }, - ); - - // A standard response body for the mock - const mockResponse = { - candidates: [ - { - content: { parts: [{ text: 'response' }], role: 'model' }, - }, - ], - } as unknown as GenerateContentResponse; - - // 2. Mock the API to return our controllable promises in order - vi.mocked(mockModelsModule.generateContent) - .mockReturnValueOnce(firstCallPromise) - .mockReturnValueOnce(secondCallPromise); - - // 3. Start the first message call. Do not await it yet. - const firstMessagePromise = chat.sendMessage( - { message: 'first' }, - 'prompt-1', - ); - - // Give the event loop a chance to run the async call up to the `await` - await new Promise(process.nextTick); - - // 4. While the first call is "in-flight", start the second message call. - const secondMessagePromise = chat.sendMessage( - { message: 'second' }, - 'prompt-2', - ); - - // 5. CRUCIAL CHECK: At this point, only the first API call should have been made. - // The second call should be waiting on `sendPromise`. - expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(1); - expect(mockModelsModule.generateContent).toHaveBeenCalledWith( - expect.objectContaining({ - contents: expect.arrayContaining([ - expect.objectContaining({ parts: [{ text: 'first' }] }), - ]), - }), - 'prompt-1', - ); - - // 6. Unblock the first API call and wait for the first message to fully complete. - firstCallResolver!(mockResponse); - await firstMessagePromise; - - // Give the event loop a chance to unblock and run the second call. - await new Promise(process.nextTick); - - // 7. CRUCIAL CHECK: Now, the second API call should have been made. - expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(2); - expect(mockModelsModule.generateContent).toHaveBeenCalledWith( - expect.objectContaining({ - contents: expect.arrayContaining([ - expect.objectContaining({ parts: [{ text: 'second' }] }), - ]), - }), - 'prompt-2', - ); - - // 8. Clean up by resolving the second call. - secondCallResolver!(mockResponse); - await secondMessagePromise; - }); - }); it('should retry if the model returns a completely empty stream (no chunks)', async () => { // 1. Mock the API to return an empty stream first, then a valid one. - vi.mocked(mockModelsModule.generateContentStream) + vi.mocked(mockContentGenerator.generateContentStream) .mockImplementationOnce( // First call resolves to an async generator that yields nothing. async () => (async function* () {})(), @@ -1344,6 +1052,7 @@ describe('GeminiChat', () => { // 2. Call the method and consume the stream. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test empty stream' }, 'prompt-id-empty-stream', ); @@ -1353,7 +1062,7 @@ describe('GeminiChat', () => { } // 3. Assert the results. - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2); expect( chunks.some( (c) => @@ -1417,12 +1126,13 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream) + vi.mocked(mockContentGenerator.generateContentStream) .mockResolvedValueOnce(firstStreamGenerator) .mockResolvedValueOnce(secondStreamGenerator); // 3. Start the first stream and consume only the first chunk to pause it const firstStream = await chat.sendMessageStream( + 'test-model', { message: 'first' }, 'prompt-1', ); @@ -1431,12 +1141,13 @@ describe('GeminiChat', () => { // 4. While the first stream is paused, start the second call. It will block. const secondStreamPromise = chat.sendMessageStream( + 'test-model', { message: 'second' }, 'prompt-2', ); // 5. Assert that only one API call has been made so far. - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(1); + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(1); // 6. Unblock and fully consume the first stream to completion. continueFirstStream!(); @@ -1451,7 +1162,7 @@ describe('GeminiChat', () => { await secondStreamIterator.next(); // 9. The second API call should now have been made. - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2); // 10. FIX: Fully consume the second stream to ensure recordHistory is called. await secondStreamIterator.next(); // This finishes the iterator. @@ -1469,9 +1180,169 @@ describe('GeminiChat', () => { expect(turn4.parts[0].text).toBe('second response'); }); + describe('Model Resolution', () => { + const mockResponse = { + candidates: [ + { + content: { parts: [{ text: 'response' }], role: 'model' }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + + it('should use the FLASH model when in fallback mode (sendMessageStream)', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro'); + vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true); + vi.mocked(mockContentGenerator.generateContentStream).mockImplementation( + async () => + (async function* () { + yield mockResponse; + })(), + ); + + const stream = await chat.sendMessageStream( + 'test-model', + { message: 'test' }, + 'prompt-id-res3', + ); + for await (const _ of stream) { + // consume stream + } + + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith( + expect.objectContaining({ + model: DEFAULT_GEMINI_FLASH_MODEL, + }), + 'prompt-id-res3', + ); + }); + }); + + describe('Fallback Integration (Retries)', () => { + const error429 = Object.assign(new Error('API Error 429: Quota exceeded'), { + status: 429, + }); + + // Define the simulated behavior for retryWithBackoff for these tests. + // This simulation tries the apiCall, if it fails, it calls the callback, + // and then tries the apiCall again if the callback returns true. + const simulateRetryBehavior = async ( + apiCall: () => Promise, + options: Partial, + ) => { + try { + return await apiCall(); + } catch (error) { + if (options.onPersistent429) { + // We simulate the "persistent" trigger here for simplicity. + const shouldRetry = await options.onPersistent429( + options.authType, + error, + ); + if (shouldRetry) { + return await apiCall(); + } + } + throw error; // Stop if callback returns false/null or doesn't exist + } + }; + + beforeEach(() => { + mockRetryWithBackoff.mockImplementation(simulateRetryBehavior); + }); + + afterEach(() => { + mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall()); + }); + + it('should call handleFallback with the specific failed model and retry if handler returns true', async () => { + const authType = AuthType.LOGIN_WITH_GOOGLE; + vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ + authType, + }); + + const isInFallbackModeSpy = vi.spyOn(mockConfig, 'isInFallbackMode'); + isInFallbackModeSpy.mockReturnValue(false); + + vi.mocked(mockContentGenerator.generateContentStream) + .mockRejectedValueOnce(error429) // Attempt 1 fails + .mockResolvedValueOnce( + // Attempt 2 succeeds + (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Success on retry' }] }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(), + ); + + mockHandleFallback.mockImplementation(async () => { + isInFallbackModeSpy.mockReturnValue(true); + return true; // Signal retry + }); + + const stream = await chat.sendMessageStream( + 'test-model', + { message: 'trigger 429' }, + 'prompt-id-fb1', + ); + + // Consume stream to trigger logic + for await (const _ of stream) { + // no-op + } + + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes( + 2, + ); + expect(mockHandleFallback).toHaveBeenCalledTimes(1); + expect(mockHandleFallback).toHaveBeenCalledWith( + mockConfig, + 'test-model', + authType, + error429, + ); + + const history = chat.getHistory(); + const modelTurn = history[1]!; + expect(modelTurn.parts![0]!.text).toBe('Success on retry'); + }); + + it('should stop retrying if handleFallback returns false (e.g., auth intent)', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro'); + vi.mocked(mockContentGenerator.generateContentStream).mockRejectedValue( + error429, + ); + mockHandleFallback.mockResolvedValue(false); + + const stream = await chat.sendMessageStream( + 'test-model', + { message: 'test stop' }, + 'prompt-id-fb2', + ); + + await expect( + (async () => { + for await (const _ of stream) { + /* consume stream */ + } + })(), + ).rejects.toThrow(error429); + + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes( + 1, + ); + expect(mockHandleFallback).toHaveBeenCalledTimes(1); + }); + }); + it('should discard valid partial content from a failed attempt upon retry', async () => { - // ARRANGE: Mock the stream to fail on the first attempt after yielding some valid content. - vi.mocked(mockModelsModule.generateContentStream) + // Mock the stream to fail on the first attempt after yielding some valid content. + vi.mocked(mockContentGenerator.generateContentStream) .mockImplementationOnce(async () => // First attempt: yields one valid chunk, then one invalid chunk (async function* () { @@ -1505,8 +1376,9 @@ describe('GeminiChat', () => { })(), ); - // ACT: Send a message and consume the stream + // Send a message and consume the stream const stream = await chat.sendMessageStream( + 'test-model', { message: 'test' }, 'prompt-id-discard-test', ); @@ -1515,9 +1387,8 @@ describe('GeminiChat', () => { events.push(event); } - // ASSERT // Check that a retry happened - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2); expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true); // Check the final recorded history @@ -1532,4 +1403,41 @@ describe('GeminiChat', () => { 'This valid part should be discarded', ); }); + + describe('stripThoughtsFromHistory', () => { + it('should strip thought signatures', () => { + chat.setHistory([ + { + role: 'user', + parts: [{ text: 'hello' }], + }, + { + role: 'model', + parts: [ + { text: 'thinking...', thoughtSignature: 'thought-123' }, + { + functionCall: { name: 'test', args: {} }, + thoughtSignature: 'thought-456', + }, + ], + }, + ]); + + chat.stripThoughtsFromHistory(); + + expect(chat.getHistory()).toEqual([ + { + role: 'user', + parts: [{ text: 'hello' }], + }, + { + role: 'model', + parts: [ + { text: 'thinking...' }, + { functionCall: { name: 'test', args: {} } }, + ], + }, + ]); + }); + }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 159e3560c3f..9f7f19ba9a8 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -18,10 +18,11 @@ import type { import { toParts } from '../code_assist/converter.js'; import { createUserContent } from '@google/genai'; import { retryWithBackoff } from '../utils/retry.js'; -import type { ContentGenerator } from './contentGenerator.js'; -import { AuthType } from './contentGenerator.js'; import type { Config } from '../config/config.js'; -import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { + DEFAULT_GEMINI_FLASH_MODEL, + getEffectiveModel, +} from '../config/models.js'; import { hasCycleInSchema } from '../tools/tools.js'; import type { StructuredError } from './turn.js'; import type { CompletedToolCall } from './coreToolScheduler.js'; @@ -36,6 +37,7 @@ import { ContentRetryFailureEvent, InvalidChunkEvent, } from '../telemetry/types.js'; +import { handleFallback } from '../fallback/handler.js'; import { isFunctionResponse } from '../utils/messageInspectors.js'; import { partListUnionToString } from './geminiRequest.js'; @@ -172,7 +174,6 @@ export class GeminiChat { constructor( private readonly config: Config, - private readonly contentGenerator: ContentGenerator, private readonly generationConfig: GenerateContentConfig = {}, private history: Content[] = [], ) { @@ -181,171 +182,9 @@ export class GeminiChat { this.chatRecordingService.initialize(); } - /** - * Handles falling back to Flash model when persistent 429 errors occur for OAuth users. - * Uses a fallback handler if provided by the config; otherwise, returns null. - */ - private async handleFlashFallback( - authType?: string, - error?: unknown, - ): Promise { - // Only handle fallback for OAuth users - if (authType !== AuthType.LOGIN_WITH_GOOGLE) { - return null; - } - - const currentModel = this.config.getModel(); - const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; - - // Don't fallback if already using Flash model - if (currentModel === fallbackModel) { - return null; - } - - // Check if config has a fallback handler (set by CLI package) - const fallbackHandler = this.config.flashFallbackHandler; - if (typeof fallbackHandler === 'function') { - try { - const accepted = await fallbackHandler( - currentModel, - fallbackModel, - error, - ); - if (accepted !== false && accepted !== null) { - this.config.setModel(fallbackModel); - this.config.setFallbackMode(true); - return fallbackModel; - } - // Check if the model was switched manually in the handler - if (this.config.getModel() === fallbackModel) { - return null; // Model was switched but don't continue with current prompt - } - } catch (error) { - console.warn('Flash fallback handler failed:', error); - } - } - - return null; - } - setSystemInstruction(sysInstr: string) { this.generationConfig.systemInstruction = sysInstr; } - /** - * Sends a message to the model and returns the response. - * - * @remarks - * This method will wait for the previous message to be processed before - * sending the next message. - * - * @see {@link Chat#sendMessageStream} for streaming method. - * @param params - parameters for sending messages within a chat session. - * @returns The model's response. - * - * @example - * ```ts - * const chat = ai.chats.create({model: 'gemini-2.0-flash'}); - * const response = await chat.sendMessage({ - * message: 'Why is the sky blue?' - * }); - * console.log(response.text); - * ``` - */ - async sendMessage( - params: SendMessageParameters, - prompt_id: string, - ): Promise { - await this.sendPromise; - const userContent = createUserContent(params.message); - - // Record user input - capture complete message with all parts (text, files, images, etc.) - // but skip recording function responses (tool call results) as they should be stored in tool call records - if (!isFunctionResponse(userContent)) { - const userMessage = Array.isArray(params.message) - ? params.message - : [params.message]; - this.chatRecordingService.recordMessage({ - type: 'user', - content: userMessage, - }); - } - const requestContents = this.getHistory(true).concat(userContent); - - let response: GenerateContentResponse; - - try { - const apiCall = () => { - const modelToUse = this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL; - - // Prevent Flash model calls immediately after quota error - if ( - this.config.getQuotaErrorOccurred() && - modelToUse === DEFAULT_GEMINI_FLASH_MODEL - ) { - throw new Error( - 'Please submit a new query to continue with the Flash model.', - ); - } - - return this.contentGenerator.generateContent( - { - model: modelToUse, - contents: requestContents, - config: { ...this.generationConfig, ...params.config }, - }, - prompt_id, - ); - }; - - response = await retryWithBackoff(apiCall, { - shouldRetry: (error: unknown) => { - // Check for known error messages and codes. - if (error instanceof Error && error.message) { - if (isSchemaDepthError(error.message)) return false; - if (error.message.includes('429')) return true; - if (error.message.match(/5\d{2}/)) return true; - } - return false; // Don't retry other errors by default - }, - onPersistent429: async (authType?: string, error?: unknown) => - await this.handleFlashFallback(authType, error), - authType: this.config.getContentGeneratorConfig()?.authType, - }); - - this.sendPromise = (async () => { - const outputContent = response.candidates?.[0]?.content; - const modelOutput = outputContent ? [outputContent] : []; - - // Because the AFC input contains the entire curated chat history in - // addition to the new user input, we need to truncate the AFC history - // to deduplicate the existing chat history. - const fullAutomaticFunctionCallingHistory = - response.automaticFunctionCallingHistory; - const index = this.getHistory(true).length; - let automaticFunctionCallingHistory: Content[] = []; - if (fullAutomaticFunctionCallingHistory != null) { - automaticFunctionCallingHistory = - fullAutomaticFunctionCallingHistory.slice(index) ?? []; - } - - this.recordHistory( - userContent, - modelOutput, - automaticFunctionCallingHistory, - ); - })(); - await this.sendPromise.catch((error) => { - // Resets sendPromise to avoid subsequent calls failing - this.sendPromise = Promise.resolve(); - // Re-throw the error so the caller knows something went wrong. - throw error; - }); - return response; - } catch (error) { - this.sendPromise = Promise.resolve(); - throw error; - } - } /** * Sends a message to the model and returns the response in chunks. @@ -370,6 +209,7 @@ export class GeminiChat { * ``` */ async sendMessageStream( + model: string, params: SendMessageParameters, prompt_id: string, ): Promise> { @@ -417,6 +257,7 @@ export class GeminiChat { } const stream = await self.makeApiCallAndProcessStream( + model, requestContents, params, prompt_id, @@ -481,13 +322,17 @@ export class GeminiChat { } private async makeApiCallAndProcessStream( + model: string, requestContents: Content[], params: SendMessageParameters, prompt_id: string, userContent: Content, ): Promise> { const apiCall = () => { - const modelToUse = this.config.getModel(); + const modelToUse = getEffectiveModel( + this.config.isInFallbackMode(), + model, + ); if ( this.config.getQuotaErrorOccurred() && @@ -498,7 +343,7 @@ export class GeminiChat { ); } - return this.contentGenerator.generateContentStream( + return this.config.getContentGenerator().generateContentStream( { model: modelToUse, contents: requestContents, @@ -508,6 +353,11 @@ export class GeminiChat { ); }; + const onPersistent429Callback = async ( + authType?: string, + error?: unknown, + ) => await handleFallback(this.config, model, authType, error); + const streamResponse = await retryWithBackoff(apiCall, { shouldRetry: (error: unknown) => { if (error instanceof Error && error.message) { @@ -517,8 +367,7 @@ export class GeminiChat { } return false; }, - onPersistent429: async (authType?: string, error?: unknown) => - await this.handleFlashFallback(authType, error), + onPersistent429: onPersistent429Callback, authType: this.config.getContentGeneratorConfig()?.authType, }); @@ -570,10 +419,28 @@ export class GeminiChat { addHistory(content: Content): void { this.history.push(content); } + setHistory(history: Content[]): void { this.history = history; } + stripThoughtsFromHistory(): void { + this.history = this.history.map((content) => { + const newContent = { ...content }; + if (newContent.parts) { + newContent.parts = newContent.parts.map((part) => { + if (part && typeof part === 'object' && 'thoughtSignature' in part) { + const newPart = { ...part }; + delete (newPart as { thoughtSignature?: string }).thoughtSignature; + return newPart; + } + return part; + }); + } + return newContent; + }); + } + setTools(tools: Tool[]): void { this.generationConfig.tools = tools; } @@ -701,35 +568,18 @@ export class GeminiChat { this.recordHistory(userInput, modelOutput); } - private recordHistory( - userInput: Content, - modelOutput: Content[], - automaticFunctionCallingHistory?: Content[], - ) { + private recordHistory(userInput: Content, modelOutput: Content[]) { // Part 1: Handle the user's turn. - if ( - automaticFunctionCallingHistory && - automaticFunctionCallingHistory.length > 0 - ) { - this.history.push( - ...extractCuratedHistory(automaticFunctionCallingHistory), - ); - } else { - if ( - this.history.length === 0 || - this.history[this.history.length - 1] !== userInput - ) { - const lastTurn = this.history[this.history.length - 1]; - // The only time we don't push is if it's the *exact same* object, - // which happens in streaming where we add it preemptively. - if (lastTurn !== userInput) { - if (lastTurn?.role === 'user') { - // This is an invalid sequence. - throw new Error('Cannot add a user turn after another user turn.'); - } - this.history.push(userInput); - } + + const lastTurn = this.history[this.history.length - 1]; + // The only time we don't push is if it's the *exact same* object, + // which happens in streaming where we add it preemptively. + if (lastTurn !== userInput) { + if (lastTurn?.role === 'user') { + // This is an invalid sequence. + throw new Error('Cannot add a user turn after another user turn.'); } + this.history.push(userInput); } // Part 2: Process the model output into a final, consolidated list of turns. diff --git a/packages/core/src/core/nonInteractiveToolExecutor.test.ts b/packages/core/src/core/nonInteractiveToolExecutor.test.ts index 864b78a9cf3..4678c61eba6 100644 --- a/packages/core/src/core/nonInteractiveToolExecutor.test.ts +++ b/packages/core/src/core/nonInteractiveToolExecutor.test.ts @@ -46,6 +46,10 @@ describe('executeToolCall', () => { model: 'test-model', authType: 'oauth-personal', }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), storage: { getProjectTempDir: () => '/tmp', }, @@ -88,6 +92,10 @@ describe('executeToolCall', () => { errorType: undefined, outputFile: undefined, resultDisplay: 'Success!', + contentLength: + typeof toolResult.llmContent === 'string' + ? toolResult.llmContent.length + : undefined, responseParts: [ { functionResponse: { @@ -127,6 +135,7 @@ describe('executeToolCall', () => { error: new Error(expectedErrorMessage), errorType: ToolErrorType.TOOL_NOT_REGISTERED, resultDisplay: expectedErrorMessage, + contentLength: expectedErrorMessage.length, responseParts: [ { functionResponse: { @@ -176,6 +185,7 @@ describe('executeToolCall', () => { }, ], resultDisplay: 'Invalid parameters', + contentLength: 'Invalid parameters'.length, }); }); @@ -219,6 +229,7 @@ describe('executeToolCall', () => { }, ], resultDisplay: 'Execution failed', + contentLength: 'Execution failed'.length, }); }); @@ -246,6 +257,7 @@ describe('executeToolCall', () => { error: new Error('Something went very wrong'), errorType: ToolErrorType.UNHANDLED_EXCEPTION, resultDisplay: 'Something went very wrong', + contentLength: 'Something went very wrong'.length, responseParts: [ { functionResponse: { @@ -288,6 +300,7 @@ describe('executeToolCall', () => { errorType: undefined, outputFile: undefined, resultDisplay: 'Image processed', + contentLength: undefined, responseParts: [ { functionResponse: { @@ -302,4 +315,56 @@ describe('executeToolCall', () => { ], }); }); + + it('should calculate contentLength for a string llmContent', async () => { + const request: ToolCallRequestInfo = { + callId: 'call7', + name: 'testTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-7', + }; + const toolResult: ToolResult = { + llmContent: 'This is a test string.', + returnDisplay: 'String returned', + }; + vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); + mockTool.executeFn.mockReturnValue(toolResult); + + const response = await executeToolCall( + mockConfig, + request, + abortController.signal, + ); + + expect(response.contentLength).toBe( + typeof toolResult.llmContent === 'string' + ? toolResult.llmContent.length + : undefined, + ); + }); + + it('should have undefined contentLength for array llmContent with no string parts', async () => { + const request: ToolCallRequestInfo = { + callId: 'call8', + name: 'testTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-8', + }; + const toolResult: ToolResult = { + llmContent: [{ inlineData: { mimeType: 'image/png', data: 'fakedata' } }], + returnDisplay: 'Image data returned', + }; + vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); + mockTool.executeFn.mockReturnValue(toolResult); + + const response = await executeToolCall( + mockConfig, + request, + abortController.signal, + ); + + expect(response.contentLength).toBeUndefined(); + }); }); diff --git a/packages/core/src/core/subagent.test.ts b/packages/core/src/core/subagent.test.ts index cc54037badb..065aeb27918 100644 --- a/packages/core/src/core/subagent.test.ts +++ b/packages/core/src/core/subagent.test.ts @@ -45,6 +45,7 @@ vi.mock('../ide/ide-client.js'); async function createMockConfig( toolRegistryMocks = {}, + configParameters: Partial = {}, ): Promise<{ config: Config; toolRegistry: ToolRegistry }> { const configParams: ConfigParameters = { sessionId: 'test-session', @@ -52,11 +53,10 @@ async function createMockConfig( targetDir: '.', debugMode: false, cwd: process.cwd(), + ...configParameters, }; const config = new Config(configParams); await config.initialize(); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - await config.refreshAuth('test-auth' as any); // Mock ToolRegistry const mockToolRegistry = { @@ -164,15 +164,13 @@ describe('subagent.ts', () => { // Helper to safely access generationConfig from mock calls const getGenerationConfigFromMock = ( callIndex = 0, - ): GenerateContentConfig & { systemInstruction?: string | Content } => { + ): GenerateContentConfig => { const callArgs = vi.mocked(GeminiChat).mock.calls[callIndex]; - const generationConfig = callArgs?.[2]; + const generationConfig = callArgs?.[1]; // Ensure it's defined before proceeding expect(generationConfig).toBeDefined(); if (!generationConfig) throw new Error('generationConfig is undefined'); - return generationConfig as GenerateContentConfig & { - systemInstruction?: string | Content; - }; + return generationConfig as GenerateContentConfig; }; describe('create (Tool Validation)', () => { @@ -347,7 +345,7 @@ describe('subagent.ts', () => { ); // Check History (should include environment context) - const history = callArgs[3]; + const history = callArgs[2]; expect(history).toEqual([ { role: 'user', parts: [{ text: 'Env Context' }] }, { @@ -420,7 +418,7 @@ describe('subagent.ts', () => { const callArgs = vi.mocked(GeminiChat).mock.calls[0]; const generationConfig = getGenerationConfigFromMock(); - const history = callArgs[3]; + const history = callArgs[2]; expect(generationConfig.systemInstruction).toBeUndefined(); expect(history).toEqual([ @@ -503,7 +501,7 @@ describe('subagent.ts', () => { expect(scope.output.emitted_vars).toEqual({}); expect(mockSendMessageStream).toHaveBeenCalledTimes(1); // Check the initial message - expect(mockSendMessageStream.mock.calls[0][0].message).toEqual([ + expect(mockSendMessageStream.mock.calls[0][1].message).toEqual([ { text: 'Get Started!' }, ]); }); @@ -547,7 +545,7 @@ describe('subagent.ts', () => { expect(mockSendMessageStream).toHaveBeenCalledTimes(1); // Check the tool response sent back in the second call - const secondCallArgs = mockSendMessageStream.mock.calls[0][0]; + const secondCallArgs = mockSendMessageStream.mock.calls[0][1]; expect(secondCallArgs.message).toEqual([{ text: 'Get Started!' }]); }); @@ -609,7 +607,7 @@ describe('subagent.ts', () => { ); // Check the response sent back to the model - const secondCallArgs = mockSendMessageStream.mock.calls[1][0]; + const secondCallArgs = mockSendMessageStream.mock.calls[1][1]; expect(secondCallArgs.message).toEqual([ { text: 'file1.txt\nfile2.ts' }, ]); @@ -657,7 +655,7 @@ describe('subagent.ts', () => { await scope.runNonInteractive(new ContextState()); // The agent should send the specific error message from responseParts. - const secondCallArgs = mockSendMessageStream.mock.calls[1][0]; + const secondCallArgs = mockSendMessageStream.mock.calls[1][1]; expect(secondCallArgs.message).toEqual([ { @@ -703,7 +701,7 @@ describe('subagent.ts', () => { await scope.runNonInteractive(new ContextState()); // Check the nudge message sent in Turn 2 - const secondCallArgs = mockSendMessageStream.mock.calls[1][0]; + const secondCallArgs = mockSendMessageStream.mock.calls[1][1]; // We check that the message contains the required variable name and the nudge phrasing. expect(secondCallArgs.message[0].text).toContain('required_var'); @@ -771,7 +769,7 @@ describe('subagent.ts', () => { // Use fake timers to reliably test timeouts vi.useFakeTimers(); - const { config } = await createMockConfig(); + const { config } = await createMockConfig({}, { useRipgrep: false }); const runConfig: RunConfig = { max_time_minutes: 5, max_turns: 100 }; // We need to control the resolution of the sendMessageStream promise to advance the timer during execution. @@ -816,7 +814,7 @@ describe('subagent.ts', () => { }); it('should terminate with ERROR if the model call throws', async () => { - const { config } = await createMockConfig(); + const { config } = await createMockConfig({}, { useRipgrep: false }); mockSendMessageStream.mockRejectedValue(new Error('API Failure')); const scope = await SubAgentScope.create( diff --git a/packages/core/src/core/subagent.ts b/packages/core/src/core/subagent.ts index 41de5978a16..15cf5af9108 100644 --- a/packages/core/src/core/subagent.ts +++ b/packages/core/src/core/subagent.ts @@ -10,7 +10,6 @@ import type { AnyDeclarativeTool } from '../tools/tools.js'; import type { Config } from '../config/config.js'; import type { ToolCallRequestInfo } from './turn.js'; import { executeToolCall } from './nonInteractiveToolExecutor.js'; -import { createContentGenerator } from './contentGenerator.js'; import { getEnvironmentContext } from '../utils/environmentContext.js'; import type { Content, @@ -431,6 +430,7 @@ export class SubAgentScope { }; const responseStream = await chat.sendMessageStream( + this.modelConfig.model, messageParams, promptId, ); @@ -635,9 +635,7 @@ export class SubAgentScope { : undefined; try { - const generationConfig: GenerateContentConfig & { - systemInstruction?: string | Content; - } = { + const generationConfig: GenerateContentConfig = { temperature: this.modelConfig.temp, topP: this.modelConfig.top_p, }; @@ -646,17 +644,10 @@ export class SubAgentScope { generationConfig.systemInstruction = systemInstruction; } - const contentGenerator = await createContentGenerator( - this.runtimeContext.getContentGeneratorConfig(), - this.runtimeContext, - this.runtimeContext.getSessionId(), - ); - this.runtimeContext.setModel(this.modelConfig.model); return new GeminiChat( this.runtimeContext, - contentGenerator, generationConfig, start_history, ); diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index 16fdd90fd97..d3451166a9d 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -97,6 +97,7 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Hi' }]; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -104,6 +105,7 @@ describe('Turn', () => { } expect(mockSendMessageStream).toHaveBeenCalledWith( + 'test-model', { message: reqParts, config: { abortSignal: expect.any(AbortSignal) }, @@ -144,6 +146,7 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Use tools' }]; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -206,7 +209,11 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test abort' }]; - for await (const event of turn.run(reqParts, abortController.signal)) { + for await (const event of turn.run( + 'test-model', + reqParts, + abortController.signal, + )) { events.push(event); } expect(events).toEqual([ @@ -227,6 +234,7 @@ describe('Turn', () => { mockMaybeIncludeSchemaDepthContext.mockResolvedValue(undefined); const events = []; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -267,6 +275,7 @@ describe('Turn', () => { const events = []; for await (const event of turn.run( + 'test-model', [{ text: 'Test undefined tool parts' }], new AbortController().signal, )) { @@ -323,6 +332,7 @@ describe('Turn', () => { const events = []; for await (const event of turn.run( + 'test-model', [{ text: 'Test finish reason' }], new AbortController().signal, )) { @@ -370,6 +380,7 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Generate long text' }]; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -407,6 +418,7 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test safety' }]; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -443,6 +455,7 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test no finish reason' }]; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -487,6 +500,7 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test multiple responses' }]; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -529,6 +543,7 @@ describe('Turn', () => { const events = []; for await (const event of turn.run( + 'test-model', [{ text: 'Test citations' }], new AbortController().signal, )) { @@ -578,6 +593,7 @@ describe('Turn', () => { const events = []; for await (const event of turn.run( + 'test-model', [{ text: 'test' }], new AbortController().signal, )) { @@ -624,6 +640,7 @@ describe('Turn', () => { const events = []; for await (const event of turn.run( + 'test-model', [{ text: 'test' }], new AbortController().signal, )) { @@ -669,6 +686,7 @@ describe('Turn', () => { const events = []; for await (const event of turn.run( + 'test-model', [{ text: 'test' }], new AbortController().signal, )) { @@ -705,7 +723,11 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test malformed error handling' }]; - for await (const event of turn.run(reqParts, abortController.signal)) { + for await (const event of turn.run( + 'test-model', + reqParts, + abortController.signal, + )) { events.push(event); } @@ -727,7 +749,11 @@ describe('Turn', () => { mockSendMessageStream.mockResolvedValue(mockResponseStream); const events = []; - for await (const event of turn.run([], new AbortController().signal)) { + for await (const event of turn.run( + 'test-model', + [], + new AbortController().signal, + )) { events.push(event); } @@ -752,7 +778,11 @@ describe('Turn', () => { })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); const reqParts: Part[] = [{ text: 'Hi' }]; - for await (const _ of turn.run(reqParts, new AbortController().signal)) { + for await (const _ of turn.run( + 'test-model', + reqParts, + new AbortController().signal, + )) { // consume stream } expect(turn.getDebugResponses()).toEqual([resp1, resp2]); diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 5ca3ee05cc4..8f38f7d2ed1 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -92,6 +92,7 @@ export interface ToolCallResponseInfo { error: Error | undefined; errorType: ToolErrorType | undefined; outputFile?: string | undefined; + contentLength?: number; } export interface ServerToolCallConfirmationDetails { @@ -210,6 +211,7 @@ export class Turn { ) {} // The run method yields simpler events suitable for server logic async *run( + model: string, req: PartListUnion, signal: AbortSignal, ): AsyncGenerator { @@ -217,6 +219,7 @@ export class Turn { // Note: This assumes `sendMessageStream` yields events like // { type: StreamEventType.RETRY } or { type: StreamEventType.CHUNK, value: GenerateContentResponse } const responseStream = await this.chat.sendMessageStream( + model, { message: req, config: { diff --git a/packages/core/src/fallback/handler.test.ts b/packages/core/src/fallback/handler.test.ts new file mode 100644 index 00000000000..77c9375644d --- /dev/null +++ b/packages/core/src/fallback/handler.test.ts @@ -0,0 +1,218 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + type Mock, + type MockInstance, + afterEach, +} from 'vitest'; +import { handleFallback } from './handler.js'; +import type { Config } from '../config/config.js'; +import { AuthType } from '../core/contentGenerator.js'; +import { + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_MODEL, +} from '../config/models.js'; +import { logFlashFallback } from '../telemetry/index.js'; +import type { FallbackModelHandler } from './types.js'; + +// Mock the telemetry logger and event class +vi.mock('../telemetry/index.js', () => ({ + logFlashFallback: vi.fn(), + FlashFallbackEvent: class {}, +})); + +const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL; +const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL; +const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE; +const AUTH_API_KEY = AuthType.USE_GEMINI; + +const createMockConfig = (overrides: Partial = {}): Config => + ({ + isInFallbackMode: vi.fn(() => false), + setFallbackMode: vi.fn(), + fallbackHandler: undefined, + ...overrides, + }) as unknown as Config; + +describe('handleFallback', () => { + let mockConfig: Config; + let mockHandler: Mock; + let consoleErrorSpy: MockInstance; + + beforeEach(() => { + vi.clearAllMocks(); + mockHandler = vi.fn(); + // Default setup: OAuth user, Pro model failed, handler injected + mockConfig = createMockConfig({ + fallbackModelHandler: mockHandler, + }); + consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + }); + + afterEach(() => { + consoleErrorSpy.mockRestore(); + }); + + it('should return null immediately if authType is not OAuth', async () => { + const result = await handleFallback( + mockConfig, + MOCK_PRO_MODEL, + AUTH_API_KEY, + ); + expect(result).toBeNull(); + expect(mockHandler).not.toHaveBeenCalled(); + expect(mockConfig.setFallbackMode).not.toHaveBeenCalled(); + }); + + it('should return null if the failed model is already the fallback model', async () => { + const result = await handleFallback( + mockConfig, + FALLBACK_MODEL, // Failed model is Flash + AUTH_OAUTH, + ); + expect(result).toBeNull(); + expect(mockHandler).not.toHaveBeenCalled(); + }); + + it('should return null if no fallbackHandler is injected in config', async () => { + const configWithoutHandler = createMockConfig({ + fallbackModelHandler: undefined, + }); + const result = await handleFallback( + configWithoutHandler, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + expect(result).toBeNull(); + }); + + describe('when handler returns "retry"', () => { + it('should activate fallback mode, log telemetry, and return true', async () => { + mockHandler.mockResolvedValue('retry'); + + const result = await handleFallback( + mockConfig, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + + expect(result).toBe(true); + expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true); + expect(logFlashFallback).toHaveBeenCalled(); + }); + }); + + describe('when handler returns "stop"', () => { + it('should activate fallback mode, log telemetry, and return false', async () => { + mockHandler.mockResolvedValue('stop'); + + const result = await handleFallback( + mockConfig, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + + expect(result).toBe(false); + expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true); + expect(logFlashFallback).toHaveBeenCalled(); + }); + }); + + describe('when handler returns "auth"', () => { + it('should NOT activate fallback mode and return false', async () => { + mockHandler.mockResolvedValue('auth'); + + const result = await handleFallback( + mockConfig, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + + expect(result).toBe(false); + expect(mockConfig.setFallbackMode).not.toHaveBeenCalled(); + expect(logFlashFallback).not.toHaveBeenCalled(); + }); + }); + + describe('when handler returns an unexpected value', () => { + it('should log an error and return null', async () => { + mockHandler.mockResolvedValue(null); + + const result = await handleFallback( + mockConfig, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + + expect(result).toBeNull(); + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Fallback UI handler failed:', + new Error( + 'Unexpected fallback intent received from fallbackModelHandler: "null"', + ), + ); + expect(mockConfig.setFallbackMode).not.toHaveBeenCalled(); + }); + }); + + it('should pass the correct context (failedModel, fallbackModel, error) to the handler', async () => { + const mockError = new Error('Quota Exceeded'); + mockHandler.mockResolvedValue('retry'); + + await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH, mockError); + + expect(mockHandler).toHaveBeenCalledWith( + MOCK_PRO_MODEL, + FALLBACK_MODEL, + mockError, + ); + }); + + it('should not call setFallbackMode or log telemetry if already in fallback mode', async () => { + // Setup config where fallback mode is already active + const activeFallbackConfig = createMockConfig({ + fallbackModelHandler: mockHandler, + isInFallbackMode: vi.fn(() => true), // Already active + setFallbackMode: vi.fn(), + }); + + mockHandler.mockResolvedValue('retry'); + + const result = await handleFallback( + activeFallbackConfig, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + + // Should still return true to allow the retry (which will use the active fallback mode) + expect(result).toBe(true); + // Should still consult the handler + expect(mockHandler).toHaveBeenCalled(); + // But should not mutate state or log telemetry again + expect(activeFallbackConfig.setFallbackMode).not.toHaveBeenCalled(); + expect(logFlashFallback).not.toHaveBeenCalled(); + }); + + it('should catch errors from the handler, log an error, and return null', async () => { + const handlerError = new Error('UI interaction failed'); + mockHandler.mockRejectedValue(handlerError); + + const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH); + + expect(result).toBeNull(); + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Fallback UI handler failed:', + handlerError, + ); + expect(mockConfig.setFallbackMode).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/core/src/fallback/handler.ts b/packages/core/src/fallback/handler.ts new file mode 100644 index 00000000000..762552cd2d2 --- /dev/null +++ b/packages/core/src/fallback/handler.ts @@ -0,0 +1,69 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../config/config.js'; +import { AuthType } from '../core/contentGenerator.js'; +import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js'; + +export async function handleFallback( + config: Config, + failedModel: string, + authType?: string, + error?: unknown, +): Promise { + // Applicability Checks + if (authType !== AuthType.LOGIN_WITH_GOOGLE) return null; + + const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; + + if (failedModel === fallbackModel) return null; + + // Consult UI Handler for Intent + const fallbackModelHandler = config.fallbackModelHandler; + if (typeof fallbackModelHandler !== 'function') return null; + + try { + // Pass the specific failed model to the UI handler. + const intent = await fallbackModelHandler( + failedModel, + fallbackModel, + error, + ); + + // Process Intent and Update State + switch (intent) { + case 'retry': + // Activate fallback mode. The NEXT retry attempt will pick this up. + activateFallbackMode(config, authType); + return true; // Signal retryWithBackoff to continue. + + case 'stop': + activateFallbackMode(config, authType); + return false; + + case 'auth': + return false; + + default: + throw new Error( + `Unexpected fallback intent received from fallbackModelHandler: "${intent}"`, + ); + } + } catch (handlerError) { + console.error('Fallback UI handler failed:', handlerError); + return null; + } +} + +function activateFallbackMode(config: Config, authType: string | undefined) { + if (!config.isInFallbackMode()) { + config.setFallbackMode(true); + if (authType) { + logFlashFallback(config, new FlashFallbackEvent(authType)); + } + } +} diff --git a/packages/core/src/fallback/types.ts b/packages/core/src/fallback/types.ts new file mode 100644 index 00000000000..65431233713 --- /dev/null +++ b/packages/core/src/fallback/types.ts @@ -0,0 +1,23 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Defines the intent returned by the UI layer during a fallback scenario. + */ +export type FallbackIntent = + | 'retry' // Immediately retry the current request with the fallback model. + | 'stop' // Switch to fallback for future requests, but stop the current request. + | 'auth'; // Stop the current request; user intends to change authentication. + +/** + * The interface for the handler provided by the UI layer (e.g., the CLI) + * to interact with the user during a fallback scenario. + */ +export type FallbackModelHandler = ( + failedModel: string, + fallbackModel: string, + error?: unknown, +) => Promise; diff --git a/packages/core/src/ide/constants.ts b/packages/core/src/ide/constants.ts index f1f066c43cc..573b9aec035 100644 --- a/packages/core/src/ide/constants.ts +++ b/packages/core/src/ide/constants.ts @@ -5,3 +5,5 @@ */ export const GEMINI_CLI_COMPANION_EXTENSION_NAME = 'Gemini CLI Companion'; +export const IDE_MAX_OPEN_FILES = 10; +export const IDE_MAX_SELECTED_TEXT_LENGTH = 16384; // 16 KiB limit diff --git a/packages/core/src/ide/ide-client.test.ts b/packages/core/src/ide/ide-client.test.ts index c47753e4200..a2f832ca820 100644 --- a/packages/core/src/ide/ide-client.test.ts +++ b/packages/core/src/ide/ide-client.test.ts @@ -12,6 +12,7 @@ import { beforeEach, afterEach, type Mocked, + type Mock, } from 'vitest'; import { IdeClient, IDEConnectionStatus } from './ide-client.js'; import * as fs from 'node:fs'; @@ -29,11 +30,13 @@ import * as os from 'node:os'; import * as path from 'node:path'; vi.mock('node:fs', async (importOriginal) => { - const actual = await importOriginal(); + const actual = await importOriginal(); return { ...(actual as object), promises: { + ...actual.promises, readFile: vi.fn(), + readdir: vi.fn(), }, realpathSync: (p: string) => p, existsSync: () => false, @@ -80,6 +83,7 @@ describe('IdeClient', () => { close: vi.fn(), setNotificationHandler: vi.fn(), callTool: vi.fn(), + request: vi.fn(), } as unknown as Mocked; mockHttpTransport = { close: vi.fn(), @@ -103,12 +107,17 @@ describe('IdeClient', () => { it('should connect using HTTP when port is provided in config file', async () => { const config = { port: '8080' }; vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([]); const ideClient = await IdeClient.getInstance(); await ideClient.connect(); expect(fs.promises.readFile).toHaveBeenCalledWith( - path.join('/tmp', 'gemini-ide-server-12345.json'), + path.join('/tmp/', 'gemini-ide-server-12345.json'), 'utf8', ); expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( @@ -124,6 +133,11 @@ describe('IdeClient', () => { it('should connect using stdio when stdio config is provided in file', async () => { const config = { stdio: { command: 'test-cmd', args: ['--foo'] } }; vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([]); const ideClient = await IdeClient.getInstance(); await ideClient.connect(); @@ -144,6 +158,11 @@ describe('IdeClient', () => { stdio: { command: 'test-cmd', args: ['--foo'] }, }; vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([]); const ideClient = await IdeClient.getInstance(); await ideClient.connect(); @@ -159,6 +178,11 @@ describe('IdeClient', () => { vi.mocked(fs.promises.readFile).mockRejectedValue( new Error('File not found'), ); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([]); process.env['GEMINI_CLI_IDE_SERVER_PORT'] = '9090'; const ideClient = await IdeClient.getInstance(); @@ -178,6 +202,11 @@ describe('IdeClient', () => { vi.mocked(fs.promises.readFile).mockRejectedValue( new Error('File not found'), ); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([]); process.env['GEMINI_CLI_IDE_SERVER_STDIO_COMMAND'] = 'env-cmd'; process.env['GEMINI_CLI_IDE_SERVER_STDIO_ARGS'] = '["--bar"]'; @@ -197,6 +226,11 @@ describe('IdeClient', () => { it('should prioritize file config over environment variables', async () => { const config = { port: '8080' }; vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([]); process.env['GEMINI_CLI_IDE_SERVER_PORT'] = '9090'; const ideClient = await IdeClient.getInstance(); @@ -215,6 +249,11 @@ describe('IdeClient', () => { vi.mocked(fs.promises.readFile).mockRejectedValue( new Error('File not found'), ); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([]); const ideClient = await IdeClient.getInstance(); await ideClient.connect(); @@ -229,4 +268,360 @@ describe('IdeClient', () => { ); }); }); + + describe('getConnectionConfigFromFile', () => { + it('should return config from the specific pid file if it exists', async () => { + const config = { port: '1234', workspacePath: '/test/workspace' }; + vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); + + const ideClient = await IdeClient.getInstance(); + // In tests, the private method can be accessed like this. + const result = await ( + ideClient as unknown as { + getConnectionConfigFromFile: () => Promise; + } + ).getConnectionConfigFromFile(); + + expect(result).toEqual(config); + expect(fs.promises.readFile).toHaveBeenCalledWith( + path.join('/tmp', 'gemini-ide-server-12345.json'), + 'utf8', + ); + }); + + it('should return undefined if no config files are found', async () => { + vi.mocked(fs.promises.readFile).mockRejectedValue(new Error('not found')); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([]); + + const ideClient = await IdeClient.getInstance(); + const result = await ( + ideClient as unknown as { + getConnectionConfigFromFile: () => Promise; + } + ).getConnectionConfigFromFile(); + + expect(result).toBeUndefined(); + }); + + it('should find and parse a single config file with the new naming scheme', async () => { + const config = { port: '5678', workspacePath: '/test/workspace' }; + vi.mocked(fs.promises.readFile).mockRejectedValueOnce( + new Error('not found'), + ); // For old path + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue(['gemini-ide-server-12345-123.json']); + vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); + vi.spyOn(IdeClient, 'validateWorkspacePath').mockReturnValue({ + isValid: true, + }); + + const ideClient = await IdeClient.getInstance(); + const result = await ( + ideClient as unknown as { + getConnectionConfigFromFile: () => Promise; + } + ).getConnectionConfigFromFile(); + + expect(result).toEqual(config); + expect(fs.promises.readFile).toHaveBeenCalledWith( + path.join('/tmp/.gemini/ide', 'gemini-ide-server-12345-123.json'), + 'utf8', + ); + }); + + it('should filter out configs with invalid workspace paths', async () => { + const validConfig = { + port: '5678', + workspacePath: '/test/workspace', + }; + const invalidConfig = { + port: '1111', + workspacePath: '/invalid/workspace', + }; + vi.mocked(fs.promises.readFile).mockRejectedValueOnce( + new Error('not found'), + ); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([ + 'gemini-ide-server-12345-111.json', + 'gemini-ide-server-12345-222.json', + ]); + vi.mocked(fs.promises.readFile) + .mockResolvedValueOnce(JSON.stringify(invalidConfig)) + .mockResolvedValueOnce(JSON.stringify(validConfig)); + + const validateSpy = vi + .spyOn(IdeClient, 'validateWorkspacePath') + .mockReturnValueOnce({ isValid: false }) + .mockReturnValueOnce({ isValid: true }); + + const ideClient = await IdeClient.getInstance(); + const result = await ( + ideClient as unknown as { + getConnectionConfigFromFile: () => Promise; + } + ).getConnectionConfigFromFile(); + + expect(result).toEqual(validConfig); + expect(validateSpy).toHaveBeenCalledWith( + '/invalid/workspace', + 'VS Code', + '/test/workspace/sub-dir', + ); + expect(validateSpy).toHaveBeenCalledWith( + '/test/workspace', + 'VS Code', + '/test/workspace/sub-dir', + ); + }); + + it('should return the first valid config when multiple workspaces are valid', async () => { + const config1 = { port: '1111', workspacePath: '/test/workspace' }; + const config2 = { port: '2222', workspacePath: '/test/workspace2' }; + vi.mocked(fs.promises.readFile).mockRejectedValueOnce( + new Error('not found'), + ); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([ + 'gemini-ide-server-12345-111.json', + 'gemini-ide-server-12345-222.json', + ]); + vi.mocked(fs.promises.readFile) + .mockResolvedValueOnce(JSON.stringify(config1)) + .mockResolvedValueOnce(JSON.stringify(config2)); + vi.spyOn(IdeClient, 'validateWorkspacePath').mockReturnValue({ + isValid: true, + }); + + const ideClient = await IdeClient.getInstance(); + const result = await ( + ideClient as unknown as { + getConnectionConfigFromFile: () => Promise; + } + ).getConnectionConfigFromFile(); + + expect(result).toEqual(config1); + }); + + it('should prioritize the config matching the port from the environment variable', async () => { + process.env['GEMINI_CLI_IDE_SERVER_PORT'] = '2222'; + const config1 = { port: '1111', workspacePath: '/test/workspace' }; + const config2 = { port: '2222', workspacePath: '/test/workspace2' }; + vi.mocked(fs.promises.readFile).mockRejectedValueOnce( + new Error('not found'), + ); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([ + 'gemini-ide-server-12345-111.json', + 'gemini-ide-server-12345-222.json', + ]); + vi.mocked(fs.promises.readFile) + .mockResolvedValueOnce(JSON.stringify(config1)) + .mockResolvedValueOnce(JSON.stringify(config2)); + vi.spyOn(IdeClient, 'validateWorkspacePath').mockReturnValue({ + isValid: true, + }); + + const ideClient = await IdeClient.getInstance(); + const result = await ( + ideClient as unknown as { + getConnectionConfigFromFile: () => Promise; + } + ).getConnectionConfigFromFile(); + + expect(result).toEqual(config2); + }); + + it('should handle invalid JSON in one of the config files', async () => { + const validConfig = { port: '2222', workspacePath: '/test/workspace' }; + vi.mocked(fs.promises.readFile).mockRejectedValueOnce( + new Error('not found'), + ); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([ + 'gemini-ide-server-12345-111.json', + 'gemini-ide-server-12345-222.json', + ]); + vi.mocked(fs.promises.readFile) + .mockResolvedValueOnce('invalid json') + .mockResolvedValueOnce(JSON.stringify(validConfig)); + vi.spyOn(IdeClient, 'validateWorkspacePath').mockReturnValue({ + isValid: true, + }); + + const ideClient = await IdeClient.getInstance(); + const result = await ( + ideClient as unknown as { + getConnectionConfigFromFile: () => Promise; + } + ).getConnectionConfigFromFile(); + + expect(result).toEqual(validConfig); + }); + + it('should return undefined if readdir throws an error', async () => { + vi.mocked(fs.promises.readFile).mockRejectedValueOnce( + new Error('not found'), + ); + vi.mocked(fs.promises.readdir).mockRejectedValue( + new Error('readdir failed'), + ); + + const ideClient = await IdeClient.getInstance(); + const result = await ( + ideClient as unknown as { + getConnectionConfigFromFile: () => Promise; + } + ).getConnectionConfigFromFile(); + + expect(result).toBeUndefined(); + }); + + it('should ignore files with invalid names', async () => { + const validConfig = { port: '3333', workspacePath: '/test/workspace' }; + vi.mocked(fs.promises.readFile).mockRejectedValueOnce( + new Error('not found'), + ); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([ + 'gemini-ide-server-12345-111.json', // valid + 'not-a-config-file.txt', // invalid + 'gemini-ide-server-asdf.json', // invalid + ]); + vi.mocked(fs.promises.readFile).mockResolvedValueOnce( + JSON.stringify(validConfig), + ); + vi.spyOn(IdeClient, 'validateWorkspacePath').mockReturnValue({ + isValid: true, + }); + + const ideClient = await IdeClient.getInstance(); + const result = await ( + ideClient as unknown as { + getConnectionConfigFromFile: () => Promise; + } + ).getConnectionConfigFromFile(); + + expect(result).toEqual(validConfig); + expect(fs.promises.readFile).toHaveBeenCalledWith( + path.join('/tmp/.gemini/ide', 'gemini-ide-server-12345-111.json'), + 'utf8', + ); + expect(fs.promises.readFile).not.toHaveBeenCalledWith( + path.join('/tmp/.gemini/ide', 'not-a-config-file.txt'), + 'utf8', + ); + }); + }); + + describe('isDiffingEnabled', () => { + it('should return false if not connected', async () => { + const ideClient = await IdeClient.getInstance(); + expect(ideClient.isDiffingEnabled()).toBe(false); + }); + + it('should return false if tool discovery fails', async () => { + const config = { port: '8080' }; + vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([]); + mockClient.request.mockRejectedValue(new Error('Method not found')); + + const ideClient = await IdeClient.getInstance(); + await ideClient.connect(); + + expect(ideClient.getConnectionStatus().status).toBe( + IDEConnectionStatus.Connected, + ); + expect(ideClient.isDiffingEnabled()).toBe(false); + }); + + it('should return false if diffing tools are not available', async () => { + const config = { port: '8080' }; + vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([]); + mockClient.request.mockResolvedValue({ + tools: [{ name: 'someOtherTool' }], + }); + + const ideClient = await IdeClient.getInstance(); + await ideClient.connect(); + + expect(ideClient.getConnectionStatus().status).toBe( + IDEConnectionStatus.Connected, + ); + expect(ideClient.isDiffingEnabled()).toBe(false); + }); + + it('should return false if only openDiff tool is available', async () => { + const config = { port: '8080' }; + vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([]); + mockClient.request.mockResolvedValue({ + tools: [{ name: 'openDiff' }], + }); + + const ideClient = await IdeClient.getInstance(); + await ideClient.connect(); + + expect(ideClient.getConnectionStatus().status).toBe( + IDEConnectionStatus.Connected, + ); + expect(ideClient.isDiffingEnabled()).toBe(false); + }); + + it('should return true if connected and diffing tools are available', async () => { + const config = { port: '8080' }; + vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); + ( + vi.mocked(fs.promises.readdir) as Mock< + (path: fs.PathLike) => Promise + > + ).mockResolvedValue([]); + mockClient.request.mockResolvedValue({ + tools: [{ name: 'openDiff' }, { name: 'closeDiff' }], + }); + + const ideClient = await IdeClient.getInstance(); + await ideClient.connect(); + + expect(ideClient.getConnectionStatus().status).toBe( + IDEConnectionStatus.Connected, + ); + expect(ideClient.isDiffingEnabled()).toBe(true); + }); + }); }); diff --git a/packages/core/src/ide/ide-client.ts b/packages/core/src/ide/ide-client.ts index 0c860594055..7c4b50d8ba6 100644 --- a/packages/core/src/ide/ide-client.ts +++ b/packages/core/src/ide/ide-client.ts @@ -7,14 +7,14 @@ import * as fs from 'node:fs'; import { isSubpath } from '../utils/paths.js'; import { detectIde, type DetectedIde, getIdeInfo } from '../ide/detect-ide.js'; -import type { DiffUpdateResult } from '../ide/ideContext.js'; import { - ideContext, - IdeContextNotificationSchema, + ideContextStore, IdeDiffAcceptedNotificationSchema, IdeDiffClosedNotificationSchema, CloseDiffResponseSchema, -} from '../ide/ideContext.js'; + type DiffUpdateResult, +} from './ideContext.js'; +import { IdeContextNotificationSchema } from './types.js'; import { getIdeProcessInfo } from './process-utils.js'; import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; @@ -22,6 +22,7 @@ import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' import * as os from 'node:os'; import * as path from 'node:path'; import { EnvHttpProxyAgent } from 'undici'; +import { ListToolsResultSchema } from '@modelcontextprotocol/sdk/types.js'; const logger = { // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -78,6 +79,13 @@ export class IdeClient { private diffResponses = new Map void>(); private statusListeners = new Set<(state: IDEConnectionState) => void>(); private trustChangeListeners = new Set<(isTrusted: boolean) => void>(); + private availableTools: string[] = []; + /** + * A mutex to ensure that only one diff view is open in the IDE at a time. + * This prevents race conditions and UI issues in IDEs like VSCode that + * can't handle multiple diff views being opened simultaneously. + */ + private diffMutex = Promise.resolve(); private constructor() {} @@ -202,10 +210,16 @@ export class IdeClient { filePath: string, newContent?: string, ): Promise { - return new Promise((resolve, reject) => { + const release = await this.acquireMutex(); + + const promise = new Promise((resolve, reject) => { + if (!this.client) { + // The promise will be rejected, and the finally block below will release the mutex. + return reject(new Error('IDE client is not connected.')); + } this.diffResponses.set(filePath, resolve); this.client - ?.callTool({ + .callTool({ name: `openDiff`, arguments: { filePath, @@ -214,9 +228,42 @@ export class IdeClient { }) .catch((err) => { logger.debug(`callTool for ${filePath} failed:`, err); + this.diffResponses.delete(filePath); reject(err); }); }); + + // Ensure the mutex is released only after the diff interaction is complete. + promise.finally(release); + + return promise; + } + + /** + * Acquires a lock to ensure sequential execution of critical sections. + * + * This method implements a promise-based mutex. It works by chaining promises. + * Each call to `acquireMutex` gets the current `diffMutex` promise. It then + * creates a *new* promise (`newMutex`) that will be resolved when the caller + * invokes the returned `release` function. The `diffMutex` is immediately + * updated to this `newMutex`. + * + * The method returns a promise that resolves with the `release` function only + * *after* the *previous* `diffMutex` promise has resolved. This creates a + * queue where each subsequent operation must wait for the previous one to release + * the lock. + * + * @returns A promise that resolves to a function that must be called to + * release the lock. + */ + private acquireMutex(): Promise<() => void> { + let release: () => void; + const newMutex = new Promise((resolve) => { + release = resolve; + }); + const oldMutex = this.diffMutex; + this.diffMutex = newMutex; + return oldMutex.then(() => release); } async closeDiff( @@ -289,6 +336,53 @@ export class IdeClient { return this.currentIdeDisplayName; } + isDiffingEnabled(): boolean { + return ( + !!this.client && + this.state.status === IDEConnectionStatus.Connected && + this.availableTools.includes('openDiff') && + this.availableTools.includes('closeDiff') + ); + } + + private async discoverTools(): Promise { + if (!this.client) { + return; + } + try { + logger.debug('Discovering tools from IDE...'); + const response = await this.client.request( + { method: 'tools/list', params: {} }, + ListToolsResultSchema, + ); + + // Map the array of tool objects to an array of tool names (strings) + this.availableTools = response.tools.map((tool) => tool.name); + + if (this.availableTools.length > 0) { + logger.debug( + `Discovered ${this.availableTools.length} tools from IDE: ${this.availableTools.join(', ')}`, + ); + } else { + logger.debug( + 'IDE supports tool discovery, but no tools are available.', + ); + } + } catch (error) { + // It's okay if this fails, the IDE might not support it. + // Don't log an error if the method is not found, which is a common case. + if ( + error instanceof Error && + !error.message?.includes('Method not found') + ) { + logger.error(`Error discovering tools from IDE: ${error.message}`); + } else { + logger.debug('IDE does not support tool discovery.'); + } + this.availableTools = []; + } + } + private setState( status: IDEConnectionStatus, details?: string, @@ -317,7 +411,7 @@ export class IdeClient { } if (status === IDEConnectionStatus.Disconnected) { - ideContext.clearIdeContext(); + ideContextStore.clear(); } } @@ -396,8 +490,10 @@ export class IdeClient { (ConnectionConfig & { workspacePath?: string }) | undefined > { if (!this.ideProcessInfo) { - return {}; + return undefined; } + + // For backwards compatability try { const portFile = path.join( os.tmpdir(), @@ -406,8 +502,82 @@ export class IdeClient { const portFileContents = await fs.promises.readFile(portFile, 'utf8'); return JSON.parse(portFileContents); } catch (_) { + // For newer extension versions, the file name matches the pattern + // /^gemini-ide-server-${pid}-\d+\.json$/. If multiple IDE + // windows are open, multiple files matching the pattern are expected to + // exist. + } + + const portFileDir = path.join(os.tmpdir(), '.gemini', 'ide'); + let portFiles; + try { + portFiles = await fs.promises.readdir(portFileDir); + } catch (e) { + logger.debug('Failed to read IDE connection directory:', e); + return undefined; + } + + const fileRegex = new RegExp( + `^gemini-ide-server-${this.ideProcessInfo.pid}-\\d+\\.json$`, + ); + const matchingFiles = portFiles + .filter((file) => fileRegex.test(file)) + .sort(); + if (matchingFiles.length === 0) { + return undefined; + } + + let fileContents: string[]; + try { + fileContents = await Promise.all( + matchingFiles.map((file) => + fs.promises.readFile(path.join(portFileDir, file), 'utf8'), + ), + ); + } catch (e) { + logger.debug('Failed to read IDE connection config file(s):', e); + return undefined; + } + const parsedContents = fileContents.map((content) => { + try { + return JSON.parse(content); + } catch (e) { + logger.debug('Failed to parse JSON from config file: ', e); + return undefined; + } + }); + + const validWorkspaces = parsedContents.filter((content) => { + if (!content) { + return false; + } + const { isValid } = IdeClient.validateWorkspacePath( + content.workspacePath, + this.currentIdeDisplayName, + process.cwd(), + ); + return isValid; + }); + + if (validWorkspaces.length === 0) { return undefined; } + + if (validWorkspaces.length === 1) { + return validWorkspaces[0]; + } + + const portFromEnv = this.getPortFromEnv(); + if (portFromEnv) { + const matchingPort = validWorkspaces.find( + (content) => content.port === portFromEnv, + ); + if (matchingPort) { + return matchingPort; + } + } + + return validWorkspaces[0]; } private createProxyAwareFetch() { @@ -441,7 +611,7 @@ export class IdeClient { this.client.setNotificationHandler( IdeContextNotificationSchema, (notification) => { - ideContext.setIdeContext(notification.params); + ideContextStore.set(notification.params); const isTrusted = notification.params.workspaceState?.isTrusted; if (isTrusted !== undefined) { for (const listener of this.trustChangeListeners) { @@ -510,6 +680,7 @@ export class IdeClient { ); await this.client.connect(transport); this.registerClientHandlers(); + await this.discoverTools(); this.setState(IDEConnectionStatus.Connected); return true; } catch (_error) { @@ -543,6 +714,7 @@ export class IdeClient { }); await this.client.connect(transport); this.registerClientHandlers(); + await this.discoverTools(); this.setState(IDEConnectionStatus.Connected); return true; } catch (_error) { diff --git a/packages/core/src/ide/ideContext.test.ts b/packages/core/src/ide/ideContext.test.ts index 7e01d3aadb6..0e3c7ec53ba 100644 --- a/packages/core/src/ide/ideContext.test.ts +++ b/packages/core/src/ide/ideContext.test.ts @@ -4,24 +4,34 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, beforeEach, vi } from 'vitest'; import { - createIdeContextStore, + IDE_MAX_OPEN_FILES, + IDE_MAX_SELECTED_TEXT_LENGTH, +} from './constants.js'; +import { describe, it, expect, beforeEach, vi, afterEach } from 'vitest'; +import { IdeContextStore } from './ideContext.js'; +import { + type IdeContext, FileSchema, IdeContextSchema, -} from './ideContext.js'; + type File, +} from './types.js'; describe('ideContext', () => { describe('createIdeContextStore', () => { - let ideContext: ReturnType; + let ideContextStore: IdeContextStore; beforeEach(() => { // Create a fresh, isolated instance for each test - ideContext = createIdeContextStore(); + ideContextStore = new IdeContextStore(); + }); + + afterEach(() => { + vi.restoreAllMocks(); }); it('should return undefined initially for ide context', () => { - expect(ideContext.getIdeContext()).toBeUndefined(); + expect(ideContextStore.get()).toBeUndefined(); }); it('should set and retrieve the ide context', () => { @@ -38,9 +48,9 @@ describe('ideContext', () => { }, }; - ideContext.setIdeContext(testFile); + ideContextStore.set(testFile); - const activeFile = ideContext.getIdeContext(); + const activeFile = ideContextStore.get(); expect(activeFile).toEqual(testFile); }); @@ -57,7 +67,7 @@ describe('ideContext', () => { ], }, }; - ideContext.setIdeContext(firstFile); + ideContextStore.set(firstFile); const secondFile = { workspaceState: { @@ -71,9 +81,9 @@ describe('ideContext', () => { ], }, }; - ideContext.setIdeContext(secondFile); + ideContextStore.set(secondFile); - const activeFile = ideContext.getIdeContext(); + const activeFile = ideContextStore.get(); expect(activeFile).toEqual(secondFile); }); @@ -90,16 +100,16 @@ describe('ideContext', () => { ], }, }; - ideContext.setIdeContext(testFile); - expect(ideContext.getIdeContext()).toEqual(testFile); + ideContextStore.set(testFile); + expect(ideContextStore.get()).toEqual(testFile); }); it('should notify subscribers when ide context changes', () => { const subscriber1 = vi.fn(); const subscriber2 = vi.fn(); - ideContext.subscribeToIdeContext(subscriber1); - ideContext.subscribeToIdeContext(subscriber2); + ideContextStore.subscribe(subscriber1); + ideContextStore.subscribe(subscriber2); const testFile = { workspaceState: { @@ -113,7 +123,7 @@ describe('ideContext', () => { ], }, }; - ideContext.setIdeContext(testFile); + ideContextStore.set(testFile); expect(subscriber1).toHaveBeenCalledTimes(1); expect(subscriber1).toHaveBeenCalledWith(testFile); @@ -133,7 +143,7 @@ describe('ideContext', () => { ], }, }; - ideContext.setIdeContext(newFile); + ideContextStore.set(newFile); expect(subscriber1).toHaveBeenCalledTimes(2); expect(subscriber1).toHaveBeenCalledWith(newFile); @@ -145,10 +155,10 @@ describe('ideContext', () => { const subscriber1 = vi.fn(); const subscriber2 = vi.fn(); - const unsubscribe1 = ideContext.subscribeToIdeContext(subscriber1); - ideContext.subscribeToIdeContext(subscriber2); + const unsubscribe1 = ideContextStore.subscribe(subscriber1); + ideContextStore.subscribe(subscriber2); - ideContext.setIdeContext({ + ideContextStore.set({ workspaceState: { openFiles: [ { @@ -165,7 +175,7 @@ describe('ideContext', () => { unsubscribe1(); - ideContext.setIdeContext({ + ideContextStore.set({ workspaceState: { openFiles: [ { @@ -195,13 +205,152 @@ describe('ideContext', () => { }, }; - ideContext.setIdeContext(testFile); + ideContextStore.set(testFile); + + expect(ideContextStore.get()).toEqual(testFile); + + ideContextStore.clear(); + + expect(ideContextStore.get()).toBeUndefined(); + }); + + it('should set the context and notify subscribers when no workspaceState is present', () => { + const subscriber = vi.fn(); + ideContextStore.subscribe(subscriber); + const context: IdeContext = {}; + ideContextStore.set(context); + expect(ideContextStore.get()).toBe(context); + expect(subscriber).toHaveBeenCalledWith(context); + }); + + it('should handle an empty openFiles array', () => { + const context: IdeContext = { + workspaceState: { + openFiles: [], + }, + }; + ideContextStore.set(context); + expect(ideContextStore.get()?.workspaceState?.openFiles).toEqual([]); + }); + + it('should sort openFiles by timestamp in descending order', () => { + const context: IdeContext = { + workspaceState: { + openFiles: [ + { path: 'file1.ts', timestamp: 100, isActive: false }, + { path: 'file2.ts', timestamp: 300, isActive: true }, + { path: 'file3.ts', timestamp: 200, isActive: false }, + ], + }, + }; + ideContextStore.set(context); + const openFiles = ideContextStore.get()?.workspaceState?.openFiles; + expect(openFiles?.[0]?.path).toBe('file2.ts'); + expect(openFiles?.[1]?.path).toBe('file3.ts'); + expect(openFiles?.[2]?.path).toBe('file1.ts'); + }); + + it('should mark only the most recent file as active and clear other active files', () => { + const context: IdeContext = { + workspaceState: { + openFiles: [ + { + path: 'file1.ts', + timestamp: 100, + isActive: true, + selectedText: 'hello', + }, + { + path: 'file2.ts', + timestamp: 300, + isActive: true, + cursor: { line: 1, character: 1 }, + selectedText: 'hello', + }, + { + path: 'file3.ts', + timestamp: 200, + isActive: false, + selectedText: 'hello', + }, + ], + }, + }; + ideContextStore.set(context); + const openFiles = ideContextStore.get()?.workspaceState?.openFiles; + expect(openFiles?.[0]?.isActive).toBe(true); + expect(openFiles?.[0]?.cursor).toBeDefined(); + expect(openFiles?.[0]?.selectedText).toBeDefined(); + + expect(openFiles?.[1]?.isActive).toBe(false); + expect(openFiles?.[1]?.cursor).toBeUndefined(); + expect(openFiles?.[1]?.selectedText).toBeUndefined(); + + expect(openFiles?.[2]?.isActive).toBe(false); + expect(openFiles?.[2]?.cursor).toBeUndefined(); + expect(openFiles?.[2]?.selectedText).toBeUndefined(); + }); - expect(ideContext.getIdeContext()).toEqual(testFile); + it('should truncate selectedText if it exceeds the max length', () => { + const longText = 'a'.repeat(IDE_MAX_SELECTED_TEXT_LENGTH + 10); + const context: IdeContext = { + workspaceState: { + openFiles: [ + { + path: 'file1.ts', + timestamp: 100, + isActive: true, + selectedText: longText, + }, + ], + }, + }; + ideContextStore.set(context); + const selectedText = + ideContextStore.get()?.workspaceState?.openFiles?.[0]?.selectedText; + expect(selectedText).toHaveLength( + IDE_MAX_SELECTED_TEXT_LENGTH + '... [TRUNCATED]'.length, + ); + expect(selectedText?.endsWith('... [TRUNCATED]')).toBe(true); + }); - ideContext.clearIdeContext(); + it('should not truncate selectedText if it is within the max length', () => { + const shortText = 'a'.repeat(IDE_MAX_SELECTED_TEXT_LENGTH); + const context: IdeContext = { + workspaceState: { + openFiles: [ + { + path: 'file1.ts', + timestamp: 100, + isActive: true, + selectedText: shortText, + }, + ], + }, + }; + ideContextStore.set(context); + const selectedText = + ideContextStore.get()?.workspaceState?.openFiles?.[0]?.selectedText; + expect(selectedText).toBe(shortText); + }); - expect(ideContext.getIdeContext()).toBeUndefined(); + it('should truncate the openFiles list if it exceeds the max length', () => { + const files: File[] = Array.from( + { length: IDE_MAX_OPEN_FILES + 5 }, + (_, i) => ({ + path: `file${i}.ts`, + timestamp: i, + isActive: false, + }), + ); + const context: IdeContext = { + workspaceState: { + openFiles: files, + }, + }; + ideContextStore.set(context); + const openFiles = ideContextStore.get()?.workspaceState?.openFiles; + expect(openFiles).toHaveLength(IDE_MAX_OPEN_FILES); }); }); diff --git a/packages/core/src/ide/ideContext.ts b/packages/core/src/ide/ideContext.ts index 9689c6323e7..cb57b99f4b5 100644 --- a/packages/core/src/ide/ideContext.ts +++ b/packages/core/src/ide/ideContext.ts @@ -5,42 +5,11 @@ */ import { z } from 'zod'; - -/** - * Zod schema for validating a file context from the IDE. - */ -export const FileSchema = z.object({ - path: z.string(), - timestamp: z.number(), - isActive: z.boolean().optional(), - selectedText: z.string().optional(), - cursor: z - .object({ - line: z.number(), - character: z.number(), - }) - .optional(), -}); -export type File = z.infer; - -export const IdeContextSchema = z.object({ - workspaceState: z - .object({ - openFiles: z.array(FileSchema).optional(), - isTrusted: z.boolean().optional(), - }) - .optional(), -}); -export type IdeContext = z.infer; - -/** - * Zod schema for validating the 'ide/contextUpdate' notification from the IDE. - */ -export const IdeContextNotificationSchema = z.object({ - jsonrpc: z.literal('2.0'), - method: z.literal('ide/contextUpdate'), - params: IdeContextSchema, -}); +import { + IDE_MAX_OPEN_FILES, + IDE_MAX_SELECTED_TEXT_LENGTH, +} from './constants.js'; +import type { IdeContext } from './types.js'; export const IdeDiffAcceptedNotificationSchema = z.object({ jsonrpc: z.literal('2.0'), @@ -100,25 +69,18 @@ export type DiffUpdateResult = content: undefined; }; -type IdeContextSubscriber = (ideContext: IdeContext | undefined) => void; +type IdeContextSubscriber = (ideContext?: IdeContext) => void; -/** - * Creates a new store for managing the IDE's context. - * This factory function encapsulates the state and logic, allowing for the creation - * of isolated instances, which is particularly useful for testing. - * - * @returns An object with methods to interact with the IDE context. - */ -export function createIdeContextStore() { - let ideContextState: IdeContext | undefined = undefined; - const subscribers = new Set(); +export class IdeContextStore { + private ideContextState?: IdeContext; + private readonly subscribers = new Set(); /** * Notifies all registered subscribers about the current IDE context. */ - function notifySubscribers(): void { - for (const subscriber of subscribers) { - subscriber(ideContextState); + private notifySubscribers(): void { + for (const subscriber of this.subscribers) { + subscriber(this.ideContextState); } } @@ -126,25 +88,76 @@ export function createIdeContextStore() { * Sets the IDE context and notifies all registered subscribers of the change. * @param newIdeContext The new IDE context from the IDE. */ - function setIdeContext(newIdeContext: IdeContext): void { - ideContextState = newIdeContext; - notifySubscribers(); + set(newIdeContext: IdeContext): void { + const { workspaceState } = newIdeContext; + if (!workspaceState) { + this.ideContextState = newIdeContext; + this.notifySubscribers(); + return; + } + + const { openFiles } = workspaceState; + + if (openFiles && openFiles.length > 0) { + // Sort by timestamp descending (newest first) + openFiles.sort((a, b) => b.timestamp - a.timestamp); + + // The most recent file is now at index 0. + const mostRecentFile = openFiles[0]; + + // If the most recent file is not active, then no file is active. + if (!mostRecentFile.isActive) { + openFiles.forEach((file) => { + file.isActive = false; + file.cursor = undefined; + file.selectedText = undefined; + }); + } else { + // The most recent file is active. Ensure it's the only one. + openFiles.forEach((file, index: number) => { + if (index !== 0) { + file.isActive = false; + file.cursor = undefined; + file.selectedText = undefined; + } + }); + + // Truncate selected text in the active file + if ( + mostRecentFile.selectedText && + mostRecentFile.selectedText.length > IDE_MAX_SELECTED_TEXT_LENGTH + ) { + mostRecentFile.selectedText = + mostRecentFile.selectedText.substring( + 0, + IDE_MAX_SELECTED_TEXT_LENGTH, + ) + '... [TRUNCATED]'; + } + } + + // Truncate files list + if (openFiles.length > IDE_MAX_OPEN_FILES) { + workspaceState.openFiles = openFiles.slice(0, IDE_MAX_OPEN_FILES); + } + } + this.ideContextState = newIdeContext; + this.notifySubscribers(); } /** * Clears the IDE context and notifies all registered subscribers of the change. */ - function clearIdeContext(): void { - ideContextState = undefined; - notifySubscribers(); + clear(): void { + this.ideContextState = undefined; + this.notifySubscribers(); } /** * Retrieves the current IDE context. * @returns The `IdeContext` object if a file is active; otherwise, `undefined`. */ - function getIdeContext(): IdeContext | undefined { - return ideContextState; + get(): IdeContext | undefined { + return this.ideContextState; } /** @@ -156,22 +169,15 @@ export function createIdeContextStore() { * @param subscriber The function to be called when the IDE context changes. * @returns A function that, when called, will unsubscribe the provided subscriber. */ - function subscribeToIdeContext(subscriber: IdeContextSubscriber): () => void { - subscribers.add(subscriber); + subscribe(subscriber: IdeContextSubscriber): () => void { + this.subscribers.add(subscriber); return () => { - subscribers.delete(subscriber); + this.subscribers.delete(subscriber); }; } - - return { - setIdeContext, - getIdeContext, - subscribeToIdeContext, - clearIdeContext, - }; } /** * The default, shared instance of the IDE context store for the application. */ -export const ideContext = createIdeContextStore(); +export const ideContextStore = new IdeContextStore(); diff --git a/packages/core/src/ide/process-utils.ts b/packages/core/src/ide/process-utils.ts index ecb93781cc1..170b1df1878 100644 --- a/packages/core/src/ide/process-utils.ts +++ b/packages/core/src/ide/process-utils.ts @@ -44,11 +44,18 @@ async function getProcessInfo(pid: number): Promise<{ ParentProcessId = 0, CommandLine = '', } = JSON.parse(output); - return { parentPid: ParentProcessId, name: Name, command: CommandLine }; + return { + parentPid: ParentProcessId, + name: Name, + command: CommandLine ?? '', + }; } else { const command = `ps -o ppid=,command= -p ${pid}`; const { stdout } = await execAsync(command); const trimmedStdout = stdout.trim(); + if (!trimmedStdout) { + return { parentPid: 0, name: '', command: '' }; + } const ppidString = trimmedStdout.split(/\s+/)[0]; const parentPid = parseInt(ppidString, 10); const fullCommand = trimmedStdout.substring(ppidString.length).trim(); diff --git a/packages/core/src/ide/types.ts b/packages/core/src/ide/types.ts new file mode 100644 index 00000000000..69310eea656 --- /dev/null +++ b/packages/core/src/ide/types.ts @@ -0,0 +1,70 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { z } from 'zod'; + +/** + * A file that is open in the IDE. + */ +export const FileSchema = z.object({ + /** + * The absolute path to the file. + */ + path: z.string(), + /** + * The unix timestamp of when the file was last focused. + */ + timestamp: z.number(), + /** + * Whether the file is the currently active file. Only one file can be active at a time. + */ + isActive: z.boolean().optional(), + /** + * The text that is currently selected in the active file. + */ + selectedText: z.string().optional(), + /** + * The cursor position in the active file. + */ + cursor: z + .object({ + /** + * The 1-based line number. + */ + line: z.number(), + /** + * The 1-based character offset. + */ + character: z.number(), + }) + .optional(), +}); +export type File = z.infer; + +/** + * The context of the IDE. + */ +export const IdeContextSchema = z.object({ + workspaceState: z + .object({ + /** + * The list of files that are currently open. + */ + openFiles: z.array(FileSchema).optional(), + /** + * Whether the workspace is trusted. + */ + isTrusted: z.boolean().optional(), + }) + .optional(), +}); +export type IdeContext = z.infer; + +export const IdeContextNotificationSchema = z.object({ + jsonrpc: z.literal('2.0'), + method: z.literal('ide/contextUpdate'), + params: IdeContextSchema, +}); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 4b8d3aa3e85..3060c074818 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -6,6 +6,9 @@ // Export config export * from './config/config.js'; +export * from './output/types.js'; +export * from './output/json-formatter.js'; +export * from './output/stream-json-formatter.js'; // Export Core Logic export * from './core/client.js'; @@ -20,6 +23,8 @@ export * from './core/geminiRequest.js'; export * from './core/coreToolScheduler.js'; export * from './core/nonInteractiveToolExecutor.js'; +export * from './fallback/types.js'; + export * from './code_assist/codeAssist.js'; export * from './code_assist/oauth2.js'; export * from './code_assist/server.js'; @@ -38,6 +43,7 @@ export * from './utils/quotaErrorDetection.js'; export * from './utils/fileUtils.js'; export * from './utils/retry.js'; export * from './utils/shell-utils.js'; +export * from './utils/terminalSerializer.js'; export * from './utils/systemEncoding.js'; export * from './utils/textUtils.js'; export * from './utils/formatters.js'; @@ -47,7 +53,7 @@ export * from './utils/errorParsing.js'; export * from './utils/workspaceContext.js'; export * from './utils/ignorePatterns.js'; export * from './utils/partUtils.js'; -export * from './utils/ide-trust.js'; +export * from './utils/promptIdContext.js'; // Export services export * from './services/fileDiscoveryService.js'; @@ -62,6 +68,7 @@ export * from './ide/ide-installer.js'; export { getIdeInfo, DetectedIde } from './ide/detect-ide.js'; export { type IdeInfo } from './ide/detect-ide.js'; export * from './ide/constants.js'; +export * from './ide/types.js'; // Export Shell Execution Service export * from './services/shellExecutionService.js'; diff --git a/packages/core/src/output/json-formatter.test.ts b/packages/core/src/output/json-formatter.test.ts new file mode 100644 index 00000000000..587030a9807 --- /dev/null +++ b/packages/core/src/output/json-formatter.test.ts @@ -0,0 +1,301 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { expect, describe, it } from 'vitest'; +import type { SessionMetrics } from '../telemetry/uiTelemetry.js'; +import { JsonFormatter } from './json-formatter.js'; +import type { JsonError } from './types.js'; + +describe('JsonFormatter', () => { + it('should format the response as JSON', () => { + const formatter = new JsonFormatter(); + const response = 'This is a test response.'; + const formatted = formatter.format(response); + const expected = { + response, + }; + expect(JSON.parse(formatted)).toEqual(expected); + }); + + it('should strip ANSI escape sequences from response text', () => { + const formatter = new JsonFormatter(); + const responseWithAnsi = + '\x1B[31mRed text\x1B[0m and \x1B[32mGreen text\x1B[0m'; + const formatted = formatter.format(responseWithAnsi); + const parsed = JSON.parse(formatted); + expect(parsed.response).toBe('Red text and Green text'); + }); + + it('should strip control characters from response text', () => { + const formatter = new JsonFormatter(); + const responseWithControlChars = + 'Text with\x07 bell\x08 and\x0B vertical tab'; + const formatted = formatter.format(responseWithControlChars); + const parsed = JSON.parse(formatted); + // Only ANSI codes are stripped, other control chars are preserved + expect(parsed.response).toBe('Text with\x07 bell\x08 and\x0B vertical tab'); + }); + + it('should preserve newlines and tabs in response text', () => { + const formatter = new JsonFormatter(); + const responseWithWhitespace = 'Line 1\nLine 2\r\nLine 3\twith tab'; + const formatted = formatter.format(responseWithWhitespace); + const parsed = JSON.parse(formatted); + expect(parsed.response).toBe('Line 1\nLine 2\r\nLine 3\twith tab'); + }); + + it('should format the response as JSON with stats', () => { + const formatter = new JsonFormatter(); + const response = 'This is a test response.'; + const stats: SessionMetrics = { + models: { + 'gemini-2.5-pro': { + api: { + totalRequests: 2, + totalErrors: 0, + totalLatencyMs: 5672, + }, + tokens: { + prompt: 24401, + candidates: 215, + total: 24719, + cached: 10656, + thoughts: 103, + tool: 0, + }, + }, + 'gemini-2.5-flash': { + api: { + totalRequests: 2, + totalErrors: 0, + totalLatencyMs: 5914, + }, + tokens: { + prompt: 20803, + candidates: 716, + total: 21657, + cached: 0, + thoughts: 138, + tool: 0, + }, + }, + }, + tools: { + totalCalls: 1, + totalSuccess: 1, + totalFail: 0, + totalDurationMs: 4582, + totalDecisions: { + accept: 0, + reject: 0, + modify: 0, + auto_accept: 1, + }, + byName: { + google_web_search: { + count: 1, + success: 1, + fail: 0, + durationMs: 4582, + decisions: { + accept: 0, + reject: 0, + modify: 0, + auto_accept: 1, + }, + }, + }, + }, + files: { + totalLinesAdded: 0, + totalLinesRemoved: 0, + }, + }; + const formatted = formatter.format(response, stats); + const expected = { + response, + stats, + }; + expect(JSON.parse(formatted)).toEqual(expected); + }); + + it('should format error as JSON', () => { + const formatter = new JsonFormatter(); + const error: JsonError = { + type: 'ValidationError', + message: 'Invalid input provided', + code: 400, + }; + const formatted = formatter.format(undefined, undefined, error); + const expected = { + error, + }; + expect(JSON.parse(formatted)).toEqual(expected); + }); + + it('should format response with error as JSON', () => { + const formatter = new JsonFormatter(); + const response = 'Partial response'; + const error: JsonError = { + type: 'TimeoutError', + message: 'Request timed out', + code: 'TIMEOUT', + }; + const formatted = formatter.format(response, undefined, error); + const expected = { + response, + error, + }; + expect(JSON.parse(formatted)).toEqual(expected); + }); + + it('should format error using formatError method', () => { + const formatter = new JsonFormatter(); + const error = new Error('Something went wrong'); + const formatted = formatter.formatError(error, 500); + const parsed = JSON.parse(formatted); + + expect(parsed).toEqual({ + error: { + type: 'Error', + message: 'Something went wrong', + code: 500, + }, + }); + }); + + it('should format custom error using formatError method', () => { + class CustomError extends Error { + constructor(message: string) { + super(message); + this.name = 'CustomError'; + } + } + + const formatter = new JsonFormatter(); + const error = new CustomError('Custom error occurred'); + const formatted = formatter.formatError(error); + const parsed = JSON.parse(formatted); + + expect(parsed).toEqual({ + error: { + type: 'CustomError', + message: 'Custom error occurred', + }, + }); + }); + + it('should format complete JSON output with response, stats, and error', () => { + const formatter = new JsonFormatter(); + const response = 'Partial response before error'; + const stats: SessionMetrics = { + models: {}, + tools: { + totalCalls: 0, + totalSuccess: 0, + totalFail: 1, + totalDurationMs: 0, + totalDecisions: { + accept: 0, + reject: 0, + modify: 0, + auto_accept: 0, + }, + byName: {}, + }, + files: { + totalLinesAdded: 0, + totalLinesRemoved: 0, + }, + }; + const error: JsonError = { + type: 'ApiError', + message: 'Rate limit exceeded', + code: 429, + }; + + const formatted = formatter.format(response, stats, error); + const expected = { + response, + stats, + error, + }; + expect(JSON.parse(formatted)).toEqual(expected); + }); + + it('should handle error messages containing JSON content', () => { + const formatter = new JsonFormatter(); + const errorWithJson = new Error( + 'API returned: {"error": "Invalid request", "code": 400}', + ); + const formatted = formatter.formatError(errorWithJson, 'API_ERROR'); + const parsed = JSON.parse(formatted); + + expect(parsed).toEqual({ + error: { + type: 'Error', + message: 'API returned: {"error": "Invalid request", "code": 400}', + code: 'API_ERROR', + }, + }); + + // Verify the entire output is valid JSON + expect(() => JSON.parse(formatted)).not.toThrow(); + }); + + it('should handle error messages with quotes and special characters', () => { + const formatter = new JsonFormatter(); + const errorWithQuotes = new Error('Error: "quoted text" and \\backslash'); + const formatted = formatter.formatError(errorWithQuotes); + const parsed = JSON.parse(formatted); + + expect(parsed).toEqual({ + error: { + type: 'Error', + message: 'Error: "quoted text" and \\backslash', + }, + }); + + // Verify the entire output is valid JSON + expect(() => JSON.parse(formatted)).not.toThrow(); + }); + + it('should handle error messages with control characters', () => { + const formatter = new JsonFormatter(); + const errorWithControlChars = new Error('Error with\n newline and\t tab'); + const formatted = formatter.formatError(errorWithControlChars); + const parsed = JSON.parse(formatted); + + // Should preserve newlines and tabs as they are common whitespace characters + expect(parsed.error.message).toBe('Error with\n newline and\t tab'); + + // Verify the entire output is valid JSON + expect(() => JSON.parse(formatted)).not.toThrow(); + }); + + it('should strip ANSI escape sequences from error messages', () => { + const formatter = new JsonFormatter(); + const errorWithAnsi = new Error('\x1B[31mRed error\x1B[0m message'); + const formatted = formatter.formatError(errorWithAnsi); + const parsed = JSON.parse(formatted); + + expect(parsed.error.message).toBe('Red error message'); + expect(() => JSON.parse(formatted)).not.toThrow(); + }); + + it('should strip unsafe control characters from error messages', () => { + const formatter = new JsonFormatter(); + const errorWithControlChars = new Error( + 'Error\x07 with\x08 control\x0B chars', + ); + const formatted = formatter.formatError(errorWithControlChars); + const parsed = JSON.parse(formatted); + + // Only ANSI codes are stripped, other control chars are preserved + expect(parsed.error.message).toBe('Error\x07 with\x08 control\x0B chars'); + expect(() => JSON.parse(formatted)).not.toThrow(); + }); +}); diff --git a/packages/core/src/output/json-formatter.ts b/packages/core/src/output/json-formatter.ts new file mode 100644 index 00000000000..83ea3e3862f --- /dev/null +++ b/packages/core/src/output/json-formatter.ts @@ -0,0 +1,39 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import stripAnsi from 'strip-ansi'; +import type { SessionMetrics } from '../telemetry/uiTelemetry.js'; +import type { JsonError, JsonOutput } from './types.js'; + +export class JsonFormatter { + format(response?: string, stats?: SessionMetrics, error?: JsonError): string { + const output: JsonOutput = {}; + + if (response !== undefined) { + output.response = stripAnsi(response); + } + + if (stats) { + output.stats = stats; + } + + if (error) { + output.error = error; + } + + return JSON.stringify(output, null, 2); + } + + formatError(error: Error, code?: string | number): string { + const jsonError: JsonError = { + type: error.constructor.name, + message: stripAnsi(error.message), + ...(code && { code }), + }; + + return this.format(undefined, undefined, jsonError); + } +} diff --git a/packages/core/src/output/stream-json-formatter.ts b/packages/core/src/output/stream-json-formatter.ts new file mode 100644 index 00000000000..25256912ab7 --- /dev/null +++ b/packages/core/src/output/stream-json-formatter.ts @@ -0,0 +1,77 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import stripAnsi from 'strip-ansi'; +import type { SessionMetrics } from '../telemetry/uiTelemetry.js'; +import type { JsonError } from './types.js'; +import type { TelemetryEvent } from '../telemetry/types.js'; + +export interface StreamJsonTelemetryBlock { + type: 'telemetry'; + event: TelemetryEvent; +} + +export interface StreamJsonContentBlock { + type: 'content'; + content: string; +} + +export interface StreamJsonFinalBlock { + type: 'final'; + response?: string; + stats?: SessionMetrics; + error?: JsonError; +} + +export type StreamJsonBlock = StreamJsonTelemetryBlock | StreamJsonContentBlock | StreamJsonFinalBlock; + +export class StreamJsonFormatter { + formatTelemetryBlock(event: TelemetryEvent): string { + const block: StreamJsonTelemetryBlock = { + type: 'telemetry', + event, + }; + return JSON.stringify(block); + } + + formatContentBlock(content: string): string { + const block: StreamJsonContentBlock = { + type: 'content', + content: stripAnsi(content), + }; + return JSON.stringify(block); + } + + formatFinalBlock(response?: string, stats?: SessionMetrics, error?: JsonError): string { + const block: StreamJsonFinalBlock = { + type: 'final', + }; + + if (response !== undefined) { + block.response = stripAnsi(response); + } + + if (stats) { + block.stats = stats; + } + + if (error) { + block.error = error; + } + + return JSON.stringify(block); + } + + formatError(error: Error, code?: string | number): string { + const jsonError: JsonError = { + type: error.constructor.name, + message: stripAnsi(error.message), + ...(code && { code }), + }; + + return this.formatFinalBlock(undefined, undefined, jsonError); + } +} \ No newline at end of file diff --git a/packages/core/src/output/types.ts b/packages/core/src/output/types.ts new file mode 100644 index 00000000000..0c7593dd4ba --- /dev/null +++ b/packages/core/src/output/types.ts @@ -0,0 +1,25 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { SessionMetrics } from '../telemetry/uiTelemetry.js'; + +export enum OutputFormat { + TEXT = 'text', + JSON = 'json', + STREAM_JSON = 'stream-json', +} + +export interface JsonError { + type: string; + message: string; + code?: string | number; +} + +export interface JsonOutput { + response?: string; + stats?: SessionMetrics; + error?: JsonError; +} diff --git a/packages/core/src/policy/index.ts b/packages/core/src/policy/index.ts new file mode 100644 index 00000000000..e15309ca69b --- /dev/null +++ b/packages/core/src/policy/index.ts @@ -0,0 +1,8 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from './policy-engine.js'; +export * from './types.js'; diff --git a/packages/core/src/policy/policy-engine.test.ts b/packages/core/src/policy/policy-engine.test.ts new file mode 100644 index 00000000000..51f222b2e42 --- /dev/null +++ b/packages/core/src/policy/policy-engine.test.ts @@ -0,0 +1,624 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach } from 'vitest'; +import { PolicyEngine } from './policy-engine.js'; +import { + PolicyDecision, + type PolicyRule, + type PolicyEngineConfig, +} from './types.js'; +import type { FunctionCall } from '@google/genai'; + +describe('PolicyEngine', () => { + let engine: PolicyEngine; + + beforeEach(() => { + engine = new PolicyEngine(); + }); + + describe('constructor', () => { + it('should use default config when none provided', () => { + const decision = engine.check({ name: 'test' }); + expect(decision).toBe(PolicyDecision.ASK_USER); + }); + + it('should respect custom default decision', () => { + engine = new PolicyEngine({ defaultDecision: PolicyDecision.DENY }); + const decision = engine.check({ name: 'test' }); + expect(decision).toBe(PolicyDecision.DENY); + }); + + it('should sort rules by priority', () => { + const rules: PolicyRule[] = [ + { toolName: 'tool1', decision: PolicyDecision.DENY, priority: 1 }, + { toolName: 'tool2', decision: PolicyDecision.ALLOW, priority: 10 }, + { toolName: 'tool3', decision: PolicyDecision.ASK_USER, priority: 5 }, + ]; + + engine = new PolicyEngine({ rules }); + const sortedRules = engine.getRules(); + + expect(sortedRules[0].priority).toBe(10); + expect(sortedRules[1].priority).toBe(5); + expect(sortedRules[2].priority).toBe(1); + }); + }); + + describe('check', () => { + it('should match tool by name', () => { + const rules: PolicyRule[] = [ + { toolName: 'shell', decision: PolicyDecision.ALLOW }, + { toolName: 'edit', decision: PolicyDecision.DENY }, + ]; + + engine = new PolicyEngine({ rules }); + + expect(engine.check({ name: 'shell' })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'edit' })).toBe(PolicyDecision.DENY); + expect(engine.check({ name: 'other' })).toBe(PolicyDecision.ASK_USER); + }); + + it('should match by args pattern', () => { + const rules: PolicyRule[] = [ + { + toolName: 'shell', + argsPattern: /rm -rf/, + decision: PolicyDecision.DENY, + }, + { + toolName: 'shell', + decision: PolicyDecision.ALLOW, + }, + ]; + + engine = new PolicyEngine({ rules }); + + const dangerousCall: FunctionCall = { + name: 'shell', + args: { command: 'rm -rf /' }, + }; + + const safeCall: FunctionCall = { + name: 'shell', + args: { command: 'ls -la' }, + }; + + expect(engine.check(dangerousCall)).toBe(PolicyDecision.DENY); + expect(engine.check(safeCall)).toBe(PolicyDecision.ALLOW); + }); + + it('should apply rules by priority', () => { + const rules: PolicyRule[] = [ + { toolName: 'shell', decision: PolicyDecision.DENY, priority: 1 }, + { toolName: 'shell', decision: PolicyDecision.ALLOW, priority: 10 }, + ]; + + engine = new PolicyEngine({ rules }); + + // Higher priority rule (ALLOW) should win + expect(engine.check({ name: 'shell' })).toBe(PolicyDecision.ALLOW); + }); + + it('should apply wildcard rules (no toolName)', () => { + const rules: PolicyRule[] = [ + { decision: PolicyDecision.DENY }, // Applies to all tools + { toolName: 'safe-tool', decision: PolicyDecision.ALLOW, priority: 10 }, + ]; + + engine = new PolicyEngine({ rules }); + + expect(engine.check({ name: 'safe-tool' })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'any-other-tool' })).toBe( + PolicyDecision.DENY, + ); + }); + + it('should handle non-interactive mode', () => { + const config: PolicyEngineConfig = { + nonInteractive: true, + rules: [ + { toolName: 'interactive-tool', decision: PolicyDecision.ASK_USER }, + { toolName: 'allowed-tool', decision: PolicyDecision.ALLOW }, + ], + }; + + engine = new PolicyEngine(config); + + // ASK_USER should become DENY in non-interactive mode + expect(engine.check({ name: 'interactive-tool' })).toBe( + PolicyDecision.DENY, + ); + // ALLOW should remain ALLOW + expect(engine.check({ name: 'allowed-tool' })).toBe(PolicyDecision.ALLOW); + // Default ASK_USER should also become DENY + expect(engine.check({ name: 'unknown-tool' })).toBe(PolicyDecision.DENY); + }); + }); + + describe('addRule', () => { + it('should add a new rule and maintain priority order', () => { + engine.addRule({ + toolName: 'tool1', + decision: PolicyDecision.ALLOW, + priority: 5, + }); + engine.addRule({ + toolName: 'tool2', + decision: PolicyDecision.DENY, + priority: 10, + }); + engine.addRule({ + toolName: 'tool3', + decision: PolicyDecision.ASK_USER, + priority: 1, + }); + + const rules = engine.getRules(); + expect(rules).toHaveLength(3); + expect(rules[0].priority).toBe(10); + expect(rules[1].priority).toBe(5); + expect(rules[2].priority).toBe(1); + }); + + it('should apply newly added rules', () => { + expect(engine.check({ name: 'new-tool' })).toBe(PolicyDecision.ASK_USER); + + engine.addRule({ toolName: 'new-tool', decision: PolicyDecision.ALLOW }); + + expect(engine.check({ name: 'new-tool' })).toBe(PolicyDecision.ALLOW); + }); + }); + + describe('removeRulesForTool', () => { + it('should remove rules for specific tool', () => { + engine.addRule({ toolName: 'tool1', decision: PolicyDecision.ALLOW }); + engine.addRule({ toolName: 'tool2', decision: PolicyDecision.DENY }); + engine.addRule({ + toolName: 'tool1', + decision: PolicyDecision.ASK_USER, + priority: 10, + }); + + expect(engine.getRules()).toHaveLength(3); + + engine.removeRulesForTool('tool1'); + + const remainingRules = engine.getRules(); + expect(remainingRules).toHaveLength(1); + expect(remainingRules.some((r) => r.toolName === 'tool1')).toBe(false); + expect(remainingRules.some((r) => r.toolName === 'tool2')).toBe(true); + }); + + it('should handle removing non-existent tool', () => { + engine.addRule({ toolName: 'existing', decision: PolicyDecision.ALLOW }); + + expect(() => engine.removeRulesForTool('non-existent')).not.toThrow(); + expect(engine.getRules()).toHaveLength(1); + }); + }); + + describe('getRules', () => { + it('should return readonly array of rules', () => { + const rules: PolicyRule[] = [ + { toolName: 'tool1', decision: PolicyDecision.ALLOW }, + { toolName: 'tool2', decision: PolicyDecision.DENY }, + ]; + + engine = new PolicyEngine({ rules }); + + const retrievedRules = engine.getRules(); + expect(retrievedRules).toHaveLength(2); + expect(retrievedRules[0].toolName).toBe('tool1'); + expect(retrievedRules[1].toolName).toBe('tool2'); + }); + }); + + describe('complex scenarios', () => { + it('should handle multiple matching rules with different priorities', () => { + const rules: PolicyRule[] = [ + { decision: PolicyDecision.DENY, priority: 0 }, // Default deny all + { toolName: 'shell', decision: PolicyDecision.ASK_USER, priority: 5 }, + { + toolName: 'shell', + argsPattern: /"command":"ls/, + decision: PolicyDecision.ALLOW, + priority: 10, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Matches highest priority rule (ls command) + expect(engine.check({ name: 'shell', args: { command: 'ls -la' } })).toBe( + PolicyDecision.ALLOW, + ); + + // Matches middle priority rule (shell without ls) + expect(engine.check({ name: 'shell', args: { command: 'pwd' } })).toBe( + PolicyDecision.ASK_USER, + ); + + // Matches lowest priority rule (not shell) + expect(engine.check({ name: 'edit' })).toBe(PolicyDecision.DENY); + }); + + it('should handle tools with no args', () => { + const rules: PolicyRule[] = [ + { + toolName: 'read', + argsPattern: /secret/, + decision: PolicyDecision.DENY, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Tool call without args should not match pattern + expect(engine.check({ name: 'read' })).toBe(PolicyDecision.ASK_USER); + + // Tool call with args not matching pattern + expect(engine.check({ name: 'read', args: { file: 'public.txt' } })).toBe( + PolicyDecision.ASK_USER, + ); + + // Tool call with args matching pattern + expect(engine.check({ name: 'read', args: { file: 'secret.txt' } })).toBe( + PolicyDecision.DENY, + ); + }); + + it('should match args pattern regardless of property order', () => { + const rules: PolicyRule[] = [ + { + toolName: 'shell', + // Pattern matches the stable stringified format + argsPattern: /"command":"rm[^"]*-rf/, + decision: PolicyDecision.DENY, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Same args with different property order should both match + const args1 = { command: 'rm -rf /', path: '/home' }; + const args2 = { path: '/home', command: 'rm -rf /' }; + + expect(engine.check({ name: 'shell', args: args1 })).toBe( + PolicyDecision.DENY, + ); + expect(engine.check({ name: 'shell', args: args2 })).toBe( + PolicyDecision.DENY, + ); + + // Verify safe command doesn't match + const safeArgs = { command: 'ls -la', path: '/home' }; + expect(engine.check({ name: 'shell', args: safeArgs })).toBe( + PolicyDecision.ASK_USER, + ); + }); + + it('should handle nested objects in args with stable stringification', () => { + const rules: PolicyRule[] = [ + { + toolName: 'api', + argsPattern: /"sensitive":true/, + decision: PolicyDecision.DENY, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Nested objects with different key orders should match consistently + const args1 = { + data: { sensitive: true, value: 'secret' }, + method: 'POST', + }; + const args2 = { + method: 'POST', + data: { value: 'secret', sensitive: true }, + }; + + expect(engine.check({ name: 'api', args: args1 })).toBe( + PolicyDecision.DENY, + ); + expect(engine.check({ name: 'api', args: args2 })).toBe( + PolicyDecision.DENY, + ); + }); + + it('should handle circular references without stack overflow', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /\[Circular\]/, + decision: PolicyDecision.DENY, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Create an object with a circular reference + type CircularArgs = Record & { + data?: Record; + }; + const circularArgs: CircularArgs = { + name: 'test', + data: {}, + }; + // Create circular reference - TypeScript allows this since data is Record + (circularArgs.data as Record)['self'] = + circularArgs.data; + + // Should not throw stack overflow error + expect(() => + engine.check({ name: 'test', args: circularArgs }), + ).not.toThrow(); + + // Should detect the circular reference pattern + expect(engine.check({ name: 'test', args: circularArgs })).toBe( + PolicyDecision.DENY, + ); + + // Non-circular object should not match + const normalArgs = { name: 'test', data: { value: 'normal' } }; + expect(engine.check({ name: 'test', args: normalArgs })).toBe( + PolicyDecision.ASK_USER, + ); + }); + + it('should handle deep circular references', () => { + const rules: PolicyRule[] = [ + { + toolName: 'deep', + argsPattern: /\[Circular\]/, + decision: PolicyDecision.DENY, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Create a deep circular reference + type DeepCircular = Record & { + level1?: { + level2?: { + level3?: Record; + }; + }; + }; + const deepCircular: DeepCircular = { + level1: { + level2: { + level3: {}, + }, + }, + }; + // Create circular reference with proper type assertions + const level3 = deepCircular.level1!.level2!.level3!; + level3['back'] = deepCircular.level1; + + // Should handle without stack overflow + expect(() => + engine.check({ name: 'deep', args: deepCircular }), + ).not.toThrow(); + + // Should detect the circular reference + expect(engine.check({ name: 'deep', args: deepCircular })).toBe( + PolicyDecision.DENY, + ); + }); + + it('should handle repeated non-circular objects correctly', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /\[Circular\]/, + decision: PolicyDecision.DENY, + }, + { + toolName: 'test', + argsPattern: /"value":"shared"/, + decision: PolicyDecision.ALLOW, + priority: 10, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Create an object with repeated references but no cycles + const sharedObj = { value: 'shared' }; + const args = { + first: sharedObj, + second: sharedObj, + third: { nested: sharedObj }, + }; + + // Should NOT mark repeated objects as circular, and should match the shared value pattern + expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + }); + + it('should omit undefined and function values from objects', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /"definedValue":"test"/, + decision: PolicyDecision.ALLOW, + }, + ]; + + engine = new PolicyEngine({ rules }); + + const args = { + definedValue: 'test', + undefinedValue: undefined, + functionValue: () => 'hello', + nullValue: null, + }; + + // Should match pattern with defined value, undefined and functions omitted + expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + + // Check that the pattern would NOT match if undefined was included + const rulesWithUndefined: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /undefinedValue/, + decision: PolicyDecision.DENY, + }, + ]; + engine = new PolicyEngine({ rules: rulesWithUndefined }); + expect(engine.check({ name: 'test', args })).toBe( + PolicyDecision.ASK_USER, + ); + + // Check that the pattern would NOT match if function was included + const rulesWithFunction: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /functionValue/, + decision: PolicyDecision.DENY, + }, + ]; + engine = new PolicyEngine({ rules: rulesWithFunction }); + expect(engine.check({ name: 'test', args })).toBe( + PolicyDecision.ASK_USER, + ); + }); + + it('should convert undefined and functions to null in arrays', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /\["value",null,null,null\]/, + decision: PolicyDecision.ALLOW, + }, + ]; + + engine = new PolicyEngine({ rules }); + + const args = { + array: ['value', undefined, () => 'hello', null], + }; + + // Should match pattern with undefined and functions converted to null + expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + }); + + it('should produce valid JSON for all inputs', () => { + const testCases: Array<{ input: Record; desc: string }> = + [ + { input: { simple: 'string' }, desc: 'simple object' }, + { + input: { nested: { deep: { value: 123 } } }, + desc: 'nested object', + }, + { input: { data: [1, 2, 3] }, desc: 'simple array' }, + { input: { mixed: [1, { a: 'b' }, null] }, desc: 'mixed array' }, + { + input: { undef: undefined, func: () => {}, normal: 'value' }, + desc: 'object with undefined and function', + }, + { + input: { data: ['a', undefined, () => {}, null] }, + desc: 'array with undefined and function', + }, + ]; + + for (const { input } of testCases) { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /.*/, + decision: PolicyDecision.ALLOW, + }, + ]; + engine = new PolicyEngine({ rules }); + + // Should not throw when checking (which internally uses stableStringify) + expect(() => engine.check({ name: 'test', args: input })).not.toThrow(); + + // The check should succeed + expect(engine.check({ name: 'test', args: input })).toBe( + PolicyDecision.ALLOW, + ); + } + }); + + it('should respect toJSON methods on objects', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /"sanitized":"safe"/, + decision: PolicyDecision.ALLOW, + }, + { + toolName: 'test', + argsPattern: /"dangerous":"data"/, + decision: PolicyDecision.DENY, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Object with toJSON that sanitizes output + const args = { + data: { + dangerous: 'data', + toJSON: () => ({ sanitized: 'safe' }), + }, + }; + + // Should match the sanitized pattern, not the dangerous one + expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + }); + + it('should handle toJSON that returns primitives', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /"value":"string-value"/, + decision: PolicyDecision.ALLOW, + }, + ]; + + engine = new PolicyEngine({ rules }); + + const args = { + value: { + complex: 'object', + toJSON: () => 'string-value', + }, + }; + + // toJSON returns a string, which should be properly stringified + expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + }); + + it('should handle toJSON that throws an error', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /"fallback":"value"/, + decision: PolicyDecision.ALLOW, + }, + ]; + + engine = new PolicyEngine({ rules }); + + const args = { + data: { + fallback: 'value', + toJSON: () => { + throw new Error('toJSON error'); + }, + }, + }; + + // Should fall back to regular object serialization when toJSON throws + expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + }); + }); +}); diff --git a/packages/core/src/policy/policy-engine.ts b/packages/core/src/policy/policy-engine.ts new file mode 100644 index 00000000000..e1006ffdef1 --- /dev/null +++ b/packages/core/src/policy/policy-engine.ts @@ -0,0 +1,107 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type FunctionCall } from '@google/genai'; +import { + PolicyDecision, + type PolicyEngineConfig, + type PolicyRule, +} from './types.js'; +import { stableStringify } from './stable-stringify.js'; + +function ruleMatches( + rule: PolicyRule, + toolCall: FunctionCall, + stringifiedArgs: string | undefined, +): boolean { + // Check tool name if specified + if (rule.toolName && toolCall.name !== rule.toolName) { + return false; + } + + // Check args pattern if specified + if (rule.argsPattern) { + // If rule has an args pattern but tool has no args, no match + if (!toolCall.args) { + return false; + } + // Use stable JSON stringification with sorted keys to ensure consistent matching + if ( + stringifiedArgs === undefined || + !rule.argsPattern.test(stringifiedArgs) + ) { + return false; + } + } + + return true; +} + +export class PolicyEngine { + private rules: PolicyRule[]; + private readonly defaultDecision: PolicyDecision; + private readonly nonInteractive: boolean; + + constructor(config: PolicyEngineConfig = {}) { + this.rules = (config.rules ?? []).sort( + (a, b) => (b.priority ?? 0) - (a.priority ?? 0), + ); + this.defaultDecision = config.defaultDecision ?? PolicyDecision.ASK_USER; + this.nonInteractive = config.nonInteractive ?? false; + } + + /** + * Check if a tool call is allowed based on the configured policies. + */ + check(toolCall: FunctionCall): PolicyDecision { + let stringifiedArgs: string | undefined; + // Compute stringified args once before the loop + if (toolCall.args && this.rules.some((rule) => rule.argsPattern)) { + stringifiedArgs = stableStringify(toolCall.args); + } + + // Find the first matching rule (already sorted by priority) + for (const rule of this.rules) { + if (ruleMatches(rule, toolCall, stringifiedArgs)) { + return this.applyNonInteractiveMode(rule.decision); + } + } + + // No matching rule found, use default decision + return this.applyNonInteractiveMode(this.defaultDecision); + } + + /** + * Add a new rule to the policy engine. + */ + addRule(rule: PolicyRule): void { + this.rules.push(rule); + // Re-sort rules by priority + this.rules.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); + } + + /** + * Remove rules for a specific tool. + */ + removeRulesForTool(toolName: string): void { + this.rules = this.rules.filter((rule) => rule.toolName !== toolName); + } + + /** + * Get all current rules. + */ + getRules(): readonly PolicyRule[] { + return this.rules; + } + + private applyNonInteractiveMode(decision: PolicyDecision): PolicyDecision { + // In non-interactive mode, ASK_USER becomes DENY + if (this.nonInteractive && decision === PolicyDecision.ASK_USER) { + return PolicyDecision.DENY; + } + return decision; + } +} diff --git a/packages/core/src/policy/stable-stringify.ts b/packages/core/src/policy/stable-stringify.ts new file mode 100644 index 00000000000..78db692eab7 --- /dev/null +++ b/packages/core/src/policy/stable-stringify.ts @@ -0,0 +1,128 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Produces a stable, deterministic JSON string representation with sorted keys. + * + * This method is critical for security policy matching. It ensures that the same + * object always produces the same string representation, regardless of property + * insertion order, which could vary across different JavaScript engines or + * runtime conditions. + * + * Key behaviors: + * 1. **Sorted Keys**: Object properties are always serialized in alphabetical order, + * ensuring deterministic output for pattern matching. + * + * 2. **Circular Reference Protection**: Uses ancestor chain tracking (not just + * object identity) to detect true circular references while correctly handling + * repeated non-circular object references. Circular references are replaced + * with "[Circular]" to prevent stack overflow attacks. + * + * 3. **JSON Spec Compliance**: + * - undefined values: Omitted from objects, converted to null in arrays + * - Functions: Omitted from objects, converted to null in arrays + * - toJSON methods: Respected and called when present (per JSON.stringify spec) + * + * 4. **Security Considerations**: + * - Prevents DoS via circular references that would cause infinite recursion + * - Ensures consistent policy rule matching by normalizing property order + * - Respects toJSON for objects that sanitize their output + * - Handles toJSON methods that throw errors gracefully + * + * @param obj - The object to stringify (typically toolCall.args) + * @returns A deterministic JSON string representation + * + * @example + * // Different property orders produce the same output: + * stableStringify({b: 2, a: 1}) === stableStringify({a: 1, b: 2}) + * // Returns: '{"a":1,"b":2}' + * + * @example + * // Circular references are handled safely: + * const obj = {a: 1}; + * obj.self = obj; + * stableStringify(obj) + * // Returns: '{"a":1,"self":"[Circular]"}' + * + * @example + * // toJSON methods are respected: + * const obj = { + * sensitive: 'secret', + * toJSON: () => ({ safe: 'data' }) + * }; + * stableStringify(obj) + * // Returns: '{"safe":"data"}' + */ +export function stableStringify(obj: unknown): string { + const stringify = (currentObj: unknown, ancestors: Set): string => { + // Handle primitives and null + if (currentObj === undefined) { + return 'null'; // undefined in arrays becomes null in JSON + } + if (currentObj === null) { + return 'null'; + } + if (typeof currentObj === 'function') { + return 'null'; // functions in arrays become null in JSON + } + if (typeof currentObj !== 'object') { + return JSON.stringify(currentObj); + } + + // Check for circular reference (object is in ancestor chain) + if (ancestors.has(currentObj)) { + return '"[Circular]"'; + } + + ancestors.add(currentObj); + + try { + // Check for toJSON method and use it if present + const objWithToJSON = currentObj as { toJSON?: () => unknown }; + if (typeof objWithToJSON.toJSON === 'function') { + try { + const jsonValue = objWithToJSON.toJSON(); + // The result of toJSON needs to be stringified recursively + if (jsonValue === null) { + return 'null'; + } + return stringify(jsonValue, ancestors); + } catch { + // If toJSON throws, treat as a regular object + } + } + + if (Array.isArray(currentObj)) { + const items = currentObj.map((item) => { + // undefined and functions in arrays become null + if (item === undefined || typeof item === 'function') { + return 'null'; + } + return stringify(item, ancestors); + }); + return '[' + items.join(',') + ']'; + } + + // Handle objects - sort keys and filter out undefined/function values + const sortedKeys = Object.keys(currentObj).sort(); + const pairs: string[] = []; + + for (const key of sortedKeys) { + const value = (currentObj as Record)[key]; + // Skip undefined and function values in objects (per JSON spec) + if (value !== undefined && typeof value !== 'function') { + pairs.push(JSON.stringify(key) + ':' + stringify(value, ancestors)); + } + } + + return '{' + pairs.join(',') + '}'; + } finally { + ancestors.delete(currentObj); + } + }; + + return stringify(obj, new Set()); +} diff --git a/packages/core/src/policy/types.ts b/packages/core/src/policy/types.ts new file mode 100644 index 00000000000..f20a88e70c6 --- /dev/null +++ b/packages/core/src/policy/types.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +export enum PolicyDecision { + ALLOW = 'allow', + DENY = 'deny', + ASK_USER = 'ask_user', +} + +export interface PolicyRule { + /** + * The name of the tool this rule applies to. + * If undefined, the rule applies to all tools. + */ + toolName?: string; + + /** + * Pattern to match against tool arguments. + * Can be used for more fine-grained control. + */ + argsPattern?: RegExp; + + /** + * The decision to make when this rule matches. + */ + decision: PolicyDecision; + + /** + * Priority of this rule. Higher numbers take precedence. + * Default is 0. + */ + priority?: number; +} + +export interface PolicyEngineConfig { + /** + * List of policy rules to apply. + */ + rules?: PolicyRule[]; + + /** + * Default decision when no rules match. + * Defaults to ASK_USER. + */ + defaultDecision?: PolicyDecision; + + /** + * Whether to allow tools in non-interactive mode. + * When true, ASK_USER decisions become DENY. + */ + nonInteractive?: boolean; +} diff --git a/packages/core/src/routing/modelRouterService.test.ts b/packages/core/src/routing/modelRouterService.test.ts new file mode 100644 index 00000000000..0f83796787d --- /dev/null +++ b/packages/core/src/routing/modelRouterService.test.ts @@ -0,0 +1,96 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { ModelRouterService } from './modelRouterService.js'; +import { Config } from '../config/config.js'; +import type { BaseLlmClient } from '../core/baseLlmClient.js'; +import type { RoutingContext, RoutingDecision } from './routingStrategy.js'; +import { DefaultStrategy } from './strategies/defaultStrategy.js'; +import { CompositeStrategy } from './strategies/compositeStrategy.js'; +import { FallbackStrategy } from './strategies/fallbackStrategy.js'; +import { OverrideStrategy } from './strategies/overrideStrategy.js'; + +vi.mock('../config/config.js'); +vi.mock('../core/baseLlmClient.js'); +vi.mock('./strategies/defaultStrategy.js'); +vi.mock('./strategies/compositeStrategy.js'); +vi.mock('./strategies/fallbackStrategy.js'); +vi.mock('./strategies/overrideStrategy.js'); + +describe('ModelRouterService', () => { + let service: ModelRouterService; + let mockConfig: Config; + let mockBaseLlmClient: BaseLlmClient; + let mockContext: RoutingContext; + let mockCompositeStrategy: CompositeStrategy; + + beforeEach(() => { + vi.clearAllMocks(); + + mockConfig = new Config({} as never); + mockBaseLlmClient = {} as BaseLlmClient; + vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient); + + mockCompositeStrategy = new CompositeStrategy( + [new FallbackStrategy(), new OverrideStrategy(), new DefaultStrategy()], + 'agent-router', + ); + vi.mocked(CompositeStrategy).mockImplementation( + () => mockCompositeStrategy, + ); + + service = new ModelRouterService(mockConfig); + + mockContext = { + history: [], + request: [{ text: 'test prompt' }], + signal: new AbortController().signal, + }; + }); + + it('should initialize with a CompositeStrategy', () => { + expect(CompositeStrategy).toHaveBeenCalled(); + expect(service['strategy']).toBeInstanceOf(CompositeStrategy); + }); + + it('should initialize the CompositeStrategy with the correct child strategies in order', () => { + // This test relies on the mock implementation detail of the constructor + const compositeStrategyArgs = vi.mocked(CompositeStrategy).mock.calls[0]; + const childStrategies = compositeStrategyArgs[0]; + + expect(childStrategies.length).toBe(3); + expect(childStrategies[0]).toBeInstanceOf(FallbackStrategy); + expect(childStrategies[1]).toBeInstanceOf(OverrideStrategy); + expect(childStrategies[2]).toBeInstanceOf(DefaultStrategy); + expect(compositeStrategyArgs[1]).toBe('agent-router'); + }); + + describe('route()', () => { + it('should delegate routing to the composite strategy', async () => { + const strategyDecision: RoutingDecision = { + model: 'strategy-chosen-model', + metadata: { + source: 'test-router/fallback', + latencyMs: 10, + reasoning: 'Strategy reasoning', + }, + }; + const strategySpy = vi + .spyOn(mockCompositeStrategy, 'route') + .mockResolvedValue(strategyDecision); + + const decision = await service.route(mockContext); + + expect(strategySpy).toHaveBeenCalledWith( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + expect(decision).toEqual(strategyDecision); + }); + }); +}); diff --git a/packages/core/src/routing/modelRouterService.ts b/packages/core/src/routing/modelRouterService.ts new file mode 100644 index 00000000000..a984125f899 --- /dev/null +++ b/packages/core/src/routing/modelRouterService.ts @@ -0,0 +1,54 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../config/config.js'; +import type { + RoutingContext, + RoutingDecision, + TerminalStrategy, +} from './routingStrategy.js'; +import { DefaultStrategy } from './strategies/defaultStrategy.js'; +import { CompositeStrategy } from './strategies/compositeStrategy.js'; +import { FallbackStrategy } from './strategies/fallbackStrategy.js'; +import { OverrideStrategy } from './strategies/overrideStrategy.js'; + +/** + * A centralized service for making model routing decisions. + */ +export class ModelRouterService { + private config: Config; + private strategy: TerminalStrategy; + + constructor(config: Config) { + this.config = config; + this.strategy = this.initializeDefaultStrategy(); + } + + private initializeDefaultStrategy(): TerminalStrategy { + // Initialize the composite strategy with the desired priority order. + // The strategies are ordered in order of highest priority. + return new CompositeStrategy( + [new FallbackStrategy(), new OverrideStrategy(), new DefaultStrategy()], + 'agent-router', + ); + } + + /** + * Determines which model to use for a given request context. + * + * @param context The full context of the request. + * @returns A promise that resolves to a RoutingDecision. + */ + async route(context: RoutingContext): Promise { + const decision = await this.strategy.route( + context, + this.config, + this.config.getBaseLlmClient(), + ); + + return decision; + } +} diff --git a/packages/core/src/routing/routingStrategy.ts b/packages/core/src/routing/routingStrategy.ts new file mode 100644 index 00000000000..d5d8df8dc9e --- /dev/null +++ b/packages/core/src/routing/routingStrategy.ts @@ -0,0 +1,76 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Content, PartListUnion } from '@google/genai'; +import type { BaseLlmClient } from '../core/baseLlmClient.js'; +import type { Config } from '../config/config.js'; + +/** + * The output of a routing decision. It specifies which model to use and why. + */ +export interface RoutingDecision { + /** The model identifier string to use for the next API call (e.g., 'gemini-2.5-pro'). */ + model: string; + /** + * Metadata about the routing decision for logging purposes. + */ + metadata: { + source: string; + latencyMs: number; + reasoning: string; + error?: string; + }; +} + +/** + * The context provided to the router for making a decision. + */ +export interface RoutingContext { + /** The full history of the conversation. */ + history: Content[]; + /** The immediate request parts to be processed. */ + request: PartListUnion; + /** An abort signal to cancel an LLM call during routing. */ + signal: AbortSignal; +} + +/** + * The core interface that all routing strategies must implement. + * Strategies implementing this interface may decline a request by returning null. + */ +export interface RoutingStrategy { + /** The name of the strategy (e.g., 'fallback', 'override', 'composite'). */ + readonly name: string; + + /** + * Determines which model to use for a given request context. + * @param context The full context of the request. + * @param config The current configuration. + * @param client A reference to the GeminiClient, allowing the strategy to make its own API calls if needed. + * @returns A promise that resolves to a RoutingDecision, or null if the strategy is not applicable. + */ + route( + context: RoutingContext, + config: Config, + baseLlmClient: BaseLlmClient, + ): Promise; +} + +/** + * A strategy that is guaranteed to return a decision. It must not return null. + * This is used to ensure that a composite chain always terminates. + */ +export interface TerminalStrategy extends RoutingStrategy { + /** + * Determines which model to use for a given request context. + * @returns A promise that resolves to a RoutingDecision. + */ + route( + context: RoutingContext, + config: Config, + baseLlmClient: BaseLlmClient, + ): Promise; +} diff --git a/packages/core/src/routing/strategies/compositeStrategy.test.ts b/packages/core/src/routing/strategies/compositeStrategy.test.ts new file mode 100644 index 00000000000..7fb2c393b40 --- /dev/null +++ b/packages/core/src/routing/strategies/compositeStrategy.test.ts @@ -0,0 +1,215 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { CompositeStrategy } from './compositeStrategy.js'; +import type { + RoutingContext, + RoutingDecision, + RoutingStrategy, + TerminalStrategy, +} from '../routingStrategy.js'; +import type { Config } from '../../config/config.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; + +describe('CompositeStrategy', () => { + let mockContext: RoutingContext; + let mockConfig: Config; + let mockBaseLlmClient: BaseLlmClient; + let mockStrategy1: RoutingStrategy; + let mockStrategy2: RoutingStrategy; + let mockTerminalStrategy: TerminalStrategy; + + beforeEach(() => { + vi.clearAllMocks(); + + mockContext = {} as RoutingContext; + mockConfig = {} as Config; + mockBaseLlmClient = {} as BaseLlmClient; + + mockStrategy1 = { + name: 'strategy1', + route: vi.fn().mockResolvedValue(null), + }; + + mockStrategy2 = { + name: 'strategy2', + route: vi.fn().mockResolvedValue(null), + }; + + mockTerminalStrategy = { + name: 'terminal', + route: vi.fn().mockResolvedValue({ + model: 'terminal-model', + metadata: { + source: 'terminal', + latencyMs: 10, + reasoning: 'Terminal decision', + }, + }), + }; + }); + + it('should try strategies in order and return the first successful decision', async () => { + const decision: RoutingDecision = { + model: 'strategy2-model', + metadata: { + source: 'strategy2', + latencyMs: 20, + reasoning: 'Strategy 2 decided', + }, + }; + vi.spyOn(mockStrategy2, 'route').mockResolvedValue(decision); + + const composite = new CompositeStrategy( + [mockStrategy1, mockStrategy2, mockTerminalStrategy], + 'test-router', + ); + + const result = await composite.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(mockStrategy1.route).toHaveBeenCalledWith( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + expect(mockStrategy2.route).toHaveBeenCalledWith( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + expect(mockTerminalStrategy.route).not.toHaveBeenCalled(); + + expect(result.model).toBe('strategy2-model'); + expect(result.metadata.source).toBe('test-router/strategy2'); + }); + + it('should fall back to the terminal strategy if no other strategy provides a decision', async () => { + const composite = new CompositeStrategy( + [mockStrategy1, mockStrategy2, mockTerminalStrategy], + 'test-router', + ); + + const result = await composite.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(mockStrategy1.route).toHaveBeenCalledTimes(1); + expect(mockStrategy2.route).toHaveBeenCalledTimes(1); + expect(mockTerminalStrategy.route).toHaveBeenCalledTimes(1); + + expect(result.model).toBe('terminal-model'); + expect(result.metadata.source).toBe('test-router/terminal'); + }); + + it('should handle errors in non-terminal strategies and continue', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + vi.spyOn(mockStrategy1, 'route').mockRejectedValue( + new Error('Strategy 1 failed'), + ); + + const composite = new CompositeStrategy( + [mockStrategy1, mockTerminalStrategy], + 'test-router', + ); + + const result = await composite.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + "[Routing] Strategy 'strategy1' failed. Continuing to next strategy. Error:", + expect.any(Error), + ); + expect(result.model).toBe('terminal-model'); + consoleErrorSpy.mockRestore(); + }); + + it('should re-throw an error from the terminal strategy', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + const terminalError = new Error('Terminal strategy failed'); + vi.spyOn(mockTerminalStrategy, 'route').mockRejectedValue(terminalError); + + const composite = new CompositeStrategy([mockTerminalStrategy]); + + await expect( + composite.route(mockContext, mockConfig, mockBaseLlmClient), + ).rejects.toThrow(terminalError); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + "[Routing] Critical Error: Terminal strategy 'terminal' failed. Routing cannot proceed. Error:", + terminalError, + ); + consoleErrorSpy.mockRestore(); + }); + + it('should correctly finalize the decision metadata', async () => { + const decision: RoutingDecision = { + model: 'some-model', + metadata: { + source: 'child-source', + latencyMs: 50, + reasoning: 'Child reasoning', + }, + }; + vi.spyOn(mockStrategy1, 'route').mockResolvedValue(decision); + + const composite = new CompositeStrategy( + [mockStrategy1, mockTerminalStrategy], + 'my-composite', + ); + + const result = await composite.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(result.model).toBe('some-model'); + expect(result.metadata.source).toBe('my-composite/child-source'); + expect(result.metadata.reasoning).toBe('Child reasoning'); + // It should keep the child's latency + expect(result.metadata.latencyMs).toBe(50); + }); + + it('should calculate total latency if child latency is not provided', async () => { + const decision: RoutingDecision = { + model: 'some-model', + metadata: { + source: 'child-source', + // No latencyMs here + latencyMs: 0, + reasoning: 'Child reasoning', + }, + }; + vi.spyOn(mockStrategy1, 'route').mockResolvedValue(decision); + + const composite = new CompositeStrategy( + [mockStrategy1, mockTerminalStrategy], + 'my-composite', + ); + + const result = await composite.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(result.metadata.latencyMs).toBeGreaterThanOrEqual(0); + }); +}); diff --git a/packages/core/src/routing/strategies/compositeStrategy.ts b/packages/core/src/routing/strategies/compositeStrategy.ts new file mode 100644 index 00000000000..42646fc4e3c --- /dev/null +++ b/packages/core/src/routing/strategies/compositeStrategy.ts @@ -0,0 +1,109 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../../config/config.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { + RoutingContext, + RoutingDecision, + RoutingStrategy, + TerminalStrategy, +} from '../routingStrategy.js'; + +/** + * A strategy that attempts a list of child strategies in order (Chain of Responsibility). + */ +export class CompositeStrategy implements TerminalStrategy { + readonly name: string; + + private strategies: [...RoutingStrategy[], TerminalStrategy]; + + /** + * Initializes the CompositeStrategy. + * @param strategies The strategies to try, in order of priority. The last strategy must be terminal. + * @param name The name of this composite configuration (e.g., 'router' or 'composite'). + */ + constructor( + strategies: [...RoutingStrategy[], TerminalStrategy], + name: string = 'composite', + ) { + this.strategies = strategies; + this.name = name; + } + + async route( + context: RoutingContext, + config: Config, + baseLlmClient: BaseLlmClient, + ): Promise { + const startTime = performance.now(); + + // Separate non-terminal strategies from the terminal one. + // This separation allows TypeScript to understand the control flow guarantees. + const nonTerminalStrategies = this.strategies.slice( + 0, + -1, + ) as RoutingStrategy[]; + const terminalStrategy = this.strategies[ + this.strategies.length - 1 + ] as TerminalStrategy; + + // Try non-terminal strategies, allowing them to fail gracefully. + for (const strategy of nonTerminalStrategies) { + try { + const decision = await strategy.route(context, config, baseLlmClient); + if (decision) { + return this.finalizeDecision(decision, startTime); + } + } catch (error) { + console.error( + `[Routing] Strategy '${strategy.name}' failed. Continuing to next strategy. Error:`, + error, + ); + } + } + + // If no other strategy matched, execute the terminal strategy. + try { + const decision = await terminalStrategy.route( + context, + config, + baseLlmClient, + ); + + return this.finalizeDecision(decision, startTime); + } catch (error) { + console.error( + `[Routing] Critical Error: Terminal strategy '${terminalStrategy.name}' failed. Routing cannot proceed. Error:`, + error, + ); + throw error; + } + } + + /** + * Helper function to enhance the decision metadata with composite information. + */ + private finalizeDecision( + decision: RoutingDecision, + startTime: number, + ): RoutingDecision { + const endTime = performance.now(); + const totalLatency = endTime - startTime; + + // Combine the source paths: composite_name/child_source (e.g. 'router/default') + const compositeSource = `${this.name}/${decision.metadata.source}`; + + return { + ...decision, + metadata: { + ...decision.metadata, + source: compositeSource, + latencyMs: decision.metadata.latencyMs || totalLatency, + }, + }; + } +} diff --git a/packages/core/src/routing/strategies/defaultStrategy.test.ts b/packages/core/src/routing/strategies/defaultStrategy.test.ts new file mode 100644 index 00000000000..1c739545a4c --- /dev/null +++ b/packages/core/src/routing/strategies/defaultStrategy.test.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { DefaultStrategy } from './defaultStrategy.js'; +import type { RoutingContext } from '../routingStrategy.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import { DEFAULT_GEMINI_MODEL } from '../../config/models.js'; +import type { Config } from '../../config/config.js'; + +describe('DefaultStrategy', () => { + it('should always route to the default Gemini model', async () => { + const strategy = new DefaultStrategy(); + const mockContext = {} as RoutingContext; + const mockConfig = {} as Config; + const mockClient = {} as BaseLlmClient; + + const decision = await strategy.route(mockContext, mockConfig, mockClient); + + expect(decision).toEqual({ + model: DEFAULT_GEMINI_MODEL, + metadata: { + source: 'default', + latencyMs: 0, + reasoning: `Routing to default model: ${DEFAULT_GEMINI_MODEL}`, + }, + }); + }); +}); diff --git a/packages/core/src/routing/strategies/defaultStrategy.ts b/packages/core/src/routing/strategies/defaultStrategy.ts new file mode 100644 index 00000000000..dba7949f9e7 --- /dev/null +++ b/packages/core/src/routing/strategies/defaultStrategy.ts @@ -0,0 +1,33 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../../config/config.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { + RoutingContext, + RoutingDecision, + TerminalStrategy, +} from '../routingStrategy.js'; +import { DEFAULT_GEMINI_MODEL } from '../../config/models.js'; + +export class DefaultStrategy implements TerminalStrategy { + readonly name = 'default'; + + async route( + _context: RoutingContext, + _config: Config, + _baseLlmClient: BaseLlmClient, + ): Promise { + return { + model: DEFAULT_GEMINI_MODEL, + metadata: { + source: this.name, + latencyMs: 0, + reasoning: `Routing to default model: ${DEFAULT_GEMINI_MODEL}`, + }, + }; + } +} diff --git a/packages/core/src/routing/strategies/fallbackStrategy.test.ts b/packages/core/src/routing/strategies/fallbackStrategy.test.ts new file mode 100644 index 00000000000..dfda72d4ca9 --- /dev/null +++ b/packages/core/src/routing/strategies/fallbackStrategy.test.ts @@ -0,0 +1,86 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { FallbackStrategy } from './fallbackStrategy.js'; +import type { RoutingContext } from '../routingStrategy.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { Config } from '../../config/config.js'; +import { + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_FLASH_LITE_MODEL, +} from '../../config/models.js'; + +describe('FallbackStrategy', () => { + const strategy = new FallbackStrategy(); + const mockContext = {} as RoutingContext; + const mockClient = {} as BaseLlmClient; + + it('should return null when not in fallback mode', async () => { + const mockConfig = { + isInFallbackMode: () => false, + getModel: () => DEFAULT_GEMINI_MODEL, + } as Config; + + const decision = await strategy.route(mockContext, mockConfig, mockClient); + expect(decision).toBeNull(); + }); + + describe('when in fallback mode', () => { + it('should downgrade a pro model to the flash model', async () => { + const mockConfig = { + isInFallbackMode: () => true, + getModel: () => DEFAULT_GEMINI_MODEL, + } as Config; + + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + ); + + expect(decision).not.toBeNull(); + expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + expect(decision?.metadata.source).toBe('fallback'); + expect(decision?.metadata.reasoning).toContain('In fallback mode'); + }); + + it('should honor a lite model request', async () => { + const mockConfig = { + isInFallbackMode: () => true, + getModel: () => DEFAULT_GEMINI_FLASH_LITE_MODEL, + } as Config; + + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + ); + + expect(decision).not.toBeNull(); + expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_LITE_MODEL); + expect(decision?.metadata.source).toBe('fallback'); + }); + + it('should use the flash model if flash is requested', async () => { + const mockConfig = { + isInFallbackMode: () => true, + getModel: () => DEFAULT_GEMINI_FLASH_MODEL, + } as Config; + + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + ); + + expect(decision).not.toBeNull(); + expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + expect(decision?.metadata.source).toBe('fallback'); + }); + }); +}); diff --git a/packages/core/src/routing/strategies/fallbackStrategy.ts b/packages/core/src/routing/strategies/fallbackStrategy.ts new file mode 100644 index 00000000000..aef01743aa9 --- /dev/null +++ b/packages/core/src/routing/strategies/fallbackStrategy.ts @@ -0,0 +1,43 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../../config/config.js'; +import { getEffectiveModel } from '../../config/models.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { + RoutingContext, + RoutingDecision, + RoutingStrategy, +} from '../routingStrategy.js'; + +export class FallbackStrategy implements RoutingStrategy { + readonly name = 'fallback'; + + async route( + _context: RoutingContext, + config: Config, + _baseLlmClient: BaseLlmClient, + ): Promise { + const isInFallbackMode: boolean = config.isInFallbackMode(); + + if (!isInFallbackMode) { + return null; + } + + const effectiveModel = getEffectiveModel( + isInFallbackMode, + config.getModel(), + ); + return { + model: effectiveModel, + metadata: { + source: this.name, + latencyMs: 0, + reasoning: `In fallback mode. Using: ${effectiveModel}`, + }, + }; + } +} diff --git a/packages/core/src/routing/strategies/overrideStrategy.test.ts b/packages/core/src/routing/strategies/overrideStrategy.test.ts new file mode 100644 index 00000000000..69c4088f8d7 --- /dev/null +++ b/packages/core/src/routing/strategies/overrideStrategy.test.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { OverrideStrategy } from './overrideStrategy.js'; +import type { RoutingContext } from '../routingStrategy.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { Config } from '../../config/config.js'; + +describe('OverrideStrategy', () => { + const strategy = new OverrideStrategy(); + const mockContext = {} as RoutingContext; + const mockClient = {} as BaseLlmClient; + + it('should return null when no override model is specified', async () => { + const mockConfig = { + getModel: () => '', // Simulate no model override + } as Config; + + const decision = await strategy.route(mockContext, mockConfig, mockClient); + expect(decision).toBeNull(); + }); + + it('should return a decision with the override model when one is specified', async () => { + const overrideModel = 'gemini-2.5-pro-custom'; + const mockConfig = { + getModel: () => overrideModel, + } as Config; + + const decision = await strategy.route(mockContext, mockConfig, mockClient); + + expect(decision).not.toBeNull(); + expect(decision?.model).toBe(overrideModel); + expect(decision?.metadata.source).toBe('override'); + expect(decision?.metadata.reasoning).toContain( + 'Routing bypassed by forced model directive', + ); + expect(decision?.metadata.reasoning).toContain(overrideModel); + }); + + it('should handle different override model names', async () => { + const overrideModel = 'gemini-2.5-flash-experimental'; + const mockConfig = { + getModel: () => overrideModel, + } as Config; + + const decision = await strategy.route(mockContext, mockConfig, mockClient); + + expect(decision).not.toBeNull(); + expect(decision?.model).toBe(overrideModel); + }); +}); diff --git a/packages/core/src/routing/strategies/overrideStrategy.ts b/packages/core/src/routing/strategies/overrideStrategy.ts new file mode 100644 index 00000000000..b3aef6c3322 --- /dev/null +++ b/packages/core/src/routing/strategies/overrideStrategy.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../../config/config.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { + RoutingContext, + RoutingDecision, + RoutingStrategy, +} from '../routingStrategy.js'; + +/** + * Handles cases where the user explicitly specifies a model (override). + */ +export class OverrideStrategy implements RoutingStrategy { + readonly name = 'override'; + + async route( + _context: RoutingContext, + config: Config, + _baseLlmClient: BaseLlmClient, + ): Promise { + const overrideModel = config.getModel(); + if (overrideModel) { + return { + model: overrideModel, + metadata: { + source: this.name, + latencyMs: 0, + reasoning: `Routing bypassed by forced model directive. Using: ${overrideModel}`, + }, + }; + } + // No override specified, pass to the next strategy. + return null; + } +} diff --git a/packages/core/src/services/fileDiscoveryService.ts b/packages/core/src/services/fileDiscoveryService.ts index 0a309f8faab..4620362685f 100644 --- a/packages/core/src/services/fileDiscoveryService.ts +++ b/packages/core/src/services/fileDiscoveryService.ts @@ -5,40 +5,34 @@ */ import type { GitIgnoreFilter } from '../utils/gitIgnoreParser.js'; +import type { GeminiIgnoreFilter } from '../utils/geminiIgnoreParser.js'; import { GitIgnoreParser } from '../utils/gitIgnoreParser.js'; +import { GeminiIgnoreParser } from '../utils/geminiIgnoreParser.js'; import { isGitRepository } from '../utils/gitUtils.js'; import * as path from 'node:path'; -const GEMINI_IGNORE_FILE_NAME = '.geminiignore'; - export interface FilterFilesOptions { respectGitIgnore?: boolean; respectGeminiIgnore?: boolean; } +export interface FilterReport { + filteredPaths: string[]; + gitIgnoredCount: number; + geminiIgnoredCount: number; +} + export class FileDiscoveryService { private gitIgnoreFilter: GitIgnoreFilter | null = null; - private geminiIgnoreFilter: GitIgnoreFilter | null = null; + private geminiIgnoreFilter: GeminiIgnoreFilter | null = null; private projectRoot: string; constructor(projectRoot: string) { this.projectRoot = path.resolve(projectRoot); if (isGitRepository(this.projectRoot)) { - const parser = new GitIgnoreParser(this.projectRoot); - try { - parser.loadGitRepoPatterns(); - } catch (_error) { - // ignore file not found - } - this.gitIgnoreFilter = parser; - } - const gParser = new GitIgnoreParser(this.projectRoot); - try { - gParser.loadPatterns(GEMINI_IGNORE_FILE_NAME); - } catch (_error) { - // ignore file not found + this.gitIgnoreFilter = new GitIgnoreParser(this.projectRoot); } - this.geminiIgnoreFilter = gParser; + this.geminiIgnoreFilter = new GeminiIgnoreParser(this.projectRoot); } /** @@ -65,6 +59,42 @@ export class FileDiscoveryService { }); } + /** + * Filters a list of file paths based on git ignore rules and returns a report + * with counts of ignored files. + */ + filterFilesWithReport( + filePaths: string[], + opts: FilterFilesOptions = { + respectGitIgnore: true, + respectGeminiIgnore: true, + }, + ): FilterReport { + const filteredPaths: string[] = []; + let gitIgnoredCount = 0; + let geminiIgnoredCount = 0; + + for (const filePath of filePaths) { + if (opts.respectGitIgnore && this.shouldGitIgnoreFile(filePath)) { + gitIgnoredCount++; + continue; + } + + if (opts.respectGeminiIgnore && this.shouldGeminiIgnoreFile(filePath)) { + geminiIgnoredCount++; + continue; + } + + filteredPaths.push(filePath); + } + + return { + filteredPaths, + gitIgnoredCount, + geminiIgnoredCount, + }; + } + /** * Checks if a single file should be git-ignored */ diff --git a/packages/core/src/services/loopDetectionService.test.ts b/packages/core/src/services/loopDetectionService.test.ts index 542ea6ce544..0b842d60545 100644 --- a/packages/core/src/services/loopDetectionService.test.ts +++ b/packages/core/src/services/loopDetectionService.test.ts @@ -130,6 +130,15 @@ describe('LoopDetectionService', () => { expect(service.addAndCheck(toolCallEvent)).toBe(true); expect(loggers.logLoopDetected).toHaveBeenCalledTimes(1); }); + + it('should not detect a loop when disabled for session', () => { + service.disableForSession(); + const event = createToolCallRequestEvent('testTool', { param: 'value' }); + for (let i = 0; i < TOOL_CALL_LOOP_THRESHOLD; i++) { + expect(service.addAndCheck(event)).toBe(false); + } + expect(loggers.logLoopDetected).not.toHaveBeenCalled(); + }); }); describe('Content Loop Detection', () => { @@ -719,4 +728,12 @@ describe('LoopDetectionService LLM Checks', () => { expect(result).toBe(false); expect(loggers.logLoopDetected).not.toHaveBeenCalled(); }); + + it('should not trigger LLM check when disabled for session', async () => { + service.disableForSession(); + await advanceTurns(30); + const result = await service.turnStarted(abortController.signal); + expect(result).toBe(false); + expect(mockGeminiClient.generateJson).not.toHaveBeenCalled(); + }); }); diff --git a/packages/core/src/services/loopDetectionService.ts b/packages/core/src/services/loopDetectionService.ts index 97cb94dcb6b..d2c531b291c 100644 --- a/packages/core/src/services/loopDetectionService.ts +++ b/packages/core/src/services/loopDetectionService.ts @@ -74,10 +74,20 @@ export class LoopDetectionService { private llmCheckInterval = DEFAULT_LLM_CHECK_INTERVAL; private lastCheckTurn = 0; + // Session-level disable flag + private disabledForSession = false; + constructor(config: Config) { this.config = config; } + /** + * Disables loop detection for the current session. + */ + disableForSession(): void { + this.disabledForSession = true; + } + private getToolCallKey(toolCall: { name: string; args: object }): string { const argsString = JSON.stringify(toolCall.args); const keyString = `${toolCall.name}:${argsString}`; @@ -90,8 +100,8 @@ export class LoopDetectionService { * @returns true if a loop is detected, false otherwise */ addAndCheck(event: ServerGeminiStreamEvent): boolean { - if (this.loopDetected) { - return true; + if (this.loopDetected || this.disabledForSession) { + return this.loopDetected; } switch (event.type) { @@ -121,6 +131,9 @@ export class LoopDetectionService { * @returns A promise that resolves to `true` if a loop is detected, and `false` otherwise. */ async turnStarted(signal: AbortSignal) { + if (this.disabledForSession) { + return false; + } this.turnsInCurrentPrompt++; if ( diff --git a/packages/core/src/services/shellExecutionService.test.ts b/packages/core/src/services/shellExecutionService.test.ts index 3edce90268a..26b663cc3dc 100644 --- a/packages/core/src/services/shellExecutionService.test.ts +++ b/packages/core/src/services/shellExecutionService.test.ts @@ -17,6 +17,7 @@ const mockCpSpawn = vi.hoisted(() => vi.fn()); const mockIsBinary = vi.hoisted(() => vi.fn()); const mockPlatform = vi.hoisted(() => vi.fn()); const mockGetPty = vi.hoisted(() => vi.fn()); +const mockSerializeTerminalToObject = vi.hoisted(() => vi.fn()); // Top-level Mocks vi.mock('@lydell/node-pty', () => ({ @@ -49,6 +50,16 @@ vi.mock('os', () => ({ vi.mock('../utils/getPty.js', () => ({ getPty: mockGetPty, })); +vi.mock('../utils/terminalSerializer.js', () => ({ + serializeTerminalToObject: mockSerializeTerminalToObject, +})); + +const shellExecutionConfig = { + terminalWidth: 80, + terminalHeight: 24, + pager: 'cat', + showColor: false, +}; const mockProcessKill = vi .spyOn(process, 'kill') @@ -60,6 +71,12 @@ describe('ShellExecutionService', () => { kill: Mock; onData: Mock; onExit: Mock; + write: Mock; + resize: Mock; + }; + let mockHeadlessTerminal: { + resize: Mock; + scrollLines: Mock; }; let onOutputEventMock: Mock<(event: ShellOutputEvent) => void>; @@ -80,11 +97,20 @@ describe('ShellExecutionService', () => { kill: Mock; onData: Mock; onExit: Mock; + write: Mock; + resize: Mock; }; mockPtyProcess.pid = 12345; mockPtyProcess.kill = vi.fn(); mockPtyProcess.onData = vi.fn(); mockPtyProcess.onExit = vi.fn(); + mockPtyProcess.write = vi.fn(); + mockPtyProcess.resize = vi.fn(); + + mockHeadlessTerminal = { + resize: vi.fn(), + scrollLines: vi.fn(), + }; mockPtySpawn.mockReturnValue(mockPtyProcess); }); @@ -96,6 +122,7 @@ describe('ShellExecutionService', () => { ptyProcess: typeof mockPtyProcess, ac: AbortController, ) => void, + config = shellExecutionConfig, ) => { const abortController = new AbortController(); const handle = await ShellExecutionService.execute( @@ -104,9 +131,10 @@ describe('ShellExecutionService', () => { onOutputEventMock, abortController.signal, true, + config, ); - await new Promise((resolve) => setImmediate(resolve)); + await new Promise((resolve) => process.nextTick(resolve)); simulation(mockPtyProcess, abortController); const result = await handle.result; return { result, handle, abortController }; @@ -128,12 +156,12 @@ describe('ShellExecutionService', () => { expect(result.signal).toBeNull(); expect(result.error).toBeNull(); expect(result.aborted).toBe(false); - expect(result.output).toBe('file1.txt'); + expect(result.output.trim()).toBe('file1.txt'); expect(handle.pid).toBe(12345); expect(onOutputEventMock).toHaveBeenCalledWith({ type: 'data', - chunk: 'file1.txt\n', + chunk: 'file1.txt', }); }); @@ -143,11 +171,13 @@ describe('ShellExecutionService', () => { pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null }); }); - expect(result.output).toBe('aredword'); - expect(onOutputEventMock).toHaveBeenCalledWith({ - type: 'data', - chunk: 'aredword', - }); + expect(result.output.trim()).toBe('aredword'); + expect(onOutputEventMock).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'data', + chunk: expect.stringContaining('aredword'), + }), + ); }); it('should correctly decode multi-byte characters split across chunks', async () => { @@ -157,16 +187,81 @@ describe('ShellExecutionService', () => { pty.onData.mock.calls[0][0](multiByteChar.slice(1)); pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null }); }); - expect(result.output).toBe('你好'); + expect(result.output.trim()).toBe('你好'); }); it('should handle commands with no output', async () => { - const { result } = await simulateExecution('touch file', (pty) => { + await simulateExecution('touch file', (pty) => { pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null }); }); - expect(result.output).toBe(''); - expect(onOutputEventMock).not.toHaveBeenCalled(); + expect(onOutputEventMock).toHaveBeenCalledWith( + expect.objectContaining({ + chunk: expect.stringMatching(/^\s*$/), + }), + ); + }); + + it('should call onPid with the process id', async () => { + const abortController = new AbortController(); + const handle = await ShellExecutionService.execute( + 'ls -l', + '/test/dir', + onOutputEventMock, + abortController.signal, + true, + shellExecutionConfig, + ); + mockPtyProcess.onExit.mock.calls[0][0]({ exitCode: 0, signal: null }); + await handle.result; + expect(handle.pid).toBe(12345); + }); + }); + + describe('pty interaction', () => { + beforeEach(() => { + vi.spyOn(ShellExecutionService['activePtys'], 'get').mockReturnValue({ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ptyProcess: mockPtyProcess as any, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + headlessTerminal: mockHeadlessTerminal as any, + }); + }); + + it('should write to the pty and trigger a render', async () => { + vi.useFakeTimers(); + await simulateExecution('interactive-app', (pty) => { + ShellExecutionService.writeToPty(pty.pid!, 'input'); + pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null }); + }); + + expect(mockPtyProcess.write).toHaveBeenCalledWith('input'); + // Use fake timers to check for the delayed render + await vi.advanceTimersByTimeAsync(17); + // The render will cause an output event + expect(onOutputEventMock).toHaveBeenCalled(); + vi.useRealTimers(); + }); + + it('should resize the pty and the headless terminal', async () => { + await simulateExecution('ls -l', (pty) => { + pty.onData.mock.calls[0][0]('file1.txt\n'); + ShellExecutionService.resizePty(pty.pid!, 100, 40); + pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null }); + }); + + expect(mockPtyProcess.resize).toHaveBeenCalledWith(100, 40); + expect(mockHeadlessTerminal.resize).toHaveBeenCalledWith(100, 40); + }); + + it('should scroll the headless terminal', async () => { + await simulateExecution('ls -l', (pty) => { + pty.onData.mock.calls[0][0]('file1.txt\n'); + ShellExecutionService.scrollPty(pty.pid!, 10); + pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null }); + }); + + expect(mockHeadlessTerminal.scrollLines).toHaveBeenCalledWith(10); }); }); @@ -178,7 +273,7 @@ describe('ShellExecutionService', () => { }); expect(result.exitCode).toBe(127); - expect(result.output).toBe('command not found'); + expect(result.output.trim()).toBe('command not found'); expect(result.error).toBeNull(); }); @@ -204,6 +299,7 @@ describe('ShellExecutionService', () => { onOutputEventMock, new AbortController().signal, true, + {}, ); const result = await handle.result; @@ -226,7 +322,7 @@ describe('ShellExecutionService', () => { ); expect(result.aborted).toBe(true); - expect(mockPtyProcess.kill).toHaveBeenCalled(); + // The process kill is mocked, so we just check that the flag is set. }); }); @@ -263,7 +359,6 @@ describe('ShellExecutionService', () => { mockIsBinary.mockImplementation((buffer) => buffer.includes(0x00)); await simulateExecution('cat mixed_file', (pty) => { - pty.onData.mock.calls[0][0](Buffer.from('some text')); pty.onData.mock.calls[0][0](Buffer.from([0x00, 0x01, 0x02])); pty.onData.mock.calls[0][0](Buffer.from('more text')); pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null }); @@ -273,7 +368,6 @@ describe('ShellExecutionService', () => { (call: [ShellOutputEvent]) => call[0].type, ); expect(eventTypes).toEqual([ - 'data', 'binary_detected', 'binary_progress', 'binary_progress', @@ -308,6 +402,57 @@ describe('ShellExecutionService', () => { ); }); }); + + describe('AnsiOutput rendering', () => { + it('should call onOutputEvent with AnsiOutput when showColor is true', async () => { + const coloredShellExecutionConfig = { + ...shellExecutionConfig, + showColor: true, + defaultFg: '#ffffff', + defaultBg: '#000000', + }; + const mockAnsiOutput = [ + [{ text: 'hello', fg: '#ffffff', bg: '#000000' }], + ]; + mockSerializeTerminalToObject.mockReturnValue(mockAnsiOutput); + + await simulateExecution( + 'ls --color=auto', + (pty) => { + pty.onData.mock.calls[0][0]('a\u001b[31mred\u001b[0mword'); + pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null }); + }, + coloredShellExecutionConfig, + ); + + expect(mockSerializeTerminalToObject).toHaveBeenCalledWith( + expect.anything(), // The terminal object + { defaultFg: '#ffffff', defaultBg: '#000000' }, + ); + + expect(onOutputEventMock).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'data', + chunk: mockAnsiOutput, + }), + ); + }); + + it('should call onOutputEvent with plain string when showColor is false', async () => { + await simulateExecution('ls --color=auto', (pty) => { + pty.onData.mock.calls[0][0]('a\u001b[31mred\u001b[0mword'); + pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null }); + }); + + expect(mockSerializeTerminalToObject).not.toHaveBeenCalled(); + expect(onOutputEventMock).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'data', + chunk: 'aredword', + }), + ); + }); + }); }); describe('ShellExecutionService child_process fallback', () => { @@ -349,9 +494,10 @@ describe('ShellExecutionService child_process fallback', () => { onOutputEventMock, abortController.signal, true, + shellExecutionConfig, ); - await new Promise((resolve) => setImmediate(resolve)); + await new Promise((resolve) => process.nextTick(resolve)); simulation(mockChildProcess, abortController); const result = await handle.result; return { result, handle, abortController }; @@ -363,6 +509,7 @@ describe('ShellExecutionService child_process fallback', () => { cp.stdout?.emit('data', Buffer.from('file1.txt\n')); cp.stderr?.emit('data', Buffer.from('a warning')); cp.emit('exit', 0, null); + cp.emit('close', 0, null); }); expect(mockCpSpawn).toHaveBeenCalledWith( @@ -375,15 +522,11 @@ describe('ShellExecutionService child_process fallback', () => { expect(result.error).toBeNull(); expect(result.aborted).toBe(false); expect(result.output).toBe('file1.txt\na warning'); - expect(handle.pid).toBe(12345); + expect(handle.pid).toBe(undefined); expect(onOutputEventMock).toHaveBeenCalledWith({ type: 'data', - chunk: 'file1.txt\n', - }); - expect(onOutputEventMock).toHaveBeenCalledWith({ - type: 'data', - chunk: 'a warning', + chunk: 'file1.txt\na warning', }); }); @@ -391,13 +534,16 @@ describe('ShellExecutionService child_process fallback', () => { const { result } = await simulateExecution('ls --color=auto', (cp) => { cp.stdout?.emit('data', Buffer.from('a\u001b[31mred\u001b[0mword')); cp.emit('exit', 0, null); + cp.emit('close', 0, null); }); - expect(result.output).toBe('aredword'); - expect(onOutputEventMock).toHaveBeenCalledWith({ - type: 'data', - chunk: 'aredword', - }); + expect(result.output.trim()).toBe('aredword'); + expect(onOutputEventMock).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'data', + chunk: expect.stringContaining('aredword'), + }), + ); }); it('should correctly decode multi-byte characters split across chunks', async () => { @@ -406,16 +552,18 @@ describe('ShellExecutionService child_process fallback', () => { cp.stdout?.emit('data', multiByteChar.slice(0, 2)); cp.stdout?.emit('data', multiByteChar.slice(2)); cp.emit('exit', 0, null); + cp.emit('close', 0, null); }); - expect(result.output).toBe('你好'); + expect(result.output.trim()).toBe('你好'); }); it('should handle commands with no output', async () => { const { result } = await simulateExecution('touch file', (cp) => { cp.emit('exit', 0, null); + cp.emit('close', 0, null); }); - expect(result.output).toBe(''); + expect(result.output.trim()).toBe(''); expect(onOutputEventMock).not.toHaveBeenCalled(); }); }); @@ -425,16 +573,18 @@ describe('ShellExecutionService child_process fallback', () => { const { result } = await simulateExecution('a-bad-command', (cp) => { cp.stderr?.emit('data', Buffer.from('command not found')); cp.emit('exit', 127, null); + cp.emit('close', 127, null); }); expect(result.exitCode).toBe(127); - expect(result.output).toBe('command not found'); + expect(result.output.trim()).toBe('command not found'); expect(result.error).toBeNull(); }); it('should capture a termination signal', async () => { const { result } = await simulateExecution('long-process', (cp) => { cp.emit('exit', null, 'SIGTERM'); + cp.emit('close', null, 'SIGTERM'); }); expect(result.exitCode).toBeNull(); @@ -446,6 +596,7 @@ describe('ShellExecutionService child_process fallback', () => { const { result } = await simulateExecution('protected-cmd', (cp) => { cp.emit('error', spawnError); cp.emit('exit', 1, null); + cp.emit('close', 1, null); }); expect(result.error).toBe(spawnError); @@ -456,6 +607,7 @@ describe('ShellExecutionService child_process fallback', () => { const error = new Error('spawn abc ENOENT'); const { result } = await simulateExecution('touch cat.jpg', (cp) => { cp.emit('error', error); // No exit event is fired. + cp.emit('close', 1, null); }); expect(result.error).toBe(error); @@ -485,10 +637,14 @@ describe('ShellExecutionService child_process fallback', () => { 'sleep 10', (cp, abortController) => { abortController.abort(); - if (expectedExit.signal) + if (expectedExit.signal) { cp.emit('exit', null, expectedExit.signal); - if (typeof expectedExit.code === 'number') + cp.emit('close', null, expectedExit.signal); + } + if (typeof expectedExit.code === 'number') { cp.emit('exit', expectedExit.code, null); + cp.emit('close', expectedExit.code, null); + } }, ); @@ -524,6 +680,7 @@ describe('ShellExecutionService child_process fallback', () => { onOutputEventMock, abortController.signal, true, + {}, ); abortController.abort(); @@ -545,14 +702,13 @@ describe('ShellExecutionService child_process fallback', () => { // Finally, simulate the process exiting and await the result mockChildProcess.emit('exit', null, 'SIGKILL'); + mockChildProcess.emit('close', null, 'SIGKILL'); const result = await handle.result; vi.useRealTimers(); expect(result.aborted).toBe(true); expect(result.signal).toBe(9); - // The individual kill calls were already asserted above. - expect(mockProcessKill).toHaveBeenCalledTimes(2); }); }); @@ -571,18 +727,10 @@ describe('ShellExecutionService child_process fallback', () => { expect(result.rawOutput).toEqual( Buffer.concat([binaryChunk1, binaryChunk2]), ); - expect(onOutputEventMock).toHaveBeenCalledTimes(3); + expect(onOutputEventMock).toHaveBeenCalledTimes(1); expect(onOutputEventMock.mock.calls[0][0]).toEqual({ type: 'binary_detected', }); - expect(onOutputEventMock.mock.calls[1][0]).toEqual({ - type: 'binary_progress', - bytesReceived: 4, - }); - expect(onOutputEventMock.mock.calls[2][0]).toEqual({ - type: 'binary_progress', - bytesReceived: 8, - }); }); it('should not emit data events after binary is detected', async () => { @@ -598,12 +746,7 @@ describe('ShellExecutionService child_process fallback', () => { const eventTypes = onOutputEventMock.mock.calls.map( (call: [ShellOutputEvent]) => call[0].type, ); - expect(eventTypes).toEqual([ - 'data', - 'binary_detected', - 'binary_progress', - 'binary_progress', - ]); + expect(eventTypes).toEqual(['binary_detected']); }); }); @@ -647,6 +790,8 @@ describe('ShellExecutionService execution method selection', () => { kill: Mock; onData: Mock; onExit: Mock; + write: Mock; + resize: Mock; }; let mockChildProcess: EventEmitter & Partial; @@ -660,11 +805,16 @@ describe('ShellExecutionService execution method selection', () => { kill: Mock; onData: Mock; onExit: Mock; + write: Mock; + resize: Mock; }; mockPtyProcess.pid = 12345; mockPtyProcess.kill = vi.fn(); mockPtyProcess.onData = vi.fn(); mockPtyProcess.onExit = vi.fn(); + mockPtyProcess.write = vi.fn(); + mockPtyProcess.resize = vi.fn(); + mockPtySpawn.mockReturnValue(mockPtyProcess); mockGetPty.mockResolvedValue({ module: { spawn: mockPtySpawn }, @@ -692,6 +842,7 @@ describe('ShellExecutionService execution method selection', () => { onOutputEventMock, abortController.signal, true, // shouldUseNodePty + shellExecutionConfig, ); // Simulate exit to allow promise to resolve @@ -712,6 +863,7 @@ describe('ShellExecutionService execution method selection', () => { onOutputEventMock, abortController.signal, false, // shouldUseNodePty + {}, ); // Simulate exit to allow promise to resolve @@ -734,6 +886,7 @@ describe('ShellExecutionService execution method selection', () => { onOutputEventMock, abortController.signal, true, // shouldUseNodePty + shellExecutionConfig, ); // Simulate exit to allow promise to resolve diff --git a/packages/core/src/services/shellExecutionService.ts b/packages/core/src/services/shellExecutionService.ts index f6f7fff7b61..23cff439a28 100644 --- a/packages/core/src/services/shellExecutionService.ts +++ b/packages/core/src/services/shellExecutionService.ts @@ -4,30 +4,24 @@ * SPDX-License-Identifier: Apache-2.0 */ +import stripAnsi from 'strip-ansi'; import type { PtyImplementation } from '../utils/getPty.js'; import { getPty } from '../utils/getPty.js'; import { spawn as cpSpawn } from 'node:child_process'; import { TextDecoder } from 'node:util'; import os from 'node:os'; +import type { IPty } from '@lydell/node-pty'; import { getCachedEncodingForBuffer } from '../utils/systemEncoding.js'; import { isBinary } from '../utils/textUtils.js'; import pkg from '@xterm/headless'; -import stripAnsi from 'strip-ansi'; +import { + serializeTerminalToObject, + type AnsiOutput, +} from '../utils/terminalSerializer.js'; const { Terminal } = pkg; const SIGKILL_TIMEOUT_MS = 200; -// @ts-expect-error getFullText is not a public API. -const getFullText = (terminal: Terminal) => { - const buffer = terminal.buffer.active; - const lines: string[] = []; - for (let i = 0; i < buffer.length; i++) { - const line = buffer.getLine(i); - lines.push(line ? line.translateToString(true) : ''); - } - return lines.join('\n').trim(); -}; - /** A structured result from a shell command execution. */ export interface ShellExecutionResult { /** The raw, unprocessed output buffer. */ @@ -56,6 +50,15 @@ export interface ShellExecutionHandle { result: Promise; } +export interface ShellExecutionConfig { + terminalWidth?: number; + terminalHeight?: number; + pager?: string; + showColor?: boolean; + defaultFg?: string; + defaultBg?: string; +} + /** * Describes a structured event emitted during shell command execution. */ @@ -64,7 +67,7 @@ export type ShellOutputEvent = /** The event contains a chunk of output data. */ type: 'data'; /** The decoded string chunk. */ - chunk: string; + chunk: string | AnsiOutput; } | { /** Signals that the output stream has been identified as binary. */ @@ -77,12 +80,41 @@ export type ShellOutputEvent = bytesReceived: number; }; +interface ActivePty { + ptyProcess: IPty; + headlessTerminal: pkg.Terminal; +} + +const getVisibleText = (terminal: pkg.Terminal): string => { + const buffer = terminal.buffer.active; + const lines: string[] = []; + for (let i = 0; i < terminal.rows; i++) { + const line = buffer.getLine(buffer.viewportY + i); + const lineContent = line ? line.translateToString(true) : ''; + lines.push(lineContent); + } + return lines.join('\n').trimEnd(); +}; + +const getFullBufferText = (terminal: pkg.Terminal): string => { + const buffer = terminal.buffer.active; + const lines: string[] = []; + for (let i = 0; i < buffer.length; i++) { + const line = buffer.getLine(i); + const lineContent = line ? line.translateToString() : ''; + lines.push(lineContent); + } + return lines.join('\n').trimEnd(); +}; + /** * A centralized service for executing shell commands with robust process * management, cross-platform compatibility, and streaming output capabilities. * */ + export class ShellExecutionService { + private static activePtys = new Map(); /** * Executes a shell command using `node-pty`, capturing all output and lifecycle events. * @@ -99,8 +131,7 @@ export class ShellExecutionService { onOutputEvent: (event: ShellOutputEvent) => void, abortSignal: AbortSignal, shouldUseNodePty: boolean, - terminalColumns?: number, - terminalRows?: number, + shellExecutionConfig: ShellExecutionConfig, ): Promise { if (shouldUseNodePty) { const ptyInfo = await getPty(); @@ -111,8 +142,7 @@ export class ShellExecutionService { cwd, onOutputEvent, abortSignal, - terminalColumns, - terminalRows, + shellExecutionConfig, ptyInfo, ); } catch (_e) { @@ -186,31 +216,18 @@ export class ShellExecutionService { if (isBinary(sniffBuffer)) { isStreamingRawContent = false; - onOutputEvent({ type: 'binary_detected' }); } } - const decoder = stream === 'stdout' ? stdoutDecoder : stderrDecoder; - const decodedChunk = decoder.decode(data, { stream: true }); - const strippedChunk = stripAnsi(decodedChunk); - - if (stream === 'stdout') { - stdout += strippedChunk; - } else { - stderr += strippedChunk; - } - if (isStreamingRawContent) { - onOutputEvent({ type: 'data', chunk: strippedChunk }); - } else { - const totalBytes = outputChunks.reduce( - (sum, chunk) => sum + chunk.length, - 0, - ); - onOutputEvent({ - type: 'binary_progress', - bytesReceived: totalBytes, - }); + const decoder = stream === 'stdout' ? stdoutDecoder : stderrDecoder; + const decodedChunk = decoder.decode(data, { stream: true }); + + if (stream === 'stdout') { + stdout += decodedChunk; + } else { + stderr += decodedChunk; + } } }; @@ -224,14 +241,24 @@ export class ShellExecutionService { const combinedOutput = stdout + (stderr ? (stdout ? separator : '') + stderr : ''); + const finalStrippedOutput = stripAnsi(combinedOutput).trim(); + + if (isStreamingRawContent) { + if (finalStrippedOutput) { + onOutputEvent({ type: 'data', chunk: finalStrippedOutput }); + } + } else { + onOutputEvent({ type: 'binary_detected' }); + } + resolve({ rawOutput: finalBuffer, - output: combinedOutput.trim(), + output: finalStrippedOutput, exitCode: code, signal: signal ? os.constants.signals[signal] : null, error, aborted: abortSignal.aborted, - pid: child.pid, + pid: undefined, executionMethod: 'child_process', }); }; @@ -264,6 +291,9 @@ export class ShellExecutionService { abortSignal.addEventListener('abort', abortHandler, { once: true }); child.on('exit', (code, signal) => { + if (child.pid) { + this.activePtys.delete(child.pid); + } handleExit(code, signal); }); @@ -273,13 +303,13 @@ export class ShellExecutionService { if (stdoutDecoder) { const remaining = stdoutDecoder.decode(); if (remaining) { - stdout += stripAnsi(remaining); + stdout += remaining; } } if (stderrDecoder) { const remaining = stderrDecoder.decode(); if (remaining) { - stderr += stripAnsi(remaining); + stderr += remaining; } } @@ -289,7 +319,7 @@ export class ShellExecutionService { } }); - return { pid: child.pid, result }; + return { pid: undefined, result }; } catch (e) { const error = e as Error; return { @@ -313,29 +343,32 @@ export class ShellExecutionService { cwd: string, onOutputEvent: (event: ShellOutputEvent) => void, abortSignal: AbortSignal, - terminalColumns: number | undefined, - terminalRows: number | undefined, - ptyInfo: PtyImplementation | undefined, + shellExecutionConfig: ShellExecutionConfig, + ptyInfo: PtyImplementation, ): ShellExecutionHandle { + if (!ptyInfo) { + // This should not happen, but as a safeguard... + throw new Error('PTY implementation not found'); + } try { - const cols = terminalColumns ?? 80; - const rows = terminalRows ?? 30; + const cols = shellExecutionConfig.terminalWidth ?? 80; + const rows = shellExecutionConfig.terminalHeight ?? 30; const isWindows = os.platform() === 'win32'; const shell = isWindows ? 'cmd.exe' : 'bash'; const args = isWindows ? `/c ${commandToExecute}` : ['-c', commandToExecute]; - const ptyProcess = ptyInfo?.module.spawn(shell, args, { + const ptyProcess = ptyInfo.module.spawn(shell, args, { cwd, - name: 'xterm-color', + name: 'xterm', cols, rows, env: { ...process.env, GEMINI_CLI: '1', TERM: 'xterm-256color', - PAGER: 'cat', + PAGER: shellExecutionConfig.pager ?? 'cat', }, handleFlowControl: true, }); @@ -346,8 +379,12 @@ export class ShellExecutionService { cols, rows, }); + + this.activePtys.set(ptyProcess.pid, { ptyProcess, headlessTerminal }); + let processingChain = Promise.resolve(); let decoder: TextDecoder | null = null; + let output: string | AnsiOutput | null = null; const outputChunks: Buffer[] = []; const error: Error | null = null; let exited = false; @@ -355,6 +392,49 @@ export class ShellExecutionService { let isStreamingRawContent = true; const MAX_SNIFF_SIZE = 4096; let sniffedBytes = 0; + let isWriting = false; + let renderTimeout: NodeJS.Timeout | null = null; + + const render = (finalRender = false) => { + if (renderTimeout) { + clearTimeout(renderTimeout); + } + + const renderFn = () => { + if (!isStreamingRawContent) { + return; + } + const newOutput = shellExecutionConfig.showColor + ? serializeTerminalToObject(headlessTerminal, { + defaultFg: shellExecutionConfig.defaultFg, + defaultBg: shellExecutionConfig.defaultBg, + }) + : getVisibleText(headlessTerminal); + + // console.log(newOutput) + + // Using stringify for a quick deep comparison. + if (JSON.stringify(output) !== JSON.stringify(newOutput)) { + output = newOutput; + onOutputEvent({ + type: 'data', + chunk: newOutput, + }); + } + }; + + if (finalRender) { + renderFn(); + } else { + renderTimeout = setTimeout(renderFn, 17); + } + }; + + headlessTerminal.onScroll(() => { + if (!isWriting) { + render(); + } + }); const handleOutput = (data: Buffer) => { processingChain = processingChain.then( @@ -383,11 +463,10 @@ export class ShellExecutionService { if (isStreamingRawContent) { const decodedChunk = decoder.decode(data, { stream: true }); + isWriting = true; headlessTerminal.write(decodedChunk, () => { - onOutputEvent({ - type: 'data', - chunk: stripAnsi(decodedChunk), - }); + render(); + isWriting = false; resolve(); }); } else { @@ -414,19 +493,23 @@ export class ShellExecutionService { ({ exitCode, signal }: { exitCode: number; signal?: number }) => { exited = true; abortSignal.removeEventListener('abort', abortHandler); + this.activePtys.delete(ptyProcess.pid); processingChain.then(() => { + render(true); const finalBuffer = Buffer.concat(outputChunks); resolve({ rawOutput: finalBuffer, - output: getFullText(headlessTerminal), + output: getFullBufferText(headlessTerminal), exitCode, signal: signal ?? null, error, aborted: abortSignal.aborted, pid: ptyProcess.pid, - executionMethod: ptyInfo?.name ?? 'node-pty', + executionMethod: + (ptyInfo?.name as 'node-pty' | 'lydell-node-pty') ?? + 'node-pty', }); }); }, @@ -434,7 +517,17 @@ export class ShellExecutionService { const abortHandler = async () => { if (ptyProcess.pid && !exited) { - ptyProcess.kill('SIGHUP'); + if (os.platform() === 'win32') { + ptyProcess.kill(); + } else { + try { + // Kill the entire process group + process.kill(-ptyProcess.pid, 'SIGINT'); + } catch (_e) { + // Fallback to killing just the process if the group kill fails + ptyProcess.kill('SIGINT'); + } + } } }; @@ -459,4 +552,65 @@ export class ShellExecutionService { }; } } + + /** + * Writes a string to the pseudo-terminal (PTY) of a running process. + * + * @param pid The process ID of the target PTY. + * @param input The string to write to the terminal. + */ + static writeToPty(pid: number, input: string): void { + const activePty = this.activePtys.get(pid); + if (activePty) { + activePty.ptyProcess.write(input); + } + } + + /** + * Resizes the pseudo-terminal (PTY) of a running process. + * + * @param pid The process ID of the target PTY. + * @param cols The new number of columns. + * @param rows The new number of rows. + */ + static resizePty(pid: number, cols: number, rows: number): void { + const activePty = this.activePtys.get(pid); + if (activePty) { + try { + activePty.ptyProcess.resize(cols, rows); + activePty.headlessTerminal.resize(cols, rows); + } catch (e) { + // Ignore errors if the pty has already exited, which can happen + // due to a race condition between the exit event and this call. + if (e instanceof Error && 'code' in e && e.code === 'ESRCH') { + // ignore + } else { + throw e; + } + } + } + } + + /** + * Scrolls the pseudo-terminal (PTY) of a running process. + * + * @param pid The process ID of the target PTY. + * @param lines The number of lines to scroll. + */ + static scrollPty(pid: number, lines: number): void { + const activePty = this.activePtys.get(pid); + if (activePty) { + try { + activePty.headlessTerminal.scrollLines(lines); + } catch (e) { + // Ignore errors if the pty has already exited, which can happen + // due to a race condition between the exit event and this call. + if (e instanceof Error && 'code' in e && e.code === 'ESRCH') { + // ignore + } else { + throw e; + } + } + } + } } diff --git a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts index eddc9a59cf0..666adcbe369 100644 --- a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts +++ b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts @@ -392,6 +392,24 @@ describe('ClearcutLogger', () => { }); }); + describe('logRipgrepFallbackEvent', () => { + it('logs an event with the proper name', () => { + const { logger } = setup(); + // Spy on flushToClearcut to prevent it from clearing the queue + const flushSpy = vi + // eslint-disable-next-line @typescript-eslint/no-explicit-any + .spyOn(logger!, 'flushToClearcut' as any) + .mockResolvedValue({ nextRequestWaitMs: 0 }); + + logger?.logRipgrepFallbackEvent(); + + const events = getEvents(logger!); + expect(events.length).toBe(1); + expect(events[0]).toHaveEventName(EventNames.RIPGREP_FALLBACK); + expect(flushSpy).toHaveBeenCalledOnce(); + }); + }); + describe('enqueueLogEvent', () => { it('should add events to the queue', () => { const { logger } = setup(); diff --git a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.ts b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.ts index bcdea047cb0..4cc19e4f7b7 100644 --- a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.ts +++ b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.ts @@ -24,6 +24,8 @@ import type { InvalidChunkEvent, ContentRetryEvent, ContentRetryFailureEvent, + ExtensionInstallEvent, + ToolOutputTruncatedEvent, } from '../types.js'; import { EventMetadataKey } from './event-metadata-key.js'; import type { Config } from '../../config/config.js'; @@ -44,6 +46,7 @@ export enum EventNames { API_ERROR = 'api_error', END_SESSION = 'end_session', FLASH_FALLBACK = 'flash_fallback', + RIPGREP_FALLBACK = 'ripgrep_fallback', LOOP_DETECTED = 'loop_detected', NEXT_SPEAKER_CHECK = 'next_speaker_check', SLASH_COMMAND = 'slash_command', @@ -55,6 +58,8 @@ export enum EventNames { INVALID_CHUNK = 'invalid_chunk', CONTENT_RETRY = 'content_retry', CONTENT_RETRY_FAILURE = 'content_retry_failure', + EXTENSION_INSTALL = 'extension_install', + TOOL_OUTPUT_TRUNCATED = 'tool_output_truncated', } export interface LogResponse { @@ -461,6 +466,10 @@ export class ClearcutLogger { gemini_cli_key: EventMetadataKey.GEMINI_CLI_TOOL_TYPE, value: JSON.stringify(event.tool_type), }, + { + gemini_cli_key: EventMetadataKey.GEMINI_CLI_TOOL_CALL_CONTENT_LENGTH, + value: JSON.stringify(event.content_length), + }, ]; if (event.metadata) { @@ -548,10 +557,6 @@ export class ClearcutLogger { gemini_cli_key: EventMetadataKey.GEMINI_CLI_API_RESPONSE_DURATION_MS, value: JSON.stringify(event.duration_ms), }, - { - gemini_cli_key: EventMetadataKey.GEMINI_CLI_API_ERROR_MESSAGE, - value: JSON.stringify(event.error), - }, { gemini_cli_key: EventMetadataKey.GEMINI_CLI_API_RESPONSE_INPUT_TOKEN_COUNT, @@ -631,6 +636,13 @@ export class ClearcutLogger { }); } + logRipgrepFallbackEvent(): void { + this.enqueueLogEvent(this.createLogEvent(EventNames.RIPGREP_FALLBACK, [])); + this.flushToClearcut().catch((error) => { + console.debug('Error flushing to Clearcut:', error); + }); + } + logLoopDetectedEvent(event: LoopDetectedEvent): void { const data: EventValue[] = [ { @@ -825,6 +837,65 @@ export class ClearcutLogger { this.flushIfNeeded(); } + logExtensionInstallEvent(event: ExtensionInstallEvent): void { + const data: EventValue[] = [ + { + gemini_cli_key: EventMetadataKey.GEMINI_CLI_EXTENSION_NAME, + value: event.extension_name, + }, + { + gemini_cli_key: EventMetadataKey.GEMINI_CLI_EXTENSION_VERSION, + value: event.extension_version, + }, + { + gemini_cli_key: EventMetadataKey.GEMINI_CLI_EXTENSION_SOURCE, + value: event.extension_source, + }, + { + gemini_cli_key: EventMetadataKey.GEMINI_CLI_EXTENSION_INSTALL_STATUS, + value: event.status, + }, + ]; + + this.enqueueLogEvent( + this.createLogEvent(EventNames.EXTENSION_INSTALL, data), + ); + this.flushIfNeeded(); + } + + logToolOutputTruncatedEvent(event: ToolOutputTruncatedEvent): void { + const data: EventValue[] = [ + { + gemini_cli_key: EventMetadataKey.GEMINI_CLI_TOOL_CALL_NAME, + value: JSON.stringify(event.tool_name), + }, + { + gemini_cli_key: + EventMetadataKey.GEMINI_CLI_TOOL_OUTPUT_TRUNCATED_ORIGINAL_LENGTH, + value: JSON.stringify(event.original_content_length), + }, + { + gemini_cli_key: + EventMetadataKey.GEMINI_CLI_TOOL_OUTPUT_TRUNCATED_TRUNCATED_LENGTH, + value: JSON.stringify(event.truncated_content_length), + }, + { + gemini_cli_key: + EventMetadataKey.GEMINI_CLI_TOOL_OUTPUT_TRUNCATED_THRESHOLD, + value: JSON.stringify(event.threshold), + }, + { + gemini_cli_key: EventMetadataKey.GEMINI_CLI_TOOL_OUTPUT_TRUNCATED_LINES, + value: JSON.stringify(event.lines), + }, + ]; + + this.enqueueLogEvent( + this.createLogEvent(EventNames.TOOL_OUTPUT_TRUNCATED, data), + ); + this.flushIfNeeded(); + } + /** * Adds default fields to data, and returns a new data array. This fields * should exist on all log events. diff --git a/packages/core/src/telemetry/clearcut-logger/event-metadata-key.ts b/packages/core/src/telemetry/clearcut-logger/event-metadata-key.ts index 31e718a4662..25ea0462836 100644 --- a/packages/core/src/telemetry/clearcut-logger/event-metadata-key.ts +++ b/packages/core/src/telemetry/clearcut-logger/event-metadata-key.ts @@ -6,6 +6,8 @@ // Defines valid event metadata keys for Clearcut logging. export enum EventMetadataKey { + // Deleted enums: 24 + GEMINI_CLI_KEY_UNKNOWN = 0, // ========================================================================== @@ -77,6 +79,9 @@ export enum EventMetadataKey { // Logs the tool call error type, if any. GEMINI_CLI_TOOL_CALL_ERROR_TYPE = 19, + // Logs the length of tool output + GEMINI_CLI_TOOL_CALL_CONTENT_LENGTH = 93, + // ========================================================================== // GenAI API Request Event Keys // =========================================================================== @@ -97,9 +102,6 @@ export enum EventMetadataKey { // Logs the duration of the API call in milliseconds. GEMINI_CLI_API_RESPONSE_DURATION_MS = 23, - // Logs the error message of the API call, if any. - GEMINI_CLI_API_ERROR_MESSAGE = 24, - // Logs the input token count of the API call. GEMINI_CLI_API_RESPONSE_INPUT_TOKEN_COUNT = 25, @@ -331,4 +333,36 @@ export enum EventMetadataKey { // Logs the current nodejs version GEMINI_CLI_NODE_VERSION = 83, + + // ========================================================================== + // Extension Install Event Keys + // =========================================================================== + + // Logs the name of the extension. + GEMINI_CLI_EXTENSION_NAME = 85, + + // Logs the version of the extension. + GEMINI_CLI_EXTENSION_VERSION = 86, + + // Logs the source of the extension. + GEMINI_CLI_EXTENSION_SOURCE = 87, + + // Logs the status of the extension install. + GEMINI_CLI_EXTENSION_INSTALL_STATUS = 88, + + // ========================================================================== + // Tool Output Truncated Event Keys + // =========================================================================== + + // Logs the original length of the tool output. + GEMINI_CLI_TOOL_OUTPUT_TRUNCATED_ORIGINAL_LENGTH = 89, + + // Logs the truncated length of the tool output. + GEMINI_CLI_TOOL_OUTPUT_TRUNCATED_TRUNCATED_LENGTH = 90, + + // Logs the threshold at which the tool output was truncated. + GEMINI_CLI_TOOL_OUTPUT_TRUNCATED_THRESHOLD = 91, + + // Logs the number of lines the tool output was truncated to. + GEMINI_CLI_TOOL_OUTPUT_TRUNCATED_LINES = 92, } diff --git a/packages/core/src/telemetry/constants.ts b/packages/core/src/telemetry/constants.ts index 6b62b6deed6..2e06dacd4f3 100644 --- a/packages/core/src/telemetry/constants.ts +++ b/packages/core/src/telemetry/constants.ts @@ -13,6 +13,7 @@ export const EVENT_API_ERROR = 'gemini_cli.api_error'; export const EVENT_API_RESPONSE = 'gemini_cli.api_response'; export const EVENT_CLI_CONFIG = 'gemini_cli.config'; export const EVENT_FLASH_FALLBACK = 'gemini_cli.flash_fallback'; +export const EVENT_RIPGREP_FALLBACK = 'gemini_cli.ripgrep_fallback'; export const EVENT_NEXT_SPEAKER_CHECK = 'gemini_cli.next_speaker_check'; export const EVENT_SLASH_COMMAND = 'gemini_cli.slash_command'; export const EVENT_IDE_CONNECTION = 'gemini_cli.ide_connection'; diff --git a/packages/core/src/telemetry/high-water-mark-tracker.test.ts b/packages/core/src/telemetry/high-water-mark-tracker.test.ts new file mode 100644 index 00000000000..568b2a79710 --- /dev/null +++ b/packages/core/src/telemetry/high-water-mark-tracker.test.ts @@ -0,0 +1,198 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { HighWaterMarkTracker } from './high-water-mark-tracker.js'; + +describe('HighWaterMarkTracker', () => { + let tracker: HighWaterMarkTracker; + + beforeEach(() => { + tracker = new HighWaterMarkTracker(5); // 5% threshold + }); + + describe('constructor', () => { + it('should initialize with default values', () => { + const defaultTracker = new HighWaterMarkTracker(); + expect(defaultTracker).toBeInstanceOf(HighWaterMarkTracker); + }); + + it('should initialize with custom values', () => { + const customTracker = new HighWaterMarkTracker(10); + expect(customTracker).toBeInstanceOf(HighWaterMarkTracker); + }); + + it('should throw on negative threshold', () => { + expect(() => new HighWaterMarkTracker(-1)).toThrow( + 'growthThresholdPercent must be non-negative.', + ); + }); + }); + + describe('shouldRecordMetric', () => { + it('should return true for first measurement', () => { + const result = tracker.shouldRecordMetric('heap_used', 1000000); + expect(result).toBe(true); + }); + + it('should return false for small increases', () => { + // Set initial high-water mark + tracker.shouldRecordMetric('heap_used', 1000000); + + // Small increase (less than 5%) + const result = tracker.shouldRecordMetric('heap_used', 1030000); // 3% increase + expect(result).toBe(false); + }); + + it('should return true for significant increases', () => { + // Set initial high-water mark + tracker.shouldRecordMetric('heap_used', 1000000); + + // Add several readings to build up smoothing window + tracker.shouldRecordMetric('heap_used', 1100000); // 10% increase + tracker.shouldRecordMetric('heap_used', 1150000); // Additional growth + const result = tracker.shouldRecordMetric('heap_used', 1200000); // Sustained growth + expect(result).toBe(true); + }); + + it('should handle decreasing values correctly', () => { + // Set initial high-water mark + tracker.shouldRecordMetric('heap_used', 1000000); + + // Decrease (should not trigger) + const result = tracker.shouldRecordMetric('heap_used', 900000); // 10% decrease + expect(result).toBe(false); + }); + + it('should update high-water mark when threshold exceeded', () => { + tracker.shouldRecordMetric('heap_used', 1000000); + + const beforeMark = tracker.getHighWaterMark('heap_used'); + + // Create sustained growth pattern to trigger update + tracker.shouldRecordMetric('heap_used', 1100000); + tracker.shouldRecordMetric('heap_used', 1150000); + tracker.shouldRecordMetric('heap_used', 1200000); + + const afterMark = tracker.getHighWaterMark('heap_used'); + + expect(afterMark).toBeGreaterThan(beforeMark); + }); + + it('should handle multiple metric types independently', () => { + tracker.shouldRecordMetric('heap_used', 1000000); + tracker.shouldRecordMetric('rss', 2000000); + + expect(tracker.getHighWaterMark('heap_used')).toBeGreaterThan(0); + expect(tracker.getHighWaterMark('rss')).toBeGreaterThan(0); + expect(tracker.getHighWaterMark('heap_used')).not.toBe( + tracker.getHighWaterMark('rss'), + ); + }); + }); + + describe('smoothing functionality', () => { + it('should reduce noise from garbage collection spikes', () => { + // Establish baseline + tracker.shouldRecordMetric('heap_used', 1000000); + tracker.shouldRecordMetric('heap_used', 1000000); + tracker.shouldRecordMetric('heap_used', 1000000); + + // Single spike (should be smoothed out) + const result = tracker.shouldRecordMetric('heap_used', 2000000); + + // With the new responsive algorithm, large spikes do trigger + expect(result).toBe(true); + }); + + it('should eventually respond to sustained growth', () => { + // Establish baseline + tracker.shouldRecordMetric('heap_used', 1000000); + + // Sustained growth pattern + tracker.shouldRecordMetric('heap_used', 1100000); + tracker.shouldRecordMetric('heap_used', 1150000); + const result = tracker.shouldRecordMetric('heap_used', 1200000); + + expect(result).toBe(true); + }); + }); + + describe('getHighWaterMark', () => { + it('should return 0 for unknown metric types', () => { + const mark = tracker.getHighWaterMark('unknown_metric'); + expect(mark).toBe(0); + }); + + it('should return correct value for known metric types', () => { + tracker.shouldRecordMetric('heap_used', 1000000); + const mark = tracker.getHighWaterMark('heap_used'); + expect(mark).toBeGreaterThan(0); + }); + }); + + describe('getAllHighWaterMarks', () => { + it('should return empty object initially', () => { + const marks = tracker.getAllHighWaterMarks(); + expect(marks).toEqual({}); + }); + + it('should return all recorded marks', () => { + tracker.shouldRecordMetric('heap_used', 1000000); + tracker.shouldRecordMetric('rss', 2000000); + + const marks = tracker.getAllHighWaterMarks(); + expect(Object.keys(marks)).toHaveLength(2); + expect(marks['heap_used']).toBeGreaterThan(0); + expect(marks['rss']).toBeGreaterThan(0); + }); + }); + + describe('resetHighWaterMark', () => { + it('should reset specific metric type', () => { + tracker.shouldRecordMetric('heap_used', 1000000); + tracker.shouldRecordMetric('rss', 2000000); + + tracker.resetHighWaterMark('heap_used'); + + expect(tracker.getHighWaterMark('heap_used')).toBe(0); + expect(tracker.getHighWaterMark('rss')).toBeGreaterThan(0); + }); + }); + + describe('resetAllHighWaterMarks', () => { + it('should reset all metrics', () => { + tracker.shouldRecordMetric('heap_used', 1000000); + tracker.shouldRecordMetric('rss', 2000000); + + tracker.resetAllHighWaterMarks(); + + expect(tracker.getHighWaterMark('heap_used')).toBe(0); + expect(tracker.getHighWaterMark('rss')).toBe(0); + expect(tracker.getAllHighWaterMarks()).toEqual({}); + }); + }); + + describe('time-based cleanup', () => { + it('should clean up old readings', () => { + vi.useFakeTimers(); + + // Add readings + tracker.shouldRecordMetric('heap_used', 1000000); + + // Advance time significantly + vi.advanceTimersByTime(15000); // 15 seconds + + // Explicit cleanup should remove stale entries when age exceeded + tracker.cleanup(10000); // 10 seconds + + // Entry should be removed + expect(tracker.getHighWaterMark('heap_used')).toBe(0); + + vi.useRealTimers(); + }); + }); +}); diff --git a/packages/core/src/telemetry/high-water-mark-tracker.ts b/packages/core/src/telemetry/high-water-mark-tracker.ts new file mode 100644 index 00000000000..7317650bb41 --- /dev/null +++ b/packages/core/src/telemetry/high-water-mark-tracker.ts @@ -0,0 +1,100 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * High-water mark tracker for memory metrics + * Only triggers when memory usage increases by a significant threshold + */ +export class HighWaterMarkTracker { + private waterMarks: Map = new Map(); + private lastUpdateTimes: Map = new Map(); + private readonly growthThresholdPercent: number; + + constructor(growthThresholdPercent: number = 5) { + if (growthThresholdPercent < 0) { + throw new Error('growthThresholdPercent must be non-negative.'); + } + this.growthThresholdPercent = growthThresholdPercent; + } + + /** + * Check if current value represents a new high-water mark that should trigger recording + * @param metricType - Type of metric (e.g., 'heap_used', 'rss') + * @param currentValue - Current memory value in bytes + * @returns true if this value should trigger a recording + */ + shouldRecordMetric(metricType: string, currentValue: number): boolean { + const now = Date.now(); + // Track last seen time for cleanup regardless of whether we record + this.lastUpdateTimes.set(metricType, now); + // Get current high-water mark + const currentWaterMark = this.waterMarks.get(metricType) || 0; + + // For first measurement, always record + if (currentWaterMark === 0) { + this.waterMarks.set(metricType, currentValue); + this.lastUpdateTimes.set(metricType, now); + return true; + } + + // Check if current value exceeds threshold + const thresholdValue = + currentWaterMark * (1 + this.growthThresholdPercent / 100); + + if (currentValue > thresholdValue) { + // Update high-water mark + this.waterMarks.set(metricType, currentValue); + this.lastUpdateTimes.set(metricType, now); + return true; + } + + return false; + } + + /** + * Get current high-water mark for a metric type + */ + getHighWaterMark(metricType: string): number { + return this.waterMarks.get(metricType) || 0; + } + + /** + * Get all high-water marks + */ + getAllHighWaterMarks(): Record { + return Object.fromEntries(this.waterMarks); + } + + /** + * Reset high-water mark for a specific metric type + */ + resetHighWaterMark(metricType: string): void { + this.waterMarks.delete(metricType); + this.lastUpdateTimes.delete(metricType); + } + + /** + * Reset all high-water marks + */ + resetAllHighWaterMarks(): void { + this.waterMarks.clear(); + this.lastUpdateTimes.clear(); + } + + /** + * Remove stale entries to avoid unbounded growth if metric types are variable. + * Entries not updated within maxAgeMs will be removed. + */ + cleanup(maxAgeMs: number = 3600000): void { + const cutoffTime = Date.now() - maxAgeMs; + for (const [metricType, lastTime] of this.lastUpdateTimes.entries()) { + if (lastTime < cutoffTime) { + this.lastUpdateTimes.delete(metricType); + this.waterMarks.delete(metricType); + } + } + } +} diff --git a/packages/core/src/telemetry/index.ts b/packages/core/src/telemetry/index.ts index a5d33cc34b8..2a22f684cf4 100644 --- a/packages/core/src/telemetry/index.ts +++ b/packages/core/src/telemetry/index.ts @@ -30,6 +30,7 @@ export { logConversationFinishedEvent, logKittySequenceOverflow, logChatCompression, + logToolOutputTruncated, } from './loggers.js'; export type { SlashCommandEvent, ChatCompressionEvent } from './types.js'; export { @@ -44,9 +45,14 @@ export { ToolCallEvent, ConversationFinishedEvent, KittySequenceOverflowEvent, + ToolOutputTruncatedEvent, } from './types.js'; export { makeSlashCommandEvent, makeChatCompressionEvent } from './types.js'; export type { TelemetryEvent } from './types.js'; export { SpanStatusCode, ValueType } from '@opentelemetry/api'; export { SemanticAttributes } from '@opentelemetry/semantic-conventions'; export * from './uiTelemetry.js'; +export { HighWaterMarkTracker } from './high-water-mark-tracker.js'; +export { RateLimiter } from './rate-limiter.js'; +export { streamingTelemetryService } from './streamingTelemetry.js'; +export type { TelemetryStreamListener } from './streamingTelemetry.js'; diff --git a/packages/core/src/telemetry/loggers.test.ts b/packages/core/src/telemetry/loggers.test.ts index 0407154961a..1cee336d99c 100644 --- a/packages/core/src/telemetry/loggers.test.ts +++ b/packages/core/src/telemetry/loggers.test.ts @@ -18,6 +18,7 @@ import { ToolErrorType, ToolRegistry, } from '../index.js'; +import { OutputFormat } from '../output/types.js'; import { logs } from '@opentelemetry/api-logs'; import { SemanticAttributes } from '@opentelemetry/semantic-conventions'; import type { Config } from '../config/config.js'; @@ -30,6 +31,7 @@ import { EVENT_FLASH_FALLBACK, EVENT_MALFORMED_JSON_RESPONSE, EVENT_FILE_OPERATION, + EVENT_RIPGREP_FALLBACK, } from './constants.js'; import { logApiRequest, @@ -41,6 +43,8 @@ import { logChatCompression, logMalformedJsonResponse, logFileOperation, + logRipgrepFallback, + logToolOutputTruncated, } from './loggers.js'; import { ToolCallDecision } from './tool-call-decision.js'; import { @@ -50,9 +54,11 @@ import { ToolCallEvent, UserPromptEvent, FlashFallbackEvent, + RipgrepFallbackEvent, MalformedJsonResponseEvent, makeChatCompressionEvent, FileOperationEvent, + ToolOutputTruncatedEvent, } from './types.js'; import * as metrics from './metrics.js'; import { FileOperation } from './metrics.js'; @@ -154,6 +160,7 @@ describe('loggers', () => { getQuestion: () => 'test-question', getTargetDir: () => 'target-dir', getProxy: () => 'http://test.proxy.com:8080', + getOutputFormat: () => OutputFormat.JSON, } as unknown as Config; const startSessionEvent = new StartSessionEvent(mockConfig); @@ -180,6 +187,7 @@ describe('loggers', () => { mcp_servers_count: 1, mcp_tools: undefined, mcp_tools_count: undefined, + output_format: 'json', }, }); }); @@ -312,7 +320,6 @@ describe('loggers', () => { response_text: 'test-response', prompt_id: 'prompt-id-1', auth_type: 'oauth-personal', - error: undefined, }, }); @@ -321,7 +328,6 @@ describe('loggers', () => { 'test-model', 100, 200, - undefined, ); expect(mockMetrics.recordTokenUsageMetrics).toHaveBeenCalledWith( @@ -337,45 +343,6 @@ describe('loggers', () => { 'event.timestamp': '2025-01-01T00:00:00.000Z', }); }); - - it('should log an API response with an error', () => { - const usageData: GenerateContentResponseUsageMetadata = { - promptTokenCount: 17, - candidatesTokenCount: 50, - cachedContentTokenCount: 10, - thoughtsTokenCount: 5, - toolUsePromptTokenCount: 2, - }; - const event = new ApiResponseEvent( - 'test-model', - 100, - 'prompt-id-1', - AuthType.USE_GEMINI, - usageData, - 'test-response', - 'test-error', - ); - - logApiResponse(mockConfig, event); - - expect(mockLogger.emit).toHaveBeenCalledWith({ - body: 'API response from test-model. Status: 200. Duration: 100ms.', - attributes: { - 'session.id': 'test-session-id', - 'user.email': 'test-user@example.com', - ...event, - 'event.name': EVENT_API_RESPONSE, - 'event.timestamp': '2025-01-01T00:00:00.000Z', - 'error.message': 'test-error', - }, - }); - - expect(mockUiEvent.addEvent).toHaveBeenCalledWith({ - ...event, - 'event.name': EVENT_API_RESPONSE, - 'event.timestamp': '2025-01-01T00:00:00.000Z', - }); - }); }); describe('logApiRequest', () => { @@ -453,6 +420,59 @@ describe('loggers', () => { }); }); + describe('logRipgrepFallback', () => { + const mockConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + } as unknown as Config; + + beforeEach(() => { + vi.spyOn(ClearcutLogger.prototype, 'logRipgrepFallbackEvent'); + }); + + it('should log ripgrep fallback event', () => { + const event = new RipgrepFallbackEvent(); + + logRipgrepFallback(mockConfig, event); + + expect( + ClearcutLogger.prototype.logRipgrepFallbackEvent, + ).toHaveBeenCalled(); + + const emittedEvent = mockLogger.emit.mock.calls[0][0]; + expect(emittedEvent.body).toBe('Switching to grep as fallback.'); + expect(emittedEvent.attributes).toEqual( + expect.objectContaining({ + 'session.id': 'test-session-id', + 'user.email': 'test-user@example.com', + 'event.name': EVENT_RIPGREP_FALLBACK, + error: undefined, + }), + ); + }); + + it('should log ripgrep fallback event with an error', () => { + const event = new RipgrepFallbackEvent('rg not found'); + + logRipgrepFallback(mockConfig, event); + + expect( + ClearcutLogger.prototype.logRipgrepFallbackEvent, + ).toHaveBeenCalled(); + + const emittedEvent = mockLogger.emit.mock.calls[0][0]; + expect(emittedEvent.body).toBe('Switching to grep as fallback.'); + expect(emittedEvent.attributes).toEqual( + expect.objectContaining({ + 'session.id': 'test-session-id', + 'user.email': 'test-user@example.com', + 'event.name': EVENT_RIPGREP_FALLBACK, + error: 'rg not found', + }), + ); + }); + }); + describe('logToolCall', () => { const cfg1 = { getSessionId: () => 'test-session-id', @@ -524,9 +544,25 @@ describe('loggers', () => { response: { callId: 'test-call-id', responseParts: [{ text: 'test-response' }], - resultDisplay: undefined, + resultDisplay: { + fileDiff: 'diff', + fileName: 'file.txt', + originalContent: 'old content', + newContent: 'new content', + diffStat: { + model_added_lines: 1, + model_removed_lines: 2, + model_added_chars: 3, + model_removed_chars: 4, + user_added_lines: 5, + user_removed_lines: 6, + user_added_chars: 7, + user_removed_chars: 8, + }, + }, error: undefined, errorType: undefined, + contentLength: 13, }, tool, invocation: {} as AnyToolInvocation, @@ -560,7 +596,18 @@ describe('loggers', () => { tool_type: 'native', error: undefined, error_type: undefined, - metadata: undefined, + + metadata: { + model_added_lines: 1, + model_removed_lines: 2, + model_added_chars: 3, + model_removed_chars: 4, + user_added_lines: 5, + user_removed_lines: 6, + user_added_chars: 7, + user_removed_chars: 8, + }, + content_length: 13, }, }); @@ -598,6 +645,7 @@ describe('loggers', () => { resultDisplay: undefined, error: undefined, errorType: undefined, + contentLength: undefined, }, durationMs: 100, outcome: ToolConfirmationOutcome.Cancel, @@ -630,6 +678,7 @@ describe('loggers', () => { error: undefined, error_type: undefined, metadata: undefined, + content_length: undefined, }, }); @@ -668,6 +717,7 @@ describe('loggers', () => { resultDisplay: undefined, error: undefined, errorType: undefined, + contentLength: 13, }, outcome: ToolConfirmationOutcome.ModifyWithEditor, tool: new EditTool(mockConfig), @@ -702,6 +752,7 @@ describe('loggers', () => { error: undefined, error_type: undefined, metadata: undefined, + content_length: 13, }, }); @@ -740,6 +791,7 @@ describe('loggers', () => { resultDisplay: undefined, error: undefined, errorType: undefined, + contentLength: 13, }, tool: new EditTool(mockConfig), invocation: {} as AnyToolInvocation, @@ -773,6 +825,7 @@ describe('loggers', () => { error: undefined, error_type: undefined, metadata: undefined, + content_length: 13, }, }); @@ -793,6 +846,7 @@ describe('loggers', () => { }); it('should log a failed tool call with an error', () => { + const errorMessage = 'test-error'; const call: ErroredToolCall = { status: 'error', request: { @@ -809,11 +863,9 @@ describe('loggers', () => { callId: 'test-call-id', responseParts: [{ text: 'test-response' }], resultDisplay: undefined, - error: { - name: 'test-error-type', - message: 'test-error', - }, + error: new Error(errorMessage), errorType: ToolErrorType.UNKNOWN, + contentLength: errorMessage.length, }, durationMs: 100, }; @@ -847,6 +899,7 @@ describe('loggers', () => { tool_type: 'native', decision: undefined, metadata: undefined, + content_length: errorMessage.length, }, }); @@ -952,4 +1005,40 @@ describe('loggers', () => { ); }); }); + + describe('logToolOutputTruncated', () => { + const mockConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + } as unknown as Config; + + it('should log a tool output truncated event', () => { + const event = new ToolOutputTruncatedEvent('prompt-id-1', { + toolName: 'test-tool', + originalContentLength: 1000, + truncatedContentLength: 100, + threshold: 500, + lines: 10, + }); + + logToolOutputTruncated(mockConfig, event); + + expect(mockLogger.emit).toHaveBeenCalledWith({ + body: 'Tool output truncated for test-tool.', + attributes: { + 'session.id': 'test-session-id', + 'user.email': 'test-user@example.com', + 'event.name': 'tool_output_truncated', + 'event.timestamp': '2025-01-01T00:00:00.000Z', + eventName: 'tool_output_truncated', + prompt_id: 'prompt-id-1', + tool_name: 'test-tool', + original_content_length: 1000, + truncated_content_length: 100, + threshold: 500, + lines: 10, + }, + }); + }); + }); }); diff --git a/packages/core/src/telemetry/loggers.ts b/packages/core/src/telemetry/loggers.ts index 06476d3a81f..5f556a53ecf 100644 --- a/packages/core/src/telemetry/loggers.ts +++ b/packages/core/src/telemetry/loggers.ts @@ -27,6 +27,7 @@ import { EVENT_CONTENT_RETRY, EVENT_CONTENT_RETRY_FAILURE, EVENT_FILE_OPERATION, + EVENT_RIPGREP_FALLBACK, } from './constants.js'; import type { ApiErrorEvent, @@ -48,6 +49,8 @@ import type { InvalidChunkEvent, ContentRetryEvent, ContentRetryFailureEvent, + RipgrepFallbackEvent, + ToolOutputTruncatedEvent, } from './types.js'; import { recordApiErrorMetrics, @@ -63,6 +66,7 @@ import { import { isTelemetrySdkInitialized } from './sdk.js'; import type { UiEvent } from './uiTelemetry.js'; import { uiTelemetryService } from './uiTelemetry.js'; +import { streamingTelemetryService } from './streamingTelemetry.js'; import { ClearcutLogger } from './clearcut-logger/clearcut-logger.js'; import { safeJsonStringify } from '../utils/safeJsonStringify.js'; import { UserAccountManager } from '../utils/userAccountManager.js'; @@ -104,6 +108,7 @@ export function logCliConfiguration( mcp_servers_count: event.mcp_servers_count, mcp_tools: event.mcp_tools, mcp_tools_count: event.mcp_tools_count, + output_format: event.output_format, }; const logger = logs.getLogger(SERVICE_NAME); @@ -115,6 +120,7 @@ export function logCliConfiguration( } export function logUserPrompt(config: Config, event: UserPromptEvent): void { + streamingTelemetryService.emitEvent(event); ClearcutLogger.getInstance(config)?.logNewPromptEvent(event); if (!isTelemetrySdkInitialized()) return; @@ -143,6 +149,7 @@ export function logUserPrompt(config: Config, event: UserPromptEvent): void { } export function logToolCall(config: Config, event: ToolCallEvent): void { + streamingTelemetryService.emitEvent(event); const uiEvent = { ...event, 'event.name': EVENT_TOOL_CALL, @@ -182,6 +189,28 @@ export function logToolCall(config: Config, event: ToolCallEvent): void { ); } +export function logToolOutputTruncated( + config: Config, + event: ToolOutputTruncatedEvent, +): void { + ClearcutLogger.getInstance(config)?.logToolOutputTruncatedEvent(event); + if (!isTelemetrySdkInitialized()) return; + + const attributes: LogAttributes = { + ...getCommonAttributes(config), + ...event, + 'event.name': 'tool_output_truncated', + 'event.timestamp': new Date().toISOString(), + }; + + const logger = logs.getLogger(SERVICE_NAME); + const logRecord: LogRecord = { + body: `Tool output truncated for ${event.tool_name}.`, + attributes, + }; + logger.emit(logRecord); +} + export function logFileOperation( config: Config, event: FileOperationEvent, @@ -268,6 +297,28 @@ export function logFlashFallback( logger.emit(logRecord); } +export function logRipgrepFallback( + config: Config, + event: RipgrepFallbackEvent, +): void { + ClearcutLogger.getInstance(config)?.logRipgrepFallbackEvent(); + if (!isTelemetrySdkInitialized()) return; + + const attributes: LogAttributes = { + ...getCommonAttributes(config), + ...event, + 'event.name': EVENT_RIPGREP_FALLBACK, + 'event.timestamp': new Date().toISOString(), + }; + + const logger = logs.getLogger(SERVICE_NAME); + const logRecord: LogRecord = { + body: `Switching to grep as fallback.`, + attributes, + }; + logger.emit(logRecord); +} + export function logApiError(config: Config, event: ApiErrorEvent): void { const uiEvent = { ...event, @@ -311,6 +362,7 @@ export function logApiError(config: Config, event: ApiErrorEvent): void { } export function logApiResponse(config: Config, event: ApiResponseEvent): void { + streamingTelemetryService.emitEvent(event); const uiEvent = { ...event, 'event.name': EVENT_API_RESPONSE, @@ -328,9 +380,7 @@ export function logApiResponse(config: Config, event: ApiResponseEvent): void { if (event.response_text) { attributes['response_text'] = event.response_text; } - if (event.error) { - attributes['error.message'] = event.error; - } else if (event.status_code) { + if (event.status_code) { if (typeof event.status_code === 'number') { attributes[SemanticAttributes.HTTP_STATUS_CODE] = event.status_code; } @@ -347,7 +397,6 @@ export function logApiResponse(config: Config, event: ApiResponseEvent): void { event.model, event.duration_ms, event.status_code, - event.error, ); recordTokenUsageMetrics( config, diff --git a/packages/core/src/telemetry/metrics.ts b/packages/core/src/telemetry/metrics.ts index 385ee076f70..a51f74cee90 100644 --- a/packages/core/src/telemetry/metrics.ts +++ b/packages/core/src/telemetry/metrics.ts @@ -174,7 +174,6 @@ export function recordApiResponseMetrics( model: string, durationMs: number, statusCode?: number | string, - error?: string, ): void { if ( !apiRequestCounter || @@ -185,7 +184,7 @@ export function recordApiResponseMetrics( const metricAttributes: Attributes = { ...getCommonAttributes(config), model, - status_code: statusCode ?? (error ? 'error' : 'ok'), + status_code: statusCode ?? 'ok', }; apiRequestCounter.add(1, metricAttributes); apiRequestLatencyHistogram.record(durationMs, { diff --git a/packages/core/src/telemetry/rate-limiter.test.ts b/packages/core/src/telemetry/rate-limiter.test.ts new file mode 100644 index 00000000000..11e94690438 --- /dev/null +++ b/packages/core/src/telemetry/rate-limiter.test.ts @@ -0,0 +1,293 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { RateLimiter } from './rate-limiter.js'; + +describe('RateLimiter', () => { + let rateLimiter: RateLimiter; + + beforeEach(() => { + rateLimiter = new RateLimiter(1000); // 1 second interval for testing + }); + + describe('constructor', () => { + it('should initialize with default interval', () => { + const defaultLimiter = new RateLimiter(); + expect(defaultLimiter).toBeInstanceOf(RateLimiter); + }); + + it('should initialize with custom interval', () => { + const customLimiter = new RateLimiter(5000); + expect(customLimiter).toBeInstanceOf(RateLimiter); + }); + + it('should throw on negative interval', () => { + expect(() => new RateLimiter(-1)).toThrow( + 'minIntervalMs must be non-negative.', + ); + }); + }); + + describe('shouldRecord', () => { + it('should allow first recording', () => { + const result = rateLimiter.shouldRecord('test_metric'); + expect(result).toBe(true); + }); + + it('should block immediate subsequent recordings', () => { + rateLimiter.shouldRecord('test_metric'); // First call + const result = rateLimiter.shouldRecord('test_metric'); // Immediate second call + expect(result).toBe(false); + }); + + it('should allow recording after interval', () => { + vi.useFakeTimers(); + + rateLimiter.shouldRecord('test_metric'); // First call + + // Advance time past interval + vi.advanceTimersByTime(1500); + + const result = rateLimiter.shouldRecord('test_metric'); + expect(result).toBe(true); + + vi.useRealTimers(); + }); + + it('should handle different metric keys independently', () => { + rateLimiter.shouldRecord('metric_a'); // First call for metric_a + + const resultA = rateLimiter.shouldRecord('metric_a'); // Second call for metric_a + const resultB = rateLimiter.shouldRecord('metric_b'); // First call for metric_b + + expect(resultA).toBe(false); // Should be blocked + expect(resultB).toBe(true); // Should be allowed + }); + + it('should use shorter interval for high priority events', () => { + vi.useFakeTimers(); + + rateLimiter.shouldRecord('test_metric', true); // High priority + + // Advance time by half the normal interval + vi.advanceTimersByTime(500); + + const result = rateLimiter.shouldRecord('test_metric', true); + expect(result).toBe(true); // Should be allowed due to high priority + + vi.useRealTimers(); + }); + + it('should still block high priority events if interval not met', () => { + vi.useFakeTimers(); + + rateLimiter.shouldRecord('test_metric', true); // High priority + + // Advance time by less than half interval + vi.advanceTimersByTime(300); + + const result = rateLimiter.shouldRecord('test_metric', true); + expect(result).toBe(false); // Should still be blocked + + vi.useRealTimers(); + }); + }); + + describe('forceRecord', () => { + it('should update last record time', () => { + const before = rateLimiter.getTimeUntilNextAllowed('test_metric'); + + rateLimiter.forceRecord('test_metric'); + + const after = rateLimiter.getTimeUntilNextAllowed('test_metric'); + expect(after).toBeGreaterThan(before); + }); + + it('should block subsequent recordings after force record', () => { + rateLimiter.forceRecord('test_metric'); + + const result = rateLimiter.shouldRecord('test_metric'); + expect(result).toBe(false); + }); + }); + + describe('getTimeUntilNextAllowed', () => { + it('should return 0 for new metric', () => { + const time = rateLimiter.getTimeUntilNextAllowed('new_metric'); + expect(time).toBe(0); + }); + + it('should return correct time after recording', () => { + vi.useFakeTimers(); + + rateLimiter.shouldRecord('test_metric'); + + // Advance time partially + vi.advanceTimersByTime(300); + + const timeRemaining = rateLimiter.getTimeUntilNextAllowed('test_metric'); + expect(timeRemaining).toBeCloseTo(700, -1); // Approximately 700ms remaining + + vi.useRealTimers(); + }); + + it('should return 0 after interval has passed', () => { + vi.useFakeTimers(); + + rateLimiter.shouldRecord('test_metric'); + + // Advance time past interval + vi.advanceTimersByTime(1500); + + const timeRemaining = rateLimiter.getTimeUntilNextAllowed('test_metric'); + expect(timeRemaining).toBe(0); + + vi.useRealTimers(); + }); + + it('should account for high priority interval', () => { + vi.useFakeTimers(); + + rateLimiter.shouldRecord('hp_metric', true); + + // After 300ms, with 1000ms base interval, half rounded is 500ms + vi.advanceTimersByTime(300); + + const timeRemaining = rateLimiter.getTimeUntilNextAllowed( + 'hp_metric', + true, + ); + expect(timeRemaining).toBeCloseTo(200, -1); + + vi.useRealTimers(); + }); + }); + + describe('getStats', () => { + it('should return empty stats initially', () => { + const stats = rateLimiter.getStats(); + expect(stats).toEqual({ + totalMetrics: 0, + oldestRecord: 0, + newestRecord: 0, + averageInterval: 0, + }); + }); + + it('should return correct stats after recordings', () => { + vi.useFakeTimers(); + + rateLimiter.shouldRecord('metric_a'); + vi.advanceTimersByTime(500); + rateLimiter.shouldRecord('metric_b'); + vi.advanceTimersByTime(500); + rateLimiter.shouldRecord('metric_c'); + + const stats = rateLimiter.getStats(); + expect(stats.totalMetrics).toBe(3); + expect(stats.averageInterval).toBeCloseTo(500, -1); + + vi.useRealTimers(); + }); + + it('should handle single recording correctly', () => { + rateLimiter.shouldRecord('test_metric'); + + const stats = rateLimiter.getStats(); + expect(stats.totalMetrics).toBe(1); + expect(stats.averageInterval).toBe(0); + }); + }); + + describe('reset', () => { + it('should clear all rate limiting state', () => { + rateLimiter.shouldRecord('metric_a'); + rateLimiter.shouldRecord('metric_b'); + + rateLimiter.reset(); + + const stats = rateLimiter.getStats(); + expect(stats.totalMetrics).toBe(0); + + // Should allow immediate recording after reset + const result = rateLimiter.shouldRecord('metric_a'); + expect(result).toBe(true); + }); + }); + + describe('cleanup', () => { + it('should remove old entries', () => { + vi.useFakeTimers(); + + rateLimiter.shouldRecord('old_metric'); + + // Advance time beyond cleanup threshold + vi.advanceTimersByTime(4000000); // More than 1 hour + + rateLimiter.cleanup(3600000); // 1 hour cleanup + + // Should allow immediate recording of old metric after cleanup + const result = rateLimiter.shouldRecord('old_metric'); + expect(result).toBe(true); + + vi.useRealTimers(); + }); + + it('should preserve recent entries', () => { + vi.useFakeTimers(); + + rateLimiter.shouldRecord('recent_metric'); + + // Advance time but not beyond cleanup threshold + vi.advanceTimersByTime(1800000); // 30 minutes + + rateLimiter.cleanup(3600000); // 1 hour cleanup + + // Should no longer be rate limited after 30 minutes (way past 1 minute default interval) + const result = rateLimiter.shouldRecord('recent_metric'); + expect(result).toBe(true); + + vi.useRealTimers(); + }); + + it('should use default cleanup age', () => { + vi.useFakeTimers(); + + rateLimiter.shouldRecord('test_metric'); + + // Advance time beyond default cleanup (1 hour) + vi.advanceTimersByTime(4000000); + + rateLimiter.cleanup(); // Use default age + + const result = rateLimiter.shouldRecord('test_metric'); + expect(result).toBe(true); + + vi.useRealTimers(); + }); + }); + + describe('edge cases', () => { + it('should handle zero interval', () => { + const zeroLimiter = new RateLimiter(0); + + zeroLimiter.shouldRecord('test_metric'); + const result = zeroLimiter.shouldRecord('test_metric'); + + expect(result).toBe(true); // Should allow with zero interval + }); + + it('should handle very large intervals', () => { + const longLimiter = new RateLimiter(Number.MAX_SAFE_INTEGER); + + longLimiter.shouldRecord('test_metric'); + const timeRemaining = longLimiter.getTimeUntilNextAllowed('test_metric'); + + expect(timeRemaining).toBeGreaterThan(1000000); + }); + }); +}); diff --git a/packages/core/src/telemetry/rate-limiter.ts b/packages/core/src/telemetry/rate-limiter.ts new file mode 100644 index 00000000000..076887cd354 --- /dev/null +++ b/packages/core/src/telemetry/rate-limiter.ts @@ -0,0 +1,124 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Rate limiter to prevent excessive telemetry recording + * Ensures we don't send metrics more frequently than specified limits + */ +export class RateLimiter { + private lastRecordTimes: Map = new Map(); + private readonly minIntervalMs: number; + private static readonly HIGH_PRIORITY_DIVISOR = 2; + + constructor(minIntervalMs: number = 60000) { + if (minIntervalMs < 0) { + throw new Error('minIntervalMs must be non-negative.'); + } + this.minIntervalMs = minIntervalMs; + } + + /** + * Check if we should record a metric based on rate limiting + * @param metricKey - Unique key for the metric type/context + * @param isHighPriority - If true, uses shorter interval for critical events + * @returns true if metric should be recorded + */ + shouldRecord(metricKey: string, isHighPriority: boolean = false): boolean { + const now = Date.now(); + const lastRecordTime = this.lastRecordTimes.get(metricKey) || 0; + + // Use shorter interval for high priority events (e.g., memory leaks) + const interval = isHighPriority + ? Math.round(this.minIntervalMs / RateLimiter.HIGH_PRIORITY_DIVISOR) + : this.minIntervalMs; + + if (now - lastRecordTime >= interval) { + this.lastRecordTimes.set(metricKey, now); + return true; + } + + return false; + } + + /** + * Force record a metric (bypasses rate limiting) + * Use sparingly for critical events + */ + forceRecord(metricKey: string): void { + this.lastRecordTimes.set(metricKey, Date.now()); + } + + /** + * Get time until next allowed recording for a metric + */ + getTimeUntilNextAllowed( + metricKey: string, + isHighPriority: boolean = false, + ): number { + const now = Date.now(); + const lastRecordTime = this.lastRecordTimes.get(metricKey) || 0; + const interval = isHighPriority + ? Math.round(this.minIntervalMs / RateLimiter.HIGH_PRIORITY_DIVISOR) + : this.minIntervalMs; + const nextAllowedTime = lastRecordTime + interval; + + return Math.max(0, nextAllowedTime - now); + } + + /** + * Get statistics about rate limiting + */ + getStats(): { + totalMetrics: number; + oldestRecord: number; + newestRecord: number; + averageInterval: number; + } { + const recordTimes = Array.from(this.lastRecordTimes.values()); + + if (recordTimes.length === 0) { + return { + totalMetrics: 0, + oldestRecord: 0, + newestRecord: 0, + averageInterval: 0, + }; + } + + const oldest = Math.min(...recordTimes); + const newest = Math.max(...recordTimes); + const totalSpan = newest - oldest; + const averageInterval = + recordTimes.length > 1 ? totalSpan / (recordTimes.length - 1) : 0; + + return { + totalMetrics: recordTimes.length, + oldestRecord: oldest, + newestRecord: newest, + averageInterval, + }; + } + + /** + * Clear all rate limiting state + */ + reset(): void { + this.lastRecordTimes.clear(); + } + + /** + * Remove old entries to prevent memory leaks + */ + cleanup(maxAgeMs: number = 3600000): void { + const cutoffTime = Date.now() - maxAgeMs; + + for (const [key, time] of this.lastRecordTimes.entries()) { + if (time < cutoffTime) { + this.lastRecordTimes.delete(key); + } + } + } +} diff --git a/packages/core/src/telemetry/streamingTelemetry.ts b/packages/core/src/telemetry/streamingTelemetry.ts new file mode 100644 index 00000000000..eab21041425 --- /dev/null +++ b/packages/core/src/telemetry/streamingTelemetry.ts @@ -0,0 +1,44 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { EventEmitter } from 'node:events'; +import type { TelemetryEvent } from './types.js'; + +export interface TelemetryStreamListener { + (event: TelemetryEvent): void; +} + +class StreamingTelemetryService extends EventEmitter { + private enabled = false; + + enable(): void { + this.enabled = true; + } + + disable(): void { + this.enabled = false; + } + + isEnabled(): boolean { + return this.enabled; + } + + addTelemetryListener(listener: TelemetryStreamListener): void { + this.on('telemetry', listener); + } + + removeTelemetryListener(listener: TelemetryStreamListener): void { + this.off('telemetry', listener); + } + + emitEvent(event: TelemetryEvent): void { + if (this.enabled) { + this.emit('telemetry', event); + } + } +} + +export const streamingTelemetryService = new StreamingTelemetryService(); \ No newline at end of file diff --git a/packages/core/src/telemetry/types.ts b/packages/core/src/telemetry/types.ts index 804b0c0304c..e22cc372882 100644 --- a/packages/core/src/telemetry/types.ts +++ b/packages/core/src/telemetry/types.ts @@ -18,6 +18,7 @@ import { import type { FileOperation } from './metrics.js'; export { ToolCallDecision }; import type { ToolRegistry } from '../tools/tool-registry.js'; +import type { OutputFormat } from '../output/types.js'; export interface BaseTelemetryEvent { 'event.name': string; @@ -45,6 +46,7 @@ export class StartSessionEvent implements BaseTelemetryEvent { mcp_servers_count: number; mcp_tools_count?: number; mcp_tools?: string; + output_format: OutputFormat; constructor(config: Config, toolRegistry?: ToolRegistry) { const generatorConfig = config.getContentGeneratorConfig(); @@ -74,6 +76,7 @@ export class StartSessionEvent implements BaseTelemetryEvent { this.file_filtering_respect_git_ignore = config.getFileFilteringRespectGitIgnore(); this.mcp_servers_count = mcpServers ? Object.keys(mcpServers).length : 0; + this.output_format = config.getOutputFormat(); if (toolRegistry) { const mcpTools = toolRegistry .getAllTools() @@ -133,6 +136,7 @@ export class ToolCallEvent implements BaseTelemetryEvent { error_type?: string; prompt_id: string; tool_type: 'native' | 'mcp'; + content_length?: number; // eslint-disable-next-line @typescript-eslint/no-explicit-any metadata?: { [key: string]: any }; @@ -153,6 +157,7 @@ export class ToolCallEvent implements BaseTelemetryEvent { typeof call.tool !== 'undefined' && call.tool instanceof DiscoveredMCPTool ? 'mcp' : 'native'; + this.content_length = call.response.contentLength; if ( call.status === 'success' && @@ -165,8 +170,12 @@ export class ToolCallEvent implements BaseTelemetryEvent { this.metadata = { model_added_lines: diffStat.model_added_lines, model_removed_lines: diffStat.model_removed_lines, + model_added_chars: diffStat.model_added_chars, + model_removed_chars: diffStat.model_removed_chars, user_added_lines: diffStat.user_added_lines, user_removed_lines: diffStat.user_removed_lines, + user_added_chars: diffStat.user_added_chars, + user_removed_chars: diffStat.user_removed_chars, }; } } @@ -227,7 +236,6 @@ export class ApiResponseEvent implements BaseTelemetryEvent { model: string; status_code?: number | string; duration_ms: number; - error?: string; input_token_count: number; output_token_count: number; cached_content_token_count: number; @@ -245,7 +253,6 @@ export class ApiResponseEvent implements BaseTelemetryEvent { auth_type?: string, usage_data?: GenerateContentResponseUsageMetadata, response_text?: string, - error?: string, ) { this['event.name'] = 'api_response'; this['event.timestamp'] = new Date().toISOString(); @@ -259,7 +266,6 @@ export class ApiResponseEvent implements BaseTelemetryEvent { this.tool_token_count = usage_data?.toolUsePromptTokenCount ?? 0; this.total_token_count = usage_data?.totalTokenCount ?? 0; this.response_text = response_text; - this.error = error; this.prompt_id = prompt_id; this.auth_type = auth_type; } @@ -277,6 +283,16 @@ export class FlashFallbackEvent implements BaseTelemetryEvent { } } +export class RipgrepFallbackEvent implements BaseTelemetryEvent { + 'event.name': 'ripgrep_fallback'; + 'event.timestamp': string; + + constructor(public error?: string) { + this['event.name'] = 'ripgrep_fallback'; + this['event.timestamp'] = new Date().toISOString(); + } +} + export enum LoopType { CONSECUTIVE_IDENTICAL_TOOL_CALLS = 'consecutive_identical_tool_calls', CHANTING_IDENTICAL_SENTENCES = 'chanting_identical_sentences', @@ -517,4 +533,60 @@ export type TelemetryEvent = | FileOperationEvent | InvalidChunkEvent | ContentRetryEvent - | ContentRetryFailureEvent; + | ContentRetryFailureEvent + | ExtensionInstallEvent + | ToolOutputTruncatedEvent; + +export class ExtensionInstallEvent implements BaseTelemetryEvent { + 'event.name': 'extension_install'; + 'event.timestamp': string; + extension_name: string; + extension_version: string; + extension_source: string; + status: 'success' | 'error'; + + constructor( + extension_name: string, + extension_version: string, + extension_source: string, + status: 'success' | 'error', + ) { + this['event.name'] = 'extension_install'; + this['event.timestamp'] = new Date().toISOString(); + this.extension_name = extension_name; + this.extension_version = extension_version; + this.extension_source = extension_source; + this.status = status; + } +} + +export class ToolOutputTruncatedEvent implements BaseTelemetryEvent { + readonly eventName = 'tool_output_truncated'; + readonly 'event.timestamp' = new Date().toISOString(); + 'event.name': string; + tool_name: string; + original_content_length: number; + truncated_content_length: number; + threshold: number; + lines: number; + prompt_id: string; + + constructor( + prompt_id: string, + details: { + toolName: string; + originalContentLength: number; + truncatedContentLength: number; + threshold: number; + lines: number; + }, + ) { + this['event.name'] = this.eventName; + this.prompt_id = prompt_id; + this.tool_name = details.toolName; + this.original_content_length = details.originalContentLength; + this.truncated_content_length = details.truncatedContentLength; + this.threshold = details.threshold; + this.lines = details.lines; + } +} diff --git a/packages/core/src/tools/edit.test.ts b/packages/core/src/tools/edit.test.ts index 7b98083e311..49ed12e5a6d 100644 --- a/packages/core/src/tools/edit.test.ts +++ b/packages/core/src/tools/edit.test.ts @@ -10,16 +10,12 @@ const mockEnsureCorrectEdit = vi.hoisted(() => vi.fn()); const mockGenerateJson = vi.hoisted(() => vi.fn()); const mockOpenDiff = vi.hoisted(() => vi.fn()); -import { IdeClient, IDEConnectionStatus } from '../ide/ide-client.js'; +import { IdeClient } from '../ide/ide-client.js'; vi.mock('../ide/ide-client.js', () => ({ IdeClient: { getInstance: vi.fn(), }, - IDEConnectionStatus: { - Connected: 'connected', - Disconnected: 'disconnected', - }, })); vi.mock('../utils/editCorrector.js', () => ({ @@ -896,9 +892,7 @@ describe('EditTool', () => { filePath = path.join(rootDir, testFile); ideClient = { openDiff: vi.fn(), - getConnectionStatus: vi.fn().mockReturnValue({ - status: IDEConnectionStatus.Connected, - }), + isDiffingEnabled: vi.fn().mockReturnValue(true), }; vi.mocked(IdeClient.getInstance).mockResolvedValue(ideClient); (mockConfig as any).getIdeMode = () => true; diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index cb0418b1e48..2c17fe86f09 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -32,7 +32,7 @@ import type { ModifiableDeclarativeTool, ModifyContext, } from './modifiable-tool.js'; -import { IdeClient, IDEConnectionStatus } from '../ide/ide-client.js'; +import { IdeClient } from '../ide/ide-client.js'; export function applyReplacement( currentContent: string | null, @@ -268,8 +268,7 @@ class EditToolInvocation implements ToolInvocation { ); const ideClient = await IdeClient.getInstance(); const ideConfirmation = - this.config.getIdeMode() && - ideClient?.getConnectionStatus().status === IDEConnectionStatus.Connected + this.config.getIdeMode() && ideClient.isDiffingEnabled() ? ideClient.openDiff(this.params.file_path, editData.newContent) : undefined; diff --git a/packages/core/src/tools/glob.test.ts b/packages/core/src/tools/glob.test.ts index 3a911a57112..b965ce90362 100644 --- a/packages/core/src/tools/glob.test.ts +++ b/packages/core/src/tools/glob.test.ts @@ -28,6 +28,10 @@ describe('GlobTool', () => { const mockConfig = { getFileService: () => new FileDiscoveryService(tempRootDir), getFileFilteringRespectGitIgnore: () => true, + getFileFilteringOptions: () => ({ + respectGitIgnore: true, + respectGeminiIgnore: true, + }), getTargetDir: () => tempRootDir, getWorkspaceContext: () => createMockWorkspaceContext(tempRootDir), getFileExclusions: () => ({ @@ -38,6 +42,7 @@ describe('GlobTool', () => { beforeEach(async () => { // Create a unique root directory for each test run tempRootDir = await fs.mkdtemp(path.join(os.tmpdir(), 'glob-tool-root-')); + await fs.writeFile(path.join(tempRootDir, '.git'), ''); // Fake git repo globTool = new GlobTool(mockConfig); // Create some test files and directories within this root @@ -366,6 +371,88 @@ describe('GlobTool', () => { expect(result.llmContent).toContain('FileD.MD'); }); }); + + describe('ignore file handling', () => { + it('should respect .gitignore files by default', async () => { + await fs.writeFile(path.join(tempRootDir, '.gitignore'), '*.ignored.txt'); + await fs.writeFile( + path.join(tempRootDir, 'a.ignored.txt'), + 'ignored content', + ); + await fs.writeFile( + path.join(tempRootDir, 'b.notignored.txt'), + 'not ignored content', + ); + + const params: GlobToolParams = { pattern: '*.txt' }; + const invocation = globTool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('Found 3 file(s)'); // fileA.txt, FileB.TXT, b.notignored.txt + expect(result.llmContent).not.toContain('a.ignored.txt'); + }); + + it('should respect .geminiignore files by default', async () => { + await fs.writeFile( + path.join(tempRootDir, '.geminiignore'), + '*.geminiignored.txt', + ); + await fs.writeFile( + path.join(tempRootDir, 'a.geminiignored.txt'), + 'ignored content', + ); + await fs.writeFile( + path.join(tempRootDir, 'b.notignored.txt'), + 'not ignored content', + ); + + const params: GlobToolParams = { pattern: '*.txt' }; + const invocation = globTool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('Found 3 file(s)'); // fileA.txt, FileB.TXT, b.notignored.txt + expect(result.llmContent).not.toContain('a.geminiignored.txt'); + }); + + it('should not respect .gitignore when respect_git_ignore is false', async () => { + await fs.writeFile(path.join(tempRootDir, '.gitignore'), '*.ignored.txt'); + await fs.writeFile( + path.join(tempRootDir, 'a.ignored.txt'), + 'ignored content', + ); + + const params: GlobToolParams = { + pattern: '*.txt', + respect_git_ignore: false, + }; + const invocation = globTool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('Found 3 file(s)'); // fileA.txt, FileB.TXT, a.ignored.txt + expect(result.llmContent).toContain('a.ignored.txt'); + }); + + it('should not respect .geminiignore when respect_gemini_ignore is false', async () => { + await fs.writeFile( + path.join(tempRootDir, '.geminiignore'), + '*.geminiignored.txt', + ); + await fs.writeFile( + path.join(tempRootDir, 'a.geminiignored.txt'), + 'ignored content', + ); + + const params: GlobToolParams = { + pattern: '*.txt', + respect_gemini_ignore: false, + }; + const invocation = globTool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('Found 3 file(s)'); // fileA.txt, FileB.TXT, a.geminiignored.txt + expect(result.llmContent).toContain('a.geminiignored.txt'); + }); + }); }); describe('sortFileEntries', () => { diff --git a/packages/core/src/tools/glob.ts b/packages/core/src/tools/glob.ts index 7efae2baf1c..895b6c9e8a8 100644 --- a/packages/core/src/tools/glob.ts +++ b/packages/core/src/tools/glob.ts @@ -10,7 +10,10 @@ import { glob, escape } from 'glob'; import type { ToolInvocation, ToolResult } from './tools.js'; import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js'; import { shortenPath, makeRelative } from '../utils/paths.js'; -import type { Config } from '../config/config.js'; +import { + type Config, + DEFAULT_FILE_FILTERING_OPTIONS, +} from '../config/config.js'; import { ToolErrorType } from './tool-error.js'; // Subset of 'Path' interface provided by 'glob' that we can implement for testing @@ -72,6 +75,11 @@ export interface GlobToolParams { * Whether to respect .gitignore patterns (optional, defaults to true) */ respect_git_ignore?: boolean; + + /** + * Whether to respect .geminiignore patterns (optional, defaults to true) + */ + respect_gemini_ignore?: boolean; } class GlobToolInvocation extends BaseToolInvocation< @@ -128,14 +136,10 @@ class GlobToolInvocation extends BaseToolInvocation< } // Get centralized file discovery service - const respectGitIgnore = - this.params.respect_git_ignore ?? - this.config.getFileFilteringRespectGitIgnore(); const fileDiscovery = this.config.getFileService(); // Collect entries from all search directories - let allEntries: GlobPath[] = []; - + const allEntries: GlobPath[] = []; for (const searchDir of searchDirectories) { let pattern = this.params.pattern; const fullPath = path.join(searchDir, pattern); @@ -155,33 +159,32 @@ class GlobToolInvocation extends BaseToolInvocation< signal, })) as GlobPath[]; - allEntries = allEntries.concat(entries); + allEntries.push(...entries); } - const entries = allEntries; - - // Apply git-aware filtering if enabled and in git repository - let filteredEntries = entries; - let gitIgnoredCount = 0; + const relativePaths = allEntries.map((p) => + path.relative(this.config.getTargetDir(), p.fullpath()), + ); - if (respectGitIgnore) { - const relativePaths = entries.map((p) => - path.relative(this.config.getTargetDir(), p.fullpath()), - ); - const filteredRelativePaths = fileDiscovery.filterFiles(relativePaths, { - respectGitIgnore, + const { filteredPaths, gitIgnoredCount, geminiIgnoredCount } = + fileDiscovery.filterFilesWithReport(relativePaths, { + respectGitIgnore: + this.params?.respect_git_ignore ?? + this.config.getFileFilteringOptions().respectGitIgnore ?? + DEFAULT_FILE_FILTERING_OPTIONS.respectGitIgnore, + respectGeminiIgnore: + this.params?.respect_gemini_ignore ?? + this.config.getFileFilteringOptions().respectGeminiIgnore ?? + DEFAULT_FILE_FILTERING_OPTIONS.respectGeminiIgnore, }); - const filteredAbsolutePaths = new Set( - filteredRelativePaths.map((p) => - path.resolve(this.config.getTargetDir(), p), - ), - ); - filteredEntries = entries.filter((entry) => - filteredAbsolutePaths.has(entry.fullpath()), - ); - gitIgnoredCount = entries.length - filteredEntries.length; - } + const filteredAbsolutePaths = new Set( + filteredPaths.map((p) => path.resolve(this.config.getTargetDir(), p)), + ); + + const filteredEntries = allEntries.filter((entry) => + filteredAbsolutePaths.has(entry.fullpath()), + ); if (!filteredEntries || filteredEntries.length === 0) { let message = `No files found matching pattern "${this.params.pattern}"`; @@ -193,6 +196,9 @@ class GlobToolInvocation extends BaseToolInvocation< if (gitIgnoredCount > 0) { message += ` (${gitIgnoredCount} files were git-ignored)`; } + if (geminiIgnoredCount > 0) { + message += ` (${geminiIgnoredCount} files were gemini-ignored)`; + } return { llmContent: message, returnDisplay: `No files found`, @@ -225,6 +231,9 @@ class GlobToolInvocation extends BaseToolInvocation< if (gitIgnoredCount > 0) { resultMessage += ` (${gitIgnoredCount} additional files were git-ignored)`; } + if (geminiIgnoredCount > 0) { + resultMessage += ` (${geminiIgnoredCount} additional files were gemini-ignored)`; + } resultMessage += `, sorted by modification time (newest first):\n${fileListDescription}`; return { @@ -282,6 +291,11 @@ export class GlobTool extends BaseDeclarativeTool { 'Optional: Whether to respect .gitignore patterns when finding files. Only available in git repositories. Defaults to true.', type: 'boolean', }, + respect_gemini_ignore: { + description: + 'Optional: Whether to respect .geminiignore patterns when finding files. Defaults to true.', + type: 'boolean', + }, }, required: ['pattern'], type: 'object', diff --git a/packages/core/src/tools/ls.test.ts b/packages/core/src/tools/ls.test.ts index 9b2797b548f..f48f8cde238 100644 --- a/packages/core/src/tools/ls.test.ts +++ b/packages/core/src/tools/ls.test.ts @@ -4,65 +4,38 @@ * SPDX-License-Identifier: Apache-2.0 */ -/* eslint-disable @typescript-eslint/no-explicit-any */ - -import { describe, it, expect, beforeEach, vi } from 'vitest'; -import fs from 'node:fs'; +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; +import fs from 'node:fs/promises'; import path from 'node:path'; - -vi.mock('fs', () => ({ - default: { - statSync: vi.fn(), - readdirSync: vi.fn(), - }, - statSync: vi.fn(), - readdirSync: vi.fn(), - mkdirSync: vi.fn(), -})); +import os from 'node:os'; import { LSTool } from './ls.js'; import type { Config } from '../config/config.js'; -import type { WorkspaceContext } from '../utils/workspaceContext.js'; -import type { FileDiscoveryService } from '../services/fileDiscoveryService.js'; +import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { ToolErrorType } from './tool-error.js'; +import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js'; describe('LSTool', () => { let lsTool: LSTool; + let tempRootDir: string; + let tempSecondaryDir: string; let mockConfig: Config; - let mockWorkspaceContext: WorkspaceContext; - let mockFileService: FileDiscoveryService; - const mockPrimaryDir = '/home/user/project'; - const mockSecondaryDir = '/home/user/other-project'; - - beforeEach(() => { - vi.resetAllMocks(); - - // Mock WorkspaceContext - mockWorkspaceContext = { - getDirectories: vi - .fn() - .mockReturnValue([mockPrimaryDir, mockSecondaryDir]), - isPathWithinWorkspace: vi - .fn() - .mockImplementation( - (path) => - path.startsWith(mockPrimaryDir) || - path.startsWith(mockSecondaryDir), - ), - addDirectory: vi.fn(), - } as unknown as WorkspaceContext; - - // Mock FileService - mockFileService = { - shouldGitIgnoreFile: vi.fn().mockReturnValue(false), - shouldGeminiIgnoreFile: vi.fn().mockReturnValue(false), - } as unknown as FileDiscoveryService; - - // Mock Config + const abortSignal = new AbortController().signal; + + beforeEach(async () => { + tempRootDir = await fs.mkdtemp(path.join(os.tmpdir(), 'ls-tool-root-')); + tempSecondaryDir = await fs.mkdtemp( + path.join(os.tmpdir(), 'ls-tool-secondary-'), + ); + + const mockWorkspaceContext = createMockWorkspaceContext(tempRootDir, [ + tempSecondaryDir, + ]); + mockConfig = { - getTargetDir: vi.fn().mockReturnValue(mockPrimaryDir), - getWorkspaceContext: vi.fn().mockReturnValue(mockWorkspaceContext), - getFileService: vi.fn().mockReturnValue(mockFileService), - getFileFilteringOptions: vi.fn().mockReturnValue({ + getTargetDir: () => tempRootDir, + getWorkspaceContext: () => mockWorkspaceContext, + getFileService: () => new FileDiscoveryService(tempRootDir), + getFileFilteringOptions: () => ({ respectGitIgnore: true, respectGeminiIgnore: true, }), @@ -71,221 +44,132 @@ describe('LSTool', () => { lsTool = new LSTool(mockConfig); }); + afterEach(async () => { + await fs.rm(tempRootDir, { recursive: true, force: true }); + await fs.rm(tempSecondaryDir, { recursive: true, force: true }); + }); + describe('parameter validation', () => { - it('should accept valid absolute paths within workspace', () => { - const params = { - path: '/home/user/project/src', - }; - vi.mocked(fs.statSync).mockReturnValue({ - isDirectory: () => true, - } as fs.Stats); - const invocation = lsTool.build(params); + it('should accept valid absolute paths within workspace', async () => { + const testPath = path.join(tempRootDir, 'src'); + await fs.mkdir(testPath); + + const invocation = lsTool.build({ path: testPath }); + expect(invocation).toBeDefined(); }); it('should reject relative paths', () => { - const params = { - path: './src', - }; - - expect(() => lsTool.build(params)).toThrow( + expect(() => lsTool.build({ path: './src' })).toThrow( 'Path must be absolute: ./src', ); }); it('should reject paths outside workspace with clear error message', () => { - const params = { - path: '/etc/passwd', - }; - - expect(() => lsTool.build(params)).toThrow( - 'Path must be within one of the workspace directories: /home/user/project, /home/user/other-project', + expect(() => lsTool.build({ path: '/etc/passwd' })).toThrow( + `Path must be within one of the workspace directories: ${tempRootDir}, ${tempSecondaryDir}`, ); }); - it('should accept paths in secondary workspace directory', () => { - const params = { - path: '/home/user/other-project/lib', - }; - vi.mocked(fs.statSync).mockReturnValue({ - isDirectory: () => true, - } as fs.Stats); - const invocation = lsTool.build(params); + it('should accept paths in secondary workspace directory', async () => { + const testPath = path.join(tempSecondaryDir, 'lib'); + await fs.mkdir(testPath); + + const invocation = lsTool.build({ path: testPath }); + expect(invocation).toBeDefined(); }); }); describe('execute', () => { it('should list files in a directory', async () => { - const testPath = '/home/user/project/src'; - const mockFiles = ['file1.ts', 'file2.ts', 'subdir']; - const mockStats = { - isDirectory: vi.fn(), - mtime: new Date(), - size: 1024, - }; - - vi.mocked(fs.statSync).mockImplementation((path: any) => { - const pathStr = path.toString(); - if (pathStr === testPath) { - return { isDirectory: () => true } as fs.Stats; - } - // For individual files - if (pathStr.toString().endsWith('subdir')) { - return { ...mockStats, isDirectory: () => true, size: 0 } as fs.Stats; - } - return { ...mockStats, isDirectory: () => false } as fs.Stats; - }); - - vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); + await fs.writeFile(path.join(tempRootDir, 'file1.txt'), 'content1'); + await fs.mkdir(path.join(tempRootDir, 'subdir')); + await fs.writeFile( + path.join(tempSecondaryDir, 'secondary-file.txt'), + 'secondary', + ); - const invocation = lsTool.build({ path: testPath }); - const result = await invocation.execute(new AbortController().signal); + const invocation = lsTool.build({ path: tempRootDir }); + const result = await invocation.execute(abortSignal); expect(result.llmContent).toContain('[DIR] subdir'); - expect(result.llmContent).toContain('file1.ts'); - expect(result.llmContent).toContain('file2.ts'); - expect(result.returnDisplay).toBe('Listed 3 item(s).'); + expect(result.llmContent).toContain('file1.txt'); + expect(result.returnDisplay).toBe('Listed 2 item(s).'); }); it('should list files from secondary workspace directory', async () => { - const testPath = '/home/user/other-project/lib'; - const mockFiles = ['module1.js', 'module2.js']; - - vi.mocked(fs.statSync).mockImplementation((path: any) => { - if (path.toString() === testPath) { - return { isDirectory: () => true } as fs.Stats; - } - return { - isDirectory: () => false, - mtime: new Date(), - size: 2048, - } as fs.Stats; - }); + await fs.writeFile(path.join(tempRootDir, 'file1.txt'), 'content1'); + await fs.mkdir(path.join(tempRootDir, 'subdir')); + await fs.writeFile( + path.join(tempSecondaryDir, 'secondary-file.txt'), + 'secondary', + ); - vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); + const invocation = lsTool.build({ path: tempSecondaryDir }); + const result = await invocation.execute(abortSignal); - const invocation = lsTool.build({ path: testPath }); - const result = await invocation.execute(new AbortController().signal); - - expect(result.llmContent).toContain('module1.js'); - expect(result.llmContent).toContain('module2.js'); - expect(result.returnDisplay).toBe('Listed 2 item(s).'); + expect(result.llmContent).toContain('secondary-file.txt'); + expect(result.returnDisplay).toBe('Listed 1 item(s).'); }); it('should handle empty directories', async () => { - const testPath = '/home/user/project/empty'; + const emptyDir = path.join(tempRootDir, 'empty'); + await fs.mkdir(emptyDir); + const invocation = lsTool.build({ path: emptyDir }); + const result = await invocation.execute(abortSignal); - vi.mocked(fs.statSync).mockReturnValue({ - isDirectory: () => true, - } as fs.Stats); - vi.mocked(fs.readdirSync).mockReturnValue([]); - - const invocation = lsTool.build({ path: testPath }); - const result = await invocation.execute(new AbortController().signal); - - expect(result.llmContent).toBe( - 'Directory /home/user/project/empty is empty.', - ); + expect(result.llmContent).toBe(`Directory ${emptyDir} is empty.`); expect(result.returnDisplay).toBe('Directory is empty.'); }); it('should respect ignore patterns', async () => { - const testPath = '/home/user/project/src'; - const mockFiles = ['test.js', 'test.spec.js', 'index.js']; - - vi.mocked(fs.statSync).mockImplementation((path: any) => { - const pathStr = path.toString(); - if (pathStr === testPath) { - return { isDirectory: () => true } as fs.Stats; - } - return { - isDirectory: () => false, - mtime: new Date(), - size: 1024, - } as fs.Stats; - }); - vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); + await fs.writeFile(path.join(tempRootDir, 'file1.txt'), 'content1'); + await fs.writeFile(path.join(tempRootDir, 'file2.log'), 'content1'); const invocation = lsTool.build({ - path: testPath, - ignore: ['*.spec.js'], + path: tempRootDir, + ignore: ['*.log'], }); - const result = await invocation.execute(new AbortController().signal); + const result = await invocation.execute(abortSignal); - expect(result.llmContent).toContain('test.js'); - expect(result.llmContent).toContain('index.js'); - expect(result.llmContent).not.toContain('test.spec.js'); - expect(result.returnDisplay).toBe('Listed 2 item(s).'); + expect(result.llmContent).toContain('file1.txt'); + expect(result.llmContent).not.toContain('file2.log'); + expect(result.returnDisplay).toBe('Listed 1 item(s).'); }); it('should respect gitignore patterns', async () => { - const testPath = '/home/user/project/src'; - const mockFiles = ['file1.js', 'file2.js', 'ignored.js']; - - vi.mocked(fs.statSync).mockImplementation((path: any) => { - const pathStr = path.toString(); - if (pathStr === testPath) { - return { isDirectory: () => true } as fs.Stats; - } - return { - isDirectory: () => false, - mtime: new Date(), - size: 1024, - } as fs.Stats; - }); - vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); - (mockFileService.shouldGitIgnoreFile as any).mockImplementation( - (path: string) => path.includes('ignored.js'), - ); - - const invocation = lsTool.build({ path: testPath }); - const result = await invocation.execute(new AbortController().signal); - - expect(result.llmContent).toContain('file1.js'); - expect(result.llmContent).toContain('file2.js'); - expect(result.llmContent).not.toContain('ignored.js'); - expect(result.returnDisplay).toBe('Listed 2 item(s). (1 git-ignored)'); + await fs.writeFile(path.join(tempRootDir, 'file1.txt'), 'content1'); + await fs.writeFile(path.join(tempRootDir, 'file2.log'), 'content1'); + await fs.writeFile(path.join(tempRootDir, '.git'), ''); + await fs.writeFile(path.join(tempRootDir, '.gitignore'), '*.log'); + const invocation = lsTool.build({ path: tempRootDir }); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('file1.txt'); + expect(result.llmContent).not.toContain('file2.log'); + // .git is always ignored by default. + expect(result.returnDisplay).toBe('Listed 2 item(s). (2 git-ignored)'); }); it('should respect geminiignore patterns', async () => { - const testPath = '/home/user/project/src'; - const mockFiles = ['file1.js', 'file2.js', 'private.js']; - - vi.mocked(fs.statSync).mockImplementation((path: any) => { - const pathStr = path.toString(); - if (pathStr === testPath) { - return { isDirectory: () => true } as fs.Stats; - } - return { - isDirectory: () => false, - mtime: new Date(), - size: 1024, - } as fs.Stats; - }); - vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); - (mockFileService.shouldGeminiIgnoreFile as any).mockImplementation( - (path: string) => path.includes('private.js'), - ); - - const invocation = lsTool.build({ path: testPath }); - const result = await invocation.execute(new AbortController().signal); - - expect(result.llmContent).toContain('file1.js'); - expect(result.llmContent).toContain('file2.js'); - expect(result.llmContent).not.toContain('private.js'); + await fs.writeFile(path.join(tempRootDir, 'file1.txt'), 'content1'); + await fs.writeFile(path.join(tempRootDir, 'file2.log'), 'content1'); + await fs.writeFile(path.join(tempRootDir, '.geminiignore'), '*.log'); + const invocation = lsTool.build({ path: tempRootDir }); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('file1.txt'); + expect(result.llmContent).not.toContain('file2.log'); expect(result.returnDisplay).toBe('Listed 2 item(s). (1 gemini-ignored)'); }); it('should handle non-directory paths', async () => { - const testPath = '/home/user/project/file.txt'; - - vi.mocked(fs.statSync).mockReturnValue({ - isDirectory: () => false, - } as fs.Stats); + const testPath = path.join(tempRootDir, 'file1.txt'); + await fs.writeFile(testPath, 'content1'); const invocation = lsTool.build({ path: testPath }); - const result = await invocation.execute(new AbortController().signal); + const result = await invocation.execute(abortSignal); expect(result.llmContent).toContain('Path is not a directory'); expect(result.returnDisplay).toBe('Error: Path is not a directory.'); @@ -293,14 +177,9 @@ describe('LSTool', () => { }); it('should handle non-existent paths', async () => { - const testPath = '/home/user/project/does-not-exist'; - - vi.mocked(fs.statSync).mockImplementation(() => { - throw new Error('ENOENT: no such file or directory'); - }); - + const testPath = path.join(tempRootDir, 'does-not-exist'); const invocation = lsTool.build({ path: testPath }); - const result = await invocation.execute(new AbortController().signal); + const result = await invocation.execute(abortSignal); expect(result.llmContent).toContain('Error listing directory'); expect(result.returnDisplay).toBe('Error: Failed to list directory.'); @@ -308,54 +187,38 @@ describe('LSTool', () => { }); it('should sort directories first, then files alphabetically', async () => { - const testPath = '/home/user/project/src'; - const mockFiles = ['z-file.ts', 'a-dir', 'b-file.ts', 'c-dir']; - - vi.mocked(fs.statSync).mockImplementation((path: any) => { - if (path.toString() === testPath) { - return { isDirectory: () => true } as fs.Stats; - } - if (path.toString().endsWith('-dir')) { - return { - isDirectory: () => true, - mtime: new Date(), - size: 0, - } as fs.Stats; - } - return { - isDirectory: () => false, - mtime: new Date(), - size: 1024, - } as fs.Stats; - }); + await fs.writeFile(path.join(tempRootDir, 'a-file.txt'), 'content1'); + await fs.writeFile(path.join(tempRootDir, 'b-file.txt'), 'content1'); + await fs.mkdir(path.join(tempRootDir, 'x-dir')); + await fs.mkdir(path.join(tempRootDir, 'y-dir')); - vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); - - const invocation = lsTool.build({ path: testPath }); - const result = await invocation.execute(new AbortController().signal); + const invocation = lsTool.build({ path: tempRootDir }); + const result = await invocation.execute(abortSignal); const lines = ( typeof result.llmContent === 'string' ? result.llmContent : '' - ).split('\n'); - const entries = lines.slice(1).filter((line: string) => line.trim()); // Skip header - expect(entries[0]).toBe('[DIR] a-dir'); - expect(entries[1]).toBe('[DIR] c-dir'); - expect(entries[2]).toBe('b-file.ts'); - expect(entries[3]).toBe('z-file.ts'); + ) + .split('\n') + .filter(Boolean); + const entries = lines.slice(1); // Skip header + + expect(entries[0]).toBe('[DIR] x-dir'); + expect(entries[1]).toBe('[DIR] y-dir'); + expect(entries[2]).toBe('a-file.txt'); + expect(entries[3]).toBe('b-file.txt'); }); it('should handle permission errors gracefully', async () => { - const testPath = '/home/user/project/restricted'; + const restrictedDir = path.join(tempRootDir, 'restricted'); + await fs.mkdir(restrictedDir); - vi.mocked(fs.statSync).mockReturnValue({ - isDirectory: () => true, - } as fs.Stats); - vi.mocked(fs.readdirSync).mockImplementation(() => { - throw new Error('EACCES: permission denied'); - }); + // To simulate a permission error in a cross-platform way, + // we mock fs.readdir to throw an error. + const error = new Error('EACCES: permission denied'); + vi.spyOn(fs, 'readdir').mockRejectedValueOnce(error); - const invocation = lsTool.build({ path: testPath }); - const result = await invocation.execute(new AbortController().signal); + const invocation = lsTool.build({ path: restrictedDir }); + const result = await invocation.execute(abortSignal); expect(result.llmContent).toContain('Error listing directory'); expect(result.llmContent).toContain('permission denied'); @@ -363,62 +226,57 @@ describe('LSTool', () => { expect(result.error?.type).toBe(ToolErrorType.LS_EXECUTION_ERROR); }); - it('should throw for invalid params at build time', async () => { + it('should throw for invalid params at build time', () => { expect(() => lsTool.build({ path: '../outside' })).toThrow( 'Path must be absolute: ../outside', ); }); it('should handle errors accessing individual files during listing', async () => { - const testPath = '/home/user/project/src'; - const mockFiles = ['accessible.ts', 'inaccessible.ts']; - - vi.mocked(fs.statSync).mockImplementation((path: any) => { - if (path.toString() === testPath) { - return { isDirectory: () => true } as fs.Stats; + await fs.writeFile(path.join(tempRootDir, 'file1.txt'), 'content1'); + const problematicFile = path.join(tempRootDir, 'problematic.txt'); + await fs.writeFile(problematicFile, 'content2'); + + // To simulate an error on a single file in a cross-platform way, + // we mock fs.stat to throw for a specific file. This avoids + // platform-specific behavior with things like dangling symlinks. + const originalStat = fs.stat; + const statSpy = vi.spyOn(fs, 'stat').mockImplementation(async (p) => { + if (p.toString() === problematicFile) { + throw new Error('Simulated stat error'); } - if (path.toString().endsWith('inaccessible.ts')) { - throw new Error('EACCES: permission denied'); - } - return { - isDirectory: () => false, - mtime: new Date(), - size: 1024, - } as fs.Stats; + return originalStat(p); }); - vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); - // Spy on console.error to verify it's called const consoleErrorSpy = vi .spyOn(console, 'error') .mockImplementation(() => {}); - const invocation = lsTool.build({ path: testPath }); - const result = await invocation.execute(new AbortController().signal); + const invocation = lsTool.build({ path: tempRootDir }); + const result = await invocation.execute(abortSignal); - // Should still list the accessible file - expect(result.llmContent).toContain('accessible.ts'); - expect(result.llmContent).not.toContain('inaccessible.ts'); + // Should still list the other files + expect(result.llmContent).toContain('file1.txt'); + expect(result.llmContent).not.toContain('problematic.txt'); expect(result.returnDisplay).toBe('Listed 1 item(s).'); // Verify error was logged expect(consoleErrorSpy).toHaveBeenCalledWith( - expect.stringContaining('Error accessing'), + expect.stringMatching(/Error accessing.*problematic\.txt/s), ); + statSpy.mockRestore(); consoleErrorSpy.mockRestore(); }); }); describe('getDescription', () => { it('should return shortened relative path', () => { + const deeplyNestedDir = path.join(tempRootDir, 'deeply', 'nested'); const params = { - path: `${mockPrimaryDir}/deeply/nested/directory`, + path: path.join(deeplyNestedDir, 'directory'), }; - vi.mocked(fs.statSync).mockReturnValue({ - isDirectory: () => true, - } as fs.Stats); const invocation = lsTool.build(params); const description = invocation.getDescription(); expect(description).toBe(path.join('deeply', 'nested', 'directory')); @@ -426,31 +284,27 @@ describe('LSTool', () => { it('should handle paths in secondary workspace', () => { const params = { - path: `${mockSecondaryDir}/lib`, + path: path.join(tempSecondaryDir, 'lib'), }; - vi.mocked(fs.statSync).mockReturnValue({ - isDirectory: () => true, - } as fs.Stats); const invocation = lsTool.build(params); const description = invocation.getDescription(); - expect(description).toBe(path.join('..', 'other-project', 'lib')); + const expected = path.relative(tempRootDir, params.path); + expect(description).toBe(expected); }); }); describe('workspace boundary validation', () => { - it('should accept paths in primary workspace directory', () => { - const params = { path: `${mockPrimaryDir}/src` }; - vi.mocked(fs.statSync).mockReturnValue({ - isDirectory: () => true, - } as fs.Stats); + it('should accept paths in primary workspace directory', async () => { + const testPath = path.join(tempRootDir, 'src'); + await fs.mkdir(testPath); + const params = { path: testPath }; expect(lsTool.build(params)).toBeDefined(); }); - it('should accept paths in secondary workspace directory', () => { - const params = { path: `${mockSecondaryDir}/lib` }; - vi.mocked(fs.statSync).mockReturnValue({ - isDirectory: () => true, - } as fs.Stats); + it('should accept paths in secondary workspace directory', async () => { + const testPath = path.join(tempSecondaryDir, 'lib'); + await fs.mkdir(testPath); + const params = { path: testPath }; expect(lsTool.build(params)).toBeDefined(); }); @@ -462,28 +316,16 @@ describe('LSTool', () => { }); it('should list files from secondary workspace directory', async () => { - const testPath = `${mockSecondaryDir}/tests`; - const mockFiles = ['test1.spec.ts', 'test2.spec.ts']; - - vi.mocked(fs.statSync).mockImplementation((path: any) => { - if (path.toString() === testPath) { - return { isDirectory: () => true } as fs.Stats; - } - return { - isDirectory: () => false, - mtime: new Date(), - size: 512, - } as fs.Stats; - }); - - vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); + await fs.writeFile( + path.join(tempSecondaryDir, 'secondary-file.txt'), + 'secondary', + ); - const invocation = lsTool.build({ path: testPath }); - const result = await invocation.execute(new AbortController().signal); + const invocation = lsTool.build({ path: tempSecondaryDir }); + const result = await invocation.execute(abortSignal); - expect(result.llmContent).toContain('test1.spec.ts'); - expect(result.llmContent).toContain('test2.spec.ts'); - expect(result.returnDisplay).toBe('Listed 2 item(s).'); + expect(result.llmContent).toContain('secondary-file.txt'); + expect(result.returnDisplay).toBe('Listed 1 item(s).'); }); }); }); diff --git a/packages/core/src/tools/ls.ts b/packages/core/src/tools/ls.ts index 4a597306121..09c26c796e7 100644 --- a/packages/core/src/tools/ls.ts +++ b/packages/core/src/tools/ls.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import fs from 'node:fs'; +import fs from 'node:fs/promises'; import path from 'node:path'; import type { ToolInvocation, ToolResult } from './tools.js'; import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js'; @@ -133,7 +133,7 @@ class LSToolInvocation extends BaseToolInvocation { */ async execute(_signal: AbortSignal): Promise { try { - const stats = fs.statSync(this.params.path); + const stats = await fs.stat(this.params.path); if (!stats) { // fs.statSync throws on non-existence, so this check might be redundant // but keeping for clarity. Error message adjusted. @@ -151,28 +151,7 @@ class LSToolInvocation extends BaseToolInvocation { ); } - const files = fs.readdirSync(this.params.path); - - const defaultFileIgnores = - this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS; - - const fileFilteringOptions = { - respectGitIgnore: - this.params.file_filtering_options?.respect_git_ignore ?? - defaultFileIgnores.respectGitIgnore, - respectGeminiIgnore: - this.params.file_filtering_options?.respect_gemini_ignore ?? - defaultFileIgnores.respectGeminiIgnore, - }; - - // Get centralized file discovery service - - const fileDiscovery = this.config.getFileService(); - - const entries: FileEntry[] = []; - let gitIgnoredCount = 0; - let geminiIgnoredCount = 0; - + const files = await fs.readdir(this.params.path); if (files.length === 0) { // Changed error message to be more neutral for LLM return { @@ -181,38 +160,39 @@ class LSToolInvocation extends BaseToolInvocation { }; } - for (const file of files) { - if (this.shouldIgnore(file, this.params.ignore)) { - continue; - } - - const fullPath = path.join(this.params.path, file); - const relativePath = path.relative( + const relativePaths = files.map((file) => + path.relative( this.config.getTargetDir(), - fullPath, - ); + path.join(this.params.path, file), + ), + ); - // Check if this file should be ignored based on git or gemini ignore rules - if ( - fileFilteringOptions.respectGitIgnore && - fileDiscovery.shouldGitIgnoreFile(relativePath) - ) { - gitIgnoredCount++; - continue; - } - if ( - fileFilteringOptions.respectGeminiIgnore && - fileDiscovery.shouldGeminiIgnoreFile(relativePath) - ) { - geminiIgnoredCount++; + const fileDiscovery = this.config.getFileService(); + const { filteredPaths, gitIgnoredCount, geminiIgnoredCount } = + fileDiscovery.filterFilesWithReport(relativePaths, { + respectGitIgnore: + this.params.file_filtering_options?.respect_git_ignore ?? + this.config.getFileFilteringOptions().respectGitIgnore ?? + DEFAULT_FILE_FILTERING_OPTIONS.respectGitIgnore, + respectGeminiIgnore: + this.params.file_filtering_options?.respect_gemini_ignore ?? + this.config.getFileFilteringOptions().respectGeminiIgnore ?? + DEFAULT_FILE_FILTERING_OPTIONS.respectGeminiIgnore, + }); + + const entries = []; + for (const relativePath of filteredPaths) { + const fullPath = path.resolve(this.config.getTargetDir(), relativePath); + + if (this.shouldIgnore(path.basename(fullPath), this.params.ignore)) { continue; } try { - const stats = fs.statSync(fullPath); + const stats = await fs.stat(fullPath); const isDir = stats.isDirectory(); entries.push({ - name: file, + name: path.basename(fullPath), path: fullPath, isDirectory: isDir, size: isDir ? 0 : stats.size, @@ -244,7 +224,6 @@ class LSToolInvocation extends BaseToolInvocation { if (geminiIgnoredCount > 0) { ignoredMessages.push(`${geminiIgnoredCount} gemini-ignored`); } - if (ignoredMessages.length > 0) { resultMessage += `\n\n(${ignoredMessages.join(', ')})`; } diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index dcb4589106a..93e25ea8b27 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -66,22 +66,11 @@ export class McpClientManager { this.mcpServerCommand, ); - const serverEntries = Object.entries(servers); - const total = serverEntries.length; - - this.eventEmitter?.emit('mcp-servers-discovery-start', { count: total }); - this.discoveryState = MCPDiscoveryState.IN_PROGRESS; - const discoveryPromises = serverEntries.map( - async ([name, config], index) => { - const current = index + 1; - this.eventEmitter?.emit('mcp-server-connecting', { - name, - current, - total, - }); - + this.eventEmitter?.emit('mcp-client-update', this.clients); + const discoveryPromises = Object.entries(servers).map( + async ([name, config]) => { const client = new McpClient( name, config, @@ -92,21 +81,13 @@ export class McpClientManager { ); this.clients.set(name, client); + this.eventEmitter?.emit('mcp-client-update', this.clients); try { await client.connect(); await client.discover(cliConfig); - this.eventEmitter?.emit('mcp-server-connected', { - name, - current, - total, - }); + this.eventEmitter?.emit('mcp-client-update', this.clients); } catch (error) { - this.eventEmitter?.emit('mcp-server-error', { - name, - current, - total, - error, - }); + this.eventEmitter?.emit('mcp-client-update', this.clients); // Log the error but don't let a single failed server stop the others console.error( `Error during discovery for server '${name}': ${getErrorMessage( diff --git a/packages/core/src/tools/read-many-files.ts b/packages/core/src/tools/read-many-files.ts index 538e0d90568..06944137625 100644 --- a/packages/core/src/tools/read-many-files.ts +++ b/packages/core/src/tools/read-many-files.ts @@ -18,8 +18,10 @@ import { getSpecificMimeType, } from '../utils/fileUtils.js'; import type { PartListUnion } from '@google/genai'; -import type { Config } from '../config/config.js'; -import { DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js'; +import { + type Config, + DEFAULT_FILE_FILTERING_OPTIONS, +} from '../config/config.js'; import { FileOperation } from '../telemetry/metrics.js'; import { getProgrammingLanguage } from '../telemetry/telemetry-utils.js'; import { logFileOperation } from '../telemetry/loggers.js'; @@ -173,20 +175,6 @@ ${finalExclusionPatternsForDescription useDefaultExcludes = true, } = this.params; - const defaultFileIgnores = - this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS; - - const fileFilteringOptions = { - respectGitIgnore: - this.params.file_filtering_options?.respect_git_ignore ?? - defaultFileIgnores.respectGitIgnore, // Use the property from the returned object - respectGeminiIgnore: - this.params.file_filtering_options?.respect_gemini_ignore ?? - defaultFileIgnores.respectGeminiIgnore, // Use the property from the returned object - }; - // Get centralized file discovery service - const fileDiscovery = this.config.getFileService(); - const filesToConsider = new Set(); const skippedFiles: Array<{ path: string; reason: string }> = []; const processedFilesRelativePaths: string[] = []; @@ -227,71 +215,37 @@ ${finalExclusionPatternsForDescription allEntries.add(entry); } } - const entries = Array.from(allEntries); - - const gitFilteredEntries = fileFilteringOptions.respectGitIgnore - ? fileDiscovery - .filterFiles( - entries.map((p) => path.relative(this.config.getTargetDir(), p)), - { - respectGitIgnore: true, - respectGeminiIgnore: false, - }, - ) - .map((p) => path.resolve(this.config.getTargetDir(), p)) - : entries; - - // Apply gemini ignore filtering if enabled - const finalFilteredEntries = fileFilteringOptions.respectGeminiIgnore - ? fileDiscovery - .filterFiles( - gitFilteredEntries.map((p) => - path.relative(this.config.getTargetDir(), p), - ), - { - respectGitIgnore: false, - respectGeminiIgnore: true, - }, - ) - .map((p) => path.resolve(this.config.getTargetDir(), p)) - : gitFilteredEntries; - - let gitIgnoredCount = 0; - let geminiIgnoredCount = 0; - - for (const absoluteFilePath of entries) { + const relativeEntries = Array.from(allEntries).map((p) => + path.relative(this.config.getTargetDir(), p), + ); + + const fileDiscovery = this.config.getFileService(); + const { filteredPaths, gitIgnoredCount, geminiIgnoredCount } = + fileDiscovery.filterFilesWithReport(relativeEntries, { + respectGitIgnore: + this.params.file_filtering_options?.respect_git_ignore ?? + this.config.getFileFilteringOptions().respectGitIgnore ?? + DEFAULT_FILE_FILTERING_OPTIONS.respectGitIgnore, + respectGeminiIgnore: + this.params.file_filtering_options?.respect_gemini_ignore ?? + this.config.getFileFilteringOptions().respectGeminiIgnore ?? + DEFAULT_FILE_FILTERING_OPTIONS.respectGeminiIgnore, + }); + + for (const relativePath of filteredPaths) { // Security check: ensure the glob library didn't return something outside the workspace. + + const fullPath = path.resolve(this.config.getTargetDir(), relativePath); if ( - !this.config - .getWorkspaceContext() - .isPathWithinWorkspace(absoluteFilePath) + !this.config.getWorkspaceContext().isPathWithinWorkspace(fullPath) ) { skippedFiles.push({ - path: absoluteFilePath, - reason: `Security: Glob library returned path outside workspace. Path: ${absoluteFilePath}`, + path: fullPath, + reason: `Security: Glob library returned path outside workspace. Path: ${fullPath}`, }); continue; } - - // Check if this file was filtered out by git ignore - if ( - fileFilteringOptions.respectGitIgnore && - !gitFilteredEntries.includes(absoluteFilePath) - ) { - gitIgnoredCount++; - continue; - } - - // Check if this file was filtered out by gemini ignore - if ( - fileFilteringOptions.respectGeminiIgnore && - !finalFilteredEntries.includes(absoluteFilePath) - ) { - geminiIgnoredCount++; - continue; - } - - filesToConsider.add(absoluteFilePath); + filesToConsider.add(fullPath); } // Add info about git-ignored files if any were filtered diff --git a/packages/core/src/tools/ripGrep.test.ts b/packages/core/src/tools/ripGrep.test.ts index 06cc4ccce51..7c47275b497 100644 --- a/packages/core/src/tools/ripGrep.test.ts +++ b/packages/core/src/tools/ripGrep.test.ts @@ -4,9 +4,17 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; +import { + describe, + it, + expect, + beforeEach, + afterEach, + vi, + type Mock, +} from 'vitest'; import type { RipGrepToolParams } from './ripGrep.js'; -import { RipGrepTool } from './ripGrep.js'; +import { canUseRipgrep, RipGrepTool, ensureRgPath } from './ripGrep.js'; import path from 'node:path'; import fs from 'node:fs/promises'; import os, { EOL } from 'node:os'; @@ -14,10 +22,24 @@ import type { Config } from '../config/config.js'; import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js'; import type { ChildProcess } from 'node:child_process'; import { spawn } from 'node:child_process'; +import { downloadRipGrep } from '@joshua.litt/get-ripgrep'; +import { fileExists } from '../utils/fileUtils.js'; -// Mock @lvce-editor/ripgrep for testing -vi.mock('@lvce-editor/ripgrep', () => ({ - rgPath: '/mock/rg/path', +// Mock dependencies for canUseRipgrep +vi.mock('@joshua.litt/get-ripgrep', () => ({ + downloadRipGrep: vi.fn(), +})); +vi.mock('../utils/fileUtils.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + fileExists: vi.fn(), + }; +}); +vi.mock('../config/storage.js', () => ({ + Storage: { + getGlobalBinDir: vi.fn().mockReturnValue('/mock/bin/dir'), + }, })); // Mock child_process for ripgrep calls @@ -27,6 +49,97 @@ vi.mock('child_process', () => ({ const mockSpawn = vi.mocked(spawn); +describe('canUseRipgrep', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('should return true if ripgrep already exists', async () => { + (fileExists as Mock).mockResolvedValue(true); + const result = await canUseRipgrep(); + expect(result).toBe(true); + expect(fileExists).toHaveBeenCalledWith(path.join('/mock/bin/dir', 'rg')); + expect(downloadRipGrep).not.toHaveBeenCalled(); + }); + + it('should download ripgrep and return true if it does not exist initially', async () => { + (fileExists as Mock) + .mockResolvedValueOnce(false) + .mockResolvedValueOnce(true); + (downloadRipGrep as Mock).mockResolvedValue(undefined); + + const result = await canUseRipgrep(); + + expect(result).toBe(true); + expect(fileExists).toHaveBeenCalledTimes(2); + expect(downloadRipGrep).toHaveBeenCalledWith('/mock/bin/dir'); + }); + + it('should return false if download fails and file does not exist', async () => { + (fileExists as Mock).mockResolvedValue(false); + (downloadRipGrep as Mock).mockResolvedValue(undefined); + + const result = await canUseRipgrep(); + + expect(result).toBe(false); + expect(fileExists).toHaveBeenCalledTimes(2); + expect(downloadRipGrep).toHaveBeenCalledWith('/mock/bin/dir'); + }); + + it('should propagate errors from downloadRipGrep', async () => { + const error = new Error('Download failed'); + (fileExists as Mock).mockResolvedValue(false); + (downloadRipGrep as Mock).mockRejectedValue(error); + + await expect(canUseRipgrep()).rejects.toThrow(error); + expect(fileExists).toHaveBeenCalledTimes(1); + expect(downloadRipGrep).toHaveBeenCalledWith('/mock/bin/dir'); + }); +}); + +describe('ensureRgPath', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('should return rg path if ripgrep already exists', async () => { + (fileExists as Mock).mockResolvedValue(true); + const rgPath = await ensureRgPath(); + expect(rgPath).toBe(path.join('/mock/bin/dir', 'rg')); + expect(fileExists).toHaveBeenCalledOnce(); + expect(downloadRipGrep).not.toHaveBeenCalled(); + }); + + it('should return rg path if ripgrep is downloaded successfully', async () => { + (fileExists as Mock) + .mockResolvedValueOnce(false) + .mockResolvedValueOnce(true); + (downloadRipGrep as Mock).mockResolvedValue(undefined); + const rgPath = await ensureRgPath(); + expect(rgPath).toBe(path.join('/mock/bin/dir', 'rg')); + expect(downloadRipGrep).toHaveBeenCalledOnce(); + expect(fileExists).toHaveBeenCalledTimes(2); + }); + + it('should throw an error if ripgrep cannot be used after download attempt', async () => { + (fileExists as Mock).mockResolvedValue(false); + (downloadRipGrep as Mock).mockResolvedValue(undefined); + await expect(ensureRgPath()).rejects.toThrow('Cannot use ripgrep.'); + expect(downloadRipGrep).toHaveBeenCalledOnce(); + expect(fileExists).toHaveBeenCalledTimes(2); + }); + + it('should propagate errors from downloadRipGrep', async () => { + const error = new Error('Download failed'); + (fileExists as Mock).mockResolvedValue(false); + (downloadRipGrep as Mock).mockRejectedValue(error); + + await expect(ensureRgPath()).rejects.toThrow(error); + expect(fileExists).toHaveBeenCalledTimes(1); + expect(downloadRipGrep).toHaveBeenCalledWith('/mock/bin/dir'); + }); +}); + // Helper function to create mock spawn implementations function createMockSpawn( options: { @@ -88,6 +201,8 @@ describe('RipGrepTool', () => { beforeEach(async () => { vi.clearAllMocks(); + (downloadRipGrep as Mock).mockResolvedValue(undefined); + (fileExists as Mock).mockResolvedValue(true); mockSpawn.mockClear(); tempRootDir = await fs.mkdtemp(path.join(os.tmpdir(), 'grep-tool-root-')); grepTool = new RipGrepTool(mockConfig); @@ -434,6 +549,20 @@ describe('RipGrepTool', () => { /params must have required property 'pattern'/, ); }); + + it('should throw an error if ripgrep is not available', async () => { + // Make ensureRgPath throw + (fileExists as Mock).mockResolvedValue(false); + (downloadRipGrep as Mock).mockResolvedValue(undefined); + + const params: RipGrepToolParams = { pattern: 'world' }; + const invocation = grepTool.build(params); + + expect(await invocation.execute(abortSignal)).toStrictEqual({ + llmContent: 'Error during grep search operation: Cannot use ripgrep.', + returnDisplay: 'Error: Cannot use ripgrep.', + }); + }); }); describe('multi-directory workspace', () => { diff --git a/packages/core/src/tools/ripGrep.ts b/packages/core/src/tools/ripGrep.ts index b851c2cd1e7..269fb379930 100644 --- a/packages/core/src/tools/ripGrep.ts +++ b/packages/core/src/tools/ripGrep.ts @@ -8,16 +8,44 @@ import fs from 'node:fs'; import path from 'node:path'; import { EOL } from 'node:os'; import { spawn } from 'node:child_process'; -import { rgPath } from '@lvce-editor/ripgrep'; +import { downloadRipGrep } from '@joshua.litt/get-ripgrep'; import type { ToolInvocation, ToolResult } from './tools.js'; import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js'; import { SchemaValidator } from '../utils/schemaValidator.js'; import { makeRelative, shortenPath } from '../utils/paths.js'; import { getErrorMessage, isNodeError } from '../utils/errors.js'; import type { Config } from '../config/config.js'; +import { fileExists } from '../utils/fileUtils.js'; +import { Storage } from '../config/storage.js'; const DEFAULT_TOTAL_MAX_MATCHES = 20000; +function getRgPath(): string { + return path.join(Storage.getGlobalBinDir(), 'rg'); +} + +/** + * Checks if `rg` exists, if not then attempt to download it. + */ +export async function canUseRipgrep(): Promise { + if (await fileExists(getRgPath())) { + return true; + } + + await downloadRipGrep(Storage.getGlobalBinDir()); + return await fileExists(getRgPath()); +} + +/** + * Ensures `rg` is downloaded, or throws. + */ +export async function ensureRgPath(): Promise { + if (await canUseRipgrep()) { + return getRgPath(); + } + throw new Error('Cannot use ripgrep.'); +} + /** * Parameters for the GrepTool */ @@ -292,6 +320,7 @@ class GrepToolInvocation extends BaseToolInvocation< rgArgs.push(absolutePath); try { + const rgPath = await ensureRgPath(); const output = await new Promise((resolve, reject) => { const child = spawn(rgPath, rgArgs, { windowsHide: true, diff --git a/packages/core/src/tools/shell.test.ts b/packages/core/src/tools/shell.test.ts index 88daa0cdfa0..58df386f4d7 100644 --- a/packages/core/src/tools/shell.test.ts +++ b/packages/core/src/tools/shell.test.ts @@ -155,8 +155,7 @@ describe('ShellTool', () => { expect.any(Function), mockAbortSignal, false, - undefined, - undefined, + {}, ); expect(result.llmContent).toContain('Background PIDs: 54322'); expect(vi.mocked(fs.unlinkSync)).toHaveBeenCalledWith(tmpFile); @@ -183,8 +182,7 @@ describe('ShellTool', () => { expect.any(Function), mockAbortSignal, false, - undefined, - undefined, + {}, ); }); @@ -296,43 +294,6 @@ describe('ShellTool', () => { vi.useRealTimers(); }); - it('should throttle text output updates', async () => { - const invocation = shellTool.build({ command: 'stream' }); - const promise = invocation.execute(mockAbortSignal, updateOutputMock); - - // First chunk, should be throttled. - mockShellOutputCallback({ - type: 'data', - chunk: 'hello ', - }); - expect(updateOutputMock).not.toHaveBeenCalled(); - - // Advance time past the throttle interval. - await vi.advanceTimersByTimeAsync(OUTPUT_UPDATE_INTERVAL_MS + 1); - - // Send a second chunk. THIS event triggers the update with the CUMULATIVE content. - mockShellOutputCallback({ - type: 'data', - chunk: 'world', - }); - - // It should have been called once now with the combined output. - expect(updateOutputMock).toHaveBeenCalledOnce(); - expect(updateOutputMock).toHaveBeenCalledWith('hello world'); - - resolveExecutionPromise({ - rawOutput: Buffer.from(''), - output: '', - exitCode: 0, - signal: null, - error: null, - aborted: false, - pid: 12345, - executionMethod: 'child_process', - }); - await promise; - }); - it('should immediately show binary detection message and throttle progress', async () => { const invocation = shellTool.build({ command: 'cat img' }); const promise = invocation.execute(mockAbortSignal, updateOutputMock); diff --git a/packages/core/src/tools/shell.ts b/packages/core/src/tools/shell.ts index 94e4bd85eca..8e3390bab9f 100644 --- a/packages/core/src/tools/shell.ts +++ b/packages/core/src/tools/shell.ts @@ -24,9 +24,13 @@ import { } from './tools.js'; import { getErrorMessage } from '../utils/errors.js'; import { summarizeToolOutput } from '../utils/summarizer.js'; -import type { ShellOutputEvent } from '../services/shellExecutionService.js'; +import type { + ShellExecutionConfig, + ShellOutputEvent, +} from '../services/shellExecutionService.js'; import { ShellExecutionService } from '../services/shellExecutionService.js'; import { formatMemoryUsage } from '../utils/formatters.js'; +import type { AnsiOutput } from '../utils/terminalSerializer.js'; import { getCommandRoots, isCommandAllowed, @@ -41,7 +45,7 @@ export interface ShellToolParams { directory?: string; } -class ShellToolInvocation extends BaseToolInvocation< +export class ShellToolInvocation extends BaseToolInvocation< ShellToolParams, ToolResult > { @@ -96,9 +100,9 @@ class ShellToolInvocation extends BaseToolInvocation< async execute( signal: AbortSignal, - updateOutput?: (output: string) => void, - terminalColumns?: number, - terminalRows?: number, + updateOutput?: (output: string | AnsiOutput) => void, + shellExecutionConfig?: ShellExecutionConfig, + setPidCallback?: (pid: number) => void, ): Promise { const strippedCommand = stripShellWrapper(this.params.command); @@ -131,63 +135,60 @@ class ShellToolInvocation extends BaseToolInvocation< this.params.directory || '', ); - let cumulativeOutput = ''; - let outputChunks: string[] = [cumulativeOutput]; + let cumulativeOutput: string | AnsiOutput = ''; let lastUpdateTime = Date.now(); let isBinaryStream = false; - const { result: resultPromise } = await ShellExecutionService.execute( - commandToExecute, - cwd, - (event: ShellOutputEvent) => { - if (!updateOutput) { - return; - } + const { result: resultPromise, pid } = + await ShellExecutionService.execute( + commandToExecute, + cwd, + (event: ShellOutputEvent) => { + if (!updateOutput) { + return; + } + + let shouldUpdate = false; - let currentDisplayOutput = ''; - let shouldUpdate = false; - - switch (event.type) { - case 'data': - if (isBinaryStream) break; - outputChunks.push(event.chunk); - if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) { - cumulativeOutput = outputChunks.join(''); - outputChunks = [cumulativeOutput]; - currentDisplayOutput = cumulativeOutput; + switch (event.type) { + case 'data': + if (isBinaryStream) break; + cumulativeOutput = event.chunk; shouldUpdate = true; - } - break; - case 'binary_detected': - isBinaryStream = true; - currentDisplayOutput = - '[Binary output detected. Halting stream...]'; - shouldUpdate = true; - break; - case 'binary_progress': - isBinaryStream = true; - currentDisplayOutput = `[Receiving binary output... ${formatMemoryUsage( - event.bytesReceived, - )} received]`; - if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) { + break; + case 'binary_detected': + isBinaryStream = true; + cumulativeOutput = + '[Binary output detected. Halting stream...]'; shouldUpdate = true; + break; + case 'binary_progress': + isBinaryStream = true; + cumulativeOutput = `[Receiving binary output... ${formatMemoryUsage( + event.bytesReceived, + )} received]`; + if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) { + shouldUpdate = true; + } + break; + default: { + throw new Error('An unhandled ShellOutputEvent was found.'); } - break; - default: { - throw new Error('An unhandled ShellOutputEvent was found.'); } - } - if (shouldUpdate) { - updateOutput(currentDisplayOutput); - lastUpdateTime = Date.now(); - } - }, - signal, - this.config.getShouldUseNodePtyShell(), - terminalColumns, - terminalRows, - ); + if (shouldUpdate) { + updateOutput(cumulativeOutput); + lastUpdateTime = Date.now(); + } + }, + signal, + this.config.getShouldUseNodePtyShell(), + shellExecutionConfig ?? {}, + ); + + if (pid && setPidCallback) { + setPidCallback(pid); + } const result = await resultPromise; diff --git a/packages/core/src/tools/smart-edit.test.ts b/packages/core/src/tools/smart-edit.test.ts index 132d9933067..9ce42c506e5 100644 --- a/packages/core/src/tools/smart-edit.test.ts +++ b/packages/core/src/tools/smart-edit.test.ts @@ -10,16 +10,12 @@ const mockFixLLMEditWithInstruction = vi.hoisted(() => vi.fn()); const mockGenerateJson = vi.hoisted(() => vi.fn()); const mockOpenDiff = vi.hoisted(() => vi.fn()); -import { IdeClient, IDEConnectionStatus } from '../ide/ide-client.js'; +import { IdeClient } from '../ide/ide-client.js'; vi.mock('../ide/ide-client.js', () => ({ IdeClient: { getInstance: vi.fn(), }, - IDEConnectionStatus: { - Connected: 'connected', - Disconnected: 'disconnected', - }, })); vi.mock('../utils/llm-edit-fixer.js', () => ({ @@ -60,6 +56,7 @@ import { ApprovalMode, type Config } from '../config/config.js'; import { type Content, type Part, type SchemaUnion } from '@google/genai'; import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js'; import { StandardFileSystemService } from '../services/fileSystemService.js'; +import type { BaseLlmClient } from '../core/baseLlmClient.js'; describe('SmartEditTool', () => { let tool: SmartEditTool; @@ -67,6 +64,7 @@ describe('SmartEditTool', () => { let rootDir: string; let mockConfig: Config; let geminiClient: any; + let baseLlmClient: BaseLlmClient; beforeEach(() => { vi.restoreAllMocks(); @@ -78,8 +76,13 @@ describe('SmartEditTool', () => { generateJson: mockGenerateJson, }; + baseLlmClient = { + generateJson: mockGenerateJson, + } as unknown as BaseLlmClient; + mockConfig = { getGeminiClient: vi.fn().mockReturnValue(geminiClient), + getBaseLlmClient: vi.fn().mockReturnValue(baseLlmClient), getTargetDir: () => rootDir, getApprovalMode: vi.fn(), setApprovalMode: vi.fn(), @@ -454,9 +457,7 @@ describe('SmartEditTool', () => { filePath = path.join(rootDir, testFile); ideClient = { openDiff: vi.fn(), - getConnectionStatus: vi.fn().mockReturnValue({ - status: IDEConnectionStatus.Connected, - }), + isDiffingEnabled: vi.fn().mockReturnValue(true), }; vi.mocked(IdeClient.getInstance).mockResolvedValue(ideClient); (mockConfig as any).getIdeMode = () => true; diff --git a/packages/core/src/tools/smart-edit.ts b/packages/core/src/tools/smart-edit.ts index 3647b73246b..a39d906d2e0 100644 --- a/packages/core/src/tools/smart-edit.ts +++ b/packages/core/src/tools/smart-edit.ts @@ -28,7 +28,7 @@ import { type ModifiableDeclarativeTool, type ModifyContext, } from './modifiable-tool.js'; -import { IdeClient, IDEConnectionStatus } from '../ide/ide-client.js'; +import { IdeClient } from '../ide/ide-client.js'; import { FixLLMEditWithInstruction } from '../utils/llm-edit-fixer.js'; export function applyReplacement( @@ -310,7 +310,7 @@ class EditToolInvocation implements ToolInvocation { params.new_string, initialError.raw, currentContent, - this.config.getGeminiClient(), + this.config.getBaseLlmClient(), abortSignal, ); @@ -528,8 +528,7 @@ class EditToolInvocation implements ToolInvocation { ); const ideClient = await IdeClient.getInstance(); const ideConfirmation = - this.config.getIdeMode() && - ideClient?.getConnectionStatus().status === IDEConnectionStatus.Connected + this.config.getIdeMode() && ideClient.isDiffingEnabled() ? ideClient.openDiff(this.params.file_path, editData.newContent) : undefined; diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 21a7f965bbf..6029b9f8d28 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -7,7 +7,9 @@ import type { FunctionDeclaration, PartListUnion } from '@google/genai'; import { ToolErrorType } from './tool-error.js'; import type { DiffUpdateResult } from '../ide/ideContext.js'; +import type { ShellExecutionConfig } from '../services/shellExecutionService.js'; import { SchemaValidator } from '../utils/schemaValidator.js'; +import type { AnsiOutput } from '../utils/terminalSerializer.js'; /** * Represents a validated and ready-to-execute tool call. @@ -51,7 +53,8 @@ export interface ToolInvocation< */ execute( signal: AbortSignal, - updateOutput?: (output: string) => void, + updateOutput?: (output: string | AnsiOutput) => void, + shellExecutionConfig?: ShellExecutionConfig, ): Promise; } @@ -79,7 +82,8 @@ export abstract class BaseToolInvocation< abstract execute( signal: AbortSignal, - updateOutput?: (output: string) => void, + updateOutput?: (output: string | AnsiOutput) => void, + shellExecutionConfig?: ShellExecutionConfig, ): Promise; } @@ -197,10 +201,11 @@ export abstract class DeclarativeTool< async buildAndExecute( params: TParams, signal: AbortSignal, - updateOutput?: (output: string) => void, + updateOutput?: (output: string | AnsiOutput) => void, + shellExecutionConfig?: ShellExecutionConfig, ): Promise { const invocation = this.build(params); - return invocation.execute(signal, updateOutput); + return invocation.execute(signal, updateOutput, shellExecutionConfig); } /** @@ -432,7 +437,7 @@ export function hasCycleInSchema(schema: object): boolean { return traverse(schema, new Set(), new Set()); } -export type ToolResultDisplay = string | FileDiff; +export type ToolResultDisplay = string | FileDiff | AnsiOutput; export interface FileDiff { fileDiff: string; diff --git a/packages/core/src/tools/write-file.test.ts b/packages/core/src/tools/write-file.test.ts index df9a33ef63f..adb6ebf7511 100644 --- a/packages/core/src/tools/write-file.test.ts +++ b/packages/core/src/tools/write-file.test.ts @@ -33,6 +33,8 @@ import { } from '../utils/editCorrector.js'; import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js'; import { StandardFileSystemService } from '../services/fileSystemService.js'; +import { IdeClient } from '../ide/ide-client.js'; +import type { DiffUpdateResult } from '../ide/ideContext.js'; const rootDir = path.resolve(os.tmpdir(), 'gemini-cli-test-root'); @@ -47,12 +49,19 @@ vi.mock('../ide/ide-client.js', () => ({ let mockGeminiClientInstance: Mocked; const mockEnsureCorrectEdit = vi.fn(); const mockEnsureCorrectFileContent = vi.fn(); +const mockIdeClient = { + openDiff: vi.fn(), + isDiffingEnabled: vi.fn(), +}; // Wire up the mocked functions to be used by the actual module imports vi.mocked(ensureCorrectEdit).mockImplementation(mockEnsureCorrectEdit); vi.mocked(ensureCorrectFileContent).mockImplementation( mockEnsureCorrectFileContent, ); +vi.mocked(IdeClient.getInstance).mockResolvedValue( + mockIdeClient as unknown as IdeClient, +); // Mock Config const fsService = new StandardFileSystemService(); @@ -437,6 +446,107 @@ describe('WriteFileTool', () => { originalContent.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&'), ); }); + + describe('with IDE integration', () => { + beforeEach(() => { + // Enable IDE mode and set connection status for these tests + mockConfigInternal.getIdeMode.mockReturnValue(true); + mockIdeClient.isDiffingEnabled.mockReturnValue(true); + mockIdeClient.openDiff.mockResolvedValue({ + status: 'accepted', + content: 'ide-modified-content', + }); + }); + + it('should call openDiff and await it when in IDE mode and connected', async () => { + const filePath = path.join(rootDir, 'ide_confirm_file.txt'); + const params = { file_path: filePath, content: 'test' }; + const invocation = tool.build(params); + + const confirmation = (await invocation.shouldConfirmExecute( + abortSignal, + )) as ToolEditConfirmationDetails; + + expect(mockIdeClient.openDiff).toHaveBeenCalledWith( + filePath, + 'test', // The corrected content + ); + // Ensure the promise is awaited by checking the result + expect(confirmation.ideConfirmation).toBeDefined(); + await confirmation.ideConfirmation; // Should resolve + }); + + it('should not call openDiff if not in IDE mode', async () => { + mockConfigInternal.getIdeMode.mockReturnValue(false); + const filePath = path.join(rootDir, 'ide_disabled_file.txt'); + const params = { file_path: filePath, content: 'test' }; + const invocation = tool.build(params); + + await invocation.shouldConfirmExecute(abortSignal); + + expect(mockIdeClient.openDiff).not.toHaveBeenCalled(); + }); + + it('should not call openDiff if IDE is not connected', async () => { + mockIdeClient.isDiffingEnabled.mockReturnValue(false); + const filePath = path.join(rootDir, 'ide_disconnected_file.txt'); + const params = { file_path: filePath, content: 'test' }; + const invocation = tool.build(params); + + await invocation.shouldConfirmExecute(abortSignal); + + expect(mockIdeClient.openDiff).not.toHaveBeenCalled(); + }); + + it('should update params.content with IDE content when onConfirm is called', async () => { + const filePath = path.join(rootDir, 'ide_onconfirm_file.txt'); + const params = { file_path: filePath, content: 'original-content' }; + const invocation = tool.build(params); + + // This is the key part: get the confirmation details + const confirmation = (await invocation.shouldConfirmExecute( + abortSignal, + )) as ToolEditConfirmationDetails; + + // The `onConfirm` function should exist on the details object + expect(confirmation.onConfirm).toBeDefined(); + + // Call `onConfirm` to trigger the logic that updates the content + await confirmation.onConfirm!(ToolConfirmationOutcome.ProceedOnce); + + // Now, check if the original `params` object (captured by the invocation) was modified + expect(invocation.params.content).toBe('ide-modified-content'); + }); + + it('should not await ideConfirmation promise', async () => { + const filePath = path.join(rootDir, 'ide_no_await_file.txt'); + const params = { file_path: filePath, content: 'test' }; + const invocation = tool.build(params); + + let diffPromiseResolved = false; + const diffPromise = new Promise((resolve) => { + setTimeout(() => { + diffPromiseResolved = true; + resolve({ status: 'accepted', content: 'ide-modified-content' }); + }, 50); // A small delay to ensure the check happens before resolution + }); + mockIdeClient.openDiff.mockReturnValue(diffPromise); + + const confirmation = (await invocation.shouldConfirmExecute( + abortSignal, + )) as ToolEditConfirmationDetails; + + // This is the key check: the confirmation details should be returned + // *before* the diffPromise is resolved. + expect(diffPromiseResolved).toBe(false); + expect(confirmation).toBeDefined(); + expect(confirmation.ideConfirmation).toBe(diffPromise); + + // Now, we can await the promise to let the test finish cleanly. + await diffPromise; + expect(diffPromiseResolved).toBe(true); + }); + }); }); describe('execute', () => { diff --git a/packages/core/src/tools/write-file.ts b/packages/core/src/tools/write-file.ts index 597043f05b4..3253fc2d6ef 100644 --- a/packages/core/src/tools/write-file.ts +++ b/packages/core/src/tools/write-file.ts @@ -35,7 +35,7 @@ import type { ModifiableDeclarativeTool, ModifyContext, } from './modifiable-tool.js'; -import { IdeClient, IDEConnectionStatus } from '../ide/ide-client.js'; +import { IdeClient } from '../ide/ide-client.js'; import { logFileOperation } from '../telemetry/loggers.js'; import { FileOperationEvent } from '../telemetry/types.js'; import { FileOperation } from '../telemetry/metrics.js'; @@ -195,8 +195,7 @@ class WriteFileToolInvocation extends BaseToolInvocation< const ideClient = await IdeClient.getInstance(); const ideConfirmation = - this.config.getIdeMode() && - ideClient.getConnectionStatus().status === IDEConnectionStatus.Connected + this.config.getIdeMode() && ideClient.isDiffingEnabled() ? ideClient.openDiff(this.params.file_path, correctedContent) : undefined; diff --git a/packages/core/src/utils/bfsFileSearch.ts b/packages/core/src/utils/bfsFileSearch.ts index cb2f387262d..2d66f719b71 100644 --- a/packages/core/src/utils/bfsFileSearch.ts +++ b/packages/core/src/utils/bfsFileSearch.ts @@ -99,6 +99,16 @@ export async function bfsFileSearch( for (const { currentDir, entries } of results) { for (const entry of entries) { const fullPath = path.join(currentDir, entry.name); + const isDirectory = entry.isDirectory(); + const isMatchingFile = entry.isFile() && entry.name === fileName; + + if (!isDirectory && !isMatchingFile) { + continue; + } + if (isDirectory && ignoreDirsSet.has(entry.name)) { + continue; + } + if ( fileService?.shouldIgnoreFile(fullPath, { respectGitIgnore: options.fileFilteringOptions?.respectGitIgnore, @@ -109,11 +119,9 @@ export async function bfsFileSearch( continue; } - if (entry.isDirectory()) { - if (!ignoreDirsSet.has(entry.name)) { - queue.push(fullPath); - } - } else if (entry.isFile() && entry.name === fileName) { + if (isDirectory) { + queue.push(fullPath); + } else { foundFiles.push(fullPath); } } diff --git a/packages/core/src/utils/errors.ts b/packages/core/src/utils/errors.ts index a02399ea9ef..030910ce884 100644 --- a/packages/core/src/utils/errors.ts +++ b/packages/core/src/utils/errors.ts @@ -59,6 +59,16 @@ export class FatalTurnLimitedError extends FatalError { super(message, 53); } } +export class FatalToolExecutionError extends FatalError { + constructor(message: string) { + super(message, 54); + } +} +export class FatalCancellationError extends FatalError { + constructor(message: string) { + super(message, 130); // Standard exit code for SIGINT + } +} export class ForbiddenError extends Error {} export class UnauthorizedError extends Error {} diff --git a/packages/core/src/utils/fileUtils.test.ts b/packages/core/src/utils/fileUtils.test.ts index fe6860f1f16..dd1ad6e62cd 100644 --- a/packages/core/src/utils/fileUtils.test.ts +++ b/packages/core/src/utils/fileUtils.test.ts @@ -28,6 +28,7 @@ import { processSingleFileContent, detectBOM, readFileWithEncoding, + fileExists, } from './fileUtils.js'; import { StandardFileSystemService } from '../services/fileSystemService.js'; @@ -133,6 +134,25 @@ describe('fileUtils', () => { }); }); + describe('fileExists', () => { + it('should return true if the file exists', async () => { + const testFile = path.join(tempRootDir, 'exists.txt'); + actualNodeFs.writeFileSync(testFile, 'content'); + await expect(fileExists(testFile)).resolves.toBe(true); + }); + + it('should return false if the file does not exist', async () => { + const testFile = path.join(tempRootDir, 'does-not-exist.txt'); + await expect(fileExists(testFile)).resolves.toBe(false); + }); + + it('should return true for a directory that exists', async () => { + const testDir = path.join(tempRootDir, 'exists-dir'); + actualNodeFs.mkdirSync(testDir); + await expect(fileExists(testDir)).resolves.toBe(true); + }); + }); + describe('isBinaryFile', () => { let filePathForBinaryTest: string; diff --git a/packages/core/src/utils/fileUtils.ts b/packages/core/src/utils/fileUtils.ts index f623ae58c85..8525c3b913c 100644 --- a/packages/core/src/utils/fileUtils.ts +++ b/packages/core/src/utils/fileUtils.ts @@ -5,6 +5,7 @@ */ import fs from 'node:fs'; +import fsPromises from 'node:fs/promises'; import path from 'node:path'; import type { PartUnion } from '@google/genai'; // eslint-disable-next-line import/no-internal-modules @@ -467,3 +468,12 @@ export async function processSingleFileContent( }; } } + +export async function fileExists(filePath: string): Promise { + try { + await fsPromises.access(filePath, fs.constants.F_OK); + return true; + } catch (_: unknown) { + return false; + } +} diff --git a/packages/core/src/utils/flashFallback.integration.test.ts b/packages/core/src/utils/flashFallback.test.ts similarity index 62% rename from packages/core/src/utils/flashFallback.integration.test.ts rename to packages/core/src/utils/flashFallback.test.ts index 9211ad2f20e..6d4330f1ad9 100644 --- a/packages/core/src/utils/flashFallback.integration.test.ts +++ b/packages/core/src/utils/flashFallback.test.ts @@ -17,10 +17,13 @@ import { import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { retryWithBackoff } from './retry.js'; import { AuthType } from '../core/contentGenerator.js'; +// Import the new types (Assuming this test file is in packages/core/src/utils/) +import type { FallbackModelHandler } from '../fallback/types.js'; vi.mock('node:fs'); -describe('Flash Fallback Integration', () => { +// Update the description to reflect that this tests the retry utility's integration +describe('Retry Utility Fallback Integration', () => { let config: Config; beforeEach(() => { @@ -41,25 +44,28 @@ describe('Flash Fallback Integration', () => { resetRequestCounter(); }); - it('should automatically accept fallback', async () => { - // Set up a minimal flash fallback handler for testing - const flashFallbackHandler = async (): Promise => true; + // This test validates the Config's ability to store and execute the handler contract. + it('should execute the injected FallbackHandler contract correctly', async () => { + // Set up a minimal handler for testing, ensuring it matches the new type. + const fallbackHandler: FallbackModelHandler = async () => 'retry'; - config.setFlashFallbackHandler(flashFallbackHandler); + // Use the generalized setter + config.setFallbackModelHandler(fallbackHandler); - // Call the handler directly to test - const result = await config.flashFallbackHandler!( + // Call the handler directly via the config property + const result = await config.fallbackModelHandler!( 'gemini-2.5-pro', DEFAULT_GEMINI_FLASH_MODEL, ); - // Verify it automatically accepts - expect(result).toBe(true); + // Verify it returns the correct intent + expect(result).toBe('retry'); }); - it('should trigger fallback after 2 consecutive 429 errors for OAuth users', async () => { + // This test validates the retry utility's logic for triggering the callback. + it('should trigger onPersistent429 after 2 consecutive 429 errors for OAuth users', async () => { let fallbackCalled = false; - let fallbackModel = ''; + // Removed fallbackModel variable as it's no longer relevant here. // Mock function that simulates exactly 2 429 errors, then succeeds after fallback const mockApiCall = vi @@ -68,11 +74,11 @@ describe('Flash Fallback Integration', () => { .mockRejectedValueOnce(createSimulated429Error()) .mockResolvedValueOnce('success after fallback'); - // Mock fallback handler - const mockFallbackHandler = vi.fn(async (_authType?: string) => { + // Mock the onPersistent429 callback (this is what client.ts/geminiChat.ts provides) + const mockPersistent429Callback = vi.fn(async (_authType?: string) => { fallbackCalled = true; - fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; - return fallbackModel; + // Return true to signal retryWithBackoff to reset attempts and continue. + return true; }); // Test with OAuth personal auth type, with maxAttempts = 2 to ensure fallback triggers @@ -84,14 +90,13 @@ describe('Flash Fallback Integration', () => { const status = (error as Error & { status?: number }).status; return status === 429; }, - onPersistent429: mockFallbackHandler, + onPersistent429: mockPersistent429Callback, authType: AuthType.LOGIN_WITH_GOOGLE, }); - // Verify fallback was triggered + // Verify fallback mechanism was triggered expect(fallbackCalled).toBe(true); - expect(fallbackModel).toBe(DEFAULT_GEMINI_FLASH_MODEL); - expect(mockFallbackHandler).toHaveBeenCalledWith( + expect(mockPersistent429Callback).toHaveBeenCalledWith( AuthType.LOGIN_WITH_GOOGLE, expect.any(Error), ); @@ -100,16 +105,16 @@ describe('Flash Fallback Integration', () => { expect(mockApiCall).toHaveBeenCalledTimes(3); }); - it('should not trigger fallback for API key users', async () => { + it('should not trigger onPersistent429 for API key users', async () => { let fallbackCalled = false; // Mock function that simulates 429 errors const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error()); - // Mock fallback handler - const mockFallbackHandler = vi.fn(async () => { + // Mock the callback + const mockPersistent429Callback = vi.fn(async () => { fallbackCalled = true; - return DEFAULT_GEMINI_FLASH_MODEL; + return true; }); // Test with API key auth type - should not trigger fallback @@ -122,7 +127,7 @@ describe('Flash Fallback Integration', () => { const status = (error as Error & { status?: number }).status; return status === 429; }, - onPersistent429: mockFallbackHandler, + onPersistent429: mockPersistent429Callback, authType: AuthType.USE_GEMINI, // API key auth type }); } catch (error) { @@ -132,10 +137,11 @@ describe('Flash Fallback Integration', () => { // Verify fallback was NOT triggered for API key users expect(fallbackCalled).toBe(false); - expect(mockFallbackHandler).not.toHaveBeenCalled(); + expect(mockPersistent429Callback).not.toHaveBeenCalled(); }); - it('should properly disable simulation state after fallback', () => { + // This test validates the test utilities themselves. + it('should properly disable simulation state after fallback (Test Utility)', () => { // Enable simulation setSimulate429(true); diff --git a/packages/core/src/utils/geminiIgnoreParser.test.ts b/packages/core/src/utils/geminiIgnoreParser.test.ts new file mode 100644 index 00000000000..bf85cd8c697 --- /dev/null +++ b/packages/core/src/utils/geminiIgnoreParser.test.ts @@ -0,0 +1,70 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; +import { GeminiIgnoreParser } from './geminiIgnoreParser.js'; +import * as fs from 'node:fs/promises'; +import * as path from 'node:path'; +import * as os from 'node:os'; + +describe('GeminiIgnoreParser', () => { + let projectRoot: string; + + async function createTestFile(filePath: string, content = '') { + const fullPath = path.join(projectRoot, filePath); + await fs.mkdir(path.dirname(fullPath), { recursive: true }); + await fs.writeFile(fullPath, content); + } + + beforeEach(async () => { + projectRoot = await fs.mkdtemp( + path.join(os.tmpdir(), 'geminiignore-test-'), + ); + }); + + afterEach(async () => { + await fs.rm(projectRoot, { recursive: true, force: true }); + vi.restoreAllMocks(); + }); + + describe('when .geminiignore exists', () => { + beforeEach(async () => { + await createTestFile( + '.geminiignore', + 'ignored.txt\n# A comment\n/ignored_dir/\n', + ); + await createTestFile('ignored.txt', 'ignored'); + await createTestFile('not_ignored.txt', 'not ignored'); + await createTestFile( + path.join('ignored_dir', 'file.txt'), + 'in ignored dir', + ); + await createTestFile( + path.join('subdir', 'not_ignored.txt'), + 'not ignored', + ); + }); + + it('should ignore files specified in .geminiignore', () => { + const parser = new GeminiIgnoreParser(projectRoot); + expect(parser.getPatterns()).toEqual(['ignored.txt', '/ignored_dir/']); + expect(parser.isIgnored('ignored.txt')).toBe(true); + expect(parser.isIgnored('not_ignored.txt')).toBe(false); + expect(parser.isIgnored(path.join('ignored_dir', 'file.txt'))).toBe(true); + expect(parser.isIgnored(path.join('subdir', 'not_ignored.txt'))).toBe( + false, + ); + }); + }); + + describe('when .geminiignore does not exist', () => { + it('should not load any patterns and not ignore any files', () => { + const parser = new GeminiIgnoreParser(projectRoot); + expect(parser.getPatterns()).toEqual([]); + expect(parser.isIgnored('any_file.txt')).toBe(false); + }); + }); +}); diff --git a/packages/core/src/utils/geminiIgnoreParser.ts b/packages/core/src/utils/geminiIgnoreParser.ts new file mode 100644 index 00000000000..8518923de49 --- /dev/null +++ b/packages/core/src/utils/geminiIgnoreParser.ts @@ -0,0 +1,81 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as fs from 'node:fs'; +import * as path from 'node:path'; +import ignore from 'ignore'; + +export interface GeminiIgnoreFilter { + isIgnored(filePath: string): boolean; + getPatterns(): string[]; +} + +export class GeminiIgnoreParser implements GeminiIgnoreFilter { + private projectRoot: string; + private patterns: string[] = []; + private ig = ignore(); + + constructor(projectRoot: string) { + this.projectRoot = path.resolve(projectRoot); + this.loadPatterns(); + } + + private loadPatterns(): void { + const patternsFilePath = path.join(this.projectRoot, '.geminiignore'); + let content: string; + try { + content = fs.readFileSync(patternsFilePath, 'utf-8'); + } catch (_error) { + // ignore file not found + return; + } + + this.patterns = (content ?? '') + .split('\n') + .map((p) => p.trim()) + .filter((p) => p !== '' && !p.startsWith('#')); + + this.ig.add(this.patterns); + } + + isIgnored(filePath: string): boolean { + if (this.patterns.length === 0) { + return false; + } + + if (!filePath || typeof filePath !== 'string') { + return false; + } + + if ( + filePath.startsWith('\\') || + filePath === '/' || + filePath.includes('\0') + ) { + return false; + } + + const resolved = path.resolve(this.projectRoot, filePath); + const relativePath = path.relative(this.projectRoot, resolved); + + if (relativePath === '' || relativePath.startsWith('..')) { + return false; + } + + // Even in windows, Ignore expects forward slashes. + const normalizedPath = relativePath.replace(/\\/g, '/'); + + if (normalizedPath.startsWith('/') || normalizedPath === '') { + return false; + } + + return this.ig.ignores(normalizedPath); + } + + getPatterns(): string[] { + return this.patterns; + } +} diff --git a/packages/core/src/utils/gitIgnoreParser.test.ts b/packages/core/src/utils/gitIgnoreParser.test.ts index 25faf889209..19903320913 100644 --- a/packages/core/src/utils/gitIgnoreParser.test.ts +++ b/packages/core/src/utils/gitIgnoreParser.test.ts @@ -33,14 +33,16 @@ describe('GitIgnoreParser', () => { await fs.rm(projectRoot, { recursive: true, force: true }); }); - describe('initialization', () => { - it('should initialize without errors when no .gitignore exists', async () => { + describe('Basic ignore behaviors', () => { + beforeEach(async () => { await setupGitRepo(); - expect(() => parser.loadGitRepoPatterns()).not.toThrow(); }); - it('should load .gitignore patterns when file exists', async () => { - await setupGitRepo(); + it('should not ignore files when no .gitignore exists', async () => { + expect(parser.isIgnored('file.txt')).toBe(false); + }); + + it('should ignore files based on a root .gitignore', async () => { const gitignoreContent = ` # Comment node_modules/ @@ -50,52 +52,28 @@ node_modules/ `; await createTestFile('.gitignore', gitignoreContent); - parser.loadGitRepoPatterns(); - - expect(parser.getPatterns()).toEqual([ - '.git', - 'node_modules/', - '*.log', - '/dist', - '.env', - ]); expect(parser.isIgnored(path.join('node_modules', 'some-lib'))).toBe( true, ); expect(parser.isIgnored(path.join('src', 'app.log'))).toBe(true); expect(parser.isIgnored(path.join('dist', 'index.js'))).toBe(true); expect(parser.isIgnored('.env')).toBe(true); + expect(parser.isIgnored('src/index.js')).toBe(false); }); it('should handle git exclude file', async () => { - await setupGitRepo(); await createTestFile( path.join('.git', 'info', 'exclude'), 'temp/\n*.tmp', ); - parser.loadGitRepoPatterns(); - expect(parser.getPatterns()).toEqual(['.git', 'temp/', '*.tmp']); - expect(parser.isIgnored(path.join('temp', 'file.txt'))).toBe(true); - expect(parser.isIgnored(path.join('src', 'file.tmp'))).toBe(true); - }); - - it('should handle custom patterns file name', async () => { - // No .git directory for this test - await createTestFile('.geminiignore', 'temp/\n*.tmp'); - - parser.loadPatterns('.geminiignore'); - expect(parser.getPatterns()).toEqual(['temp/', '*.tmp']); expect(parser.isIgnored(path.join('temp', 'file.txt'))).toBe(true); expect(parser.isIgnored(path.join('src', 'file.tmp'))).toBe(true); - }); - - it('should initialize without errors when no .geminiignore exists', () => { - expect(() => parser.loadPatterns('.geminiignore')).not.toThrow(); + expect(parser.isIgnored('src/file.js')).toBe(false); }); }); - describe('isIgnored', () => { + describe('isIgnored path handling', () => { beforeEach(async () => { await setupGitRepo(); const gitignoreContent = ` @@ -107,7 +85,6 @@ src/*.tmp !src/important.tmp `; await createTestFile('.gitignore', gitignoreContent); - parser.loadGitRepoPatterns(); }); it('should always ignore .git directory', () => { @@ -205,8 +182,6 @@ src/*.tmp }); it('should handle nested .gitignore files correctly', async () => { - parser.loadGitRepoPatterns(); - // From root .gitignore expect(parser.isIgnored('root-ignored.txt')).toBe(true); expect(parser.isIgnored('a/root-ignored.txt')).toBe(true); @@ -230,34 +205,27 @@ src/*.tmp expect(parser.isIgnored('a/d/f/g')).toBe(true); expect(parser.isIgnored('a/f/g')).toBe(false); }); + }); - it('should correctly transform patterns from nested gitignore files', () => { - parser.loadGitRepoPatterns(); - const patterns = parser.getPatterns(); - - // From root .gitignore - expect(patterns).toContain('root-ignored.txt'); + describe('precedence rules', () => { + beforeEach(async () => { + await setupGitRepo(); + }); - // From a/.gitignore - expect(patterns).toContain('/a/b'); // /b becomes /a/b - expect(patterns).toContain('/a/**/c'); // c becomes /a/**/c + it('should prioritize nested .gitignore over root .gitignore', async () => { + await createTestFile('.gitignore', '*.log'); + await createTestFile('a/b/.gitignore', '!special.log'); - // From a/d/.gitignore - expect(patterns).toContain('/a/d/**/e.txt'); // e.txt becomes /a/d/**/e.txt - expect(patterns).toContain('/a/d/f/g'); // f/g becomes /a/d/f/g + expect(parser.isIgnored('a/b/any.log')).toBe(true); + expect(parser.isIgnored('a/b/special.log')).toBe(false); }); - }); - describe('precedence rules', () => { - it('should prioritize root .gitignore over .git/info/exclude', async () => { - await setupGitRepo(); + it('should prioritize .gitignore over .git/info/exclude', async () => { // Exclude all .log files await createTestFile(path.join('.git', 'info', 'exclude'), '*.log'); // But make an exception in the root .gitignore await createTestFile('.gitignore', '!important.log'); - parser.loadGitRepoPatterns(); - expect(parser.isIgnored('some.log')).toBe(true); expect(parser.isIgnored('important.log')).toBe(false); expect(parser.isIgnored(path.join('subdir', 'some.log'))).toBe(true); @@ -266,15 +234,4 @@ src/*.tmp ); }); }); - - describe('getIgnoredPatterns', () => { - it('should return the raw patterns added', async () => { - await setupGitRepo(); - const gitignoreContent = '*.log\n!important.log'; - await createTestFile('.gitignore', gitignoreContent); - - parser.loadGitRepoPatterns(); - expect(parser.getPatterns()).toEqual(['.git', '*.log', '!important.log']); - }); - }); }); diff --git a/packages/core/src/utils/gitIgnoreParser.ts b/packages/core/src/utils/gitIgnoreParser.ts index 9b1da742204..21d83651a5c 100644 --- a/packages/core/src/utils/gitIgnoreParser.ts +++ b/packages/core/src/utils/gitIgnoreParser.ts @@ -6,82 +6,38 @@ import * as fs from 'node:fs'; import * as path from 'node:path'; -import ignore, { type Ignore } from 'ignore'; -import { isGitRepository } from './gitUtils.js'; +import ignore from 'ignore'; export interface GitIgnoreFilter { isIgnored(filePath: string): boolean; - getPatterns(): string[]; } export class GitIgnoreParser implements GitIgnoreFilter { private projectRoot: string; - private ig: Ignore = ignore(); - private patterns: string[] = []; + private cache: Map = new Map(); + private globalPatterns: string[] | undefined; constructor(projectRoot: string) { this.projectRoot = path.resolve(projectRoot); } - loadGitRepoPatterns(): void { - if (!isGitRepository(this.projectRoot)) return; - - // Always ignore .git directory regardless of .gitignore content - this.addPatterns(['.git']); - - this.loadPatterns(path.join('.git', 'info', 'exclude')); - this.findAndLoadGitignoreFiles(this.projectRoot); - } - - private findAndLoadGitignoreFiles(dir: string): void { - const relativeDir = path.relative(this.projectRoot, dir); - - // For sub-directories, check if they are ignored before proceeding. - // The root directory (relativeDir === '') should not be checked. - if (relativeDir && this.isIgnored(relativeDir)) { - return; - } - - // Load patterns from .gitignore in the current directory - const gitignorePath = path.join(dir, '.gitignore'); - if (fs.existsSync(gitignorePath)) { - this.loadPatterns(path.relative(this.projectRoot, gitignorePath)); - } - - // Recurse into subdirectories - try { - const entries = fs.readdirSync(dir, { withFileTypes: true }); - for (const entry of entries) { - if (entry.name === '.git') { - continue; - } - if (entry.isDirectory()) { - this.findAndLoadGitignoreFiles(path.join(dir, entry.name)); - } - } - } catch (_error) { - // ignore readdir errors - } - } - - loadPatterns(patternsFileName: string): void { - const patternsFilePath = path.join(this.projectRoot, patternsFileName); + private loadPatternsForFile(patternsFilePath: string): string[] { let content: string; try { content = fs.readFileSync(patternsFilePath, 'utf-8'); } catch (_error) { - // ignore file not found - return; + return []; } - // .git/info/exclude file patterns are relative to project root and not file directory - const isExcludeFile = - patternsFileName.replace(/\\/g, '/') === '.git/info/exclude'; + const isExcludeFile = patternsFilePath.endsWith( + path.join('.git', 'info', 'exclude'), + ); + const relativeBaseDir = isExcludeFile ? '.' - : path.dirname(patternsFileName); + : path.dirname(path.relative(this.projectRoot, patternsFilePath)); - const patterns = (content ?? '') + return content .split('\n') .map((p) => p.trim()) .filter((p) => p !== '' && !p.startsWith('#')) @@ -139,12 +95,6 @@ export class GitIgnoreParser implements GitIgnoreFilter { return newPattern; }) .filter((p) => p !== ''); - this.addPatterns(patterns); - } - - private addPatterns(patterns: string[]) { - this.ig.add(patterns); - this.patterns.push(...patterns); } isIgnored(filePath: string): boolean { @@ -152,11 +102,8 @@ export class GitIgnoreParser implements GitIgnoreFilter { return false; } - if ( - filePath.startsWith('\\') || - filePath === '/' || - filePath.includes('\0') - ) { + const absoluteFilePath = path.resolve(this.projectRoot, filePath); + if (!absoluteFilePath.startsWith(this.projectRoot)) { return false; } @@ -175,13 +122,68 @@ export class GitIgnoreParser implements GitIgnoreFilter { return false; } - return this.ig.ignores(normalizedPath); + const ig = ignore(); + + // Always ignore .git directory + ig.add('.git'); + + // Load global patterns from .git/info/exclude on first call + if (this.globalPatterns === undefined) { + const excludeFile = path.join( + this.projectRoot, + '.git', + 'info', + 'exclude', + ); + this.globalPatterns = fs.existsSync(excludeFile) + ? this.loadPatternsForFile(excludeFile) + : []; + } + ig.add(this.globalPatterns); + + const pathParts = relativePath.split(path.sep); + + const dirsToVisit = [this.projectRoot]; + let currentAbsDir = this.projectRoot; + // Collect all directories in the path + for (let i = 0; i < pathParts.length - 1; i++) { + currentAbsDir = path.join(currentAbsDir, pathParts[i]); + dirsToVisit.push(currentAbsDir); + } + + for (const dir of dirsToVisit) { + const relativeDir = path.relative(this.projectRoot, dir); + if (relativeDir) { + const normalizedRelativeDir = relativeDir.replace(/\\/g, '/'); + if (ig.ignores(normalizedRelativeDir)) { + // This directory is ignored by an ancestor's .gitignore. + // According to git behavior, we don't need to process this + // directory's .gitignore, as nothing inside it can be + // un-ignored. + break; + } + } + + if (this.cache.has(dir)) { + const patterns = this.cache.get(dir); + if (patterns) { + ig.add(patterns); + } + } else { + const gitignorePath = path.join(dir, '.gitignore'); + if (fs.existsSync(gitignorePath)) { + const patterns = this.loadPatternsForFile(gitignorePath); + this.cache.set(dir, patterns); + ig.add(patterns); + } else { + this.cache.set(dir, []); // Cache miss + } + } + } + + return ig.ignores(normalizedPath); } catch (_error) { return false; } } - - getPatterns(): string[] { - return this.patterns; - } } diff --git a/packages/core/src/utils/ide-trust.ts b/packages/core/src/utils/ide-trust.ts deleted file mode 100644 index 1cfaa88b477..00000000000 --- a/packages/core/src/utils/ide-trust.ts +++ /dev/null @@ -1,15 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { ideContext } from '../ide/ideContext.js'; - -/** - * Gets the workspace trust from the IDE if available. - * @returns A boolean if the IDE provides a trust value, otherwise undefined. - */ -export function getIdeTrust(): boolean | undefined { - return ideContext.getIdeContext()?.workspaceState?.isTrusted; -} diff --git a/packages/core/src/utils/llm-edit-fixer.test.ts b/packages/core/src/utils/llm-edit-fixer.test.ts new file mode 100644 index 00000000000..4c236ad3425 --- /dev/null +++ b/packages/core/src/utils/llm-edit-fixer.test.ts @@ -0,0 +1,203 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + FixLLMEditWithInstruction, + resetLlmEditFixerCaches_TEST_ONLY, + type SearchReplaceEdit, +} from './llm-edit-fixer.js'; +import { promptIdContext } from './promptIdContext.js'; +import type { BaseLlmClient } from '../core/baseLlmClient.js'; + +// Mock the BaseLlmClient +const mockGenerateJson = vi.fn(); +const mockBaseLlmClient = { + generateJson: mockGenerateJson, +} as unknown as BaseLlmClient; + +describe('FixLLMEditWithInstruction', () => { + const instruction = 'Replace the title'; + const old_string = '

Old Title

'; + const new_string = '

New Title

'; + const error = 'String not found'; + const current_content = '

Old Title

'; + const abortController = new AbortController(); + const abortSignal = abortController.signal; + + beforeEach(() => { + vi.clearAllMocks(); + resetLlmEditFixerCaches_TEST_ONLY(); // Ensure cache is cleared before each test + }); + + afterEach(() => { + vi.useRealTimers(); // Reset timers after each test + }); + + const mockApiResponse: SearchReplaceEdit = { + search: '

Old Title

', + replace: '

New Title

', + noChangesRequired: false, + explanation: 'The original search was correct.', + }; + + it('should use the promptId from the AsyncLocalStorage context when available', async () => { + const testPromptId = 'test-prompt-id-12345'; + mockGenerateJson.mockResolvedValue(mockApiResponse); + + await promptIdContext.run(testPromptId, async () => { + await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + }); + + // Verify that generateJson was called with the promptId from the context + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(mockGenerateJson).toHaveBeenCalledWith( + expect.objectContaining({ + promptId: testPromptId, + }), + ); + }); + + it('should generate and use a fallback promptId when context is not available', async () => { + mockGenerateJson.mockResolvedValue(mockApiResponse); + const consoleWarnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => {}); + + // Run the function outside of any context + await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + + // Verify the warning was logged + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining( + 'Could not find promptId in context. This is unexpected. Using a fallback ID: llm-fixer-fallback-', + ), + ); + + // Verify that generateJson was called with the generated fallback promptId + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(mockGenerateJson).toHaveBeenCalledWith( + expect.objectContaining({ + promptId: expect.stringContaining('llm-fixer-fallback-'), + }), + ); + + // Restore mocks + consoleWarnSpy.mockRestore(); + }); + + it('should construct the user prompt correctly', async () => { + mockGenerateJson.mockResolvedValue(mockApiResponse); + const promptId = 'test-prompt-id-prompt-construction'; + + await promptIdContext.run(promptId, async () => { + await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + }); + + const generateJsonCall = mockGenerateJson.mock.calls[0][0]; + const userPromptContent = generateJsonCall.contents[0].parts[0].text; + + expect(userPromptContent).toContain( + `\n${instruction}\n`, + ); + expect(userPromptContent).toContain(`\n${old_string}\n`); + expect(userPromptContent).toContain(`\n${new_string}\n`); + expect(userPromptContent).toContain(`\n${error}\n`); + expect(userPromptContent).toContain( + `\n${current_content}\n`, + ); + }); + + it('should return a cached result on subsequent identical calls', async () => { + mockGenerateJson.mockResolvedValue(mockApiResponse); + const testPromptId = 'test-prompt-id-caching'; + + await promptIdContext.run(testPromptId, async () => { + // First call - should call the API + const result1 = await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + + // Second call with identical parameters - should hit the cache + const result2 = await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + + expect(result1).toEqual(mockApiResponse); + expect(result2).toEqual(mockApiResponse); + // Verify the underlying service was only called ONCE + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + }); + }); + + it('should not use cache for calls with different parameters', async () => { + mockGenerateJson.mockResolvedValue(mockApiResponse); + const testPromptId = 'test-prompt-id-cache-miss'; + + await promptIdContext.run(testPromptId, async () => { + // First call + await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + + // Second call with a different instruction + await FixLLMEditWithInstruction( + 'A different instruction', + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + + // Verify the underlying service was called TWICE + expect(mockGenerateJson).toHaveBeenCalledTimes(2); + }); + }); +}); diff --git a/packages/core/src/utils/llm-edit-fixer.ts b/packages/core/src/utils/llm-edit-fixer.ts index 95496d47794..a4b4b131c0c 100644 --- a/packages/core/src/utils/llm-edit-fixer.ts +++ b/packages/core/src/utils/llm-edit-fixer.ts @@ -5,9 +5,10 @@ */ import { type Content, Type } from '@google/genai'; -import { type GeminiClient } from '../core/client.js'; +import { type BaseLlmClient } from '../core/baseLlmClient.js'; import { LruCache } from './LruCache.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { promptIdContext } from './promptIdContext.js'; const MAX_CACHE_SIZE = 50; @@ -93,8 +94,9 @@ const editCorrectionWithInstructionCache = new LruCache< * @param new_string The original replacement string. * @param error The error that occurred during the initial edit. * @param current_content The current content of the file. - * @param geminiClient The Gemini client to use for the LLM call. + * @param baseLlmClient The BaseLlmClient to use for the LLM call. * @param abortSignal An abort signal to cancel the operation. + * @param promptId A unique ID for the prompt. * @returns A new search and replace pair. */ export async function FixLLMEditWithInstruction( @@ -103,9 +105,17 @@ export async function FixLLMEditWithInstruction( new_string: string, error: string, current_content: string, - geminiClient: GeminiClient, + baseLlmClient: BaseLlmClient, abortSignal: AbortSignal, ): Promise { + let promptId = promptIdContext.getStore(); + if (!promptId) { + promptId = `llm-fixer-fallback-${Date.now()}-${Math.random().toString(16).slice(2)}`; + console.warn( + `Could not find promptId in context. This is unexpected. Using a fallback ID: ${promptId}`, + ); + } + const cacheKey = `${instruction}---${old_string}---${new_string}--${current_content}--${error}`; const cachedResult = editCorrectionWithInstructionCache.get(cacheKey); if (cachedResult) { @@ -120,21 +130,18 @@ export async function FixLLMEditWithInstruction( const contents: Content[] = [ { role: 'user', - parts: [ - { - text: `${EDIT_SYS_PROMPT} -${userPrompt}`, - }, - ], + parts: [{ text: userPrompt }], }, ]; - const result = (await geminiClient.generateJson( + const result = (await baseLlmClient.generateJson({ contents, - SearchReplaceEditSchema, + schema: SearchReplaceEditSchema, abortSignal, - DEFAULT_GEMINI_FLASH_MODEL, - )) as unknown as SearchReplaceEdit; + model: DEFAULT_GEMINI_FLASH_MODEL, + systemInstruction: EDIT_SYS_PROMPT, + promptId, + })) as unknown as SearchReplaceEdit; editCorrectionWithInstructionCache.set(cacheKey, result); return result; diff --git a/packages/core/src/utils/nextSpeakerChecker.test.ts b/packages/core/src/utils/nextSpeakerChecker.test.ts index b9e861998ec..dab9099d698 100644 --- a/packages/core/src/utils/nextSpeakerChecker.test.ts +++ b/packages/core/src/utils/nextSpeakerChecker.test.ts @@ -6,10 +6,10 @@ import type { Mock } from 'vitest'; import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import type { Content, GoogleGenAI, Models } from '@google/genai'; +import type { Content } from '@google/genai'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { GeminiClient } from '../core/client.js'; -import { Config } from '../config/config.js'; +import type { Config } from '../config/config.js'; import type { NextSpeakerResponse } from './nextSpeakerChecker.js'; import { checkNextSpeaker } from './nextSpeakerChecker.js'; import { GeminiChat } from '../core/geminiChat.js'; @@ -44,73 +44,28 @@ vi.mock('node:fs', () => { vi.mock('../core/client.js'); vi.mock('../config/config.js'); -// Define mocks for GoogleGenAI and Models instances that will be used across tests -const mockModelsInstance = { - generateContent: vi.fn(), - generateContentStream: vi.fn(), - countTokens: vi.fn(), - embedContent: vi.fn(), - batchEmbedContents: vi.fn(), -} as unknown as Models; - -const mockGoogleGenAIInstance = { - getGenerativeModel: vi.fn().mockReturnValue(mockModelsInstance), - // Add other methods of GoogleGenAI if they are directly used by GeminiChat constructor or its methods -} as unknown as GoogleGenAI; - -vi.mock('@google/genai', async () => { - const actualGenAI = - await vi.importActual('@google/genai'); - return { - ...actualGenAI, - GoogleGenAI: vi.fn(() => mockGoogleGenAIInstance), // Mock constructor to return the predefined instance - // If Models is instantiated directly in GeminiChat, mock its constructor too - // For now, assuming Models instance is obtained via getGenerativeModel - }; -}); - describe('checkNextSpeaker', () => { let chatInstance: GeminiChat; + let mockConfig: Config; let mockGeminiClient: GeminiClient; - let MockConfig: Mock; const abortSignal = new AbortController().signal; beforeEach(() => { - MockConfig = vi.mocked(Config); - const mockConfigInstance = new MockConfig( - 'test-api-key', - 'gemini-pro', - false, - '.', - false, - undefined, - false, - undefined, - undefined, - undefined, - ); - - // Mock the methods that ChatRecordingService needs - mockConfigInstance.getSessionId = vi - .fn() - .mockReturnValue('test-session-id'); - mockConfigInstance.getProjectRoot = vi - .fn() - .mockReturnValue('/test/project/root'); - mockConfigInstance.storage = { - getProjectTempDir: vi.fn().mockReturnValue('/test/temp'), - }; - - mockGeminiClient = new GeminiClient(mockConfigInstance); - - // Reset mocks before each test to ensure test isolation - vi.mocked(mockModelsInstance.generateContent).mockReset(); - vi.mocked(mockModelsInstance.generateContentStream).mockReset(); + vi.resetAllMocks(); + mockConfig = { + getProjectRoot: vi.fn().mockReturnValue('/test/project/root'), + getSessionId: vi.fn().mockReturnValue('test-session-id'), + getModel: () => 'test-model', + storage: { + getProjectTempDir: vi.fn().mockReturnValue('/test/temp'), + }, + } as unknown as Config; + + mockGeminiClient = new GeminiClient(mockConfig); // GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor chatInstance = new GeminiChat( - mockConfigInstance, - mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel + mockConfig, {}, [], // initial history ); @@ -120,7 +75,7 @@ describe('checkNextSpeaker', () => { }); afterEach(() => { - vi.clearAllMocks(); + vi.restoreAllMocks(); }); it('should return null if history is empty', async () => { @@ -135,9 +90,9 @@ describe('checkNextSpeaker', () => { }); it('should return null if the last speaker was the user', async () => { - (chatInstance.getHistory as Mock).mockReturnValue([ + vi.mocked(chatInstance.getHistory).mockReturnValue([ { role: 'user', parts: [{ text: 'Hello' }] }, - ] as Content[]); + ]); const result = await checkNextSpeaker( chatInstance, mockGeminiClient, diff --git a/packages/core/src/utils/promptIdContext.ts b/packages/core/src/utils/promptIdContext.ts new file mode 100644 index 00000000000..6344bd0b834 --- /dev/null +++ b/packages/core/src/utils/promptIdContext.ts @@ -0,0 +1,9 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { AsyncLocalStorage } from 'node:async_hooks'; + +export const promptIdContext = new AsyncLocalStorage(); diff --git a/packages/core/src/utils/terminalSerializer.test.ts b/packages/core/src/utils/terminalSerializer.test.ts new file mode 100644 index 00000000000..fd6241d04dc --- /dev/null +++ b/packages/core/src/utils/terminalSerializer.test.ts @@ -0,0 +1,197 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { Terminal } from '@xterm/headless'; +import { + serializeTerminalToObject, + convertColorToHex, + ColorMode, +} from './terminalSerializer.js'; + +const RED_FG = '\x1b[31m'; +const RESET = '\x1b[0m'; + +function writeToTerminal(terminal: Terminal, data: string): Promise { + return new Promise((resolve) => { + terminal.write(data, resolve); + }); +} + +describe('terminalSerializer', () => { + describe('serializeTerminalToObject', () => { + it('should handle an empty terminal', () => { + const terminal = new Terminal({ + cols: 80, + rows: 24, + allowProposedApi: true, + }); + const result = serializeTerminalToObject(terminal); + expect(result).toHaveLength(24); + result.forEach((line) => { + // Expect each line to be either empty or contain a single token with spaces + if (line.length > 0) { + expect(line[0].text.trim()).toBe(''); + } + }); + }); + + it('should serialize a single line of text', async () => { + const terminal = new Terminal({ + cols: 80, + rows: 24, + allowProposedApi: true, + }); + await writeToTerminal(terminal, 'Hello, world!'); + const result = serializeTerminalToObject(terminal); + expect(result[0][0].text).toContain('Hello, world!'); + }); + + it('should serialize multiple lines of text', async () => { + const terminal = new Terminal({ + cols: 7, + rows: 24, + allowProposedApi: true, + }); + await writeToTerminal(terminal, 'Line 1\r\nLine 2'); + const result = serializeTerminalToObject(terminal); + expect(result[0][0].text).toBe('Line 1 '); + expect(result[1][0].text).toBe('Line 2'); + }); + + it('should handle bold text', async () => { + const terminal = new Terminal({ + cols: 80, + rows: 24, + allowProposedApi: true, + }); + await writeToTerminal(terminal, '\x1b[1mBold text\x1b[0m'); + const result = serializeTerminalToObject(terminal); + expect(result[0][0].bold).toBe(true); + expect(result[0][0].text).toBe('Bold text'); + }); + + it('should handle italic text', async () => { + const terminal = new Terminal({ + cols: 80, + rows: 24, + allowProposedApi: true, + }); + await writeToTerminal(terminal, '\x1b[3mItalic text\x1b[0m'); + const result = serializeTerminalToObject(terminal); + expect(result[0][0].italic).toBe(true); + expect(result[0][0].text).toBe('Italic text'); + }); + + it('should handle underlined text', async () => { + const terminal = new Terminal({ + cols: 80, + rows: 24, + allowProposedApi: true, + }); + await writeToTerminal(terminal, '\x1b[4mUnderlined text\x1b[0m'); + const result = serializeTerminalToObject(terminal); + expect(result[0][0].underline).toBe(true); + expect(result[0][0].text).toBe('Underlined text'); + }); + + it('should handle dim text', async () => { + const terminal = new Terminal({ + cols: 80, + rows: 24, + allowProposedApi: true, + }); + await writeToTerminal(terminal, '\x1b[2mDim text\x1b[0m'); + const result = serializeTerminalToObject(terminal); + expect(result[0][0].dim).toBe(true); + expect(result[0][0].text).toBe('Dim text'); + }); + + it('should handle inverse text', async () => { + const terminal = new Terminal({ + cols: 80, + rows: 24, + allowProposedApi: true, + }); + await writeToTerminal(terminal, '\x1b[7mInverse text\x1b[0m'); + const result = serializeTerminalToObject(terminal); + expect(result[0][0].inverse).toBe(true); + expect(result[0][0].text).toBe('Inverse text'); + }); + + it('should handle foreground colors', async () => { + const terminal = new Terminal({ + cols: 80, + rows: 24, + allowProposedApi: true, + }); + await writeToTerminal(terminal, `${RED_FG}Red text${RESET}`); + const result = serializeTerminalToObject(terminal); + expect(result[0][0].fg).toBe('#800000'); + expect(result[0][0].text).toBe('Red text'); + }); + + it('should handle background colors', async () => { + const terminal = new Terminal({ + cols: 80, + rows: 24, + allowProposedApi: true, + }); + await writeToTerminal(terminal, '\x1b[42mGreen background\x1b[0m'); + const result = serializeTerminalToObject(terminal); + expect(result[0][0].bg).toBe('#008000'); + expect(result[0][0].text).toBe('Green background'); + }); + + it('should handle RGB colors', async () => { + const terminal = new Terminal({ + cols: 80, + rows: 24, + allowProposedApi: true, + }); + await writeToTerminal(terminal, '\x1b[38;2;100;200;50mRGB text\x1b[0m'); + const result = serializeTerminalToObject(terminal); + expect(result[0][0].fg).toBe('#64c832'); + expect(result[0][0].text).toBe('RGB text'); + }); + + it('should handle a combination of styles', async () => { + const terminal = new Terminal({ + cols: 80, + rows: 24, + allowProposedApi: true, + }); + await writeToTerminal(terminal, '\x1b[1;31;42mStyled text\x1b[0m'); + const result = serializeTerminalToObject(terminal); + expect(result[0][0].bold).toBe(true); + expect(result[0][0].fg).toBe('#800000'); + expect(result[0][0].bg).toBe('#008000'); + expect(result[0][0].text).toBe('Styled text'); + }); + }); + describe('convertColorToHex', () => { + it('should convert RGB color to hex', () => { + const color = (100 << 16) | (200 << 8) | 50; + const hex = convertColorToHex(color, ColorMode.RGB, '#000000'); + expect(hex).toBe('#64c832'); + }); + + it('should convert palette color to hex', () => { + const hex = convertColorToHex(1, ColorMode.PALETTE, '#000000'); + expect(hex).toBe('#800000'); + }); + + it('should return default color for ColorMode.DEFAULT', () => { + const hex = convertColorToHex(0, ColorMode.DEFAULT, '#ffffff'); + expect(hex).toBe('#ffffff'); + }); + + it('should return default color for invalid palette index', () => { + const hex = convertColorToHex(999, ColorMode.PALETTE, '#000000'); + expect(hex).toBe('#000000'); + }); + }); +}); diff --git a/packages/core/src/utils/terminalSerializer.ts b/packages/core/src/utils/terminalSerializer.ts new file mode 100644 index 00000000000..f3c8eacec02 --- /dev/null +++ b/packages/core/src/utils/terminalSerializer.ts @@ -0,0 +1,479 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { IBufferCell, Terminal } from '@xterm/headless'; +export interface AnsiToken { + text: string; + bold: boolean; + italic: boolean; + underline: boolean; + dim: boolean; + inverse: boolean; + fg: string; + bg: string; +} + +export type AnsiLine = AnsiToken[]; +export type AnsiOutput = AnsiLine[]; + +const enum Attribute { + inverse = 1, + bold = 2, + italic = 4, + underline = 8, + dim = 16, +} + +export const enum ColorMode { + DEFAULT = 0, + PALETTE = 1, + RGB = 2, +} + +class Cell { + private readonly cell: IBufferCell | null; + private readonly x: number; + private readonly y: number; + private readonly cursorX: number; + private readonly cursorY: number; + private readonly attributes: number = 0; + fg = 0; + bg = 0; + fgColorMode: ColorMode = ColorMode.DEFAULT; + bgColorMode: ColorMode = ColorMode.DEFAULT; + + constructor( + cell: IBufferCell | null, + x: number, + y: number, + cursorX: number, + cursorY: number, + ) { + this.cell = cell; + this.x = x; + this.y = y; + this.cursorX = cursorX; + this.cursorY = cursorY; + + if (!cell) { + return; + } + + if (cell.isInverse()) { + this.attributes += Attribute.inverse; + } + if (cell.isBold()) { + this.attributes += Attribute.bold; + } + if (cell.isItalic()) { + this.attributes += Attribute.italic; + } + if (cell.isUnderline()) { + this.attributes += Attribute.underline; + } + if (cell.isDim()) { + this.attributes += Attribute.dim; + } + + if (cell.isFgRGB()) { + this.fgColorMode = ColorMode.RGB; + } else if (cell.isFgPalette()) { + this.fgColorMode = ColorMode.PALETTE; + } else { + this.fgColorMode = ColorMode.DEFAULT; + } + + if (cell.isBgRGB()) { + this.bgColorMode = ColorMode.RGB; + } else if (cell.isBgPalette()) { + this.bgColorMode = ColorMode.PALETTE; + } else { + this.bgColorMode = ColorMode.DEFAULT; + } + + if (this.fgColorMode === ColorMode.DEFAULT) { + this.fg = -1; + } else { + this.fg = cell.getFgColor(); + } + + if (this.bgColorMode === ColorMode.DEFAULT) { + this.bg = -1; + } else { + this.bg = cell.getBgColor(); + } + } + + isCursor(): boolean { + return this.x === this.cursorX && this.y === this.cursorY; + } + + getChars(): string { + return this.cell?.getChars() || ' '; + } + + isAttribute(attribute: Attribute): boolean { + return (this.attributes & attribute) !== 0; + } + + equals(other: Cell): boolean { + return ( + this.attributes === other.attributes && + this.fg === other.fg && + this.bg === other.bg && + this.fgColorMode === other.fgColorMode && + this.bgColorMode === other.bgColorMode && + this.isCursor() === other.isCursor() + ); + } +} + +export function serializeTerminalToObject( + terminal: Terminal, + options?: { defaultFg?: string; defaultBg?: string }, +): AnsiOutput { + const buffer = terminal.buffer.active; + const cursorX = buffer.cursorX; + const cursorY = buffer.cursorY; + const defaultFg = options?.defaultFg ?? '#ffffff'; + const defaultBg = options?.defaultBg ?? '#000000'; + + const result: AnsiOutput = []; + + for (let y = 0; y < terminal.rows; y++) { + const line = buffer.getLine(buffer.viewportY + y); + const currentLine: AnsiLine = []; + if (!line) { + result.push(currentLine); + continue; + } + + let lastCell = new Cell(null, -1, -1, cursorX, cursorY); + let currentText = ''; + + for (let x = 0; x < terminal.cols; x++) { + const cellData = line.getCell(x); + const cell = new Cell(cellData || null, x, y, cursorX, cursorY); + + if (x > 0 && !cell.equals(lastCell)) { + if (currentText) { + const token: AnsiToken = { + text: currentText, + bold: lastCell.isAttribute(Attribute.bold), + italic: lastCell.isAttribute(Attribute.italic), + underline: lastCell.isAttribute(Attribute.underline), + dim: lastCell.isAttribute(Attribute.dim), + inverse: + lastCell.isAttribute(Attribute.inverse) || lastCell.isCursor(), + fg: convertColorToHex(lastCell.fg, lastCell.fgColorMode, defaultFg), + bg: convertColorToHex(lastCell.bg, lastCell.bgColorMode, defaultBg), + }; + currentLine.push(token); + } + currentText = ''; + } + currentText += cell.getChars(); + lastCell = cell; + } + + if (currentText) { + const token: AnsiToken = { + text: currentText, + bold: lastCell.isAttribute(Attribute.bold), + italic: lastCell.isAttribute(Attribute.italic), + underline: lastCell.isAttribute(Attribute.underline), + dim: lastCell.isAttribute(Attribute.dim), + inverse: lastCell.isAttribute(Attribute.inverse) || lastCell.isCursor(), + fg: convertColorToHex(lastCell.fg, lastCell.fgColorMode, defaultFg), + bg: convertColorToHex(lastCell.bg, lastCell.bgColorMode, defaultBg), + }; + currentLine.push(token); + } + + result.push(currentLine); + } + + return result; +} + +// ANSI color palette from https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit +const ANSI_COLORS = [ + '#000000', + '#800000', + '#008000', + '#808000', + '#000080', + '#800080', + '#008080', + '#c0c0c0', + '#808080', + '#ff0000', + '#00ff00', + '#ffff00', + '#0000ff', + '#ff00ff', + '#00ffff', + '#ffffff', + '#000000', + '#00005f', + '#000087', + '#0000af', + '#0000d7', + '#0000ff', + '#005f00', + '#005f5f', + '#005f87', + '#005faf', + '#005fd7', + '#005fff', + '#008700', + '#00875f', + '#008787', + '#0087af', + '#0087d7', + '#0087ff', + '#00af00', + '#00af5f', + '#00af87', + '#00afaf', + '#00afd7', + '#00afff', + '#00d700', + '#00d75f', + '#00d787', + '#00d7af', + '#00d7d7', + '#00d7ff', + '#00ff00', + '#00ff5f', + '#00ff87', + '#00ffaf', + '#00ffd7', + '#00ffff', + '#5f0000', + '#5f005f', + '#5f0087', + '#5f00af', + '#5f00d7', + '#5f00ff', + '#5f5f00', + '#5f5f5f', + '#5f5f87', + '#5f5faf', + '#5f5fd7', + '#5f5fff', + '#5f8700', + '#5f875f', + '#5f8787', + '#5f87af', + '#5f87d7', + '#5f87ff', + '#5faf00', + '#5faf5f', + '#5faf87', + '#5fafaf', + '#5fafd7', + '#5fafff', + '#5fd700', + '#5fd75f', + '#5fd787', + '#5fd7af', + '#5fd7d7', + '#5fd7ff', + '#5fff00', + '#5fff5f', + '#5fff87', + '#5fffaf', + '#5fffd7', + '#5fffff', + '#870000', + '#87005f', + '#870087', + '#8700af', + '#8700d7', + '#8700ff', + '#875f00', + '#875f5f', + '#875f87', + '#875faf', + '#875fd7', + '#875fff', + '#878700', + '#87875f', + '#878787', + '#8787af', + '#8787d7', + '#8787ff', + '#87af00', + '#87af5f', + '#87af87', + '#87afaf', + '#87afd7', + '#87afff', + '#87d700', + '#87d75f', + '#87d787', + '#87d7af', + '#87d7d7', + '#87d7ff', + '#87ff00', + '#87ff5f', + '#87ff87', + '#87ffaf', + '#87ffd7', + '#87ffff', + '#af0000', + '#af005f', + '#af0087', + '#af00af', + '#af00d7', + '#af00ff', + '#af5f00', + '#af5f5f', + '#af5f87', + '#af5faf', + '#af5fd7', + '#af5fff', + '#af8700', + '#af875f', + '#af8787', + '#af87af', + '#af87d7', + '#af87ff', + '#afaf00', + '#afaf5f', + '#afaf87', + '#afafaf', + '#afafd7', + '#afafff', + '#afd700', + '#afd75f', + '#afd787', + '#afd7af', + '#afd7d7', + '#afd7ff', + '#afff00', + '#afff5f', + '#afff87', + '#afffaf', + '#afffd7', + '#afffff', + '#d70000', + '#d7005f', + '#d70087', + '#d700af', + '#d700d7', + '#d700ff', + '#d75f00', + '#d75f5f', + '#d75f87', + '#d75faf', + '#d75fd7', + '#d75fff', + '#d78700', + '#d7875f', + '#d78787', + '#d787af', + '#d787d7', + '#d787ff', + '#d7af00', + '#d7af5f', + '#d7af87', + '#d7afaf', + '#d7afd7', + '#d7afff', + '#d7d700', + '#d7d75f', + '#d7d787', + '#d7d7af', + '#d7d7d7', + '#d7d7ff', + '#d7ff00', + '#d7ff5f', + '#d7ff87', + '#d7ffaf', + '#d7ffd7', + '#d7ffff', + '#ff0000', + '#ff005f', + '#ff0087', + '#ff00af', + '#ff00d7', + '#ff00ff', + '#ff5f00', + '#ff5f5f', + '#ff5f87', + '#ff5faf', + '#ff5fd7', + '#ff5fff', + '#ff8700', + '#ff875f', + '#ff8787', + '#ff87af', + '#ff87d7', + '#ff87ff', + '#ffaf00', + '#ffaf5f', + '#ffaf87', + '#ffafaf', + '#ffafd7', + '#ffafff', + '#ffd700', + '#ffd75f', + '#ffd787', + '#ffd7af', + '#ffd7d7', + '#ffd7ff', + '#ffff00', + '#ffff5f', + '#ffff87', + '#ffffaf', + '#ffffd7', + '#ffffff', + '#080808', + '#121212', + '#1c1c1c', + '#262626', + '#303030', + '#3a3a3a', + '#444444', + '#4e4e4e', + '#585858', + '#626262', + '#6c6c6c', + '#767676', + '#808080', + '#8a8a8a', + '#949494', + '#9e9e9e', + '#a8a8a8', + '#b2b2b2', + '#bcbcbc', + '#c6c6c6', + '#d0d0d0', + '#dadada', + '#e4e4e4', + '#eeeeee', +]; + +export function convertColorToHex( + color: number, + colorMode: ColorMode, + defaultColor: string, +): string { + if (colorMode === ColorMode.RGB) { + const r = (color >> 16) & 255; + const g = (color >> 8) & 255; + const b = color & 255; + return `#${r.toString(16).padStart(2, '0')}${g + .toString(16) + .padStart(2, '0')}${b.toString(16).padStart(2, '0')}`; + } + if (colorMode === ColorMode.PALETTE) { + return ANSI_COLORS[color] || defaultColor; + } + return defaultColor; +} diff --git a/packages/test-utils/package.json b/packages/test-utils/package.json index 27d68c6155b..43ac6bd9342 100644 --- a/packages/test-utils/package.json +++ b/packages/test-utils/package.json @@ -1,6 +1,6 @@ { "name": "@blocksuser/gemini-cli-test-utils", - "version": "0.3.3", + "version": "0.7.0-nightly.20250912.68035591", "private": true, "main": "src/index.ts", "license": "Apache-2.0", diff --git a/packages/vscode-ide-companion/package.json b/packages/vscode-ide-companion/package.json index b41ec7bc233..beda73df0df 100644 --- a/packages/vscode-ide-companion/package.json +++ b/packages/vscode-ide-companion/package.json @@ -2,7 +2,7 @@ "name": "gemini-cli-vscode-ide-companion", "displayName": "Gemini CLI Companion", "description": "Enable Gemini CLI with direct access to your IDE workspace.", - "version": "0.3.3", + "version": "0.7.0-nightly.20250912.68035591", "publisher": "google", "icon": "assets/icon.png", "repository": { diff --git a/packages/vscode-ide-companion/src/extension.test.ts b/packages/vscode-ide-companion/src/extension.test.ts index 8377c01e97d..e3193ad8b7f 100644 --- a/packages/vscode-ide-companion/src/extension.test.ts +++ b/packages/vscode-ide-companion/src/extension.test.ts @@ -7,6 +7,15 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import * as vscode from 'vscode'; import { activate } from './extension.js'; +import { DetectedIde, detectIdeFromEnv } from '@blocksuser/gemini-cli-core'; + +vi.mock('@blocksuser/gemini-cli-core', async () => { + const actual = await vi.importActual('@blocksuser/gemini-cli-core'); + return { + ...actual, + detectIdeFromEnv: vi.fn(() => DetectedIde.VSCode), + }; +}); vi.mock('vscode', () => ({ window: { @@ -187,6 +196,23 @@ describe('activate', () => { expect(showInformationMessageMock).not.toHaveBeenCalled(); }); + it.each([ + { + ide: DetectedIde.CloudShell, + }, + { ide: DetectedIde.FirebaseStudio }, + ])('does not show the notification for $ide', async ({ ide }) => { + vi.mocked(detectIdeFromEnv).mockReturnValue(ide); + vi.mocked(context.globalState.get).mockReturnValue(undefined); + const showInformationMessageMock = vi.mocked( + vscode.window.showInformationMessage, + ); + + await activate(context); + + expect(showInformationMessageMock).not.toHaveBeenCalled(); + }); + it('should not show an update notification if the version is older', async () => { vi.spyOn(global, 'fetch').mockResolvedValue({ ok: true, diff --git a/packages/vscode-ide-companion/src/extension.ts b/packages/vscode-ide-companion/src/extension.ts index 66020cb94f3..afec0e61fe1 100644 --- a/packages/vscode-ide-companion/src/extension.ts +++ b/packages/vscode-ide-companion/src/extension.ts @@ -9,11 +9,22 @@ import { IDEServer } from './ide-server.js'; import semver from 'semver'; import { DiffContentProvider, DiffManager } from './diff-manager.js'; import { createLogger } from './utils/logger.js'; +import { detectIdeFromEnv, DetectedIde } from '@blocksuser/gemini-cli-core'; const CLI_IDE_COMPANION_IDENTIFIER = 'Google.gemini-cli-vscode-ide-companion'; const INFO_MESSAGE_SHOWN_KEY = 'geminiCliInfoMessageShown'; export const DIFF_SCHEME = 'gemini-diff'; +/** + * IDE environments where the installation greeting is hidden. In these + * environments we either are pre-installed and the installation message is + * confusing or we just want to be quiet. + */ +const HIDE_INSTALLATION_GREETING_IDES: ReadonlySet = new Set([ + DetectedIde.FirebaseStudio, + DetectedIde.CloudShell, +]); + let ideServer: IDEServer; let logger: vscode.OutputChannel; @@ -133,7 +144,10 @@ export async function activate(context: vscode.ExtensionContext) { log(`Failed to start IDE server: ${message}`); } - if (!context.globalState.get(INFO_MESSAGE_SHOWN_KEY)) { + const infoMessageEnabled = + !HIDE_INSTALLATION_GREETING_IDES.has(detectIdeFromEnv()); + + if (!context.globalState.get(INFO_MESSAGE_SHOWN_KEY) && infoMessageEnabled) { void vscode.window.showInformationMessage( 'Gemini CLI Companion extension successfully installed.', ); diff --git a/packages/vscode-ide-companion/src/ide-server.ts b/packages/vscode-ide-companion/src/ide-server.ts index 1588a4a1e29..86ef28a8412 100644 --- a/packages/vscode-ide-companion/src/ide-server.ts +++ b/packages/vscode-ide-companion/src/ide-server.ts @@ -107,7 +107,7 @@ export class IDEServer { const sessionsWithInitialNotification = new Set(); const app = express(); - app.use(express.json()); + app.use(express.json({ limit: '10mb' })); const mcpServer = createMcpServer(this.diffManager); this.openFilesManager = new OpenFilesManager(context); @@ -245,6 +245,7 @@ export class IDEServer { `gemini-ide-server-${process.ppid}.json`, ); this.log(`IDE server listening on port ${this.port}`); + await writePortAndWorkspace( context, this.port, diff --git a/scripts/create-patch-pr.js b/scripts/create-patch-pr.js new file mode 100644 index 00000000000..7804c2a9f0a --- /dev/null +++ b/scripts/create-patch-pr.js @@ -0,0 +1,133 @@ +#!/usr/bin/env node + +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { execSync } from 'node:child_process'; +import yargs from 'yargs'; +import { hideBin } from 'yargs/helpers'; + +async function main() { + const argv = await yargs(hideBin(process.argv)) + .option('commit', { + alias: 'c', + description: 'The commit SHA to cherry-pick for the patch.', + type: 'string', + demandOption: true, + }) + .option('channel', { + alias: 'ch', + description: 'The release channel to patch.', + choices: ['stable', 'preview'], + demandOption: true, + }) + .option('dry-run', { + description: 'Whether to run in dry-run mode.', + type: 'boolean', + default: false, + }) + .help() + .alias('help', 'h').argv; + + const { commit, channel, dryRun } = argv; + + console.log(`Starting patch process for commit: ${commit}`); + console.log(`Targeting channel: ${channel}`); + if (dryRun) { + console.log('Running in dry-run mode.'); + } + + run('git fetch --all --tags --prune', dryRun); + + const latestTag = getLatestTag(channel); + console.log(`Found latest tag for ${channel}: ${latestTag}`); + + const releaseBranch = `release/${latestTag}`; + const hotfixBranch = `hotfix/${latestTag}/cherry-pick-${commit.substring(0, 7)}`; + + // Create the release branch from the tag if it doesn't exist. + if (!branchExists(releaseBranch)) { + console.log( + `Release branch ${releaseBranch} does not exist. Creating it from tag ${latestTag}...`, + ); + run(`git checkout -b ${releaseBranch} ${latestTag}`, dryRun); + run(`git push origin ${releaseBranch}`, dryRun); + } else { + console.log(`Release branch ${releaseBranch} already exists.`); + } + + // Create the hotfix branch from the release branch. + console.log( + `Creating hotfix branch ${hotfixBranch} from ${releaseBranch}...`, + ); + run(`git checkout -b ${hotfixBranch} origin/${releaseBranch}`, dryRun); + + // Cherry-pick the commit. + console.log(`Cherry-picking commit ${commit} into ${hotfixBranch}...`); + run(`git cherry-pick ${commit}`, dryRun); + + // Push the hotfix branch. + console.log(`Pushing hotfix branch ${hotfixBranch} to origin...`); + run(`git push --set-upstream origin ${hotfixBranch}`, dryRun); + + // Create the pull request. + console.log( + `Creating pull request from ${hotfixBranch} to ${releaseBranch}...`, + ); + const prTitle = `fix(patch): cherry-pick ${commit.substring(0, 7)} to ${releaseBranch}`; + let prBody = `This PR automatically cherry-picks commit ${commit} to patch the ${channel} release.`; + if (dryRun) { + prBody += '\n\n**[DRY RUN]**'; + } + run( + `gh pr create --base ${releaseBranch} --head ${hotfixBranch} --title "${prTitle}" --body "${prBody}"`, + dryRun, + ); + + console.log('Patch process completed successfully!'); +} + +function run(command, dryRun = false) { + console.log(`> ${command}`); + if (dryRun) { + return; + } + try { + return execSync(command).toString().trim(); + } catch (err) { + console.error(`Command failed: ${command}`); + throw err; + } +} + +function branchExists(branchName) { + try { + execSync(`git ls-remote --exit-code --heads origin ${branchName}`); + return true; + } catch (_e) { + return false; + } +} + +function getLatestTag(channel) { + console.log(`Fetching latest tag for channel: ${channel}...`); + const pattern = + channel === 'stable' + ? '\'(contains("nightly") or contains("preview")) | not\'' + : '\'(contains("preview"))\''; + const command = `gh release list --limit 30 --json tagName | jq -r '[.[] | select(.tagName | ${pattern})] | .[0].tagName'`; + try { + return execSync(command).toString().trim(); + } catch (err) { + console.error(`Failed to get latest tag for channel: ${channel}`); + throw err; + } +} + +main().catch((err) => { + console.error(err); + process.exit(1); +}); diff --git a/scripts/get-release-version.js b/scripts/get-release-version.js index 9d27410d458..fd9a7608e82 100644 --- a/scripts/get-release-version.js +++ b/scripts/get-release-version.js @@ -1,3 +1,5 @@ +#!/usr/bin/env node + /** * @license * Copyright 2025 Google LLC @@ -5,129 +7,82 @@ */ import { execSync } from 'node:child_process'; - -function getLatestStableTag() { - // Fetches all tags, then filters for the latest stable (non-prerelease) tag. - const tags = execSync('git tag --list "v*.*.*" --sort=-v:refname') - .toString() - .split('\n'); - const latestStableTag = tags.find((tag) => - tag.match(/^v[0-9]+\.[0-9]+\.[0-9]+$/), - ); - if (!latestStableTag) { - throw new Error('Could not find a stable tag.'); - } - return latestStableTag; -} - -function getShortSha() { - return execSync('git rev-parse --short HEAD').toString().trim(); +import { fileURLToPath } from 'node:url'; +import { readFileSync } from 'node:fs'; +import path from 'node:path'; + +const __filename = fileURLToPath(import.meta.url); +const __dirname = path.dirname(__filename); + +function getArgs() { + const args = {}; + process.argv.slice(2).forEach((arg) => { + if (arg.startsWith('--')) { + const [key, value] = arg.substring(2).split('='); + args[key] = value === undefined ? true : value; + } + }); + return args; } -function getNextVersionString(stableVersion, minorIncrement) { - const [major, minor] = stableVersion.substring(1).split('.'); - const nextMinorVersion = parseInt(minor, 10) + minorIncrement; - return `${major}.${nextMinorVersion}.0`; -} - -export function getNightlyTagName(stableVersion) { - const version = getNextVersionString(stableVersion, 2); - - const now = new Date(); - const year = now.getUTCFullYear().toString(); - const month = (now.getUTCMonth() + 1).toString().padStart(2, '0'); - const day = now.getUTCDate().toString().padStart(2, '0'); - const date = `${year}${month}${day}`; - - const sha = getShortSha(); - return `v${version}-nightly.${date}.${sha}`; -} - -export function getPreviewTagName(stableVersion) { - const version = getNextVersionString(stableVersion, 1); - return `v${version}-preview`; -} - -function getPreviousReleaseTag(isNightly) { - if (isNightly) { - console.error('Finding latest nightly release...'); - return execSync( - `gh release list --limit 100 --json tagName | jq -r '[.[] | select(.tagName | contains("nightly"))] | .[0].tagName'`, - ) - .toString() - .trim(); - } else { - console.error('Finding latest STABLE release (excluding pre-releases)...'); - return execSync( - `gh release list --limit 100 --json tagName | jq -r '[.[] | select(.tagName | (contains("nightly") or contains("preview")) | not)] | .[0].tagName'`, - ) - .toString() - .trim(); +function getLatestTag(pattern) { + const command = `gh release list --limit 100 --json tagName | jq -r '[.[] | select(.tagName | ${pattern})] | .[0].tagName'`; + try { + return execSync(command).toString().trim(); + } catch { + // Suppress error output for cleaner test failures + return ''; } } -export function getReleaseVersion() { - const isNightly = process.env.IS_NIGHTLY === 'true'; - const isPreview = process.env.IS_PREVIEW === 'true'; - const manualVersion = process.env.MANUAL_VERSION; - - let releaseTag; - - if (isNightly) { - console.error('Calculating next nightly version...'); - const stableVersion = getLatestStableTag(); - releaseTag = getNightlyTagName(stableVersion); - } else if (isPreview) { - console.error('Calculating next preview version...'); - const stableVersion = getLatestStableTag(); - releaseTag = getPreviewTagName(stableVersion); - } else if (manualVersion) { - console.error(`Using manual version: ${manualVersion}`); - releaseTag = manualVersion; - } else { - throw new Error( - 'Error: No version specified and this is not a nightly or preview release.', - ); - } - - if (!releaseTag) { - throw new Error('Error: Version could not be determined.'); - } +export function getVersion(options = {}) { + const args = getArgs(); + const type = options.type || args.type || 'nightly'; - if (!releaseTag.startsWith('v')) { - console.error("Version is missing 'v' prefix. Prepending it."); - releaseTag = `v${releaseTag}`; - } + let releaseVersion; + let npmTag; + let previousReleaseTag; - if (releaseTag.includes('+')) { - throw new Error( - 'Error: Versions with build metadata (+) are not supported for releases. Please use a pre-release version (e.g., v1.2.3-alpha.4) instead.', + if (type === 'nightly') { + const packageJson = JSON.parse( + readFileSync(path.join(__dirname, '..', 'package.json'), 'utf-8'), ); - } - - if (!releaseTag.match(/^v[0-9]+\.[0-9]+\.[0-9]+(-[a-zA-Z0-9.-]+)?$/)) { - throw new Error( - 'Error: Version must be in the format vX.Y.Z or vX.Y.Z-prerelease', + const [major, minor] = packageJson.version.split('.'); + const nextMinor = parseInt(minor) + 1; + const date = new Date().toISOString().slice(0, 10).replace(/-/g, ''); + const gitShortHash = execSync('git rev-parse --short HEAD') + .toString() + .trim(); + releaseVersion = `${major}.${nextMinor}.0-nightly.${date}.${gitShortHash}`; + npmTag = 'nightly'; + previousReleaseTag = getLatestTag('contains("nightly")'); + } else if (type === 'stable') { + const latestPreviewTag = getLatestTag('contains("preview")'); + releaseVersion = latestPreviewTag + .replace(/-preview.*/, '') + .replace(/^v/, ''); + npmTag = 'latest'; + previousReleaseTag = getLatestTag( + '(contains("nightly") or contains("preview")) | not', ); + } else if (type === 'preview') { + const latestNightlyTag = getLatestTag('contains("nightly")'); + releaseVersion = + latestNightlyTag.replace(/-nightly.*/, '').replace(/^v/, '') + '-preview'; + npmTag = 'preview'; + previousReleaseTag = getLatestTag('contains("preview")'); } - const releaseVersion = releaseTag.substring(1); - let npmTag = 'latest'; - if (releaseVersion.includes('-')) { - npmTag = releaseVersion.split('-')[1].split('.')[0]; - } - - const previousReleaseTag = getPreviousReleaseTag(isNightly); + const releaseTag = `v${releaseVersion}`; - return { releaseTag, releaseVersion, npmTag, previousReleaseTag }; + return { + releaseTag, + releaseVersion, + npmTag, + previousReleaseTag, + }; } -if (process.argv[1] === new URL(import.meta.url).pathname) { - try { - const versions = getReleaseVersion(); - console.log(JSON.stringify(versions)); - } catch (error) { - console.error(error.message); - process.exit(1); - } +if (process.argv[1] === fileURLToPath(import.meta.url)) { + console.log(JSON.stringify(getVersion(), null, 2)); } diff --git a/scripts/tests/get-release-version.test.js b/scripts/tests/get-release-version.test.js index c91844fb51d..d21c062f32f 100644 --- a/scripts/tests/get-release-version.test.js +++ b/scripts/tests/get-release-version.test.js @@ -4,151 +4,82 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { getReleaseVersion } from '../get-release-version'; - -// Mock child_process so we can spy on execSync -vi.mock('child_process', () => ({ - execSync: vi.fn(), -})); - -describe('getReleaseVersion', async () => { - // Dynamically import execSync after mocking - const { execSync } = await import('node:child_process'); - const originalEnv = { ...process.env }; +import { vi, describe, it, expect, beforeEach } from 'vitest'; +import { getVersion } from '../get-release-version.js'; +import { execSync } from 'node:child_process'; + +vi.mock('node:child_process'); +vi.mock('node:fs'); + +vi.mock('../get-release-version.js', async () => { + const actual = await vi.importActual('../get-release-version.js'); + return { + ...actual, + getVersion: (options) => { + if (options.type === 'nightly') { + return { + releaseTag: 'v0.6.0-nightly.20250911.a1b2c3d', + releaseVersion: '0.6.0-nightly.20250911.a1b2c3d', + npmTag: 'nightly', + previousReleaseTag: 'v0.5.0-nightly.20250910.abcdef', + }; + } + return actual.getVersion(options); + }, + }; +}); +describe('getReleaseVersion', () => { beforeEach(() => { vi.resetAllMocks(); - process.env = { ...originalEnv }; // Mock date to be consistent - vi.setSystemTime(new Date('2025-08-20T00:00:00.000Z')); - // Provide a default mock for execSync to avoid toString() on undefined - vi.mocked(execSync).mockReturnValue(''); - }); - - afterEach(() => { - process.env = originalEnv; - vi.useRealTimers(); - }); - - it('should generate a nightly version and get previous tag', () => { - process.env.IS_NIGHTLY = 'true'; - - vi.mocked(execSync).mockImplementation((command) => { - if (command.includes('git tag')) { - return 'v0.1.0\nv0.0.1'; - } - if (command.includes('git rev-parse')) { - return 'abcdef'; - } - if (command.includes('gh release list')) { - return 'v0.3.0-nightly.20250819.abcdef'; - } - return ''; - }); - - const result = getReleaseVersion(); - - expect(result).toEqual({ - releaseTag: 'v0.3.0-nightly.20250820.abcdef', - releaseVersion: '0.3.0-nightly.20250820.abcdef', - npmTag: 'nightly', - previousReleaseTag: 'v0.3.0-nightly.20250819.abcdef', - }); + vi.setSystemTime(new Date('2025-09-11T00:00:00.000Z')); }); - it('should generate a preview version and get previous tag', () => { - process.env.IS_PREVIEW = 'true'; - - vi.mocked(execSync).mockImplementation((command) => { - if (command.includes('git tag')) { - return 'v0.1.0\nv0.0.1'; - } - if (command.includes('gh release list')) { - return 'v0.1.0'; // Previous stable release - } - return ''; - }); - - const result = getReleaseVersion(); + describe('Nightly Workflow Logic', () => { + it('should calculate the next nightly version based on package.json', async () => { + const { getVersion } = await import('../get-release-version.js'); + const result = getVersion({ type: 'nightly' }); - expect(result).toEqual({ - releaseTag: 'v0.2.0-preview', - releaseVersion: '0.2.0-preview', - npmTag: 'preview', - previousReleaseTag: 'v0.1.0', + expect(result.releaseVersion).toBe('0.6.0-nightly.20250911.a1b2c3d'); + expect(result.npmTag).toBe('nightly'); + expect(result.previousReleaseTag).toBe('v0.5.0-nightly.20250910.abcdef'); }); }); - it('should use the manual version and get previous tag', () => { - process.env.MANUAL_VERSION = 'v0.1.1'; + describe('Promote Workflow Logic', () => { + it('should calculate stable version from the latest preview tag', () => { + const latestPreview = 'v0.5.0-preview'; + const latestStable = 'v0.4.0'; - vi.mocked(execSync).mockImplementation((command) => { - if (command.includes('gh release list')) { - return 'v0.1.0'; // Previous stable release - } - return ''; - }); + vi.mocked(execSync).mockImplementation((command) => { + if (command.includes('not')) return latestStable; + if (command.includes('contains("preview")')) return latestPreview; + return ''; + }); - const result = getReleaseVersion(); + const result = getVersion({ type: 'stable' }); - expect(result).toEqual({ - releaseTag: 'v0.1.1', - releaseVersion: '0.1.1', - npmTag: 'latest', - previousReleaseTag: 'v0.1.0', + expect(result.releaseVersion).toBe('0.5.0'); + expect(result.npmTag).toBe('latest'); + expect(result.previousReleaseTag).toBe(latestStable); }); - }); - - it('should prepend v to manual version if missing', () => { - process.env.MANUAL_VERSION = '1.2.3'; - const { releaseTag } = getReleaseVersion(); - expect(releaseTag).toBe('v1.2.3'); - }); - - it('should handle pre-release versions correctly', () => { - process.env.MANUAL_VERSION = 'v1.2.3-beta.1'; - const { releaseTag, releaseVersion, npmTag } = getReleaseVersion(); - expect(releaseTag).toBe('v1.2.3-beta.1'); - expect(releaseVersion).toBe('1.2.3-beta.1'); - expect(npmTag).toBe('beta'); - }); - it('should throw an error for invalid version format', () => { - process.env.MANUAL_VERSION = '1.2'; - expect(() => getReleaseVersion()).toThrow( - 'Error: Version must be in the format vX.Y.Z or vX.Y.Z-prerelease', - ); - }); + it('should calculate preview version from the latest nightly tag', () => { + const latestNightly = 'v0.6.0-nightly.20250910.abcdef'; + const latestPreview = 'v0.5.0-preview'; - it('should throw an error if no version is provided for non-nightly/preview release', () => { - expect(() => getReleaseVersion()).toThrow( - 'Error: No version specified and this is not a nightly or preview release.', - ); - }); - - it('should throw an error for versions with build metadata', () => { - process.env.MANUAL_VERSION = 'v1.2.3+build456'; - expect(() => getReleaseVersion()).toThrow( - 'Error: Versions with build metadata (+) are not supported for releases.', - ); - }); + vi.mocked(execSync).mockImplementation((command) => { + if (command.includes('nightly')) return latestNightly; + if (command.includes('preview')) return latestPreview; + return ''; + }); - it('should correctly calculate the next version from a patch release', () => { - process.env.IS_PREVIEW = 'true'; + const result = getVersion({ type: 'preview' }); - vi.mocked(execSync).mockImplementation((command) => { - if (command.includes('git tag')) { - return 'v1.1.3\nv1.1.2\nv1.1.1\nv1.1.0\nv1.0.0'; - } - if (command.includes('gh release list')) { - return 'v1.1.3'; - } - return ''; + expect(result.releaseVersion).toBe('0.6.0-preview'); + expect(result.npmTag).toBe('preview'); + expect(result.previousReleaseTag).toBe(latestPreview); }); - - const result = getReleaseVersion(); - - expect(result.releaseTag).toBe('v1.2.0-preview'); }); }); diff --git a/third_party/get-ripgrep/LICENSE b/third_party/get-ripgrep/LICENSE new file mode 100644 index 00000000000..77b44f1dfa1 --- /dev/null +++ b/third_party/get-ripgrep/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Lvce Editor + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/get-ripgrep/package.json b/third_party/get-ripgrep/package.json new file mode 100644 index 00000000000..80e01ae7ac8 --- /dev/null +++ b/third_party/get-ripgrep/package.json @@ -0,0 +1,46 @@ +{ + "name": "@lvce-editor/ripgrep", + "version": "0.0.0-dev", + "description": "A module for using ripgrep in a Node project", + "main": "src/index.js", + "typings": "src/index.d.ts", + "type": "module", + "repository": { + "type": "git", + "url": "https://github.com/lvce-editor/ripgrep" + }, + "scripts": { + "postinstall": "node ./src/postinstall.js", + "test": "node --experimental-vm-modules node_modules/jest/bin/jest.js", + "test:watch": "node --experimental-vm-modules node_modules/jest/bin/jest.js --watch", + "format": "prettier --write ." + }, + "keywords": [ + "lvce-editor", + "ripgrep" + ], + "author": "Lvce Editor", + "license": "MIT", + "dependencies": { + "@lvce-editor/verror": "^1.6.0", + "execa": "^9.5.2", + "extract-zip": "^2.0.1", + "fs-extra": "^11.3.0", + "got": "^14.4.5", + "path-exists": "^5.0.0", + "tempy": "^3.1.0", + "xdg-basedir": "^5.1.0" + }, + "devDependencies": { + "@types/fs-extra": "^11.0.4", + "@types/jest": "^29.5.14", + "@types/node": "^22.13.0", + "jest": "^29.7.0", + "prettier": "^3.4.2", + "typescript": "^5.7.3" + }, + "prettier": { + "semi": false, + "singleQuote": true + } +} diff --git a/third_party/get-ripgrep/src/downloadRipGrep.js b/third_party/get-ripgrep/src/downloadRipGrep.js new file mode 100644 index 00000000000..906b2a9e2e5 --- /dev/null +++ b/third_party/get-ripgrep/src/downloadRipGrep.js @@ -0,0 +1,123 @@ +/* eslint-disable */ +/** + * @license + * Copyright 2023 Lvce Editor + * SPDX-License-Identifier: MIT + */ +import { VError } from '@lvce-editor/verror' +import { execa } from 'execa' +import extractZip from 'extract-zip' +import fsExtra from 'fs-extra' +import got from 'got' +import * as os from 'node:os' +import { dirname, join } from 'node:path' +import { pathExists } from 'path-exists' +import { pipeline } from 'node:stream/promises' +import { temporaryFile } from 'tempy' +import { fileURLToPath } from 'node:url' +import { xdgCache } from 'xdg-basedir' + +const { mkdir, createWriteStream, move } = fsExtra + +const __dirname = dirname(fileURLToPath(import.meta.url)) + +const REPOSITORY = `microsoft/ripgrep-prebuilt` +const VERSION = process.env.RIPGREP_VERSION || 'v13.0.0-10' +console.log({ VERSION }) +const BIN_PATH = join(__dirname, '../bin') + +const getTarget = () => { + const arch = process.env.npm_config_arch || os.arch() + const platform = process.env.platform || os.platform() + switch (platform) { + case 'darwin': + switch (arch) { + case 'arm64': + return 'aarch64-apple-darwin.tar.gz' + default: + return 'x86_64-apple-darwin.tar.gz' + } + case 'win32': + switch (arch) { + case 'x64': + return 'x86_64-pc-windows-msvc.zip' + case 'arm': + return 'aarch64-pc-windows-msvc.zip' + default: + return 'i686-pc-windows-msvc.zip' + } + case 'linux': + switch (arch) { + case 'x64': + return 'x86_64-unknown-linux-musl.tar.gz' + case 'arm': + case 'armv7l': + return 'arm-unknown-linux-gnueabihf.tar.gz' + case 'arm64': + return 'aarch64-unknown-linux-gnu.tar.gz' + case 'ppc64': + return 'powerpc64le-unknown-linux-gnu.tar.gz' + case 's390x': + return 's390x-unknown-linux-gnu.tar.gz' + default: + return 'i686-unknown-linux-musl.tar.gz' + } + default: + throw new VError('Unknown platform: ' + platform) + } +} + +export const downloadFile = async (url, outFile) => { + try { + const tmpFile = temporaryFile() + await pipeline(got.stream(url), createWriteStream(tmpFile)) + await mkdir(dirname(outFile), { recursive: true }) + await move(tmpFile, outFile) + } catch (error) { + throw new VError(error, `Failed to download "${url}"`) + } +} + +/** + * @param {string} inFile + * @param {string} outDir + */ +const unzip = async (inFile, outDir) => { + try { + await mkdir(outDir, { recursive: true }) + await extractZip(inFile, { dir: outDir }) + } catch (error) { + throw new VError(error, `Failed to unzip "${inFile}"`) + } +} + +/** + * @param {string} inFile + * @param {string} outDir + */ +const untarGz = async (inFile, outDir) => { + try { + await mkdir(outDir, { recursive: true }) + await execa('tar', ['xvf', inFile, '-C', outDir]) + } catch (error) { + throw new VError(error, `Failed to extract "${inFile}"`) + } +} + +export const downloadRipGrep = async () => { + const target = getTarget() + const url = `https://github.com/${REPOSITORY}/releases/download/${VERSION}/ripgrep-${VERSION}-${target}` + const downloadPath = `${xdgCache}/vscode-ripgrep/ripgrep-${VERSION}-${target}` + if (!(await pathExists(downloadPath))) { + await downloadFile(url, downloadPath) + } else { + console.info(`File ${downloadPath} has been cached`) + } + if (downloadPath.endsWith('.tar.gz')) { + await untarGz(downloadPath, BIN_PATH) + } else if (downloadPath.endsWith('.zip')) { + await unzip(downloadPath, BIN_PATH) + } else { + throw new VError(`Invalid downloadPath ${downloadPath}`) + } +} diff --git a/third_party/get-ripgrep/src/index.js b/third_party/get-ripgrep/src/index.js new file mode 100644 index 00000000000..8fc965e98fe --- /dev/null +++ b/third_party/get-ripgrep/src/index.js @@ -0,0 +1,17 @@ +/* eslint-disable */ +/** + * @license + * Copyright 2023 Lvce Editor + * SPDX-License-Identifier: MIT + */ +import { dirname, join } from 'node:path' +import { fileURLToPath } from 'node:url' + +const __dirname = dirname(fileURLToPath(import.meta.url)) + +export const rgPath = join( + __dirname, + '..', + 'bin', + `rg${process.platform === 'win32' ? '.exe' : ''}`, +)