mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
ep cli: Load captured examples from Snowflake (#46102)
Release Notes: - N/A
This commit is contained in:
parent
216933f0b8
commit
583a479f77
6 changed files with 376 additions and 36 deletions
|
|
@ -125,17 +125,9 @@ impl Example {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
|
||||
pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
|
||||
let mut examples = Vec::new();
|
||||
|
||||
let stdin_path: PathBuf = PathBuf::from("-");
|
||||
|
||||
let inputs = if inputs.is_empty() {
|
||||
&[stdin_path]
|
||||
} else {
|
||||
inputs
|
||||
};
|
||||
|
||||
for path in inputs {
|
||||
let is_stdin = path.as_path() == Path::new("-");
|
||||
let content = if is_stdin {
|
||||
|
|
@ -201,7 +193,6 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
|
|||
}
|
||||
}
|
||||
|
||||
sort_examples_by_repo_and_rev(&mut examples);
|
||||
examples
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,12 +9,12 @@ mod metrics;
|
|||
mod paths;
|
||||
mod predict;
|
||||
mod progress;
|
||||
mod pull_examples;
|
||||
mod reorder_patch;
|
||||
mod retrieve_context;
|
||||
mod score;
|
||||
mod split_commit;
|
||||
mod synthesize;
|
||||
|
||||
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use gpui::Application;
|
||||
|
|
@ -24,7 +24,7 @@ use std::fmt::Display;
|
|||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
use crate::distill::run_distill;
|
||||
use crate::example::{group_examples_by_repo, read_examples, write_examples};
|
||||
use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
|
||||
use crate::format_prompt::run_format_prompt;
|
||||
use crate::load_project::run_load_project;
|
||||
use crate::paths::FAILED_EXAMPLES_DIR;
|
||||
|
|
@ -42,9 +42,11 @@ struct EpArgs {
|
|||
printenv: bool,
|
||||
#[clap(long, default_value_t = 10, global = true)]
|
||||
max_parallelism: usize,
|
||||
#[clap(long, global = true)]
|
||||
limit: Option<usize>,
|
||||
#[command(subcommand)]
|
||||
command: Option<Command>,
|
||||
#[clap(global = true)]
|
||||
#[clap(global = true, help = INPUTS_HELP)]
|
||||
inputs: Vec<PathBuf>,
|
||||
#[arg(long, short, global = true)]
|
||||
output: Option<PathBuf>,
|
||||
|
|
@ -54,7 +56,37 @@ struct EpArgs {
|
|||
failfast: bool,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
const INPUTS_HELP: &str = r#"
|
||||
Inputs can be file paths or special specifiers:
|
||||
|
||||
path
|
||||
Path to an example(s) file (.md, .json, or .jsonl)
|
||||
|
||||
captured-after:{timestamp}
|
||||
Fetch captured examples from Snowflake after the given RFC3339 timestamp.
|
||||
|
||||
You can specify this multiple times and mix it with file inputs.
|
||||
|
||||
Required environment variables to connect to Snowflake:
|
||||
EP_SNOWFLAKE_API_KEY
|
||||
EP_SNOWFLAKE_BASE_URL
|
||||
|
||||
Optional:
|
||||
EP_SNOWFLAKE_ROLE
|
||||
|
||||
Examples:
|
||||
|
||||
# Predict from a file
|
||||
ep predict examples.jsonl
|
||||
|
||||
# Predict from captured examples after a timestamp
|
||||
ep predict captured-after:2025-01-01T00:00:00Z
|
||||
|
||||
# Mix file inputs and captured-after in the same invocation
|
||||
ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
|
||||
"#;
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Command {
|
||||
/// Parse markdown examples and output a combined .jsonl file
|
||||
ParseExample,
|
||||
|
|
@ -137,7 +169,7 @@ impl Display for Command {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
#[derive(Debug, Args, Clone)]
|
||||
struct FormatPromptArgs {
|
||||
#[clap(long)]
|
||||
prompt_format: PromptFormat,
|
||||
|
|
@ -149,7 +181,7 @@ enum PromptFormat {
|
|||
Zeta2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
#[derive(Debug, Args, Clone)]
|
||||
struct PredictArgs {
|
||||
#[clap(long)]
|
||||
provider: PredictionProvider,
|
||||
|
|
@ -167,7 +199,7 @@ enum PredictionProvider {
|
|||
TeacherNonBatching,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
#[derive(Debug, Args, Clone)]
|
||||
struct SynthesizeArgs {
|
||||
/// Repository URL (git@github.com:owner/repo or https://...)
|
||||
#[clap(long)]
|
||||
|
|
@ -200,6 +232,60 @@ impl EpArgs {
|
|||
}
|
||||
}
|
||||
|
||||
async fn load_examples(
|
||||
http_client: Arc<dyn http_client::HttpClient>,
|
||||
args: &EpArgs,
|
||||
) -> anyhow::Result<Vec<Example>> {
|
||||
let mut captured_after_timestamps = Vec::new();
|
||||
let mut file_inputs = Vec::new();
|
||||
|
||||
for input in &args.inputs {
|
||||
let input_string = input.to_string_lossy();
|
||||
if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
|
||||
captured_after_timestamps.push(timestamp.to_string());
|
||||
} else {
|
||||
file_inputs.push(input.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut examples = read_example_files(&file_inputs);
|
||||
let total_steps = examples.len() + captured_after_timestamps.len();
|
||||
Progress::global().set_total_steps(total_steps);
|
||||
|
||||
let remaining_limit_for_snowflake =
|
||||
args.limit.map(|limit| limit.saturating_sub(examples.len()));
|
||||
|
||||
if let Some(0) = remaining_limit_for_snowflake {
|
||||
log::info!(
|
||||
"skipping captured-after inputs because --limit is already satisfied by example files"
|
||||
);
|
||||
} else if !captured_after_timestamps.is_empty() {
|
||||
captured_after_timestamps.sort();
|
||||
|
||||
let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
|
||||
|
||||
let mut captured_examples = pull_examples::fetch_captured_examples_after(
|
||||
http_client,
|
||||
&captured_after_timestamps,
|
||||
max_rows_per_timestamp,
|
||||
)
|
||||
.await?;
|
||||
examples.append(&mut captured_examples);
|
||||
}
|
||||
|
||||
crate::example::sort_examples_by_repo_and_rev(&mut examples);
|
||||
|
||||
if let Some(limit) = args.limit {
|
||||
if examples.len() > limit {
|
||||
examples.truncate(limit);
|
||||
}
|
||||
}
|
||||
|
||||
Progress::global().set_total_steps(examples.len() + captured_after_timestamps.len());
|
||||
|
||||
Ok(examples)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args = EpArgs::parse();
|
||||
|
||||
|
|
@ -209,8 +295,8 @@ fn main() {
|
|||
}
|
||||
|
||||
let output = args.output_path();
|
||||
let command = match args.command {
|
||||
Some(cmd) => cmd,
|
||||
let command = match &args.command {
|
||||
Some(cmd) => cmd.clone(),
|
||||
None => {
|
||||
EpArgs::command().print_help().unwrap();
|
||||
return;
|
||||
|
|
@ -251,7 +337,6 @@ fn main() {
|
|||
_ => {}
|
||||
}
|
||||
|
||||
let mut examples = read_examples(&args.inputs);
|
||||
let http_client = Arc::new(ReqwestClient::new());
|
||||
let app = Application::headless().with_http_client(http_client);
|
||||
|
||||
|
|
@ -261,12 +346,13 @@ fn main() {
|
|||
|
||||
cx.spawn(async move |cx| {
|
||||
let result = async {
|
||||
let mut examples = load_examples(app_state.client.http_client(), &args).await?;
|
||||
|
||||
if let Command::Predict(args) = &command {
|
||||
predict::sync_batches(&args.provider).await?;
|
||||
}
|
||||
|
||||
let total_examples = examples.len();
|
||||
Progress::global().set_total_examples(total_examples);
|
||||
let failfast_on_single_example = examples.len() == 1;
|
||||
|
||||
let mut grouped_examples = group_examples_by_repo(&mut examples);
|
||||
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
|
||||
|
|
@ -347,7 +433,7 @@ fn main() {
|
|||
|
||||
let msg = format!(
|
||||
indoc::indoc! {"
|
||||
While processing {}:
|
||||
While processing \"{}\":
|
||||
|
||||
{:?}
|
||||
|
||||
|
|
@ -366,7 +452,7 @@ fn main() {
|
|||
command,
|
||||
failed_example_path.display(),
|
||||
);
|
||||
if args.failfast || total_examples == 1 {
|
||||
if args.failfast || failfast_on_single_example {
|
||||
Progress::global().finalize();
|
||||
panic!("{}", msg);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@ struct ProgressInner {
|
|||
terminal_width: usize,
|
||||
max_example_name_len: usize,
|
||||
status_lines_displayed: usize,
|
||||
total_examples: usize,
|
||||
total_steps: usize,
|
||||
failed_examples: usize,
|
||||
last_line_is_logging: bool,
|
||||
ticker: Option<std::thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
@ -47,6 +48,7 @@ pub enum Step {
|
|||
Predict,
|
||||
Score,
|
||||
Synthesize,
|
||||
PullExamples,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
|
|
@ -64,6 +66,7 @@ impl Step {
|
|||
Step::Predict => "Predict",
|
||||
Step::Score => "Score",
|
||||
Step::Synthesize => "Synthesize",
|
||||
Step::PullExamples => "Pull",
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -75,6 +78,7 @@ impl Step {
|
|||
Step::Predict => "\x1b[32m",
|
||||
Step::Score => "\x1b[31m",
|
||||
Step::Synthesize => "\x1b[36m",
|
||||
Step::PullExamples => "\x1b[36m",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -84,6 +88,7 @@ static LOGGER: ProgressLogger = ProgressLogger;
|
|||
|
||||
const MARGIN: usize = 4;
|
||||
const MAX_STATUS_LINES: usize = 10;
|
||||
const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
|
||||
|
||||
impl Progress {
|
||||
/// Returns the global Progress instance, initializing it if necessary.
|
||||
|
|
@ -98,9 +103,10 @@ impl Progress {
|
|||
terminal_width: get_terminal_width(),
|
||||
max_example_name_len: 0,
|
||||
status_lines_displayed: 0,
|
||||
total_examples: 0,
|
||||
total_steps: 0,
|
||||
failed_examples: 0,
|
||||
last_line_is_logging: false,
|
||||
ticker: None,
|
||||
}),
|
||||
});
|
||||
let _ = log::set_logger(&LOGGER);
|
||||
|
|
@ -110,9 +116,9 @@ impl Progress {
|
|||
.clone()
|
||||
}
|
||||
|
||||
pub fn set_total_examples(&self, total: usize) {
|
||||
pub fn set_total_steps(&self, total: usize) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.total_examples = total;
|
||||
inner.total_steps = total;
|
||||
}
|
||||
|
||||
pub fn increment_failed(&self) {
|
||||
|
|
@ -142,7 +148,14 @@ impl Progress {
|
|||
|
||||
Self::clear_status_lines(&mut inner);
|
||||
|
||||
inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
|
||||
let max_name_width = inner
|
||||
.terminal_width
|
||||
.saturating_sub(MARGIN * 2)
|
||||
.saturating_div(3)
|
||||
.max(1);
|
||||
inner.max_example_name_len = inner
|
||||
.max_example_name_len
|
||||
.max(example_name.len().min(max_name_width));
|
||||
inner.in_progress.insert(
|
||||
example_name.to_string(),
|
||||
InProgressTask {
|
||||
|
|
@ -153,6 +166,23 @@ impl Progress {
|
|||
},
|
||||
);
|
||||
|
||||
if inner.is_tty && inner.ticker.is_none() {
|
||||
let progress = self.clone();
|
||||
inner.ticker = Some(std::thread::spawn(move || {
|
||||
loop {
|
||||
std::thread::sleep(STATUS_TICK_INTERVAL);
|
||||
|
||||
let mut inner = progress.inner.lock().unwrap();
|
||||
if inner.in_progress.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
Progress::clear_status_lines(&mut inner);
|
||||
Progress::print_status_lines(&mut inner);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
Self::print_status_lines(&mut inner);
|
||||
|
||||
StepProgress {
|
||||
|
|
@ -179,7 +209,9 @@ impl Progress {
|
|||
|
||||
Self::clear_status_lines(&mut inner);
|
||||
Self::print_logging_closing_divider(&mut inner);
|
||||
Self::print_completed(&inner, inner.completed.last().unwrap());
|
||||
if let Some(last_completed) = inner.completed.last() {
|
||||
Self::print_completed(&inner, last_completed);
|
||||
}
|
||||
Self::print_status_lines(&mut inner);
|
||||
} else {
|
||||
inner.in_progress.insert(example_name.to_string(), task);
|
||||
|
|
@ -210,6 +242,7 @@ impl Progress {
|
|||
fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
|
||||
let duration = format_duration(task.duration);
|
||||
let name_width = inner.max_example_name_len;
|
||||
let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
|
||||
|
||||
if inner.is_tty {
|
||||
let reset = "\x1b[0m";
|
||||
|
|
@ -233,7 +266,7 @@ impl Progress {
|
|||
"{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
|
||||
color = task.step.color_code(),
|
||||
label = task.step.label(),
|
||||
name = task.example_name,
|
||||
name = truncated_name,
|
||||
);
|
||||
|
||||
let duration_with_margin = format!("{duration} ");
|
||||
|
|
@ -255,7 +288,7 @@ impl Progress {
|
|||
eprintln!(
|
||||
"{label:>12} {name:<name_width$}{info_part} {duration}",
|
||||
label = task.step.label(),
|
||||
name = task.example_name,
|
||||
name = truncate_with_ellipsis(&task.example_name, name_width),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -283,7 +316,7 @@ impl Progress {
|
|||
|
||||
let range_label = format!(
|
||||
" {}/{}/{} ",
|
||||
done_count, in_progress_count, inner.total_examples
|
||||
done_count, in_progress_count, inner.total_steps
|
||||
);
|
||||
|
||||
// Print a divider line with failed count on left, range label on right
|
||||
|
|
@ -318,10 +351,11 @@ impl Progress {
|
|||
let step_label = task.step.label();
|
||||
let step_color = task.step.color_code();
|
||||
let name_width = inner.max_example_name_len;
|
||||
let truncated_name = truncate_with_ellipsis(name, name_width);
|
||||
|
||||
let prefix = format!(
|
||||
"{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
|
||||
name = name,
|
||||
name = truncated_name,
|
||||
);
|
||||
|
||||
let duration_with_margin = format!("{elapsed} ");
|
||||
|
|
@ -348,6 +382,15 @@ impl Progress {
|
|||
}
|
||||
|
||||
pub fn finalize(&self) {
|
||||
let ticker = {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.ticker.take()
|
||||
};
|
||||
|
||||
if let Some(ticker) = ticker {
|
||||
let _ = ticker.join();
|
||||
}
|
||||
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
Self::clear_status_lines(&mut inner);
|
||||
|
||||
|
|
|
|||
220
crates/edit_prediction_cli/src/pull_examples.rs
Normal file
220
crates/edit_prediction_cli/src/pull_examples.rs
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
use anyhow::{Context as _, Result};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request};
|
||||
use indoc::indoc;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Value as JsonValue, json};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
example::Example,
|
||||
progress::{InfoStyle, Progress, Step},
|
||||
};
|
||||
use edit_prediction::example_spec::ExampleSpec;
|
||||
|
||||
const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
|
||||
const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
|
||||
|
||||
const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
|
||||
|
||||
/// Parse an input token of the form `captured-after:{timestamp}`.
|
||||
pub fn parse_captured_after_input(input: &str) -> Option<&str> {
|
||||
input.strip_prefix("captured-after:")
|
||||
}
|
||||
|
||||
pub async fn fetch_captured_examples_after(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
after_timestamps: &[String],
|
||||
max_rows_per_timestamp: usize,
|
||||
) -> Result<Vec<Example>> {
|
||||
if after_timestamps.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let progress = Progress::global();
|
||||
|
||||
let token = std::env::var("EP_SNOWFLAKE_API_KEY")
|
||||
.context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
|
||||
let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
|
||||
"missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
|
||||
)?;
|
||||
let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
|
||||
|
||||
let mut all_examples = Vec::new();
|
||||
|
||||
for after_date in after_timestamps.iter() {
|
||||
let step_progress_name = format!(">{after_date}");
|
||||
let step_progress = progress.start(Step::PullExamples, &step_progress_name);
|
||||
step_progress.set_substatus("querying");
|
||||
|
||||
let statement = indoc! {r#"
|
||||
SELECT
|
||||
event_properties:example AS example
|
||||
FROM events
|
||||
WHERE event_type = ?
|
||||
AND time > TRY_TO_TIMESTAMP_NTZ(?)
|
||||
ORDER BY time ASC
|
||||
LIMIT ?
|
||||
"#};
|
||||
|
||||
let request = json!({
|
||||
"statement": statement,
|
||||
"timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
|
||||
"database": "EVENTS",
|
||||
"schema": "PUBLIC",
|
||||
"warehouse": "DBT",
|
||||
"role": role,
|
||||
"bindings": {
|
||||
"1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
|
||||
"2": { "type": "TEXT", "value": after_date },
|
||||
"3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
|
||||
}
|
||||
});
|
||||
|
||||
let response = run_sql(http_client.clone(), &base_url, &token, &request).await?;
|
||||
|
||||
step_progress.set_info(format!("{} rows", response.data.len()), InfoStyle::Normal);
|
||||
step_progress.set_substatus("parsing");
|
||||
|
||||
all_examples.extend(examples_from_response(&response)?);
|
||||
|
||||
step_progress.set_substatus("done");
|
||||
}
|
||||
|
||||
Ok(all_examples)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct SnowflakeStatementResponse {
|
||||
#[serde(default)]
|
||||
data: Vec<Vec<JsonValue>>,
|
||||
#[serde(default)]
|
||||
result_set_meta_data: Option<SnowflakeResultSetMetaData>,
|
||||
#[serde(default)]
|
||||
code: Option<String>,
|
||||
#[serde(default)]
|
||||
message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct SnowflakeResultSetMetaData {
|
||||
#[serde(default, rename = "rowType")]
|
||||
row_type: Vec<SnowflakeColumnMeta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct SnowflakeColumnMeta {
|
||||
#[serde(default)]
|
||||
name: String,
|
||||
}
|
||||
|
||||
fn examples_from_response(
|
||||
response: &SnowflakeStatementResponse,
|
||||
) -> Result<impl Iterator<Item = Example>> {
|
||||
if let Some(code) = &response.code {
|
||||
if code != SNOWFLAKE_SUCCESS_CODE {
|
||||
anyhow::bail!(
|
||||
"snowflake sql api returned error code={code} message={}",
|
||||
response.message.as_deref().unwrap_or("<no message>")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let example_index = response
|
||||
.result_set_meta_data
|
||||
.as_ref()
|
||||
.and_then(|m| {
|
||||
m.row_type.iter().enumerate().find_map(|(index, col)| {
|
||||
if col.name.eq_ignore_ascii_case("example") {
|
||||
Some(index)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.unwrap_or(0);
|
||||
|
||||
let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
|
||||
let Some(example_value) = data_row.get(example_index) else {
|
||||
return None;
|
||||
};
|
||||
if example_value.is_null() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parse_result = match example_value {
|
||||
JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
|
||||
_ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
|
||||
};
|
||||
|
||||
match parse_result {
|
||||
Ok(spec) => Some(Example {
|
||||
spec,
|
||||
buffer: None,
|
||||
context: None,
|
||||
prompt: None,
|
||||
predictions: Vec::new(),
|
||||
score: Vec::new(),
|
||||
state: None,
|
||||
}),
|
||||
Err(error) => {
|
||||
let raw_json = serde_json::to_string_pretty(example_value)
|
||||
.unwrap_or_else(|_| "<failed to serialize json>".to_string());
|
||||
log::error!(
|
||||
"failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(iter)
|
||||
}
|
||||
|
||||
async fn run_sql(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
base_url: &str,
|
||||
token: &str,
|
||||
request: &serde_json::Value,
|
||||
) -> Result<SnowflakeStatementResponse> {
|
||||
let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
|
||||
|
||||
let request_body =
|
||||
serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
|
||||
|
||||
let http_request = Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(url.as_str())
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.header(
|
||||
"X-Snowflake-Authorization-Token-Type",
|
||||
"PROGRAMMATIC_ACCESS_TOKEN",
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
.body(AsyncBody::from(request_body.clone()))?;
|
||||
|
||||
let response = http_client
|
||||
.send(http_request)
|
||||
.await
|
||||
.context("failed to send request to Snowflake SQL API")?;
|
||||
|
||||
let status = response.status();
|
||||
let body_bytes = {
|
||||
use futures::AsyncReadExt as _;
|
||||
|
||||
let mut body = response.into_body();
|
||||
let mut bytes = Vec::new();
|
||||
body.read_to_end(&mut bytes)
|
||||
.await
|
||||
.context("failed to read Snowflake SQL API response body")?;
|
||||
bytes
|
||||
};
|
||||
|
||||
if !status.is_success() {
|
||||
let body_text = String::from_utf8_lossy(&body_bytes);
|
||||
anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
|
||||
}
|
||||
|
||||
serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
|
||||
.context("failed to parse Snowflake SQL API response JSON")
|
||||
}
|
||||
|
|
@ -16,7 +16,7 @@ use std::fs;
|
|||
use std::io::{self, Read};
|
||||
|
||||
/// `ep split-commit` CLI args.
|
||||
#[derive(Debug, Args)]
|
||||
#[derive(Debug, Args, Clone)]
|
||||
pub struct SplitCommitArgs {
|
||||
/// Path to the commit file (use "-" for stdin)
|
||||
#[arg(long, short = 'c')]
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
|
|||
std::os::windows::fs::symlink_dir(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?;
|
||||
|
||||
let progress = Progress::global();
|
||||
progress.set_total_examples(config.count);
|
||||
progress.set_total_steps(config.count);
|
||||
|
||||
let clone_progress = progress.start(Step::Synthesize, "clone");
|
||||
let repo_path = ensure_repo_cloned(&config.repo_url).await?;
|
||||
|
|
|
|||
Loading…
Reference in a new issue