Merge branch 'main' into win-legacy-compat

This commit is contained in:
kevin-mai 2026-05-28 19:09:06 +08:00 committed by GitHub
commit 70d099fa0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
137 changed files with 11672 additions and 990 deletions

View file

@ -0,0 +1,160 @@
---
name: gpui-test
description: >-
Use when writing, debugging, or reproducing GPUI tests in Zed, including
gpui::test arguments, TestAppContext parameters, scheduler seeds,
ITERATIONS/SEED reproduction, parking failures, and pending task traces.
---
# GPUI Test Debugging
Use this skill when the user asks about `#[gpui::test]`, GPUI test seeds or iterations, deterministic scheduler failures, parking/pending task failures, or how to reproduce a flaky GPUI test.
## What `#[gpui::test]` does
`#[gpui::test]` expands to a normal Rust `#[test]`, so it runs under standard Rust test runners such as `cargo test` and `cargo nextest`.
It wraps the body in GPUI's deterministic test dispatcher/scheduler and can run the same test multiple times with different seeds. The seed controls scheduler task interleavings and any `StdRng` argument injected into the test.
The macro supports both synchronous and asynchronous tests.
### Supported function arguments
The macro recognizes arguments by type name:
| Test kind | Supported arguments |
| --- | --- |
| Sync and async | `&TestAppContext`, `&mut TestAppContext`, `StdRng` |
| Async only | `BackgroundExecutor` |
| Sync only | `&App`, `&mut App` |
`StdRng` is seeded from the current GPUI test seed, and `BackgroundExecutor` is backed by the same deterministic test dispatcher.
### Attribute arguments
Use these forms on `#[gpui::test(arguments)]`:
- No arguments: runs once with seed `0`, unless `SEED` is set.
- `seed = N`: adds a single explicit seed.
- `seeds(...)`: adds multiple explicit seeds.
- `iterations = N`: runs sequential seeds starting at `0` by default.
- `retries = N`: retries a failing run up to `N` times before surfacing the failure.
- `on_failure = "path::to::function"`: calls the function after final failure, before resuming the panic.
- `iterations` can be combined with explicit `seed` / `seeds`; explicit seeds are appended to the `0..iterations` range.
- If the `SEED` environment variable is set, it takes precedence over explicit seeds.
- With `SEED=N` and `ITERATIONS=M` or `iterations = M`, the harness runs seeds `N..N+M`.
## Environment variables
### GPUI test macro / scheduler execution
- `SEED=<u64>` — chooses the scheduler seed. Use this to reproduce a failure printed as `failing seed: N`. It also seeds injected `StdRng` arguments. For `#[gpui::property_test]`, it controls the scheduler seed and GPUI applies it to the proptest config for deterministic case generation.
- `ITERATIONS=<usize>` — overrides the `iterations = ...` value at runtime. Use to sweep many seeds without editing the test.
- `PENDING_TRACES=1` or `PENDING_TRACES=true` — captures and prints pending task traces when the test scheduler panics with `Parking forbidden`. Use this when `run_until_parked()` or teardown reports pending work.
- `GPUI_RUN_UNTIL_PARKED_LOG=1` — logs when `allow_parking()` is enabled. Use to find tests that explicitly permit parking/pending work.
- `DEBUG_SCHEDULER=1` — prints scheduler clock/timer debugging from `scheduler::TestScheduler`.
### Lower-level scheduler tests
- `SCHEDULER_NONINTERACTIVE=1` — suppresses interactive seed progress output in `scheduler::TestScheduler::many`. This does not affect the `#[gpui::test]` harness path.
### General Rust test debugging vars often useful with GPUI tests
- `RUST_BACKTRACE=1` or `RUST_BACKTRACE=full` — show panic backtraces.
- `RUST_LOG=<filter>` — enable logs when the test initializes logging.
- `ZED_HEADLESS=1` — forces GPUI platform guessing toward headless mode; useful for tests that otherwise interact with platform/window setup.
Prefer env vars over editing the test when narrowing a reproduction.
## Reproducing a specific GPUI test
1. Identify the crate/package and test name.
2. Run the narrowest test filter first, skip to 3. if a failing seed is known.
```sh
cargo -q test -p <crate-name> <test_name> -- --nocapture
```
3. If the failure mentions a seed, rerun exactly that seed.
```sh
SEED=<seed> cargo -q test -p <crate-name> <test_name> -- --nocapture
```
4. If the failure is flaky and no seed is known, sweep seeds.
```sh
ITERATIONS=100 cargo -q test -p <crate-name> <test_name> -- --nocapture
```
When the harness prints `failing seed: <seed>`, switch to `SEED=<seed>` for all future debugging.
5. If the failure is `Parking forbidden`, rerun with pending traces.
```sh
PENDING_TRACES=1 cargo -q test -p <crate-name> <test_name> -- --nocapture
```
If a failing seed was printed or is already known, include it too:
```sh
SEED=<seed> PENDING_TRACES=1 cargo -q test -p <crate-name> <test_name> -- --nocapture
```
Inspect the pending traces for a task that was spawned but not awaited, detached, completed, or intentionally allowed to park.
6. If timing or timer advancement is involved, prefer GPUI scheduler timers in tests:
```rust
cx.background_executor().timer(duration).await;
```
Avoid `smol::Timer::after(...)` in GPUI tests that rely on `run_until_parked()`, because GPUI's scheduler may not track it.
7. Minimize the reproduction.
- Keep the failing `SEED` fixed.
- Reduce `ITERATIONS` to `1` or remove it once a seed is known.
- Remove unrelated setup only after confirming the same seed still fails.
- Preserve scheduler-sensitive awaits/yields; removing them can mask the bug.
- If randomness is test-controlled via `StdRng`, log or assert the generated scenario after fixing the scheduler seed.
8. Validate the fix.
- Run the fixed seed.
- Run a modest seed sweep, e.g. `ITERATIONS=20`, if the failure was scheduler-sensitive.
- Run the relevant crate's test filter or broader suite if the touched code has shared behavior.
## Common diagnosis patterns
### Seed-dependent assertion failure
Likely caused by a scheduler interleaving or by `StdRng`-driven test data. Fix `SEED`, reproduce, and inspect which task or generated scenario differs.
### `Parking forbidden`
Usually means a foreground/background task is still pending when the scheduler expected the test to make progress or finish. Look for:
- A task that should be awaited but was dropped.
- A task that should be detached with error logging.
- A timer or receiver that is waiting forever.
- A missing `cx.run_until_parked()` after triggering async work in a test.
- A missing `cx.advance_clock(...)` to wait for debounced work in a test.
- Use of non-GPUI timers or executors that the test scheduler cannot drive.
Rerun with `PENDING_TRACES=1` before changing code.
### Non-determinism / wrong thread
The scheduler can report activity from an unexpected thread. Look for work escaping GPUI's foreground/background executors, direct thread spawns, or external async runtimes not controlled by the test dispatcher.
### Tests pass alone but fail in sweeps
Use the failing seed from sweep output. Avoid assuming test order unless the runner is explicitly serial. Check globals, leaked entities/tasks, and state not reset by test initialization.
## Writing GPUI tests
- Prefer `#[gpui::test]` for tests that need `TestAppContext`, deterministic executors, fake time, or scheduler interleaving coverage.
- Add `iterations = N` when the test is intentionally checking interleavings.
- Use `StdRng` as a test argument when randomized test data should follow the same seed as the scheduler.
- Use `cx.background_executor().timer(duration).await` for delays/timeouts in GPUI tests.
- Do not add or increase `retries` while fixing a test unless the user explicitly asks or the test already documents why probabilistic tolerance is intentional. Retries can mask the failure instead of fixing it.

View file

@ -0,0 +1,175 @@
---
name: zed-cherry-pick
description: Cherry-pick one or more merged PRs and/or commits into Zed's `preview` or `stable` release branch. Use this whenever the user mentions cherry-picking to preview/stable, a failed cherry-pick run, or wants to manually port fix(es) into a release branch.
---
# Zed Cherry-Pick
Zed ships from two long-lived release branches that live on `origin`:
- `preview` channel → branch like `v1.4.x`
- `stable` channel → branch like `v1.3.x`
The version numbers change with each release. **Never hardcode them — always discover the current mapping** (see [Finding the target branch](#finding-the-target-branch)).
A merged PR on `main` gets ported to a release branch by `script/cherry-pick`, normally driven by the `cherry_pick` GitHub Actions workflow. When that workflow fails (almost always a merge conflict), use this skill to finish the job locally and open the cherry-pick PR by hand.
## When to use
Use this when the user asks to cherry-pick one or more commits and/or Pull Requests (by number or URL) to `preview` or `stable`.
Optionally, the user may specify whether to resolve merge conflicts; if unspecified, attempt the cherry-pick, and then if there are merge conflicts in practice, stop and inform the user that there are merge conflicts and offer to resolve them. (Users may prefer to resolve the merge conflicts themselves before continuing.)
## The script you're emulating
The canonical procedure lives in `script/cherry-pick` and the `cherry_pick` GitHub Actions workflow. Read the script first if anything looks off — your local steps must produce the same branch name, PR title, and PR body it would.
Signature: `script/cherry-pick <branch-name> <commit-sha> <channel>`
- `<branch-name>` is the release branch (e.g. `v1.4.x`), **not** the channel name.
- `<channel>` is `preview` or `stable`, used only for display text in the PR title/body.
It creates a local branch named `cherry-pick-<branch-name>-<short-sha>` (the short SHA is the first 8 chars of the commit), force-pushes it to `origin`, and opens a PR.
## Finding the target branch
The channel→branch mapping changes every release. Find the current one by inspecting the most recent `cherry_pick` workflow runs:
```
gh run list --workflow=cherry_pick.yml --limit 30 --json displayTitle,databaseId
# pick a recent run for the channel you want, then:
gh run view <id> --log 2>&1 | grep -E "BRANCH:|CHANNEL:"
```
A successful run prints both `BRANCH:` and `CHANNEL:` env vars; that's your mapping.
## Procedure
### 1. Gather context
You need three things: the **merge commit SHA**, the **target branch**, and the **channel name**.
If the user requested multiple PRs and/or commits, gather the metadata for all of them first and cherry-pick them in the order they landed on `main`, oldest to newest. For PRs, order by `mergedAt`; for raw commits, use their order on `main` when available, otherwise commit date. This tends to reduce avoidable conflicts because later changes may depend on earlier ones, but it does not guarantee a conflict-free cherry-pick when the release branch has diverged.
```
gh pr view <PR_NUMBER> --json title,number,mergeCommit,mergedAt,url
```
If the user said the workflow failed, fetch its log to see exactly which command failed and which file conflicted:
```
gh run list --workflow=cherry_pick.yml --limit 10 --json databaseId,displayTitle,status,conclusion
gh run view <failed_run_id> --log-failed
```
The failed-run log also confirms the `BRANCH` and `COMMIT` the workflow used — handy if there's any ambiguity.
### 2. Reproduce the script's setup locally
The repository may be a worktree (check `.git` — if it's a file, you're in a worktree pointing at a shared gitdir). That's fine; just operate normally.
```
git --no-pager fetch origin <branch-name> <commit-sha>
git checkout --force origin/<branch-name> -B cherry-pick-<branch-name>-<short-sha>
git cherry-pick <commit-sha>
```
The branch name **must** match `cherry-pick-<branch-name>-<short-sha>` exactly (script convention; reviewers and tooling expect it).
### 3. Check for missing prerequisite cherry-picks
If the cherry-pick conflicts, do not immediately resolve the conflicts manually.
First determine whether the conflict is likely caused by other PRs or commits that are already on `main` but missing from the release branch. If so, point out those candidate prerequisite PRs/commits to the user, including PR links, and offer to either resolve the conflicts manually or let the user run the GitHub cherry-pick workflow for those commits first.
If the user wants to run the workflow for the missing prerequisites, stop here. This often keeps cherry-picks clean and eligible for automatic approval.
Only resolve conflicts manually if:
- no likely missing prerequisites are found, or
- the user chooses manual conflict resolution instead of cherry-picking the prerequisites first.
### 4. Resolve the conflicts manually
Do this only after checking for missing prerequisite cherry-picks.
- Inspect every conflicted file with `grep -n '<<<<<<<\\|>>>>>>>\\|=======' <path>` to find the markers.
- Conflicts are usually `diff3` style with three sections: HEAD (release branch), `||||||| parent of <sha>` (merge base on `main`), and the incoming change.
- Read the **original commit** (`git --no-pager show <commit-sha> -- <path>`) to understand the author's intent, then pick the resolution that produces the equivalent end state on the release branch.
- Don't grab unrelated changes from `main` that happen to surround the conflict — keep the cherry-pick minimal.
### 5. Validate
Always build and (if reasonable) test the affected crate(s) before continuing the cherry-pick.
```
cargo check -p <affected_crate>
cargo test -p <affected_crate>
```
If validation fails, fix the resolution — do **not** continue with a broken build. If you can't reach a clean state, abort with `git cherry-pick --abort` and report back to the user.
### 6. Finish the cherry-pick
`git cherry-pick --continue` opens an editor by default. Prevent that:
```
git add <resolved_files>
GIT_EDITOR=true git cherry-pick --continue
```
This preserves the original commit message verbatim, which is what the script does.
### 7. Push and open the PR
```
git push origin -f cherry-pick-<branch-name>-<short-sha>
```
Then create the PR with the **exact** title and body format `script/cherry-pick` uses, so it's indistinguishable from an automated one.
**Title:**
```
<original commit subject> (cherry-pick to <channel>)
```
The original commit subject already ends in ` (#<original_pr_number>)`; keep it.
**Body** (when the original commit title ends in `(#<N>)`, which is the normal case):
```
Cherry-pick of #<original_pr_number> to <channel>
----
<original commit body, verbatim>
```
Create it with `gh pr create`, writing the body to a temp file to keep formatting intact:
```
git --no-pager log -1 --pretty=format:"%b" > /tmp/cp-body-tail.md
printf 'Cherry-pick of #%s to %s\n\n----\n' <PR_NUMBER> <channel> | cat - /tmp/cp-body-tail.md > /tmp/cp-body.md
gh pr create --base <branch-name> --head cherry-pick-<branch-name>-<short-sha> \\
--title "<commit subject> (cherry-pick to <channel>)" \\
--body-file /tmp/cp-body.md
```
Do **not** add a `Release Notes:` section — the original commit body already has one (or already says `N/A`), and you don't want it duplicated.
## Final report to the user
Tell the user:
- The new PR URL.
- A one-line summary of the conflict and how you resolved it.
- What validation you ran (commands + result).
- That their local branch is now `cherry-pick-<branch-name>-<short-sha>`, in case they want you to switch back.
## Gotchas
- **`--no-pager` and `GIT_EDITOR=true`**: required for non-interactive git in this environment. Forgetting `GIT_EDITOR=true` on `cherry-pick --continue` hangs the terminal.
- **Worktree index lock**: if a previous git command was interrupted, you may see `index.lock` errors. The lock lives at `<gitdir>/index.lock` where `<gitdir>` is what `cat .git` points to (for a worktree). Remove it only if you're sure no git process is running.
- **Don't expand the cherry-pick's scope**: when resolving conflicts, never pull in unrelated changes from `main` just because they sit next to the conflict region. The PR should be the smallest diff that reproduces the original commit's intent on the release branch.
- **Channel branches are not called `preview`/`stable`**: don't try to `git fetch origin preview`. Look up the actual `vX.Y.x` branch name first.
## When Finished
After everything is finished, the last thing to do is to provide a link to the opened pull request(s) for the cherry-pick(s).

3
.gitignore vendored
View file

@ -55,3 +55,6 @@ crates/docs_preprocessor/actions.json
# Local documentation audit files
/december-2025-releases.md
/docs/december-2025-documentation-gaps.md
# NixOS integration test state
.nixos-test-history

1199
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -130,6 +130,7 @@ members = [
"crates/lsp",
"crates/markdown",
"crates/markdown_preview",
"crates/mermaid_render",
"crates/media",
"crates/menu",
"crates/migrator",
@ -172,6 +173,7 @@ members = [
"crates/rope",
"crates/rpc",
"crates/rules_library",
"crates/sandbox",
"crates/skill_creator",
"crates/scheduler",
"crates/schema_generator",
@ -389,10 +391,10 @@ lmstudio = { path = "crates/lmstudio" }
lsp = { path = "crates/lsp" }
markdown = { path = "crates/markdown" }
markdown_preview = { path = "crates/markdown_preview" }
mermaid_render = { path = "crates/mermaid_render" }
svg_preview = { path = "crates/svg_preview" }
media = { path = "crates/media" }
menu = { path = "crates/menu" }
mermaid-rs-renderer = { git = "https://github.com/zed-industries/mermaid-rs-renderer", rev = "782b89a7da3f0e91e51f98d00a93acba679be6fb", default-features = false }
migrator = { path = "crates/migrator" }
mistral = { path = "crates/mistral" }
multi_buffer = { path = "crates/multi_buffer" }
@ -435,6 +437,7 @@ rpc = { path = "crates/rpc" }
rules_library = { path = "crates/rules_library" }
skill_creator = { path = "crates/skill_creator" }
scheduler = { path = "crates/scheduler" }
sandbox = { path = "crates/sandbox" }
search = { path = "crates/search" }
session = { path = "crates/session" }
sidebar = { path = "crates/sidebar" }
@ -502,6 +505,10 @@ ztracing_macro = { path = "crates/ztracing_macro" }
# External crates
#
accesskit = "0.24.0"
accesskit_macos = "0.26.0"
accesskit_unix = "0.21.0"
accesskit_windows = "0.32.1"
agent-client-protocol = { version = "=0.12.1", features = ["unstable"] }
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty", rev = "9d9640d4" }
@ -592,7 +599,7 @@ futures = "0.3.32"
futures-concurrency = "7.7.1"
futures-lite = "1.13"
gh-workflow = { git = "https://github.com/zed-industries/gh-workflow", rev = "37f3c0575d379c218a9c455ee67585184e40d43f" }
git2 = { version = "0.20.1", default-features = false, features = ["vendored-libgit2"] }
git2 = { version = "0.21.0", default-features = false, features = ["vendored-libgit2", "unstable-sha256"] }
globset = "0.4"
heapless = "0.9.2"
handlebars = "4.3"

View file

@ -13,7 +13,7 @@ On macOS, Linux, and Windows you can [download Zed directly](https://zed.dev/dow
Other platforms are not yet available:
- Web ([tracking issue](https://github.com/zed-industries/zed/issues/5396))
- Web ([tracking discussion](https://github.com/zed-industries/zed/discussions/26195))
### Developing Zed

View file

@ -1563,6 +1563,7 @@
"context": "SkillCreator",
"bindings": {
"ctrl-w": "workspace::CloseWindow",
"ctrl-enter": "skill_creator::SaveSkill",
"tab": "skill_creator::FocusNextField",
"shift-tab": "skill_creator::FocusPreviousField",
},
@ -1571,6 +1572,7 @@
"context": "SkillCreator > Editor",
"bindings": {
"ctrl-w": "workspace::CloseWindow",
"ctrl-enter": "skill_creator::SaveSkill",
"tab": "skill_creator::FocusNextField",
"shift-tab": "skill_creator::FocusPreviousField",
},

View file

@ -1657,6 +1657,7 @@
"use_key_equivalents": true,
"bindings": {
"cmd-w": "workspace::CloseWindow",
"cmd-enter": "skill_creator::SaveSkill",
"tab": "skill_creator::FocusNextField",
"shift-tab": "skill_creator::FocusPreviousField",
},
@ -1666,6 +1667,7 @@
"use_key_equivalents": true,
"bindings": {
"cmd-w": "workspace::CloseWindow",
"cmd-enter": "skill_creator::SaveSkill",
"tab": "skill_creator::FocusNextField",
"shift-tab": "skill_creator::FocusPreviousField",
},

View file

@ -1583,6 +1583,7 @@
"use_key_equivalents": true,
"bindings": {
"ctrl-w": "workspace::CloseWindow",
"ctrl-enter": "skill_creator::SaveSkill",
"tab": "skill_creator::FocusNextField",
"shift-tab": "skill_creator::FocusPreviousField",
},
@ -1592,6 +1593,7 @@
"use_key_equivalents": true,
"bindings": {
"ctrl-w": "workspace::CloseWindow",
"ctrl-enter": "skill_creator::SaveSkill",
"tab": "skill_creator::FocusNextField",
"shift-tab": "skill_creator::FocusPreviousField",
},

View file

@ -338,7 +338,7 @@
"ctrl-x": "vim::Decrement",
"shift-j": "vim::JoinLines",
"i": "vim::InsertBefore",
"a": "vim::InsertAfter",
"a": "vim::HelixAppend",
"o": "vim::InsertLineBelow",
"shift-o": "vim::InsertLineAbove",
"p": "vim::Paste",

View file

@ -39,6 +39,7 @@ image.workspace = true
portable-pty.workspace = true
project.workspace = true
prompt_store.workspace = true
sandbox.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true

View file

@ -2955,6 +2955,7 @@ impl AcpThread {
extra_env: Vec<acp::EnvVariable>,
cwd: Option<PathBuf>,
output_byte_limit: Option<u64>,
sandbox_wrap: Option<SandboxWrap>,
cx: &mut Context<Self>,
) -> Task<Result<Entity<Terminal>>> {
let env = match &cwd {
@ -2995,6 +2996,8 @@ impl AcpThread {
ShellBuilder::new(&Shell::Program(shell), is_windows)
.redirect_stdin_to_dev_null()
.build(Some(command.clone()), &args);
let (task_command, task_args, sandbox_config) =
apply_sandbox_wrap(task_command, task_args, sandbox_wrap)?;
let terminal = project
.update(cx, |project, cx| {
project.create_terminal_task(
@ -3018,6 +3021,7 @@ impl AcpThread {
output_byte_limit.map(|l| l as usize),
terminal,
language_registry,
sandbox_config,
cx,
)
}))
@ -3097,6 +3101,9 @@ impl AcpThread {
output_byte_limit.map(|l| l as usize),
terminal,
language_registry,
// External terminal providers manage their own sandboxing
// (if any). We don't wrap their commands.
None,
cx,
)
});

View file

@ -17,6 +17,82 @@ use std::{
use task::Shell;
use util::get_default_system_shell_preferring_bash;
/// Request to run a terminal command inside an OS-level sandbox.
///
/// Passed to [`super::AcpThread::create_terminal`]. The actual sandboxing
/// mechanism is platform-specific (today: macOS Seatbelt; nothing on other
/// platforms — the wrap is silently a no-op there), so callers describe the
/// *intent* with plain data here rather than constructing platform-specific
/// types directly.
///
/// All-zero defaults are the fully-sandboxed run. Setting `allow_network` /
/// `allow_fs_write` requests a relaxation; the caller is responsible for
/// having obtained user approval before reaching this point.
#[derive(Clone, Debug, Default)]
pub struct SandboxWrap {
/// Directory subtrees the sandbox should allow writes to. Pass the
/// project's worktree paths (and any per-command scratch directory)
/// here — *not* the command's working directory, which is model-
/// controlled and would let the model widen its own writable scope.
pub writable_paths: Vec<PathBuf>,
/// Allow outbound network access for this command.
pub allow_network: bool,
/// Allow unrestricted filesystem writes (ignores `writable_paths`).
pub allow_fs_write: bool,
}
/// Opaque RAII handle the sandbox implementation hands back to keep its
/// per-command resources (e.g. an on-disk Seatbelt config file) alive for
/// the duration of the spawned command. `Terminal` holds it in a field
/// whose only job is to drop with the entity.
pub type SandboxConfigHandle = Box<dyn std::any::Any + Send>;
/// Apply a [`SandboxWrap`] to a `(program, args)` pair, substituting the
/// platform's sandbox-launcher invocation in place of the original. The
/// returned `SandboxConfigHandle` (when `Some`) must be kept alive for the
/// duration of the spawned command — dropping it deletes any on-disk
/// config the launcher reads at startup.
///
/// On non-macOS hosts this is a no-op: the inputs pass through unchanged
/// and the returned handle is `None`. (We don't yet have a sandbox
/// integration for other platforms.)
pub(crate) fn apply_sandbox_wrap(
program: String,
args: Vec<String>,
sandbox_wrap: Option<SandboxWrap>,
) -> anyhow::Result<(String, Vec<String>, Option<SandboxConfigHandle>)> {
let Some(sandbox_wrap) = sandbox_wrap else {
return Ok((program, args, None));
};
#[cfg(target_os = "macos")]
{
let writable: Vec<&std::path::Path> = sandbox_wrap
.writable_paths
.iter()
.map(|p| p.as_path())
.collect();
let permissions = sandbox::macos_seatbelt::SandboxPermissions {
allow_network: sandbox_wrap.allow_network,
allow_fs_write: sandbox_wrap.allow_fs_write,
};
let (new_program, new_args, config_file) =
sandbox::macos_seatbelt::wrap_invocation(&program, &args, &writable, permissions)?;
Ok((
new_program,
new_args,
Some(Box::new(config_file) as SandboxConfigHandle),
))
}
#[cfg(not(target_os = "macos"))]
{
// No sandbox integration available; ignore the wrap request and
// let the command run with the agent's ambient permissions.
let _ = sandbox_wrap;
Ok((program, args, None))
}
}
pub struct Terminal {
id: acp::TerminalId,
command: Entity<Markdown>,
@ -30,6 +106,10 @@ pub struct Terminal {
/// (e.g., clicking the Stop button). This is set before kill() is called
/// so that code awaiting wait_for_exit() can check it deterministically.
user_stopped: Arc<AtomicBool>,
/// RAII handle kept alive for the duration of the sandboxed command.
/// `None` when the command isn't sandboxed (the common case for
/// terminals not created by the agent).
_sandbox_config: Option<SandboxConfigHandle>,
}
pub struct TerminalOutput {
@ -48,11 +128,13 @@ impl Terminal {
output_byte_limit: Option<usize>,
terminal: Entity<terminal::Terminal>,
language_registry: Arc<LanguageRegistry>,
sandbox_config: Option<SandboxConfigHandle>,
cx: &mut Context<Self>,
) -> Self {
let command_task = terminal.read(cx).wait_for_completed_task(cx);
Self {
id,
_sandbox_config: sandbox_config,
command: cx.new(|cx| {
Markdown::new(
format!("```\n{}\n```", command_label).into(),

View file

@ -65,6 +65,7 @@ streaming_diff.workspace = true
strsim.workspace = true
task.workspace = true
telemetry.workspace = true
tempfile.workspace = true
text.workspace = true
thiserror.workspace = true
ui.workspace = true
@ -82,6 +83,7 @@ agent_servers = { workspace = true, "features" = ["test-support"] }
client = { workspace = true, "features" = ["test-support"] }
clock = { workspace = true, "features" = ["test-support"] }
context_server = { workspace = true, "features" = ["test-support"] }
criterion.workspace = true
ctor.workspace = true
db = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] }
@ -99,10 +101,14 @@ project = { workspace = true, "features" = ["test-support"] }
rand.workspace = true
reqwest_client.workspace = true
settings = { workspace = true, "features" = ["test-support"] }
tempfile.workspace = true
theme = { workspace = true, "features" = ["test-support"] }
unindent = { workspace = true }
zlog.workspace = true
[[bench]]
name = "edit_file_tool"
harness = false
required-features = ["test-support"]

View file

@ -0,0 +1,519 @@
use std::{
future::Future,
path::Path,
sync::Arc,
task::{Context, Poll},
};
use action_log::ActionLog;
use agent::{
AgentTool, ContextServerRegistry, EditFileTool, EditFileToolInput, EditFileToolOutput,
Templates, Thread, ToolCallEventStream, ToolInput,
};
use agent_settings::{AgentSettings, ToolRules};
use criterion::{
BatchSize, BenchmarkId, Criterion, Throughput, black_box, criterion_group, criterion_main,
};
use futures::{pin_mut, task::noop_waker};
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext, UpdateGlobal as _};
use language_model::fake_provider::FakeLanguageModel;
use project::{FakeFs, Project};
use prompt_store::ProjectContext;
use rand::{Rng as _, SeedableRng as _, rngs::StdRng};
use serde_json::{Value, json};
use settings::{Settings as _, SettingsStore};
const SEED: u64 = 0x5EED_5EED;
const OLD_TEXT_CHUNK_SIZE: usize = 512;
const NEW_TEXT_CHUNK_SIZE: usize = 512;
#[derive(Clone)]
struct EditFixture {
name: &'static str,
old_file_text: String,
expected_file_text: String,
old_text: String,
new_text: String,
}
struct BenchmarkHarness {
cx: Option<TestAppContext>,
edit_tool: Option<Arc<EditFileTool>>,
thread: Option<Entity<Thread>>,
partial_payloads: Vec<Value>,
final_payload: Value,
expected_file_text: String,
}
impl Drop for BenchmarkHarness {
fn drop(&mut self) {
// Release our handles to the entities first.
self.edit_tool.take();
self.thread.take();
if let Some(cx) = self.cx.take() {
// `ActionLog` holds buffers strongly via `tracked_buffers`, and spawns a background
// diff-maintenance task that also captures a strong `Entity<Buffer>`. Releasing the
// last handle to the action log only marks its entity for deferred release; the
// entity's value (and the buffer handles inside) is not actually dropped until
// `flush_effects` runs `release_dropped_entities`. Even then, the cancelled task's
// captured handle does not drop until the executor pumps the cancellation through.
//
// Without this two-step teardown, GPUI's test leak detector panics on
// `TestAppContext` drop because the buffer still appears alive. See
// `ActionLog::track_buffer_internal` and `LeakDetector::drop` in
// `crates/gpui/src/app/entity_map.rs`.
cx.update(|_| {});
cx.executor().run_until_parked();
cx.quit();
}
}
}
fn edit_file_tool_streaming(c: &mut Criterion) {
let fixtures = fixtures();
let mut group = c.benchmark_group("edit_file_tool_streaming");
group.sample_size(10);
for fixture in fixtures {
group.throughput(Throughput::Bytes(fixture.new_text.len() as u64));
group.bench_with_input(
BenchmarkId::new(fixture.name, fixture.old_text.len()),
&fixture,
|bench, fixture| {
bench.iter_batched(
|| setup_harness(fixture.clone()),
|mut harness| {
let output = run_streamed_edit(&mut harness);
let EditFileToolOutput::Success { new_text, .. } = &output else {
panic!("expected edit_file tool to succeed");
};
assert_eq!(new_text, &harness.expected_file_text);
// Return the harness as part of the output so its teardown (which has
// to pump the executor to release `Entity<Buffer>` handles captured by
// background tasks) runs in criterion's drop phase after the timer has
// stopped, rather than inside the timed region.
(black_box(output), harness)
},
BatchSize::SmallInput,
);
},
);
}
group.finish();
}
fn setup_harness(fixture: EditFixture) -> BenchmarkHarness {
let mut cx = init_context();
let executor = cx.executor();
let (edit_tool, thread) = block_on_executor(
&executor,
setup_edit_tool(&mut cx, fixture.old_file_text.clone()),
);
let partial_payloads = streamed_partial_payloads(&fixture.old_text, &fixture.new_text);
let final_payload = json!({
"path": "root/src/workspace_snapshot.rs",
"edits": [{
"old_text": fixture.old_text,
"new_text": fixture.new_text,
}],
});
BenchmarkHarness {
cx: Some(cx),
edit_tool: Some(edit_tool),
thread: Some(thread),
partial_payloads,
final_payload,
expected_file_text: fixture.expected_file_text,
}
}
fn init_context() -> TestAppContext {
let cx = TestAppContext::single();
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
SettingsStore::update_global(cx, |store: &mut SettingsStore, cx| {
store.update_user_settings(cx, |settings| {
settings
.project
.all_languages
.defaults
.ensure_final_newline_on_save = Some(false);
});
});
let mut agent_settings = AgentSettings::get_global(cx).clone();
agent_settings.tool_permissions.tools.insert(
EditFileTool::NAME.into(),
ToolRules {
default: Some(settings::ToolPermissionMode::Allow),
always_allow: vec![],
always_deny: vec![],
always_confirm: vec![],
invalid_patterns: vec![],
},
);
AgentSettings::override_global(agent_settings, cx);
});
cx
}
async fn setup_edit_tool(
cx: &mut TestAppContext,
file_text: String,
) -> (Arc<EditFileTool>, Entity<Thread>) {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"src": {
"workspace_snapshot.rs": file_text,
},
}),
)
.await;
let project = Project::test(fs, [Path::new("/root")], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
Some(model),
cx,
)
});
let action_log: Entity<ActionLog> =
thread.read_with(cx, |thread, _cx| thread.action_log().clone());
let edit_tool = Arc::new(EditFileTool::new(
project,
thread.downgrade(),
action_log,
language_registry,
));
(edit_tool, thread)
}
fn run_streamed_edit(harness: &mut BenchmarkHarness) -> EditFileToolOutput {
let (mut sender, input): (_, ToolInput<EditFileToolInput>) = ToolInput::test();
for payload in &harness.partial_payloads {
sender.send_partial(payload.clone());
}
sender.send_full(harness.final_payload.clone());
let (event_stream, _event_rx) = ToolCallEventStream::test();
let cx = harness
.cx
.as_ref()
.expect("benchmark harness should have a cx");
let task = cx.update(|cx| {
harness
.edit_tool
.as_ref()
.expect("benchmark harness should have an edit tool")
.clone()
.run(input, event_stream, cx)
});
let executor = harness
.cx
.as_ref()
.expect("benchmark harness should have a cx")
.executor();
block_on_executor(&executor, task).unwrap()
}
fn block_on_executor<R>(executor: &BackgroundExecutor, future: impl Future<Output = R>) -> R {
pin_mut!(future);
let waker = noop_waker();
let mut task_context = Context::from_waker(&waker);
for _ in 0..10_000 {
if let Poll::Ready(output) = future.as_mut().poll(&mut task_context) {
return output;
}
executor.run_until_parked();
}
panic!("future did not complete while running edit_file_tool benchmark");
}
fn streamed_partial_payloads(old_text: &str, new_text: &str) -> Vec<Value> {
let path = "root/src/workspace_snapshot.rs";
let mut payloads = Vec::new();
payloads.push(json!({ "path": path }));
payloads.push(json!({ "path": path }));
for old_end in chunk_ends(old_text, OLD_TEXT_CHUNK_SIZE) {
payloads.push(json!({
"path": path,
"edits": [{ "old_text": &old_text[..old_end] }],
}));
}
payloads.push(json!({
"path": path,
"edits": [{ "old_text": old_text, "new_text": "" }],
}));
for new_end in chunk_ends(new_text, NEW_TEXT_CHUNK_SIZE) {
payloads.push(json!({
"path": path,
"edits": [{
"old_text": old_text,
"new_text": &new_text[..new_end],
}],
}));
}
payloads
}
fn chunk_ends(text: &str, chunk_size: usize) -> impl Iterator<Item = usize> + '_ {
let mut end = 0;
std::iter::from_fn(move || {
if end == text.len() {
return None;
}
end = (end + chunk_size).min(text.len());
while !text.is_char_boundary(end) {
end -= 1;
}
Some(end)
})
}
fn fixtures() -> Vec<EditFixture> {
vec![
make_fixture(
"tiny_function_rewrite",
2,
EditPattern::LocalizedRewrite {
start_line: 12,
line_count: 6,
},
SEED,
),
make_fixture(
"small_function_rewrite",
5,
EditPattern::LocalizedRewrite {
start_line: 22,
line_count: 12,
},
SEED + 1,
),
make_fixture(
"medium_many_small_changes",
8,
EditPattern::ManySmallChanges { every_nth_line: 7 },
SEED + 2,
),
make_fixture(
"medium_insertions",
8,
EditPattern::InsertHelperBlocks { every_nth_line: 9 },
SEED + 3,
),
]
}
enum EditPattern {
LocalizedRewrite {
start_line: usize,
line_count: usize,
},
ManySmallChanges {
every_nth_line: usize,
},
InsertHelperBlocks {
every_nth_line: usize,
},
}
fn make_fixture(
name: &'static str,
function_count: usize,
pattern: EditPattern,
seed: u64,
) -> EditFixture {
let mut rng = StdRng::seed_from_u64(seed);
let old_lines = random_rust_module(&mut rng, function_count);
let edit_range = edit_range(&old_lines, &pattern);
let old_text = old_lines[edit_range.clone()].join("\n");
let mut new_lines = old_lines.clone();
match pattern {
EditPattern::LocalizedRewrite { .. } => {
rewrite_local_block(&mut new_lines[edit_range.clone()], &mut rng)
}
EditPattern::ManySmallChanges { every_nth_line } => {
rewrite_many_small_lines(&mut new_lines[edit_range.clone()], every_nth_line, &mut rng)
}
EditPattern::InsertHelperBlocks { every_nth_line } => {
insert_helper_blocks(&mut new_lines, edit_range.clone(), every_nth_line, &mut rng)
}
}
let new_text_end = edit_range.end + new_lines.len().saturating_sub(old_lines.len());
let old_file_text = old_lines.join("\n");
let expected_file_text = new_lines.join("\n");
let new_text = new_lines[edit_range.start..new_text_end].join("\n");
EditFixture {
name,
old_file_text,
expected_file_text,
old_text,
new_text,
}
}
fn edit_range(lines: &[String], pattern: &EditPattern) -> std::ops::Range<usize> {
let mut range = match pattern {
EditPattern::LocalizedRewrite {
start_line,
line_count,
} => *start_line..(*start_line + *line_count).min(lines.len()),
EditPattern::ManySmallChanges { .. } | EditPattern::InsertHelperBlocks { .. } => {
10..lines.len().saturating_sub(5)
}
};
while range.end > range.start && lines[range.end - 1].is_empty() {
range.end -= 1;
}
range
}
fn random_rust_module(rng: &mut StdRng, function_count: usize) -> Vec<String> {
let mut lines = vec![
"use anyhow::{Context as _, Result};".to_string(),
"use collections::HashMap;".to_string(),
"".to_string(),
"#[derive(Clone, Debug)]".to_string(),
"pub struct WorkspaceSnapshot {".to_string(),
" buffers: HashMap<String, usize>,".to_string(),
" version: usize,".to_string(),
"}".to_string(),
"".to_string(),
"impl WorkspaceSnapshot {".to_string(),
];
for function_index in 0..function_count {
let function_name = identifier(rng, function_index);
let argument_name = identifier(rng, function_index + 1_000);
let local_name = identifier(rng, function_index + 2_000);
let branch_name = identifier(rng, function_index + 3_000);
let multiplier = rng.random_range(2..17);
let offset = rng.random_range(1..128);
lines.extend([
format!(
" pub fn {function_name}(&mut self, {argument_name}: usize) -> Result<usize> {{"
),
format!(" let mut {local_name} = {argument_name}.saturating_mul({multiplier});"),
format!(" if {local_name} % 2 == 0 {{"),
format!(
" {local_name} = {local_name}.saturating_add(self.version + {offset});"
),
" } else {".to_string(),
format!(" {local_name} = {local_name}.saturating_sub({offset});"),
" }".to_string(),
format!(" let {branch_name} = self.buffers.len().saturating_add({local_name});"),
format!(" self.version = self.version.saturating_add({branch_name});"),
format!(" Ok({branch_name})"),
" }".to_string(),
"".to_string(),
]);
}
lines.push("}".to_string());
lines.push("".to_string());
lines.push("pub fn normalize_path(path: &str) -> String {".to_string());
lines.push(" path.replace('\\\\', \"/\")".to_string());
lines.push("}".to_string());
lines
}
fn rewrite_local_block(lines: &mut [String], rng: &mut StdRng) {
for (line_index, line) in lines.iter_mut().enumerate() {
let suffix = identifier(rng, line_index + 10_000);
if line.contains("saturating_add") {
*line = format!(
" let {suffix} = self.version.checked_add({line_index}).context(\"version overflow\")?;"
);
} else if line.contains("saturating_sub") {
*line = format!(
" {suffix}.saturating_sub({});",
rng.random_range(8..256)
);
} else if line.trim().is_empty() {
*line =
format!(" tracing::trace!(target: \"agent_bench\", value = {line_index});");
} else {
*line = format!("{line} // updated {suffix}");
}
}
}
fn rewrite_many_small_lines(lines: &mut [String], every_nth_line: usize, rng: &mut StdRng) {
for (line_index, line) in lines.iter_mut().enumerate() {
if line_index.is_multiple_of(every_nth_line) || line.trim().is_empty() {
continue;
}
let suffix = identifier(rng, line_index + 20_000);
*line = format!("{line} // audited {suffix}");
}
}
fn insert_helper_blocks(
lines: &mut Vec<String>,
range: std::ops::Range<usize>,
every_nth_line: usize,
rng: &mut StdRng,
) {
let mut line_index = range.start;
while line_index < range.end.min(lines.len()) {
if line_index.is_multiple_of(every_nth_line) && !lines[line_index].trim().is_empty() {
let suffix = identifier(rng, line_index + 30_000);
lines.splice(
line_index..line_index,
[
format!(" let {suffix}_before = self.version;"),
format!(" tracing::debug!(version = {suffix}_before);"),
],
);
line_index += 2;
}
line_index += 1;
}
}
fn identifier(rng: &mut StdRng, salt: usize) -> String {
const PARTS: &[&str] = &[
"alpha", "beta", "gamma", "delta", "epsilon", "zeta", "theta", "lambda", "sigma", "omega",
];
format!(
"{}_{}_{}",
PARTS[rng.random_range(0..PARTS.len())],
salt,
rng.random_range(0..10_000)
)
}
criterion_group!(benches, edit_file_tool_streaming);
criterion_main!(benches);

View file

@ -3,6 +3,7 @@ mod legacy_thread;
mod native_agent_server;
pub mod outline;
mod pattern_extraction;
mod sandboxing;
mod templates;
#[cfg(test)]
mod tests;
@ -10,7 +11,6 @@ mod thread;
mod thread_store;
mod tool_permissions;
mod tools;
mod user_agents_md;
use context_server::ContextServerId;
pub use db::*;
@ -23,7 +23,6 @@ pub use thread::*;
pub use thread_store::*;
pub use tool_permissions::*;
pub use tools::*;
pub use user_agents_md::{UserAgentsMd, UserAgentsMdState, init as init_user_agents_md};
use acp_thread::{
AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
@ -51,10 +50,7 @@ use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageMo
use project::{
AgentId, Project, ProjectItem, ProjectPath, Worktree, trusted_worktrees::TrustedWorktrees,
};
use prompt_store::{
ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
WorktreeContext,
};
use prompt_store::{ProjectContext, RULES_FILE_NAMES, RulesFileContext, WorktreeContext};
use serde::{Deserialize, Serialize};
use settings::{LanguageModelSelection, Settings as _, update_settings_file};
use std::any::Any;
@ -308,7 +304,6 @@ pub struct NativeAgent {
templates: Arc<Templates>,
/// Cached model information
models: LanguageModels,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
_subscriptions: Vec<Subscription>,
/// Tracks the lifecycle of global skills directory observation. We
@ -355,20 +350,16 @@ impl NativeAgent {
pub fn new(
thread_store: Entity<ThreadStore>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
cx: &mut App,
) -> Entity<NativeAgent> {
log::debug!("Creating new NativeAgent");
cx.new(|cx| {
let mut subscriptions = vec![cx.subscribe(
let subscriptions = vec![cx.subscribe(
&LanguageModelRegistry::global(cx),
Self::handle_models_updated_event,
)];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
}
if !cx.has_global::<SkillIndex>() {
cx.set_global(SkillIndex::default());
@ -381,7 +372,6 @@ impl NativeAgent {
projects: HashMap::default(),
templates,
models: LanguageModels::new(cx),
prompt_store,
fs,
_subscriptions: subscriptions,
skills_state: SkillsState::default(),
@ -640,7 +630,7 @@ impl NativeAgent {
return project_id;
}
let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
let project_context = cx.new(|_| ProjectContext::new(vec![]));
self.register_project_with_initial_context(project.clone(), project_context, cx);
if let Some(state) = self.projects.get_mut(&project_id) {
state.project_context_needs_refresh.send(()).ok();
@ -733,7 +723,6 @@ impl NativeAgent {
.context("project state not found")?;
anyhow::Ok(Self::build_project_context(
&state.project,
this.prompt_store.as_ref(),
this.fs.clone(),
cx,
))
@ -805,7 +794,6 @@ impl NativeAgent {
fn build_project_context(
project: &Entity<Project>,
prompt_store: Option<&Entity<PromptStore>>,
fs: Arc<dyn Fs>,
cx: &mut App,
) -> Task<(ProjectContext, Vec<Skill>, Vec<SkillLoadError>)> {
@ -887,22 +875,8 @@ impl NativeAgent {
.collect();
cx.background_spawn(async move { future::join_all(project_skills_futures).await })
};
let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
prompt_store.read_with(cx, |prompt_store, cx| {
let prompts = prompt_store.default_prompt_metadata();
let load_tasks = prompts.into_iter().map(|prompt_metadata| {
let contents = prompt_store.load(prompt_metadata.id, cx);
async move { (contents.await, prompt_metadata) }
});
cx.background_spawn(future::join_all(load_tasks))
})
} else {
Task::ready(vec![])
};
cx.spawn(async move |_cx| {
let (worktrees, default_user_rules) =
future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
let worktrees = future::join_all(worktree_tasks).await;
let worktrees = worktrees
.into_iter()
@ -915,27 +889,6 @@ impl NativeAgent {
})
.collect::<Vec<_>>();
let default_user_rules = default_user_rules
.into_iter()
.flat_map(|(contents, prompt_metadata)| match contents {
Ok(contents) => Some(UserRulesContext {
uuid: prompt_metadata.id.as_user()?,
title: prompt_metadata.title.map(|title| title.to_string()),
contents,
}),
Err(_err) => {
// TODO: show error message
// this.update(cx, |_, cx| {
// cx.emit(RulesLoadingError {
// message: format!("{err:?}").into(),
// });
// })
// .ok();
None
}
})
.collect::<Vec<_>>();
// Load and combine skills. `combine_skills` deliberately
// does NOT deduplicate — the autocomplete popup needs to
// see every entry so users can disambiguate same-named
@ -959,8 +912,7 @@ impl NativeAgent {
let (catalog_skills, budget_errors) = select_catalog_skills(&overridden);
skill_errors.extend(budget_errors);
let project_context =
ProjectContext::new(worktrees, default_user_rules).with_skills(catalog_skills);
let project_context = ProjectContext::new(worktrees).with_skills(catalog_skills);
(project_context, skills, skill_errors)
})
}
@ -1121,17 +1073,6 @@ impl NativeAgent {
}
}
fn handle_prompts_updated_event(
&mut self,
_prompt_store: Entity<PromptStore>,
_event: &prompt_store::PromptsUpdatedEvent,
_cx: &mut Context<Self>,
) {
for state in self.projects.values_mut() {
state.project_context_needs_refresh.send(()).ok();
}
}
fn handle_models_updated_event(
&mut self,
_registry: Entity<LanguageModelRegistry>,
@ -2669,12 +2610,54 @@ impl ThreadEnvironment for NativeThreadEnvironment {
fn create_terminal(
&self,
command: String,
extra_env: Vec<acp::EnvVariable>,
cwd: Option<PathBuf>,
output_byte_limit: Option<u64>,
sandbox_wrap: Option<acp_thread::SandboxWrap>,
cx: &mut AsyncApp,
) -> Task<Result<Rc<dyn TerminalHandle>>> {
// Use a per-thread temp directory for all terminal commands, even when
// sandboxing is disabled, so the model can't infer sandbox state from
// `$TMPDIR` changing between conversations.
let mut extra_env = extra_env;
let mut sandbox_wrap = sandbox_wrap;
match self
.thread
.update(cx, |thread, cx| thread.sandboxed_terminal_temp_dir(cx))
{
Ok(Ok(temp_dir)) => {
// Canonicalize so the path matches what the sandbox resolves
// symlinks to (e.g. `/var` -> `/private/var` on macOS).
// `$TMPDIR` and the writable-scope entry below must agree, and
// they must agree with the path the kernel actually checks.
let temp_dir = temp_dir.canonicalize().unwrap_or(temp_dir);
let temp_dir_string = temp_dir.to_string_lossy().into_owned();
extra_env.extend([
acp::EnvVariable::new("TMPDIR", &temp_dir_string),
acp::EnvVariable::new("TMP", &temp_dir_string),
acp::EnvVariable::new("TEMP", &temp_dir_string),
]);
// The command's `$TMPDIR` must live inside the sandbox's
// writable scope. The per-thread temp directory is owned here
// (not in the terminal tool that assembles the rest of the
// writable set), so add it whenever the command is sandboxed.
if let Some(sandbox_wrap) = &mut sandbox_wrap {
sandbox_wrap.writable_paths.push(temp_dir);
}
}
Ok(Err(error)) => return Task::ready(Err(error)),
Err(error) => return Task::ready(Err(error)),
};
let task = self.acp_thread.update(cx, |thread, cx| {
thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
thread.create_terminal(
command,
vec![],
extra_env,
cwd,
output_byte_limit,
sandbox_wrap,
cx,
)
});
let acp_thread = self.acp_thread.clone();
@ -3501,7 +3484,7 @@ mod internal_tests {
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent =
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), fs.clone(), cx));
// Creating a session registers the project and triggers context building.
let connection = NativeAgentConnection(agent.clone());
@ -3592,7 +3575,7 @@ mod internal_tests {
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent =
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), fs.clone(), cx));
// Simulate the user-interaction trigger that the agent panel
// fires (input focus, slash autocomplete, or submit). In tests
@ -3655,7 +3638,7 @@ mod internal_tests {
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent =
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), fs.clone(), cx));
// First scan trigger: nothing on disk yet, state stays idle.
cx.update(|cx| {
@ -3732,7 +3715,7 @@ mod internal_tests {
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent =
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), fs.clone(), cx));
// First scan trigger: nothing on disk yet.
cx.update(|cx| {
@ -3878,7 +3861,7 @@ mod internal_tests {
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent =
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), fs.clone(), cx));
// Open a parent session through the connection, the same way
// production does. This triggers project-context refresh which
@ -3991,7 +3974,7 @@ mod internal_tests {
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent =
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), fs.clone(), cx));
let connection = NativeAgentConnection(agent.clone());
let acp_thread = cx
@ -4096,7 +4079,7 @@ mod internal_tests {
Project::test_with_worktree_trust(fs.clone(), [Path::new("/project")], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent =
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), fs.clone(), cx));
let connection = NativeAgentConnection(agent.clone());
let acp_thread = cx
@ -4177,10 +4160,9 @@ mod internal_tests {
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let connection =
NativeAgentConnection(cx.update(|cx| {
NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
}));
let connection = NativeAgentConnection(
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), fs.clone(), cx)),
);
// Create a thread/session
let acp_thread = cx
@ -4254,7 +4236,7 @@ mod internal_tests {
// Create the agent and connection
let agent =
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), fs.clone(), cx));
let connection = NativeAgentConnection(agent.clone());
// Create a thread/session
@ -4351,7 +4333,7 @@ mod internal_tests {
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent =
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), fs.clone(), cx));
let connection = NativeAgentConnection(agent.clone());
let acp_thread = cx
@ -4442,7 +4424,7 @@ mod internal_tests {
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent =
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@ -4493,9 +4475,8 @@ mod internal_tests {
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent = cx
.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
// Register a thinking model.
@ -4596,9 +4577,8 @@ mod internal_tests {
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent = cx
.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
// Register a model where id() != name(), like real Anthropic models
@ -4712,9 +4692,8 @@ mod internal_tests {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent = cx
.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@ -4894,9 +4873,8 @@ mod internal_tests {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent = cx
.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@ -4975,9 +4953,8 @@ mod internal_tests {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent = cx
.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@ -5059,9 +5036,8 @@ mod internal_tests {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent = cx
.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@ -5204,9 +5180,8 @@ mod internal_tests {
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent = cx
.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx

View file

@ -16,7 +16,7 @@ use sqlez::{
connection::Connection,
statement::Statement,
};
use std::sync::Arc;
use std::{io::ErrorKind, path::PathBuf, sync::Arc};
use ui::{App, SharedString};
use util::path_list::PathList;
use zed_env_vars::ZED_STATELESS;
@ -81,6 +81,8 @@ pub struct DbThread {
pub draft_prompt: Option<Vec<acp::ContentBlock>>,
#[serde(default)]
pub ui_scroll_position: Option<SerializedScrollPosition>,
#[serde(default)]
pub sandboxed_terminal_temp_dir: Option<PathBuf>,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
@ -130,6 +132,7 @@ impl SharedThread {
thinking_effort: None,
draft_prompt: None,
ui_scroll_position: None,
sandboxed_terminal_temp_dir: None,
}
}
@ -309,6 +312,7 @@ impl DbThread {
thinking_effort: None,
draft_prompt: None,
ui_scroll_position: None,
sandboxed_terminal_temp_dir: None,
})
}
}
@ -569,15 +573,7 @@ impl ThreadsDatabase {
let rows = select(id.0)?;
if let Some((data_type, data)) = rows.into_iter().next() {
let json_data = match data_type {
DataType::Zstd => {
let decompressed = zstd::decode_all(&data[..])?;
String::from_utf8(decompressed)?
}
DataType::Json => String::from_utf8(data)?,
};
let thread = DbThread::from_json(json_data.as_bytes())?;
Ok(Some(thread))
Ok(Some(Self::deserialize_thread(data_type, data)?))
} else {
Ok(None)
}
@ -596,17 +592,71 @@ impl ThreadsDatabase {
.spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
}
fn deserialize_thread(data_type: DataType, data: Vec<u8>) -> Result<DbThread> {
let json_data = match data_type {
DataType::Zstd => {
let decompressed = zstd::decode_all(&data[..])?;
String::from_utf8(decompressed)?
}
DataType::Json => String::from_utf8(data)?,
};
DbThread::from_json(json_data.as_bytes())
}
fn sandboxed_terminal_temp_dir(data_type: DataType, data: Vec<u8>) -> Option<PathBuf> {
match Self::deserialize_thread(data_type, data) {
Ok(thread) => thread.sandboxed_terminal_temp_dir,
Err(error) => {
log::warn!("failed to deserialize thread before deleting it: {error:#}");
None
}
}
}
fn remove_sandboxed_terminal_temp_dir(temp_dir: PathBuf) {
match std::fs::remove_dir_all(&temp_dir) {
Ok(()) => {}
Err(error) if error.kind() == ErrorKind::NotFound => {}
Err(error) => {
log::warn!(
"failed to remove sandboxed terminal temp directory {}: {error}",
temp_dir.display()
);
}
}
}
pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
let connection = self.connection.clone();
self.executor.spawn(async move {
let connection = connection.lock();
let sandboxed_terminal_temp_dir = {
let connection = connection.lock();
let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
DELETE FROM threads WHERE id = ?
"})?;
let mut select =
connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
"})?;
delete(id.0)?;
let sandboxed_terminal_temp_dir = select(id.0.clone())?
.into_iter()
.next()
.and_then(|(data_type, data)| {
Self::sandboxed_terminal_temp_dir(data_type, data)
});
let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
DELETE FROM threads WHERE id = ?
"})?;
delete(id.0)?;
sandboxed_terminal_temp_dir
};
if let Some(temp_dir) = sandboxed_terminal_temp_dir {
Self::remove_sandboxed_terminal_temp_dir(temp_dir);
}
Ok(())
})
@ -616,13 +666,32 @@ impl ThreadsDatabase {
let connection = self.connection.clone();
self.executor.spawn(async move {
let connection = connection.lock();
let sandboxed_terminal_temp_dirs = {
let connection = connection.lock();
let mut delete = connection.exec_bound::<()>(indoc! {"
DELETE FROM threads
"})?;
let mut select = connection.select_bound::<(), (DataType, Vec<u8>)>(indoc! {"
SELECT data_type, data FROM threads
"})?;
delete(())?;
let sandboxed_terminal_temp_dirs = select(())?
.into_iter()
.filter_map(|(data_type, data)| {
Self::sandboxed_terminal_temp_dir(data_type, data)
})
.collect::<Vec<_>>();
let mut delete = connection.exec_bound::<()>(indoc! {"
DELETE FROM threads
"})?;
delete(())?;
sandboxed_terminal_temp_dirs
};
for temp_dir in sandboxed_terminal_temp_dirs {
Self::remove_sandboxed_terminal_temp_dir(temp_dir);
}
Ok(())
})
@ -694,6 +763,7 @@ mod tests {
thinking_effort: None,
draft_prompt: None,
ui_scroll_position: None,
sandboxed_terminal_temp_dir: None,
}
}
@ -797,6 +867,78 @@ mod tests {
);
}
#[test]
fn test_sandboxed_terminal_temp_dir_defaults_to_none() {
let json = r#"{
"title": "Old Thread",
"messages": [],
"updated_at": "2024-01-01T00:00:00Z"
}"#;
let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
assert!(
db_thread.sandboxed_terminal_temp_dir.is_none(),
"Legacy threads without sandboxed_terminal_temp_dir should default to None"
);
}
#[gpui::test]
async fn test_sandboxed_terminal_temp_dir_roundtrips_through_save_load(
cx: &mut TestAppContext,
) {
let database = ThreadsDatabase::new(cx.executor()).unwrap();
let thread_id = session_id("sandbox-temp-dir-thread");
let temp_dir = tempfile::Builder::new()
.prefix("zed-agent-terminal-test-")
.tempdir()
.unwrap()
.keep();
let mut thread = make_thread(
"Sandbox Temp Dir Thread",
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
);
thread.sandboxed_terminal_temp_dir = Some(temp_dir.clone());
database
.save_thread(thread_id.clone(), thread, PathList::default())
.await
.unwrap();
let loaded = database
.load_thread(thread_id)
.await
.unwrap()
.expect("thread should exist");
assert_eq!(loaded.sandboxed_terminal_temp_dir, Some(temp_dir.clone()));
std::fs::remove_dir_all(temp_dir).unwrap();
}
#[gpui::test]
async fn test_delete_thread_removes_sandboxed_terminal_temp_dir(cx: &mut TestAppContext) {
let database = ThreadsDatabase::new(cx.executor()).unwrap();
let thread_id = session_id("sandbox-temp-dir-delete-thread");
let temp_dir = tempfile::Builder::new()
.prefix("zed-agent-terminal-test-")
.tempdir()
.unwrap()
.keep();
std::fs::write(temp_dir.join("sentinel"), b"content").unwrap();
let mut thread = make_thread(
"Sandbox Temp Dir Delete Thread",
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
);
thread.sandboxed_terminal_temp_dir = Some(temp_dir.clone());
database
.save_thread(thread_id.clone(), thread, PathList::default())
.await
.unwrap();
database.delete_thread(thread_id).await.unwrap();
assert!(!temp_dir.exists());
}
#[gpui::test]
async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
let database = ThreadsDatabase::new(cx.executor()).unwrap();

View file

@ -9,9 +9,7 @@ use fs::Fs;
use gpui::{App, Entity, Task};
use language_model::{LanguageModelId, LanguageModelProviderId, LanguageModelRegistry};
use project::{AgentId, Project};
use prompt_store::PromptStore;
use settings::{LanguageModelSelection, Settings as _, update_settings_file};
use util::ResultExt as _;
use crate::{NativeAgent, NativeAgentConnection, ThreadStore, templates::Templates};
@ -45,15 +43,12 @@ impl AgentServer for NativeAgentServer {
log::debug!("NativeAgentServer::connect");
let fs = self.fs.clone();
let thread_store = self.thread_store.clone();
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
log::debug!("Creating templates for native agent");
let templates = Templates::new();
let prompt_store = prompt_store.await.log_err();
log::debug!("Creating native agent entity");
let agent =
cx.update(|cx| NativeAgent::new(thread_store, templates, prompt_store, fs, cx));
let agent = cx.update(|cx| NativeAgent::new(thread_store, templates, fs, cx));
// Create the connection wrapper
let connection = NativeAgentConnection(agent);

View file

@ -0,0 +1,25 @@
//! Agent-side glue for the [`sandbox`] crate.
//!
//! Centralizes the "should agent-run terminal commands be sandboxed for this
//! process?" check so the system prompt, the terminal tool, and any other
//! caller see the same answer (and so the `target_os` gate lives in one
//! place instead of scattered across the agent crate).
//!
//! The current policy is: enabled iff we're on macOS *and* the user has the
//! `sandboxing` feature flag turned on. There's deliberately no settings or
//! env-var override yet — the flag is the only switch.
//!
//! On non-macOS hosts we don't have a sandbox integration today, so this
//! returns `false` regardless of the flag.
//!
//! Naming note: this module is about agent terminal sandboxing specifically.
//! Other agent operations (e.g. file edits) are gated separately.
use feature_flags::{FeatureFlagAppExt as _, SandboxingFeatureFlag};
use gpui::App;
/// Whether agent-run terminal commands should be wrapped in an OS-level
/// sandbox for this process. See module docs for the policy.
pub(crate) fn sandboxing_enabled(cx: &App) -> bool {
cfg!(target_os = "macos") && cx.has_flag::<SandboxingFeatureFlag>()
}

View file

@ -43,6 +43,12 @@ pub struct SystemPromptTemplate<'a> {
/// Contents of the user-global `~/.config/zed/AGENTS.md` file (or the
/// platform equivalent), if present and non-empty.
pub user_agents_md: Option<SharedString>,
/// Whether agent-run terminal commands are wrapped in an OS-level
/// sandbox for this conversation. When `true`, the rendered prompt
/// describes the sandbox's read/write/network rules and the
/// per-command flags the model can request to relax them. When
/// `false`, the prompt omits the sandbox section entirely.
pub sandboxing: bool,
}
impl Template for SystemPromptTemplate<'_> {
@ -87,6 +93,7 @@ mod tests {
model_name: Some("test-model".to_string()),
date: "2026-01-01".to_string(),
user_agents_md: None,
sandboxing: false,
};
let templates = Templates::new();
let rendered = template.render(&templates).unwrap();
@ -112,13 +119,14 @@ mod tests {
project_entry_id: 1,
}),
}];
let project = ProjectContext::new(worktrees, Vec::new());
let project = ProjectContext::new(worktrees);
let template = SystemPromptTemplate {
project: &project,
available_tools: vec!["echo".into()],
model_name: Some("test-model".to_string()),
date: "2026-01-01".to_string(),
user_agents_md: Some("always be concise".into()),
sandboxing: false,
};
let templates = Templates::new();
let rendered = template.render(&templates).unwrap();
@ -136,6 +144,78 @@ mod tests {
);
}
#[test]
fn test_system_prompt_omits_sandbox_section_when_sandboxing_disabled() {
let project = prompt_store::ProjectContext::default();
let template = SystemPromptTemplate {
project: &project,
available_tools: vec!["echo".into()],
model_name: Some("test-model".to_string()),
date: "2026-01-01".to_string(),
user_agents_md: None,
sandboxing: false,
};
let templates = Templates::new();
let rendered = template.render(&templates).unwrap();
assert!(!rendered.contains("## Terminal sandbox"));
assert!(!rendered.contains("allow_network"));
}
#[test]
fn test_system_prompt_renders_sandbox_section_with_worktrees_when_enabled() {
use prompt_store::{ProjectContext, WorktreeContext};
let worktrees = vec![
WorktreeContext {
root_name: "alpha".to_string(),
abs_path: std::path::Path::new("/tmp/alpha").into(),
rules_file: None,
},
WorktreeContext {
root_name: "beta".to_string(),
abs_path: std::path::Path::new("/tmp/beta").into(),
rules_file: None,
},
];
let project = ProjectContext::new(worktrees);
let template = SystemPromptTemplate {
project: &project,
available_tools: vec!["echo".into()],
model_name: Some("test-model".to_string()),
date: "2026-01-01".to_string(),
user_agents_md: None,
sandboxing: true,
};
let templates = Templates::new();
let rendered = template.render(&templates).unwrap();
assert!(rendered.contains("## Terminal sandbox"));
assert!(rendered.contains("`/tmp/alpha`"));
assert!(rendered.contains("`/tmp/beta`"));
assert!(rendered.contains("allow_network: true"));
assert!(rendered.contains("allow_fs_write: true"));
assert!(rendered.contains("unsandboxed: true"));
assert!(rendered.contains("remain in effect for the entire duration"));
}
#[test]
fn test_system_prompt_sandbox_section_handles_zero_worktrees() {
let project = prompt_store::ProjectContext::default();
let template = SystemPromptTemplate {
project: &project,
available_tools: vec!["echo".into()],
model_name: Some("test-model".to_string()),
date: "2026-01-01".to_string(),
user_agents_md: None,
sandboxing: true,
};
let templates = Templates::new();
let rendered = template.render(&templates).unwrap();
assert!(rendered.contains("## Terminal sandbox"));
assert!(rendered.contains("No project directories are currently writable"));
}
#[test]
fn test_system_prompt_omits_user_agents_md_section_when_absent() {
let project = prompt_store::ProjectContext::default();
@ -145,9 +225,28 @@ mod tests {
model_name: Some("test-model".to_string()),
date: "2026-01-01".to_string(),
user_agents_md: None,
sandboxing: false,
};
let templates = Templates::new();
let rendered = template.render(&templates).unwrap();
assert!(!rendered.contains("### Personal `AGENTS.md`"));
}
#[test]
fn test_system_prompt_does_not_render_legacy_zed_rules_section() {
let project = prompt_store::ProjectContext::default();
let template = SystemPromptTemplate {
project: &project,
available_tools: vec!["echo".into()],
model_name: Some("test-model".to_string()),
date: "2026-01-01".to_string(),
user_agents_md: None,
sandboxing: false,
};
let templates = Templates::new();
let rendered = template.render(&templates).unwrap();
assert!(!rendered.contains("The user has specified the following rules"));
assert!(!rendered.contains("Rules title:"));
}
}

View file

@ -173,12 +173,11 @@ The current project contains the following root directories:
You are powered by the model named {{model_name}}.
{{/if}}
{{#if (or has_rules has_user_rules)}}
{{#if has_rules}}
## User's Custom Instructions
The following additional instructions are provided by the user and should be followed to the best of your ability{{#if (gt (len available_tools) 0)}} without interfering with the tool use guidelines{{/if}}.
{{#if has_rules}}
There are project rules that apply to these root directories:
{{#each worktrees}}
{{#if rules_file}}
@ -189,17 +188,3 @@ There are project rules that apply to these root directories:
{{/if}}
{{/each}}
{{/if}}
{{#if has_user_rules}}
The user has specified the following rules that should be applied:
{{#each user_rules}}
{{#if title}}
Rules title: {{title}}
{{/if}}
``````
{{contents}}
``````
{{/each}}
{{/if}}
{{/if}}

View file

@ -23,15 +23,13 @@ graph TD
A[Start] --> B[End]
```
The renderer supports the following diagram types: flowchart, sequence, class, state, ER, gantt, pie, gitgraph, mindmap, timeline, quadrant chart, xy chart, and journey. Other diagram types will only show as code.
Mermaid diagrams are automatically themed to match the user's editor theme. Do not include `%%{init}%%` directives or define your own `classDef` styles.
Do *NOT* include inline HTML elements in mermaid diagrams, as they cannot be rendered. It is better to simply skip formatting (e.g. bold/italic/etc.).
When you need accent colors for emphasis (e.g. color-coding layers, categories, or states), use the pre-defined classes `accent0` through `accent7` with the `:::` syntax:
A:::accent0 --> B:::accent1 --> C:::accent2
These classes automatically match the user's theme. Do not hardcode hex color values unless an exact color match is specifically required. Note that the rendered view may be narrow, so try to prioritize generating taller diagrams over wider ones.
Mermaid diagrams are automatically color-coded using the user's theme accent palette. Do not hardcode hex color values unless an exact color match is specifically required. Note that the rendered view may be narrow, so try to prioritize generating taller diagrams over wider ones.
{{#if (gt (len available_tools) 0)}}
## Tool Use
@ -189,6 +187,24 @@ The current project contains the following root directories:
- `{{abs_path}}`
{{/each}}
{{#if sandboxing}}
## Terminal sandbox
The `terminal` tool runs commands inside a sandbox with these permissions:
- Reads: any path on the filesystem is readable.
- Writes: a per-thread temporary directory exposed via `$TMPDIR`, `$TMP`, and `$TEMP` is writable and persists across `terminal` calls in this conversation{{#if worktrees}}, along with these project directories:
{{#each worktrees}}
- `{{abs_path}}`
{{/each}}
Writes anywhere else on the filesystem are blocked.{{else}}. No project directories are currently writable.{{/if}}
- Network: outbound network access is blocked.
You can request elevated permissions on individual `terminal` calls by setting `allow_network: true`, `allow_fs_write: true`, or `unsandboxed: true`. The user will be prompted to approve before the command runs.
These sandbox settings are guaranteed to remain in effect for the entire duration of this conversation. If they ever change, you will be told.
{{/if}}
{{#if model_name}}
## Model Information
@ -223,7 +239,7 @@ To use a Skill:
4. If the Skill references additional files, use `read_file` to access them. Paths inside a Skill resolve relative to that Skill's directory (the parent of its `SKILL.md`).
{{/if}}
{{#if (or user_agents_md has_rules has_user_rules)}}
{{#if (or user_agents_md has_rules)}}
## User's Custom Instructions
The following additional instructions are provided by the user and should be followed to the best of your ability{{#if (gt (len available_tools) 0)}} without interfering with the tool use guidelines{{/if}}.
@ -254,16 +270,4 @@ There are project rules that apply to these root directories:
{{/each}}
{{/if}}
{{#if has_user_rules}}
The user has specified the following rules that should be applied:
{{#each user_rules}}
{{#if title}}
Rules title: {{title}}
{{/if}}
``````
{{contents}}
``````
{{/each}}
{{/if}}
{{/if}}

View file

@ -200,8 +200,10 @@ impl crate::ThreadEnvironment for FakeThreadEnvironment {
fn create_terminal(
&self,
_command: String,
_extra_env: Vec<acp::EnvVariable>,
_cwd: Option<std::path::PathBuf>,
_output_byte_limit: Option<u64>,
_sandbox_wrap: Option<acp_thread::SandboxWrap>,
_cx: &mut AsyncApp,
) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
self.terminal_creations.fetch_add(1, Ordering::SeqCst);
@ -242,8 +244,10 @@ impl crate::ThreadEnvironment for MultiTerminalEnvironment {
fn create_terminal(
&self,
_command: String,
_extra_env: Vec<acp::EnvVariable>,
_cwd: Option<std::path::PathBuf>,
_output_byte_limit: Option<u64>,
_sandbox_wrap: Option<acp_thread::SandboxWrap>,
cx: &mut AsyncApp,
) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
@ -320,6 +324,7 @@ async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
command: "sleep 1000".to_string(),
cd: ".".to_string(),
timeout_ms: Some(5),
..Default::default()
}),
event_stream,
cx,
@ -387,6 +392,7 @@ async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAp
command: "sleep 1000".to_string(),
cd: ".".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -3520,8 +3526,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let thread_store = cx.new(|cx| ThreadStore::new(cx));
// Create agent and connection
let agent = cx
.update(|cx| NativeAgent::new(thread_store, templates.clone(), None, fake_fs.clone(), cx));
let agent =
cx.update(|cx| NativeAgent::new(thread_store, templates.clone(), fake_fs.clone(), cx));
let connection = NativeAgentConnection(agent.clone());
// Create a thread using new_thread
@ -4892,6 +4898,7 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
command: "rm -rf /".to_string(),
cd: ".".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -4944,6 +4951,7 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
command: "echo hello".to_string(),
cd: ".".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -5002,6 +5010,7 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
command: "sudo rm file".to_string(),
cd: ".".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -5049,6 +5058,7 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
command: "echo hello".to_string(),
cd: ".".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -5091,9 +5101,8 @@ async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent =
cx.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@ -5226,9 +5235,8 @@ async fn test_subagent_tool_output_does_not_include_thinking(cx: &mut TestAppCon
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent =
cx.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@ -5374,9 +5382,8 @@ async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAp
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent =
cx.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@ -5504,9 +5511,8 @@ async fn test_subagent_tool_resume_session(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent =
cx.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@ -6037,9 +6043,8 @@ async fn test_subagent_context_window_warning(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent =
cx.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@ -6163,9 +6168,8 @@ async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mu
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent =
cx.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
@ -6337,9 +6341,8 @@ async fn test_subagent_error_propagation(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
let agent = cx.update(|cx| {
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
});
let agent =
cx.update(|cx| NativeAgent::new(thread_store.clone(), Templates::new(), fs.clone(), cx));
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx

View file

@ -4,11 +4,11 @@ use crate::{
FindPathTool, FindReferencesTool, GetCodeActionsTool, GoToDefinitionTool, GrepTool,
ListDirectoryTool, MovePathTool, ProjectSnapshot, ReadFileTool, RenameTool, SpawnAgentTool,
SystemPromptTemplate, Template, Templates, TerminalTool, ToolPermissionDecision,
UpdatePlanTool, UpdateTitleTool, UserAgentsMd, WebSearchTool, WriteFileTool,
decide_permission_from_settings,
UpdatePlanTool, UpdateTitleTool, WebSearchTool, WriteFileTool, decide_permission_from_settings,
};
use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
use agent_settings::UserAgentsMd;
use feature_flags::{
FeatureFlagAppExt as _, LspToolFeatureFlag, RenameToolFeatureFlag, UpdatePlanToolFeatureFlag,
UpdateTitleToolFeatureFlag,
@ -51,16 +51,16 @@ use serde::{Deserialize, Serialize};
use settings::{
LanguageModelSelection, Settings, SettingsStore, ToolPermissionMode, update_settings_file,
};
use std::fmt::Write;
use std::{
collections::BTreeMap,
marker::PhantomData,
ops::RangeInclusive,
path::Path,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
time::{Duration, Instant},
};
use std::{fmt::Write, path::PathBuf};
use util::{ResultExt, debug_panic, markdown::MarkdownCodeBlock, paths::PathStyle};
use uuid::Uuid;
@ -668,8 +668,10 @@ pub trait ThreadEnvironment {
fn create_terminal(
&self,
command: String,
extra_env: Vec<acp::EnvVariable>,
cwd: Option<PathBuf>,
output_byte_limit: Option<u64>,
sandbox_wrap: Option<acp_thread::SandboxWrap>,
cx: &mut AsyncApp,
) -> Task<Result<Rc<dyn TerminalHandle>>>;
@ -1005,6 +1007,7 @@ pub struct Thread {
/// Weak references to running subagent threads for cancellation propagation
running_subagents: Vec<WeakEntity<Thread>>,
inherits_parent_model_settings: bool,
sandboxed_terminal_temp_dir: Option<PathBuf>,
}
impl Thread {
@ -1131,6 +1134,7 @@ impl Thread {
ui_scroll_position: None,
running_subagents: Vec::new(),
inherits_parent_model_settings: true,
sandboxed_terminal_temp_dir: None,
}
}
@ -1174,6 +1178,30 @@ impl Thread {
&self.id
}
pub(crate) fn sandboxed_terminal_temp_dir(
&mut self,
cx: &mut Context<Self>,
) -> Result<PathBuf> {
if let Some(temp_dir) = &self.sandboxed_terminal_temp_dir {
std::fs::create_dir_all(temp_dir).with_context(|| {
format!(
"failed to recreate sandboxed terminal temp directory {}",
temp_dir.display()
)
})?;
return Ok(temp_dir.clone());
}
let temp_dir = tempfile::Builder::new()
.prefix("zed-agent-terminal-")
.tempdir()
.context("failed to create sandboxed terminal temp directory")?;
let temp_dir = temp_dir.keep();
self.sandboxed_terminal_temp_dir = Some(temp_dir.clone());
cx.notify();
Ok(temp_dir)
}
/// Returns true if this thread was imported from a shared thread.
pub fn is_imported(&self) -> bool {
self.imported
@ -1449,6 +1477,7 @@ impl Thread {
}),
running_subagents: Vec::new(),
inherits_parent_model_settings: true,
sandboxed_terminal_temp_dir: db_thread.sandboxed_terminal_temp_dir,
}
}
@ -1479,6 +1508,7 @@ impl Thread {
offset_in_item: lo.offset_in_item.as_f32(),
}
}),
sandboxed_terminal_temp_dir: self.sandboxed_terminal_temp_dir.clone(),
};
cx.background_spawn(async move {
@ -3171,6 +3201,7 @@ impl Thread {
model_name: self.model.as_ref().map(|m| m.name().0.to_string()),
date: Local::now().format("%Y-%m-%d").to_string(),
user_agents_md,
sandboxing: crate::sandboxing::sandboxing_enabled(cx),
}
.render(&self.templates)
.context("failed to build system prompt")

View file

@ -167,6 +167,7 @@ mod tests {
thinking_effort: None,
draft_prompt: None,
ui_scroll_position: None,
sandboxed_terminal_temp_dir: None,
}
}

View file

@ -1,5 +1,6 @@
use super::tool_permissions::{
authorize_symlink_escapes, canonicalize_worktree_roots, collect_symlink_escapes,
resolve_creatable_global_skill_descendant_path, resolve_global_skill_descendant_path,
sensitive_settings_kind,
};
use crate::{
@ -23,6 +24,7 @@ use util::markdown::MarkdownInlineCode;
///
/// This tool should be used when it's desirable to create a copy of a file or directory without modifying the original.
/// It's much more efficient than doing this by separately reading and then writing the file or directory's contents, so this tool should be preferred over that approach whenever copying is the goal.
/// The only supported paths outside the project are descendants of `~/.agents/skills`, for global agent skills.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct CopyPathToolInput {
/// The source path of the file or directory to copy.
@ -100,6 +102,15 @@ impl AgentTool for CopyPathTool {
let fs = project.read_with(cx, |project, _cx| project.fs().clone());
let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await;
let global_source_path =
resolve_global_skill_descendant_path(Path::new(&input.source_path), fs.as_ref())
.await;
let global_destination_path = resolve_creatable_global_skill_descendant_path(
Path::new(&input.destination_path),
fs.as_ref(),
)
.await;
let symlink_escapes: Vec<(&str, std::path::PathBuf)> =
project.read_with(cx, |project, cx| {
collect_symlink_escapes(
@ -160,6 +171,63 @@ impl AgentTool for CopyPathTool {
authorize.await.map_err(|e| e.to_string())?;
}
if global_source_path.is_some() || global_destination_path.is_some() {
let source_path = if let Some(global_source_path) = global_source_path {
global_source_path
} else {
project.read_with(cx, |project, cx| {
let project_path = project.find_project_path(&input.source_path, cx).ok_or_else(|| {
format!("Source path {} was not found in the project.", input.source_path)
})?;
project.entry_for_path(&project_path, cx).ok_or_else(|| {
format!("Source path {} was not found in the project.", input.source_path)
})?;
project.absolute_path(&project_path, cx).ok_or_else(|| {
format!("Source path {} could not be resolved.", input.source_path)
})
})?
};
let destination_path = if let Some(global_destination_path) = global_destination_path
{
global_destination_path
} else {
project.read_with(cx, |project, cx| {
let project_path = project.find_project_path(&input.destination_path, cx).ok_or_else(|| {
format!(
"Destination path {} was outside the project.",
input.destination_path
)
})?;
project.absolute_path(&project_path, cx).ok_or_else(|| {
format!(
"Destination path {} could not be resolved.",
input.destination_path
)
})
})?
};
futures::select! {
result = fs::copy_recursive(
fs.as_ref(),
&source_path,
&destination_path,
fs::CopyOptions::default(),
).fuse() => {
result.map_err(|e| format!("Copying {} to {}: {e}", input.source_path, input.destination_path))?;
}
_ = event_stream.cancelled_by_user().fuse() => {
return Err("Copy cancelled by user".to_string());
}
}
return Ok(format!(
"Copied {} to {}",
input.source_path, input.destination_path
));
}
let copy_task = project.update(cx, |project, cx| {
match project
.find_project_path(&input.source_path, cx)
@ -222,6 +290,124 @@ mod tests {
});
}
#[gpui::test]
async fn test_copy_path_global_skill_directory_to_project(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/root/project"), json!({})).await;
let skill_dir = agent_skills::global_skills_dir().join("my-skill");
fs.insert_tree(&skill_dir, json!({ "SKILL.md": "content" }))
.await;
let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
cx.executor().run_until_parked();
let tool = Arc::new(CopyPathTool::new(project));
let input_path = PathBuf::from("~")
.join(".agents")
.join("skills")
.join("my-skill")
.to_string_lossy()
.into_owned();
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| {
tool.run(
ToolInput::resolved(CopyPathToolInput {
source_path: input_path,
destination_path: path!("/root/project/my-skill").to_string(),
}),
event_stream,
cx,
)
});
let auth = event_rx.expect_authorization().await;
let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
assert!(
title.contains("agent skills"),
"Authorization title should mention agent skills, got: {title}",
);
auth.response
.send(acp_thread::SelectedPermissionOutcome::new(
acp::PermissionOptionId::new("allow"),
acp::PermissionOptionKind::AllowOnce,
))
.expect("authorization response should send");
let result = task.await;
assert!(result.is_ok(), "should copy after approval: {result:?}");
assert!(fs.is_dir(&skill_dir).await);
assert_eq!(
fs.load(path!("/root/project/my-skill/SKILL.md").as_ref())
.await
.unwrap(),
"content"
);
}
#[gpui::test]
async fn test_copy_path_project_directory_to_global_skill_directory(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root/project"),
json!({ "exported-skill": { "SKILL.md": "content" } }),
)
.await;
let skills_dir = agent_skills::global_skills_dir();
fs.create_dir(&skills_dir).await.unwrap();
let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
cx.executor().run_until_parked();
let tool = Arc::new(CopyPathTool::new(project));
let destination_path = PathBuf::from("~")
.join(".agents")
.join("skills")
.join("exported-skill")
.to_string_lossy()
.into_owned();
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| {
tool.run(
ToolInput::resolved(CopyPathToolInput {
source_path: path!("/root/project/exported-skill").to_string(),
destination_path,
}),
event_stream,
cx,
)
});
let auth = event_rx.expect_authorization().await;
let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
assert!(
title.contains("agent skills"),
"Authorization title should mention agent skills, got: {title}",
);
auth.response
.send(acp_thread::SelectedPermissionOutcome::new(
acp::PermissionOptionId::new("allow"),
acp::PermissionOptionKind::AllowOnce,
))
.expect("authorization response should send");
let result = task.await;
assert!(result.is_ok(), "should copy after approval: {result:?}");
assert!(
fs.is_dir(path!("/root/project/exported-skill").as_ref())
.await
);
assert_eq!(
fs.load(skills_dir.join("exported-skill").join("SKILL.md").as_ref())
.await
.unwrap(),
"content"
);
}
#[gpui::test]
async fn test_copy_path_symlink_escape_source_requests_authorization(cx: &mut TestAppContext) {
init_test(cx);

View file

@ -1,6 +1,6 @@
use super::tool_permissions::{
authorize_symlink_access, canonicalize_worktree_roots, detect_symlink_escape,
sensitive_settings_kind,
resolve_global_skill_descendant_path, resolves_to_global_skills_dir, sensitive_settings_kind,
};
use crate::{
AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision,
@ -20,6 +20,8 @@ use std::sync::Arc;
use util::markdown::MarkdownInlineCode;
/// Deletes the file or directory (and the directory's contents, recursively) at the specified path in the project, and returns confirmation of the deletion.
///
/// The only supported paths outside the project are descendants of `~/.agents/skills`, for global agent skills.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct DeletePathToolInput {
/// The path of the file or directory to delete.
@ -95,6 +97,16 @@ impl AgentTool for DeletePathTool {
let fs = project.read_with(cx, |project, _cx| project.fs().clone());
let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await;
if resolves_to_global_skills_dir(Path::new(&path), fs.as_ref()).await {
return Err(
"Cannot delete the global agent skills directory itself. Delete a skill directory or file beneath it instead."
.to_string(),
);
}
let global_skill_path =
resolve_global_skill_descendant_path(Path::new(&path), fs.as_ref()).await;
let symlink_escape_target = project.read_with(cx, |project, cx| {
detect_symlink_escape(project, &path, &canonical_roots, cx)
.map(|(_, target)| target)
@ -147,6 +159,38 @@ impl AgentTool for DeletePathTool {
authorize.await.map_err(|e| e.to_string())?;
}
if let Some(global_skill_path) = global_skill_path {
let metadata = fs
.metadata(&global_skill_path)
.await
.map_err(|e| format!("Deleting {path}: {e}"))?
.ok_or_else(|| format!("Deleting {path}: path not found"))?;
futures::select! {
result = async {
if metadata.is_dir {
fs.remove_dir(
&global_skill_path,
fs::RemoveOptions {
recursive: true,
..fs::RemoveOptions::default()
},
)
.await
} else {
fs.remove_file(&global_skill_path, fs::RemoveOptions::default()).await
}
}.fuse() => {
result.map_err(|e| format!("Deleting {path}: {e}"))?;
}
_ = event_stream.cancelled_by_user().fuse() => {
return Err("Delete cancelled by user".to_string());
}
}
return Ok(format!("Deleted {path}"));
}
let (project_path, worktree_snapshot) = project.read_with(cx, |project, cx| {
let project_path = project.find_project_path(&path, cx).ok_or_else(|| {
format!("Couldn't delete {path} because that path isn't in this project.")
@ -248,6 +292,145 @@ mod tests {
});
}
#[gpui::test]
async fn test_delete_path_global_skill_directory(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/root/project"), json!({})).await;
let skills_dir = agent_skills::global_skills_dir();
let skill_dir = skills_dir.join("my-skill");
fs.insert_tree(&skill_dir, json!({ "SKILL.md": "content" }))
.await;
let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
cx.executor().run_until_parked();
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(DeletePathTool::new(project, action_log));
let input_path = PathBuf::from("~")
.join(".agents")
.join("skills")
.join("my-skill")
.to_string_lossy()
.into_owned();
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| {
tool.run(
ToolInput::resolved(DeletePathToolInput { path: input_path }),
event_stream,
cx,
)
});
let auth = event_rx.expect_authorization().await;
let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
assert!(
title.contains("agent skills"),
"Authorization title should mention agent skills, got: {title}",
);
auth.response
.send(acp_thread::SelectedPermissionOutcome::new(
acp::PermissionOptionId::new("allow"),
acp::PermissionOptionKind::AllowOnce,
))
.expect("authorization response should send");
let result = task.await;
assert!(result.is_ok(), "should delete after approval: {result:?}");
assert!(fs.is_dir(&skills_dir).await);
assert!(!fs.is_dir(&skill_dir).await);
}
#[gpui::test]
async fn test_delete_path_global_skill_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/root/project"), json!({})).await;
let skill_file = agent_skills::global_skills_dir()
.join("my-skill")
.join("references")
.join("notes.md");
fs.create_dir(skill_file.parent().unwrap()).await.unwrap();
fs.insert_file(&skill_file, b"notes".to_vec()).await;
let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
cx.executor().run_until_parked();
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(DeletePathTool::new(project, action_log));
let input_path = PathBuf::from("~")
.join(".agents")
.join("skills")
.join("my-skill")
.join("references")
.join("notes.md")
.to_string_lossy()
.into_owned();
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| {
tool.run(
ToolInput::resolved(DeletePathToolInput { path: input_path }),
event_stream,
cx,
)
});
let auth = event_rx.expect_authorization().await;
auth.response
.send(acp_thread::SelectedPermissionOutcome::new(
acp::PermissionOptionId::new("allow"),
acp::PermissionOptionKind::AllowOnce,
))
.expect("authorization response should send");
let result = task.await;
assert!(result.is_ok(), "should delete after approval: {result:?}");
assert!(!fs.is_file(&skill_file).await);
}
#[gpui::test]
async fn test_delete_path_rejects_global_skills_root(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/root/project"), json!({})).await;
let skills_dir = agent_skills::global_skills_dir();
fs.create_dir(&skills_dir).await.unwrap();
let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
cx.executor().run_until_parked();
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(DeletePathTool::new(project, action_log));
let input_path = PathBuf::from("~")
.join(".agents")
.join("skills")
.to_string_lossy()
.into_owned();
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let result = cx
.update(|cx| {
tool.run(
ToolInput::resolved(DeletePathToolInput { path: input_path }),
event_stream,
cx,
)
})
.await;
assert!(result.is_err(), "should reject deleting skills root");
assert!(fs.is_dir(&skills_dir).await);
assert!(
!matches!(
event_rx.try_recv(),
Ok(Ok(crate::ThreadEvent::ToolCallAuthorization(_)))
),
"Deleting the skills root should fail before requesting authorization",
);
}
#[gpui::test]
async fn test_delete_path_symlink_escape_requests_authorization(cx: &mut TestAppContext) {
init_test(cx);

View file

@ -361,7 +361,7 @@ impl EditToolTest {
abs_path: Path::new("/path/to/root").into(),
rules_file: None,
}];
let project_context = ProjectContext::new(worktrees, Vec::default());
let project_context = ProjectContext::new(worktrees);
let tool_names = tools
.iter()
.map(|tool| tool.name.clone().into())
@ -372,6 +372,7 @@ impl EditToolTest {
model_name: None,
date: chrono::Local::now().format("%Y-%m-%d").to_string(),
user_agents_md: None,
sandboxing: false,
};
let templates = Templates::new();
template.render(&templates)?

View file

@ -220,7 +220,7 @@ impl TerminalToolTest {
abs_path: Path::new("/path/to/root").into(),
rules_file: None,
}];
let project_context = ProjectContext::new(worktrees, Vec::default());
let project_context = ProjectContext::new(worktrees);
let tool_names = tools
.iter()
.map(|tool| tool.name.clone().into())
@ -231,6 +231,7 @@ impl TerminalToolTest {
model_name: None,
date: chrono::Local::now().format("%Y-%m-%d").to_string(),
user_agents_md: None,
sandboxing: false,
};
template.render(&Templates::new())?
};

View file

@ -191,7 +191,7 @@ impl WriteToolTest {
abs_path: Path::new("/path/to/root").into(),
rules_file: None,
}];
let project_context = ProjectContext::new(worktrees, Vec::default());
let project_context = ProjectContext::new(worktrees);
let tool_names = tools
.iter()
.map(|tool| tool.name.clone().into())
@ -202,6 +202,7 @@ impl WriteToolTest {
model_name: None,
date: chrono::Local::now().format("%Y-%m-%d").to_string(),
user_agents_md: None,
sandboxing: false,
};
let templates = Templates::new();
template.render(&templates)?

View file

@ -1,6 +1,7 @@
use super::tool_permissions::{
authorize_symlink_escapes, canonicalize_worktree_roots, collect_symlink_escapes,
sensitive_settings_kind,
resolve_creatable_global_skill_descendant_path, resolve_global_skill_descendant_path,
resolves_to_global_skills_dir, sensitive_settings_kind,
};
use crate::{
AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision,
@ -22,6 +23,7 @@ use util::markdown::MarkdownInlineCode;
/// If the source and destination directories are the same, but the filename is different, this performs a rename. Otherwise, it performs a move.
///
/// This tool should be used when it's desirable to move or rename a file or directory without changing its contents at all.
/// The only supported paths outside the project are descendants of `~/.agents/skills`, for global agent skills.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct MovePathToolInput {
/// The source path of the file or directory to move/rename.
@ -116,6 +118,28 @@ impl AgentTool for MovePathTool {
let fs = project.read_with(cx, |project, _cx| project.fs().clone());
let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await;
if resolves_to_global_skills_dir(Path::new(&input.source_path), fs.as_ref()).await
|| resolves_to_global_skills_dir(
Path::new(&input.destination_path),
fs.as_ref(),
)
.await
{
return Err(
"Cannot move the global agent skills directory itself. Move a skill directory or file beneath it instead."
.to_string(),
);
}
let global_source_path =
resolve_global_skill_descendant_path(Path::new(&input.source_path), fs.as_ref())
.await;
let global_destination_path = resolve_creatable_global_skill_descendant_path(
Path::new(&input.destination_path),
fs.as_ref(),
)
.await;
let symlink_escapes: Vec<(&str, std::path::PathBuf)> =
project.read_with(cx, |project, cx| {
collect_symlink_escapes(
@ -176,6 +200,65 @@ impl AgentTool for MovePathTool {
authorize.await.map_err(|e| e.to_string())?;
}
if global_source_path.is_some() || global_destination_path.is_some() {
let source_path = if let Some(global_source_path) = global_source_path {
global_source_path
} else {
project.read_with(cx, |project, cx| {
let project_path = project.find_project_path(&input.source_path, cx).ok_or_else(|| {
format!("Source path {} was not found in the project.", input.source_path)
})?;
project.entry_for_path(&project_path, cx).ok_or_else(|| {
format!("Source path {} was not found in the project.", input.source_path)
})?;
project.absolute_path(&project_path, cx).ok_or_else(|| {
format!("Source path {} could not be resolved.", input.source_path)
})
})?
};
let destination_path = if let Some(global_destination_path) = global_destination_path
{
global_destination_path
} else {
project.read_with(cx, |project, cx| {
let project_path = project.find_project_path(&input.destination_path, cx).ok_or_else(|| {
format!(
"Destination path {} was outside the project.",
input.destination_path
)
})?;
project.absolute_path(&project_path, cx).ok_or_else(|| {
format!(
"Destination path {} could not be resolved.",
input.destination_path
)
})
})?
};
futures::select! {
result = fs.rename(
&source_path,
&destination_path,
fs::RenameOptions {
create_parents: true,
..fs::RenameOptions::default()
},
).fuse() => {
result.map_err(|e| format!("Moving {} to {}: {e}", input.source_path, input.destination_path))?;
}
_ = event_stream.cancelled_by_user().fuse() => {
return Err("Move cancelled by user".to_string());
}
}
return Ok(format!(
"Moved {} to {}",
input.source_path, input.destination_path
));
}
let rename_task = project.update(cx, |project, cx| {
match project
.find_project_path(&input.source_path, cx)
@ -232,6 +315,125 @@ mod tests {
});
}
#[gpui::test]
async fn test_move_path_global_skill_directory_to_project(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/root/project"), json!({})).await;
let skill_dir = agent_skills::global_skills_dir().join("my-skill");
fs.insert_tree(&skill_dir, json!({ "SKILL.md": "content" }))
.await;
let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
cx.executor().run_until_parked();
let tool = Arc::new(MovePathTool::new(project));
let input_path = PathBuf::from("~")
.join(".agents")
.join("skills")
.join("my-skill")
.to_string_lossy()
.into_owned();
let destination_path = path!("/root/project/my-skill").to_string();
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| {
tool.run(
ToolInput::resolved(MovePathToolInput {
source_path: input_path,
destination_path,
}),
event_stream,
cx,
)
});
let auth = event_rx.expect_authorization().await;
let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
assert!(
title.contains("agent skills"),
"Authorization title should mention agent skills, got: {title}",
);
auth.response
.send(acp_thread::SelectedPermissionOutcome::new(
acp::PermissionOptionId::new("allow"),
acp::PermissionOptionKind::AllowOnce,
))
.expect("authorization response should send");
let result = task.await;
assert!(result.is_ok(), "should move after approval: {result:?}");
assert!(!fs.is_dir(&skill_dir).await);
assert_eq!(
fs.load(path!("/root/project/my-skill/SKILL.md").as_ref())
.await
.unwrap(),
"content"
);
}
#[gpui::test]
async fn test_move_path_project_directory_to_global_skill_directory(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root/project"),
json!({ "exported-skill": { "SKILL.md": "content" } }),
)
.await;
let skills_dir = agent_skills::global_skills_dir();
fs.create_dir(&skills_dir).await.unwrap();
let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
cx.executor().run_until_parked();
let tool = Arc::new(MovePathTool::new(project));
let destination_path = PathBuf::from("~")
.join(".agents")
.join("skills")
.join("exported-skill")
.to_string_lossy()
.into_owned();
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| {
tool.run(
ToolInput::resolved(MovePathToolInput {
source_path: path!("/root/project/exported-skill").to_string(),
destination_path,
}),
event_stream,
cx,
)
});
let auth = event_rx.expect_authorization().await;
let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
assert!(
title.contains("agent skills"),
"Authorization title should mention agent skills, got: {title}",
);
auth.response
.send(acp_thread::SelectedPermissionOutcome::new(
acp::PermissionOptionId::new("allow"),
acp::PermissionOptionKind::AllowOnce,
))
.expect("authorization response should send");
let result = task.await;
assert!(result.is_ok(), "should move after approval: {result:?}");
assert!(
!fs.is_dir(path!("/root/project/exported-skill").as_ref())
.await
);
assert_eq!(
fs.load(skills_dir.join("exported-skill").join("SKILL.md").as_ref())
.await
.unwrap(),
"content"
);
}
#[gpui::test]
async fn test_move_path_symlink_escape_source_requests_authorization(cx: &mut TestAppContext) {
init_test(cx);

View file

@ -14,6 +14,7 @@ use std::{
time::Duration,
};
use crate::sandboxing::sandboxing_enabled;
use crate::{AgentTool, ThreadEnvironment, ToolCallEventStream, ToolInput};
const COMMAND_OUTPUT_LIMIT: u64 = 16 * 1024;
@ -39,7 +40,7 @@ const COMMAND_OUTPUT_LIMIT: u64 = 16 * 1024;
/// - Always insert `--no-pager` immediately after `git` for any read-only git command, including `git log`, `git diff`, `git show`, `git blame`, and `git stash show`. Example: `git --no-pager log -n 5` (NOT `git log -n 5`).
/// - Always prepend `GIT_EDITOR=true ` to any git command that may invoke an editor, including `git rebase`, `git commit`, `git merge`, and `git tag`. Example: `GIT_EDITOR=true git rebase origin/main` (NOT `git rebase origin/main`).
/// - For other commands that may open a pager or editor, set `PAGER=cat` and/or `EDITOR=true` similarly.
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema)]
pub struct TerminalToolInput {
/// The one-liner command to execute. Do not include shell substitutions or interpolations such as `$VAR`, `${VAR}`, `$(...)`, backticks, `$((...))`, `<(...)`, or `>(...)`; resolve those values first or ask the user.
///
@ -49,6 +50,35 @@ pub struct TerminalToolInput {
pub cd: String,
/// Optional maximum runtime (in milliseconds). If exceeded, the running terminal task is killed.
pub timeout_ms: Option<u64>,
/// Request network access for this command.
///
/// Only meaningful when the system prompt's "Terminal sandbox" section
/// is present — ignored otherwise. By default sandboxed commands
/// cannot make outbound network connections; set this to `true` only
/// when the command needs network access. The user will be prompted
/// to approve before the command runs.
#[serde(default)]
pub allow_network: Option<bool>,
/// Request unrestricted filesystem-write access for this command.
///
/// Only meaningful when the system prompt's "Terminal sandbox" section
/// is present — ignored otherwise. By default sandboxed commands can
/// only write to the project worktree directories and a per-command
/// temporary directory; set this to `true` only when the command
/// needs to write elsewhere. The user will be prompted to approve
/// before the command runs.
#[serde(default)]
pub allow_fs_write: Option<bool>,
/// Request to run this command outside the sandbox entirely.
///
/// Only meaningful when the system prompt's "Terminal sandbox" section
/// is present — ignored otherwise. Prefer `allow_network: true` or
/// `allow_fs_write: true` when one of those is enough. Set this to
/// `true` ONLY when the command needs behavior that the sandbox can't
/// grant on a per-permission basis. The user will be prompted to
/// approve before the command runs without sandbox restrictions.
#[serde(default)]
pub unsandboxed: Option<bool>,
}
pub struct TerminalTool {
@ -96,24 +126,100 @@ impl AgentTool for TerminalTool {
cx.spawn(async move |cx| {
let input = input.recv().await.map_err(|e| e.to_string())?;
let (working_dir, authorize) = cx.update(|cx| {
let (working_dir, authorize, sandboxing) = cx.update(|cx| {
let working_dir =
working_dir(&input, &self.project, cx).map_err(|err| err.to_string())?;
let context =
crate::ToolPermissionContext::new(Self::NAME, vec![input.command.clone()]);
let authorize =
event_stream.authorize(self.initial_title(Ok(input.clone()), cx), context, cx);
Result::<_, String>::Ok((working_dir, authorize))
let sandboxing = sandboxing_enabled(cx);
Result::<_, String>::Ok((working_dir, authorize, sandboxing))
})?;
authorize.await.map_err(|e| e.to_string())?;
// Sandbox flags only do anything when sandboxing is on. When
// off, we treat them as `None` so the model can't surreptitiously
// change runtime behavior by setting flags described as a no-op
// in the system prompt.
let want_network = sandboxing && input.allow_network == Some(true);
let want_fs_write = sandboxing && input.allow_fs_write == Some(true);
let want_unsandboxed = sandboxing && input.unsandboxed == Some(true);
// `unsandboxed: true` bypasses the wrap entirely; per-permission
// requests are only meaningful when the command is still being
// sandboxed.
let escalate = !want_unsandboxed && (want_network || want_fs_write);
if want_unsandboxed || escalate {
let title = sandbox_approval_title(want_network, want_fs_write, want_unsandboxed);
let approve = cx.update(|cx| {
let context = crate::ToolPermissionContext::new(
Self::NAME,
vec![input.command.clone()],
);
// Sandbox escalations always prompt, even if the user
// has `always_allow` rules for this command — the
// escalation is a stronger trust boundary than the
// baseline command approval.
event_stream.authorize_always_prompt(title, context, cx)
});
if let Err(error) = approve.await {
return Ok(if want_unsandboxed {
format!(
"Command cancelled: user denied permission to run outside the sandbox ({error})."
)
} else {
format!(
"Command cancelled: user denied the requested sandbox permissions ({error})."
)
});
}
}
// The per-thread scratch directory (and the `$TMPDIR`/`TMP`/
// `TEMP` environment variables pointing at it) is provisioned by
// the thread environment in `create_terminal`, which also adds it
// to the sandbox's writable scope. We must not set `$TMPDIR` here:
// the environment overrides it with the per-thread directory, so a
// per-command directory set here would never be the `$TMPDIR` the
// command actually sees and would be left out of the writable
// scope, breaking writes into `$TMPDIR`.
let extra_env = Vec::new();
// Build the writable scope from the project's worktrees. The
// per-thread temp directory is appended by the thread environment
// (which owns it and points `$TMPDIR` at it). Crucially we do
// *not* include the resolved `cd` working directory — that's
// model-controlled, and using it as the writable scope would
// let the model widen its own write permissions outside the
// project.
let sandbox_wrap = if sandboxing && !want_unsandboxed {
let writable_paths: Vec<PathBuf> = cx.update(|cx| {
self.project
.read(cx)
.worktrees(cx)
.map(|w| w.read(cx).abs_path().to_path_buf())
.collect::<Vec<_>>()
});
Some(acp_thread::SandboxWrap {
writable_paths,
allow_network: want_network,
allow_fs_write: want_fs_write,
})
} else {
None
};
let terminal = self
.environment
.create_terminal(
input.command.clone(),
extra_env,
working_dir,
Some(COMMAND_OUTPUT_LIMIT),
sandbox_wrap,
cx,
)
.await
@ -182,6 +288,29 @@ impl AgentTool for TerminalTool {
}
}
/// User-facing title for the sandbox-escalation approval prompt.
///
/// `want_unsandboxed` wins over the per-permission flags because
/// `unsandboxed: true` bypasses the per-permission machinery entirely.
fn sandbox_approval_title(
want_network: bool,
want_fs_write: bool,
want_unsandboxed: bool,
) -> &'static str {
if want_unsandboxed {
"Allow this command to run outside the sandbox?"
} else {
match (want_network, want_fs_write) {
(true, true) => "Allow network access and arbitrary filesystem writes?",
(true, false) => "Allow network access?",
(false, true) => "Allow arbitrary filesystem writes?",
// Caller only invokes this when at least one flag is set, so
// this fallback is unreachable in practice.
(false, false) => "Allow this command to run?",
}
}
}
fn process_content(
output: acp::TerminalOutputResponse,
command: &str,
@ -310,6 +439,7 @@ mod tests {
.to_string(),
cd: ".".to_string(),
timeout_ms: None,
..Default::default()
};
let title = format_initial_title(Ok(input));
@ -369,6 +499,7 @@ mod tests {
command: cmd.to_string(),
cd: ".".to_string(),
timeout_ms: None,
..Default::default()
};
let title = format_initial_title(Ok(input));
@ -406,6 +537,7 @@ mod tests {
command: "echo 'hello world'".to_string(),
cd: ".".to_string(),
timeout_ms: None,
..Default::default()
};
let title = format_initial_title(Ok(input));
@ -435,6 +567,7 @@ mod tests {
command: long_command,
cd: ".".to_string(),
timeout_ms: None,
..Default::default()
};
let title = format_initial_title(Ok(input));
@ -641,6 +774,7 @@ mod tests {
command: "echo $HOME".to_string(),
cd: "root".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -708,6 +842,7 @@ mod tests {
command: "echo $HOME".to_string(),
cd: "root".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -769,6 +904,7 @@ mod tests {
command: "echo $(rm -rf /)".to_string(),
cd: "root".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -838,6 +974,7 @@ mod tests {
command: "PAGER=blah git log --oneline".to_string(),
cd: "root".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -911,6 +1048,7 @@ mod tests {
command: "PAGER=blah git log".to_string(),
cd: "root".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -1018,6 +1156,7 @@ mod tests {
command: command.to_string(),
cd: "root".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -1185,6 +1324,7 @@ mod tests {
command: "echo $(whoami)".to_string(),
cd: "root".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -1257,6 +1397,7 @@ mod tests {
command: "PAGER=other git log".to_string(),
cd: "root".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -1323,6 +1464,7 @@ mod tests {
command: "A=1 B=2 git log".to_string(),
cd: "root".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -1400,6 +1542,7 @@ mod tests {
command: "PAGER=\"less -R\" git log".to_string(),
cd: "root".to_string(),
timeout_ms: None,
..Default::default()
}),
event_stream,
cx,
@ -1428,4 +1571,72 @@ mod tests {
"unexpected terminal result: {result}"
);
}
#[test]
fn test_sandbox_approval_title_unsandboxed_wins() {
// `unsandboxed: true` skips the sandbox entirely, so the title should
// reflect that even when other flags are also set — they're moot.
assert_eq!(
sandbox_approval_title(true, true, true),
"Allow this command to run outside the sandbox?"
);
assert_eq!(
sandbox_approval_title(false, false, true),
"Allow this command to run outside the sandbox?"
);
}
#[test]
fn test_sandbox_approval_title_per_permission_flags() {
assert_eq!(
sandbox_approval_title(true, true, false),
"Allow network access and arbitrary filesystem writes?"
);
assert_eq!(
sandbox_approval_title(true, false, false),
"Allow network access?"
);
assert_eq!(
sandbox_approval_title(false, true, false),
"Allow arbitrary filesystem writes?"
);
}
#[test]
fn test_input_schema_includes_sandbox_flags() {
// The model only sees these fields when the sandboxing prompt
// section is rendered, but they're always present in the schema so
// input validation doesn't reject them when sent. Guard against
// accidentally renaming or removing them.
let schema = serde_json::to_string(&schemars::schema_for!(TerminalToolInput))
.expect("input schema should serialize");
assert!(
schema.contains("allow_network"),
"schema should advertise allow_network: {schema}"
);
assert!(
schema.contains("allow_fs_write"),
"schema should advertise allow_fs_write: {schema}"
);
assert!(
schema.contains("unsandboxed"),
"schema should advertise unsandboxed: {schema}"
);
}
#[test]
fn test_sandbox_flags_default_to_none_when_absent() {
// The model is expected to omit the sandbox fields entirely on most
// calls. Make sure deserialization doesn't reject the minimal
// payload and that the fields default to `None` (which the tool
// interprets as "no escalation requested").
let input: TerminalToolInput = serde_json::from_value(serde_json::json!({
"command": "echo hi",
"cd": ".",
}))
.expect("minimal input should deserialize");
assert_eq!(input.allow_network, None);
assert_eq!(input.allow_fs_write, None);
assert_eq!(input.unsandboxed, None);
}
}

View file

@ -199,6 +199,56 @@ pub async fn resolve_creatable_global_skill_path(path: &Path, fs: &dyn Fs) -> Op
}
}
fn is_strict_descendant(path: &Path, ancestor: &Path) -> bool {
path != ancestor && path.starts_with(ancestor)
}
/// Returns whether `path` resolves to the global agent skills directory itself.
///
/// This is used by destructive tools to reject operations targeting the root
/// `~/.agents/skills` directory while still allowing operations on individual
/// skills or resources beneath it.
pub async fn resolves_to_global_skills_dir(path: &Path, fs: &dyn Fs) -> bool {
let Some(normalized_path) = resolve_lexical_global_skill_path(path) else {
return false;
};
let Some(canonical_path) = canonicalize_with_ancestors(&normalized_path, fs).await else {
return false;
};
let Some(canonical_skills_dir) = canonical_global_skills_dir(fs).await else {
return false;
};
canonical_path == canonical_skills_dir
}
/// Filters a previously-resolved global skills path so that callers which
/// must never act on `~/.agents/skills` itself (move, delete) only see paths
/// that point strictly below the skills root.
async fn restrict_to_skill_descendant(
canonical_path: Option<PathBuf>,
fs: &dyn Fs,
) -> Option<PathBuf> {
let canonical_path = canonical_path?;
let canonical_skills_dir = canonical_global_skills_dir(fs).await?;
is_strict_descendant(&canonical_path, &canonical_skills_dir).then_some(canonical_path)
}
/// Like [`resolve_global_skill_path`], but only succeeds for paths strictly
/// below `~/.agents/skills`, not the skills directory itself.
pub async fn resolve_global_skill_descendant_path(path: &Path, fs: &dyn Fs) -> Option<PathBuf> {
restrict_to_skill_descendant(resolve_global_skill_path(path, fs).await, fs).await
}
/// Like [`resolve_creatable_global_skill_path`], but only succeeds for paths
/// strictly below `~/.agents/skills`, not the skills directory itself.
pub async fn resolve_creatable_global_skill_descendant_path(
path: &Path,
fs: &dyn Fs,
) -> Option<PathBuf> {
restrict_to_skill_descendant(resolve_creatable_global_skill_path(path, fs).await, fs).await
}
/// Returns the kind of sensitive settings or agent skills location this path targets, if any:
/// either inside a `.zed/` local-settings directory, inside `.agents/skills/`, or inside
/// the global config dir.

View file

@ -21,6 +21,7 @@ futures.workspace = true
gpui.workspace = true
language_model.workspace = true
log.workspace = true
paths.workspace = true
project.workspace = true
regex.workspace = true
schemars.workspace = true
@ -31,7 +32,6 @@ util.workspace = true
[dev-dependencies]
fs.workspace = true
gpui = { workspace = true, features = ["test-support"] }
paths.workspace = true
serde_json_lenient.workspace = true
serde_json.workspace = true

View file

@ -1,4 +1,5 @@
mod agent_profile;
mod user_agents_md;
use std::path::{Component, Path};
use std::sync::{Arc, LazyLock};
@ -20,6 +21,7 @@ use settings::{
};
pub use crate::agent_profile::*;
pub use crate::user_agents_md::{UserAgentsMd, UserAgentsMdState, init as init_user_agents_md};
pub const SUMMARIZE_THREAD_PROMPT: &str = include_str!("prompts/summarize_thread_prompt.txt");
pub const SUMMARIZE_THREAD_DETAILED_PROMPT: &str =

View file

@ -4,7 +4,6 @@ use fs::Fs;
use futures::StreamExt;
use gpui::{Global, SharedString};
use serde::{Deserialize, Serialize};
use std::io::{self, Read};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use util::paths::component_matches_ignore_ascii_case;
@ -41,7 +40,6 @@ pub struct SkillScopeId(pub usize);
/// entries would fan out an equally large number of concurrent OS-level I/O
/// operations, potentially exhausting file descriptors or stalling the app.
const SKILL_IO_CONCURRENCY: usize = 16;
const SKILL_READ_CHUNK_SIZE: usize = 4096;
/// Maximum size for a single SKILL.md file (100KB)
pub const MAX_SKILL_FILE_SIZE: usize = 100 * 1024;
@ -558,53 +556,15 @@ async fn find_skill_files(fs: &Arc<dyn Fs>, directory: &Path) -> Vec<PathBuf> {
.await
}
/// Returns the byte index ONE PAST the end of the closing frontmatter
/// delimiter line in `bytes`, or `None` if no closing delimiter has been
/// seen yet. Used by the chunked reader to know when it has enough
/// bytes to stop pulling from disk.
/// Read `skill_file_path` from disk and parse its frontmatter. The
/// SKILL.md body is parsed away by `parse_skill_frontmatter` and not
/// surfaced here; it's re-read on demand via `read_skill_body` when a
/// skill is actually being loaded for the model.
///
/// Scans for the first `\n---` line followed by `\n`, `\r\n`, or EOF
/// (excluding the opening line itself, which sits at byte 0 and is
/// naturally skipped because we only consider lines following a `\n`).
/// This may overshoot in pathological cases (e.g. `---` inside a quoted
/// YAML string), but `parse_skill_frontmatter`'s candidate-and-validate
/// logic still produces a correct result or a YAML parse error.
fn closing_delimiter_end(bytes: &[u8]) -> Option<usize> {
for (i, &b) in bytes.iter().enumerate() {
if b != b'\n' {
continue;
}
let line_start = i + 1;
if line_start + 3 > bytes.len() {
continue;
}
if &bytes[line_start..line_start + 3] != b"---" {
continue;
}
let after_dashes = line_start + 3;
if after_dashes == bytes.len() {
return Some(after_dashes);
}
if bytes[after_dashes] == b'\n' {
return Some(after_dashes + 1);
}
if after_dashes + 1 < bytes.len()
&& bytes[after_dashes] == b'\r'
&& bytes[after_dashes + 1] == b'\n'
{
return Some(after_dashes + 2);
}
// Line is `---trailing` or `----`; keep scanning.
}
None
}
/// Read just enough of `skill_file_path` from disk to parse its
/// frontmatter. The SKILL.md body is NOT loaded — that's deferred to
/// `read_skill_body`, called only when a skill is actually being
/// materialized for the model. Reading in 4KB chunks keeps the peak
/// memory cost of loading N skills proportional to total frontmatter
/// size, not total file size.
/// We load the whole file in one go rather than streaming up to the
/// closing `---`. `MAX_SKILL_FILE_SIZE` is 100KB and the metadata check
/// below caps the worst case at that, so the peak transient cost is
/// trivially small (≤ `MAX_SKILL_FILE_SIZE` × `SKILL_IO_CONCURRENCY`).
pub async fn load_skill_frontmatter(
fs: Arc<dyn Fs>,
skill_file_path: PathBuf,
@ -612,10 +572,15 @@ pub async fn load_skill_frontmatter(
) -> Result<Skill, SkillLoadError> {
// Short-circuit on oversized files before reading any of their
// contents, so a stray multi-GB file named `SKILL.md` can't OOM the
// app. We only act on a positive signal that the file is too large;
// if metadata fails or is unavailable, we fall through to the read
// loop, which is itself capped at `MAX_SKILL_FILE_SIZE`.
if let Ok(Some(metadata)) = fs.metadata(&skill_file_path).await
// app. If metadata is unavailable, refuse to read.
let metadata = fs
.metadata(&skill_file_path)
.await
.map_err(|e| SkillLoadError {
path: skill_file_path.clone(),
message: format!("Failed to read SKILL.md metadata: {}", e),
})?;
if let Some(metadata) = metadata
&& metadata.len > MAX_SKILL_FILE_SIZE as u64
{
return Err(SkillLoadError {
@ -627,54 +592,15 @@ pub async fn load_skill_frontmatter(
});
}
let mut reader = fs
.open_sync(&skill_file_path)
let content = fs
.load(&skill_file_path)
.await
.map_err(|e| SkillLoadError {
path: skill_file_path.clone(),
message: format!("Failed to open file: {}", e),
message: format!("Failed to read file: {}", e),
})?;
// The chunked read is intentionally synchronous: `Fs::open_sync`
// returns a synchronous `Read` (RealFs uses `std::fs::File`), and
// production callers already wrap `load_skills_from_directory` in
// `cx.background_spawn`, so any blocking happens on the background
// executor — not on the foreground thread. Routing through
// `smol::unblock` instead would schedule the work on smol's blocking
// pool, whose wakeups don't drive GPUI's test scheduler and therefore
// panic with "Parking forbidden" under `TestAppContext`.
let read_result: Result<Vec<u8>, io::Error> = (|| {
let mut accumulated: Vec<u8> = Vec::new();
let mut chunk = [0u8; SKILL_READ_CHUNK_SIZE];
loop {
let n = reader.read(&mut chunk)?;
if n == 0 {
break;
}
accumulated.extend_from_slice(&chunk[..n]);
if let Some(end) = closing_delimiter_end(&accumulated) {
// Discard body bytes swept up in the last chunk so that e.g. multi-byte
// graphemes split at the boundary won't cause `str::from_utf8` to fail.
accumulated.truncate(end);
break;
}
if accumulated.len() > MAX_SKILL_FILE_SIZE {
break;
}
}
Ok(accumulated)
})();
let accumulated = read_result.map_err(|e| SkillLoadError {
path: skill_file_path.clone(),
message: format!("Failed to read file: {}", e),
})?;
let content = std::str::from_utf8(&accumulated).map_err(|e| SkillLoadError {
path: skill_file_path.clone(),
message: format!("SKILL.md is not valid UTF-8: {}", e),
})?;
parse_skill_frontmatter(&skill_file_path, content, source).map_err(|e| SkillLoadError {
parse_skill_frontmatter(&skill_file_path, &content, source).map_err(|e| SkillLoadError {
path: skill_file_path.clone(),
message: e.to_string(),
})
@ -1836,46 +1762,6 @@ description: A skill with no body content
assert_eq!(skill.directory_path, PathBuf::from("/skills/my-skill"));
}
#[gpui::test]
async fn test_load_skill_frontmatter_with_emoji_at_chunk_boundary(cx: &mut TestAppContext) {
// We must be able to load skill frontmatter even when a
// multipoint grapheme crosses the chunk read boundary.
let fs = FakeFs::new(cx.executor());
let frontmatter = "---\nname: my-skill\ndescription: Example skill testing multipoint graphemes at chunk boundary\n---\n";
// Pad contents so that the emoji's first byte lands
// at the last byte of the first read chunk.
let padding = "a".repeat(SKILL_READ_CHUNK_SIZE - frontmatter.len() - 1);
let content = format!("{frontmatter}{padding}");
assert!(
(frontmatter.len() + padding.len()) < SKILL_READ_CHUNK_SIZE,
"emoji must start before the second chunk"
);
assert!(
content.len() > SKILL_READ_CHUNK_SIZE,
"skill is longer than a chunk, so we know that the emoji crosses chunk boundaries"
);
fs.insert_tree(
"/skills",
serde_json::json!({
"my-skill": {
"SKILL.md": content,
}
}),
)
.await;
load_skill_frontmatter(
fs as Arc<dyn Fs>,
PathBuf::from("/skills/my-skill/SKILL.md"),
SkillSource::Global,
)
.await
.expect("frontmatter should parse even when a multipoint grapheme such as an emoji crosses the byte chunk boundary");
}
#[gpui::test]
async fn test_read_skill_body_returns_trimmed_body(cx: &mut TestAppContext) {
let fs = FakeFs::new(cx.executor());

View file

@ -10,9 +10,10 @@ use std::{
};
use acp_thread::{AcpThread, AcpThreadEvent, MentionUri, ThreadStatus};
use agent::{ContextServerRegistry, SharedThread, ThreadStore, UserAgentsMd};
use agent::{ContextServerRegistry, SharedThread, ThreadStore};
use agent_client_protocol::schema as acp;
use agent_servers::AgentServer;
use agent_settings::UserAgentsMd;
use collections::HashSet;
use db::kvp::{Dismissable, KeyValueStore};
use itertools::Itertools;
@ -28,7 +29,8 @@ use zed_actions::{
ResolveConflictsWithAgent, ReviewBranchDiff,
},
assistant::{
CreateSkillFromUrl, FocusAgent, OpenRulesLibrary, OpenSkillCreator, Toggle, ToggleFocus,
CreateSkillFromUrl, FocusAgent, OpenGlobalAgentsMdRules, OpenProjectAgentsMdRules,
OpenRulesLibrary, OpenSkillCreator, Toggle, ToggleFocus,
},
};
@ -57,7 +59,7 @@ use anyhow::Result;
#[cfg(feature = "audio")]
use audio::{Audio, Sound};
use chrono::{DateTime, Utc};
use client::UserStore;
use client::{UserStore, zed_urls};
use cloud_api_types::Plan;
use collections::HashMap;
use editor::{Editor, MultiBuffer};
@ -97,6 +99,45 @@ const MIN_PANEL_WIDTH: Pixels = px(300.);
const LAST_USED_AGENT_KEY: &str = "agent_panel__last_used_external_agent";
const LAST_CREATED_ENTRY_KIND_KEY: &str = "agent_panel__last_created_entry_kind";
const TERMINAL_AGENT_TELEMETRY_ID: &str = "terminal";
const KNOWN_TERMINAL_AGENT_COMMANDS: &[&str] = &[
"agent", // Unfortunately, both Cursor cli + grok
"agy",
"aider",
"amp",
"claude",
"codex",
"copilot",
"crush",
"devin",
"droid",
"gemini",
"goose",
"grok",
"openhands",
"opencode",
"pi",
"qwen",
];
fn is_known_terminal_agent_command(command: &str) -> bool {
KNOWN_TERMINAL_AGENT_COMMANDS.contains(&command)
}
fn terminal_program_to_report(
last_observed_program: &mut Option<String>,
current_program: Option<String>,
) -> Option<String> {
let current_program =
current_program.filter(|program| is_known_terminal_agent_command(program));
let program_to_report =
if current_program.is_some() && current_program != *last_observed_program {
current_program.clone()
} else {
None
};
*last_observed_program = current_program;
program_to_report
}
/// Maximum number of idle threads kept in the agent panel's retained list.
/// Set as a GPUI global to override; otherwise defaults to 5.
@ -178,6 +219,60 @@ fn read_global_last_created_entry_kind(kvp: &KeyValueStore) -> Option<AgentPanel
.map(|entry| entry.entry_kind)
}
fn project_agents_md_path(
project: &Entity<Project>,
require_existing_file: bool,
cx: &App,
) -> Option<PathBuf> {
let rel_path = util::rel_path::RelPath::unix("AGENTS.md").ok()?;
project
.read(cx)
.visible_worktrees(cx)
.next()
.and_then(|worktree| {
let worktree = worktree.read(cx);
if require_existing_file {
let entry = worktree.entry_for_path(rel_path)?;
if !entry.is_file() {
return None;
}
}
Some(worktree.absolutize(rel_path))
})
}
fn open_global_rules(workspace: &mut Workspace, window: &mut Window, cx: &mut Context<Workspace>) {
workspace
.open_abs_path(
paths::agents_file().clone(),
workspace::OpenOptions {
focus: Some(true),
..Default::default()
},
window,
cx,
)
.detach_and_log_err(cx);
}
fn open_project_rules(workspace: &mut Workspace, window: &mut Window, cx: &mut Context<Workspace>) {
if let Some(path) = project_agents_md_path(workspace.project(), false, cx) {
workspace
.open_abs_path(
path,
workspace::OpenOptions {
focus: Some(true),
..Default::default()
},
window,
cx,
)
.detach_and_log_err(cx);
}
}
async fn write_global_last_created_entry_kind(kvp: KeyValueStore, entry_kind: AgentPanelEntryKind) {
if let Some(json) = serde_json::to_string(&LastCreatedEntryKind { entry_kind }).log_err() {
kvp.write_kvp(LAST_CREATED_ENTRY_KIND_KEY.to_string(), json)
@ -314,6 +409,12 @@ pub fn init(cx: &mut App) {
});
}
})
.register_action(|workspace, _: &OpenGlobalAgentsMdRules, window, cx| {
open_global_rules(workspace, window, cx);
})
.register_action(|workspace, _: &OpenProjectAgentsMdRules, window, cx| {
open_project_rules(workspace, window, cx);
})
.register_action(|workspace, action: &OpenSkillCreator, window, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
workspace.focus_panel::<AgentPanel>(window, cx);
@ -769,6 +870,7 @@ struct AgentTerminal {
title_editor_initial_title: Option<String>,
title_editor_subscription: Option<Subscription>,
last_known_title: String,
last_observed_program: Option<String>,
working_directory: Option<PathBuf>,
created_at: DateTime<Utc>,
has_notification: bool,
@ -827,6 +929,34 @@ impl AgentTerminal {
fn custom_title(&self, cx: &App) -> Option<SharedString> {
self.view.read(cx).custom_title().map(SharedString::from)
}
fn report_started_terminal_program(
&mut self,
terminal_id: TerminalId,
source: AgentThreadSource,
cx: &App,
) {
let current_program = self
.view
.read(cx)
.terminal()
.read(cx)
.foreground_process_command_name();
if let Some(program) =
terminal_program_to_report(&mut self.last_observed_program, current_program)
{
telemetry::event!(
"Agent Terminal Program Started",
agent = TERMINAL_AGENT_TELEMETRY_ID,
terminal_id = terminal_id.to_key_string(),
program = program,
source = source.as_str(),
side = crate::agent_sidebar_side(cx),
thread_location = "current_worktree",
);
}
}
}
enum BaseView {
@ -1838,6 +1968,7 @@ impl AgentPanel {
| TerminalEvent::Wakeup
| TerminalEvent::BreadcrumbsChanged => {
this.refresh_terminal_metadata(terminal_id, cx);
this.report_terminal_program(terminal_id, source, cx);
}
TerminalEvent::Bell => this.mark_terminal_notification(terminal_id, window, cx),
TerminalEvent::CloseTerminal => {
@ -1858,6 +1989,7 @@ impl AgentPanel {
last_known_title: initial_title
.map(|title| title.to_string())
.unwrap_or_default(),
last_observed_program: None,
working_directory,
created_at: created_at.unwrap_or_else(Utc::now),
has_notification: false,
@ -1869,9 +2001,10 @@ impl AgentPanel {
self.pending_terminal_spawn = None;
}
terminal.refresh_metadata(cx);
terminal.report_started_terminal_program(terminal_id, source, cx);
self.terminals.insert(terminal_id, terminal);
self.persist_terminal_metadata(terminal_id, cx);
self.emit_terminal_thread_started(source, cx);
self.emit_terminal_thread_started(terminal_id, source, cx);
if select {
self.set_base_view(BaseView::Terminal { terminal_id }, focus, window, cx);
}
@ -1966,10 +2099,16 @@ impl AgentPanel {
self.close_terminal_internal(terminal_id, false, metadata, window, cx);
}
fn emit_terminal_thread_started(&self, source: AgentThreadSource, cx: &App) {
fn emit_terminal_thread_started(
&self,
terminal_id: TerminalId,
source: AgentThreadSource,
cx: &App,
) {
telemetry::event!(
"Agent Thread Started",
agent = TERMINAL_AGENT_TELEMETRY_ID,
terminal_id = terminal_id.to_key_string(),
source = source.as_str(),
side = crate::agent_sidebar_side(cx),
thread_location = "current_worktree",
@ -1986,6 +2125,17 @@ impl AgentPanel {
}
}
fn report_terminal_program(
&mut self,
terminal_id: TerminalId,
source: AgentThreadSource,
cx: &mut Context<Self>,
) {
if let Some(terminal) = self.terminals.get_mut(&terminal_id) {
terminal.report_started_terminal_program(terminal_id, source, cx);
}
}
fn persist_all_terminal_metadata(&self, cx: &mut Context<Self>) {
let terminal_ids = self.terminals.keys().copied().collect::<Vec<_>>();
for terminal_id in terminal_ids {
@ -4861,21 +5011,7 @@ impl AgentPanel {
.active_conversation_view()
.is_some_and(|conversation_view| conversation_view.read(cx).supports_logout());
let project_agents_md_path: Option<PathBuf> = self
.project
.read(cx)
.visible_worktrees(cx)
.next()
.and_then(|worktree| {
let worktree = worktree.read(cx);
let rel_path = util::rel_path::RelPath::unix("AGENTS.md").ok()?;
let entry = worktree.entry_for_path(rel_path)?;
if entry.is_file() {
Some(worktree.absolutize(rel_path))
} else {
None
}
});
let project_agents_md_path = project_agents_md_path(&self.project, true, cx);
let global_agents_md_loaded = UserAgentsMd::global(cx)
.and_then(|md| md.content())
@ -4958,55 +5094,59 @@ impl AgentPanel {
if global_agents_md_loaded {
let workspace = workspace.clone();
menu = menu.entry(
"Open Global AGENTS.md",
None,
menu = menu.custom_entry(
|_window, _cx| {
h_flex()
.w_full()
.gap_1()
.child(Label::new("Open Global Rules"))
.child(
Label::new("(AGENTS.md)")
.color(Color::Muted)
.size(LabelSize::Small),
)
.into_any_element()
},
move |window, cx| {
workspace
.update(cx, |workspace, cx| {
workspace
.open_abs_path(
paths::agents_file().clone(),
workspace::OpenOptions {
focus: Some(true),
..Default::default()
},
window,
cx,
)
.detach_and_log_err(cx);
open_global_rules(workspace, window, cx);
})
.log_err();
},
);
}
if let Some(path) = project_agents_md_path.clone() {
if project_agents_md_path.is_some() {
let workspace = workspace.clone();
menu = menu.entry(
"Open Project AGENTS.md",
None,
menu = menu.custom_entry(
|_window, _cx| {
h_flex()
.w_full()
.gap_1()
.child(Label::new("Open Project Rules"))
.child(
Label::new("(AGENTS.md)")
.color(Color::Muted)
.size(LabelSize::Small),
)
.into_any_element()
},
move |window, cx| {
let path = path.clone();
workspace
.update(cx, |workspace, cx| {
workspace
.open_abs_path(
path,
workspace::OpenOptions {
focus: Some(true),
..Default::default()
},
window,
cx,
)
.detach_and_log_err(cx);
open_project_rules(workspace, window, cx);
})
.log_err();
},
);
}
menu = menu.entry("Rules Library", None, |_window, cx| {
cx.open_url(&zed_urls::rules_docs(cx));
});
menu = menu.separator();
}
@ -6188,6 +6328,51 @@ mod tests {
use std::sync::Arc;
use std::time::Instant;
#[test]
fn test_is_known_terminal_agent_command() {
assert!(is_known_terminal_agent_command("claude"));
assert!(is_known_terminal_agent_command("codex"));
assert!(!is_known_terminal_agent_command("cargo"));
assert!(!is_known_terminal_agent_command("internal-agent"));
}
#[test]
fn test_terminal_program_reports_known_agent_transitions() {
let mut last_observed_program = None;
assert_eq!(
terminal_program_to_report(&mut last_observed_program, Some("codex".to_string())),
Some("codex".to_string())
);
assert_eq!(
terminal_program_to_report(&mut last_observed_program, Some("codex".to_string())),
None
);
assert_eq!(
terminal_program_to_report(&mut last_observed_program, Some("zsh".to_string())),
None
);
assert_eq!(
terminal_program_to_report(
&mut last_observed_program,
Some("customer-data-export".to_string())
),
None
);
assert_eq!(
terminal_program_to_report(&mut last_observed_program, Some("codex".to_string())),
Some("codex".to_string())
);
assert_eq!(
terminal_program_to_report(&mut last_observed_program, None),
None
);
assert_eq!(
terminal_program_to_report(&mut last_observed_program, Some("codex".to_string())),
Some("codex".to_string())
);
}
#[derive(Clone, Default)]
struct SessionTrackingConnection {
next_session_number: Arc<Mutex<usize>>,
@ -9696,6 +9881,7 @@ mod tests {
thinking_effort: None,
draft_prompt: None,
ui_scroll_position: None,
sandboxed_terminal_temp_dir: None,
};
let thread_store = cx.update(|cx| ThreadStore::global(cx));

View file

@ -688,7 +688,6 @@ fn update_command_palette_filter(cx: &mut App) {
TypeId::of::<AcceptEditPrediction>(),
TypeId::of::<AcceptNextWordEditPrediction>(),
TypeId::of::<AcceptNextLineEditPrediction>(),
TypeId::of::<AcceptEditPrediction>(),
TypeId::of::<ShowEditPrediction>(),
TypeId::of::<NextEditPrediction>(),
TypeId::of::<PreviousEditPrediction>(),
@ -910,6 +909,14 @@ mod tests {
!filter.is_hidden(&zed_actions::assistant::CreateSkillFromUrl),
"CreateSkillFromUrl should be visible by default"
);
assert!(
!filter.is_hidden(&zed_actions::assistant::OpenGlobalAgentsMdRules),
"OpenGlobalAgentsMdRules should be visible by default"
);
assert!(
!filter.is_hidden(&zed_actions::assistant::OpenProjectAgentsMdRules),
"OpenProjectAgentsMdRules should be visible by default"
);
});
// Disable agent
@ -933,6 +940,14 @@ mod tests {
filter.is_hidden(&NewTerminalThread),
"NewTerminalThread should be hidden when agent is disabled"
);
assert!(
filter.is_hidden(&zed_actions::assistant::OpenGlobalAgentsMdRules),
"OpenGlobalAgentsMdRules should be hidden when agent is disabled"
);
assert!(
filter.is_hidden(&zed_actions::assistant::OpenProjectAgentsMdRules),
"OpenProjectAgentsMdRules should be hidden when agent is disabled"
);
});
// Test EditPredictionProvider

View file

@ -8,7 +8,8 @@ use agent_client_protocol::schema as acp;
use std::cell::RefCell;
use acp_thread::{ContentBlock, PlanEntry};
use agent::{SkillLoadingError, SkillLoadingErrorsUpdated, UserAgentsMd};
use agent::{SkillLoadingError, SkillLoadingErrorsUpdated};
use agent_settings::UserAgentsMd;
use cloud_api_types::{SubmitAgentThreadFeedbackBody, SubmitAgentThreadFeedbackCommentsBody};
use editor::actions::OpenExcerpts;
use feature_flags::AcpBetaFeatureFlag;
@ -9463,17 +9464,15 @@ impl ThreadView {
.map(|name| name.to_string_lossy().to_string())
.unwrap_or_else(|| "one folder".to_string());
let description = format!(
"This agent only operates on \"{}\". Other folders in this workspace are not accessible to it.",
active_dir
);
Some(
Callout::new()
.severity(Severity::Warning)
.icon(IconName::Warning)
.title("External Agents currently don't support multi-root workspaces")
.description(description)
.title("This agent doesn't currently support multi-root workspaces")
.description(format!(
"It currently only operates by default on \"{}\".",
active_dir
))
.border_position(ui::BorderPosition::Bottom)
.dismiss_action(
IconButton::new("dismiss-multi-root-callout", IconName::Close)

View file

@ -1831,6 +1831,7 @@ mod tests {
thinking_effort: None,
draft_prompt: None,
ui_scroll_position: None,
sandboxed_terminal_temp_dir: None,
}
}

View file

@ -333,6 +333,9 @@ impl Status {
struct ClientState {
credentials: Option<Credentials>,
status: (watch::Sender<Status>, watch::Receiver<Status>),
/// Bumped each time the cloud websocket finishes its handshake. Starts at `0` so
/// subscribers can distinguish "no connection yet" from a real reconnect.
cloud_connection_id: (watch::Sender<u64>, watch::Receiver<u64>),
_reconnect_task: Option<Task<()>>,
_cloud_connection_task: Option<Task<()>>,
}
@ -435,6 +438,7 @@ impl Default for ClientState {
Self {
credentials: None,
status: watch::channel_with(Status::SignedOut),
cloud_connection_id: watch::channel_with(0),
_reconnect_task: None,
_cloud_connection_task: None,
}
@ -668,6 +672,14 @@ impl Client {
self.state.read().status.1.clone()
}
/// Watches successful cloud websocket reconnections.
///
/// The value is bumped each time the websocket handshake completes. The
/// initial `0` means no reconnection yet.
pub fn cloud_connection_id(&self) -> watch::Receiver<u64> {
self.state.read().cloud_connection_id.1.clone()
}
fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
log::info!("set status on client {}: {:?}", self.id(), status);
let mut state = self.state.write();
@ -1006,6 +1018,12 @@ impl Client {
let (mut messages, _cloud_io_task) = cx.update(|cx| connection.spawn(cx));
{
let mut state = self.state.write();
let mut cloud_connection_id = state.cloud_connection_id.0.borrow_mut();
*cloud_connection_id = cloud_connection_id.saturating_add(1);
}
while let Some(message) = messages.next().await {
if let Some(message) = message.log_err() {
self.handle_message_to_client(message, cx);

View file

@ -56,6 +56,10 @@ pub fn skills_docs(cx: &App) -> String {
format!("{server_url}/docs/ai/skills", server_url = server_url(cx))
}
pub fn rules_docs(cx: &App) -> String {
format!("{server_url}/docs/ai/rules", server_url = server_url(cx))
}
/// Returns the URL to Zed's ACP registry blog post.
pub fn acp_registry_blog(cx: &App) -> String {
format!(

View file

@ -695,26 +695,69 @@ impl PickerDelegate for CommandPaletteDelegate {
}
pub fn humanize_action_name(name: &str) -> String {
let capacity = name.len() + name.chars().filter(|c| c.is_uppercase()).count();
let chars = name.chars().collect::<Vec<_>>();
let capacity = name.len() + chars.iter().filter(|c| c.is_uppercase()).count();
let mut result = String::with_capacity(capacity);
for char in name.chars() {
let mut index = 0;
while index < chars.len() {
let char = chars[index];
if char == ':' {
if result.ends_with(':') {
result.push(' ');
} else {
result.push(':');
}
index += 1;
} else if char == '_' {
result.push(' ');
index += 1;
} else if char.is_uppercase() {
if !result.ends_with(' ') {
result.push(' ');
let start = index;
index += 1;
while chars
.get(index)
.is_some_and(|next_char| next_char.is_uppercase())
{
index += 1;
}
let uppercase_run = &chars[start..index];
if uppercase_run.len() > 1 {
let split_before_last = chars
.get(index)
.is_some_and(|next_char| next_char.is_lowercase());
let acronym_end = if split_before_last {
uppercase_run.len() - 1
} else {
uppercase_run.len()
};
if acronym_end > 0 {
if !result.ends_with(' ') {
result.push(' ');
}
result.extend(&uppercase_run[..acronym_end]);
}
if split_before_last {
if !result.ends_with(' ') {
result.push(' ');
}
result.extend(uppercase_run[acronym_end].to_lowercase());
}
} else {
if !result.ends_with(' ') {
result.push(' ');
}
result.extend(char.to_lowercase());
}
result.extend(char.to_lowercase());
} else {
result.push(char);
index += 1;
}
}
result
}
@ -753,6 +796,19 @@ mod tests {
humanize_action_name("go_to_line::Deploy"),
"go to line: deploy"
);
assert_eq!(
humanize_action_name("agent::OpenGlobalAGENTS.mdRules"),
"agent: open global AGENTS.md rules"
);
assert_eq!(
humanize_action_name("agent::OpenProjectAGENTS.mdRules"),
"agent: open project AGENTS.md rules"
);
assert_eq!(humanize_action_name("editor::OpenURL"), "editor: open URL");
assert_eq!(
humanize_action_name("editor::OpenURLParser"),
"editor: open URL parser"
);
}
#[test]

View file

@ -31,7 +31,7 @@ credentials_provider.workspace = true
db.workspace = true
edit_prediction_types.workspace = true
edit_prediction_context.workspace = true
edit_prediction_metrics.workspace = true
edit_prediction_metrics = { workspace = true, features = ["tree-sitter"] }
feature_flags.workspace = true
fs.workspace = true
futures.workspace = true

View file

@ -61,7 +61,7 @@ terminal_view.workspace = true
util.workspace = true
watch.workspace = true
edit_prediction = { workspace = true, features = ["cli-support"] }
edit_prediction_metrics.workspace = true
edit_prediction_metrics = { workspace = true, features = ["tree-sitter"] }
telemetry_events.workspace = true
wasmtime.workspace = true
zeta_prompt.workspace = true

View file

@ -11,12 +11,15 @@ workspace = true
[lib]
path = "src/edit_prediction_metrics.rs"
[features]
tree-sitter = ["dep:tree-sitter"]
[dependencies]
imara-diff.workspace = true
serde.workspace = true
serde_json = "1.0"
similar = "2.7.0"
tree-sitter.workspace = true
tree-sitter = { workspace = true, optional = true }
zeta_prompt.workspace = true
[dev-dependencies]

View file

@ -4,6 +4,7 @@ mod prediction_score;
mod reversal;
mod summary;
mod tokenize;
#[cfg(feature = "tree-sitter")]
mod tree_sitter;
pub use kept_rate::AnnotatedToken;
@ -30,4 +31,5 @@ pub use prediction_score::{
};
pub use reversal::compute_prediction_reversal_ratio_from_history;
pub use summary::{PredictionSummaryInput, QaSummaryData, SummaryJson, compute_summary};
#[cfg(feature = "tree-sitter")]
pub use tree_sitter::count_tree_sitter_errors;

View file

@ -1088,6 +1088,7 @@ impl InfoPopover {
.track_scroll(&self.scroll_handle)
.child(
MarkdownElement::new(markdown, hover_markdown_style(window, cx))
.scroll_handle(self.scroll_handle.clone())
.code_block_renderer(markdown::CodeBlockRenderer::Default {
copy_button_visibility: CopyButtonVisibility::Hidden,
wrap_button_visibility: markdown::WrapButtonVisibility::Hidden,

View file

@ -34,6 +34,12 @@ impl Editor {
cx.emit(EditorEvent::InputIgnored { text: text.into() });
return;
}
cx.emit(EditorEvent::InputHandled {
utf16_range_to_replace: relative_utf16_range.clone(),
text: text.into(),
});
if let Some(relative_utf16_range) = relative_utf16_range {
let selections = self
.selections

View file

@ -560,13 +560,7 @@ async fn run_agent(
let agent = cx.update(|cx| {
let thread_store = cx.new(|cx| ThreadStore::new(cx));
NativeAgent::new(
thread_store,
Templates::new(),
None,
app_state.fs.clone(),
cx,
)
NativeAgent::new(thread_store, Templates::new(), app_state.fs.clone(), cx)
});
let connection = Rc::new(NativeAgentConnection(agent.clone()));

View file

@ -215,6 +215,10 @@ pub trait FeatureFlagAppExt {
fn flag_value<T: FeatureFlag>(&self) -> T::Value;
fn is_staff(&self) -> bool;
/// Whether feature flag overrides from settings are honored for the
/// current user. Overrides are a staff-only affordance.
fn feature_flag_overrides_enabled(&self) -> bool;
fn on_flags_ready<F>(&mut self, callback: F) -> Subscription
where
F: FnMut(OnFlagsReady, &mut App) + 'static;
@ -253,6 +257,11 @@ impl FeatureFlagAppExt for App {
.unwrap_or(false)
}
fn feature_flag_overrides_enabled(&self) -> bool {
self.try_global::<FeatureFlagStore>()
.map_or(false, |store| store.overrides_enabled())
}
fn on_flags_ready<F>(&mut self, mut callback: F) -> Subscription
where
F: FnMut(OnFlagsReady, &mut App) + 'static,

View file

@ -135,3 +135,18 @@ impl FeatureFlag for AutoWatchFeatureFlag {
type Value = PresenceFlag;
}
register_feature_flag!(AutoWatchFeatureFlag);
/// Wraps agent-run terminal commands in an OS-level sandbox where supported
/// (currently macOS Seatbelt only). When off, terminal commands run with the
/// agent's full ambient permissions, as they always have.
pub struct SandboxingFeatureFlag;
impl FeatureFlag for SandboxingFeatureFlag {
const NAME: &'static str = "sandboxing";
type Value = PresenceFlag;
fn enabled_for_staff() -> bool {
false
}
}
register_feature_flag!(SandboxingFeatureFlag);

View file

@ -96,6 +96,16 @@ impl FeatureFlagStore {
self.staff
}
/// Whether feature flag overrides from settings should be honored.
///
/// Overrides are a staff-only affordance, so non-staff users in release
/// builds can't flip flags through `settings.json` or the settings UI.
/// Debug builds are always treated as staff, and `ZED_DISABLE_STAFF`
/// forces the user to be treated as non-staff for testing.
pub fn overrides_enabled(&self) -> bool {
(cfg!(debug_assertions) || self.staff) && !*ZED_DISABLE_STAFF
}
pub fn server_flags_received(&self) -> bool {
self.server_flags_received
}
@ -158,8 +168,12 @@ impl FeatureFlagStore {
return Some(T::Value::on_variant());
}
if let Some(override_key) = FeatureFlagsSettings::get_global(cx).overrides.get(T::NAME) {
return variant_from_key::<T::Value>(override_key);
// Only apply overrides when they are specifically enabled.
if self.overrides_enabled() {
if let Some(override_key) = FeatureFlagsSettings::get_global(cx).overrides.get(T::NAME)
{
return variant_from_key::<T::Value>(override_key);
}
}
// Staff default: resolve to the enabled variant.
@ -194,15 +208,18 @@ impl FeatureFlagStore {
return on_variant_key;
}
if let Some(requested) = FeatureFlagsSettings::get_global(cx)
.overrides
.get(descriptor.name)
{
if let Some(variant) = (descriptor.variants)()
.into_iter()
.find(|v| v.override_key == requested.as_str())
// Only apply overrides when they are specifically enabled.
if self.overrides_enabled() {
if let Some(requested) = FeatureFlagsSettings::get_global(cx)
.overrides
.get(descriptor.name)
{
return variant.override_key;
if let Some(variant) = (descriptor.variants)()
.into_iter()
.find(|v| v.override_key == requested.as_str())
{
return variant.override_key;
}
}
}

View file

@ -6,7 +6,7 @@ use std::{
ops::DerefMut,
path::Path,
sync::{Arc, LazyLock, OnceLock},
time::Duration,
time::{Duration, Instant},
};
use util::{ResultExt, paths::SanitizedPath};
@ -50,11 +50,13 @@ impl FsWatcher {
fn add_existing_path(&self, path: Arc<Path>) -> anyhow::Result<()> {
let registration_path = path.clone();
let registration =
register_existing_path(path, self.tx.clone(), self.pending_path_events.clone())?;
self.registrations
.lock()
.insert(registration_path, registration);
if let Some(registration) =
register_existing_path(path, self.tx.clone(), self.pending_path_events.clone())?
{
self.registrations
.lock()
.insert(registration_path, registration);
}
Ok(())
}
@ -182,7 +184,7 @@ fn register_existing_path(
path: Arc<Path>,
tx: async_channel::Sender<()>,
pending_path_events: Arc<Mutex<Vec<PathEvent>>>,
) -> anyhow::Result<FsWatcherRegistration> {
) -> anyhow::Result<Option<FsWatcherRegistration>> {
let mode = if requires_poll_watcher(path.as_ref()) {
log::info!(
"Using poll watcher ({}ms interval) for {}",
@ -196,20 +198,24 @@ fn register_existing_path(
};
let root_path = SanitizedPath::new_arc(path.as_ref());
let path_for_callback = path.clone();
let registration_id = global_watcher().add(path, mode, move |event: &notify::Event| {
log::trace!("watcher received event: {event:?}");
push_notify_event(
&tx,
&pending_path_events,
&root_path,
path_for_callback.as_ref(),
event,
);
})?;
Ok(FsWatcherRegistration {
let Some(registration_id) =
global_watcher().add(path, mode, move |event: &notify::Event| {
log::trace!("watcher received event: {event:?}");
push_notify_event(
&tx,
&pending_path_events,
&root_path,
path_for_callback.as_ref(),
event,
);
})?
else {
return Ok(None);
};
Ok(Some(FsWatcherRegistration {
id: registration_id,
mode,
})
}))
}
#[cfg(target_os = "linux")]
@ -345,7 +351,7 @@ async fn poll_path_until_created(
}
match register_existing_path(path.clone(), tx.clone(), pending_path_events.clone()) {
Ok(registration) => {
Ok(Some(registration)) => {
{
let mut pending_registrations = pending_registrations.lock();
if pending_registrations.remove(path.as_ref()).is_none() {
@ -370,6 +376,7 @@ async fn poll_path_until_created(
);
return;
}
Ok(None) => {}
Err(error) => {
log::warn!("failed to watch newly-created path {path:?}: {error}; retrying");
}
@ -501,10 +508,16 @@ struct WatcherState {
watchers: HashMap<WatcherRegistrationId, WatcherRegistrationState>,
native_path_registrations: HashMap<Arc<std::path::Path>, PathRegistrationState>,
poll_path_registrations: HashMap<Arc<std::path::Path>, PathRegistrationState>,
cooldown_until: Option<Instant>,
last_registration: WatcherRegistrationId,
}
impl WatcherState {
fn is_native_watch_limit_cooldown_active(&self) -> bool {
self.cooldown_until
.is_some_and(|cooldown_until| cooldown_until > Instant::now())
}
fn path_registrations(
&mut self,
mode: WatcherMode,
@ -565,15 +578,30 @@ impl GlobalWatcher {
path: Arc<std::path::Path>,
mode: WatcherMode,
cb: impl Fn(&notify::Event) + Send + Sync + 'static,
) -> anyhow::Result<WatcherRegistrationId> {
) -> anyhow::Result<Option<WatcherRegistrationId>> {
let mut state = self.state.lock();
let registrations_for_mode = state.path_registrations(mode);
let path_already_covered =
path_already_covered(path.as_ref(), registrations_for_mode, mode);
let (path_already_covered, path_already_registered) = {
let registrations_for_mode = state.path_registrations(mode);
(
path_already_covered(path.as_ref(), registrations_for_mode, mode),
registrations_for_mode.contains_key(&path),
)
};
if !path_already_covered && !path_already_registered {
if mode == WatcherMode::Native && state.is_native_watch_limit_cooldown_active() {
return Ok(None);
}
if !path_already_covered && !registrations_for_mode.contains_key(&path) {
drop(state);
self.watch(&path, mode)?;
match self.watch(&path, mode) {
Ok(()) => {}
Err(error) if mode == WatcherMode::Native && is_max_files_watch_error(&error) => {
self.start_native_watch_limit_cooldown(&path);
return Ok(None);
}
Err(error) => return Err(error),
}
state = self.state.lock();
}
@ -595,7 +623,20 @@ impl GlobalWatcher {
has_os_watcher: !path_already_covered,
});
Ok(id)
Ok(Some(id))
}
fn start_native_watch_limit_cooldown(&self, path: &Path) {
let mut state = self.state.lock();
let now = Instant::now();
let should_log = !state.is_native_watch_limit_cooldown_active();
state.cooldown_until = Some(now + *NATIVE_WATCH_LIMIT_COOLDOWN);
if should_log {
log::warn!(
"OS file watch limit reached while watching {path:?}; skipping new native file watcher registrations for {} seconds",
NATIVE_WATCH_LIMIT_COOLDOWN.as_secs()
);
}
}
pub fn remove(&self, id: WatcherRegistrationId) {
@ -688,6 +729,12 @@ fn path_already_covered(
.any(|ancestor| path_registrations.contains_key(ancestor))
}
fn is_max_files_watch_error(error: &anyhow::Error) -> bool {
error
.downcast_ref::<notify::Error>()
.is_some_and(|error| matches!(&error.kind, notify::ErrorKind::MaxFilesWatch))
}
static POLL_INTERVAL: LazyLock<Duration> = LazyLock::new(|| {
let poll_ms: u64 = std::env::var("ZED_FILE_WATCHER_POLL_MS")
.ok()
@ -697,6 +744,15 @@ static POLL_INTERVAL: LazyLock<Duration> = LazyLock::new(|| {
Duration::from_millis(poll_ms)
});
static NATIVE_WATCH_LIMIT_COOLDOWN: LazyLock<Duration> = LazyLock::new(|| {
let cooldown_seconds: u64 = std::env::var("ZED_NATIVE_WATCH_LIMIT_COOLDOWN_SECONDS")
.ok()
.and_then(|value| value.parse().ok())
.unwrap_or(5)
.clamp(0, 300);
Duration::from_secs(cooldown_seconds)
});
pub fn poll_interval() -> Duration {
*POLL_INTERVAL
}
@ -709,6 +765,7 @@ fn global_watcher() -> &'static GlobalWatcher {
watchers: Default::default(),
native_path_registrations: Default::default(),
poll_path_registrations: Default::default(),
cooldown_until: None,
last_registration: Default::default(),
}),
native_watcher: Mutex::new(None),
@ -789,6 +846,7 @@ mod tests {
watched_paths: HashSet<PathBuf>,
watch_calls: Vec<PathBuf>,
unwatch_calls: Vec<PathBuf>,
fail_with_watch_limit: bool,
}
struct SharedFakeWatchBackend(Arc<Mutex<FakeWatchBackend>>);
@ -798,6 +856,9 @@ mod tests {
let path = path.to_path_buf();
let mut backend = self.0.lock();
backend.watch_calls.push(path.clone());
if backend.fail_with_watch_limit {
return Err(notify::Error::new(notify::ErrorKind::MaxFilesWatch));
}
backend.watched_paths.insert(path);
Ok(())
}
@ -815,15 +876,31 @@ mod tests {
}
fn test_watcher(poll_watcher: Arc<Mutex<FakeWatchBackend>>) -> GlobalWatcher {
test_watcher_with_backends(None, Some(poll_watcher))
}
fn test_watcher_with_backends(
native_watcher: Option<Arc<Mutex<FakeWatchBackend>>>,
poll_watcher: Option<Arc<Mutex<FakeWatchBackend>>>,
) -> GlobalWatcher {
GlobalWatcher {
state: Mutex::new(WatcherState {
watchers: Default::default(),
native_path_registrations: Default::default(),
poll_path_registrations: Default::default(),
cooldown_until: None,
last_registration: Default::default(),
}),
native_watcher: Mutex::new(None),
poll_watcher: Mutex::new(Some(Box::new(SharedFakeWatchBackend(poll_watcher)))),
native_watcher: Mutex::new(
native_watcher.map(|watcher| {
Box::new(SharedFakeWatchBackend(watcher)) as Box<dyn WatchBackend>
}),
),
poll_watcher: Mutex::new(
poll_watcher.map(|watcher| {
Box::new(SharedFakeWatchBackend(watcher)) as Box<dyn WatchBackend>
}),
),
}
}
@ -844,10 +921,12 @@ mod tests {
let parent_registration = watcher
.add(parent.as_ref().into(), WatcherMode::Poll, |_| {})
.expect("add parent watch");
.expect("add parent watch")
.expect("parent watch registered");
let child_registration = watcher
.add(child.as_ref().into(), WatcherMode::Poll, |_| {})
.expect("add covered child watch");
.expect("add covered child watch")
.expect("child watch registered");
watcher.remove(parent_registration);
watcher.remove(child_registration);
@ -857,6 +936,31 @@ mod tests {
assert_eq!(backend.unwatch_calls, &[parent.to_path_buf()]);
}
#[test]
fn native_watch_limit_cools_down_subsequent_native_registrations() {
let native_backend = Arc::new(Mutex::new(FakeWatchBackend {
fail_with_watch_limit: true,
..Default::default()
}));
let poll_backend = Arc::new(Mutex::new(FakeWatchBackend::default()));
let watcher = test_watcher_with_backends(Some(native_backend.clone()), Some(poll_backend));
let first_path = Arc::<Path>::from(Path::new("/repo/first"));
let second_path = Arc::<Path>::from(Path::new("/repo/second"));
let first_registration = watcher
.add(first_path.clone(), WatcherMode::Native, |_| {})
.expect("native watch limit is handled");
let second_registration = watcher
.add(second_path, WatcherMode::Native, |_| {})
.expect("native watch limit backoff is handled");
assert!(first_registration.is_none());
assert!(second_registration.is_none());
let native_backend = native_backend.lock();
assert_eq!(native_backend.watch_calls, &[first_path.to_path_buf()]);
}
#[test]
fn test_coalesce_pending_rescans() {
let test_cases = [

View file

@ -179,9 +179,12 @@ impl FromStr for Oid {
type Err = anyhow::Error;
fn from_str(s: &str) -> std::prelude::v1::Result<Self, Self::Err> {
libgit::Oid::from_str(s)
.context("parsing git oid")
.map(Self)
let oid = if s.len() == 64 {
libgit::Oid::from_str_ext(s, libgit::ObjectFormat::Sha256)
} else {
libgit::Oid::from_str_ext(s, libgit::ObjectFormat::Sha1)
};
oid.context("parsing git oid").map(Self)
}
}
@ -218,7 +221,7 @@ impl<'de> Deserialize<'de> for Oid {
impl Default for Oid {
fn default() -> Self {
Self(libgit::Oid::zero())
Self(libgit::Oid::ZERO_SHA1)
}
}

View file

@ -1672,7 +1672,7 @@ impl GitRepository for RealGitRepository {
.spawn(async move {
let repo = repo.lock();
let remote = repo.find_remote(&name).ok()?;
remote.url().map(|url| url.to_string())
remote.url().ok().map(|url| url.to_string())
})
.boxed()
}

View file

@ -238,7 +238,11 @@ impl BlameRenderer for GitBlameRenderer {
let message = details
.as_ref()
.map(|_| MarkdownElement::new(markdown.clone(), markdown_style).into_any())
.map(|_| {
MarkdownElement::new(markdown.clone(), markdown_style)
.scroll_handle(scroll_handle.clone())
.into_any()
})
.unwrap_or("<no commit message>".into_any());
let pull_request = details

View file

@ -258,7 +258,11 @@ impl Render for CommitTooltip {
.commit
.message
.as_ref()
.map(|_| MarkdownElement::new(self.markdown.clone(), markdown_style).into_any())
.map(|_| {
MarkdownElement::new(self.markdown.clone(), markdown_style)
.scroll_handle(self.scroll_handle.clone())
.into_any()
})
.unwrap_or("<no commit message>".into_any());
let pull_request = self

View file

@ -9,7 +9,7 @@ use crate::{branch_picker, picker_prompt, render_remote_button};
use crate::{
git_panel_settings::GitPanelSettings, git_status_icon, repository_selector::RepositorySelector,
};
use agent_settings::AgentSettings;
use agent_settings::{AgentSettings, UserAgentsMd};
use alacritty_terminal::vte::ansi;
use anyhow::Context as _;
use askpass::AskPassDelegate;
@ -2699,6 +2699,40 @@ impl GitPanel {
.unwrap_or_else(|| BuiltInPrompt::CommitMessage.default_content().to_string())
}
fn build_commit_message_prompt(
prompt: &str,
user_agents_md: Option<&str>,
rules_content: Option<&str>,
subject: &str,
diff_text: &str,
) -> String {
let user_agents_md_section = match user_agents_md {
Some(user_agents_md) => format!(
"\n\nThe user has provided the following rules that you should follow when writing the commit message. Project-specific rules may override these instructions when they conflict:\n\
<rules>\n{user_agents_md}\n</rules>\n"
),
None => String::new(),
};
let rules_section = match rules_content {
Some(rules) => format!(
"\n\nThe user has provided the following rules specific to this project that you should follow when writing the commit message:\n\
<project_rules>\n{rules}\n</project_rules>\n"
),
None => String::new(),
};
let subject_section = if subject.trim().is_empty() {
String::new()
} else {
format!("\nHere is the user's subject line:\n{subject}")
};
format!(
"{prompt}{user_agents_md_section}{rules_section}{subject_section}\nHere are the changes in this commit:\n{diff_text}"
)
}
/// Generates a commit message using an LLM.
pub fn generate_commit_message(&mut self, cx: &mut Context<Self>) {
if !self.can_commit() || !AgentSettings::get_global(cx).enabled(cx) {
@ -2730,7 +2764,7 @@ impl GitPanel {
let repo_work_dir = repo.read(cx).work_directory_abs_path.clone();
self.generate_commit_message_task = Some(cx.spawn(async move |this, mut cx| {
async move {
async move {
let _defer = cx.on_drop(&this, |this, _cx| {
this.generate_commit_message_task.take();
});
@ -2762,32 +2796,33 @@ impl GitPanel {
const MAX_DIFF_BYTES: usize = 20_000;
diff_text = Self::compress_commit_diff(&diff_text, MAX_DIFF_BYTES);
let rules_content = Self::load_project_rules(&project, &repo_work_dir, &mut cx).await;
let rules_content =
Self::load_project_rules(&project, &repo_work_dir, &mut cx).await;
let user_agents_md = cx.update(|cx| {
UserAgentsMd::global(cx)
.and_then(|user_agents_md| user_agents_md.content().cloned())
});
let prompt = Self::load_commit_message_prompt(&mut cx).await;
let subject = this.update(cx, |this, cx| {
this.commit_editor.read(cx).text(cx).lines().next().map(ToOwned::to_owned).unwrap_or_default()
this.commit_editor
.read(cx)
.text(cx)
.lines()
.next()
.map(ToOwned::to_owned)
.unwrap_or_default()
})?;
let text_empty = subject.trim().is_empty();
let rules_section = match &rules_content {
Some(rules) => format!(
"\n\nThe user has provided the following project rules that you should follow when writing the commit message:\n\
<project_rules>\n{rules}\n</project_rules>\n"
),
None => String::new(),
};
let subject_section = if text_empty {
String::new()
} else {
format!("\nHere is the user's subject line:\n{subject}")
};
let content = format!(
"{prompt}{rules_section}{subject_section}\nHere are the changes in this commit:\n{diff_text}"
let content = Self::build_commit_message_prompt(
&prompt,
user_agents_md.as_deref(),
rules_content.as_deref(),
&subject,
&diff_text,
);
let request = LanguageModelRequest {
@ -2816,7 +2851,11 @@ impl GitPanel {
this.update(cx, |this, cx| {
this.commit_message_buffer(cx).update(cx, |buffer, cx| {
let insert_position = buffer.anchor_before(buffer.len());
buffer.edit([(insert_position..insert_position, "\n")], None, cx)
buffer.edit(
[(insert_position..insert_position, "\n")],
None,
cx,
)
});
})?;
}
@ -2826,8 +2865,13 @@ impl GitPanel {
Ok(text) => {
this.update(cx, |this, cx| {
this.commit_message_buffer(cx).update(cx, |buffer, cx| {
let insert_position = buffer.anchor_before(buffer.len());
buffer.edit([(insert_position..insert_position, text)], None, cx);
let insert_position =
buffer.anchor_before(buffer.len());
buffer.edit(
[(insert_position..insert_position, text)],
None,
cx,
);
});
})?;
}
@ -2845,7 +2889,8 @@ impl GitPanel {
anyhow::Ok(())
}
.log_err().await
.log_err()
.await
}));
}
@ -8726,6 +8771,26 @@ mod tests {
assert_eq!(result, expected);
}
#[test]
fn test_commit_message_prompt_includes_user_agents_md_before_project_rules() {
let prompt = GitPanel::build_commit_message_prompt(
"Write a commit message.",
Some("Use terse commit messages."),
Some("Use the git_ui prefix."),
"Update generated message",
"diff --git a/file b/file",
);
assert!(prompt.contains("Use terse commit messages."));
assert!(prompt.contains("Use the git_ui prefix."));
assert!(prompt.contains("Update generated message"));
assert!(prompt.contains("diff --git a/file b/file"));
let user_agents_md_index = prompt.find("<rules>").unwrap();
let project_rules_index = prompt.find("<project_rules>").unwrap();
assert!(user_agents_md_index < project_rules_index);
}
#[gpui::test]
async fn test_suggest_commit_message(cx: &mut TestAppContext) {
init_test(cx);

View file

@ -45,6 +45,7 @@ path = "src/gpui.rs"
doctest = false
[dependencies]
accesskit.workspace = true
anyhow.workspace = true
async-task = "4.7"
backtrace = { workspace = true, optional = true }
@ -175,6 +176,8 @@ cbindgen = { version = "0.28.0", default-features = false }
[[example]]
name = "hello_world"
path = "examples/hello_world.rs"
@ -250,3 +253,7 @@ path = "examples/list_example.rs"
[[example]]
name = "mouse_pressure"
path = "examples/mouse_pressure.rs"
[[example]]
name = "a11y"
path = "examples/a11y.rs"

View file

@ -12,6 +12,7 @@ gpui = { version = "*" }
```
- [Ownership and data flow](_ownership_and_data_flow)
- [Accessibility](_accessibility)
Everything in GPUI starts with an `Application`. You can create one with `Application::new()`, and kick off your application by passing a callback to `Application::run()`. Inside this callback, you can create a new window with `App::open_window()`, and register your first root view. See [gpui.rs](https://www.gpui.rs/) for a complete example.

View file

@ -0,0 +1,264 @@
//! Accessibility (AccessKit) demo app.
//!
//! Run with: `cargo run -p gpui --example a11y`
//!
//! Or on Linux: `cargo run -p gpui --features gpui_platform/wayland,gpui_platform/x11 --example a11y`
//!
//! This app uses GPUI's accessibility APIs to attach structured information to
//! the element tree, which allows assistive technology to see and interact with
//! the UI programmatically.
//!
//! The app behaves as follows:
//! - It opens a single window.
//! - The window's title is "GPUI Accessibility Demo".
//! - The window has a sequence of UI elements, stacked vertically:
//! - A heading with the text "Accessibility Demo".
//! - A row containing two elements:
//! - A spin button (role `SpinButton`) labelled "Counter: <n>", where
//! `<n>` is the current count. It supports `Increment` and `Decrement`
//! accessible actions, and also increments on click. The numeric value
//! is clamped to a minimum of 0.
//! - A button labelled "Reset counter" that resets the count to 0.
//! - A row containing two elements:
//! - A switch, that can be toggled, and starts disabled. Toggling the switch
//! does nothing.
//! - The text "Enable feature".
//! - A "to-do" list, with three items, each represented with a `Text` element:
//! - "1. Write code"
//! - "2. Run tests"
//! - "3. Ship it"
use gpui::{
AccessibleAction, App, Bounds, Context, FocusHandle, KeyBinding, Role, SharedString, Toggled,
Window, WindowBounds, WindowOptions, actions, div, prelude::*, px, rgb, size, text,
};
use gpui_platform::application;
actions!(a11y_example, [Tab, TabPrev]);
struct A11yDemo {
focus_handle: FocusHandle,
count: i32,
enabled: bool,
}
impl A11yDemo {
fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
let focus_handle = cx.focus_handle();
window.focus(&focus_handle, cx);
Self {
focus_handle,
count: 0,
enabled: false,
}
}
}
impl Render for A11yDemo {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
div()
.id("root")
.role(Role::Application)
.aria_label("Accessibility Demo")
.track_focus(&self.focus_handle)
.on_action(cx.listener(|_, _: &Tab, window, cx| window.focus_next(cx)))
.on_action(cx.listener(|_, _: &TabPrev, window, cx| window.focus_prev(cx)))
.size_full()
.flex()
.flex_col()
.gap_4()
.p_4()
.bg(rgb(0x1e1e2e))
.text_color(rgb(0xcdd6f4))
// Heading
.child(
div()
.id("heading")
.role(Role::Heading)
.aria_level(1)
.aria_label("Accessibility Demo")
.text_xl()
.font_weight(gpui::FontWeight::BOLD)
.child(text!("Accessibility Demo")),
)
// Counter — uses a SpinButton role with Increment/Decrement
// actions so screen readers can adjust the value directly.
// Click also works via the built-in handler.
.child(
div()
.flex()
.items_center()
.gap_3()
.child(
div()
.id("counter")
.focusable()
.tab_stop(true)
.role(Role::SpinButton)
.aria_label(SharedString::from(format!("Counter: {}", self.count)))
.aria_numeric_value(self.count as f64)
.aria_min_numeric_value(0.0)
.on_a11y_action(AccessibleAction::Increment, {
let this = cx.entity().downgrade();
move |_, _, cx| {
this.update(cx, |this, cx| {
this.count += 1;
cx.notify();
})
.ok();
}
})
.on_a11y_action(AccessibleAction::Decrement, {
let this = cx.entity().downgrade();
move |_, _, cx| {
this.update(cx, |this, cx| {
this.count = (this.count - 1).max(0);
cx.notify();
})
.ok();
}
})
.on_click(cx.listener(|this, _, _, cx| {
this.count += 1;
cx.notify();
}))
.px_3()
.py_1()
.rounded_md()
.bg(rgb(0x89b4fa))
.text_color(rgb(0x1e1e2e))
.cursor_pointer()
.child(text!(format!("Count: {}", self.count))),
)
.child(
div()
.id("reset")
.focusable()
.tab_stop(true)
.role(Role::Button)
.aria_label("Reset counter")
.px_3()
.py_1()
.rounded_md()
.bg(rgb(0x585b70))
.cursor_pointer()
.on_click(cx.listener(|this, _, _, cx| {
this.count = 0;
cx.notify();
}))
.child(text!("Reset")),
),
)
// A toggle switch
.child(
div()
.flex()
.items_center()
.gap_2()
.child(
div()
.id("toggle")
.focusable()
.tab_stop(true)
.role(Role::Switch)
.aria_label("Enable feature")
.aria_toggled(if self.enabled {
Toggled::True
} else {
Toggled::False
})
.w(px(44.))
.h(px(24.))
.rounded_full()
.cursor_pointer()
.when(self.enabled, |el| el.bg(rgb(0x89b4fa)))
.when(!self.enabled, |el| el.bg(rgb(0x585b70)))
.child(
div()
.size(px(20.))
.rounded_full()
.bg(gpui::white())
.mt(px(2.))
.when(self.enabled, |el| el.ml(px(22.)))
.when(!self.enabled, |el| el.ml(px(2.))),
)
.on_click(cx.listener(|this, _, _, cx| {
this.enabled = !this.enabled;
cx.notify();
})),
)
.child(text!("Enable feature")),
)
// A short list
.child(
div()
.id("task-list")
.role(Role::List)
.aria_label("Tasks")
.flex()
.flex_col()
.gap_1()
.children(
["Write code", "Run tests", "Ship it"]
.iter()
.enumerate()
.map(|(i, label)| {
div()
.id(("task", i))
.role(Role::ListItem)
.aria_label(SharedString::from(*label))
.aria_position_in_set(i + 1)
.aria_size_of_set(3)
.py_1()
.px_2()
// Note: even though this `text!` macro
// produces multiple elements, it doesn't
// need its own unique ID because the parent
// div has different IDs for each string.
.child(text!(format!("{}. {}", i + 1, label)))
}),
),
)
}
}
fn run_example() {
application().run(|cx: &mut App| {
cx.bind_keys([
KeyBinding::new("tab", Tab, None),
KeyBinding::new("shift-tab", TabPrev, None),
]);
let bounds = Bounds::centered(None, size(px(500.), px(400.0)), cx);
cx.open_window(
WindowOptions {
window_bounds: Some(WindowBounds::Windowed(bounds)),
titlebar: Some(gpui::TitlebarOptions {
title: Some("GPUI Accessibility Demo".into()),
..Default::default()
}),
..Default::default()
},
|window, cx| cx.new(|cx| A11yDemo::new(window, cx)),
)
.unwrap();
cx.activate(true);
});
}
#[cfg(not(target_family = "wasm"))]
fn main() {
env_logger::builder()
.filter_level(log::LevelFilter::Warn)
.filter_module("gpui", log::LevelFilter::Info)
.init();
run_example();
}
#[cfg(target_family = "wasm")]
#[wasm_bindgen::prelude::wasm_bindgen(start)]
pub fn start() {
gpui_platform::web_init();
run_example();
}

View file

@ -0,0 +1,243 @@
//! # Accessibility in GPUI
//!
//! "Accessibility" refers to the ability of your application to be used by all
//! users, regardless of disability status. There are many aspects, all important, including:
//! - Ensuring sufficient text contrast.
//! - Providing a mechanism to disable animations.
//! - Providing a mechanism to increase text sizes.
//! - etc.
//!
//! This guide is focused on **programmatic accessibility**. This allows
//! assistive technology, such as screen readers or Braille displays, to inspect
//! and interact with your app.
//!
//! GPUI integrates with [AccessKit] to provide programmatic accessibility
//! features (referred to as simply "accessibility" for the rest of this guide).
//!
//! A minimal example can be found in the `examples/a11y` directory.
//!
//! ## Background
//!
//! Accessibility support is based on two key capabilities:
//! - Exposing information about the current UI state to assistive technology.
//! - Responding to actions requested by assistive technology.
//!
//! For example, a screen reader might want to announce to the user that a new
//! button has appeared. The user may then want to use a voice control program
//! to press that button.
//!
//! ### IDs in GPUI - [`ElementId`] and [`GlobalElementId`]
//!
//! In GPUI, each [`Element`] can have an [`id`][Element::id]:
//! ```rust
//! # use gpui::*;
//! let div_with_id = div().id("my-id").child(text!("hello"));
//!
//! // IDs are optional
//! let div_without_id = div().child(text!("hello"));
//! ```
//!
//! [`Element`]s with IDs are also assigned a [`GlobalElementId`]. This global
//! ID is formed by composing all the non-`None` IDs of its ancestors. For
//! example:
//! ```rust
//! # use gpui::*;
//! let inner = div().id("inner-id");
//! let middle = div().child(inner); // no ID
//! let outer = div().id("outer-id").child(middle);
//! ```
//! In this example, `inner`s global ID is (roughly speaking) `["outer-id",
//! "inner-id"]`.
//!
//! Since `middle` doesn't have an ID itself, it has no global ID.
//!
//! [`GlobalElementId`]s should be unique per-frame. Duplicate global IDs in the
//! same frame will likely cause bugs.
//!
//! ### IDs and accessibility
//!
//! When GPUI renders a frame, it walks your UI tree, and finds nodes with
//! global IDs, and informs assistive technology about this node.
//!
//! In order for nodes to be reported, they must also have a non-`None`
//! [`role`][Element::a11y_role]. This is used to inform assistive technology
//! what *sort* of node it is (button, label, table, etc.). You can use
//! [`div().id(...).role()`][StatefulInteractiveElement::role] to set the role.
//!
//! Nodes with the same global ID *across frames* are considered to be "the
//! same" node. For example:
//! ```rust
//! # use gpui::*;
//! // The UI in frame 1
//! let frame_1 = div()
//! .id("parent")
//! .role(Role::Button)
//! .child(
//! div()
//! .id("id-1")
//! .role(Role::Label)
//! .child(text!("hello"))
//! );
//!
//! // The UI on the next frame
//! let frame_2 = div()
//! .id("parent")
//! .role(Role::Button)
//! .child(
//! div()
//! .id("id-2") // <- different ID
//! .role(Role::Label)
//! .child(text!("hello"))
//! );
//! ```
//! Logically, the UI has not changed. But the screen reader has no way of
//! knowing that both child [`div`]s are "the same". So assistive technology
//! will interpret this as one node being removed, and another node being added.
//! This can be very disorienting for users, since announcements typically only
//! happen when something has *meaningfully* changed.
//!
//! In other words, by controlling the ID of an element, you can control whether
//! a change to a UI element is considered meaningful. You can also control
//! whether elements are reported to assistive technology *at all* by setting
//! the [`role`][Element::a11y_role], since nodes with no role are not reported.
//!
//! #### IDs and text
//!
//! Special care must be taken when dealing with text.
//!
//! GPUI provides the [`text!`] macro, which wraps strings in the [`Text`] type,
//! but automatically derives an ID. Usually, this is what you want. However,
//! the way it generates its ID is subtle and perhaps surprising.
//!
//! The ID of an invocation of the [`text!`] macro is derived from the
//! **location in the source code of that invocation**. For example:
//!
//! ```rust
//! # use gpui::*;
//! let a = text!("a");
//! let b = text!("b");
//!
//! // Different source locations, different IDs
//! assert_ne!(a.id(), b.id());
//!
//! // However:
//!
//! fn make_text(s: &str) -> Text { text!(s) }
//!
//! let a = make_text("a");
//! let b = make_text("b");
//!
//! // Both `a` and `b` are produced by the same `text!` invocation, so the IDs
//! // are the same
//! assert_eq!(a.id(), b.id());
//! ```
//! This can produce surprising behaviour. For example, this footgun:
//! ```rust
//! # use gpui::*;
//! let todos = vec!["eat lunch", "drink water", "go to gym"];
//! let todo_divs = todos.into_iter().map(|todo| {
//! text!(todo)
//! });
//!
//! div()
//! .id("todo-list")
//! .role(Role::Document)
//! .children(todo_divs); // ERROR: multiple nodes with the same global ID
//! ```
//!
//! Here, when we map the iterator, since we have only written [`text!`] once,
//! there is only one ID. And since they have the same ancestors and the same
//! ID, they will have the same global ID. In release builds, this will mean
//! some nodes get silently dropped!
//!
//! To fix this, you can set an ID:
//! ```rust
//! # use gpui::*;
//! let todos = vec!["eat lunch", "drink water", "go to gym"];
//! let todo_divs = todos.into_iter().enumerate().map(|(index, todo)| {
//! text!(todo).with_id(index) // OR `text(id = index, todo)`
//! });
//!
//! div()
//! .id("todo-list")
//! .role(Role::Document)
//! .children(todo_divs);
//! ```
//! Another possible solution is to wrap the [`text!`] in another node that
//! *does* have a unique global ID. For example:
//! ```rust
//! # use gpui::*;
//! let todos = vec!["eat lunch", "drink water", "go to gym"];
//! let todo_divs = todos.into_iter().enumerate().map(|(index, todo)| {
//! div().id(index).child(text!(todo))
//! });
//!
//! div()
//! .id("todo-list")
//! .role(Role::Document)
//! .children(todo_divs);
//! ```
//! Since the AccessKit [`NodeId`][accesskit::NodeId] is derived from the global
//! ID, and the global ID takes into account the IDs of all ancestors, this
//! works too.
//!
//! Occasionally, you will need to create a [`Text`] element with *no* ID. You
//! can achieve this with [`Text::new_inaccessible`]. If you are creating a
//! custom UI component (e.g. a button), you may want this so that you can set a
//! label property on a parent [`div`] without duplicating the text in the
//! accessibility tree.
//!
//! ### Handling actions
//!
//! Assistive technology can dispatch actions to the UI. While many users of
//! assistive technology use traditional input devices (e.g. a keyboard), some
//! use more specialized systems. For example, users with limited mobility may
//! use voice control to interact with your app.
//!
//! When a user dispatches an action, it is dispatched *to a specific node*. It
//! is your responsibility to tell the UI elements how they should respond when
//! a request comes in.
//!
//! Note, these actions are **totally unrelated** to GPUI's [`Action`] trait.
//! AccessKit exposes [`accesskit::Action`]. In GPUI, this is re-exported as
//! [`AccessibleAction`].
//!
//! To respond to an accessible action, use
//! [`div().on_a11y_action()`][InteractiveElement::on_a11y_action]:
//! ```rust,ignore
//! div()
//! .id("my-slider")
//! .role(Role::Slider)
//! .on_a11y_action(AccessibleAction::Increment, |_extra, _window, _cx| {
//! position += 1;
//! cx.notify();
//! })
//! .child(my_cool_slider());
//! ```
//!
//! Note that some common actions are automatically registered. For example,
//! [`.on_click()`][StatefulInteractiveElement::on_click] adds an
//! [`AccessibleAction::Click`] handler that calls the click handler.
//!
//! ## Further reading
//!
//! Designing high-quality accessible interfaces can be challenging, in the same
//! way that designing high-quality traditional interfaces can be. The
//! following pages have useful information:
//!
//! - [AccessKit]: The cross-platform accessibility toolkit GPUI uses
//! internally.
//! - [MDN WAI-ARIA basics][mdn-aria]: Introduction to roles, properties, and
//! states.
//! - [ARIA Authoring Practices Guide][apg]: W3C patterns for accessible
//! widgets.
//!
//! Note that, while GPUI mimics web APIs, it doesn't necessarily behave
//! *exactly* as a web browser would with the same attributes.
//!
//! [AccessKit]: https://accesskit.dev/
//! [mdn-aria]: https://developer.mozilla.org/en-US/docs/Learn_web_development/Core/Accessibility/WAI-ARIA_basics
//! [apg]: https://www.w3.org/WAI/ARIA/apg/
#[cfg(doc)]
use crate::*; // so I don't have to qualify every type :)

View file

@ -103,6 +103,22 @@ pub trait Element: 'static + IntoElement {
cx: &mut App,
);
/// Returns the accessible role for this element, if any.
/// Elements that return `None` are not included in the accessibility tree.
///
/// Note: inclusion in accessibility tree requires non-`None` [`id`][Element::id].
///
/// See the [accessibility guide](crate::_accessibility) for an overview.
fn a11y_role(&self) -> Option<accesskit::Role> {
None
}
/// Write accessibility properties to the given node.
/// Called only when `a11y_role()` returns `Some`.
///
/// See the [accessibility guide](crate::_accessibility) for an overview.
fn write_a11y_info(&self, _node: &mut accesskit::Node) {}
/// Convert this element into a dynamically-typed [`AnyElement`].
fn into_any(self) -> AnyElement {
AnyElement::new(self)
@ -302,6 +318,15 @@ impl Display for GlobalElementId {
}
}
impl GlobalElementId {
pub(crate) fn accesskit_node_id(&self) -> accesskit::NodeId {
use std::hash::{Hash, Hasher};
let mut hasher = std::hash::DefaultHasher::default();
self.hash(&mut hasher);
accesskit::NodeId(hasher.finish())
}
}
trait ElementObject {
fn inner_element(&mut self) -> &mut dyn Any;
@ -431,6 +456,26 @@ impl<E: Element> Drawable<E> {
}
let bounds = window.layout_bounds(layout_id);
let mut pushed_a11y_node = false;
if window.a11y.is_active() {
if let Some(global_id) = global_id.as_ref() {
if let Some(role) = self.element.a11y_role() {
let node_id = global_id.accesskit_node_id();
let mut node = accesskit::Node::new(role);
let scale = window.scale_factor();
node.set_bounds(accesskit::Rect {
x0: (bounds.origin.x.0 * scale) as f64,
y0: (bounds.origin.y.0 * scale) as f64,
x1: ((bounds.origin.x.0 + bounds.size.width.0) * scale) as f64,
y1: ((bounds.origin.y.0 + bounds.size.height.0) * scale) as f64,
});
self.element.write_a11y_info(&mut node);
window.a11y.node_bounds.insert(node_id, bounds);
pushed_a11y_node = window.a11y.nodes.push(node_id, node);
}
}
}
let node_id = window.next_frame.dispatch_tree.push_node();
let prepaint = self.element.prepaint(
global_id.as_ref(),
@ -442,6 +487,10 @@ impl<E: Element> Drawable<E> {
);
window.next_frame.dispatch_tree.pop_node();
if pushed_a11y_node {
window.a11y.nodes.pop();
}
if global_id.is_some() {
window.element_id_stack.pop();
}

View file

@ -1175,6 +1175,124 @@ pub trait InteractiveElement: Sized {
/// A trait for elements that want to use the standard GPUI interactivity features
/// that require state.
pub trait StatefulInteractiveElement: InteractiveElement {
/// Set the accessible role for this element.
///
/// See the [accessibility guide](crate::_accessibility) for an overview.
fn role(mut self, role: accesskit::Role) -> Self {
debug_assert!(
role != accesskit::Role::GenericContainer,
"GenericContainer is filtered out of the a11y tree and has no effect"
);
self.interactivity().override_role = Some(role);
self
}
/// Set the accessible label for this element.
fn aria_label(mut self, label: impl Into<SharedString>) -> Self {
self.interactivity().aria_label = Some(label.into());
self
}
/// Set the selected state for this element.
fn aria_selected(mut self, selected: bool) -> Self {
self.interactivity().aria_selected = Some(selected);
self
}
/// Set the expanded state for this element.
fn aria_expanded(mut self, expanded: bool) -> Self {
self.interactivity().aria_expanded = Some(expanded);
self
}
/// Set the toggled state for this element.
fn aria_toggled(mut self, toggled: accesskit::Toggled) -> Self {
self.interactivity().aria_toggled = Some(toggled);
self
}
/// Set the numeric value for this element.
fn aria_numeric_value(mut self, value: f64) -> Self {
self.interactivity().aria_numeric_value = Some(value);
self
}
/// Set the minimum numeric value for this element.
fn aria_min_numeric_value(mut self, value: f64) -> Self {
self.interactivity().aria_min_numeric_value = Some(value);
self
}
/// Set the maximum numeric value for this element.
fn aria_max_numeric_value(mut self, value: f64) -> Self {
self.interactivity().aria_max_numeric_value = Some(value);
self
}
/// Set the orientation of this element.
fn aria_orientation(mut self, orientation: accesskit::Orientation) -> Self {
self.interactivity().aria_orientation = Some(orientation);
self
}
/// Set the heading level of this element.
fn aria_level(mut self, level: usize) -> Self {
self.interactivity().aria_level = Some(level);
self
}
/// Set the position in set of this element.
fn aria_position_in_set(mut self, position: usize) -> Self {
self.interactivity().aria_position_in_set = Some(position);
self
}
/// Set the size of set for this element.
fn aria_size_of_set(mut self, size: usize) -> Self {
self.interactivity().aria_size_of_set = Some(size);
self
}
/// Set the row index for this element.
fn aria_row_index(mut self, index: usize) -> Self {
self.interactivity().aria_row_index = Some(index);
self
}
/// Set the column index for this element.
fn aria_column_index(mut self, index: usize) -> Self {
self.interactivity().aria_column_index = Some(index);
self
}
/// Set the row count for this element.
fn aria_row_count(mut self, count: usize) -> Self {
self.interactivity().aria_row_count = Some(count);
self
}
/// Set the column count for this element.
fn aria_column_count(mut self, count: usize) -> Self {
self.interactivity().aria_column_count = Some(count);
self
}
/// Register a handler for an accessibility action on this element.
/// The handler is called when a screen reader requests the given action.
///
/// See the [accessibility guide](crate::_accessibility) for an overview.
fn on_a11y_action(
mut self,
action: accesskit::Action,
listener: impl FnMut(Option<&accesskit::ActionData>, &mut crate::Window, &mut crate::App)
+ 'static,
) -> Self {
self.interactivity()
.a11y_action_listeners
.push((action, Box::new(listener)));
self
}
/// Set this element to focusable.
fn focusable(mut self) -> Self {
self.interactivity().focusable = true;
@ -1474,6 +1592,18 @@ impl Element for Div {
self.interactivity.source_location()
}
fn a11y_role(&self) -> Option<accesskit::Role> {
// Nodes with `GenericContainer` should never be reported to accesskit.
// Equivalent to an HTML div with no role.
self.interactivity
.override_role
.filter(|role| *role != accesskit::Role::GenericContainer)
}
fn write_a11y_info(&self, node: &mut accesskit::Node) {
self.interactivity.write_a11y_info(node);
}
#[stacksafe]
fn request_layout(
&mut self,
@ -1710,6 +1840,25 @@ pub struct Interactivity {
pub(crate) tab_group: bool,
pub(crate) tab_stop: bool,
pub(crate) a11y_action_listeners:
Vec<(accesskit::Action, crate::window::a11y::A11yActionListener)>,
pub(crate) override_role: Option<accesskit::Role>,
pub(crate) aria_label: Option<SharedString>,
pub(crate) aria_selected: Option<bool>,
pub(crate) aria_expanded: Option<bool>,
pub(crate) aria_toggled: Option<accesskit::Toggled>,
pub(crate) aria_numeric_value: Option<f64>,
pub(crate) aria_min_numeric_value: Option<f64>,
pub(crate) aria_max_numeric_value: Option<f64>,
pub(crate) aria_orientation: Option<accesskit::Orientation>,
pub(crate) aria_level: Option<usize>,
pub(crate) aria_position_in_set: Option<usize>,
pub(crate) aria_size_of_set: Option<usize>,
pub(crate) aria_row_index: Option<usize>,
pub(crate) aria_column_index: Option<usize>,
pub(crate) aria_row_count: Option<usize>,
pub(crate) aria_column_count: Option<usize>,
#[cfg(any(feature = "inspector", debug_assertions))]
pub(crate) source_location: Option<&'static core::panic::Location<'static>>,
@ -1830,6 +1979,16 @@ impl Interactivity {
if let Some(focus_handle) = self.tracked_focus_handle.as_ref() {
window.set_focus_handle(focus_handle, cx);
if window.a11y.is_active() {
if let Some(global_id) = global_id {
let node_id = global_id.accesskit_node_id();
window.a11y.focus_ids.insert(node_id, focus_handle.id);
if focus_handle.is_focused(window) && window.a11y.nodes.has_node(node_id) {
window.a11y.nodes.set_focus(node_id);
}
}
}
}
window.with_optional_element_state::<InteractiveElementState, _>(
global_id,
@ -2054,6 +2213,22 @@ impl Interactivity {
}
self.paint_keyboard_listeners(window, cx);
if window.a11y.is_active() {
if let Some(global_id) = global_id {
if !self.a11y_action_listeners.is_empty() {
let node_id = global_id.accesskit_node_id();
for (action, listener) in
self.a11y_action_listeners.drain(..)
{
window.on_a11y_action(
node_id, action, listener,
);
}
}
}
}
f(&style, window, cx);
if let Some(_hitbox) = hitbox {
@ -2857,6 +3032,63 @@ impl Interactivity {
style
}
pub(crate) fn write_a11y_info(&self, node: &mut accesskit::Node) {
if let Some(label) = &self.aria_label {
node.set_label(label.to_string());
}
if let Some(selected) = self.aria_selected {
node.set_selected(selected);
}
if let Some(expanded) = self.aria_expanded {
node.set_expanded(expanded);
}
if let Some(toggled) = self.aria_toggled {
node.set_toggled(toggled);
}
if let Some(value) = self.aria_numeric_value {
node.set_numeric_value(value);
}
if let Some(value) = self.aria_min_numeric_value {
node.set_min_numeric_value(value);
}
if let Some(value) = self.aria_max_numeric_value {
node.set_max_numeric_value(value);
}
if let Some(orientation) = self.aria_orientation {
node.set_orientation(orientation);
}
if let Some(level) = self.aria_level {
node.set_level(level);
}
if let Some(position) = self.aria_position_in_set {
node.set_position_in_set(position);
}
if let Some(size) = self.aria_size_of_set {
node.set_size_of_set(size);
}
if let Some(index) = self.aria_row_index {
node.set_row_index(index);
}
if let Some(index) = self.aria_column_index {
node.set_column_index(index);
}
if let Some(count) = self.aria_row_count {
node.set_row_count(count);
}
if let Some(count) = self.aria_column_count {
node.set_column_count(count);
}
if !self.click_listeners.is_empty() {
node.add_action(accesskit::Action::Click);
}
if self.tracked_focus_handle.is_some() || self.focusable {
node.add_action(accesskit::Action::Focus);
}
for (action, _) in &self.a11y_action_listeners {
node.add_action(*action);
}
}
}
/// The per-frame state of an interactive element. Used for tracking stateful interactions like clicks
@ -3263,6 +3495,14 @@ where
self.element.source_location()
}
fn a11y_role(&self) -> Option<accesskit::Role> {
self.element.a11y_role()
}
fn write_a11y_info(&self, node: &mut accesskit::Node) {
self.element.write_a11y_info(node);
}
fn request_layout(
&mut self,
id: Option<&GlobalElementId>,

View file

@ -13,11 +13,244 @@ use std::{
borrow::Cow,
cell::{Cell, RefCell},
mem,
ops::Range,
ops::{Deref, DerefMut, Range},
rc::Rc,
sync::Arc,
};
/// An [`Element`] that renders text.
///
/// In general, [`Text`] objects should be created via the [`text`] macro:
/// ```rust
/// # use gpui::*;
/// # fn render() -> impl IntoElement {
/// div().child(text!("hello"))
/// # }
/// ```
/// ## IDs and Accessibility
///
/// [`Text`] elements have an ID. This ID is primarily used to produce nodes in
/// the accessibility tree, which allows the text to be visible to screen
/// readers and other assistive technologies.
///
/// This ID is stable across frames. If the same text, with the same ID, is
/// present in two consecutive frames, no updates are reported to the screen
/// reader. If the text changes, but the ID stays the same, then the screen
/// reader will be notified that a text node's content has changed. **However**,
/// if the ID changes, then the screen reader will be notified that a node has
/// been removed, and a new node has been added.
///
/// When using the [`text`] macro, each invocation of the macro will get a
/// unique ID, derived from its position in the source code (filename, line, and
/// column). For example:
/// ```rust
/// # use gpui::*;
/// let x = text!("hello");
/// let y = text!("hello");
/// // not equal, because different `text!` invocations produced them
/// assert_ne!(x.id(), y.id());
///
/// fn make_text(s: &str) -> Text { text!(s) }
/// let x = make_text("hello");
/// let y = make_text("hello");
/// // equal, because the same `text!` invocation produced them
/// assert_eq!(x.id(), y.id());
/// ```
/// When the contents of an invocation of [`text`] do not change, this
/// distinction is less relevant (with the caveat that you still need to take
/// care to ensure that duplicate IDs do not appear).
///
/// However, when a [`text`] invocation's argument *does* change, you should
/// consider whether this change should be reported as a node "updating its
/// contents", or an old node being destroyed and a new node being created.
#[derive(Debug, Clone)]
pub struct Text {
id: Option<ElementId>,
text: SharedString,
}
impl Text {
/// Create a new [`Text`] element with a specific ID.
///
/// If you want a unique ID to be assigned automatically, use the [`text`]
/// macro. The docs for [`Text`] have more detail about choosing IDs.
#[inline]
pub const fn new(id: ElementId, text: SharedString) -> Self {
Self { id: Some(id), text }
}
/// Create a new [`Text`] element that is inaccessible to screen readers.
///
/// In order for text to be accessible to screen readers, it must have an ID
/// provided. If you want text to be accessible, either use [`text`] to have
/// an ID automatically assigned, or use [`Text::new`] to manually assign an
/// ID.
///
/// This function is intended for use inside custom UI components, where
/// accessible properties may be set on parent containers.
#[inline]
pub const fn new_inaccessible(text: SharedString) -> Self {
Self { id: None, text }
}
/// The ID of this [`Text`] element.
#[inline]
pub const fn id(&self) -> Option<&ElementId> {
self.id.as_ref()
}
/// Produce a new [`Text`] with the given `id`.
pub fn with_id(mut self, id: impl Into<ElementId>) -> Self {
self.id = Some(id.into());
self
}
/// The text that this [`Text`] element will display.
#[inline]
pub const fn text(&self) -> &SharedString {
&self.text
}
}
impl Deref for Text {
type Target = SharedString;
fn deref(&self) -> &Self::Target {
&self.text
}
}
impl DerefMut for Text {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.text
}
}
/// Trivial hash function for the location information produced by the [`text`]
/// macro. Not covered by semver guarantees. Performance is not particularly
/// significant because it's only used on small strings in const contexts.
#[doc(hidden)]
pub const fn __hash_text_macro_location_unstable_do_not_use(s: &'static str) -> u64 {
const BASIS: u64 = 0xcbf29ce484222325;
const PRIME: u64 = 0x100000001b3;
let bytes = s.as_bytes();
let mut hash = BASIS;
let mut i = 0;
while i < bytes.len() {
hash ^= bytes[i] as u64;
hash = hash.wrapping_mul(PRIME);
i += 1;
}
hash
}
/// Create a new [`Text`] element.
///
/// ```rust
/// # use gpui::*;
/// let a = text!("hello");
/// let b = text!(id = "farewell-message", "hello");
///
/// ```
///
/// Text created with this macro is *accessible*. The macro generates an ID
/// based on the source location. See the docs for [`Text`] for a more in-depth
/// explanation of the significance of the ID of a [`Text`] element.
#[macro_export]
macro_rules! text {
(id = $id:expr, $text:expr) => {{ $crate::Text::new($id.into(), $text.into()) }};
($text:expr) => {{
const ID: &'static str = concat!(file!(), "/", line!(), ":", column!());
const HASH: u64 = $crate::__hash_text_macro_location_unstable_do_not_use(ID);
$crate::Text::new($crate::ElementId::Integer(HASH), $text.into())
}};
}
impl IntoElement for Text {
type Element = Self;
#[inline]
fn into_element(self) -> Self::Element {
self
}
}
impl Element for Text {
type RequestLayoutState = TextLayout;
type PrepaintState = ();
fn id(&self) -> Option<ElementId> {
self.id.clone()
}
fn source_location(&self) -> Option<&'static std::panic::Location<'static>> {
None
}
fn a11y_role(&self) -> Option<accesskit::Role> {
if self.id.is_some() {
Some(accesskit::Role::Label)
} else {
None
}
}
fn write_a11y_info(&self, node: &mut accesskit::Node) {
node.set_value(self.text.to_string());
}
fn request_layout(
&mut self,
id: Option<&GlobalElementId>,
inspector_id: Option<&InspectorElementId>,
window: &mut Window,
cx: &mut App,
) -> (LayoutId, Self::RequestLayoutState) {
<SharedString as Element>::request_layout(&mut self.text, id, inspector_id, window, cx)
}
fn prepaint(
&mut self,
id: Option<&GlobalElementId>,
inspector_id: Option<&InspectorElementId>,
bounds: Bounds<Pixels>,
request_layout: &mut Self::RequestLayoutState,
window: &mut Window,
cx: &mut App,
) -> Self::PrepaintState {
<SharedString as Element>::prepaint(
&mut self.text,
id,
inspector_id,
bounds,
request_layout,
window,
cx,
)
}
fn paint(
&mut self,
id: Option<&GlobalElementId>,
inspector_id: Option<&InspectorElementId>,
bounds: Bounds<Pixels>,
request_layout: &mut Self::RequestLayoutState,
prepaint: &mut Self::PrepaintState,
window: &mut Window,
cx: &mut App,
) {
<SharedString as Element>::paint(
&mut self.text,
id,
inspector_id,
bounds,
request_layout,
prepaint,
window,
cx,
);
}
}
impl Element for &'static str {
type RequestLayoutState = TextLayout;
type PrepaintState = ();
@ -807,6 +1040,14 @@ impl Element for InteractiveText {
None
}
fn a11y_role(&self) -> Option<accesskit::Role> {
Some(accesskit::Role::Label)
}
fn write_a11y_info(&self, node: &mut accesskit::Node) {
node.set_value(self.text.text.to_string());
}
fn request_layout(
&mut self,
_id: Option<&GlobalElementId>,
@ -1009,6 +1250,8 @@ impl IntoElement for InteractiveText {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_into_element_for() {
use crate::{ParentElement as _, SharedString, div};
@ -1019,4 +1262,23 @@ mod tests {
let _ = div().child(Cow::Borrowed("Cow"));
let _ = div().child(SharedString::from("SharedString"));
}
#[test]
fn text_macro_id() {
// one call to `text!` = one id
fn make_text_stable_id(happy: bool) -> Text {
text!(if happy { "happy" } else { "sad" })
}
// two calls to `text!` = two ids
fn make_text_unstable_id(happy: bool) -> Text {
if happy { text!("happy") } else { text!("sad") }
}
assert_eq!(make_text_stable_id(false).id, make_text_stable_id(true).id);
assert_ne!(
make_text_unstable_id(false).id,
make_text_unstable_id(true).id
);
}
}

View file

@ -56,6 +56,8 @@ mod window;
#[cfg(any(test, feature = "test-support"))]
pub use proptest;
#[cfg(doc)]
pub mod _accessibility;
#[cfg(doc)]
pub mod _ownership_and_data_flow;
@ -75,6 +77,9 @@ mod seal {
pub trait Sealed {}
}
pub use accesskit;
pub use accesskit::Action as AccessibleAction;
pub use accesskit::{Orientation, Role, Toggled};
pub use action::*;
pub use anyhow::Result;
pub use app::*;

View file

@ -591,6 +591,16 @@ impl Tiling {
}
}
/// Callbacks for the accessibility adapter.
pub struct A11yCallbacks {
/// Called when the adapter is activated (a screen reader connects).
pub activation: Box<dyn Fn() -> Option<accesskit::TreeUpdate> + Send + 'static>,
/// Called when an action is requested by the screen reader.
pub action: Box<dyn Fn(accesskit::ActionRequest) + Send + 'static>,
/// Called when the adapter is deactivated (screen reader disconnects).
pub deactivation: Box<dyn Fn() + Send + 'static>,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
#[expect(missing_docs)]
pub struct RequestFrameOptions {
@ -700,6 +710,15 @@ pub trait PlatformWindow: HasWindowHandle + HasDisplayHandle {
fn play_system_bell(&self) {}
/// Initialize the accessibility adapter with callbacks.
fn a11y_init(&self, _callbacks: A11yCallbacks) {}
/// Provide a TreeUpdate to the accessibility adapter.
fn a11y_tree_update(&self, _tree_update: accesskit::TreeUpdate) {}
/// Inform the adapter of updated window bounds.
fn a11y_update_window_bounds(&self) {}
#[cfg(any(test, feature = "test-support"))]
fn as_test(&mut self) -> Option<&mut TestWindow> {
None

View file

@ -52,14 +52,18 @@ use std::{
rc::Rc,
sync::{
Arc, Weak,
atomic::{AtomicUsize, Ordering::SeqCst},
atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
},
time::Duration,
};
use uuid::Uuid;
pub(crate) mod a11y;
mod prompts;
use self::a11y::A11y;
#[cfg(not(target_family = "wasm"))]
use self::a11y::ROOT_NODE_ID;
use crate::util::{
atomic_incr_if_not_zero, ceil_to_device_pixel, floor_to_device_pixel, round_half_toward_zero,
round_half_toward_zero_f64, round_stroke_to_device_pixel, round_to_device_pixel,
@ -1021,6 +1025,7 @@ pub struct Window {
captured_hitbox: Option<HitboxId>,
#[cfg(any(feature = "inspector", debug_assertions))]
inspector: Option<Entity<Inspector>>,
pub(crate) a11y: A11y,
}
#[derive(Clone, Debug, Default)]
@ -1325,6 +1330,85 @@ impl Window {
WindowBounds::Windowed(_) => {}
}
let a11y_active_flag = Arc::new(AtomicBool::new(false));
#[cfg(not(target_family = "wasm"))]
{
let initial_tree = accesskit::TreeUpdate {
nodes: vec![(ROOT_NODE_ID, accesskit::Node::new(accesskit::Role::Window))],
tree: Some(accesskit::Tree::new(ROOT_NODE_ID)),
tree_id: accesskit::TreeId::ROOT,
focus: ROOT_NODE_ID,
};
let (activation_sender, activation_receiver) = async_channel::unbounded::<()>();
let (deactivation_sender, deactivation_receiver) = async_channel::unbounded::<()>();
let (action_sender, action_receiver) =
async_channel::unbounded::<accesskit::ActionRequest>();
platform_window.a11y_init(crate::A11yCallbacks {
activation: {
let active_flag = a11y_active_flag.clone();
Box::new(move || {
log::info!("Accessibility activated");
active_flag.store(true, SeqCst);
activation_sender.send_blocking(()).log_err();
Some(initial_tree.clone())
})
},
action: Box::new(move |request| {
action_sender.send_blocking(request).log_err();
}),
deactivation: {
let active_flag = a11y_active_flag.clone();
Box::new(move || {
log::info!("Accessibility deactivated");
active_flag.store(false, SeqCst);
deactivation_sender.send_blocking(()).log_err();
})
},
});
// A11y can be activated at any time, and so we cannot compute a
// correct `TreeUpdate` on-demand. When this happens, we return a
// default empty `TreeUpdate`.
//
// So we force a new frame, which will then send a correct `TreeUpdate`.
let mut async_cx = cx.to_async();
cx.foreground_executor()
.spawn(async move {
while activation_receiver.recv().await.is_ok() {
handle
.update(&mut async_cx, |_, window, _| window.refresh())
.log_err();
}
})
.detach();
let mut async_cx = cx.to_async();
cx.foreground_executor()
.spawn(async move {
while deactivation_receiver.recv().await.is_ok() {
handle
.update(&mut async_cx, |_, window, _| window.refresh())
.log_err();
}
})
.detach();
let mut async_cx = cx.to_async();
cx.foreground_executor()
.spawn(async move {
while let Ok(request) = action_receiver.recv().await {
handle
.update(&mut async_cx, |_, window, cx| {
window.handle_a11y_action(request, cx);
})
.log_err();
}
})
.detach();
}
platform_window.on_close(Box::new({
let window_id = handle.window_id();
let mut cx = cx.to_async();
@ -1633,6 +1717,7 @@ impl Window {
captured_hitbox: None,
#[cfg(any(feature = "inspector", debug_assertions))]
inspector: None,
a11y: A11y::new(a11y_active_flag),
})
}
@ -2620,6 +2705,11 @@ impl Window {
self.invalidator.set_phase(DrawPhase::Prepaint);
self.tooltip_bounds.take();
self.a11y.sync_active_flag();
if self.a11y.is_active() {
self.a11y.begin_frame();
}
let _inspector_width: Pixels = rems(30.0).to_pixels(self.rem_size());
let root_size = {
#[cfg(any(feature = "inspector", debug_assertions))]
@ -2686,6 +2776,26 @@ impl Window {
#[cfg(any(feature = "inspector", debug_assertions))]
self.paint_inspector_hitbox(cx);
// a11y may have been activated/deactivated halfway through the frame
let a11y_active_start_of_frame = self.a11y.is_active();
self.a11y.sync_active_flag();
let a11y_active_end_of_frame = self.a11y.is_active();
let should_send_a11y_update = a11y_active_start_of_frame && a11y_active_end_of_frame;
if a11y_active_start_of_frame {
// clear the builder state regardless
let tree_update = self.a11y.end_frame();
if should_send_a11y_update {
log::debug!(
"Sending a11y tree update: {} nodes",
tree_update.nodes.len()
);
self.platform_window.a11y_tree_update(tree_update);
}
}
}
fn prepaint_tooltip(&mut self, cx: &mut App) -> Option<AnyElement> {
@ -5296,6 +5406,87 @@ impl Window {
self.platform_window.play_system_bell()
}
/// Register a listener for an accessibility action on a specific node.
/// The listener will be called when a screen reader requests the given
/// action on the node identified by `node_id`.
///
/// See the [accessibility guide](crate::_accessibility) for an overview.
pub fn on_a11y_action(
&mut self,
node_id: accesskit::NodeId,
action: accesskit::Action,
listener: impl FnMut(Option<&accesskit::ActionData>, &mut Window, &mut App) + 'static,
) {
self.a11y
.action_listeners
.entry(node_id)
.or_default()
.push((action, Box::new(listener)));
}
#[cfg(not(target_family = "wasm"))]
pub(crate) fn handle_a11y_action(&mut self, request: accesskit::ActionRequest, cx: &mut App) {
// Take listeners out temporarily so the closures can borrow Window
// mutably, then restore them afterward.
if let Some(mut listeners) = self.a11y.action_listeners.remove(&request.target_node) {
let extra_data = request.data.as_ref();
let mut matched = false;
for (action, listener) in &mut listeners {
if *action == request.action {
listener(extra_data, self, cx);
matched = true;
}
}
self.a11y
.action_listeners
.insert(request.target_node, listeners);
if matched {
return;
}
}
// Fall back to built-in action handling.
match request.action {
accesskit::Action::Click => {
if let Some(bounds) = self.a11y.node_bounds.get(&request.target_node).copied() {
let center = bounds.center();
let mouse_down = PlatformInput::MouseDown(crate::MouseDownEvent {
button: MouseButton::Left,
position: center,
modifiers: Modifiers::default(),
click_count: 1,
first_mouse: false,
});
let mouse_up = PlatformInput::MouseUp(MouseUpEvent {
button: MouseButton::Left,
position: center,
modifiers: Modifiers::default(),
click_count: 1,
});
self.dispatch_event(mouse_down, cx);
self.dispatch_event(mouse_up, cx);
}
}
accesskit::Action::Focus => {
if let Some(focus_id) = self.a11y.focus_ids.get(&request.target_node).copied()
&& let Some(handle) = FocusHandle::for_id(focus_id, &cx.focus_handles)
{
self.focus(&handle, cx);
}
}
accesskit::Action::Blur => {
self.blur();
}
_ => {
log::debug!(
"Unhandled a11y action: {:?} on {:?}",
request.action,
request.target_node
);
}
}
}
/// Toggles the inspector mode on this window.
#[cfg(any(feature = "inspector", debug_assertions))]
pub fn toggle_inspector(&mut self, cx: &mut App) {

View file

@ -0,0 +1,342 @@
//! Accessibility support, provided by [AccessKit][accesskit].
//!
//! There are user-facing guide-level docs [here](crate::_accessibility).
//!
//! ## Architecture
//!
//! ```text
//! ┌────────────────────────────────┐ ┌─────────────────────┐
//! ┌─▶│ AccessKit Adapter (MacOS) │◀─▶│ MacOS System APIs │
//! │ └────────────────────────────────┘ └─────────────────────┘
//! │
//! ┌──────┐ ┌───────────┐ │ ┌────────────────────────────────┐ ┌─────────────────────┐
//! │ GPUI │◀─▶│ AccessKit │◀─┼─▶│ AccessKit Adapter (Windows) │◀─▶│ Windows System APIs │
//! └──────┘ └───────────┘ │ └────────────────────────────────┘ └─────────────────────┘
//! │
//! │ ┌────────────────────────────────┐ ┌─────────────────────┐
//! └─▶│ AccessKit Adapter (Linux) │◀─▶│ dbus │
//! └────────────────────────────────┘ └─────────────────────┘
//! ```
//!
//! In order for GPUI apps to be usable for people using assistive technology,
//! we must do a few things:
//! - Inform the system when the UI changes meaningfully. This includes:
//! - Reporting new/removed/changed UI elements
//! - *Not* reporting irrelevant UI changes, e.g. an invisible `div()` being
//! added.
//! - Reporting the appearance and capabilities of each UI element. For example:
//! - What does this piece of text say?
//! - How far along is this progress bar?
//! - Can this node be focused?
//! - Can this node have a value directly assigned? (e.g. a slider)
//! - Allowing the system to interact with the UI by dispatching actions to
//! nodes. Note that AccessKit has its own [`Action`] type, which is not the
//! [`crate::Action`] trait.
//! - Activate and deactivate accessibility features when requested by the
//! system.
//!
//! Activating and deactivating at the right time is trivial, so I won't go into
//! detail here. The other two are almost orthogonal in implementation.
//!
//! The state for both lives in the [`A11y`] struct in this module.
//!
//! ### Reporting UI changes
//!
//! Every frame, we build a [`TreeUpdate`] and send it to the platform-specific
//! adapter. A [`TreeUpdate`] is a representation of a subset of the UI tree.
//! When the adapter receives the update, it diffs it against the previous
//! update, and calls platform-specific APIs to inform screen readers about the
//! changes. Nodes may have been created, destroyed, or updated.
//!
//! Each node has an ID, and this ID *should* be stable across frames. If a
//! node's ID changes, then, from AccessKit's point of view, it is a different
//! node.
//!
//! We derive the node ID from the [`GlobalElementId`] in
//! [`GlobalElementId::accesskit_node_id`]. Nodes without [`GlobalElementId`]s
//! cannot produce an AccessKit [`NodeId`], and so are not included in the
//! accessibility tree. We try to warn when using accessibility APIs on
//! [`div()`] without setting an ID.
//!
//! This all happens in [`Drawable::prepaint`]. The [`A11y`] struct maintains a
//! stack of nodes during prepainting, which we can use to calculate the
//! [`NodeId`]s, and record parent-child relationships. Once all [`Element`]s in
//! a frame have been prepainted, we send the resulting [`TreeUpdate`] object to
//! the adapter and the screen reader can announce the changes.
//!
//! ### Responding to actions
//!
//! On adapter creation, we provide a callback to the adapter, which can be used
//! to dispatch actions. This callback forwards to [`A11y::action_listeners`], a
//! mapping from [`NodeId`]s to action handlers (basically just `Box<dyn
//! Fn()>`).
//!
//! This is populated in:
//! - [`Window::on_a11y_action`], which is called by:
//! - [`Interactivity::paint`], which is called by:
//! - [`InteractiveElement::on_a11y_action`], which is a public-facing API
//!
//! These are cleared at the start of a frame, and re-populated during painting.
//!
//! [`Element`]: crate::Element
//! [`GlobalElementId`]: crate::GlobalElementId
//! [`div()`]: crate::div
//! [`Interactivity::paint`]: crate::Interactivity::paint
//! [`InteractiveElement::on_a11y_action`]: crate::InteractiveElement::on_a11y_action
//! [`NodeId`]: accesskit::NodeId
//! [`Drawable::prepaint`]: crate::Drawable::prepaint
use crate::{App, Bounds, FocusId, Pixels, Window};
use accesskit::{Action, NodeId, TreeUpdate};
use collections::{FxHashMap, FxHashSet};
use smallvec::SmallVec;
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
/// The fixed AccessKit node ID used for the root of every window's a11y tree.
pub(crate) const ROOT_NODE_ID: NodeId = NodeId(0);
/// A listener for an accessibility action on a specific node.
pub(crate) type A11yActionListener =
Box<dyn FnMut(Option<&accesskit::ActionData>, &mut Window, &mut App) + 'static>;
/// Per-window accessibility state.
///
/// Manages the AccessKit tree that is built each frame and the mappings
/// needed to dispatch incoming action requests back to the right elements.
pub(crate) struct A11y {
/// Whether a11y features have been requested by the system.
///
/// Updated by AccessKit using callbacks provided to the adapter. Can change
/// halfway through a frame.
active_flag: Arc<AtomicBool>,
/// Whether a11y features are active for *this specific frame*.
///
/// At the start of each frame, we load [`Self::active_flag`] (using
/// [`Self::sync_active_flag`]) and use this to determine whether we
/// should construct a [`TreeUpdate`] for this frame. It's important that
/// this value is stable within a frame, because the builder API exposed by
/// this type maintains a stack of nodes and each must be pushed and popped
/// exactly once.
///
/// At the end of the frame, we re-call [`Self::sync_active_flag`] to
/// determine whether we should actually send the finished [`TreeUpdate`].
active_this_frame: bool,
pub(crate) nodes: A11yNodeBuilder,
pub(crate) focus_ids: FxHashMap<NodeId, FocusId>,
pub(crate) node_bounds: FxHashMap<NodeId, Bounds<Pixels>>,
pub(crate) action_listeners: FxHashMap<NodeId, Vec<(Action, A11yActionListener)>>,
}
impl A11y {
pub(crate) fn new(active_flag: Arc<AtomicBool>) -> Self {
Self {
active_flag,
active_this_frame: false,
nodes: A11yNodeBuilder::new(),
focus_ids: FxHashMap::default(),
node_bounds: FxHashMap::default(),
action_listeners: FxHashMap::default(),
}
}
/// Ensures that [`Self::is_active`] returns up to date information.
///
/// See the docs for [`Self::active_flag`] and [`Self::active_this_frame`]
/// for more commentary.
pub(crate) fn sync_active_flag(&mut self) {
self.active_this_frame = self.active_flag.load(Ordering::SeqCst);
}
pub(crate) fn is_active(&self) -> bool {
self.active_this_frame
}
/// Clear per-frame state and push the root node to start a new frame.
pub(crate) fn begin_frame(&mut self) {
self.focus_ids.clear();
self.node_bounds.clear();
self.action_listeners.clear();
self.nodes.begin_frame();
}
/// Finalize the tree and produce a [`TreeUpdate`] for the platform adapter.
pub(crate) fn end_frame(&mut self) -> TreeUpdate {
self.nodes.finalize()
}
}
pub(crate) struct A11yNodeBuilder {
ids_stack: SmallVec<[NodeId; 16]>,
nodes_stack: SmallVec<[accesskit::Node; 16]>,
/// This is the exact type required by accesskit, so we can't just make it a
/// `HashMap<NodeId, Node>` to remove the need for `seen_ids`
all_nodes: Vec<(NodeId, accesskit::Node)>,
seen_ids: FxHashSet<NodeId>,
focus: NodeId,
#[cfg(debug_assertions)]
has_set_focus: bool,
}
impl A11yNodeBuilder {
fn new() -> Self {
Self {
ids_stack: SmallVec::new(),
nodes_stack: SmallVec::new(),
all_nodes: Vec::new(),
seen_ids: FxHashSet::default(),
focus: ROOT_NODE_ID,
#[cfg(debug_assertions)]
has_set_focus: false,
}
}
/// Push a new node onto the stack. It becomes a child of the current
/// top-of-stack node.
///
/// Returns `true` if the node was successfully pushed.
pub(crate) fn push(&mut self, id: NodeId, node: accesskit::Node) -> bool {
debug_assert!(!self.ids_stack.is_empty(), "push called before push_root");
if !self.seen_ids.insert(id) {
debug_assert!(
false,
"Duplicate a11y node id: {id:?}. In a release build, this node would be silently discarded from the a11y tree."
);
// We need to return `false` here because inserting a duplicate
// node will cause a panic in accesskit
return false;
}
if let Some(parent) = self.nodes_stack.last_mut() {
parent.push_child(id);
}
self.ids_stack.push(id);
self.nodes_stack.push(node);
true
}
/// Pop the current node off the stack and finalize it into the all_nodes
/// list.
pub(crate) fn pop(&mut self) {
debug_assert!(self.ids_stack.len() > 1, "pop would remove the root node");
if let (Some(id), Some(node)) = (self.ids_stack.pop(), self.nodes_stack.pop()) {
self.all_nodes.push((id, node));
}
}
/// Push the root node to start a new frame.
fn begin_frame(&mut self) {
self.all_nodes.clear();
self.ids_stack.clear();
self.nodes_stack.clear();
self.seen_ids.clear();
#[cfg(debug_assertions)]
{
self.has_set_focus = false;
}
let root_node = accesskit::Node::new(accesskit::Role::Window);
self.ids_stack.push(ROOT_NODE_ID);
self.nodes_stack.push(root_node);
self.focus = ROOT_NODE_ID;
}
/// Returns whether a node with the given ID has been pushed in this frame.
pub(crate) fn has_node(&self, id: NodeId) -> bool {
id == ROOT_NODE_ID || self.seen_ids.contains(&id)
}
/// Set the focused node for this frame.
pub(crate) fn set_focus(&mut self, id: NodeId) {
#[cfg(debug_assertions)]
{
debug_assert!(
!self.has_set_focus,
"set_focus called more than once in a single frame"
);
self.has_set_focus = true;
}
self.focus = id;
}
fn finalize(&mut self) -> TreeUpdate {
// Stack should contain only the root node
debug_assert_eq!(self.ids_stack.len(), 1);
debug_assert_eq!(self.ids_stack[0], ROOT_NODE_ID);
if self.ids_stack.len() != 1 {
log::error!(
"a11y: Stack imbalance at end of frame: expected 1 (root), got {}. \
Some elements may have pushed without popping.",
self.ids_stack.len()
);
}
// Pop remaining nodes (should just be the root).
while !self.ids_stack.is_empty() {
if let (Some(id), Some(node)) = (self.ids_stack.pop(), self.nodes_stack.pop()) {
self.all_nodes.push((id, node));
}
}
let nodes = std::mem::take(&mut self.all_nodes);
let update = TreeUpdate {
nodes,
tree: Some(accesskit::Tree::new(ROOT_NODE_ID)),
tree_id: accesskit::TreeId::ROOT,
focus: self.focus,
};
Self::repair_tree_update(update)
}
/// Accesskit panics on invalid [`TreeUpdate`]s. This function defensively
/// checks invariants that accesskit panics on, and tries to fix them.
fn repair_tree_update(mut update: TreeUpdate) -> TreeUpdate {
let node_ids: FxHashSet<NodeId> = update.nodes.iter().map(|(id, _)| *id).collect();
// Focus must point to a node in the tree.
if !node_ids.contains(&update.focus) {
log::error!(
"a11y: Focused node {:?} is not in the tree ({} nodes). \
Falling back to root. This is a bug in the a11y tree builder.",
update.focus,
update.nodes.len()
);
update.focus = ROOT_NODE_ID;
}
// Every child reference must point to a node in the update.
for (id, node) in &mut update.nodes {
let has_invalid_child = node
.children()
.iter()
.any(|child_id| !node_ids.contains(child_id));
if has_invalid_child {
let children = node.children();
let invalid_count = children
.iter()
.filter(|child_id| !node_ids.contains(child_id))
.count();
log::error!(
"a11y: Node {:?} references {} children not present in the tree. \
Stripping invalid child references.",
id,
invalid_count
);
let valid: Vec<NodeId> = children
.iter()
.copied()
.filter(|child_id| node_ids.contains(child_id))
.collect();
node.set_children(valid);
}
}
update
}
}

View file

@ -51,6 +51,8 @@ screen-capture = [
[target.'cfg(any(target_os = "linux", target_os = "freebsd"))'.dependencies]
accesskit.workspace = true
accesskit_unix.workspace = true
anyhow.workspace = true
bytemuck = "1"
collections.workspace = true

View file

@ -123,6 +123,7 @@ pub struct WaylandWindowState {
in_progress_window_controls: Option<WindowControls>,
window_controls: WindowControls,
client_inset: Option<Pixels>,
accesskit_adapter: Option<accesskit_unix::Adapter>,
}
pub enum WaylandSurfaceState {
@ -398,6 +399,7 @@ impl WaylandWindowState {
in_progress_window_controls: None,
window_controls: WindowControls::default(),
client_inset: None,
accesskit_adapter: None,
})
}
@ -1047,6 +1049,9 @@ impl WaylandWindowStatePtr {
fun(focus);
self.callbacks.borrow_mut().active_status_change = Some(fun);
}
if let Some(adapter) = self.state.borrow_mut().accesskit_adapter.as_mut() {
adapter.update_window_focus_state(focus);
}
}
pub fn set_hovered(&self, focus: bool) {
@ -1519,6 +1524,60 @@ impl PlatformWindow for WaylandWindow {
bell.ring(surface);
}
}
fn a11y_init(&self, callbacks: gpui::A11yCallbacks) {
let activation_handler = TrivialActivationHandler {
callback: callbacks.activation,
};
let action_handler = TrivialActionHandler(callbacks.action);
let deactivation_handler = TrivialDeactivationHandler {
callback: callbacks.deactivation,
};
let adapter =
accesskit_unix::Adapter::new(activation_handler, action_handler, deactivation_handler);
self.borrow_mut().accesskit_adapter = Some(adapter);
}
fn a11y_tree_update(&self, tree_update: accesskit::TreeUpdate) {
let mut state = self.borrow_mut();
if let Some(adapter) = state.accesskit_adapter.as_mut() {
adapter.update_if_active(|| tree_update);
}
}
fn a11y_update_window_bounds(&self) {
// Wayland doesn't expose window position, so this is a no-op
}
}
struct TrivialActivationHandler {
callback: Box<dyn Fn() -> Option<accesskit::TreeUpdate> + Send + 'static>,
}
impl accesskit::ActivationHandler for TrivialActivationHandler {
fn request_initial_tree(&mut self) -> Option<accesskit::TreeUpdate> {
(self.callback)()
}
}
struct TrivialActionHandler(Box<dyn Fn(accesskit::ActionRequest) + Send + 'static>);
impl accesskit::ActionHandler for TrivialActionHandler {
fn do_action(&mut self, request: accesskit::ActionRequest) {
(self.0)(request);
}
}
struct TrivialDeactivationHandler {
callback: Box<dyn Fn() + Send + 'static>,
}
impl accesskit::DeactivationHandler for TrivialDeactivationHandler {
fn deactivate_accessibility(&mut self) {
(self.callback)();
}
}
fn update_window(mut state: RefMut<WaylandWindowState>) {

View file

@ -285,6 +285,7 @@ pub struct X11WindowState {
edge_constraints: Option<EdgeConstraints>,
pub handle: AnyWindowHandle,
last_insets: [u32; 4],
accesskit_adapter: Option<accesskit_unix::Adapter>,
}
impl X11WindowState {
@ -801,6 +802,7 @@ impl X11WindowState {
decorations: WindowDecorations::Server,
last_insets: [0, 0, 0, 0],
edge_constraints: None,
accesskit_adapter: None,
counter_id: sync_request_counter,
last_sync_counter: None,
})
@ -1277,6 +1279,9 @@ impl X11WindowStatePtr {
fun(focus);
self.callbacks.borrow_mut().active_status_change = Some(fun);
}
if let Some(adapter) = self.state.borrow_mut().accesskit_adapter.as_mut() {
adapter.update_window_focus_state(focus);
}
}
pub fn set_hovered(&self, focus: bool) {
@ -1886,4 +1891,84 @@ impl PlatformWindow for X11Window {
// Volume 0% means don't increase or decrease from system volume
let _ = self.0.xcb.bell(0);
}
fn a11y_init(&self, callbacks: gpui::A11yCallbacks) {
let activation_handler = TrivialActivationHandler {
callback: callbacks.activation,
};
let action_handler = TrivialActionHandler(callbacks.action);
let deactivation_handler = TrivialDeactivationHandler {
callback: callbacks.deactivation,
};
let adapter =
accesskit_unix::Adapter::new(activation_handler, action_handler, deactivation_handler);
self.0.state.borrow_mut().accesskit_adapter = Some(adapter);
}
fn a11y_tree_update(&self, tree_update: accesskit::TreeUpdate) {
let mut state = self.0.state.borrow_mut();
if let Some(adapter) = state.accesskit_adapter.as_mut() {
adapter.update_if_active(|| tree_update);
}
}
fn a11y_update_window_bounds(&self) {
let mut state = self.0.state.borrow_mut();
let scale = state.scale_factor;
let bounds = state.bounds;
let [left, right, top, bottom] = state.last_insets;
let x = f32::from(bounds.origin.x);
let y = f32::from(bounds.origin.y);
let width = f32::from(bounds.size.width);
let height = f32::from(bounds.size.height);
let outer = accesskit::Rect {
x0: (x * scale) as f64,
y0: (y * scale) as f64,
x1: ((x + width) * scale) as f64,
y1: ((y + height) * scale) as f64,
};
let inner = accesskit::Rect {
x0: (x * scale) as f64 + left as f64,
y0: (y * scale) as f64 + top as f64,
x1: ((x + width) * scale) as f64 - right as f64,
y1: ((y + height) * scale) as f64 - bottom as f64,
};
if let Some(adapter) = state.accesskit_adapter.as_mut() {
adapter.set_root_window_bounds(outer, inner);
}
}
}
struct TrivialActivationHandler {
callback: Box<dyn Fn() -> Option<accesskit::TreeUpdate> + Send + 'static>,
}
impl accesskit::ActivationHandler for TrivialActivationHandler {
fn request_initial_tree(&mut self) -> Option<accesskit::TreeUpdate> {
(self.callback)()
}
}
struct TrivialActionHandler(Box<dyn Fn(accesskit::ActionRequest) + Send + 'static>);
impl accesskit::ActionHandler for TrivialActionHandler {
fn do_action(&mut self, request: accesskit::ActionRequest) {
(self.0)(request);
}
}
struct TrivialDeactivationHandler {
callback: Box<dyn Fn() + Send + 'static>,
}
impl accesskit::DeactivationHandler for TrivialDeactivationHandler {
fn deactivate_accessibility(&mut self) {
(self.callback)();
}
}

View file

@ -22,6 +22,8 @@ screen-capture = ["gpui/screen-capture"]
gpui.workspace = true
[target.'cfg(target_os = "macos")'.dependencies]
accesskit.workspace = true
accesskit_macos.workspace = true
anyhow.workspace = true
async-task = "4.7"
block = "0.1"

View file

@ -500,6 +500,7 @@ struct MacWindowState {
toggle_tab_bar_callback: Option<Box<dyn FnMut()>>,
activated_least_once: bool,
closed: Arc<AtomicBool>,
accesskit_adapter: Option<accesskit_macos::SubclassingAdapter>,
// The parent window if this window is a sheet (Dialog kind)
sheet_parent: Option<id>,
}
@ -829,6 +830,7 @@ impl MacWindow {
toggle_tab_bar_callback: None,
activated_least_once: false,
closed: Arc::new(AtomicBool::new(false)),
accesskit_adapter: None,
sheet_parent: None,
})));
@ -1730,6 +1732,59 @@ impl PlatformWindow for MacWindow {
let mut this = self.0.lock();
this.renderer.render_to_image(scene)
}
fn a11y_init(&self, callbacks: gpui::A11yCallbacks) {
let mut lock = self.0.lock();
let activation_handler = A11yActivationHandler {
callback: callbacks.activation,
};
let action_handler = A11yActionHandler(callbacks.action);
let adapter = unsafe {
accesskit_macos::SubclassingAdapter::for_window(
lock.native_window as *mut c_void,
activation_handler,
action_handler,
)
};
lock.accesskit_adapter = Some(adapter);
}
fn a11y_tree_update(&self, tree_update: accesskit::TreeUpdate) {
let events = {
let mut lock = self.0.lock();
lock.accesskit_adapter
.as_mut()
.and_then(|adapter| adapter.update_if_active(|| tree_update))
};
if let Some(events) = events {
events.raise();
}
}
fn a11y_update_window_bounds(&self) {
// macOS handles window bounds tracking automatically via NSAccessibility.
}
}
struct A11yActivationHandler {
callback: Box<dyn Fn() -> Option<accesskit::TreeUpdate> + Send + 'static>,
}
impl accesskit::ActivationHandler for A11yActivationHandler {
fn request_initial_tree(&mut self) -> Option<accesskit::TreeUpdate> {
(self.callback)()
}
}
struct A11yActionHandler(Box<dyn Fn(accesskit::ActionRequest) + Send + 'static>);
impl accesskit::ActionHandler for A11yActionHandler {
fn do_action(&mut self, request: accesskit::ActionRequest) {
(self.0)(request);
}
}
impl rwh::HasWindowHandle for MacWindow {
@ -2341,6 +2396,16 @@ extern "C" fn window_did_change_key_status(this: &Object, selector: Sel, _: id)
let executor = lock.foreground_executor.clone();
drop(lock);
let a11y_events = {
let mut lock = window_state.lock();
lock.accesskit_adapter
.as_mut()
.and_then(|adapter| adapter.update_view_focus_state(is_active))
};
if let Some(events) = a11y_events {
events.raise();
}
// When a window becomes active, trigger an immediate synchronous frame request to prevent
// tab flicker when switching between windows in native tabs mode.
//

View file

@ -25,6 +25,8 @@ win-legacy-compat = []
gpui.workspace = true
[target.'cfg(target_os = "windows")'.dependencies]
accesskit.workspace = true
accesskit_windows.workspace = true
anyhow.workspace = true
collections.workspace = true
etagere = "0.2"

View file

@ -112,6 +112,7 @@ impl WindowsWindowInner {
WM_GPUI_FORCE_UPDATE_WINDOW => self.draw_window(handle, true),
WM_GPUI_GPU_DEVICE_LOST => self.handle_device_lost(lparam),
DM_POINTERHITTEST => self.handle_dm_pointer_hit_test(wparam),
WM_GETOBJECT => self.handle_wm_getobject(wparam, lparam),
_ => None,
};
if let Some(n) = handled {
@ -728,6 +729,17 @@ impl WindowsWindowInner {
fn handle_activate_msg(self: &Rc<Self>, wparam: WPARAM) -> Option<isize> {
let activated = wparam.loword() > 0;
let events = self
.state
.a11y
.try_borrow_mut()
.ok()
.and_then(|mut a11y| a11y.as_mut()?.adapter.update_window_focus_state(activated));
if let Some(events) = events {
events.raise();
}
let this = self.clone();
if !activated {
@ -764,6 +776,23 @@ impl WindowsWindowInner {
None
}
fn handle_wm_getobject(&self, wparam: WPARAM, lparam: LPARAM) -> Option<isize> {
let result = {
let mut a11y = self.state.a11y.borrow_mut();
let a11y = a11y.as_mut()?;
a11y.adapter.handle_wm_getobject(
accesskit_windows::WPARAM(wparam.0),
accesskit_windows::LPARAM(lparam.0),
&mut a11y.activation_handler,
)?
};
// The borrow above must be dropped before calling `.into()`, because
// it calls `UiaReturnRawElementProvider` which may send a nested
// `WM_GETOBJECT` back into this window procedure.
let lresult: accesskit_windows::LRESULT = result.into();
Some(lresult.0)
}
fn handle_create_msg(&self, handle: HWND) -> Option<isize> {
if self.hide_title_bar {
notify_frame_changed(handle);

View file

@ -83,6 +83,7 @@ pub struct WindowsWindowState {
fullscreen: Cell<Option<StyleAndBounds>>,
initial_placement: Cell<Option<WindowOpenStatus>>,
hwnd: HWND,
pub(crate) a11y: RefCell<Option<A11yState>>,
}
pub(crate) struct WindowsWindowInner {
@ -176,6 +177,7 @@ impl WindowsWindowState {
hwnd,
invalidate_devices,
direct_manipulation,
a11y: RefCell::new(None),
})
}
@ -972,6 +974,69 @@ impl PlatformWindow for WindowsWindow {
// MB_OK: The sound specified as the Windows Default Beep sound.
let _ = unsafe { MessageBeep(MB_OK) };
}
fn a11y_init(&self, callbacks: gpui::A11yCallbacks) {
let action_handler = A11yActionHandler(callbacks.action);
let is_focused = unsafe { GetForegroundWindow() } == self.0.hwnd;
let adapter = accesskit_windows::Adapter::new(
accesskit_windows::HWND(self.0.hwnd.0),
is_focused,
action_handler,
);
let activation_handler = A11yActivationHandler {
callback: callbacks.activation,
};
*self.state.a11y.borrow_mut() = Some(A11yState {
adapter,
activation_handler,
});
}
fn a11y_tree_update(&self, tree_update: accesskit::TreeUpdate) {
let events = {
let mut a11y = self.state.a11y.borrow_mut();
a11y.as_mut()
.and_then(|a11y| a11y.adapter.update_if_active(|| tree_update))
};
// The borrow must be dropped before raising events, because
// `events.raise()` calls `UiaRaiseAutomationPropertyChangedEvent`
// which may send a nested `WM_GETOBJECT` back into this window
// procedure, re-entering `handle_wm_getobject` which also borrows
// `self.state.a11y`.
if let Some(events) = events {
events.raise();
}
}
fn a11y_update_window_bounds(&self) {
// Windows UIA handles window bounds tracking automatically.
}
}
pub(crate) struct A11yState {
pub(crate) adapter: accesskit_windows::Adapter,
pub(crate) activation_handler: A11yActivationHandler,
}
pub(crate) struct A11yActivationHandler {
callback: Box<dyn Fn() -> Option<accesskit::TreeUpdate> + Send + 'static>,
}
impl accesskit::ActivationHandler for A11yActivationHandler {
fn request_initial_tree(&mut self) -> Option<accesskit::TreeUpdate> {
(self.callback)()
}
}
struct A11yActionHandler(Box<dyn Fn(accesskit::ActionRequest) + Send + 'static>);
impl accesskit::ActionHandler for A11yActionHandler {
fn do_action(&mut self, request: accesskit::ActionRequest) {
(self.0)(request);
}
}
#[implement(IDropTarget)]

View file

@ -14,16 +14,19 @@ use language_model::{
ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
};
use language_models_cloud::{CloudLlmTokenProvider, CloudModelProvider};
use rand::{Rng as _, SeedableRng as _, rngs::StdRng};
use release_channel::AppVersion;
use settings::SettingsStore;
pub use settings::ZedDotDevAvailableModel as AvailableModel;
pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
use std::sync::Arc;
use std::time::Duration;
use ui::{TintColor, prelude::*};
const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
const MODELS_REFRESH_DEBOUNCE: Duration = Duration::from_secs(5 * 60);
struct ClientTokenProvider {
client: Arc<Client>,
@ -84,10 +87,12 @@ pub struct State {
user_store: Entity<UserStore>,
status: client::Status,
provider: Entity<CloudModelProvider<ClientTokenProvider>>,
pending_models_refresh: Option<Task<()>>,
_user_store_subscription: Subscription,
_settings_subscription: Subscription,
_llm_token_subscription: Subscription,
_provider_subscription: Subscription,
_cloud_reconnect_task: Task<()>,
}
impl State {
@ -112,10 +117,32 @@ impl State {
)
});
let cloud_reconnect_task = cx.spawn({
let client = client.clone();
async move |this, cx| {
let mut connection_id_rx = client.cloud_connection_id();
while let Some(connection_id) = connection_id_rx.next().await {
// The initial value `0` means no connection has been
// established since this `Client` was created; only real
// reconnects trigger a refresh.
if connection_id == 0 {
continue;
}
if this
.update(cx, |this, cx| this.schedule_debounced_models_refresh(cx))
.is_err()
{
break;
}
}
}
});
Self {
client: client.clone(),
user_store: user_store.clone(),
status,
pending_models_refresh: None,
_provider_subscription: cx.observe(&provider, |_, _, cx| cx.notify()),
provider,
_user_store_subscription: cx.subscribe(
@ -141,6 +168,7 @@ impl State {
this.refresh_models(cx);
},
),
_cloud_reconnect_task: cloud_reconnect_task,
}
}
@ -167,6 +195,24 @@ impl State {
provider.refresh_models(cx).detach_and_log_err(cx);
});
}
/// Schedules a model list refresh, replacing any previously scheduled
/// refresh.
fn schedule_debounced_models_refresh(&mut self, cx: &mut Context<Self>) {
self.pending_models_refresh = Some(cx.spawn(async move |this, cx| {
#[cfg(any(test, feature = "test-support"))]
let mut rng = StdRng::seed_from_u64(0);
#[cfg(not(any(test, feature = "test-support")))]
let mut rng = StdRng::from_os_rng();
let jitter = Duration::from_millis(
rng.random_range(0..MODELS_REFRESH_DEBOUNCE.as_millis() as u64),
);
cx.background_executor()
.timer(MODELS_REFRESH_DEBOUNCE + jitter)
.await;
this.update(cx, |this, cx| this.refresh_models(cx)).ok();
}));
}
}
impl CloudLanguageModelProvider {

View file

@ -29,7 +29,7 @@ language.workspace = true
linkify.workspace = true
log.workspace = true
markup5ever_rcdom.workspace = true
mermaid-rs-renderer.workspace = true
mermaid_render = { path = "../mermaid_render" }
pulldown-cmark.workspace = true
settings.workspace = true
stacksafe.workspace = true

View file

@ -34,8 +34,8 @@ use gpui::{
FocusHandle, Focusable, FontStyle, FontWeight, GlobalElementId, Hitbox, Hsla, Image,
ImageFormat, ImageSource, KeyContext, Length, MouseButton, MouseDownEvent, MouseEvent,
MouseMoveEvent, MouseUpEvent, Point, ScrollHandle, Stateful, StrikethroughStyle,
StyleRefinement, StyledImage, StyledText, Task, TextAlign, TextLayout, TextRun, TextStyle,
TextStyleRefinement, actions, img, point, quad,
StyleRefinement, StyledImage, StyledText, Subscription, Task, TextAlign, TextLayout, TextRun,
TextStyle, TextStyleRefinement, actions, img, point, quad,
};
use language::{CharClassifier, Language, LanguageRegistry, Rope};
use parser::CodeBlockMetadata;
@ -333,6 +333,7 @@ pub struct Markdown {
fallback_code_block_language: Option<LanguageName>,
options: MarkdownOptions,
mermaid_state: MermaidState,
_mermaid_theme_subscription: Option<Subscription>,
mermaid_showing_code: HashSet<usize>,
copied_code_blocks: HashSet<ElementId>,
wrapped_code_blocks: HashSet<usize>,
@ -497,6 +498,16 @@ impl Markdown {
cx: &mut Context<Self>,
) -> Self {
let focus_handle = cx.focus_handle();
let theme_subscription = if options.render_mermaid_diagrams {
Some(
cx.observe_global::<theme::GlobalTheme>(|this: &mut Self, cx| {
this.invalidate_mermaid_cache(cx);
}),
)
} else {
None
};
let mut this = Self {
source,
selection: Selection::default(),
@ -513,6 +524,7 @@ impl Markdown {
fallback_code_block_language,
options,
mermaid_state: MermaidState::default(),
_mermaid_theme_subscription: theme_subscription,
mermaid_showing_code: HashSet::default(),
copied_code_blocks: HashSet::default(),
wrapped_code_blocks: HashSet::default(),
@ -561,15 +573,15 @@ impl Markdown {
.retain(|id, _| ids.contains(id));
}
/// Used in the agent panel to force a re-render when the theme changes
pub fn invalidate_mermaid_cache(&mut self, cx: &mut Context<Self>) {
if self.options.render_mermaid_diagrams && !self.parsed_markdown.mermaid_diagrams.is_empty()
if !self.options.render_mermaid_diagrams || self.parsed_markdown.mermaid_diagrams.is_empty()
{
self.mermaid_state.clear();
let parsed_markdown = self.parsed_markdown.clone();
self.mermaid_state.update(&parsed_markdown, cx);
cx.notify();
return;
}
self.mermaid_state.clear();
self.mermaid_state.update(&self.parsed_markdown, cx);
cx.notify();
}
pub(crate) fn is_mermaid_showing_code(&self, source_offset: usize) -> bool {

View file

@ -1,7 +1,7 @@
use collections::HashMap;
use gpui::{
Animation, AnimationExt, AnyElement, ClickEvent, ClipboardItem, Context, Entity, Hsla,
ImageSource, RenderImage, Rgba, StyledText, Task, img, pulsating_between,
Animation, AnimationExt, AnyElement, ClickEvent, ClipboardItem, Context, Entity, ImageSource,
RenderImage, StyledText, Task, img, pulsating_between,
};
use std::collections::BTreeMap;
use std::ops::Range;
@ -104,18 +104,12 @@ impl CachedMermaidDiagram {
let render_image_clone = render_image.clone();
let svg_renderer = cx.svg_renderer();
let mermaid_theme = build_mermaid_theme(cx);
let accent_classdefs = build_accent_classdefs(cx);
let task = cx.spawn(async move |this, cx| {
let value = cx
.background_spawn(async move {
let options = mermaid_rs_renderer::RenderOptions {
theme: mermaid_theme,
layout: mermaid_rs_renderer::LayoutConfig::default(),
};
let full_source = format!("{}\n{}", contents.contents, accent_classdefs);
let svg_string =
mermaid_rs_renderer::render_with_options(&full_source, options)?;
mermaid_render::render_to_svg(&contents.contents, &mermaid_theme)?;
let scale = contents.scale as f32 / 100.0;
svg_renderer
.render_single_frame(svg_string.as_bytes(), scale)
@ -153,128 +147,71 @@ impl CachedMermaidDiagram {
}
}
/// Converts an HSLA color to a CSS hex string (e.g. `#1a2b3c`).
fn hsla_to_hex(color: Hsla) -> String {
let rgba: Rgba = color.to_rgb();
let r = (rgba.r * 255.0).round() as u8;
let g = (rgba.g * 255.0).round() as u8;
let b = (rgba.b * 255.0).round() as u8;
format!("#{r:02x}{g:02x}{b:02x}")
/// Merman has somewhat limited text measurement capabilities.
///
/// When it doesn't have metrics for any of the specified fonts, it chooses a
/// fairly narrow width, which causes visible overflow. Adding `sans-serif`
/// allows it to fall back to a more conservative (i.e. wider) measurement.
///
/// This isn't perfect - very wide fonts will likely still cause overflow. A
/// proper fix would involve somehow piping `resvg`'s actual measurements into
/// `merman`, but that is a lot of work for a fairly uncommon edge case.
fn mermaid_font_family(font_family: &str) -> String {
let font_family = gpui::font_name_with_fallbacks(font_family, "system-ui");
if font_family
.split(',')
.any(|family| family.trim().eq_ignore_ascii_case("sans-serif"))
{
font_family.to_string()
} else {
format!("{font_family}, sans-serif")
}
}
fn mermaid_font_family(font_family: &str) -> &str {
gpui::font_name_with_fallbacks(font_family, "system-ui")
}
fn build_mermaid_theme(cx: &Context<Markdown>) -> mermaid_rs_renderer::Theme {
fn build_mermaid_theme(cx: &Context<Markdown>) -> mermaid_render::MermaidTheme {
let colors = cx.theme().colors();
let theme_settings = ThemeSettings::get_global(cx);
let mut theme = mermaid_rs_renderer::Theme::modern();
theme.font_family = mermaid_font_family(theme_settings.ui_font.family.as_ref()).to_string();
theme.background = hsla_to_hex(colors.editor_background);
theme.primary_color = hsla_to_hex(colors.surface_background);
theme.primary_text_color = hsla_to_hex(colors.text);
theme.primary_border_color = hsla_to_hex(colors.border);
theme.line_color = hsla_to_hex(colors.border);
theme.secondary_color = hsla_to_hex(colors.element_background);
theme.tertiary_color = hsla_to_hex(colors.ghost_element_hover);
theme.edge_label_background = hsla_to_hex(colors.editor_background);
theme.cluster_background = hsla_to_hex(colors.panel_background);
theme.cluster_border = hsla_to_hex(colors.border_variant);
theme.text_color = hsla_to_hex(colors.text);
let accents = cx.theme().accents();
let pie_colors: [String; 12] =
std::array::from_fn(|i| hsla_to_hex(accents.color_for_index(i as u32)));
theme.pie_colors = pie_colors;
theme.pie_title_text_color = hsla_to_hex(colors.text);
theme.pie_section_text_color = "#fff".to_string();
theme.pie_legend_text_color = hsla_to_hex(colors.text);
theme.pie_stroke_color = hsla_to_hex(colors.border);
theme.pie_outer_stroke_color = hsla_to_hex(colors.border);
theme.sequence_actor_fill = hsla_to_hex(colors.element_background);
theme.sequence_actor_border = hsla_to_hex(colors.border);
theme.sequence_actor_line = hsla_to_hex(colors.border);
theme.sequence_note_fill = hsla_to_hex(colors.surface_background);
theme.sequence_note_border = hsla_to_hex(colors.border_variant);
theme.sequence_activation_fill = hsla_to_hex(colors.ghost_element_hover);
theme.sequence_activation_border = hsla_to_hex(colors.border);
let is_dark = !cx.theme().appearance.is_light();
let players = cx.theme().players();
theme.git_colors = std::array::from_fn(|i| hsla_to_hex(players.0[i % players.0.len()].cursor));
theme.git_inv_colors =
std::array::from_fn(|i| hsla_to_hex(players.0[i % players.0.len()].background));
theme.git_branch_label_colors = std::array::from_fn(|_| "#fff".to_string());
theme.git_commit_label_color = hsla_to_hex(colors.text);
theme.git_commit_label_background = hsla_to_hex(colors.element_background);
theme.git_tag_label_color = hsla_to_hex(colors.text);
theme.git_tag_label_background = hsla_to_hex(colors.element_background);
theme.git_tag_label_border = hsla_to_hex(colors.border);
let git_branch_colors = std::array::from_fn(|i| players.0[i % players.0.len()].cursor);
let git_branch_label_colors = git_branch_colors.map(mermaid_render::text_color_for_background);
theme
}
fn build_accent_classdefs(cx: &Context<Markdown>) -> String {
use std::fmt::Write;
let players = &cx.theme().players();
let is_light = cx.theme().appearance.is_light();
let mut defs = String::new();
for (i, player) in players.0.iter().enumerate() {
let (fill, text_color) = accent_fill_and_text(player.background, is_light);
let fill = hsla_to_hex(fill);
let stroke = hsla_to_hex(player.cursor);
let text_color = hsla_to_hex(text_color);
writeln!(
defs,
"classDef accent{i} fill:{fill},stroke:{stroke},color:{text_color}"
)
.ok();
mermaid_render::MermaidTheme {
dark_mode: is_dark,
font_family: mermaid_font_family(theme_settings.ui_font.family.as_ref()),
background: colors.editor_background,
primary_color: colors.surface_background,
primary_text_color: colors.text,
primary_border_color: colors.border,
secondary_color: colors.element_background,
tertiary_color: colors.ghost_element_hover,
line_color: colors.border,
text_color: colors.text,
edge_label_background: colors.editor_background,
cluster_background: colors.panel_background,
cluster_border: colors.border_variant,
note_background: colors.surface_background,
note_border: colors.border_variant,
actor_background: colors.element_background,
actor_border: colors.border,
activation_background: colors.ghost_element_hover,
activation_border: colors.border,
git_branch_colors,
git_branch_label_colors,
er_attr_bg_odd: colors.surface_background,
er_attr_bg_even: colors.element_background,
error_color: cx.theme().status().error,
warning_color: cx.theme().status().warning,
accent_colors: players
.0
.iter()
.map(|player| mermaid_render::AccentColor {
foreground: player.cursor,
background: player.background,
})
.collect(),
}
defs
}
/// Adjusts an accent fill color to ensure readable text contrast.
///
/// On dark themes, darkens the fill and uses white text.
/// On light themes, lightens the fill and uses black text.
/// The fill is adjusted until it meets a minimum WCAG contrast ratio
/// of ~4.5:1 against the chosen text color.
fn accent_fill_and_text(color: Hsla, is_light: bool) -> (Hsla, Hsla) {
let mut fill = color;
if is_light {
// Lighten fill until luminance is high enough for black text.
// Target: relative luminance >= 0.35 → contrast ratio ~8:1 with black.
for _ in 0..50 {
if relative_luminance(fill) >= 0.35 {
break;
}
fill.l = (fill.l + 0.02).min(1.0);
}
(fill, gpui::black())
} else {
// Darken fill until luminance is low enough for white text.
// Target: relative luminance <= 0.18 → contrast ratio ~4.6:1 with white.
for _ in 0..50 {
if relative_luminance(fill) <= 0.18 {
break;
}
fill.l = (fill.l - 0.02).max(0.0);
}
(fill, gpui::white())
}
}
fn relative_luminance(color: Hsla) -> f32 {
let rgba: Rgba = color.to_rgb();
fn linearize(c: f32) -> f32 {
if c <= 0.04045 {
c / 12.92
} else {
((c + 0.055) / 1.055).powf(2.4)
}
}
0.2126 * linearize(rgba.r) + 0.7152 * linearize(rgba.g) + 0.0722 * linearize(rgba.b)
}
fn parse_mermaid_info(info: &str) -> Option<u32> {
@ -292,6 +229,38 @@ fn parse_mermaid_info(info: &str) -> Option<u32> {
)
}
/// We deliberately block rendering of some diagram types, even though `merman`
/// supports them, because we have not yet written custom CSS to ensure text is
/// readable.
fn is_supported_diagram_type(source: &str) -> bool {
/// If updating this list, also update the system prompt!
const SUPPORTED_PREFIXES: &[&str] = &[
"flowchart",
"graph",
"sequenceDiagram",
"classDiagram",
"stateDiagram",
"stateDiagram-v2",
"erDiagram",
"gantt",
"pie",
"gitGraph",
"mindmap",
"timeline",
"quadrantChart",
"xychart-beta",
"journey",
];
let first_token = source
.trim_start()
.split(|c: char| c.is_whitespace() || c == '\n')
.next()
.unwrap_or("");
SUPPORTED_PREFIXES
.iter()
.any(|prefix| first_token.eq_ignore_ascii_case(prefix))
}
pub(crate) fn extract_mermaid_diagrams(
source: &str,
events: &[(Range<usize>, MarkdownEvent)],
@ -324,6 +293,9 @@ pub(crate) fn extract_mermaid_diagrams(
.strip_suffix('\n')
.unwrap_or(&source[metadata.content_range.clone()])
.to_string();
if !is_supported_diagram_type(&contents) {
continue;
}
mermaid_diagrams.insert(
source_range.start,
ParsedMarkdownMermaidDiagram {
@ -588,24 +560,10 @@ mod tests {
MarkdownStyle, WrapButtonVisibility,
};
use collections::HashMap;
use gpui::{Context, Hsla, IntoElement, Render, RenderImage, TestAppContext, Window, size};
use gpui::{Context, IntoElement, Render, RenderImage, TestAppContext, Window, size};
use std::sync::Arc;
use ui::prelude::*;
#[gpui::property_test]
fn accent_fill_and_text_sufficient_contrast(
#[strategy = Hsla::opaque_strategy()] color: Hsla,
light_mode: bool,
) {
let (fill, text) = super::accent_fill_and_text(color, light_mode);
let fill_luminance = super::relative_luminance(fill);
let text_luminance = super::relative_luminance(text);
let lighter = fill_luminance.max(text_luminance);
let darker = fill_luminance.min(text_luminance);
let contrast_ratio = (lighter + 0.05) / (darker + 0.05);
assert!(contrast_ratio >= 4.5,);
}
fn ensure_theme_initialized(cx: &mut TestAppContext) {
cx.update(|cx| {
if !cx.has_global::<settings::SettingsStore>() {
@ -693,11 +651,27 @@ mod tests {
#[test]
fn test_mermaid_font_family_resolves_zed_virtual_fonts() {
assert_eq!(super::mermaid_font_family(".ZedSans"), "IBM Plex Sans");
assert_eq!(super::mermaid_font_family("Zed Plex Sans"), "IBM Plex Sans");
assert_eq!(super::mermaid_font_family(".ZedMono"), "Lilex");
assert_eq!(super::mermaid_font_family(".SystemUIFont"), "system-ui");
assert_eq!(super::mermaid_font_family("Custom Font"), "Custom Font");
assert_eq!(
super::mermaid_font_family(".ZedSans"),
"IBM Plex Sans, sans-serif"
);
assert_eq!(
super::mermaid_font_family("Zed Plex Sans"),
"IBM Plex Sans, sans-serif"
);
assert_eq!(super::mermaid_font_family(".ZedMono"), "Lilex, sans-serif");
assert_eq!(
super::mermaid_font_family(".SystemUIFont"),
"system-ui, sans-serif"
);
assert_eq!(
super::mermaid_font_family("Custom Font"),
"Custom Font, sans-serif"
);
assert_eq!(
super::mermaid_font_family("Custom Font, sans-serif"),
"Custom Font, sans-serif"
);
}
#[test]
@ -721,6 +695,27 @@ mod tests {
assert_eq!(diagram.contents.scale, 150);
}
#[test]
fn test_unsupported_diagram_types_are_skipped() {
let markdown = concat!(
"```mermaid\nsankey-beta\n```\n\n",
"```mermaid\nblock-beta\n```\n\n",
"```mermaid\nflowchart TD\n A --> B\n```",
);
let events = crate::parser::parse_markdown_with_options(markdown, false, false).events;
let diagrams = extract_mermaid_diagrams(markdown, &events);
assert_eq!(
diagrams.len(),
1,
"Only the flowchart should be extracted; sankey and block should be skipped"
);
let diagram = diagrams.values().next().unwrap();
assert!(
diagram.contents.contents.contains("flowchart"),
"The extracted diagram should be the flowchart"
);
}
#[gpui::test]
fn test_mermaid_fallback_on_edit(cx: &mut TestAppContext) {
let old_full_order = mermaid_sequence(&["graph A", "graph B", "graph C"]);

View file

@ -0,0 +1,27 @@
[package]
name = "mermaid_render"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/mermaid_render.rs"
doctest = false
[features]
test-support = []
[dependencies]
anyhow.workspace = true
gpui.workspace = true
merman = { version = "0.4", features = ["render"] }
quick-xml.workspace = true
serde_json.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
mermaid_render = { path = ".", features = ["test-support"] }

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,181 @@
// for a very big json! macro
#![recursion_limit = "256"]
//! Crate for rendering Mermaid diagram strings to SVG strings.
//!
//! The entrypoint to this crate is [`render_to_svg`].
//!
//! It takes a `&str` and a [`MermaidTheme`]. The output is an SVG with the
//! following properties:
//! - The style matches the provided theme
//! - Nodes are given accent colors, even if none are provided in the mermaid
//! source.
//! - The SVG has been tweaked based on the assumption that it will be rasterized
//! using `usvg`/`resvg`. Some bugs/quirks of `usvg`/`resvg` are accounted for
//! in this crate.
//!
//! This module uses the [`merman`] crate for rendering, rather than
//! `mermaid-rs`, which was used in the previous implementation of mermaid
//! rendering in Zed. Merman provides significantly more accurate rendering, and
//! seems to be somewhat faster, but by default has poor CSS, making diagrams
//! look weird without significant cleanup. This is made worse by the fact that
//! `usvg`/`resvg` doesn't support some features that [`merman`] relies on.
//!
//! As such, this crate is quite large. But the code is very self-contained, and
//! has few dependencies. In fact, the [`gpui`] dependency is only needed for
//! the [`Hsla`] and [`Rgba`] color types.
//!
//! The [`render_to_svg`] function operates in two stages:
//! - [`render`] the mermaid text to SVG using [`merman`].
//! - [`postprocess`] the SVG to clean incorrect output and add styling.
//!
//! The postprocessing is also split up into stages. We parse the generated SVG
//! using [`quick_xml`], which produces an iterator of
//! [`Event<'_>`](quick_xml::events::Event)s. This iterator is then repeatedly
//! transformed, and finally collected back into an SVG string.
//!
//! This approach:
//! - Avoids doing multiple expensive string insertions.
//! - Avoids parsing the SVG multiple times (without needing to put all the
//! logic in one huge function).
//! - But is quite a bit more complex.
//!
//! I think this complexity is justified because of the drastic performance
//! impact, as well as the low-risk nature; this code cannot panic, and errors
//! in the output just produce weird-looking diagrams.
//!
//! ## Color handling
//!
//! We try to match the users theme, and also apply accent colors to diagrams to
//! make them more visually interesting. Accent colors are derived from the
//! `player_colors` in the Zed theme.
//!
//! There are three parts to color handling:
//!
//! 1. A [`merman::MermaidConfig`] is passed when initially rendering the
//! diagram. This sets most "normal" colors (background, text, etc.). However,
//! it's not possible to color nodes individually, and not all parts of the
//! diagrams are correctly themed.
//! 2. `postprocess::accent_colors` injects custom CSS classes (e.g.
//! `zed-accent-0`) to specific elements, based on the diagram type and
//! node.
//! 3. `postprocess::inject_css` injects CSS rules for the classes applied by
//! `accent_colors`
mod postprocess;
mod render;
use anyhow::Result;
use gpui::{Hsla, Rgba};
#[derive(Debug, Clone, Copy)]
pub struct AccentColor {
pub foreground: Hsla,
pub background: Hsla,
}
#[derive(Debug, Clone)]
pub struct MermaidTheme {
pub dark_mode: bool,
pub font_family: String,
pub background: Hsla,
pub primary_color: Hsla,
pub primary_text_color: Hsla,
pub primary_border_color: Hsla,
pub secondary_color: Hsla,
pub tertiary_color: Hsla,
pub line_color: Hsla,
pub text_color: Hsla,
pub edge_label_background: Hsla,
pub cluster_background: Hsla,
pub cluster_border: Hsla,
pub note_background: Hsla,
pub note_border: Hsla,
pub actor_background: Hsla,
pub actor_border: Hsla,
pub activation_background: Hsla,
pub activation_border: Hsla,
pub git_branch_colors: [Hsla; 8],
pub git_branch_label_colors: [Hsla; 8],
pub er_attr_bg_odd: Hsla,
pub er_attr_bg_even: Hsla,
pub error_color: Hsla,
pub warning_color: Hsla,
pub accent_colors: Vec<AccentColor>,
}
/// Default theme for testing.
#[cfg(any(test, feature = "test-support"))]
impl Default for MermaidTheme {
fn default() -> Self {
use gpui::{hsla, rgb};
let git_branch_colors: [Hsla; 8] = [
hsla(240.0 / 360.0, 1.0, 0.462_745_1, 1.0),
hsla(60.0 / 360.0, 1.0, 0.435_294_12, 1.0),
hsla(80.0 / 360.0, 1.0, 0.462_745_1, 1.0),
hsla(210.0 / 360.0, 1.0, 0.462_745_1, 1.0),
hsla(180.0 / 360.0, 1.0, 0.462_745_1, 1.0),
hsla(150.0 / 360.0, 1.0, 0.462_745_1, 1.0),
hsla(300.0 / 360.0, 1.0, 0.462_745_1, 1.0),
hsla(0.0, 1.0, 0.462_745_1, 1.0),
];
let git_branch_label_colors: [Hsla; 8] =
git_branch_colors.map(crate::text_color_for_background);
Self {
dark_mode: false,
font_family: "Inter, ui-sans-serif, system-ui, -apple-system, \"Segoe UI\", \"DejaVu Sans\", \"Liberation Sans\", sans-serif, \"Noto Color Emoji\", \"Apple Color Emoji\", \"Segoe UI Emoji\"".to_string(),
background: rgb(0xFFFFFF).into(),
primary_color: rgb(0xF8FAFC).into(),
primary_text_color: rgb(0x0F172A).into(),
primary_border_color: rgb(0x94A3B8).into(),
secondary_color: rgb(0xE2E8F0).into(),
tertiary_color: rgb(0xFFFFFF).into(),
line_color: rgb(0x64748B).into(),
text_color: rgb(0x0F172A).into(),
edge_label_background: rgb(0xFFFFFF).into(),
cluster_background: rgb(0xF1F5F9).into(),
cluster_border: rgb(0xCBD5E1).into(),
note_background: rgb(0xFFF7ED).into(),
note_border: rgb(0xFDBA74).into(),
actor_background: rgb(0xF8FAFC).into(),
actor_border: rgb(0x94A3B8).into(),
activation_background: rgb(0xE2E8F0).into(),
activation_border: rgb(0x94A3B8).into(),
git_branch_colors,
git_branch_label_colors,
er_attr_bg_odd: rgb(0x94A3B8).into(),
er_attr_bg_even: rgb(0x0F172A).into(),
error_color: rgb(0xDC2626).into(),
warning_color: rgb(0xD97706).into(),
accent_colors: Vec::new(),
}
}
}
/// Formats a color as a CSS hex color for embedding in SVG/CSS.
///
/// Emits `#rrggbb` for fully opaque colors and `#rrggbbaa` when the input
/// has any transparency, so translucent theme colors (e.g. `ghost_element_hover`
/// from Zed's UI palette) round-trip without silently losing their alpha.
pub(crate) fn css_color(color: Hsla) -> String {
let rgba = Rgba::from(color);
let r = (rgba.r.clamp(0.0, 1.0) * 255.0).round() as u8;
let g = (rgba.g.clamp(0.0, 1.0) * 255.0).round() as u8;
let b = (rgba.b.clamp(0.0, 1.0) * 255.0).round() as u8;
let a = (rgba.a.clamp(0.0, 1.0) * 255.0).round() as u8;
if a == 0xff {
format!("#{r:02x}{g:02x}{b:02x}")
} else {
format!("#{r:02x}{g:02x}{b:02x}{a:02x}")
}
}
pub use postprocess::util::text_color_for_background;
/// See the [module-level docs][crate] for more info.
pub fn render_to_svg(source: &str, theme: &MermaidTheme) -> Result<String> {
let svg = render::render_mermaid(source, theme)?;
let svg = postprocess::postprocess(&svg, theme)?;
Ok(svg)
}

View file

@ -0,0 +1,136 @@
//! Post-processing of [`merman`]-produced SVGs for rasterization with `usvg`/`resvg`.
//!
//! Each submodule is a specific pass that tweaks the SVG event iterator in a particular way.
//!
//! We always produce and consume [`Event`]s with a short lifetime.
//! [`Event<'a>`] is backed internally by a [`Cow<'a, [u8]>`](std::borrow::Cow),
//! so we don't have lifetime issues when we need to mutate the text in an
//! [`Event`], but also don't force allocating a new [`String`] each time.
//!
//! Many modules contain internal structs that implement [`Iterator`] to make
//! reasoning about lifetimes simpler, but these are private implementation
//! details.
mod accent_colors;
mod element_fixup;
mod fallback_fixup;
mod foreignobject_wrap;
mod inject_css;
mod strip_foreignobject;
mod strip_invalid_css;
pub(crate) mod util;
use anyhow::{Context as _, Result};
use quick_xml::Reader;
use quick_xml::events::Event;
use crate::MermaidTheme;
pub(super) fn postprocess(svg: &str, theme: &MermaidTheme) -> Result<String> {
// Pass 1: foreignObject preparation (\n fix + word wrapping)
let svg = foreignobject_wrap::process(svg)?;
// Add <text> fallbacks alongside <foreignObject> elements
let svg = merman::render::foreign_object_label_fallback_svg_text(&svg);
// Extract SVG id for CSS scoping (quick scan of the first element)
let svg_id = extract_svg_id(&svg);
// Pass 2: themed post-processing pipeline.
// Each adapter takes an iterator of events and returns an iterator of events.
// Events borrow from the `svg` string — no .into_owned() per event.
let mut reader = Reader::from_str(&svg);
reader.config_mut().check_end_names = false;
let events = ReaderIter::new(reader);
let events = strip_foreignobject::process(events);
let events = fallback_fixup::process(events, theme);
let events = element_fixup::process(events, theme);
let events = accent_colors::process(events, theme);
let events = strip_invalid_css::process(events);
let events = inject_css::process(events, theme, &svg_id);
let mut writer = quick_xml::Writer::new(Vec::with_capacity(svg.len()));
for event in events {
writer.write_event(event?)?;
}
String::from_utf8(writer.into_inner()).context("SVG output is not valid UTF-8")
}
fn extract_svg_id(svg: &str) -> String {
let mut reader = Reader::from_str(svg);
reader.config_mut().check_end_names = false;
for event in ReaderIter::new(reader) {
let Ok(Event::Start(e) | Event::Empty(e)) = event else {
continue;
};
if e.name().as_ref() == b"svg" {
return e
.try_get_attribute("id")
.ok()
.flatten()
.and_then(|a| a.unescape_value().ok())
.map(|v| v.into_owned())
.unwrap_or_default();
}
}
String::new()
}
struct ReaderIter<'a> {
reader: Reader<&'a [u8]>,
done: bool,
}
impl<'a> ReaderIter<'a> {
fn new(reader: Reader<&'a [u8]>) -> Self {
Self {
reader,
done: false,
}
}
}
impl<'a> Iterator for ReaderIter<'a> {
type Item = Result<Event<'a>>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
match self.reader.read_event() {
Ok(Event::Eof) => {
self.done = true;
None
}
Ok(event) => Some(Ok(event)),
Err(e) => {
self.done = true;
Some(Err(e.into()))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_theme() -> MermaidTheme {
MermaidTheme::default()
}
#[test]
fn strip_css_handles_style_element_with_attributes() {
let svg = r#"<svg id="test" xmlns="http://www.w3.org/2000/svg"><style type="text/css">@keyframes bounce { 0% { transform: scale(1); } 100% { transform: scale(1.1); } } .node rect { fill: red; }</style><rect width="10" height="10"/></svg>"#;
let result = postprocess(svg, &default_theme()).unwrap();
assert!(
!result.contains("@keyframes"),
"Unsupported @keyframes should be stripped from <style type=\"text/css\">, got: {result}"
);
assert!(
result.contains(".node rect"),
"Regular CSS rules should survive stripping, got: {result}"
);
}
}

View file

@ -0,0 +1,375 @@
//! Injects CSS classes to set accent colors. Mermaid broadly speaking does not
//! provide a mechanism to color individual nodes. A few diagram types support
//! `:::my-css-class` on nodes, but most don't.
//!
//! [`inject_css`](super::inject_css) then injects CSS rules for the classes
//! that this module injects.
mod class_diagram;
mod mindmap;
mod sequence_diagram;
use anyhow::Result;
use quick_xml::events::{BytesStart, Event};
use crate::MermaidTheme;
pub(crate) struct NodeRect {
pub cx: f64,
pub cy: f64,
pub half_height: f64,
pub accent_idx: usize,
}
#[derive(Default)]
pub(crate) struct NodeTracker {
rects: Vec<NodeRect>,
building: Option<NodeRect>,
}
impl NodeTracker {
pub fn start_node(&mut self, cx: f64, cy: f64, half_height: f64, accent_idx: usize) {
self.building = Some(NodeRect {
cx,
cy,
half_height,
accent_idx,
});
}
pub fn finish_node(&mut self) {
debug_assert!(self.building.is_some());
if let Some(rect) = self.building.take() {
self.rects.push(rect);
}
}
pub fn update_half_height(&mut self, e: &BytesStart<'_>) {
if let Some(node) = &mut self.building
&& let Some(hh) = parse_path_half_height(e)
&& hh > node.half_height
{
node.half_height = hh;
}
}
pub fn lookup_accent(&self, e: &BytesStart<'_>) -> Option<usize> {
lookup_position_accent(&self.rects, e)
}
}
pub(crate) fn parse_translate(e: &BytesStart<'_>) -> Option<(f64, f64)> {
let attr = e.try_get_attribute("transform").ok()??;
let val = attr.unescape_value().ok()?;
let inner = val.strip_prefix("translate(")?.strip_suffix(')')?;
let (x_str, y_str) = inner.split_once(',')?;
Some((x_str.trim().parse().ok()?, y_str.trim().parse().ok()?))
}
pub(crate) fn parse_path_half_height(e: &BytesStart<'_>) -> Option<f64> {
let attr = e.try_get_attribute("d").ok()??;
let d = attr.unescape_value().ok()?;
let rest = d.strip_prefix('M')?.trim_start();
let mut chars = rest.chars().peekable();
while chars.peek().is_some_and(|c| *c != ' ' && *c != ',') {
chars.next();
}
while chars.peek().is_some_and(|c| *c == ' ' || *c == ',') {
chars.next();
}
let y_str: String = chars.take_while(|c| *c != ' ' && *c != ',').collect();
let y: f64 = y_str.parse().ok()?;
Some(y.abs())
}
// These arrays are basically just optimized versions of `format!("zed-accent-{i}")`
const ACCENT_CLASSES: [&str; 8] = [
"zed-accent-0",
"zed-accent-1",
"zed-accent-2",
"zed-accent-3",
"zed-accent-4",
"zed-accent-5",
"zed-accent-6",
"zed-accent-7",
];
const CHART_COLOR_CLASSES: [&str; 8] = [
"zed-chart-0",
"zed-chart-1",
"zed-chart-2",
"zed-chart-3",
"zed-chart-4",
"zed-chart-5",
"zed-chart-6",
"zed-chart-7",
];
pub(crate) fn accent_class_name(index: usize) -> &'static str {
ACCENT_CLASSES[index % ACCENT_CLASSES.len()]
}
fn chart_color_class_name(index: usize) -> &'static str {
CHART_COLOR_CLASSES[index % CHART_COLOR_CLASSES.len()]
}
/// Wraps [`add_class`] and preserves the `Start`/`Empty` variant of the original event.
pub(crate) fn add_to_event<'a>(ev: &Event<'_>, e: &BytesStart<'_>, cl: &str) -> Result<Event<'a>> {
let new_elem = add_class(e, cl)?;
Ok(match ev {
Event::Start(_) => Event::Start(new_elem),
_ => Event::Empty(new_elem),
})
}
/// Adds a CSS class to an element, preserving any existing classes.
pub(crate) fn add_class<'a>(e: &BytesStart<'_>, class_to_add: &str) -> Result<BytesStart<'a>> {
let name = e.name();
let tag = std::str::from_utf8(name.as_ref())?;
let mut new_elem = BytesStart::new(tag.to_owned());
let mut class_found = false;
for attr in e.attributes() {
let attr = attr?;
if attr.key.local_name().as_ref() == b"class" {
let existing = attr.unescape_value()?;
let new_class = format!("{existing} {class_to_add}");
new_elem.push_attribute(("class", new_class.as_str()));
class_found = true;
} else {
new_elem.push_attribute(attr);
}
}
if !class_found {
new_elem.push_attribute(("class", class_to_add));
}
Ok(new_elem)
}
pub(crate) fn current_stack_accent(stack: &[Option<usize>]) -> Option<usize> {
stack.iter().rev().find_map(|entry| *entry)
}
pub(crate) fn lookup_position_accent(node_rects: &[NodeRect], e: &BytesStart<'_>) -> Option<usize> {
let parse_attr = |name| -> Option<f64> {
e.try_get_attribute(name)
.ok()??
.unescape_value()
.ok()?
.parse()
.ok()
};
let x: f64 = parse_attr("x")?;
let y: f64 = parse_attr("y")?;
node_rects.iter().find_map(|rect| {
let in_y = (y - rect.cy).abs() <= rect.half_height + 5.0;
let in_x = (x - rect.cx).abs() <= rect.half_height * 2.0;
(in_x && in_y).then_some(rect.accent_idx)
})
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DiagramType {
Flowchart,
Mindmap,
ClassDiagram,
StateDiagram,
SequenceDiagram,
Unhandled,
}
fn detect_diagram_type(e: &BytesStart<'_>) -> DiagramType {
let class = match e
.try_get_attribute("class")
.ok()
.flatten()
.and_then(|a| a.unescape_value().ok())
{
Some(c) => c,
None => return DiagramType::SequenceDiagram,
};
for token in class.split_whitespace() {
match token {
"flowchart" => return DiagramType::Flowchart,
"mindmap" => return DiagramType::Mindmap,
"classDiagram" => return DiagramType::ClassDiagram,
"statediagram" => return DiagramType::StateDiagram,
"journey" => return DiagramType::Unhandled,
_ => {}
}
}
DiagramType::SequenceDiagram
}
/// Different diagrams require different state when computing accent colors.
enum Handler {
/// Before we have identified the diagram type
Pending,
/// Diagram type doesn't require injecting classes.
Passthrough,
Flowchart(class_diagram::ClassDiagramAccents),
Mindmap(mindmap::MindmapAccents),
ClassDiagram(class_diagram::ClassDiagramAccents),
StateDiagram(class_diagram::ClassDiagramAccents),
Sequence(sequence_diagram::SequenceDiagramAccents),
}
struct AccentColors<I> {
inner: I,
theme: MermaidTheme,
handler: Handler,
in_legend: bool,
legend_color_idx: usize,
in_plot: bool,
plot_depth: usize,
plot_path_done: bool,
pie_color_idx: usize,
quadrant_point_idx: usize,
}
impl<'a, I: Iterator<Item = Result<Event<'a>>>> AccentColors<I> {
fn process_chart_colors(&mut self, event: Event<'a>) -> Result<Event<'a>> {
match &event {
Event::Start(e) | Event::Empty(e) if e.name().as_ref() == b"g" => {
if self.in_plot {
self.plot_depth += 1;
}
if let Some(class_attr) = e.try_get_attribute("class")? {
let class = class_attr.unescape_value()?;
if class.as_ref() == "plot" {
self.in_plot = true;
self.plot_depth = 1;
self.plot_path_done = false;
} else if class.as_ref() == "legend" {
self.in_legend = true;
} else if class.as_ref() == "data-point" {
let accent_count = self.theme.accent_colors.len();
if accent_count > 0 {
let idx = self.quadrant_point_idx % accent_count;
self.quadrant_point_idx += 1;
return add_to_event(&event, e, &accent_class_name(idx));
}
}
}
Ok(event)
}
Event::End(e) if e.name().as_ref() == b"g" => {
if self.in_plot {
self.plot_depth -= 1;
if self.plot_depth == 0 {
self.in_plot = false;
}
}
Ok(event)
}
Event::Start(e) | Event::Empty(e) if e.name().as_ref() == b"rect" => {
if self.in_legend && self.legend_color_idx < 8 {
let class = chart_color_class_name(self.legend_color_idx);
self.legend_color_idx += 1;
self.in_legend = false;
add_to_event(&event, e, &class)
} else if self.in_plot {
add_to_event(&event, e, &chart_color_class_name(0))
} else {
Ok(event)
}
}
Event::Start(e) | Event::Empty(e) if e.name().as_ref() == b"path" => {
let class_val = e
.try_get_attribute("class")?
.map(|a| a.unescape_value())
.transpose()?;
if class_val.as_deref() == Some("pieCircle") {
let class = chart_color_class_name(self.pie_color_idx % 8);
self.pie_color_idx += 1;
add_to_event(&event, e, &class)
} else if self.in_plot
&& !self.plot_path_done
&& e.try_get_attribute("stroke")?.is_some()
{
self.plot_path_done = true;
add_to_event(&event, e, &chart_color_class_name(1))
} else {
Ok(event)
}
}
_ => Ok(event),
}
}
}
impl<'a, I: Iterator<Item = Result<Event<'a>>>> Iterator for AccentColors<I> {
type Item = Result<Event<'a>>;
fn next(&mut self) -> Option<Self::Item> {
let event = match self.inner.next()? {
Ok(ev) => ev,
Err(e) => return Some(Err(e)),
};
if matches!(self.handler, Handler::Pending) {
if let Event::Start(e) | Event::Empty(e) = &event {
if e.name().as_ref() == b"svg" {
let diagram_type = detect_diagram_type(e);
let count = self.theme.accent_colors.len();
self.handler = match diagram_type {
DiagramType::Flowchart => {
Handler::Flowchart(class_diagram::ClassDiagramAccents::new(count))
}
DiagramType::Mindmap => Handler::Mindmap(mindmap::MindmapAccents::new()),
DiagramType::ClassDiagram => {
Handler::ClassDiagram(class_diagram::ClassDiagramAccents::new(count))
}
DiagramType::StateDiagram => {
Handler::StateDiagram(class_diagram::ClassDiagramAccents::new(count))
}
DiagramType::SequenceDiagram => {
Handler::Sequence(sequence_diagram::SequenceDiagramAccents::new(count))
}
DiagramType::Unhandled => Handler::Passthrough,
};
}
}
}
let event = match &mut self.handler {
Handler::Flowchart(h) | Handler::ClassDiagram(h) | Handler::StateDiagram(h) => {
h.process_event(event)
}
Handler::Mindmap(h) => h.process_event(event),
Handler::Sequence(h) => h.process_event(event),
Handler::Passthrough | Handler::Pending => Ok(event),
};
Some(match event {
Ok(event) => self.process_chart_colors(event),
err => err,
})
}
}
pub(super) fn process<'a>(
events: impl Iterator<Item = Result<Event<'a>>>,
theme: &MermaidTheme,
) -> impl Iterator<Item = Result<Event<'a>>> {
AccentColors {
inner: events,
theme: theme.clone(),
handler: Handler::Pending,
in_legend: false,
legend_color_idx: 0,
in_plot: false,
plot_depth: 0,
plot_path_done: false,
pie_color_idx: 0,
quadrant_point_idx: 0,
}
}

View file

@ -0,0 +1,112 @@
use anyhow::Result;
use quick_xml::events::Event;
use super::{NodeTracker, accent_class_name, add_class, add_to_event, parse_translate};
pub(crate) struct ClassDiagramAccents {
accent_count: usize,
accent_g_stack: Vec<Option<usize>>,
node_counter: usize,
nodes: NodeTracker,
current_text_accent: Option<usize>,
}
impl ClassDiagramAccents {
pub(super) fn new(accent_count: usize) -> Self {
Self {
accent_count,
accent_g_stack: Vec::new(),
node_counter: 0,
nodes: NodeTracker::default(),
current_text_accent: None,
}
}
pub(super) fn process_event<'a>(&mut self, event: Event<'a>) -> Result<Event<'a>> {
if self.accent_count == 0 {
return Ok(event);
}
match &event {
Event::Start(e) if e.name().as_ref() == b"g" => {
let is_node = if let Some(class_attr) = e.try_get_attribute("class")? {
let class = class_attr.unescape_value()?;
class
.split_whitespace()
.any(|token| token == "node" || token == "stateGroup")
} else {
false
};
if is_node {
let accent_idx = self.node_counter % self.accent_count;
self.node_counter += 1;
if let Some((cx, cy)) = parse_translate(e) {
self.nodes.start_node(cx, cy, 30.0, accent_idx);
}
self.accent_g_stack.push(Some(accent_idx));
let new_elem = add_class(e, &accent_class_name(accent_idx))?;
return Ok(Event::Start(new_elem));
}
self.accent_g_stack.push(None);
Ok(event)
}
Event::End(e) if e.name().as_ref() == b"g" => {
if let Some(entry) = self.accent_g_stack.pop() {
if entry.is_some() {
self.nodes.finish_node();
}
}
Ok(event)
}
Event::Start(e) | Event::Empty(e)
if matches!(
e.name().as_ref(),
b"rect" | b"path" | b"circle" | b"polygon" | b"ellipse"
) =>
{
if e.name().as_ref() == b"path" {
self.nodes.update_half_height(e);
}
Ok(event)
}
Event::Start(e) | Event::Empty(e)
if e.name().as_ref() == b"text" || e.name().as_ref() == b"tspan" =>
{
let is_start = matches!(event, Event::Start(_));
let is_text = e.name().as_ref() == b"text";
let accent_idx = if is_text {
self.nodes
.lookup_accent(e)
.or_else(|| super::current_stack_accent(&self.accent_g_stack))
} else {
self.current_text_accent
};
if let Some(idx) = accent_idx {
if is_text && is_start {
self.current_text_accent = Some(idx);
}
return add_to_event(&event, e, &accent_class_name(idx));
}
Ok(event)
}
Event::End(e) if e.name().as_ref() == b"text" => {
self.current_text_accent = None;
Ok(event)
}
_ => Ok(event),
}
}
}

View file

@ -0,0 +1,127 @@
use anyhow::Result;
use quick_xml::events::{BytesStart, Event};
use super::NodeTracker;
pub(super) struct MindmapAccents {
section_classes: Vec<String>,
section_g_stack: Vec<Option<usize>>,
nodes: NodeTracker,
current_text_section: Option<usize>,
}
impl MindmapAccents {
pub(super) fn new() -> Self {
Self {
section_classes: Vec::new(),
section_g_stack: Vec::new(),
nodes: NodeTracker::default(),
current_text_section: None,
}
}
pub(super) fn process_event<'a>(&mut self, event: Event<'a>) -> Result<Event<'a>> {
match &event {
Event::Start(e) if e.name().as_ref() == b"g" => {
let section_idx = self.parse_section_class(e)?;
if let Some(idx) = section_idx {
if let Some((tx, ty)) = super::parse_translate(e) {
self.nodes.start_node(tx, ty, 0.0, idx);
}
self.section_g_stack.push(Some(idx));
} else {
self.section_g_stack.push(None);
}
Ok(event)
}
Event::End(e) if e.name().as_ref() == b"g" => {
if let Some(maybe_section) = self.section_g_stack.pop() {
if maybe_section.is_some() {
self.nodes.finish_node();
}
}
Ok(event)
}
Event::Start(e) | Event::Empty(e)
if matches!(
e.name().as_ref(),
b"path" | b"rect" | b"circle" | b"polygon" | b"ellipse"
) =>
{
if e.name().as_ref() == b"path" {
self.nodes.update_half_height(e);
}
Ok(event)
}
Event::Start(e) | Event::Empty(e)
if e.name().as_ref() == b"text" || e.name().as_ref() == b"tspan" =>
{
let section_idx = self.current_section_accent().or_else(|| {
if e.name().as_ref() == b"text" {
self.nodes.lookup_accent(e)
} else {
None
}
});
if e.name().as_ref() == b"text" {
self.current_text_section = section_idx;
}
let idx = section_idx.or(self.current_text_section);
if let Some(idx) = idx {
if let Some(class_name) = self.section_class_name(idx) {
return super::add_to_event(&event, e, class_name);
}
}
Ok(event)
}
Event::End(e) if e.name().as_ref() == b"text" => {
self.current_text_section = None;
Ok(event)
}
_ => Ok(event),
}
}
fn parse_section_class(&mut self, e: &BytesStart<'_>) -> Result<Option<usize>> {
let class_attr = match e.try_get_attribute("class")? {
Some(attr) => attr,
None => return Ok(None),
};
let class = class_attr.unescape_value()?;
let tokens: Vec<&str> = class.split_whitespace().collect();
let is_root = tokens.contains(&"section-root");
for token in &tokens {
if let Some(rest) = token.strip_prefix("section-") {
if rest == "-1" || rest.parse::<u32>().is_ok() {
let class_name = if is_root {
"section-root section--1".to_string()
} else {
format!("section-{rest}")
};
let idx = self.section_classes.len();
self.section_classes.push(class_name);
return Ok(Some(idx));
}
}
}
Ok(None)
}
fn current_section_accent(&self) -> Option<usize> {
super::current_stack_accent(&self.section_g_stack)
}
fn section_class_name(&self, idx: usize) -> Option<&str> {
self.section_classes.get(idx).map(|s| s.as_str())
}
}

View file

@ -0,0 +1,98 @@
use anyhow::Result;
use quick_xml::events::{BytesStart, Event};
use super::{accent_class_name, add_to_event};
pub(super) struct SequenceDiagramAccents {
accent_count: usize,
actor_bottom_counter: usize,
actor_top_counter: usize,
last_actor_accent: Option<usize>,
current_text_accent: Option<usize>,
}
impl SequenceDiagramAccents {
pub(super) fn new(accent_count: usize) -> Self {
Self {
accent_count,
actor_bottom_counter: 0,
actor_top_counter: 0,
last_actor_accent: None,
current_text_accent: None,
}
}
pub(super) fn process_event<'a>(&mut self, event: Event<'a>) -> Result<Event<'a>> {
if self.accent_count == 0 {
return Ok(event);
}
match &event {
Event::Start(e) | Event::Empty(e) if e.name().as_ref() == b"rect" => {
if let Some(idx) = self.check_actor_rect(e)? {
add_to_event(&event, e, &accent_class_name(idx))
} else {
Ok(event)
}
}
Event::Start(e) | Event::Empty(e) if e.name().as_ref() == b"text" => {
if let Some(idx) = self.check_actor_text(e)? {
self.current_text_accent = Some(idx);
add_to_event(&event, e, &accent_class_name(idx))
} else {
Ok(event)
}
}
Event::Start(e) | Event::Empty(e) if e.name().as_ref() == b"tspan" => {
if let Some(idx) = self.current_text_accent {
add_to_event(&event, e, &accent_class_name(idx))
} else {
Ok(event)
}
}
Event::End(e) if e.name().as_ref() == b"text" => {
self.current_text_accent = None;
Ok(event)
}
_ => Ok(event),
}
}
fn check_actor_rect(&mut self, e: &BytesStart<'_>) -> Result<Option<usize>> {
let class_attr = match e.try_get_attribute("class")? {
Some(a) => a,
None => return Ok(None),
};
let class_val = class_attr.unescape_value()?;
if class_val.contains("actor-bottom") {
let idx = self.actor_bottom_counter % self.accent_count;
self.actor_bottom_counter += 1;
self.last_actor_accent = Some(idx);
Ok(Some(idx))
} else if class_val.contains("actor-top") {
let idx = self.actor_top_counter % self.accent_count;
self.actor_top_counter += 1;
self.last_actor_accent = Some(idx);
Ok(Some(idx))
} else {
Ok(None)
}
}
fn check_actor_text(&mut self, e: &BytesStart<'_>) -> Result<Option<usize>> {
let class_attr = match e.try_get_attribute("class")? {
Some(a) => a,
None => return Ok(None),
};
let class_val = class_attr.unescape_value()?;
if class_val.contains("actor") && class_val.contains("actor-box") {
Ok(self.last_actor_accent.take())
} else {
Ok(None)
}
}
}

View file

@ -0,0 +1,305 @@
//! Fixes various issues in merman's SVG output.
//!
//! Replaces hardcoded white backgrounds with the theme background:
//! ```xml
//! <!-- before --> <svg style="background-color: white">
//! <!-- after --> <svg style="background-color: #1e1e2e">
//! ```
//!
//! Removes `<rect>` elements with missing or invalid dimensions:
//! ```xml
//! <!-- before --> <rect width="NaN" height="10"/>
//! <!-- after --> (removed)
//! ```
//!
//! Replaces hardcoded text colors with the theme text color:
//! ```xml
//! <!-- before --> <text fill="#333">Hello</text>
//! <!-- after --> <text fill="#cdd6f4">Hello</text>
//! ```
use std::borrow::Cow;
use std::fmt::Write as _;
use anyhow::Result;
use quick_xml::events::{BytesStart, Event};
use crate::MermaidTheme;
struct ElementFixup<I> {
inner: I,
background_css: String,
text_color_css: String,
font_family_css: String,
svg_seen: bool,
skip_rect_depth: usize,
}
fn rewrite_attr<'a>(
e: &BytesStart<'_>,
attr_name: &[u8],
new_value: &str,
) -> Result<BytesStart<'a>> {
let name = e.name();
let tag = std::str::from_utf8(name.as_ref())?;
let mut new_elem = BytesStart::new(tag.to_owned());
for attr in e.attributes() {
let attr = attr?;
if attr.key.local_name().as_ref() == attr_name {
let local_name = attr.key.local_name();
let key = std::str::from_utf8(local_name.as_ref())?;
new_elem.push_attribute((key, new_value));
} else {
new_elem.push_attribute(attr);
}
}
Ok(new_elem)
}
fn rewrap<'a>(event: &Event<'_>, elem: BytesStart<'a>) -> Event<'a> {
match event {
Event::Start(_) => Event::Start(elem),
_ => Event::Empty(elem),
}
}
fn is_bad_rect(e: &BytesStart) -> Result<bool> {
for attr_name in ["width", "height"] {
match e.try_get_attribute(attr_name)? {
None => return Ok(true),
Some(attr) => {
let val = attr.unescape_value()?;
let trimmed = val.trim();
if trimmed.is_empty() {
return Ok(true);
}
if let Ok(n) = trimmed.parse::<f64>() {
if !n.is_finite() || n <= 0.0 {
return Ok(true);
}
}
}
}
}
Ok(false)
}
fn is_hardcoded_text_fill(val: &str) -> bool {
matches!(
val,
"" | "#333" | "black" | "#000" | "#000000" | "white" | "#fff" | "#ffffff"
)
}
fn push_font_style(style: &mut String, font_family: &str) {
write!(style, "font-family: {font_family};").expect("write to String cannot fail");
}
fn font_style(font_family: &str) -> String {
let mut style = String::with_capacity(font_family.len() + "font-family: ;".len());
push_font_style(&mut style, font_family);
style
}
fn rewrite_background_style<'a>(style: &'a str, background_css: &str) -> Cow<'a, str> {
const WHITE_BACKGROUND_STYLE: &str = "background-color: white";
let Some(background_start) = style.find(WHITE_BACKGROUND_STYLE) else {
return Cow::Borrowed(style);
};
let mut rewritten = String::with_capacity(
style
.len()
.saturating_sub(WHITE_BACKGROUND_STYLE.len())
.saturating_add("background-color: ".len())
.saturating_add(background_css.len()),
);
rewritten.push_str(&style[..background_start]);
write!(rewritten, "background-color: {background_css}").expect("write to String cannot fail");
rewritten.push_str(&style[background_start + WHITE_BACKGROUND_STYLE.len()..]);
Cow::Owned(rewritten)
}
fn font_family_declaration_value(declaration: &str) -> Option<&str> {
let (property, value) = declaration.split_once(':')?;
property
.trim()
.eq_ignore_ascii_case("font-family")
.then(|| value.trim())
}
fn rewrite_font_style<'a>(style: &'a str, font_family: &str) -> Cow<'a, str> {
let mut font_family_declaration_count = 0;
let mut has_target_font_family = false;
for declaration in style
.split(';')
.map(str::trim)
.filter(|declaration| !declaration.is_empty())
{
if let Some(value) = font_family_declaration_value(declaration) {
font_family_declaration_count += 1;
has_target_font_family = value == font_family;
}
}
if font_family_declaration_count == 1 && has_target_font_family {
return Cow::Borrowed(style);
}
let mut rewritten =
String::with_capacity(style.len() + font_family.len() + " font-family: ;".len());
for declaration in style.split(';') {
let declaration = declaration.trim();
if declaration.is_empty() || font_family_declaration_value(declaration).is_some() {
continue;
}
if !rewritten.is_empty() {
rewritten.push(' ');
}
rewritten.push_str(declaration);
rewritten.push(';');
}
if !rewritten.is_empty() {
rewritten.push(' ');
}
push_font_style(&mut rewritten, font_family);
Cow::Owned(rewritten)
}
impl<'a, I: Iterator<Item = Result<Event<'a>>>> ElementFixup<I> {
fn rewrite_svg_style(&self, e: &BytesStart<'_>) -> Result<Option<BytesStart<'a>>> {
let Some(style) = e
.try_get_attribute("style")?
.map(|a| a.unescape_value())
.transpose()?
else {
return Ok(None);
};
let new_style = rewrite_background_style(&style, &self.background_css);
if matches!(new_style, Cow::Borrowed(_)) {
return Ok(None);
}
Ok(Some(rewrite_attr(e, b"style", &new_style)?))
}
fn rewrite_text_element(&self, e: &BytesStart<'_>, fix_fill: bool) -> Result<BytesStart<'a>> {
let name = e.name();
let tag = std::str::from_utf8(name.as_ref())?;
let mut new_elem = BytesStart::new(tag.to_owned());
let mut has_font_family = false;
let mut has_style = false;
for attr in e.attributes() {
let attr = attr?;
match attr.key.local_name().as_ref() {
b"fill" if fix_fill => {
let val = attr.unescape_value()?;
if is_hardcoded_text_fill(&val) {
new_elem.push_attribute(("fill", self.text_color_css.as_str()));
} else {
new_elem.push_attribute(attr);
}
}
b"font-family" => {
has_font_family = true;
new_elem.push_attribute(("font-family", self.font_family_css.as_str()));
}
b"style" => {
has_style = true;
let style = attr.unescape_value()?;
let style = rewrite_font_style(&style, &self.font_family_css);
new_elem.push_attribute(("style", style.as_ref()));
}
_ => new_elem.push_attribute(attr),
}
}
if !has_font_family {
new_elem.push_attribute(("font-family", self.font_family_css.as_str()));
}
if !has_style {
let style = font_style(&self.font_family_css);
new_elem.push_attribute(("style", style.as_str()));
}
Ok(new_elem)
}
fn process_event(&mut self, event: Event<'a>) -> Result<Option<Event<'a>>> {
match &event {
Event::Start(e) | Event::Empty(e) if e.name().as_ref() == b"svg" && !self.svg_seen => {
self.svg_seen = true;
if let Some(new_elem) = self.rewrite_svg_style(e)? {
Ok(Some(rewrap(&event, new_elem)))
} else {
Ok(Some(event))
}
}
Event::Start(e) | Event::Empty(e) if e.name().as_ref() == b"rect" => {
if is_bad_rect(e)? {
if matches!(event, Event::Start(_)) {
self.skip_rect_depth = 1;
}
Ok(None)
} else {
Ok(Some(event))
}
}
Event::Start(e) | Event::Empty(e) if e.name().as_ref() == b"text" => {
Ok(Some(rewrap(&event, self.rewrite_text_element(e, true)?)))
}
Event::Start(e) | Event::Empty(e) if e.name().as_ref() == b"tspan" => {
Ok(Some(rewrap(&event, self.rewrite_text_element(e, false)?)))
}
_ => Ok(Some(event)),
}
}
}
impl<'a, I: Iterator<Item = Result<Event<'a>>>> Iterator for ElementFixup<I> {
type Item = Result<Event<'a>>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let event = match self.inner.next()? {
Ok(ev) => ev,
Err(e) => return Some(Err(e)),
};
if self.skip_rect_depth > 0 {
match &event {
Event::Start(_) => self.skip_rect_depth += 1,
Event::End(_) => self.skip_rect_depth -= 1,
_ => {}
}
continue;
}
match self.process_event(event) {
Ok(Some(ev)) => return Some(Ok(ev)),
Ok(None) => continue,
Err(e) => return Some(Err(e)),
}
}
}
}
pub(super) fn process<'a>(
events: impl Iterator<Item = Result<Event<'a>>>,
theme: &MermaidTheme,
) -> impl Iterator<Item = Result<Event<'a>>> {
ElementFixup {
inner: events,
background_css: crate::css_color(theme.background),
text_color_css: crate::css_color(theme.text_color),
font_family_css: theme.font_family.clone(),
svg_seen: false,
skip_rect_depth: 0,
}
}

View file

@ -0,0 +1,223 @@
//! Fixes double-escaped HTML entities inside fallback `<text>` groups that
//! were generated as replacements for `<foreignObject>` content.
//!
//! ```xml
//! <!-- before -->
//! <g data-merman-foreignobject="fallback">
//! <text>List&amp;lt;T&amp;gt;</text>
//! </g>
//!
//! <!-- after -->
//! <g data-merman-foreignobject="fallback">
//! <text>List&lt;T&gt;</text>
//! </g>
//! ```
use std::collections::VecDeque;
use anyhow::Result;
use quick_xml::events::{BytesStart, BytesText, Event};
use crate::MermaidTheme;
struct FallbackFixup<'a, I> {
inner: I,
edge_label_bg: String,
fallback_depth: usize,
text_buffer: String,
output_queue: VecDeque<Event<'a>>,
}
impl<'a, I: Iterator<Item = Result<Event<'a>>>> Iterator for FallbackFixup<'a, I> {
type Item = Result<Event<'a>>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(event) = self.output_queue.pop_front() {
return Some(Ok(event));
}
loop {
let event = match self.inner.next()? {
Ok(ev) => ev,
Err(e) => return Some(Err(e)),
};
match &event {
Event::Start(e) if e.name().as_ref() == b"g" => {
if self.fallback_depth > 0 {
self.fallback_depth += 1;
} else {
match e.try_get_attribute("data-merman-foreignobject") {
Ok(Some(attr)) if attr.value.as_ref() == b"fallback" => {
self.fallback_depth = 1;
}
Err(e) => return Some(Err(e.into())),
_ => {}
}
}
}
Event::End(e) if e.name().as_ref() == b"g" && self.fallback_depth > 0 => {
self.flush_text_buffer();
self.fallback_depth -= 1;
}
_ => {}
}
if self.fallback_depth == 0 {
return Some(Ok(event));
}
// Inside fallback group: accumulate text-like events, process others
match &event {
Event::Text(t) => {
match std::str::from_utf8(t.as_ref()) {
Ok(raw) => self.text_buffer.push_str(raw),
Err(e) => eprintln!("Invalid UTF-8 in fallback text: {e}"),
}
continue;
}
Event::GeneralRef(r) => {
self.text_buffer.push('&');
match std::str::from_utf8(r.as_ref()) {
Ok(name) => self.text_buffer.push_str(name),
Err(e) => eprintln!("Invalid UTF-8 in fallback entity ref: {e}"),
}
self.text_buffer.push(';');
continue;
}
_ => {}
}
self.flush_text_buffer();
match self.process_non_text_event(event) {
Ok(ev) => self.output_queue.push_back(ev),
Err(e) => return Some(Err(e)),
}
if let Some(event) = self.output_queue.pop_front() {
return Some(Ok(event));
}
}
}
}
impl<'a, I> FallbackFixup<'a, I> {
fn flush_text_buffer(&mut self) {
if self.text_buffer.is_empty() {
return;
}
let text = if self.text_buffer.contains("&amp;lt;") || self.text_buffer.contains("&amp;gt;")
{
let fixed = self
.text_buffer
.replace("&amp;lt;", "&lt;")
.replace("&amp;gt;", "&gt;");
self.text_buffer.clear();
fixed
} else {
std::mem::take(&mut self.text_buffer)
};
self.output_queue
.push_back(Event::Text(BytesText::from_escaped(text)));
}
fn process_non_text_event(&self, event: Event<'a>) -> Result<Event<'a>> {
let is_start = matches!(event, Event::Start(_));
match &event {
Event::Start(e) | Event::Empty(e) if e.name().as_ref() == b"rect" => {
let mut new_elem = BytesStart::new("rect");
for attr in e.attributes() {
let attr = attr?;
if attr.key.local_name().as_ref() == b"fill" {
new_elem.push_attribute(("fill", self.edge_label_bg.as_str()));
} else {
new_elem.push_attribute(attr);
}
}
Ok(if is_start {
Event::Start(new_elem)
} else {
Event::Empty(new_elem)
})
}
_ => Ok(event),
}
}
}
pub(super) fn process<'a>(
events: impl Iterator<Item = Result<Event<'a>>>,
theme: &MermaidTheme,
) -> impl Iterator<Item = Result<Event<'a>>> {
let edge_label_bg = crate::css_color(theme.edge_label_background);
FallbackFixup {
inner: events,
edge_label_bg,
fallback_depth: 0,
text_buffer: String::new(),
output_queue: VecDeque::new(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use quick_xml::Reader;
fn run_fixup(svg: &str) -> String {
let reader = Reader::from_str(svg);
let events = std::iter::from_fn({
let mut reader = reader;
let mut done = false;
move || {
if done {
return None;
}
match reader.read_event() {
Ok(quick_xml::events::Event::Eof) => {
done = true;
None
}
Ok(ev) => Some(Ok(ev)),
Err(e) => {
done = true;
Some(Err(e.into()))
}
}
}
});
let theme = crate::MermaidTheme::default();
let fixed = process(events, &theme);
let mut writer = quick_xml::Writer::new(Vec::new());
for ev in fixed {
writer.write_event(ev.unwrap()).unwrap();
}
String::from_utf8(writer.into_inner()).unwrap()
}
#[test]
fn fixes_double_escaped_entities_in_fallback() {
let svg = r##"<g data-merman-foreignobject="fallback"><text fill="#333">-List&amp;lt;Animal&amp;gt; animals</text></g>"##;
let result = run_fixup(svg);
assert!(
!result.contains("&amp;lt;"),
"Should fix double-escaped entities, got: {result}"
);
assert!(
result.contains("&lt;"),
"Should contain single-escaped entity, got: {result}"
);
}
#[test]
fn preserves_text_outside_fallback_group() {
let svg = r##"<text>-List&amp;lt;Animal&amp;gt;</text>"##;
let result = run_fixup(svg);
assert!(
result.contains("&amp;lt;"),
"Should not fix entities outside fallback group, got: {result}"
);
}
}

View file

@ -0,0 +1,96 @@
//! Converts literal `\n` escape sequences inside `<foreignObject>` elements
//! into `<br/>` tags so that line breaks render correctly.
//!
//! ```xml
//! <!-- before -->
//! <foreignObject>Hello\nWorld</foreignObject>
//!
//! <!-- after -->
//! <foreignObject>Hello<br/>World</foreignObject>
//! ```
use anyhow::{Context as _, Result};
use quick_xml::escape;
use quick_xml::events::{BytesStart, BytesText, Event};
use quick_xml::{Reader, Writer};
pub(super) fn process(svg: &str) -> Result<String> {
let mut reader = Reader::from_str(svg);
reader.config_mut().check_end_names = false;
let mut writer = Writer::new(Vec::with_capacity(svg.len()));
let mut foreign_object_depth: usize = 0;
let mut buffer = Vec::new();
loop {
let event = match reader.read_event() {
Ok(Event::Eof) => break,
Ok(event) => event,
Err(e) => return Err(e).context("failed to parse SVG in foreignObject wrap pass"),
};
let is_fo_start =
matches!(&event, Event::Start(e) if e.name().as_ref() == b"foreignObject");
let is_fo_end = matches!(&event, Event::End(e) if e.name().as_ref() == b"foreignObject");
if is_fo_start {
if foreign_object_depth == 0 {
buffer.clear();
}
buffer.push(event);
foreign_object_depth += 1;
} else if is_fo_end {
foreign_object_depth = foreign_object_depth.saturating_sub(1);
buffer.push(event);
if foreign_object_depth == 0 {
emit_buffered(std::mem::take(&mut buffer), &mut writer)?;
}
} else if foreign_object_depth > 0 {
buffer.push(event);
} else {
writer.write_event(event)?;
}
}
String::from_utf8(writer.into_inner()).context("SVG output is not valid UTF-8")
}
fn emit_buffered(buffer: Vec<Event<'_>>, writer: &mut Writer<Vec<u8>>) -> Result<()> {
for event in buffer {
match event {
Event::Text(t) => {
let processed = {
let decoded = t.decode().unwrap_or_default();
let text = escape::unescape(&decoded).unwrap_or_else(|_| decoded.clone());
emit_text_content(&text, writer)?
};
if !processed {
writer.write_event(Event::Text(t))?;
}
}
other => {
writer.write_event(other)?;
}
}
}
Ok(())
}
fn emit_text_content(text: &str, writer: &mut Writer<Vec<u8>>) -> Result<bool> {
if !text.contains("\\n") {
return Ok(false);
}
let mut first_segment = true;
for segment in text.split("\\n") {
if !first_segment {
writer.write_event(Event::Empty(BytesStart::new("br")))?;
}
first_segment = false;
writer.write_event(Event::Text(BytesText::from_escaped(escape::escape(
segment,
))))?;
}
Ok(true)
}

View file

@ -0,0 +1,517 @@
//! Builds a theme-aware CSS stylesheet and appends it into the SVG's `<style>`
//! element. All selectors are scoped to the SVG's `id` to prevent leaking.
//!
//! ```xml
//! <!-- before -->
//! <style>.node rect { fill: white; }</style>
//!
//! <!-- after -->
//! <style>.node rect { fill: white; }
//! #mermaid-1 .node rect { fill: #89b4fa !important; }
//! /* ... theme rules ... */
//! </style>
//! ```
use std::collections::VecDeque;
use std::fmt::Write;
use anyhow::Result;
use quick_xml::events::{BytesText, Event};
use crate::MermaidTheme;
/// Morally equivalent to `format!(".section-{i}")`, but without allocating
const MINDMAP_SECTION_SELECTORS: [&str; 11] = [
".section-0",
".section-1",
".section-2",
".section-3",
".section-4",
".section-5",
".section-6",
".section-7",
".section-8",
".section-9",
".section-10",
];
struct InjectCss<'a, I> {
inner: I,
injected_css: String,
in_style: bool,
injected: bool,
pending: VecDeque<Event<'a>>,
}
impl<'a, I: Iterator<Item = Result<Event<'a>>>> Iterator for InjectCss<'a, I> {
type Item = Result<Event<'a>>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(event) = self.pending.pop_front() {
return Some(Ok(event));
}
let event = match self.inner.next()? {
Ok(ev) => ev,
Err(e) => return Some(Err(e)),
};
match &event {
Event::Start(e) if e.name().as_ref() == b"style" => {
self.in_style = true;
return Some(Ok(event));
}
Event::End(e) if e.name().as_ref() == b"style" => {
self.in_style = false;
if !self.injected {
self.injected = true;
self.pending
.push_back(Event::Text(BytesText::from_escaped(std::mem::take(
&mut self.injected_css,
))));
self.pending.push_back(event);
return self.pending.pop_front().map(Ok);
}
return Some(Ok(event));
}
Event::Text(text) if self.in_style => {
self.injected = true;
let existing = match std::str::from_utf8(text.as_ref()) {
Ok(s) => s,
Err(e) => return Some(Err(e.into())),
};
let mut combined = String::with_capacity(existing.len() + self.injected_css.len());
combined.push_str(existing);
combined.push_str(&self.injected_css);
return Some(Ok(Event::Text(BytesText::from_escaped(combined))));
}
_ => {}
}
Some(Ok(event))
}
}
pub(super) fn process<'a>(
events: impl Iterator<Item = Result<Event<'a>>>,
theme: &MermaidTheme,
svg_id: &str,
) -> impl Iterator<Item = Result<Event<'a>>> {
let injected_css = build_injected_css(theme, svg_id);
InjectCss {
inner: events,
injected_css,
in_style: false,
injected: false,
pending: VecDeque::new(),
}
}
fn mindmap_section_css(theme: &MermaidTheme) -> String {
let colors: Vec<String> = theme
.git_branch_colors
.iter()
.map(|c| crate::css_color(*c))
.collect();
let fills: Vec<String> = theme
.git_branch_colors
.iter()
.map(|c| {
crate::css_color(blend_over_background(
*c,
theme.background,
ACCENT_FILL_OPACITY,
))
})
.collect();
let text = crate::css_color(theme.text_color);
let mut css = String::with_capacity(5_400);
let emit = |css: &mut String, selector: &str, color: &str, fill: &str, txt: &str| {
let section_index = selector
.trim_start_matches(".section-root.section-")
.trim_start_matches(".section-");
write!(
css,
"{selector} rect, {selector} path, {selector} circle, {selector} polygon \
{{ fill: {fill} !important; stroke: {color} !important; }}\n\
{selector} text, {selector} span, \
text{selector}, tspan{selector} \
{{ fill: {txt} !important; color: {txt} !important; }}\n\
{selector} foreignObject div, {selector} foreignObject span, {selector} foreignObject p \
{{ color: {txt} !important; }}\n\
.section-edge{section_index} {{ stroke: {color} !important; }}\n",
)
.expect("write to String cannot fail");
};
emit(
&mut css,
".section-root.section--1",
&colors[0],
&fills[0],
&text,
);
emit(&mut css, ".section--1", &colors[1], &fills[1], &text);
for (i, selector) in MINDMAP_SECTION_SELECTORS.iter().enumerate() {
let ci = 2 + (i % 6);
emit(&mut css, selector, &colors[ci], &fills[ci], &text);
}
css
}
fn git_branch_css(theme: &MermaidTheme) -> String {
let text = crate::css_color(theme.text_color);
let mut css = String::with_capacity(8 * 200);
for i in 0..8 {
let c = crate::css_color(theme.git_branch_colors[i]);
let label_fill = crate::css_color(blend_over_background(
theme.git_branch_colors[i],
theme.background,
ACCENT_FILL_OPACITY,
));
write!(
css,
".commit{i} {{ stroke: {c}; fill: {c}; }}\n\
.arrow{i} {{ stroke: {c}; }}\n\
.label{i} {{ fill: {label_fill}; stroke: {c}; }}\n\
.branch-label{i} {{ fill: {text}; }}\n"
)
.expect("write to String cannot fail");
}
css
}
fn adjust_lightness(color: &mut gpui::Hsla, dark_mode: bool) {
if dark_mode {
color.l = (color.l * 0.7).max(0.0);
} else {
color.l = (color.l * 1.3).min(1.0);
}
}
const ACCENT_FILL_OPACITY: f32 = 0.15;
fn blend_over_background(
foreground: gpui::Hsla,
background: gpui::Hsla,
opacity: f32,
) -> gpui::Hsla {
let fg = gpui::Rgba::from(foreground);
let bg = gpui::Rgba::from(background);
let blended = gpui::Rgba {
r: fg.r * opacity + bg.r * (1.0 - opacity),
g: fg.g * opacity + bg.g * (1.0 - opacity),
b: fg.b * opacity + bg.b * (1.0 - opacity),
a: 1.0,
};
gpui::Hsla::from(blended)
}
fn accent_css(theme: &MermaidTheme) -> String {
let mut css = String::with_capacity(theme.accent_colors.len() * 420);
let text = crate::css_color(theme.text_color);
for (i, accent) in theme.accent_colors.iter().enumerate() {
let stroke = crate::css_color(accent.foreground);
let fill = crate::css_color(blend_over_background(
accent.background,
theme.background,
ACCENT_FILL_OPACITY,
));
let class = format!(".zed-accent-{i}");
write!(
css,
"{class} rect, {class} path, {class} circle, {class} polygon, {class} ellipse, \
rect{class}, path{class}, circle{class}, polygon{class}, ellipse{class} \
{{ fill: {fill} !important; stroke: {stroke} !important; }}\n\
{class} text, {class} tspan, text{class}, tspan{class} \
{{ fill: {text} !important; }}\n",
)
.expect("write to String cannot fail");
}
css
}
fn chart_color_css(theme: &MermaidTheme) -> String {
// Each block is around 230 bytes, add some headroom
let mut css = String::with_capacity(8 * 250);
for i in 0..8 {
let color = crate::css_color(theme.git_branch_colors[i]);
let class = format!(".zed-chart-{i}");
write!(
css,
"path.pieCircle{class} {{ fill: {color} !important; }}\n\
.plot rect{class}, .legend rect{class} {{ fill: {color} !important; stroke: {color} !important; }}\n\
.plot path{class} {{ stroke: {color} !important; }}\n"
)
.expect("write to String cannot fail");
}
css
}
fn timeline_css(theme: &MermaidTheme) -> String {
let mut css = String::with_capacity(8 * 300);
let text = crate::css_color(theme.text_color);
for i in 0..8 {
let c = crate::css_color(theme.git_branch_colors[i]);
let fill = crate::css_color(blend_over_background(
theme.git_branch_colors[i],
theme.background,
ACCENT_FILL_OPACITY,
));
write!(
css,
"rect.task-type-{i}, rect.section-type-{i} {{ fill: {fill} !important; stroke: {c} !important; }}\n"
).expect("write to String cannot fail");
}
for i in 0..4 {
let c = crate::css_color(theme.git_branch_colors[i % 8]);
let fill = crate::css_color(blend_over_background(
theme.git_branch_colors[i % 8],
theme.background,
ACCENT_FILL_OPACITY,
));
write!(
css,
".section{i} {{ fill: {fill} !important; }}\n\
.task{i} {{ fill: {fill} !important; stroke: {c} !important; }}\n\
.taskText{i} {{ fill: {text} !important; }}\n\
.taskTextOutside{i} {{ fill: {text} !important; }}\n"
)
.expect("write to String cannot fail");
}
css
}
fn should_scope_css_line(trimmed: &str) -> bool {
!trimmed.is_empty()
&& (trimmed.starts_with('.')
|| trimmed.starts_with("foreignObject")
|| trimmed.starts_with("g.")
|| trimmed.starts_with("text")
|| trimmed.starts_with("tspan")
|| trimmed.starts_with("rect.")
|| trimmed.starts_with("path.")
|| trimmed.starts_with("defs")
|| trimmed.starts_with('#'))
}
fn scoped_selector_count(raw_css: &str) -> usize {
raw_css.lines().fold(0, |count, line| {
let trimmed = line.trim();
if !should_scope_css_line(trimmed) {
return count;
}
let Some((selectors, _)) = trimmed.split_once('{') else {
return count;
};
count.saturating_add(selectors.split(',').count())
})
}
fn scope_css(raw_css: &str, svg_id: &str) -> String {
let scoped_selector_prefix_len = svg_id.len().saturating_add(2);
let result_capacity = raw_css
.len()
.saturating_add(scoped_selector_count(raw_css).saturating_mul(scoped_selector_prefix_len));
let mut result = String::with_capacity(result_capacity);
for line in raw_css.lines() {
let trimmed = line.trim();
if should_scope_css_line(trimmed) {
if let Some(brace) = trimmed.find('{') {
let (selectors, rest) = trimmed.split_at(brace);
let mut first = true;
for selector in selectors.split(',') {
if !first {
result.push_str(", ");
}
first = false;
write!(result, "#{svg_id} {}", selector.trim())
.expect("write to String cannot fail");
}
writeln!(result, "{rest}").expect("write to String cannot fail");
continue;
}
}
writeln!(result, "{line}").expect("write to String cannot fail");
}
result
}
fn build_injected_css(theme: &MermaidTheme, svg_id: &str) -> String {
let font = &theme.font_family;
let text = crate::css_color(theme.text_color);
let line = crate::css_color(theme.line_color);
let primary = crate::css_color(theme.primary_color);
let border = crate::css_color(theme.primary_border_color);
let secondary = crate::css_color(theme.secondary_color);
let tertiary = crate::css_color(theme.tertiary_color);
let background = crate::css_color(theme.background);
let edge_label_bg = crate::css_color(theme.edge_label_background);
let actor_bg = crate::css_color(theme.actor_background);
let actor_border = crate::css_color(theme.actor_border);
let error_bg = {
let mut c = theme.error_color;
adjust_lightness(&mut c, theme.dark_mode);
c
};
let error = crate::css_color(error_bg);
let error_text = crate::css_color(crate::postprocess::util::text_color_for_background(
error_bg,
));
let warning_bg = {
let mut c = theme.warning_color;
adjust_lightness(&mut c, theme.dark_mode);
c
};
let warning = crate::css_color(warning_bg);
let warning_text = crate::css_color(crate::postprocess::util::text_color_for_background(
warning_bg,
));
let note_bg = crate::css_color(theme.note_background);
let note_border = crate::css_color(theme.note_border);
let er_odd = crate::css_color(theme.er_attr_bg_odd);
let er_even = crate::css_color(theme.er_attr_bg_even);
let actor_text = &text;
let note_text = &text;
let raw_css = format!(
r#"
text, tspan, foreignObject div, foreignObject span, foreignObject p {{ font-family: {font} !important; }}
foreignObject div, foreignObject span, foreignObject p {{ font-size: 16px; color: {text}; }}
foreignObject p {{ margin: 0; }}
foreignObject {{ overflow: visible; }}
foreignObject div {{ max-width: none !important; }}
.label-group foreignObject {{ font-weight: bold; }}
.node rect, .node path {{ fill: {primary}; stroke: {border}; }}
.node polygon {{ fill: {primary}; stroke: {border}; }}
.label-container path {{ fill: {primary}; stroke: {border}; }}
{mindmap_css}
.mindmap-node line, .timeline-node line {{ stroke: transparent !important; }}
g.stateGroup rect {{ fill: {primary} !important; stroke: {border} !important; }}
g.stateGroup text {{ fill: {text} !important; }}
g.stateGroup .state-title {{ fill: {text} !important; }}
.stateGroup .composit {{ fill: {background} !important; }}
.stateGroup .alt-composit {{ fill: {tertiary} !important; }}
.state-note {{ stroke: {note_border} !important; fill: {note_bg} !important; }}
.state-note text {{ fill: {note_text} !important; }}
.stateLabel .box {{ fill: {primary} !important; }}
.stateLabel text {{ fill: {text} !important; }}
.node circle.state-start {{ fill: {line} !important; stroke: {line} !important; }}
.node .fork-join {{ fill: {line} !important; stroke: {line} !important; }}
.node circle.state-end {{ fill: {border} !important; stroke: {background} !important; }}
.end-state-inner {{ fill: {background} !important; }}
.statediagram-cluster rect {{ fill: {primary} !important; stroke: {border} !important; }}
.statediagram-cluster.statediagram-cluster .inner {{ fill: {background} !important; }}
.statediagram-cluster.statediagram-cluster-alt .inner {{ fill: {tertiary} !important; }}
.statediagram-state rect.divider {{ fill: {tertiary} !important; }}
.statediagram-note rect {{ fill: {note_bg} !important; stroke: {note_border} !important; }}
.statediagram-note text {{ fill: {note_text} !important; }}
.statediagramTitleText {{ fill: {text} !important; }}
.transition {{ stroke: {line} !important; }}
.cluster-label, .nodeLabel {{ color: {text} !important; }}
defs #statediagram-barbEnd {{ fill: {line} !important; stroke: {line} !important; }}
#statediagram-barbEnd {{ fill: {line} !important; }}
.edgeLabel .label rect {{ fill: {primary} !important; }}
.edgeLabel rect {{ fill: {primary} !important; background-color: {primary} !important; }}
.edgeLabel .label text {{ fill: {text} !important; }}
.edgeLabel p {{ background-color: {primary} !important; }}
.edgeLabel {{ background-color: {primary} !important; }}
.actor {{ stroke: {actor_border}; fill: {actor_bg}; }}
text.actor {{ text-anchor: middle; }}
text.actor>tspan {{ fill: {actor_text} !important; stroke: none; }}
.labelText, .labelText>tspan {{ fill: {actor_text} !important; }}
.actor-line {{ stroke: {actor_border} !important; }}
.messageLine0 {{ stroke: {text} !important; }}
.messageLine1 {{ stroke: {text} !important; }}
#arrowhead path {{ fill: {text} !important; stroke: {text} !important; }}
#crosshead path {{ fill: {text} !important; stroke: {text} !important; }}
.messageText {{ fill: {text} !important; }}
.loopText, .loopText>tspan {{ fill: {text} !important; }}
.loopLine {{ stroke: {actor_border} !important; fill: {actor_border} !important; }}
.note {{ stroke: {note_border} !important; fill: {note_bg} !important; }}
.noteText, .noteText>tspan {{ fill: {note_text} !important; }}
.activation0, .activation1, .activation2 {{ fill: {secondary} !important; stroke: {border} !important; }}
.labelBox {{ stroke: {actor_border} !important; fill: {actor_bg} !important; }}
.actor-man line {{ stroke: {actor_border} !important; fill: {actor_bg} !important; }}
.actor-man circle {{ stroke: {actor_border} !important; fill: {actor_bg} !important; }}
.pieTitleText {{ fill: {text} !important; }}
.slice {{ fill: {text} !important; }}
.legend text {{ fill: {text} !important; }}
.pieOuterCircle {{ stroke: {border} !important; }}
.pieCircle {{ stroke: {border} !important; }}
{timeline_css}
text.journey-section, text.task {{ fill: {text} !important; }}
.relationshipLabelBox {{ fill: {tertiary} !important; opacity: 0.7; background-color: {tertiary} !important; }}
.labelBkg {{ background-color: {tertiary} !important; }}
.edgeLabel .label {{ fill: {border} !important; }}
.label {{ color: {text} !important; }}
.relationshipLine {{ stroke: {line} !important; fill: none !important; }}
.entityBox {{ fill: {primary}; stroke: {border}; }}
.node .row-rect-odd path {{ fill: {er_odd} !important; }}
.node .row-rect-even path {{ fill: {er_even} !important; }}
.edge-thickness-normal {{ stroke-width: 1px; }}
.relation {{ stroke: {line}; stroke-width: 1; fill: none; }}
.edgePaths path {{ fill: none; }}
.marker {{ fill: {line} !important; stroke: {line} !important; }}
.marker.er {{ fill: none !important; stroke: {line} !important; }}
.composition {{ fill: {line} !important; stroke: {line} !important; stroke-width: 1; }}
.extension {{ fill: transparent !important; stroke: {line} !important; stroke-width: 1; }}
.aggregation {{ fill: transparent !important; stroke: {line} !important; stroke-width: 1; }}
.dependency {{ fill: {line} !important; stroke: {line} !important; stroke-width: 1; }}
.lollipop {{ fill: {primary} !important; stroke: {line} !important; stroke-width: 1; }}
.sectionTitle0, .sectionTitle1, .sectionTitle2, .sectionTitle3 {{ fill: {text} !important; }}
.sectionTitle {{ font-family: {font} !important; }}
.taskTextOutsideRight {{ fill: {text} !important; font-family: {font} !important; }}
.taskTextOutsideLeft {{ fill: {text} !important; }}
.active0, .active1, .active2, .active3 {{ fill: {secondary} !important; stroke: {border} !important; }}
.activeText0, .activeText1, .activeText2, .activeText3 {{ fill: {text} !important; }}
.done0, .done1, .done2, .done3 {{ stroke: {border} !important; fill: {secondary} !important; stroke-width: 2; }}
.doneText0, .doneText1, .doneText2, .doneText3 {{ fill: {text} !important; }}
.crit0, .crit1, .crit2, .crit3 {{ fill: {error} !important; stroke: {error} !important; }}
.critText0, .critText1, .critText2, .critText3 {{ fill: {error_text} !important; }}
.activeCrit0, .activeCrit1, .activeCrit2, .activeCrit3 {{ fill: {warning} !important; stroke: {warning} !important; }}
.activeCritText0, .activeCritText1, .activeCritText2, .activeCritText3 {{ fill: {warning_text} !important; }}
.doneCrit0, .doneCrit1, .doneCrit2, .doneCrit3 {{ fill: {error} !important; stroke: {border} !important; stroke-width: 2; }}
.doneCritText0, .doneCritText1, .doneCritText2, .doneCritText3 {{ fill: {error_text} !important; }}
.titleText {{ fill: {text} !important; font-family: {font} !important; }}
.grid .tick text {{ fill: {text} !important; font-family: {font} !important; }}
.grid .tick {{ stroke: {border} !important; }}
{git_branch_css}
.commit-merge {{ stroke: {primary}; fill: {primary}; }}
.commit-reverse {{ stroke: {primary}; fill: {primary}; stroke-width: 3; }}
.commit-highlight-inner {{ stroke: {primary}; fill: {primary}; }}
.tag-label {{ font-size: 10px; }}
.tag-label-bkg {{ fill: {primary}; stroke: {border}; }}
.tag-hole {{ fill: {line}; }}
.commit-label {{ fill: {text}; }}
.commit-label-bkg {{ fill: {edge_label_bg}; }}
.commit-id, .commit-msg, .branch-label {{ fill: {text}; color: {text}; font-family: {font}; }}
{accent_css}
.data-point text {{ fill: {text} !important; }}
{chart_color_css}
"#,
mindmap_css = mindmap_section_css(theme),
git_branch_css = git_branch_css(theme),
accent_css = accent_css(theme),
chart_color_css = chart_color_css(theme),
timeline_css = timeline_css(theme),
);
scope_css(&raw_css, svg_id)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scope_css_prefixes_selectors() {
let input = " .foo { color: red; }\n";
let result = scope_css(input, "my-svg");
assert!(result.contains("#my-svg .foo"), "got: {result}");
}
}

View file

@ -0,0 +1,114 @@
//! Strips `<foreignObject>` elements and their contents from the SVG, since
//! `usvg`/`resvg` does not support them.
//!
//! ```xml
//! <!-- before -->
//! <foreignObject><div>Hello</div></foreignObject>
//! <text class="nodeLabel">Hello</text>
//!
//! <!-- after -->
//! <text class="nodeLabel">Hello</text>
//! ```
use anyhow::Result;
use quick_xml::events::Event;
struct StripForeignObject<I> {
inner: I,
/// Depth inside a `<foreignObject>` element being stripped.
foreign_depth: usize,
/// Depth inside a `<g data-merman-foreignobject="fallback">` being stripped.
fallback_depth: usize,
/// Set to true once we see a `<text>` element outside of foreignObjects
/// and fallback groups. When true, fallback groups are redundant and
/// should be stripped.
has_native_text: bool,
}
impl<'a, I: Iterator<Item = Result<Event<'a>>>> Iterator for StripForeignObject<I> {
type Item = Result<Event<'a>>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let event = self.inner.next()?;
let event = match event {
Ok(event) => event,
Err(e) => return Some(Err(e)),
};
// Strip foreignObject elements and their contents.
match &event {
Event::Start(e) if e.name().as_ref() == b"foreignObject" => {
self.foreign_depth += 1;
continue;
}
Event::Start(_) if self.foreign_depth > 0 => {
self.foreign_depth += 1;
continue;
}
Event::End(_) if self.foreign_depth > 0 => {
self.foreign_depth -= 1;
continue;
}
Event::Empty(e) if e.name().as_ref() == b"foreignObject" => {
continue;
}
_ if self.foreign_depth > 0 => {
continue;
}
_ => {}
}
// Strip fallback groups when native text exists.
match &event {
Event::Start(e) if e.name().as_ref() == b"g" && self.fallback_depth == 0 => {
if self.has_native_text {
if let Ok(Some(attr)) = e.try_get_attribute("data-merman-foreignobject") {
if attr.value.as_ref() == b"fallback" {
self.fallback_depth = 1;
continue;
}
}
}
}
Event::Start(_) if self.fallback_depth > 0 => {
self.fallback_depth += 1;
continue;
}
Event::End(_) if self.fallback_depth > 0 => {
self.fallback_depth -= 1;
continue;
}
_ if self.fallback_depth > 0 => {
continue;
}
_ => {}
}
// Track whether the diagram has native <text> elements.
if !self.has_native_text {
match &event {
Event::Start(e) | Event::Empty(e) if e.name().as_ref() == b"text" => {
if e.try_get_attribute("class").ok().flatten().is_some() {
self.has_native_text = true;
}
}
_ => {}
}
}
return Some(Ok(event));
}
}
}
pub(super) fn process<'a>(
inner: impl Iterator<Item = Result<Event<'a>>>,
) -> impl Iterator<Item = Result<Event<'a>>> {
StripForeignObject {
inner,
foreign_depth: 0,
fallback_depth: 0,
has_native_text: false,
}
}

View file

@ -0,0 +1,161 @@
//! Removes CSS constructs that `usvg`/`resvg` cannot handle.
//!
//! - `@keyframes` and `@-webkit-keyframes` blocks
//! - `:root { ... }` blocks (CSS custom properties)
//! - `:not(...)` pseudo-selectors
//! - `deg` angle units (e.g. `rotate(45deg)` → `rotate(45)`)
//!
//! Also removes `!important` declarations (so that our injected theme CSS
//! always wins).
use std::borrow::Cow;
use anyhow::Result;
use quick_xml::events::{BytesText, Event};
struct StripInvalidCss<I> {
inner: I,
in_style: bool,
}
impl<'a, I: Iterator<Item = Result<Event<'a>>>> Iterator for StripInvalidCss<I> {
type Item = Result<Event<'a>>;
fn next(&mut self) -> Option<Self::Item> {
let event = match self.inner.next()? {
Ok(ev) => ev,
Err(e) => return Some(Err(e)),
};
match &event {
Event::Start(e) if e.name().as_ref() == b"style" => {
self.in_style = true;
}
Event::End(e) if e.name().as_ref() == b"style" => {
self.in_style = false;
}
Event::Text(text) if self.in_style => {
let css_text = match std::str::from_utf8(text.as_ref()) {
Ok(s) => s,
Err(e) => return Some(Err(e.into())),
};
return Some(match strip_unsupported_css(css_text) {
Cow::Borrowed(_) => Ok(event),
Cow::Owned(processed) => Ok(Event::Text(BytesText::from_escaped(processed))),
});
}
_ => {}
}
Some(Ok(event))
}
}
pub(super) fn process<'a>(
events: impl Iterator<Item = Result<Event<'a>>>,
) -> impl Iterator<Item = Result<Event<'a>>> {
StripInvalidCss {
inner: events,
in_style: false,
}
}
fn strip_unsupported_css(css: &str) -> Cow<'_, str> {
let mut chars = css.char_indices().peekable();
let mut result = None;
let mut copied_until = 0;
while let Some((i, _)) = chars.next() {
let remaining = &css[i..];
if remaining.starts_with("@keyframes")
|| remaining.starts_with("@-webkit-keyframes")
|| remaining.starts_with(":root")
{
let result = result.get_or_insert_with(|| String::with_capacity(css.len()));
result.push_str(&css[copied_until..i]);
skip_css_block(&mut chars);
copied_until = chars.peek().map_or(css.len(), |&(i, _)| i);
}
}
let mut result = if let Some(mut result) = result {
result.push_str(&css[copied_until..]);
Cow::Owned(result)
} else {
Cow::Borrowed(css)
};
strip_css_angle_units(&mut result);
strip_css_important(&mut result);
result
}
fn skip_css_block(chars: &mut std::iter::Peekable<std::str::CharIndices>) {
for (_, c) in chars.by_ref() {
if c == '{' {
break;
}
}
let mut depth = 1u32;
for (_, c) in chars.by_ref() {
match c {
'{' => depth += 1,
'}' => {
depth -= 1;
if depth == 0 {
return;
}
}
_ => {}
}
}
}
fn replace_all_in_place(css: &mut Cow<'_, str>, needle: &str, replacement: &str) {
while let Some(pos) = css.as_ref().find(needle) {
css.to_mut()
.replace_range(pos..pos + needle.len(), replacement);
}
}
fn strip_css_angle_units(css: &mut Cow<'_, str>) {
replace_all_in_place(css, "deg)", ")");
}
/// Strip `!important` from mermaid's generated CSS so that our injected
/// theme CSS (which uses `!important`) always takes priority. This works
/// around a usvg cascade bug where competing `!important` rules are
/// resolved by first-wins rather than the CSS spec's last-wins.
fn strip_css_important(css: &mut Cow<'_, str>) {
replace_all_in_place(css, "!important", "");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strips_keyframes() {
let input = "@keyframes bounce { 0% { transform: scale(1); } 100% { transform: scale(1.1); } } .node rect { fill: red; }";
let result = strip_unsupported_css(input);
assert!(!result.contains("@keyframes"), "got: {result}");
assert!(result.contains(".node rect"), "got: {result}");
}
#[test]
fn strips_root_blocks() {
let input = ":root { --bg: white; } .foo { color: red; }";
let result = strip_unsupported_css(input);
assert!(!result.contains(":root"), "got: {result}");
assert!(result.contains(".foo"), "got: {result}");
}
#[test]
fn strips_deg_units() {
let input = ".foo { transform: rotate(45deg); }";
let result = strip_unsupported_css(input);
assert!(result.contains("rotate(45)"), "got: {result}");
assert!(!result.contains("deg"), "got: {result}");
}
}

View file

@ -0,0 +1,148 @@
use gpui::{Hsla, Rgba};
/// Produces a readable text color for a given background, subtly tinted by the
/// background's own hue using the OKLCH color space.
///
/// The result keeps ~15% of the background's chroma so the text feels
/// harmonious with its surroundings rather than a flat black or white.
/// Lightness is set to ensure readable contrast against the background.
pub fn text_color_for_background(background: Hsla) -> Hsla {
let rgba = Rgba::from(background);
let r_lin = srgb_to_linear(rgba.r);
let g_lin = srgb_to_linear(rgba.g);
let b_lin = srgb_to_linear(rgba.b);
let (_, ok_a, ok_b) = linear_rgb_to_oklab(r_lin, g_lin, b_lin);
let chroma = (ok_a * ok_a + ok_b * ok_b).sqrt();
let hue = ok_b.atan2(ok_a);
let bg_luminance = relative_luminance(rgba);
let text_l = if bg_luminance > 0.18 { 0.18 } else { 0.96 };
let text_c = chroma * 0.15;
let build = |c: f32| -> Rgba {
let (tr, tg, tb) = oklab_to_linear_rgb(text_l, c * hue.cos(), c * hue.sin());
Rgba {
r: linear_to_srgb(tr.clamp(0.0, 1.0)),
g: linear_to_srgb(tg.clamp(0.0, 1.0)),
b: linear_to_srgb(tb.clamp(0.0, 1.0)),
a: 1.0,
}
};
let meets_contrast =
|fg: Rgba| contrast_ratio_between(bg_luminance, relative_luminance(fg)) >= 4.5;
let candidate = build(text_c);
let result = if meets_contrast(candidate) {
candidate
} else {
// Binary search for the maximum chroma that still meets 4.5:1.
let mut lo = 0.0_f32;
let mut hi = text_c;
for _ in 0..16 {
let mid = (lo + hi) * 0.5;
if meets_contrast(build(mid)) {
lo = mid;
} else {
hi = mid;
}
}
let best = build(lo);
// Floating-point precision can leave the binary search result just
// below the 4.5:1 threshold. Fall back to pure black or white.
if meets_contrast(best) {
best
} else if bg_luminance > 0.18 {
Rgba {
r: 0.0,
g: 0.0,
b: 0.0,
a: 1.0,
}
} else {
Rgba {
r: 1.0,
g: 1.0,
b: 1.0,
a: 1.0,
}
}
};
Hsla::from(result)
}
fn srgb_to_linear(c: f32) -> f32 {
if c <= 0.04045 {
c / 12.92
} else {
((c + 0.055) / 1.055).powf(2.4)
}
}
fn linear_to_srgb(c: f32) -> f32 {
if c <= 0.0031308 {
c * 12.92
} else {
1.055 * c.powf(1.0 / 2.4) - 0.055
}
}
fn linear_rgb_to_oklab(r: f32, g: f32, b: f32) -> (f32, f32, f32) {
let l = (0.4122214708 * r + 0.5363325363 * g + 0.0514459929 * b).cbrt();
let m = (0.2119034982 * r + 0.6806995451 * g + 0.1073969566 * b).cbrt();
let s = (0.0883024619 * r + 0.2817188376 * g + 0.6299787005 * b).cbrt();
(
0.2104542553 * l + 0.7936177850 * m - 0.0040720468 * s,
1.9779984951 * l - 2.4285922050 * m + 0.4505937099 * s,
0.0259040371 * l + 0.7827717662 * m - 0.8086757660 * s,
)
}
fn oklab_to_linear_rgb(l: f32, a: f32, b: f32) -> (f32, f32, f32) {
let l_ = l + 0.3963377774 * a + 0.2158037573 * b;
let m_ = l - 0.1055613458 * a - 0.0638541728 * b;
let s_ = l - 0.0894841775 * a - 1.2914855480 * b;
(
4.0767416621 * l_ * l_ * l_ - 3.3077115913 * m_ * m_ * m_ + 0.2309699292 * s_ * s_ * s_,
-1.2684380046 * l_ * l_ * l_ + 2.6097574011 * m_ * m_ * m_ - 0.3413193965 * s_ * s_ * s_,
-0.0041960863 * l_ * l_ * l_ - 0.7034186147 * m_ * m_ * m_ + 1.7076147010 * s_ * s_ * s_,
)
}
fn relative_luminance(c: Rgba) -> f32 {
0.2126 * srgb_to_linear(c.r) + 0.7152 * srgb_to_linear(c.g) + 0.0722 * srgb_to_linear(c.b)
}
fn contrast_ratio_between(luminance_a: f32, luminance_b: f32) -> f32 {
let (lighter, darker) = if luminance_a > luminance_b {
(luminance_a, luminance_b)
} else {
(luminance_b, luminance_a)
};
(lighter + 0.05) / (darker + 0.05)
}
#[cfg(test)]
fn wcag_contrast_ratio(a: Rgba, b: Rgba) -> f32 {
contrast_ratio_between(relative_luminance(a), relative_luminance(b))
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::proptest::prelude::*;
#[gpui::property_test]
fn sufficient_contrast_for_any_opaque_background(
#[strategy = Hsla::opaque_strategy()] bg: Hsla,
) -> Result<(), TestCaseError> {
let text = text_color_for_background(bg);
let ratio = wcag_contrast_ratio(Rgba::from(bg), Rgba::from(text));
prop_assert!(
ratio >= 4.5,
"WCAG AA contrast ratio {ratio:.2} < 4.5 for bg {bg:?} -> text {text:?}",
);
Ok(())
}
}

View file

@ -0,0 +1,122 @@
use std::sync::atomic::{AtomicU64, Ordering};
use anyhow::{Context as _, Result, anyhow};
use crate::{MermaidTheme, css_color};
pub(super) fn render_mermaid(source: &str, theme: &MermaidTheme) -> Result<String> {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
let diagram_id = format!("merman-{id}");
let config = to_merman_config(theme);
let renderer = merman::render::HeadlessRenderer::new()
.with_site_config(config)
.with_vendored_text_measurer()
.with_diagram_id(&diagram_id);
let svg = renderer
.render_svg_sync(source)
.context("merman render failed")?
.ok_or_else(|| anyhow!("merman returned no SVG for the given input"))?;
Ok(svg)
}
fn to_merman_config(theme: &MermaidTheme) -> merman::MermaidConfig {
let primary = css_color(theme.primary_color);
let primary_text = css_color(theme.primary_text_color);
let primary_border = css_color(theme.primary_border_color);
let line = css_color(theme.line_color);
let secondary = css_color(theme.secondary_color);
let tertiary = css_color(theme.tertiary_color);
let background = css_color(theme.background);
let cluster_bg = css_color(theme.cluster_background);
let cluster_border = css_color(theme.cluster_border);
let edge_label_bg = css_color(theme.edge_label_background);
let text = css_color(theme.text_color);
let note_bg = css_color(theme.note_background);
let note_border = css_color(theme.note_border);
let actor_bg = css_color(theme.actor_background);
let actor_border = css_color(theme.actor_border);
let activation_bg = css_color(theme.activation_background);
let activation_border = css_color(theme.activation_border);
let er_odd = css_color(theme.er_attr_bg_odd);
let er_even = css_color(theme.er_attr_bg_even);
let git: [String; 8] = theme.git_branch_colors.map(css_color);
let git_lbl: [String; 8] = theme.git_branch_label_colors.map(css_color);
let mut theme_vars = serde_json::json!({
"primaryColor": primary,
"primaryTextColor": primary_text,
"primaryBorderColor": primary_border,
"lineColor": line,
"secondaryColor": secondary,
"secondaryTextColor": text,
"tertiaryColor": tertiary,
"tertiaryTextColor": text,
"background": background,
"mainBkg": primary,
"nodeBorder": primary_border,
"nodeTextColor": primary_text,
"clusterBkg": cluster_bg,
"clusterBorder": cluster_border,
"titleColor": text,
"edgeLabelBackground": edge_label_bg,
"textColor": text,
"fontFamily": theme.font_family,
"noteBkgColor": note_bg,
"noteBorderColor": note_border,
"noteTextColor": text,
"actorBkg": actor_bg,
"actorBorder": actor_border,
"actorTextColor": primary_text,
"labelTextColor": text,
"loopTextColor": text,
"signalColor": text,
"signalTextColor": text,
"activationBkgColor": activation_bg,
"activationBorderColor": activation_border,
"classText": text,
"labelColor": primary_text,
"attributeBackgroundColorOdd": er_odd,
"attributeBackgroundColorEven": er_even,
"pieTitleTextColor": text,
"pieSectionTextColor": text,
"pieLegendTextColor": text,
"pieStrokeColor": primary_border,
"pieOuterStrokeColor": primary_border,
"quadrant1Fill": primary,
"quadrant2Fill": primary,
"quadrant3Fill": primary,
"quadrant4Fill": primary,
"quadrant1TextFill": text,
"quadrant2TextFill": text,
"quadrant3TextFill": text,
"quadrant4TextFill": text,
"quadrantPointFill": line,
"quadrantPointTextFill": text,
"quadrantTitleFill": text,
"quadrantXAxisTextFill": text,
"quadrantYAxisTextFill": text,
"quadrantExternalBorderStrokeFill": primary_border,
"quadrantInternalBorderStrokeFill": primary_border,
});
let map = theme_vars.as_object_mut().expect("just created as object");
for i in 0..8 {
map.insert(format!("cScale{i}"), git[i].clone().into());
map.insert(format!("cScaleLabel{i}"), git_lbl[i].clone().into());
map.insert(format!("pie{}", i + 1), git[i].clone().into());
}
merman::MermaidConfig::from_value(serde_json::json!({
"theme": "base",
"darkMode": theme.dark_mode,
"fontFamily": theme.font_family,
"flowchart": {
"padding": 16,
},
"themeVariables": theme_vars,
}))
}

View file

@ -0,0 +1,394 @@
use gpui::Hsla;
use mermaid_render::MermaidTheme;
fn rgb(r: u8, g: u8, b: u8) -> Hsla {
gpui::Rgba {
r: r as f32 / 255.0,
g: g as f32 / 255.0,
b: b as f32 / 255.0,
a: 1.0,
}
.into()
}
const DIAGRAMS: &[(&str, &str)] = &[
(
"flowchart",
"flowchart TD\n A[Hello] --> B[World]\n B --> C{Decision}\n C -->|Yes| D[OK]\n C -->|No| E[Fail]",
),
(
"sequence",
"sequenceDiagram\n Alice->>Bob: Hello\n Bob-->>Alice: Hi\n Note over Alice,Bob: A note",
),
(
"state",
"stateDiagram-v2\n [*] --> Active\n Active --> [*]",
),
(
"er",
"erDiagram\n A { int id PK }\n B { int id PK }\n A ||--o{ B : has",
),
(
"class",
"classDiagram\n class Foo {\n +bar() void\n }",
),
("pie", "pie title Test\n \"A\" : 42\n \"B\" : 58"),
(
"gantt",
"gantt\n title Test\n dateFormat YYYY-MM-DD\n section S\n Task :a1, 2025-01-01, 7d",
),
("mindmap", "mindmap\n root((Root))\n Child1\n Child2"),
(
"journey",
"journey\n title Test\n section S\n Task: 5: Actor",
),
(
"gitgraph",
"gitGraph\n commit id: \"init\"\n branch dev\n commit id: \"feat\"\n checkout main\n merge dev",
),
(
"quadrant",
"quadrantChart\n title Test\n x-axis Low --> High\n y-axis Low --> High\n A: [0.3, 0.8]\n B: [0.7, 0.4]",
),
(
"timeline",
"timeline\n title Test\n section 2020s\n 2020 : Event A\n 2022 : Event B",
),
(
"xychart",
"xychart-beta\n title Test\n x-axis [\"A\", \"B\", \"C\"]\n y-axis \"Val\" 0 --> 10\n bar [3, 7, 5]",
),
];
fn rgb_theme() -> MermaidTheme {
MermaidTheme {
dark_mode: true,
font_family: "system-ui".to_string(),
background: rgb(40, 44, 51),
primary_color: rgb(47, 52, 62),
primary_text_color: rgb(220, 224, 229),
primary_border_color: rgb(70, 75, 87),
secondary_color: rgb(46, 52, 62),
tertiary_color: rgb(54, 60, 70),
line_color: rgb(70, 75, 87),
text_color: rgb(220, 224, 229),
edge_label_background: rgb(40, 44, 51),
cluster_background: rgb(47, 52, 62),
cluster_border: rgb(54, 60, 70),
note_background: rgb(47, 52, 62),
note_border: rgb(54, 60, 70),
actor_background: rgb(46, 52, 62),
actor_border: rgb(70, 75, 87),
activation_background: rgb(54, 60, 70),
activation_border: rgb(70, 75, 87),
git_branch_colors: [
rgb(116, 173, 232),
rgb(190, 80, 70),
rgb(191, 149, 106),
rgb(180, 119, 207),
rgb(110, 180, 191),
rgb(208, 114, 119),
rgb(222, 193, 132),
rgb(161, 193, 129),
],
git_branch_label_colors: [
rgb(116, 173, 232),
rgb(190, 80, 70),
rgb(191, 149, 106),
rgb(180, 119, 207),
rgb(110, 180, 191),
rgb(208, 114, 119),
rgb(222, 193, 132),
rgb(161, 193, 129),
]
.map(mermaid_render::text_color_for_background),
er_attr_bg_odd: rgb(47, 52, 62),
er_attr_bg_even: rgb(46, 52, 62),
error_color: rgb(220, 38, 38),
warning_color: rgb(217, 119, 6),
accent_colors: vec![
mermaid_render::AccentColor {
foreground: rgb(116, 173, 232),
background: rgb(116, 173, 232),
},
mermaid_render::AccentColor {
foreground: rgb(190, 80, 70),
background: rgb(190, 80, 70),
},
mermaid_render::AccentColor {
foreground: rgb(191, 149, 106),
background: rgb(191, 149, 106),
},
mermaid_render::AccentColor {
foreground: rgb(180, 119, 207),
background: rgb(180, 119, 207),
},
mermaid_render::AccentColor {
foreground: rgb(110, 180, 191),
background: rgb(110, 180, 191),
},
mermaid_render::AccentColor {
foreground: rgb(208, 114, 119),
background: rgb(208, 114, 119),
},
mermaid_render::AccentColor {
foreground: rgb(222, 193, 132),
background: rgb(222, 193, 132),
},
mermaid_render::AccentColor {
foreground: rgb(161, 193, 129),
background: rgb(161, 193, 129),
},
],
}
}
fn check_svg_issues(name: &str, svg: &str) -> Vec<String> {
let bad_patterns = [
"fill=\"\"",
"stroke=\"\"",
"width=\"\"",
"height=\"\"",
"NaN",
// Also check for empty values in style attributes
"fill: ;",
"fill:;",
"stroke: ;",
"stroke:;",
// Check for attributes with just whitespace
"fill=\" \"",
];
let mut issues = Vec::new();
for pattern in &bad_patterns {
let mut start = 0;
while let Some(pos) = svg[start..].find(pattern) {
let abs = start + pos;
let ctx_start = abs.saturating_sub(100);
let ctx_end = (abs + pattern.len() + 60).min(svg.len());
issues.push(format!(
"{name}: found `{pattern}` at byte {abs}:\n ...{}...\n",
&svg[ctx_start..ctx_end]
));
start = abs + pattern.len();
}
}
// Parse with quick-xml to find ANY empty attribute values on visual elements
use quick_xml::events::Event;
let mut reader = quick_xml::Reader::from_str(svg);
loop {
match reader.read_event() {
Ok(Event::Eof) => break,
Ok(Event::Start(e)) | Ok(Event::Empty(e)) => {
let tag = String::from_utf8_lossy(e.name().local_name().as_ref()).to_string();
for attr in e.attributes().flatten() {
let key = String::from_utf8_lossy(attr.key.local_name().as_ref()).to_string();
let val = attr.unescape_value().unwrap_or_default();
let visual_attr = matches!(
key.as_str(),
"fill"
| "stroke"
| "width"
| "height"
| "x"
| "y"
| "r"
| "cx"
| "cy"
| "rx"
| "ry"
| "stroke-width"
);
if visual_attr && val.is_empty() {
issues.push(format!("{name}: <{tag}> has empty {key}=\"\"\n"));
}
// Check for CSS length units that usvg can't parse
if visual_attr
&& matches!(key.as_str(), "width" | "height")
&& val.ends_with("px")
{
issues.push(format!("{name}: <{tag}> has {key}=\"{val}\" (px suffix)\n"));
}
}
}
Err(e) => {
issues.push(format!("{name}: XML parse error: {e}\n"));
break;
}
_ => {}
}
}
issues
}
#[test]
fn accent_colors_auto_applied_to_nodes() {
let theme = rgb_theme();
// A plain state diagram with no :::accent syntax should get
// automatic accent colors applied to its node groups.
let source = "stateDiagram-v2\n [*] --> Idle\n Idle --> Processing\n Processing --> Done\n Done --> [*]";
let svg = mermaid_render::render_to_svg(source, &theme).expect("render failed");
// accent_fill_and_text darkens the background color for dark mode.
// The stroke colors are direct hex conversions of the accent rgb values.
// With 3 states (Idle, Processing, Done), we expect at least accent0 and
// accent1 stroke colors to appear.
let accent0_stroke = "#74ade8"; // rgb(116, 173, 232) -> hex
let accent1_stroke = "#be5046"; // rgb(190, 80, 70) -> hex
assert!(
svg.contains(accent0_stroke),
"Expected accent0 stroke color ({accent0_stroke}) in auto-colored state diagram SVG.\n\
This means auto-coloring did not apply accent colors to node groups.\n\
SVG snippet: {}...",
&svg[..svg.len().min(2000)]
);
assert!(
svg.contains(accent1_stroke),
"Expected accent1 stroke color ({accent1_stroke}) in auto-colored state diagram SVG."
);
}
#[test]
fn generics_not_double_escaped() {
let theme = rgb_theme();
let source = "classDiagram\n class Shelter {\n -List~Animal~ animals\n +adopt(Animal a) bool\n }";
let svg = mermaid_render::render_to_svg(source, &theme).expect("render failed");
assert!(
!svg.contains("&amp;lt;"),
"Double-escaped &amp;lt; found in SVG"
);
assert!(
!svg.contains("&amp;gt;"),
"Double-escaped &amp;gt; found in SVG"
);
}
#[test]
fn backslash_n_converted_to_line_break() {
let theme = rgb_theme();
let source = r#"graph TD
L7["Layer 7\nHTTP, FTP"]
L6["Layer 6\nEncryption"]
L7 --> L6"#;
let svg = mermaid_render::render_to_svg(source, &theme).expect("render failed");
assert!(
!svg.contains(r"\n"),
"Literal \\n should not appear in SVG output"
);
assert!(
svg.contains(">Layer 7<") && svg.contains(">HTTP, FTP<"),
"Label lines should be split into separate <text> elements"
);
}
#[test]
fn class_diagram_fallback_text_uses_accent_classes() {
let theme = rgb_theme();
let source = r#"classDiagram
class Animal {
+String name
+makeSound() void
}
class Dog {
+String breed
+bark() void
}
Dog --|> Animal"#;
let svg = mermaid_render::render_to_svg(source, &theme).expect("render failed");
use quick_xml::events::Event;
let mut reader = quick_xml::Reader::from_str(&svg);
let mut in_fallback = false;
let mut accent_classes: Vec<String> = Vec::new();
loop {
match reader.read_event() {
Ok(Event::Eof) => break,
Ok(Event::Start(e)) => {
if e.name().as_ref() == b"g" {
if let Ok(Some(attr)) = e.try_get_attribute("data-merman-foreignobject") {
if attr.value.as_ref() == b"fallback" {
in_fallback = true;
}
}
}
if in_fallback && e.name().as_ref() == b"text" {
if let Ok(Some(class_attr)) = e.try_get_attribute("class") {
let class = class_attr.unescape_value().unwrap_or_default().to_string();
for token in class.split_whitespace() {
if token.starts_with("zed-accent-") {
accent_classes.push(token.to_string());
}
}
}
}
}
Ok(Event::End(e)) if e.name().as_ref() == b"g" => {
in_fallback = false;
}
_ => {}
}
}
assert!(
!accent_classes.is_empty(),
"expected zed-accent-N classes on text elements in fallback groups",
);
}
#[test]
fn sequence_diagram_tspan_uses_accent_classes() {
let theme = rgb_theme();
let source = "sequenceDiagram\n participant Database";
let svg = mermaid_render::render_to_svg(source, &theme).expect("render failed");
use quick_xml::events::Event;
let mut reader = quick_xml::Reader::from_str(&svg);
let mut accent_classes: Vec<String> = Vec::new();
loop {
match reader.read_event() {
Ok(Event::Eof) => break,
Ok(Event::Start(e)) if e.name().as_ref() == b"tspan" => {
if let Ok(Some(class_attr)) = e.try_get_attribute("class") {
let class = class_attr.unescape_value().unwrap_or_default().to_string();
for token in class.split_whitespace() {
if token.starts_with("zed-accent-") {
accent_classes.push(token.to_string());
}
}
}
}
_ => {}
}
}
assert!(
!accent_classes.is_empty(),
"expected zed-accent-N classes on tspan elements in sequence diagram",
);
}
#[test]
fn no_empty_attributes_or_nan_with_rgb_theme() {
let theme = rgb_theme();
let mut all_issues = Vec::new();
for (name, source) in DIAGRAMS {
match mermaid_render::render_to_svg(source, &theme) {
Ok(svg) => all_issues.extend(check_svg_issues(name, &svg)),
Err(e) => eprintln!("{name}: render failed (skipped): {e}"),
}
}
if !all_issues.is_empty() {
panic!(
"Found {} issues in merman SVG output (rgb theme):\n\n{}",
all_issues.len(),
all_issues.join("\n")
);
}
}

Some files were not shown because too many files have changed in this diff Show more