diff --git a/crates/eval_cli/src/main.rs b/crates/eval_cli/src/main.rs index e77e75bc879..3a323776ef4 100644 --- a/crates/eval_cli/src/main.rs +++ b/crates/eval_cli/src/main.rs @@ -47,7 +47,10 @@ use feature_flags::FeatureFlagAppExt as _; use futures::{FutureExt, select_biased}; use gpui::{AppContext as _, AsyncApp, Entity, UpdateGlobal}; -use language_model::{LanguageModelRegistry, SelectedModel}; +use language_model::{ + ANTHROPIC_PROVIDER_ID, LanguageModel, LanguageModelId, LanguageModelProviderId, + LanguageModelRegistry, SelectedModel, +}; use project::Project; use settings::SettingsStore; use util::path_list::PathList; @@ -128,6 +131,8 @@ const EXIT_OK: i32 = 0; const EXIT_ERROR: i32 = 1; const EXIT_TIMEOUT: i32 = 2; const EXIT_INTERRUPTED: i32 = 3; +const MODEL_DISCOVERY_TIMEOUT: Duration = Duration::from_secs(30); +const MODEL_DISCOVERY_POLL_INTERVAL: Duration = Duration::from_millis(100); static TERMINATED: AtomicBool = AtomicBool::new(false); @@ -268,6 +273,182 @@ fn read_instruction(args: &Args) -> Result { Ok(text) } +async fn wait_for_model(selected: &SelectedModel, cx: &mut AsyncApp) -> Result<()> { + let started_at = Instant::now(); + + loop { + let found = cx.update(|cx| find_available_model(selected, cx).is_some()); + if found { + return Ok(()); + } + + cx.update(|cx| ensure_provider_authenticated(selected, cx))?; + + let selected_provider_has_models = cx.update(|cx| { + LanguageModelRegistry::global(cx) + .read(cx) + .available_models(cx) + .any(|model| model.provider_id() == selected.provider) + }); + let should_wait_for_discovery = + selected.provider == ANTHROPIC_PROVIDER_ID || !selected_provider_has_models; + + if !should_wait_for_discovery || started_at.elapsed() >= MODEL_DISCOVERY_TIMEOUT { + return Err(cx.update(|cx| model_not_found_error(&selected_model_name(selected), cx))); + } + + cx.background_executor() + .timer(MODEL_DISCOVERY_POLL_INTERVAL) + .await; + } +} + +fn ensure_provider_authenticated(selected: &SelectedModel, cx: &gpui::App) -> Result<()> { + let registry = LanguageModelRegistry::global(cx); + let provider = registry + .read(cx) + .provider(&selected.provider) + .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected.provider.0))?; + + anyhow::ensure!( + provider.is_authenticated(cx), + "Provider {} is not authenticated", + selected.provider.0 + ); + + Ok(()) +} + +fn find_available_model( + selected: &SelectedModel, + cx: &gpui::App, +) -> Option> { + let registry = LanguageModelRegistry::global(cx); + let models = registry.read(cx).available_models(cx).collect::>(); + + if let Some(model) = models + .iter() + .find(|model| model.provider_id() == selected.provider && model.id() == selected.model) + { + return Some(model.clone()); + } + + models + .into_iter() + .filter(|model| { + model.provider_id() == selected.provider + && model_id_matches_selected(&model.provider_id(), &model.id(), &selected.model) + }) + .max_by(|left, right| left.id().0.to_string().cmp(&right.id().0.to_string())) +} + +fn model_id_matches_selected( + provider_id: &LanguageModelProviderId, + available: &LanguageModelId, + selected: &LanguageModelId, +) -> bool { + if available == selected { + return true; + } + + if provider_id != &ANTHROPIC_PROVIDER_ID { + return false; + } + + anthropic_model_ids_match(available.0.as_ref(), selected.0.as_ref()) +} + +fn anthropic_model_ids_match(available: &str, selected: &str) -> bool { + let available = anthropic_model_alias_base(available); + let selected = anthropic_model_alias_base(selected); + + available == selected || anthropic_dated_model_id_matches_base(available, selected) +} + +fn anthropic_model_alias_base(mut model_id: &str) -> &str { + if let Some(stripped) = model_id.strip_suffix("-latest") { + model_id = stripped; + } + if let Some(stripped) = model_id.strip_suffix("-thinking") { + model_id = stripped; + } + if let Some(stripped) = model_id.strip_suffix("-1m-context") { + model_id = stripped; + } + model_id +} + +fn anthropic_dated_model_id_matches_base(available: &str, selected: &str) -> bool { + let Some(suffix) = available.strip_prefix(selected) else { + return false; + }; + let Some(date) = suffix.strip_prefix('-') else { + return false; + }; + + date.len() == 8 && date.chars().all(|character| character.is_ascii_digit()) +} + +fn selected_model_name(selected: &SelectedModel) -> String { + format!("{}/{}", selected.provider.0, selected.model.0) +} + +fn model_not_found_error(model_name: &str, cx: &gpui::App) -> anyhow::Error { + let available = LanguageModelRegistry::global(cx) + .read(cx) + .available_models(cx) + .map(|model| format!("{}/{}", model.provider_id().0, model.id().0)) + .collect::>(); + let available = if available.is_empty() { + "(none)".to_string() + } else { + available.join(", ") + }; + + anyhow::anyhow!("Model {model_name} not found. Available: {available}") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn anthropic_latest_alias_matches_listed_base_model() { + assert!(model_id_matches_selected( + &ANTHROPIC_PROVIDER_ID, + &LanguageModelId("claude-sonnet-4-6".into()), + &LanguageModelId("claude-sonnet-4-6-latest".into()), + )); + } + + #[test] + fn anthropic_thinking_alias_matches_listed_base_model() { + assert!(model_id_matches_selected( + &ANTHROPIC_PROVIDER_ID, + &LanguageModelId("claude-sonnet-4-6".into()), + &LanguageModelId("claude-sonnet-4-6-1m-context-thinking-latest".into()), + )); + } + + #[test] + fn anthropic_latest_alias_matches_listed_dated_model() { + assert!(model_id_matches_selected( + &ANTHROPIC_PROVIDER_ID, + &LanguageModelId("claude-sonnet-4-6-20260518".into()), + &LanguageModelId("claude-sonnet-4-6-latest".into()), + )); + } + + #[test] + fn non_anthropic_models_require_exact_ids() { + assert!(!model_id_matches_selected( + &LanguageModelProviderId("other".into()), + &LanguageModelId("claude-sonnet-4-6".into()), + &LanguageModelId("claude-sonnet-4-6-latest".into()), + )); + } +} + async fn run_agent( app_state: &Arc, workdir: &std::path::Path, @@ -279,37 +460,33 @@ async fn run_agent( output_dir: Option<&std::path::Path>, cx: &mut AsyncApp, ) -> (Result, Option) { + let selected = match SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!("{e}")) { + Ok(selected) => selected, + Err(e) => return (Err(e), None), + }; + + if let Err(e) = wait_for_model(&selected, cx).await { + return (Err(e), None); + } + let setup_result: Result<()> = cx.update(|cx| { - let selected = SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!("{e}"))?; let registry = LanguageModelRegistry::global(cx); - let model = registry + let model = find_available_model(&selected, cx) + .ok_or_else(|| model_not_found_error(model_name, cx))?; + let provider = registry .read(cx) - .available_models(cx) - .find(|m| m.id() == selected.model && m.provider_id() == selected.provider) - .ok_or_else(|| { - let available = registry - .read(cx) - .available_models(cx) - .map(|m| format!("{}/{}", m.provider_id().0, m.id().0)) - .collect::>() - .join(", "); - anyhow::anyhow!("Model {model_name} not found. Available: {available}") - })?; + .provider(&model.provider_id()) + .context("Provider not found")?; let supports_thinking = model.supports_thinking(); + let model_id = model.id().0.to_string(); registry.update(cx, |registry, cx| { registry.set_default_model( - Some(language_model::ConfiguredModel { - provider: registry - .provider(&model.provider_id()) - .context("Provider not found")?, - model, - }), + Some(language_model::ConfiguredModel { provider, model }), cx, ); - anyhow::Ok(()) - })?; + }); let enable_thinking = thinking_override.unwrap_or(supports_thinking); let effort = if enable_thinking { @@ -321,7 +498,6 @@ async fn run_agent( "null".to_string() }; let provider_id = selected.provider.0.to_string(); - let model_id = selected.model.0.to_string(); SettingsStore::update_global(cx, |store, cx| { let settings = format!( r#"{{ @@ -339,8 +515,9 @@ async fn run_agent( }}" "# ); - store.set_user_settings(&settings, cx).ok(); - }); + store.set_user_settings(&settings, cx).result() + }) + .context("updating agent settings")?; anyhow::Ok(()) }); diff --git a/crates/eval_cli/zed_eval/pyproject.toml b/crates/eval_cli/zed_eval/pyproject.toml index 10e72028a5e..07a61d7a13b 100644 --- a/crates/eval_cli/zed_eval/pyproject.toml +++ b/crates/eval_cli/zed_eval/pyproject.toml @@ -3,7 +3,7 @@ name = "zed-eval" version = "0.1.0" description = "Harbor agent wrapper for Zed's eval-cli" requires-python = ">=3.12" -dependencies = ["harbor==0.6.4"] +dependencies = ["harbor==0.7.0"] [build-system] requires = ["setuptools"]