New evals for inline assistant (#44431)

Also factor out some common code in the evals.

Release Notes:

- N/A

---------

Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
This commit is contained in:
Michael Benfield 2025-12-14 22:55:41 -08:00 committed by GitHub
parent b8e40e6fdb
commit 0c47984a19
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 396 additions and 225 deletions

View file

@ -1343,6 +1343,7 @@ fn run_eval(eval: EvalInput) -> eval_utils::EvalOutput<EditEvalMetadata> {
let test = EditAgentTest::new(&mut cx).await;
test.eval(eval, &mut cx).await
});
cx.quit();
match result {
Ok(output) => eval_utils::EvalOutput {
data: output.to_string(),

View file

@ -13,7 +13,7 @@ path = "src/agent_ui.rs"
doctest = false
[features]
test-support = ["gpui/test-support", "language/test-support", "reqwest_client"]
test-support = ["assistant_text_thread/test-support", "eval_utils", "gpui/test-support", "language/test-support", "reqwest_client", "workspace/test-support"]
unit-eval = []
[dependencies]
@ -40,6 +40,7 @@ component.workspace = true
context_server.workspace = true
db.workspace = true
editor.workspace = true
eval_utils = { workspace = true, optional = true }
extension.workspace = true
extension_host.workspace = true
feature_flags.workspace = true
@ -71,6 +72,7 @@ postage.workspace = true
project.workspace = true
prompt_store.workspace = true
proto.workspace = true
rand.workspace = true
release_channel.workspace = true
rope.workspace = true
rules_library.workspace = true
@ -119,7 +121,6 @@ language_model = { workspace = true, "features" = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
semver.workspace = true
rand.workspace = true
reqwest_client.workspace = true
tree-sitter-md.workspace = true
unindent.workspace = true

View file

@ -7,8 +7,6 @@ mod buffer_codegen;
mod completion_provider;
mod context;
mod context_server_configuration;
#[cfg(test)]
mod evals;
mod inline_assistant;
mod inline_prompt_editor;
mod language_model_selector;

View file

@ -41,7 +41,6 @@ use std::{
time::Instant,
};
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
use ui::SharedString;
/// Use this tool to provide a message to the user when you're unable to complete a task.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@ -56,16 +55,16 @@ pub struct FailureMessageInput {
/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct RewriteSectionInput {
/// The text to replace the section with.
#[serde(default)]
pub replacement_text: String,
/// A brief description of the edit you have made.
///
/// The description may use markdown formatting if you wish.
/// This is optional - if the edit is simple or obvious, you should leave it empty.
#[serde(default)]
pub description: String,
/// The text to replace the section with.
#[serde(default)]
pub replacement_text: String,
}
pub struct BufferCodegen {
@ -287,8 +286,9 @@ pub struct CodegenAlternative {
completion: Option<String>,
selected_text: Option<String>,
pub message_id: Option<String>,
pub model_explanation: Option<SharedString>,
session_id: Uuid,
pub description: Option<String>,
pub failure: Option<String>,
}
impl EventEmitter<CodegenEvent> for CodegenAlternative {}
@ -346,8 +346,9 @@ impl CodegenAlternative {
elapsed_time: None,
completion: None,
selected_text: None,
model_explanation: None,
session_id,
description: None,
failure: None,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
}
}
@ -920,6 +921,16 @@ impl CodegenAlternative {
self.completion.clone()
}
#[cfg(any(test, feature = "test-support"))]
pub fn current_description(&self) -> Option<String> {
self.description.clone()
}
#[cfg(any(test, feature = "test-support"))]
pub fn current_failure(&self) -> Option<String> {
self.failure.clone()
}
pub fn selected_text(&self) -> Option<&str> {
self.selected_text.as_deref()
}
@ -1133,32 +1144,69 @@ impl CodegenAlternative {
}
};
enum ToolUseOutput {
Rewrite {
text: String,
description: Option<String>,
},
Failure(String),
}
enum ModelUpdate {
Description(String),
Failure(String),
}
let chars_read_so_far = Arc::new(Mutex::new(0usize));
let tool_to_text_and_message =
move |tool_use: LanguageModelToolUse| -> (Option<String>, Option<String>) {
let mut chars_read_so_far = chars_read_so_far.lock();
match tool_use.name.as_ref() {
"rewrite_section" => {
let Ok(mut input) =
serde_json::from_value::<RewriteSectionInput>(tool_use.input)
else {
return (None, None);
};
let value = input.replacement_text[*chars_read_so_far..].to_string();
*chars_read_so_far = input.replacement_text.len();
(Some(value), Some(std::mem::take(&mut input.description)))
}
"failure_message" => {
let Ok(mut input) =
serde_json::from_value::<FailureMessageInput>(tool_use.input)
else {
return (None, None);
};
(None, Some(std::mem::take(&mut input.message)))
}
_ => (None, None),
let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option<ToolUseOutput> {
let mut chars_read_so_far = chars_read_so_far.lock();
let is_complete = tool_use.is_input_complete;
match tool_use.name.as_ref() {
"rewrite_section" => {
let Ok(mut input) =
serde_json::from_value::<RewriteSectionInput>(tool_use.input)
else {
return None;
};
let text = input.replacement_text[*chars_read_so_far..].to_string();
*chars_read_so_far = input.replacement_text.len();
let description = is_complete
.then(|| {
let desc = std::mem::take(&mut input.description);
if desc.is_empty() { None } else { Some(desc) }
})
.flatten();
Some(ToolUseOutput::Rewrite { text, description })
}
};
"failure_message" => {
if !is_complete {
return None;
}
let Ok(mut input) =
serde_json::from_value::<FailureMessageInput>(tool_use.input)
else {
return None;
};
Some(ToolUseOutput::Failure(std::mem::take(&mut input.message)))
}
_ => None,
}
};
let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded::<ModelUpdate>();
cx.spawn({
let codegen = codegen.clone();
async move |cx| {
while let Some(update) = message_rx.next().await {
let _ = codegen.update(cx, |this, _cx| match update {
ModelUpdate::Description(d) => this.description = Some(d),
ModelUpdate::Failure(f) => this.failure = Some(f),
});
}
}
})
.detach();
let mut message_id = None;
let mut first_text = None;
@ -1171,24 +1219,23 @@ impl CodegenAlternative {
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
message_id = Some(id);
}
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
if matches!(
tool_use.name.as_ref(),
"rewrite_section" | "failure_message"
) =>
{
let is_complete = tool_use.is_input_complete;
let (text, message) = tool_to_text_and_message(tool_use);
// Only update the model explanation if the tool use is complete.
// Otherwise the UI element bounces around as it's updated.
if is_complete {
let _ = codegen.update(cx, |this, _cx| {
this.model_explanation = message.map(Into::into);
});
}
first_text = text;
if first_text.is_some() {
break;
Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
if let Some(output) = process_tool_use(tool_use) {
let (text, update) = match output {
ToolUseOutput::Rewrite { text, description } => {
(Some(text), description.map(ModelUpdate::Description))
}
ToolUseOutput::Failure(message) => {
(None, Some(ModelUpdate::Failure(message)))
}
};
if let Some(update) = update {
let _ = message_tx.unbounded_send(update);
}
first_text = text;
if first_text.is_some() {
break;
}
}
}
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
@ -1215,41 +1262,30 @@ impl CodegenAlternative {
return;
};
let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded();
cx.spawn({
let codegen = codegen.clone();
async move |cx| {
while let Some(message) = message_rx.next().await {
let _ = codegen.update(cx, |this, _cx| {
this.model_explanation = message;
});
}
}
})
.detach();
let move_last_token_usage = last_token_usage.clone();
let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
completion_events.filter_map(move |e| {
let tool_to_text_and_message = tool_to_text_and_message.clone();
let process_tool_use = process_tool_use.clone();
let last_token_usage = move_last_token_usage.clone();
let total_text = total_text.clone();
let mut message_tx = message_tx.clone();
async move {
match e {
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
if matches!(
tool_use.name.as_ref(),
"rewrite_section" | "failure_message"
) =>
{
let is_complete = tool_use.is_input_complete;
let (text, message) = tool_to_text_and_message(tool_use);
if is_complete {
// Again only send the message when complete to not get a bouncing UI element.
let _ = message_tx.send(message.map(Into::into)).await;
Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
let Some(output) = process_tool_use(tool_use) else {
return None;
};
let (text, update) = match output {
ToolUseOutput::Rewrite { text, description } => {
(Some(text), description.map(ModelUpdate::Description))
}
ToolUseOutput::Failure(message) => {
(None, Some(ModelUpdate::Failure(message)))
}
};
if let Some(update) = update {
let _ = message_tx.send(update).await;
}
text.map(Ok)
}

View file

@ -1,89 +0,0 @@
use std::str::FromStr;
use crate::inline_assistant::test::run_inline_assistant_test;
use eval_utils::{EvalOutput, NoProcessor};
use gpui::TestAppContext;
use language_model::{LanguageModelRegistry, SelectedModel};
use rand::{SeedableRng as _, rngs::StdRng};
#[test]
#[cfg_attr(not(feature = "unit-eval"), ignore)]
fn eval_single_cursor_edit() {
eval_utils::eval(20, 1.0, NoProcessor, move || {
run_eval(
&EvalInput {
prompt: "Rename this variable to buffer_text".to_string(),
buffer: indoc::indoc! {"
struct EvalExampleStruct {
text: Strˇing,
prompt: String,
}
"}
.to_string(),
},
&|_, output| {
let expected = indoc::indoc! {"
struct EvalExampleStruct {
buffer_text: String,
prompt: String,
}
"};
if output == expected {
EvalOutput {
outcome: eval_utils::OutcomeKind::Passed,
data: "Passed!".to_string(),
metadata: (),
}
} else {
EvalOutput {
outcome: eval_utils::OutcomeKind::Failed,
data: format!("Failed to rename variable, output: {}", output),
metadata: (),
}
}
},
)
});
}
struct EvalInput {
buffer: String,
prompt: String,
}
fn run_eval(
input: &EvalInput,
judge: &dyn Fn(&EvalInput, &str) -> eval_utils::EvalOutput<()>,
) -> eval_utils::EvalOutput<()> {
let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng());
let mut cx = TestAppContext::build(dispatcher, None);
cx.skip_drawing();
let buffer_text = run_inline_assistant_test(
input.buffer.clone(),
input.prompt.clone(),
|cx| {
// Reconfigure to use a real model instead of the fake one
let model_name = std::env::var("ZED_AGENT_MODEL")
.unwrap_or("anthropic/claude-sonnet-4-latest".into());
let selected_model = SelectedModel::from_str(&model_name)
.expect("Invalid model format. Use 'provider/model-id'");
log::info!("Selected model: {selected_model:?}");
cx.update(|_, cx| {
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.select_inline_assistant_model(Some(&selected_model), cx);
});
});
},
|_cx| {
log::info!("Waiting for actual response from the LLM...");
},
&mut cx,
);
judge(input, &buffer_text)
}

View file

@ -117,14 +117,6 @@ impl InlineAssistant {
}
}
#[cfg(any(test, feature = "test-support"))]
pub fn set_completion_receiver(
&mut self,
sender: mpsc::UnboundedSender<anyhow::Result<InlineAssistId>>,
) {
self._inline_assistant_completions = Some(sender);
}
pub fn register_workspace(
&mut self,
workspace: &Entity<Workspace>,
@ -1593,6 +1585,27 @@ impl InlineAssistant {
.map(InlineAssistTarget::Terminal)
}
}
#[cfg(any(test, feature = "test-support"))]
pub fn set_completion_receiver(
&mut self,
sender: mpsc::UnboundedSender<anyhow::Result<InlineAssistId>>,
) {
self._inline_assistant_completions = Some(sender);
}
#[cfg(any(test, feature = "test-support"))]
pub fn get_codegen(
&mut self,
assist_id: InlineAssistId,
cx: &mut App,
) -> Option<Entity<CodegenAlternative>> {
self.assists.get(&assist_id).map(|inline_assist| {
inline_assist
.codegen
.update(cx, |codegen, _cx| codegen.active_alternative().clone())
})
}
}
struct EditorInlineAssists {
@ -2014,8 +2027,10 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
}
}
#[cfg(any(test, feature = "test-support"))]
#[cfg(any(test, feature = "unit-eval"))]
#[cfg_attr(not(test), allow(dead_code))]
pub mod test {
use std::sync::Arc;
use agent::HistoryStore;
@ -2026,7 +2041,6 @@ pub mod test {
use futures::channel::mpsc;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use language::Buffer;
use language_model::LanguageModelRegistry;
use project::Project;
use prompt_store::PromptBuilder;
use smol::stream::StreamExt as _;
@ -2035,13 +2049,43 @@ pub mod test {
use crate::InlineAssistant;
#[derive(Debug)]
pub enum InlineAssistantOutput {
Success {
completion: Option<String>,
description: Option<String>,
full_buffer_text: String,
},
Failure {
failure: String,
},
// These fields are used for logging
#[allow(unused)]
Malformed {
completion: Option<String>,
description: Option<String>,
failure: Option<String>,
},
}
impl InlineAssistantOutput {
pub fn buffer_text(&self) -> &str {
match self {
InlineAssistantOutput::Success {
full_buffer_text, ..
} => full_buffer_text,
_ => "",
}
}
}
pub fn run_inline_assistant_test<SetupF, TestF>(
base_buffer: String,
prompt: String,
setup: SetupF,
test: TestF,
cx: &mut TestAppContext,
) -> String
) -> InlineAssistantOutput
where
SetupF: FnOnce(&mut gpui::VisualTestContext),
TestF: FnOnce(&mut gpui::VisualTestContext),
@ -2133,39 +2177,198 @@ pub mod test {
test(cx);
cx.executor()
.block_test(async { completion_rx.next().await });
let assist_id = cx
.executor()
.block_test(async { completion_rx.next().await })
.unwrap()
.unwrap();
buffer.read_with(cx, |buffer, _| buffer.text())
}
let (completion, description, failure) = cx.update(|_, cx| {
InlineAssistant::update_global(cx, |inline_assistant, cx| {
let codegen = inline_assistant.get_codegen(assist_id, cx).unwrap();
#[allow(unused)]
pub fn test_inline_assistant(
base_buffer: &'static str,
llm_output: &'static str,
cx: &mut TestAppContext,
) -> String {
run_inline_assistant_test(
base_buffer.to_string(),
"Prompt doesn't matter because we're using a fake model".to_string(),
|cx| {
cx.update(|_, cx| LanguageModelRegistry::test(cx));
},
|cx| {
let fake_model = cx.update(|_, cx| {
LanguageModelRegistry::global(cx)
.update(cx, |registry, _| registry.fake_model())
});
let fake = fake_model.as_fake();
let completion = codegen.read(cx).current_completion();
let description = codegen.read(cx).current_description();
let failure = codegen.read(cx).current_failure();
// let fake = fake_model;
fake.send_last_completion_stream_text_chunk(llm_output.to_string());
fake.end_last_completion_stream();
(completion, description, failure)
})
});
// Run again to process the model's response
cx.run_until_parked();
},
cx,
)
if failure.is_some() && (completion.is_some() || description.is_some()) {
InlineAssistantOutput::Malformed {
completion,
description,
failure,
}
} else if let Some(failure) = failure {
InlineAssistantOutput::Failure { failure }
} else {
InlineAssistantOutput::Success {
completion,
description,
full_buffer_text: buffer.read_with(cx, |buffer, _| buffer.text()),
}
}
}
}
#[cfg(any(test, feature = "unit-eval"))]
#[cfg_attr(not(test), allow(dead_code))]
pub mod evals {
use std::str::FromStr;
use eval_utils::{EvalOutput, NoProcessor};
use gpui::TestAppContext;
use language_model::{LanguageModelRegistry, SelectedModel};
use rand::{SeedableRng as _, rngs::StdRng};
use crate::inline_assistant::test::{InlineAssistantOutput, run_inline_assistant_test};
#[test]
#[cfg_attr(not(feature = "unit-eval"), ignore)]
fn eval_single_cursor_edit() {
run_eval(
20,
1.0,
"Rename this variable to buffer_text".to_string(),
indoc::indoc! {"
struct EvalExampleStruct {
text: Strˇing,
prompt: String,
}
"}
.to_string(),
exact_buffer_match(indoc::indoc! {"
struct EvalExampleStruct {
buffer_text: String,
prompt: String,
}
"}),
);
}
#[test]
#[cfg_attr(not(feature = "unit-eval"), ignore)]
fn eval_cant_do() {
run_eval(
20,
1.0,
"Rename the struct to EvalExampleStructNope",
indoc::indoc! {"
struct EvalExampleStruct {
text: Strˇing,
prompt: String,
}
"},
uncertain_output,
);
}
#[test]
#[cfg_attr(not(feature = "unit-eval"), ignore)]
fn eval_unclear() {
run_eval(
20,
1.0,
"Make exactly the change I want you to make",
indoc::indoc! {"
struct EvalExampleStruct {
text: Strˇing,
prompt: String,
}
"},
uncertain_output,
);
}
fn run_eval(
iterations: usize,
expected_pass_ratio: f32,
prompt: impl Into<String>,
buffer: impl Into<String>,
judge: impl Fn(InlineAssistantOutput) -> eval_utils::EvalOutput<()> + Send + Sync + 'static,
) {
let buffer = buffer.into();
let prompt = prompt.into();
eval_utils::eval(iterations, expected_pass_ratio, NoProcessor, move || {
let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng());
let mut cx = TestAppContext::build(dispatcher, None);
cx.skip_drawing();
let output = run_inline_assistant_test(
buffer.clone(),
prompt.clone(),
|cx| {
// Reconfigure to use a real model instead of the fake one
let model_name = std::env::var("ZED_AGENT_MODEL")
.unwrap_or("anthropic/claude-sonnet-4-latest".into());
let selected_model = SelectedModel::from_str(&model_name)
.expect("Invalid model format. Use 'provider/model-id'");
log::info!("Selected model: {selected_model:?}");
cx.update(|_, cx| {
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.select_inline_assistant_model(Some(&selected_model), cx);
});
});
},
|_cx| {
log::info!("Waiting for actual response from the LLM...");
},
&mut cx,
);
cx.quit();
judge(output)
});
}
fn uncertain_output(output: InlineAssistantOutput) -> EvalOutput<()> {
match &output {
o @ InlineAssistantOutput::Success {
completion,
description,
..
} => {
if description.is_some() && completion.is_none() {
EvalOutput::passed(format!(
"Assistant produced no completion, but a description:\n{}",
description.as_ref().unwrap()
))
} else {
EvalOutput::failed(format!("Assistant produced a completion:\n{:?}", o))
}
}
InlineAssistantOutput::Failure {
failure: error_message,
} => EvalOutput::passed(format!(
"Assistant produced a failure message: {}",
error_message
)),
o @ InlineAssistantOutput::Malformed { .. } => {
EvalOutput::failed(format!("Assistant produced a malformed response:\n{:?}", o))
}
}
}
fn exact_buffer_match(
correct_output: impl Into<String>,
) -> impl Fn(InlineAssistantOutput) -> EvalOutput<()> {
let correct_output = correct_output.into();
move |output| {
if output.buffer_text() == correct_output {
EvalOutput::passed("Assistant output matches")
} else {
EvalOutput::failed(format!(
"Assistant output does not match expected output: {:?}",
output
))
}
}
}
}

View file

@ -101,11 +101,11 @@ impl<T: 'static> Render for PromptEditor<T> {
let left_gutter_width = gutter.full_width() + (gutter.margin / 2.0);
let right_padding = editor_margins.right + RIGHT_PADDING;
let explanation = codegen
.active_alternative()
.read(cx)
.model_explanation
.clone();
let active_alternative = codegen.active_alternative().read(cx);
let explanation = active_alternative
.description
.clone()
.or_else(|| active_alternative.failure.clone());
(left_gutter_width, right_padding, explanation)
}
@ -139,7 +139,7 @@ impl<T: 'static> Render for PromptEditor<T> {
if let Some(explanation) = &explanation {
markdown.update(cx, |markdown, cx| {
markdown.reset(explanation.clone(), cx);
markdown.reset(SharedString::from(explanation), cx);
});
}

View file

@ -40,6 +40,24 @@ pub struct EvalOutput<M> {
pub metadata: M,
}
impl<M: Default> EvalOutput<M> {
pub fn passed(message: impl Into<String>) -> Self {
EvalOutput {
outcome: OutcomeKind::Passed,
data: message.into(),
metadata: M::default(),
}
}
pub fn failed(message: impl Into<String>) -> Self {
EvalOutput {
outcome: OutcomeKind::Failed,
data: message.into(),
metadata: M::default(),
}
}
}
pub struct NoProcessor;
impl EvalOutputProcessor for NoProcessor {
type Metadata = ();

View file

@ -18,6 +18,6 @@ impl FeatureFlag for InlineAssistantUseToolFeatureFlag {
const NAME: &'static str = "inline-assistant-use-tool";
fn enabled_for_staff() -> bool {
false
true
}
}

View file

@ -17,7 +17,7 @@ use settings::{Settings, SettingsStore};
use std::collections::HashMap;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::{Arc, LazyLock, OnceLock};
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
@ -31,7 +31,6 @@ static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY";
static CODESTRAL_API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME);
static CODESTRAL_API_KEY: OnceLock<Entity<ApiKeyState>> = OnceLock::new();
#[derive(Default, Clone, Debug, PartialEq)]
pub struct MistralSettings {
@ -49,14 +48,18 @@ pub struct State {
codestral_api_key_state: Entity<ApiKeyState>,
}
struct CodestralApiKey(Entity<ApiKeyState>);
impl Global for CodestralApiKey {}
pub fn codestral_api_key(cx: &mut App) -> Entity<ApiKeyState> {
return CODESTRAL_API_KEY
.get_or_init(|| {
cx.new(|_| {
ApiKeyState::new(CODESTRAL_API_URL.into(), CODESTRAL_API_KEY_ENV_VAR.clone())
})
})
.clone();
if cx.has_global::<CodestralApiKey>() {
cx.global::<CodestralApiKey>().0.clone()
} else {
let api_key_state = cx
.new(|_| ApiKeyState::new(CODESTRAL_API_URL.into(), CODESTRAL_API_KEY_ENV_VAR.clone()));
cx.set_global(CodestralApiKey(api_key_state.clone()));
api_key_state
}
}
impl State {