ep cli: Load captured examples from Snowflake (#46102)

Release Notes:

- N/A
This commit is contained in:
Agus Zubiaga 2026-01-06 10:40:13 -03:00 committed by GitHub
parent 216933f0b8
commit 583a479f77
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 376 additions and 36 deletions

View file

@ -125,17 +125,9 @@ impl Example {
}
}
pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
let mut examples = Vec::new();
let stdin_path: PathBuf = PathBuf::from("-");
let inputs = if inputs.is_empty() {
&[stdin_path]
} else {
inputs
};
for path in inputs {
let is_stdin = path.as_path() == Path::new("-");
let content = if is_stdin {
@ -201,7 +193,6 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
}
}
sort_examples_by_repo_and_rev(&mut examples);
examples
}

View file

@ -9,12 +9,12 @@ mod metrics;
mod paths;
mod predict;
mod progress;
mod pull_examples;
mod reorder_patch;
mod retrieve_context;
mod score;
mod split_commit;
mod synthesize;
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
use edit_prediction::EditPredictionStore;
use gpui::Application;
@ -24,7 +24,7 @@ use std::fmt::Display;
use std::{path::PathBuf, sync::Arc};
use crate::distill::run_distill;
use crate::example::{group_examples_by_repo, read_examples, write_examples};
use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
use crate::paths::FAILED_EXAMPLES_DIR;
@ -42,9 +42,11 @@ struct EpArgs {
printenv: bool,
#[clap(long, default_value_t = 10, global = true)]
max_parallelism: usize,
#[clap(long, global = true)]
limit: Option<usize>,
#[command(subcommand)]
command: Option<Command>,
#[clap(global = true)]
#[clap(global = true, help = INPUTS_HELP)]
inputs: Vec<PathBuf>,
#[arg(long, short, global = true)]
output: Option<PathBuf>,
@ -54,7 +56,37 @@ struct EpArgs {
failfast: bool,
}
#[derive(Subcommand, Debug)]
const INPUTS_HELP: &str = r#"
Inputs can be file paths or special specifiers:
path
Path to an example(s) file (.md, .json, or .jsonl)
captured-after:{timestamp}
Fetch captured examples from Snowflake after the given RFC3339 timestamp.
You can specify this multiple times and mix it with file inputs.
Required environment variables to connect to Snowflake:
EP_SNOWFLAKE_API_KEY
EP_SNOWFLAKE_BASE_URL
Optional:
EP_SNOWFLAKE_ROLE
Examples:
# Predict from a file
ep predict examples.jsonl
# Predict from captured examples after a timestamp
ep predict captured-after:2025-01-01T00:00:00Z
# Mix file inputs and captured-after in the same invocation
ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
"#;
#[derive(Subcommand, Debug, Clone)]
enum Command {
/// Parse markdown examples and output a combined .jsonl file
ParseExample,
@ -137,7 +169,7 @@ impl Display for Command {
}
}
#[derive(Debug, Args)]
#[derive(Debug, Args, Clone)]
struct FormatPromptArgs {
#[clap(long)]
prompt_format: PromptFormat,
@ -149,7 +181,7 @@ enum PromptFormat {
Zeta2,
}
#[derive(Debug, Args)]
#[derive(Debug, Args, Clone)]
struct PredictArgs {
#[clap(long)]
provider: PredictionProvider,
@ -167,7 +199,7 @@ enum PredictionProvider {
TeacherNonBatching,
}
#[derive(Debug, Args)]
#[derive(Debug, Args, Clone)]
struct SynthesizeArgs {
/// Repository URL (git@github.com:owner/repo or https://...)
#[clap(long)]
@ -200,6 +232,60 @@ impl EpArgs {
}
}
async fn load_examples(
http_client: Arc<dyn http_client::HttpClient>,
args: &EpArgs,
) -> anyhow::Result<Vec<Example>> {
let mut captured_after_timestamps = Vec::new();
let mut file_inputs = Vec::new();
for input in &args.inputs {
let input_string = input.to_string_lossy();
if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
captured_after_timestamps.push(timestamp.to_string());
} else {
file_inputs.push(input.clone());
}
}
let mut examples = read_example_files(&file_inputs);
let total_steps = examples.len() + captured_after_timestamps.len();
Progress::global().set_total_steps(total_steps);
let remaining_limit_for_snowflake =
args.limit.map(|limit| limit.saturating_sub(examples.len()));
if let Some(0) = remaining_limit_for_snowflake {
log::info!(
"skipping captured-after inputs because --limit is already satisfied by example files"
);
} else if !captured_after_timestamps.is_empty() {
captured_after_timestamps.sort();
let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
let mut captured_examples = pull_examples::fetch_captured_examples_after(
http_client,
&captured_after_timestamps,
max_rows_per_timestamp,
)
.await?;
examples.append(&mut captured_examples);
}
crate::example::sort_examples_by_repo_and_rev(&mut examples);
if let Some(limit) = args.limit {
if examples.len() > limit {
examples.truncate(limit);
}
}
Progress::global().set_total_steps(examples.len() + captured_after_timestamps.len());
Ok(examples)
}
fn main() {
let args = EpArgs::parse();
@ -209,8 +295,8 @@ fn main() {
}
let output = args.output_path();
let command = match args.command {
Some(cmd) => cmd,
let command = match &args.command {
Some(cmd) => cmd.clone(),
None => {
EpArgs::command().print_help().unwrap();
return;
@ -251,7 +337,6 @@ fn main() {
_ => {}
}
let mut examples = read_examples(&args.inputs);
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client);
@ -261,12 +346,13 @@ fn main() {
cx.spawn(async move |cx| {
let result = async {
let mut examples = load_examples(app_state.client.http_client(), &args).await?;
if let Command::Predict(args) = &command {
predict::sync_batches(&args.provider).await?;
}
let total_examples = examples.len();
Progress::global().set_total_examples(total_examples);
let failfast_on_single_example = examples.len() == 1;
let mut grouped_examples = group_examples_by_repo(&mut examples);
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
@ -347,7 +433,7 @@ fn main() {
let msg = format!(
indoc::indoc! {"
While processing {}:
While processing \"{}\":
{:?}
@ -366,7 +452,7 @@ fn main() {
command,
failed_example_path.display(),
);
if args.failfast || total_examples == 1 {
if args.failfast || failfast_on_single_example {
Progress::global().finalize();
panic!("{}", msg);
} else {

View file

@ -19,9 +19,10 @@ struct ProgressInner {
terminal_width: usize,
max_example_name_len: usize,
status_lines_displayed: usize,
total_examples: usize,
total_steps: usize,
failed_examples: usize,
last_line_is_logging: bool,
ticker: Option<std::thread::JoinHandle<()>>,
}
#[derive(Clone)]
@ -47,6 +48,7 @@ pub enum Step {
Predict,
Score,
Synthesize,
PullExamples,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@ -64,6 +66,7 @@ impl Step {
Step::Predict => "Predict",
Step::Score => "Score",
Step::Synthesize => "Synthesize",
Step::PullExamples => "Pull",
}
}
@ -75,6 +78,7 @@ impl Step {
Step::Predict => "\x1b[32m",
Step::Score => "\x1b[31m",
Step::Synthesize => "\x1b[36m",
Step::PullExamples => "\x1b[36m",
}
}
}
@ -84,6 +88,7 @@ static LOGGER: ProgressLogger = ProgressLogger;
const MARGIN: usize = 4;
const MAX_STATUS_LINES: usize = 10;
const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
impl Progress {
/// Returns the global Progress instance, initializing it if necessary.
@ -98,9 +103,10 @@ impl Progress {
terminal_width: get_terminal_width(),
max_example_name_len: 0,
status_lines_displayed: 0,
total_examples: 0,
total_steps: 0,
failed_examples: 0,
last_line_is_logging: false,
ticker: None,
}),
});
let _ = log::set_logger(&LOGGER);
@ -110,9 +116,9 @@ impl Progress {
.clone()
}
pub fn set_total_examples(&self, total: usize) {
pub fn set_total_steps(&self, total: usize) {
let mut inner = self.inner.lock().unwrap();
inner.total_examples = total;
inner.total_steps = total;
}
pub fn increment_failed(&self) {
@ -142,7 +148,14 @@ impl Progress {
Self::clear_status_lines(&mut inner);
inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
let max_name_width = inner
.terminal_width
.saturating_sub(MARGIN * 2)
.saturating_div(3)
.max(1);
inner.max_example_name_len = inner
.max_example_name_len
.max(example_name.len().min(max_name_width));
inner.in_progress.insert(
example_name.to_string(),
InProgressTask {
@ -153,6 +166,23 @@ impl Progress {
},
);
if inner.is_tty && inner.ticker.is_none() {
let progress = self.clone();
inner.ticker = Some(std::thread::spawn(move || {
loop {
std::thread::sleep(STATUS_TICK_INTERVAL);
let mut inner = progress.inner.lock().unwrap();
if inner.in_progress.is_empty() {
break;
}
Progress::clear_status_lines(&mut inner);
Progress::print_status_lines(&mut inner);
}
}));
}
Self::print_status_lines(&mut inner);
StepProgress {
@ -179,7 +209,9 @@ impl Progress {
Self::clear_status_lines(&mut inner);
Self::print_logging_closing_divider(&mut inner);
Self::print_completed(&inner, inner.completed.last().unwrap());
if let Some(last_completed) = inner.completed.last() {
Self::print_completed(&inner, last_completed);
}
Self::print_status_lines(&mut inner);
} else {
inner.in_progress.insert(example_name.to_string(), task);
@ -210,6 +242,7 @@ impl Progress {
fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
let duration = format_duration(task.duration);
let name_width = inner.max_example_name_len;
let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
if inner.is_tty {
let reset = "\x1b[0m";
@ -233,7 +266,7 @@ impl Progress {
"{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
color = task.step.color_code(),
label = task.step.label(),
name = task.example_name,
name = truncated_name,
);
let duration_with_margin = format!("{duration} ");
@ -255,7 +288,7 @@ impl Progress {
eprintln!(
"{label:>12} {name:<name_width$}{info_part} {duration}",
label = task.step.label(),
name = task.example_name,
name = truncate_with_ellipsis(&task.example_name, name_width),
);
}
}
@ -283,7 +316,7 @@ impl Progress {
let range_label = format!(
" {}/{}/{} ",
done_count, in_progress_count, inner.total_examples
done_count, in_progress_count, inner.total_steps
);
// Print a divider line with failed count on left, range label on right
@ -318,10 +351,11 @@ impl Progress {
let step_label = task.step.label();
let step_color = task.step.color_code();
let name_width = inner.max_example_name_len;
let truncated_name = truncate_with_ellipsis(name, name_width);
let prefix = format!(
"{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
name = name,
name = truncated_name,
);
let duration_with_margin = format!("{elapsed} ");
@ -348,6 +382,15 @@ impl Progress {
}
pub fn finalize(&self) {
let ticker = {
let mut inner = self.inner.lock().unwrap();
inner.ticker.take()
};
if let Some(ticker) = ticker {
let _ = ticker.join();
}
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);

View file

@ -0,0 +1,220 @@
use anyhow::{Context as _, Result};
use http_client::{AsyncBody, HttpClient, Method, Request};
use indoc::indoc;
use serde::Deserialize;
use serde_json::{Value as JsonValue, json};
use std::sync::Arc;
use crate::{
example::Example,
progress::{InfoStyle, Progress, Step},
};
use edit_prediction::example_spec::ExampleSpec;
const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
/// Parse an input token of the form `captured-after:{timestamp}`.
pub fn parse_captured_after_input(input: &str) -> Option<&str> {
input.strip_prefix("captured-after:")
}
pub async fn fetch_captured_examples_after(
http_client: Arc<dyn HttpClient>,
after_timestamps: &[String],
max_rows_per_timestamp: usize,
) -> Result<Vec<Example>> {
if after_timestamps.is_empty() {
return Ok(Vec::new());
}
let progress = Progress::global();
let token = std::env::var("EP_SNOWFLAKE_API_KEY")
.context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
"missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
)?;
let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
let mut all_examples = Vec::new();
for after_date in after_timestamps.iter() {
let step_progress_name = format!(">{after_date}");
let step_progress = progress.start(Step::PullExamples, &step_progress_name);
step_progress.set_substatus("querying");
let statement = indoc! {r#"
SELECT
event_properties:example AS example
FROM events
WHERE event_type = ?
AND time > TRY_TO_TIMESTAMP_NTZ(?)
ORDER BY time ASC
LIMIT ?
"#};
let request = json!({
"statement": statement,
"timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
"database": "EVENTS",
"schema": "PUBLIC",
"warehouse": "DBT",
"role": role,
"bindings": {
"1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
"2": { "type": "TEXT", "value": after_date },
"3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
}
});
let response = run_sql(http_client.clone(), &base_url, &token, &request).await?;
step_progress.set_info(format!("{} rows", response.data.len()), InfoStyle::Normal);
step_progress.set_substatus("parsing");
all_examples.extend(examples_from_response(&response)?);
step_progress.set_substatus("done");
}
Ok(all_examples)
}
#[derive(Debug, Clone, Deserialize)]
struct SnowflakeStatementResponse {
#[serde(default)]
data: Vec<Vec<JsonValue>>,
#[serde(default)]
result_set_meta_data: Option<SnowflakeResultSetMetaData>,
#[serde(default)]
code: Option<String>,
#[serde(default)]
message: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct SnowflakeResultSetMetaData {
#[serde(default, rename = "rowType")]
row_type: Vec<SnowflakeColumnMeta>,
}
#[derive(Debug, Clone, Deserialize)]
struct SnowflakeColumnMeta {
#[serde(default)]
name: String,
}
fn examples_from_response(
response: &SnowflakeStatementResponse,
) -> Result<impl Iterator<Item = Example>> {
if let Some(code) = &response.code {
if code != SNOWFLAKE_SUCCESS_CODE {
anyhow::bail!(
"snowflake sql api returned error code={code} message={}",
response.message.as_deref().unwrap_or("<no message>")
);
}
}
let example_index = response
.result_set_meta_data
.as_ref()
.and_then(|m| {
m.row_type.iter().enumerate().find_map(|(index, col)| {
if col.name.eq_ignore_ascii_case("example") {
Some(index)
} else {
None
}
})
})
.unwrap_or(0);
let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
let Some(example_value) = data_row.get(example_index) else {
return None;
};
if example_value.is_null() {
return None;
}
let parse_result = match example_value {
JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
_ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
};
match parse_result {
Ok(spec) => Some(Example {
spec,
buffer: None,
context: None,
prompt: None,
predictions: Vec::new(),
score: Vec::new(),
state: None,
}),
Err(error) => {
let raw_json = serde_json::to_string_pretty(example_value)
.unwrap_or_else(|_| "<failed to serialize json>".to_string());
log::error!(
"failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
);
None
}
}
});
Ok(iter)
}
async fn run_sql(
http_client: Arc<dyn HttpClient>,
base_url: &str,
token: &str,
request: &serde_json::Value,
) -> Result<SnowflakeStatementResponse> {
let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
let request_body =
serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
let http_request = Request::builder()
.method(Method::POST)
.uri(url.as_str())
.header("Authorization", format!("Bearer {token}"))
.header(
"X-Snowflake-Authorization-Token-Type",
"PROGRAMMATIC_ACCESS_TOKEN",
)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.body(AsyncBody::from(request_body.clone()))?;
let response = http_client
.send(http_request)
.await
.context("failed to send request to Snowflake SQL API")?;
let status = response.status();
let body_bytes = {
use futures::AsyncReadExt as _;
let mut body = response.into_body();
let mut bytes = Vec::new();
body.read_to_end(&mut bytes)
.await
.context("failed to read Snowflake SQL API response body")?;
bytes
};
if !status.is_success() {
let body_text = String::from_utf8_lossy(&body_bytes);
anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
}
serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
.context("failed to parse Snowflake SQL API response JSON")
}

View file

@ -16,7 +16,7 @@ use std::fs;
use std::io::{self, Read};
/// `ep split-commit` CLI args.
#[derive(Debug, Args)]
#[derive(Debug, Args, Clone)]
pub struct SplitCommitArgs {
/// Path to the commit file (use "-" for stdin)
#[arg(long, short = 'c')]

View file

@ -108,7 +108,7 @@ pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
std::os::windows::fs::symlink_dir(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?;
let progress = Progress::global();
progress.set_total_examples(config.count);
progress.set_total_steps(config.count);
let clone_progress = progress.start(Step::Synthesize, "clone");
let repo_path = ensure_repo_cloned(&config.repo_url).await?;