mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
New evals for inline assistant (#44431)
Also factor out some common code in the evals. Release Notes: - N/A --------- Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
This commit is contained in:
parent
b8e40e6fdb
commit
0c47984a19
10 changed files with 396 additions and 225 deletions
|
|
@ -1343,6 +1343,7 @@ fn run_eval(eval: EvalInput) -> eval_utils::EvalOutput<EditEvalMetadata> {
|
|||
let test = EditAgentTest::new(&mut cx).await;
|
||||
test.eval(eval, &mut cx).await
|
||||
});
|
||||
cx.quit();
|
||||
match result {
|
||||
Ok(output) => eval_utils::EvalOutput {
|
||||
data: output.to_string(),
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ path = "src/agent_ui.rs"
|
|||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = ["gpui/test-support", "language/test-support", "reqwest_client"]
|
||||
test-support = ["assistant_text_thread/test-support", "eval_utils", "gpui/test-support", "language/test-support", "reqwest_client", "workspace/test-support"]
|
||||
unit-eval = []
|
||||
|
||||
[dependencies]
|
||||
|
|
@ -40,6 +40,7 @@ component.workspace = true
|
|||
context_server.workspace = true
|
||||
db.workspace = true
|
||||
editor.workspace = true
|
||||
eval_utils = { workspace = true, optional = true }
|
||||
extension.workspace = true
|
||||
extension_host.workspace = true
|
||||
feature_flags.workspace = true
|
||||
|
|
@ -71,6 +72,7 @@ postage.workspace = true
|
|||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
proto.workspace = true
|
||||
rand.workspace = true
|
||||
release_channel.workspace = true
|
||||
rope.workspace = true
|
||||
rules_library.workspace = true
|
||||
|
|
@ -119,7 +121,6 @@ language_model = { workspace = true, "features" = ["test-support"] }
|
|||
pretty_assertions.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
semver.workspace = true
|
||||
rand.workspace = true
|
||||
reqwest_client.workspace = true
|
||||
tree-sitter-md.workspace = true
|
||||
unindent.workspace = true
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@ mod buffer_codegen;
|
|||
mod completion_provider;
|
||||
mod context;
|
||||
mod context_server_configuration;
|
||||
#[cfg(test)]
|
||||
mod evals;
|
||||
mod inline_assistant;
|
||||
mod inline_prompt_editor;
|
||||
mod language_model_selector;
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ use std::{
|
|||
time::Instant,
|
||||
};
|
||||
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
|
||||
use ui::SharedString;
|
||||
|
||||
/// Use this tool to provide a message to the user when you're unable to complete a task.
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
|
|
@ -56,16 +55,16 @@ pub struct FailureMessageInput {
|
|||
/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct RewriteSectionInput {
|
||||
/// The text to replace the section with.
|
||||
#[serde(default)]
|
||||
pub replacement_text: String,
|
||||
|
||||
/// A brief description of the edit you have made.
|
||||
///
|
||||
/// The description may use markdown formatting if you wish.
|
||||
/// This is optional - if the edit is simple or obvious, you should leave it empty.
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
|
||||
/// The text to replace the section with.
|
||||
#[serde(default)]
|
||||
pub replacement_text: String,
|
||||
}
|
||||
|
||||
pub struct BufferCodegen {
|
||||
|
|
@ -287,8 +286,9 @@ pub struct CodegenAlternative {
|
|||
completion: Option<String>,
|
||||
selected_text: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
pub model_explanation: Option<SharedString>,
|
||||
session_id: Uuid,
|
||||
pub description: Option<String>,
|
||||
pub failure: Option<String>,
|
||||
}
|
||||
|
||||
impl EventEmitter<CodegenEvent> for CodegenAlternative {}
|
||||
|
|
@ -346,8 +346,9 @@ impl CodegenAlternative {
|
|||
elapsed_time: None,
|
||||
completion: None,
|
||||
selected_text: None,
|
||||
model_explanation: None,
|
||||
session_id,
|
||||
description: None,
|
||||
failure: None,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
}
|
||||
}
|
||||
|
|
@ -920,6 +921,16 @@ impl CodegenAlternative {
|
|||
self.completion.clone()
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn current_description(&self) -> Option<String> {
|
||||
self.description.clone()
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn current_failure(&self) -> Option<String> {
|
||||
self.failure.clone()
|
||||
}
|
||||
|
||||
pub fn selected_text(&self) -> Option<&str> {
|
||||
self.selected_text.as_deref()
|
||||
}
|
||||
|
|
@ -1133,32 +1144,69 @@ impl CodegenAlternative {
|
|||
}
|
||||
};
|
||||
|
||||
enum ToolUseOutput {
|
||||
Rewrite {
|
||||
text: String,
|
||||
description: Option<String>,
|
||||
},
|
||||
Failure(String),
|
||||
}
|
||||
|
||||
enum ModelUpdate {
|
||||
Description(String),
|
||||
Failure(String),
|
||||
}
|
||||
|
||||
let chars_read_so_far = Arc::new(Mutex::new(0usize));
|
||||
let tool_to_text_and_message =
|
||||
move |tool_use: LanguageModelToolUse| -> (Option<String>, Option<String>) {
|
||||
let mut chars_read_so_far = chars_read_so_far.lock();
|
||||
match tool_use.name.as_ref() {
|
||||
"rewrite_section" => {
|
||||
let Ok(mut input) =
|
||||
serde_json::from_value::<RewriteSectionInput>(tool_use.input)
|
||||
else {
|
||||
return (None, None);
|
||||
};
|
||||
let value = input.replacement_text[*chars_read_so_far..].to_string();
|
||||
*chars_read_so_far = input.replacement_text.len();
|
||||
(Some(value), Some(std::mem::take(&mut input.description)))
|
||||
}
|
||||
"failure_message" => {
|
||||
let Ok(mut input) =
|
||||
serde_json::from_value::<FailureMessageInput>(tool_use.input)
|
||||
else {
|
||||
return (None, None);
|
||||
};
|
||||
(None, Some(std::mem::take(&mut input.message)))
|
||||
}
|
||||
_ => (None, None),
|
||||
let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option<ToolUseOutput> {
|
||||
let mut chars_read_so_far = chars_read_so_far.lock();
|
||||
let is_complete = tool_use.is_input_complete;
|
||||
match tool_use.name.as_ref() {
|
||||
"rewrite_section" => {
|
||||
let Ok(mut input) =
|
||||
serde_json::from_value::<RewriteSectionInput>(tool_use.input)
|
||||
else {
|
||||
return None;
|
||||
};
|
||||
let text = input.replacement_text[*chars_read_so_far..].to_string();
|
||||
*chars_read_so_far = input.replacement_text.len();
|
||||
let description = is_complete
|
||||
.then(|| {
|
||||
let desc = std::mem::take(&mut input.description);
|
||||
if desc.is_empty() { None } else { Some(desc) }
|
||||
})
|
||||
.flatten();
|
||||
Some(ToolUseOutput::Rewrite { text, description })
|
||||
}
|
||||
};
|
||||
"failure_message" => {
|
||||
if !is_complete {
|
||||
return None;
|
||||
}
|
||||
let Ok(mut input) =
|
||||
serde_json::from_value::<FailureMessageInput>(tool_use.input)
|
||||
else {
|
||||
return None;
|
||||
};
|
||||
Some(ToolUseOutput::Failure(std::mem::take(&mut input.message)))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
};
|
||||
|
||||
let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded::<ModelUpdate>();
|
||||
|
||||
cx.spawn({
|
||||
let codegen = codegen.clone();
|
||||
async move |cx| {
|
||||
while let Some(update) = message_rx.next().await {
|
||||
let _ = codegen.update(cx, |this, _cx| match update {
|
||||
ModelUpdate::Description(d) => this.description = Some(d),
|
||||
ModelUpdate::Failure(f) => this.failure = Some(f),
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
let mut message_id = None;
|
||||
let mut first_text = None;
|
||||
|
|
@ -1171,24 +1219,23 @@ impl CodegenAlternative {
|
|||
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
|
||||
message_id = Some(id);
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
|
||||
if matches!(
|
||||
tool_use.name.as_ref(),
|
||||
"rewrite_section" | "failure_message"
|
||||
) =>
|
||||
{
|
||||
let is_complete = tool_use.is_input_complete;
|
||||
let (text, message) = tool_to_text_and_message(tool_use);
|
||||
// Only update the model explanation if the tool use is complete.
|
||||
// Otherwise the UI element bounces around as it's updated.
|
||||
if is_complete {
|
||||
let _ = codegen.update(cx, |this, _cx| {
|
||||
this.model_explanation = message.map(Into::into);
|
||||
});
|
||||
}
|
||||
first_text = text;
|
||||
if first_text.is_some() {
|
||||
break;
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
|
||||
if let Some(output) = process_tool_use(tool_use) {
|
||||
let (text, update) = match output {
|
||||
ToolUseOutput::Rewrite { text, description } => {
|
||||
(Some(text), description.map(ModelUpdate::Description))
|
||||
}
|
||||
ToolUseOutput::Failure(message) => {
|
||||
(None, Some(ModelUpdate::Failure(message)))
|
||||
}
|
||||
};
|
||||
if let Some(update) = update {
|
||||
let _ = message_tx.unbounded_send(update);
|
||||
}
|
||||
first_text = text;
|
||||
if first_text.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
|
|
@ -1215,41 +1262,30 @@ impl CodegenAlternative {
|
|||
return;
|
||||
};
|
||||
|
||||
let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded();
|
||||
|
||||
cx.spawn({
|
||||
let codegen = codegen.clone();
|
||||
async move |cx| {
|
||||
while let Some(message) = message_rx.next().await {
|
||||
let _ = codegen.update(cx, |this, _cx| {
|
||||
this.model_explanation = message;
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
let move_last_token_usage = last_token_usage.clone();
|
||||
|
||||
let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
|
||||
completion_events.filter_map(move |e| {
|
||||
let tool_to_text_and_message = tool_to_text_and_message.clone();
|
||||
let process_tool_use = process_tool_use.clone();
|
||||
let last_token_usage = move_last_token_usage.clone();
|
||||
let total_text = total_text.clone();
|
||||
let mut message_tx = message_tx.clone();
|
||||
async move {
|
||||
match e {
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
|
||||
if matches!(
|
||||
tool_use.name.as_ref(),
|
||||
"rewrite_section" | "failure_message"
|
||||
) =>
|
||||
{
|
||||
let is_complete = tool_use.is_input_complete;
|
||||
let (text, message) = tool_to_text_and_message(tool_use);
|
||||
if is_complete {
|
||||
// Again only send the message when complete to not get a bouncing UI element.
|
||||
let _ = message_tx.send(message.map(Into::into)).await;
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
|
||||
let Some(output) = process_tool_use(tool_use) else {
|
||||
return None;
|
||||
};
|
||||
let (text, update) = match output {
|
||||
ToolUseOutput::Rewrite { text, description } => {
|
||||
(Some(text), description.map(ModelUpdate::Description))
|
||||
}
|
||||
ToolUseOutput::Failure(message) => {
|
||||
(None, Some(ModelUpdate::Failure(message)))
|
||||
}
|
||||
};
|
||||
if let Some(update) = update {
|
||||
let _ = message_tx.send(update).await;
|
||||
}
|
||||
text.map(Ok)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,89 +0,0 @@
|
|||
use std::str::FromStr;
|
||||
|
||||
use crate::inline_assistant::test::run_inline_assistant_test;
|
||||
|
||||
use eval_utils::{EvalOutput, NoProcessor};
|
||||
use gpui::TestAppContext;
|
||||
use language_model::{LanguageModelRegistry, SelectedModel};
|
||||
use rand::{SeedableRng as _, rngs::StdRng};
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "unit-eval"), ignore)]
|
||||
fn eval_single_cursor_edit() {
|
||||
eval_utils::eval(20, 1.0, NoProcessor, move || {
|
||||
run_eval(
|
||||
&EvalInput {
|
||||
prompt: "Rename this variable to buffer_text".to_string(),
|
||||
buffer: indoc::indoc! {"
|
||||
struct EvalExampleStruct {
|
||||
text: Strˇing,
|
||||
prompt: String,
|
||||
}
|
||||
"}
|
||||
.to_string(),
|
||||
},
|
||||
&|_, output| {
|
||||
let expected = indoc::indoc! {"
|
||||
struct EvalExampleStruct {
|
||||
buffer_text: String,
|
||||
prompt: String,
|
||||
}
|
||||
"};
|
||||
if output == expected {
|
||||
EvalOutput {
|
||||
outcome: eval_utils::OutcomeKind::Passed,
|
||||
data: "Passed!".to_string(),
|
||||
metadata: (),
|
||||
}
|
||||
} else {
|
||||
EvalOutput {
|
||||
outcome: eval_utils::OutcomeKind::Failed,
|
||||
data: format!("Failed to rename variable, output: {}", output),
|
||||
metadata: (),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
struct EvalInput {
|
||||
buffer: String,
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
fn run_eval(
|
||||
input: &EvalInput,
|
||||
judge: &dyn Fn(&EvalInput, &str) -> eval_utils::EvalOutput<()>,
|
||||
) -> eval_utils::EvalOutput<()> {
|
||||
let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng());
|
||||
let mut cx = TestAppContext::build(dispatcher, None);
|
||||
cx.skip_drawing();
|
||||
|
||||
let buffer_text = run_inline_assistant_test(
|
||||
input.buffer.clone(),
|
||||
input.prompt.clone(),
|
||||
|cx| {
|
||||
// Reconfigure to use a real model instead of the fake one
|
||||
let model_name = std::env::var("ZED_AGENT_MODEL")
|
||||
.unwrap_or("anthropic/claude-sonnet-4-latest".into());
|
||||
|
||||
let selected_model = SelectedModel::from_str(&model_name)
|
||||
.expect("Invalid model format. Use 'provider/model-id'");
|
||||
|
||||
log::info!("Selected model: {selected_model:?}");
|
||||
|
||||
cx.update(|_, cx| {
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.select_inline_assistant_model(Some(&selected_model), cx);
|
||||
});
|
||||
});
|
||||
},
|
||||
|_cx| {
|
||||
log::info!("Waiting for actual response from the LLM...");
|
||||
},
|
||||
&mut cx,
|
||||
);
|
||||
|
||||
judge(input, &buffer_text)
|
||||
}
|
||||
|
|
@ -117,14 +117,6 @@ impl InlineAssistant {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn set_completion_receiver(
|
||||
&mut self,
|
||||
sender: mpsc::UnboundedSender<anyhow::Result<InlineAssistId>>,
|
||||
) {
|
||||
self._inline_assistant_completions = Some(sender);
|
||||
}
|
||||
|
||||
pub fn register_workspace(
|
||||
&mut self,
|
||||
workspace: &Entity<Workspace>,
|
||||
|
|
@ -1593,6 +1585,27 @@ impl InlineAssistant {
|
|||
.map(InlineAssistTarget::Terminal)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn set_completion_receiver(
|
||||
&mut self,
|
||||
sender: mpsc::UnboundedSender<anyhow::Result<InlineAssistId>>,
|
||||
) {
|
||||
self._inline_assistant_completions = Some(sender);
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn get_codegen(
|
||||
&mut self,
|
||||
assist_id: InlineAssistId,
|
||||
cx: &mut App,
|
||||
) -> Option<Entity<CodegenAlternative>> {
|
||||
self.assists.get(&assist_id).map(|inline_assist| {
|
||||
inline_assist
|
||||
.codegen
|
||||
.update(cx, |codegen, _cx| codegen.active_alternative().clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct EditorInlineAssists {
|
||||
|
|
@ -2014,8 +2027,10 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
#[cfg(any(test, feature = "unit-eval"))]
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
pub mod test {
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent::HistoryStore;
|
||||
|
|
@ -2026,7 +2041,6 @@ pub mod test {
|
|||
use futures::channel::mpsc;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
|
||||
use language::Buffer;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use smol::stream::StreamExt as _;
|
||||
|
|
@ -2035,13 +2049,43 @@ pub mod test {
|
|||
|
||||
use crate::InlineAssistant;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum InlineAssistantOutput {
|
||||
Success {
|
||||
completion: Option<String>,
|
||||
description: Option<String>,
|
||||
full_buffer_text: String,
|
||||
},
|
||||
Failure {
|
||||
failure: String,
|
||||
},
|
||||
// These fields are used for logging
|
||||
#[allow(unused)]
|
||||
Malformed {
|
||||
completion: Option<String>,
|
||||
description: Option<String>,
|
||||
failure: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl InlineAssistantOutput {
|
||||
pub fn buffer_text(&self) -> &str {
|
||||
match self {
|
||||
InlineAssistantOutput::Success {
|
||||
full_buffer_text, ..
|
||||
} => full_buffer_text,
|
||||
_ => "",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_inline_assistant_test<SetupF, TestF>(
|
||||
base_buffer: String,
|
||||
prompt: String,
|
||||
setup: SetupF,
|
||||
test: TestF,
|
||||
cx: &mut TestAppContext,
|
||||
) -> String
|
||||
) -> InlineAssistantOutput
|
||||
where
|
||||
SetupF: FnOnce(&mut gpui::VisualTestContext),
|
||||
TestF: FnOnce(&mut gpui::VisualTestContext),
|
||||
|
|
@ -2133,39 +2177,198 @@ pub mod test {
|
|||
|
||||
test(cx);
|
||||
|
||||
cx.executor()
|
||||
.block_test(async { completion_rx.next().await });
|
||||
let assist_id = cx
|
||||
.executor()
|
||||
.block_test(async { completion_rx.next().await })
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
buffer.read_with(cx, |buffer, _| buffer.text())
|
||||
}
|
||||
let (completion, description, failure) = cx.update(|_, cx| {
|
||||
InlineAssistant::update_global(cx, |inline_assistant, cx| {
|
||||
let codegen = inline_assistant.get_codegen(assist_id, cx).unwrap();
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn test_inline_assistant(
|
||||
base_buffer: &'static str,
|
||||
llm_output: &'static str,
|
||||
cx: &mut TestAppContext,
|
||||
) -> String {
|
||||
run_inline_assistant_test(
|
||||
base_buffer.to_string(),
|
||||
"Prompt doesn't matter because we're using a fake model".to_string(),
|
||||
|cx| {
|
||||
cx.update(|_, cx| LanguageModelRegistry::test(cx));
|
||||
},
|
||||
|cx| {
|
||||
let fake_model = cx.update(|_, cx| {
|
||||
LanguageModelRegistry::global(cx)
|
||||
.update(cx, |registry, _| registry.fake_model())
|
||||
});
|
||||
let fake = fake_model.as_fake();
|
||||
let completion = codegen.read(cx).current_completion();
|
||||
let description = codegen.read(cx).current_description();
|
||||
let failure = codegen.read(cx).current_failure();
|
||||
|
||||
// let fake = fake_model;
|
||||
fake.send_last_completion_stream_text_chunk(llm_output.to_string());
|
||||
fake.end_last_completion_stream();
|
||||
(completion, description, failure)
|
||||
})
|
||||
});
|
||||
|
||||
// Run again to process the model's response
|
||||
cx.run_until_parked();
|
||||
},
|
||||
cx,
|
||||
)
|
||||
if failure.is_some() && (completion.is_some() || description.is_some()) {
|
||||
InlineAssistantOutput::Malformed {
|
||||
completion,
|
||||
description,
|
||||
failure,
|
||||
}
|
||||
} else if let Some(failure) = failure {
|
||||
InlineAssistantOutput::Failure { failure }
|
||||
} else {
|
||||
InlineAssistantOutput::Success {
|
||||
completion,
|
||||
description,
|
||||
full_buffer_text: buffer.read_with(cx, |buffer, _| buffer.text()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "unit-eval"))]
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
pub mod evals {
|
||||
use std::str::FromStr;
|
||||
|
||||
use eval_utils::{EvalOutput, NoProcessor};
|
||||
use gpui::TestAppContext;
|
||||
use language_model::{LanguageModelRegistry, SelectedModel};
|
||||
use rand::{SeedableRng as _, rngs::StdRng};
|
||||
|
||||
use crate::inline_assistant::test::{InlineAssistantOutput, run_inline_assistant_test};
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "unit-eval"), ignore)]
|
||||
fn eval_single_cursor_edit() {
|
||||
run_eval(
|
||||
20,
|
||||
1.0,
|
||||
"Rename this variable to buffer_text".to_string(),
|
||||
indoc::indoc! {"
|
||||
struct EvalExampleStruct {
|
||||
text: Strˇing,
|
||||
prompt: String,
|
||||
}
|
||||
"}
|
||||
.to_string(),
|
||||
exact_buffer_match(indoc::indoc! {"
|
||||
struct EvalExampleStruct {
|
||||
buffer_text: String,
|
||||
prompt: String,
|
||||
}
|
||||
"}),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "unit-eval"), ignore)]
|
||||
fn eval_cant_do() {
|
||||
run_eval(
|
||||
20,
|
||||
1.0,
|
||||
"Rename the struct to EvalExampleStructNope",
|
||||
indoc::indoc! {"
|
||||
struct EvalExampleStruct {
|
||||
text: Strˇing,
|
||||
prompt: String,
|
||||
}
|
||||
"},
|
||||
uncertain_output,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "unit-eval"), ignore)]
|
||||
fn eval_unclear() {
|
||||
run_eval(
|
||||
20,
|
||||
1.0,
|
||||
"Make exactly the change I want you to make",
|
||||
indoc::indoc! {"
|
||||
struct EvalExampleStruct {
|
||||
text: Strˇing,
|
||||
prompt: String,
|
||||
}
|
||||
"},
|
||||
uncertain_output,
|
||||
);
|
||||
}
|
||||
|
||||
fn run_eval(
|
||||
iterations: usize,
|
||||
expected_pass_ratio: f32,
|
||||
prompt: impl Into<String>,
|
||||
buffer: impl Into<String>,
|
||||
judge: impl Fn(InlineAssistantOutput) -> eval_utils::EvalOutput<()> + Send + Sync + 'static,
|
||||
) {
|
||||
let buffer = buffer.into();
|
||||
let prompt = prompt.into();
|
||||
|
||||
eval_utils::eval(iterations, expected_pass_ratio, NoProcessor, move || {
|
||||
let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng());
|
||||
let mut cx = TestAppContext::build(dispatcher, None);
|
||||
cx.skip_drawing();
|
||||
|
||||
let output = run_inline_assistant_test(
|
||||
buffer.clone(),
|
||||
prompt.clone(),
|
||||
|cx| {
|
||||
// Reconfigure to use a real model instead of the fake one
|
||||
let model_name = std::env::var("ZED_AGENT_MODEL")
|
||||
.unwrap_or("anthropic/claude-sonnet-4-latest".into());
|
||||
|
||||
let selected_model = SelectedModel::from_str(&model_name)
|
||||
.expect("Invalid model format. Use 'provider/model-id'");
|
||||
|
||||
log::info!("Selected model: {selected_model:?}");
|
||||
|
||||
cx.update(|_, cx| {
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.select_inline_assistant_model(Some(&selected_model), cx);
|
||||
});
|
||||
});
|
||||
},
|
||||
|_cx| {
|
||||
log::info!("Waiting for actual response from the LLM...");
|
||||
},
|
||||
&mut cx,
|
||||
);
|
||||
|
||||
cx.quit();
|
||||
|
||||
judge(output)
|
||||
});
|
||||
}
|
||||
|
||||
fn uncertain_output(output: InlineAssistantOutput) -> EvalOutput<()> {
|
||||
match &output {
|
||||
o @ InlineAssistantOutput::Success {
|
||||
completion,
|
||||
description,
|
||||
..
|
||||
} => {
|
||||
if description.is_some() && completion.is_none() {
|
||||
EvalOutput::passed(format!(
|
||||
"Assistant produced no completion, but a description:\n{}",
|
||||
description.as_ref().unwrap()
|
||||
))
|
||||
} else {
|
||||
EvalOutput::failed(format!("Assistant produced a completion:\n{:?}", o))
|
||||
}
|
||||
}
|
||||
InlineAssistantOutput::Failure {
|
||||
failure: error_message,
|
||||
} => EvalOutput::passed(format!(
|
||||
"Assistant produced a failure message: {}",
|
||||
error_message
|
||||
)),
|
||||
o @ InlineAssistantOutput::Malformed { .. } => {
|
||||
EvalOutput::failed(format!("Assistant produced a malformed response:\n{:?}", o))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn exact_buffer_match(
|
||||
correct_output: impl Into<String>,
|
||||
) -> impl Fn(InlineAssistantOutput) -> EvalOutput<()> {
|
||||
let correct_output = correct_output.into();
|
||||
move |output| {
|
||||
if output.buffer_text() == correct_output {
|
||||
EvalOutput::passed("Assistant output matches")
|
||||
} else {
|
||||
EvalOutput::failed(format!(
|
||||
"Assistant output does not match expected output: {:?}",
|
||||
output
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -101,11 +101,11 @@ impl<T: 'static> Render for PromptEditor<T> {
|
|||
let left_gutter_width = gutter.full_width() + (gutter.margin / 2.0);
|
||||
let right_padding = editor_margins.right + RIGHT_PADDING;
|
||||
|
||||
let explanation = codegen
|
||||
.active_alternative()
|
||||
.read(cx)
|
||||
.model_explanation
|
||||
.clone();
|
||||
let active_alternative = codegen.active_alternative().read(cx);
|
||||
let explanation = active_alternative
|
||||
.description
|
||||
.clone()
|
||||
.or_else(|| active_alternative.failure.clone());
|
||||
|
||||
(left_gutter_width, right_padding, explanation)
|
||||
}
|
||||
|
|
@ -139,7 +139,7 @@ impl<T: 'static> Render for PromptEditor<T> {
|
|||
|
||||
if let Some(explanation) = &explanation {
|
||||
markdown.update(cx, |markdown, cx| {
|
||||
markdown.reset(explanation.clone(), cx);
|
||||
markdown.reset(SharedString::from(explanation), cx);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,24 @@ pub struct EvalOutput<M> {
|
|||
pub metadata: M,
|
||||
}
|
||||
|
||||
impl<M: Default> EvalOutput<M> {
|
||||
pub fn passed(message: impl Into<String>) -> Self {
|
||||
EvalOutput {
|
||||
outcome: OutcomeKind::Passed,
|
||||
data: message.into(),
|
||||
metadata: M::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn failed(message: impl Into<String>) -> Self {
|
||||
EvalOutput {
|
||||
outcome: OutcomeKind::Failed,
|
||||
data: message.into(),
|
||||
metadata: M::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NoProcessor;
|
||||
impl EvalOutputProcessor for NoProcessor {
|
||||
type Metadata = ();
|
||||
|
|
|
|||
|
|
@ -18,6 +18,6 @@ impl FeatureFlag for InlineAssistantUseToolFeatureFlag {
|
|||
const NAME: &'static str = "inline-assistant-use-tool";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
false
|
||||
true
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ use settings::{Settings, SettingsStore};
|
|||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use std::sync::{Arc, LazyLock, OnceLock};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
|
||||
use ui_input::InputField;
|
||||
|
|
@ -31,7 +31,6 @@ static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
|
|||
|
||||
const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY";
|
||||
static CODESTRAL_API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME);
|
||||
static CODESTRAL_API_KEY: OnceLock<Entity<ApiKeyState>> = OnceLock::new();
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct MistralSettings {
|
||||
|
|
@ -49,14 +48,18 @@ pub struct State {
|
|||
codestral_api_key_state: Entity<ApiKeyState>,
|
||||
}
|
||||
|
||||
struct CodestralApiKey(Entity<ApiKeyState>);
|
||||
impl Global for CodestralApiKey {}
|
||||
|
||||
pub fn codestral_api_key(cx: &mut App) -> Entity<ApiKeyState> {
|
||||
return CODESTRAL_API_KEY
|
||||
.get_or_init(|| {
|
||||
cx.new(|_| {
|
||||
ApiKeyState::new(CODESTRAL_API_URL.into(), CODESTRAL_API_KEY_ENV_VAR.clone())
|
||||
})
|
||||
})
|
||||
.clone();
|
||||
if cx.has_global::<CodestralApiKey>() {
|
||||
cx.global::<CodestralApiKey>().0.clone()
|
||||
} else {
|
||||
let api_key_state = cx
|
||||
.new(|_| ApiKeyState::new(CODESTRAL_API_URL.into(), CODESTRAL_API_KEY_ENV_VAR.clone()));
|
||||
cx.set_global(CodestralApiKey(api_key_state.clone()));
|
||||
api_key_state
|
||||
}
|
||||
}
|
||||
|
||||
impl State {
|
||||
|
|
|
|||
Loading…
Reference in a new issue