mirror of
https://github.com/zed-industries/zed.git
synced 2026-05-31 19:05:00 +07:00
eval_cli: Wait for model discovery (#57038)
Given the model list is dynamic now, we need a wait Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Release Notes: - N/A
This commit is contained in:
parent
bfe914beec
commit
5aeb8a7e0f
2 changed files with 203 additions and 26 deletions
|
|
@ -47,7 +47,10 @@ use feature_flags::FeatureFlagAppExt as _;
|
|||
|
||||
use futures::{FutureExt, select_biased};
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, UpdateGlobal};
|
||||
use language_model::{LanguageModelRegistry, SelectedModel};
|
||||
use language_model::{
|
||||
ANTHROPIC_PROVIDER_ID, LanguageModel, LanguageModelId, LanguageModelProviderId,
|
||||
LanguageModelRegistry, SelectedModel,
|
||||
};
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use util::path_list::PathList;
|
||||
|
|
@ -128,6 +131,8 @@ const EXIT_OK: i32 = 0;
|
|||
const EXIT_ERROR: i32 = 1;
|
||||
const EXIT_TIMEOUT: i32 = 2;
|
||||
const EXIT_INTERRUPTED: i32 = 3;
|
||||
const MODEL_DISCOVERY_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
const MODEL_DISCOVERY_POLL_INTERVAL: Duration = Duration::from_millis(100);
|
||||
|
||||
static TERMINATED: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
|
|
@ -268,6 +273,182 @@ fn read_instruction(args: &Args) -> Result<String> {
|
|||
Ok(text)
|
||||
}
|
||||
|
||||
async fn wait_for_model(selected: &SelectedModel, cx: &mut AsyncApp) -> Result<()> {
|
||||
let started_at = Instant::now();
|
||||
|
||||
loop {
|
||||
let found = cx.update(|cx| find_available_model(selected, cx).is_some());
|
||||
if found {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
cx.update(|cx| ensure_provider_authenticated(selected, cx))?;
|
||||
|
||||
let selected_provider_has_models = cx.update(|cx| {
|
||||
LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.available_models(cx)
|
||||
.any(|model| model.provider_id() == selected.provider)
|
||||
});
|
||||
let should_wait_for_discovery =
|
||||
selected.provider == ANTHROPIC_PROVIDER_ID || !selected_provider_has_models;
|
||||
|
||||
if !should_wait_for_discovery || started_at.elapsed() >= MODEL_DISCOVERY_TIMEOUT {
|
||||
return Err(cx.update(|cx| model_not_found_error(&selected_model_name(selected), cx)));
|
||||
}
|
||||
|
||||
cx.background_executor()
|
||||
.timer(MODEL_DISCOVERY_POLL_INTERVAL)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_provider_authenticated(selected: &SelectedModel, cx: &gpui::App) -> Result<()> {
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
let provider = registry
|
||||
.read(cx)
|
||||
.provider(&selected.provider)
|
||||
.ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected.provider.0))?;
|
||||
|
||||
anyhow::ensure!(
|
||||
provider.is_authenticated(cx),
|
||||
"Provider {} is not authenticated",
|
||||
selected.provider.0
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn find_available_model(
|
||||
selected: &SelectedModel,
|
||||
cx: &gpui::App,
|
||||
) -> Option<Arc<dyn LanguageModel>> {
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
let models = registry.read(cx).available_models(cx).collect::<Vec<_>>();
|
||||
|
||||
if let Some(model) = models
|
||||
.iter()
|
||||
.find(|model| model.provider_id() == selected.provider && model.id() == selected.model)
|
||||
{
|
||||
return Some(model.clone());
|
||||
}
|
||||
|
||||
models
|
||||
.into_iter()
|
||||
.filter(|model| {
|
||||
model.provider_id() == selected.provider
|
||||
&& model_id_matches_selected(&model.provider_id(), &model.id(), &selected.model)
|
||||
})
|
||||
.max_by(|left, right| left.id().0.to_string().cmp(&right.id().0.to_string()))
|
||||
}
|
||||
|
||||
fn model_id_matches_selected(
|
||||
provider_id: &LanguageModelProviderId,
|
||||
available: &LanguageModelId,
|
||||
selected: &LanguageModelId,
|
||||
) -> bool {
|
||||
if available == selected {
|
||||
return true;
|
||||
}
|
||||
|
||||
if provider_id != &ANTHROPIC_PROVIDER_ID {
|
||||
return false;
|
||||
}
|
||||
|
||||
anthropic_model_ids_match(available.0.as_ref(), selected.0.as_ref())
|
||||
}
|
||||
|
||||
fn anthropic_model_ids_match(available: &str, selected: &str) -> bool {
|
||||
let available = anthropic_model_alias_base(available);
|
||||
let selected = anthropic_model_alias_base(selected);
|
||||
|
||||
available == selected || anthropic_dated_model_id_matches_base(available, selected)
|
||||
}
|
||||
|
||||
fn anthropic_model_alias_base(mut model_id: &str) -> &str {
|
||||
if let Some(stripped) = model_id.strip_suffix("-latest") {
|
||||
model_id = stripped;
|
||||
}
|
||||
if let Some(stripped) = model_id.strip_suffix("-thinking") {
|
||||
model_id = stripped;
|
||||
}
|
||||
if let Some(stripped) = model_id.strip_suffix("-1m-context") {
|
||||
model_id = stripped;
|
||||
}
|
||||
model_id
|
||||
}
|
||||
|
||||
fn anthropic_dated_model_id_matches_base(available: &str, selected: &str) -> bool {
|
||||
let Some(suffix) = available.strip_prefix(selected) else {
|
||||
return false;
|
||||
};
|
||||
let Some(date) = suffix.strip_prefix('-') else {
|
||||
return false;
|
||||
};
|
||||
|
||||
date.len() == 8 && date.chars().all(|character| character.is_ascii_digit())
|
||||
}
|
||||
|
||||
fn selected_model_name(selected: &SelectedModel) -> String {
|
||||
format!("{}/{}", selected.provider.0, selected.model.0)
|
||||
}
|
||||
|
||||
fn model_not_found_error(model_name: &str, cx: &gpui::App) -> anyhow::Error {
|
||||
let available = LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.available_models(cx)
|
||||
.map(|model| format!("{}/{}", model.provider_id().0, model.id().0))
|
||||
.collect::<Vec<_>>();
|
||||
let available = if available.is_empty() {
|
||||
"(none)".to_string()
|
||||
} else {
|
||||
available.join(", ")
|
||||
};
|
||||
|
||||
anyhow::anyhow!("Model {model_name} not found. Available: {available}")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn anthropic_latest_alias_matches_listed_base_model() {
|
||||
assert!(model_id_matches_selected(
|
||||
&ANTHROPIC_PROVIDER_ID,
|
||||
&LanguageModelId("claude-sonnet-4-6".into()),
|
||||
&LanguageModelId("claude-sonnet-4-6-latest".into()),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn anthropic_thinking_alias_matches_listed_base_model() {
|
||||
assert!(model_id_matches_selected(
|
||||
&ANTHROPIC_PROVIDER_ID,
|
||||
&LanguageModelId("claude-sonnet-4-6".into()),
|
||||
&LanguageModelId("claude-sonnet-4-6-1m-context-thinking-latest".into()),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn anthropic_latest_alias_matches_listed_dated_model() {
|
||||
assert!(model_id_matches_selected(
|
||||
&ANTHROPIC_PROVIDER_ID,
|
||||
&LanguageModelId("claude-sonnet-4-6-20260518".into()),
|
||||
&LanguageModelId("claude-sonnet-4-6-latest".into()),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_anthropic_models_require_exact_ids() {
|
||||
assert!(!model_id_matches_selected(
|
||||
&LanguageModelProviderId("other".into()),
|
||||
&LanguageModelId("claude-sonnet-4-6".into()),
|
||||
&LanguageModelId("claude-sonnet-4-6-latest".into()),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_agent(
|
||||
app_state: &Arc<AgentCliAppState>,
|
||||
workdir: &std::path::Path,
|
||||
|
|
@ -279,37 +460,33 @@ async fn run_agent(
|
|||
output_dir: Option<&std::path::Path>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> (Result<AgentOutcome>, Option<language_model::TokenUsage>) {
|
||||
let selected = match SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!("{e}")) {
|
||||
Ok(selected) => selected,
|
||||
Err(e) => return (Err(e), None),
|
||||
};
|
||||
|
||||
if let Err(e) = wait_for_model(&selected, cx).await {
|
||||
return (Err(e), None);
|
||||
}
|
||||
|
||||
let setup_result: Result<()> = cx.update(|cx| {
|
||||
let selected = SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!("{e}"))?;
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
let model = registry
|
||||
let model = find_available_model(&selected, cx)
|
||||
.ok_or_else(|| model_not_found_error(model_name, cx))?;
|
||||
let provider = registry
|
||||
.read(cx)
|
||||
.available_models(cx)
|
||||
.find(|m| m.id() == selected.model && m.provider_id() == selected.provider)
|
||||
.ok_or_else(|| {
|
||||
let available = registry
|
||||
.read(cx)
|
||||
.available_models(cx)
|
||||
.map(|m| format!("{}/{}", m.provider_id().0, m.id().0))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
anyhow::anyhow!("Model {model_name} not found. Available: {available}")
|
||||
})?;
|
||||
.provider(&model.provider_id())
|
||||
.context("Provider not found")?;
|
||||
|
||||
let supports_thinking = model.supports_thinking();
|
||||
let model_id = model.id().0.to_string();
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.set_default_model(
|
||||
Some(language_model::ConfiguredModel {
|
||||
provider: registry
|
||||
.provider(&model.provider_id())
|
||||
.context("Provider not found")?,
|
||||
model,
|
||||
}),
|
||||
Some(language_model::ConfiguredModel { provider, model }),
|
||||
cx,
|
||||
);
|
||||
anyhow::Ok(())
|
||||
})?;
|
||||
});
|
||||
|
||||
let enable_thinking = thinking_override.unwrap_or(supports_thinking);
|
||||
let effort = if enable_thinking {
|
||||
|
|
@ -321,7 +498,6 @@ async fn run_agent(
|
|||
"null".to_string()
|
||||
};
|
||||
let provider_id = selected.provider.0.to_string();
|
||||
let model_id = selected.model.0.to_string();
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
let settings = format!(
|
||||
r#"{{
|
||||
|
|
@ -339,8 +515,9 @@ async fn run_agent(
|
|||
}}"
|
||||
"#
|
||||
);
|
||||
store.set_user_settings(&settings, cx).ok();
|
||||
});
|
||||
store.set_user_settings(&settings, cx).result()
|
||||
})
|
||||
.context("updating agent settings")?;
|
||||
|
||||
anyhow::Ok(())
|
||||
});
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ name = "zed-eval"
|
|||
version = "0.1.0"
|
||||
description = "Harbor agent wrapper for Zed's eval-cli"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = ["harbor==0.6.4"]
|
||||
dependencies = ["harbor==0.7.0"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue