mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
Track additional metrics in settled (#52938)
Stacked on https://github.com/zed-industries/zed/pull/50566. Begin collecting kept chars rate, as well as the count of tree-sitter errors in the code before and after applying the prediction. Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [ ] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A or Added/Fixed/Improved ...
This commit is contained in:
parent
364ebfcc07
commit
7597666c08
16 changed files with 905 additions and 129 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
|
@ -5176,6 +5176,7 @@ dependencies = [
|
|||
"copilot",
|
||||
"copilot_ui",
|
||||
"credentials_provider",
|
||||
"criterion",
|
||||
"ctor",
|
||||
"db",
|
||||
"edit_prediction_context",
|
||||
|
|
@ -5189,9 +5190,11 @@ dependencies = [
|
|||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"languages",
|
||||
"log",
|
||||
"lsp",
|
||||
"menu",
|
||||
"node_runtime",
|
||||
"open_ai",
|
||||
"parking_lot",
|
||||
"postage",
|
||||
|
|
@ -5235,7 +5238,6 @@ dependencies = [
|
|||
"client",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"criterion",
|
||||
"db",
|
||||
"debug_adapter_extension",
|
||||
"dirs 4.0.0",
|
||||
|
|
|
|||
|
|
@ -72,10 +72,14 @@ zeta_prompt.workspace = true
|
|||
zstd.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
criterion.workspace = true
|
||||
fs = { workspace = true, features = ["test-support"] }
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
languages = { workspace = true, features = ["load-grammars"] }
|
||||
node_runtime.workspace = true
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
cloud_llm_client = { workspace = true, features = ["test-support"] }
|
||||
ctor.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
language_model = { workspace = true, features = ["test-support"] }
|
||||
|
|
@ -86,3 +90,11 @@ settings = { workspace = true, features = ["test-support"] }
|
|||
workspace = { workspace = true, features = ["test-support"] }
|
||||
|
||||
zlog.workspace = true
|
||||
|
||||
[[bench]]
|
||||
name = "kept_rate"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "ts_error_count"
|
||||
harness = false
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
|
||||
use edit_prediction_cli::kept_rate::compute_kept_rate;
|
||||
use edit_prediction::metrics::compute_kept_rate;
|
||||
|
||||
fn repeated_function_lines(line_count: usize) -> String {
|
||||
let mut text = String::with_capacity(line_count * 32);
|
||||
454
crates/edit_prediction/benches/ts_error_count.rs
Normal file
454
crates/edit_prediction/benches/ts_error_count.rs
Normal file
|
|
@ -0,0 +1,454 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
|
||||
use edit_prediction::metrics::count_tree_sitter_errors;
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext as _, TestAppContext};
|
||||
use language::{Buffer, BufferSnapshot, LanguageRegistry};
|
||||
use languages::init as init_languages;
|
||||
use node_runtime::NodeRuntime;
|
||||
use settings::SettingsStore;
|
||||
|
||||
struct ParsedCase {
|
||||
label: String,
|
||||
bytes: usize,
|
||||
error_count: usize,
|
||||
snapshot: BufferSnapshot,
|
||||
}
|
||||
|
||||
fn replace_nth_occurrences(
|
||||
source: &mut String,
|
||||
needle: &str,
|
||||
replacement: &str,
|
||||
every: usize,
|
||||
max_replacements: usize,
|
||||
) {
|
||||
let mut rebuilt = String::with_capacity(source.len());
|
||||
let mut cursor = 0;
|
||||
let mut seen = 0;
|
||||
let mut replaced = 0;
|
||||
|
||||
while let Some(relative_index) = source[cursor..].find(needle) {
|
||||
let start = cursor + relative_index;
|
||||
let end = start + needle.len();
|
||||
rebuilt.push_str(&source[cursor..start]);
|
||||
|
||||
if seen % every == 0 && replaced < max_replacements {
|
||||
rebuilt.push_str(replacement);
|
||||
replaced += 1;
|
||||
} else {
|
||||
rebuilt.push_str(needle);
|
||||
}
|
||||
|
||||
seen += 1;
|
||||
cursor = end;
|
||||
}
|
||||
|
||||
rebuilt.push_str(&source[cursor..]);
|
||||
*source = rebuilt;
|
||||
}
|
||||
|
||||
fn rust_source(function_count: usize) -> String {
|
||||
let mut source = String::from(
|
||||
"pub struct Counter {\n value: usize,\n}\n\nimpl Counter {\n pub fn new() -> Self {\n Self { value: 0 }\n }\n}\n\n",
|
||||
);
|
||||
for index in 0..function_count {
|
||||
source.push_str(&format!(
|
||||
"pub fn compute_value_{index}(input: usize) -> usize {{\n let mut total = input;\n for offset in 0..32 {{\n total += offset + {index};\n }}\n if total % 2 == 0 {{\n total / 2\n }} else {{\n total * 3 + 1\n }}\n}}\n\n"
|
||||
));
|
||||
}
|
||||
source
|
||||
}
|
||||
|
||||
fn rust_source_with_errors(function_count: usize) -> String {
|
||||
let mut source = rust_source(function_count);
|
||||
replace_nth_occurrences(
|
||||
&mut source,
|
||||
" if total % 2 == 0 {\n",
|
||||
" if total % 2 == 0 \n",
|
||||
17,
|
||||
48,
|
||||
);
|
||||
source
|
||||
}
|
||||
|
||||
fn python_source(function_count: usize) -> String {
|
||||
let mut source = String::from(
|
||||
"class Counter:\n def __init__(self) -> None:\n self.value = 0\n\n\n",
|
||||
);
|
||||
for index in 0..function_count {
|
||||
source.push_str(&format!(
|
||||
"def compute_value_{index}(input_value: int) -> int:\n total = input_value\n for offset in range(32):\n total += offset + {index}\n if total % 2 == 0:\n return total // 2\n return total * 3 + 1\n\n"
|
||||
));
|
||||
}
|
||||
source
|
||||
}
|
||||
|
||||
fn python_source_with_errors(function_count: usize) -> String {
|
||||
let mut source = python_source(function_count);
|
||||
replace_nth_occurrences(
|
||||
&mut source,
|
||||
" if total % 2 == 0:\n",
|
||||
" if total % 2 == 0\n",
|
||||
19,
|
||||
48,
|
||||
);
|
||||
source
|
||||
}
|
||||
|
||||
fn go_source(function_count: usize) -> String {
|
||||
let mut source = String::from(
|
||||
"package bench\n\ntype Counter struct {\n\tvalue int\n}\n\nfunc NewCounter() Counter {\n\treturn Counter{value: 0}\n}\n\n",
|
||||
);
|
||||
for index in 0..function_count {
|
||||
source.push_str(&format!(
|
||||
"func ComputeValue{index}(inputValue int) int {{\n\ttotal := inputValue\n\tfor offset := 0; offset < 32; offset++ {{\n\t\ttotal += offset + {index}\n\t}}\n\tif total%2 == 0 {{\n\t\treturn total / 2\n\t}}\n\treturn total*3 + 1\n}}\n\n"
|
||||
));
|
||||
}
|
||||
source
|
||||
}
|
||||
|
||||
fn go_source_with_errors(function_count: usize) -> String {
|
||||
let mut source = go_source(function_count);
|
||||
replace_nth_occurrences(
|
||||
&mut source,
|
||||
"\tfor offset := 0; offset < 32; offset++ {\n",
|
||||
"\tfor offset := 0; offset < 32; offset++ \n",
|
||||
17,
|
||||
48,
|
||||
);
|
||||
source
|
||||
}
|
||||
|
||||
fn typescript_source(function_count: usize) -> String {
|
||||
let mut source = String::from(
|
||||
"export type Counter = { value: number };\n\nexport function newCounter(): Counter {\n return { value: 0 };\n}\n\n",
|
||||
);
|
||||
for index in 0..function_count {
|
||||
source.push_str(&format!(
|
||||
"export function computeValue{index}(inputValue: number): number {{\n let total = inputValue;\n for (let offset = 0; offset < 32; offset += 1) {{\n total += offset + {index};\n }}\n return total % 2 === 0 ? total / 2 : total * 3 + 1;\n}}\n\n"
|
||||
));
|
||||
}
|
||||
source
|
||||
}
|
||||
|
||||
fn typescript_source_with_errors(function_count: usize) -> String {
|
||||
let mut source = typescript_source(function_count);
|
||||
replace_nth_occurrences(
|
||||
&mut source,
|
||||
" return total % 2 === 0 ? total / 2 : total * 3 + 1;\n",
|
||||
" return total % 2 === 0 ? total / 2 : ;\n",
|
||||
17,
|
||||
64,
|
||||
);
|
||||
source
|
||||
}
|
||||
|
||||
fn tsx_source(component_count: usize) -> String {
|
||||
let mut source = String::from(
|
||||
"type ItemProps = { index: number; label: string };\n\nfunction Item({ index, label }: ItemProps) {\n return <li data-index={index}>{label}</li>;\n}\n\nexport function App() {\n return <section><ul>{[0, 1, 2].map((value) => <Item key={value} index={value} label={`item-${value}`} />)}</ul></section>;\n}\n\n",
|
||||
);
|
||||
for index in 0..component_count {
|
||||
source.push_str(&format!(
|
||||
"export function Widget{index}(): JSX.Element {{\n const items = Array.from({{ length: 16 }}, (_, value) => value + {index});\n return (\n <div className=\"widget-{index}\">\n <h2>Widget {index}</h2>\n <ul>\n {{items.map((value) => (\n <Item key={{value}} index={{value}} label={{`widget-{index}-${{value}}`}} />\n ))}}\n </ul>\n </div>\n );\n}}\n\n"
|
||||
));
|
||||
}
|
||||
source
|
||||
}
|
||||
|
||||
fn tsx_source_with_errors(component_count: usize) -> String {
|
||||
let mut source = tsx_source(component_count);
|
||||
replace_nth_occurrences(
|
||||
&mut source,
|
||||
" const items = Array.from({ length: 16 }, (_, value) => value + ",
|
||||
" const items = Array.from({ length: 16 }, (_, value) => ); // ",
|
||||
11,
|
||||
32,
|
||||
);
|
||||
source
|
||||
}
|
||||
|
||||
fn json_source(object_count: usize) -> String {
|
||||
let mut source = String::from("{\n \"items\": [\n");
|
||||
for index in 0..object_count {
|
||||
let suffix = if index + 1 == object_count { "" } else { "," };
|
||||
source.push_str(&format!(
|
||||
" {{\n \"id\": {index},\n \"name\": \"item-{index}\",\n \"enabled\": true,\n \"tags\": [\"alpha\", \"beta\", \"gamma\"],\n \"metrics\": {{ \"count\": {}, \"ratio\": {} }}\n }}{suffix}\n",
|
||||
index * 3 + 1,
|
||||
index as f64 / 10.0,
|
||||
));
|
||||
}
|
||||
source.push_str(" ]\n}\n");
|
||||
source
|
||||
}
|
||||
|
||||
fn json_source_with_errors(object_count: usize) -> String {
|
||||
let mut source = json_source(object_count);
|
||||
replace_nth_occurrences(
|
||||
&mut source,
|
||||
" \"enabled\": true,\n",
|
||||
" \"enabled\": ,\n",
|
||||
23,
|
||||
64,
|
||||
);
|
||||
source
|
||||
}
|
||||
|
||||
fn yaml_source(document_count: usize) -> String {
|
||||
let mut source = String::new();
|
||||
for index in 0..document_count {
|
||||
source.push_str(&format!(
|
||||
"- id: {index}\n name: item-{index}\n enabled: true\n tags:\n - alpha\n - beta\n - gamma\n metrics:\n count: {}\n ratio: {}\n",
|
||||
index * 3 + 1,
|
||||
index as f64 / 10.0,
|
||||
));
|
||||
}
|
||||
source
|
||||
}
|
||||
|
||||
fn yaml_source_with_errors(document_count: usize) -> String {
|
||||
let mut source = yaml_source(document_count);
|
||||
replace_nth_occurrences(&mut source, " count: ", " count ", 23, 64);
|
||||
source
|
||||
}
|
||||
|
||||
fn css_source(rule_count: usize) -> String {
|
||||
let mut source = String::new();
|
||||
for index in 0..rule_count {
|
||||
source.push_str(&format!(
|
||||
".widget-{index} {{\n display: grid;\n grid-template-columns: repeat(4, minmax(0, 1fr));\n gap: 12px;\n padding: 8px;\n color: rgb({}, {}, {});\n}}\n\n.widget-{index} > .item-{index} {{\n border: 1px solid rgba(0, 0, 0, 0.15);\n background: linear-gradient(90deg, #fff, #eef);\n}}\n\n",
|
||||
(index * 17) % 255,
|
||||
(index * 31) % 255,
|
||||
(index * 47) % 255,
|
||||
));
|
||||
}
|
||||
source
|
||||
}
|
||||
|
||||
fn css_source_with_errors(rule_count: usize) -> String {
|
||||
let mut source = css_source(rule_count);
|
||||
replace_nth_occurrences(&mut source, " gap: 12px;\n", " gap 12px;\n", 29, 64);
|
||||
source
|
||||
}
|
||||
|
||||
fn build_case(
|
||||
context: &mut TestAppContext,
|
||||
languages: &Arc<LanguageRegistry>,
|
||||
language_name: &'static str,
|
||||
variant_name: &'static str,
|
||||
source: String,
|
||||
expect_errors: bool,
|
||||
) -> ParsedCase {
|
||||
let language_task = context.background_spawn({
|
||||
let languages = languages.clone();
|
||||
async move { languages.language_for_name(language_name).await }
|
||||
});
|
||||
while !language_task.is_ready() {
|
||||
context.run_until_parked();
|
||||
}
|
||||
let language = futures::executor::block_on(language_task)
|
||||
.unwrap_or_else(|error| panic!("failed to load {language_name}: {error}"));
|
||||
|
||||
let buffer = context.new(|cx| Buffer::local(source, cx).with_language(language, cx));
|
||||
context.run_until_parked();
|
||||
while buffer.read_with(context, |buffer, _| buffer.is_parsing()) {
|
||||
context.run_until_parked();
|
||||
}
|
||||
|
||||
let snapshot = buffer.read_with(context, |buffer, _| buffer.snapshot());
|
||||
let full_range = 0..snapshot.text.len();
|
||||
let error_count = count_tree_sitter_errors(snapshot.syntax_layers());
|
||||
if expect_errors {
|
||||
assert!(
|
||||
error_count > 0,
|
||||
"expected tree-sitter errors for {language_name}/{variant_name}",
|
||||
);
|
||||
} else {
|
||||
assert_eq!(
|
||||
error_count, 0,
|
||||
"expected no tree-sitter errors for {language_name}/{variant_name}",
|
||||
);
|
||||
}
|
||||
|
||||
let label = format!(
|
||||
"{}/{}_{}kb_{}e",
|
||||
language_name.to_lowercase(),
|
||||
variant_name,
|
||||
full_range.end / 1024,
|
||||
error_count,
|
||||
);
|
||||
ParsedCase {
|
||||
label,
|
||||
bytes: full_range.end,
|
||||
error_count,
|
||||
snapshot,
|
||||
}
|
||||
}
|
||||
|
||||
fn parsed_cases() -> Vec<ParsedCase> {
|
||||
let mut context = TestAppContext::single();
|
||||
context.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
});
|
||||
|
||||
let languages = Arc::new(LanguageRegistry::new(context.executor()));
|
||||
let fs = FakeFs::new(context.executor());
|
||||
let node_runtime = NodeRuntime::unavailable();
|
||||
context.update(|cx| init_languages(languages.clone(), fs, node_runtime, cx));
|
||||
|
||||
vec![
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"Rust",
|
||||
"valid",
|
||||
rust_source(900),
|
||||
false,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"Rust",
|
||||
"error_heavy",
|
||||
rust_source_with_errors(900),
|
||||
true,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"Python",
|
||||
"valid",
|
||||
python_source(1100),
|
||||
false,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"Python",
|
||||
"error_heavy",
|
||||
python_source_with_errors(1100),
|
||||
true,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"Go",
|
||||
"valid",
|
||||
go_source(1000),
|
||||
false,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"Go",
|
||||
"error_heavy",
|
||||
go_source_with_errors(1000),
|
||||
true,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"TypeScript",
|
||||
"valid",
|
||||
typescript_source(1000),
|
||||
false,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"TypeScript",
|
||||
"error_heavy",
|
||||
typescript_source_with_errors(1000),
|
||||
true,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"TSX",
|
||||
"valid",
|
||||
tsx_source(350),
|
||||
false,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"TSX",
|
||||
"error_heavy",
|
||||
tsx_source_with_errors(350),
|
||||
true,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"JSON",
|
||||
"valid",
|
||||
json_source(2200),
|
||||
false,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"JSON",
|
||||
"error_heavy",
|
||||
json_source_with_errors(2200),
|
||||
true,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"YAML",
|
||||
"valid",
|
||||
yaml_source(2200),
|
||||
false,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"YAML",
|
||||
"error_heavy",
|
||||
yaml_source_with_errors(2200),
|
||||
true,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"CSS",
|
||||
"valid",
|
||||
css_source(2400),
|
||||
false,
|
||||
),
|
||||
build_case(
|
||||
&mut context,
|
||||
&languages,
|
||||
"CSS",
|
||||
"error_heavy",
|
||||
css_source_with_errors(2400),
|
||||
true,
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn ts_error_count_benchmark(c: &mut Criterion) {
|
||||
let cases = parsed_cases();
|
||||
let mut group = c.benchmark_group("ts_error_count/full_file");
|
||||
|
||||
for case in &cases {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(&case.label),
|
||||
case,
|
||||
|bench, case| {
|
||||
bench.iter(|| {
|
||||
black_box(case.bytes);
|
||||
black_box(case.error_count);
|
||||
black_box(count_tree_sitter_errors(case.snapshot.syntax_layers()))
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, ts_error_count_benchmark);
|
||||
criterion_main!(benches);
|
||||
|
|
@ -30,7 +30,7 @@ use gpui::{
|
|||
};
|
||||
use heapless::Vec as ArrayVec;
|
||||
use language::language_settings::all_language_settings;
|
||||
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
|
||||
use language::{Anchor, Buffer, EditPreview, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
|
||||
use language::{BufferSnapshot, OffsetRangeExt};
|
||||
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
|
||||
use release_channel::AppVersion;
|
||||
|
|
@ -61,6 +61,7 @@ pub mod example_spec;
|
|||
pub mod fim;
|
||||
mod license_detection;
|
||||
pub mod mercury;
|
||||
pub mod metrics;
|
||||
pub mod ollama;
|
||||
mod onboarding_modal;
|
||||
pub mod open_ai_response;
|
||||
|
|
@ -80,6 +81,7 @@ use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
|
|||
use crate::example_spec::ExampleSpec;
|
||||
use crate::license_detection::LicenseDetectionWatcher;
|
||||
use crate::mercury::Mercury;
|
||||
pub use crate::metrics::{KeptRateResult, compute_kept_rate};
|
||||
use crate::onboarding_modal::ZedPredictModal;
|
||||
pub use crate::prediction::EditPrediction;
|
||||
pub use crate::prediction::EditPredictionId;
|
||||
|
|
@ -478,10 +480,13 @@ impl std::ops::Deref for BufferEditPrediction<'_> {
|
|||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
||||
struct PendingSettledPrediction {
|
||||
request_id: EditPredictionId,
|
||||
editable_anchor_range: Range<Anchor>,
|
||||
editable_region_before_prediction: String,
|
||||
predicted_editable_region: String,
|
||||
ts_error_count_before_prediction: usize,
|
||||
ts_error_count_after_prediction: usize,
|
||||
example: Option<ExampleSpec>,
|
||||
enqueued_at: Instant,
|
||||
last_edit_at: Instant,
|
||||
|
|
@ -1603,63 +1608,100 @@ impl EditPredictionStore {
|
|||
};
|
||||
|
||||
let now = cx.background_executor().now();
|
||||
|
||||
let mut oldest_edited_at = None;
|
||||
let mut ready_predictions = Vec::new();
|
||||
|
||||
this.update(cx, |this, _| {
|
||||
for (_, project_state) in this.projects.iter_mut() {
|
||||
for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
|
||||
registered_buffer
|
||||
.pending_predictions
|
||||
.retain_mut(|pending_prediction| {
|
||||
let age =
|
||||
now.saturating_duration_since(pending_prediction.enqueued_at);
|
||||
if age >= EDIT_PREDICTION_SETTLED_TTL {
|
||||
return false;
|
||||
}
|
||||
let mut pending_index = 0;
|
||||
while pending_index < registered_buffer.pending_predictions.len() {
|
||||
let pending_prediction =
|
||||
®istered_buffer.pending_predictions[pending_index];
|
||||
let age = now.saturating_duration_since(pending_prediction.enqueued_at);
|
||||
if age >= EDIT_PREDICTION_SETTLED_TTL {
|
||||
registered_buffer.pending_predictions.remove(pending_index);
|
||||
continue;
|
||||
}
|
||||
|
||||
let quiet_for =
|
||||
now.saturating_duration_since(pending_prediction.last_edit_at);
|
||||
if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
|
||||
let settled_editable_region = registered_buffer
|
||||
.snapshot
|
||||
.text_for_range(
|
||||
pending_prediction.editable_anchor_range.clone(),
|
||||
)
|
||||
.collect::<String>();
|
||||
let quiet_for =
|
||||
now.saturating_duration_since(pending_prediction.last_edit_at);
|
||||
if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
|
||||
let pending_prediction =
|
||||
registered_buffer.pending_predictions.remove(pending_index);
|
||||
let settled_editable_region = registered_buffer
|
||||
.snapshot
|
||||
.text_for_range(
|
||||
pending_prediction.editable_anchor_range.clone(),
|
||||
)
|
||||
.collect::<String>();
|
||||
ready_predictions
|
||||
.push((pending_prediction, settled_editable_region));
|
||||
continue;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
if let Some(callback) = &this.settled_event_callback {
|
||||
callback(
|
||||
pending_prediction.request_id.clone(),
|
||||
settled_editable_region.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
telemetry::event!(
|
||||
EDIT_PREDICTION_SETTLED_EVENT,
|
||||
request_id = pending_prediction.request_id.0.clone(),
|
||||
settled_editable_region,
|
||||
example = pending_prediction.example.take(),
|
||||
e2e_latency = pending_prediction.e2e_latency.as_millis(),
|
||||
);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
if oldest_edited_at
|
||||
.is_none_or(|t| pending_prediction.last_edit_at < t)
|
||||
{
|
||||
oldest_edited_at = Some(pending_prediction.last_edit_at);
|
||||
}
|
||||
|
||||
true
|
||||
});
|
||||
if oldest_edited_at
|
||||
.is_none_or(|time| pending_prediction.last_edit_at < time)
|
||||
{
|
||||
oldest_edited_at = Some(pending_prediction.last_edit_at);
|
||||
}
|
||||
pending_index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
|
||||
for (pending_prediction, settled_editable_region) in ready_predictions {
|
||||
let PendingSettledPrediction {
|
||||
request_id,
|
||||
editable_region_before_prediction,
|
||||
predicted_editable_region,
|
||||
ts_error_count_before_prediction,
|
||||
ts_error_count_after_prediction,
|
||||
example,
|
||||
e2e_latency,
|
||||
..
|
||||
} = pending_prediction;
|
||||
let settled_editable_region_for_metrics = settled_editable_region.clone();
|
||||
let kept_rate_result = cx
|
||||
.background_spawn(async move {
|
||||
compute_kept_rate(
|
||||
&editable_region_before_prediction,
|
||||
&predicted_editable_region,
|
||||
&settled_editable_region_for_metrics,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
||||
#[cfg(test)]
|
||||
{
|
||||
let request_id = request_id.clone();
|
||||
let settled_editable_region = settled_editable_region.clone();
|
||||
this.update(cx, |this, _| {
|
||||
if let Some(callback) = &this.settled_event_callback {
|
||||
callback(request_id, settled_editable_region);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
telemetry::event!(
|
||||
EDIT_PREDICTION_SETTLED_EVENT,
|
||||
request_id = request_id.0.clone(),
|
||||
settled_editable_region,
|
||||
ts_error_count_before_prediction,
|
||||
ts_error_count_after_prediction,
|
||||
edit_bytes_predicted_new = kept_rate_result.predicted_new_chars,
|
||||
edit_bytes_final_new = kept_rate_result.final_new_chars,
|
||||
edit_bytes_kept = kept_rate_result.kept_chars,
|
||||
edit_bytes_discarded = kept_rate_result.discarded_chars,
|
||||
edit_bytes_context = kept_rate_result.context_chars,
|
||||
edit_bytes_kept_rate = kept_rate_result.kept_rate,
|
||||
example,
|
||||
e2e_latency = e2e_latency.as_millis(),
|
||||
);
|
||||
}
|
||||
|
||||
next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1670,28 +1712,58 @@ impl EditPredictionStore {
|
|||
edited_buffer: &Entity<Buffer>,
|
||||
edited_buffer_snapshot: &BufferSnapshot,
|
||||
editable_offset_range: Range<usize>,
|
||||
edit_preview: &EditPreview,
|
||||
example: Option<ExampleSpec>,
|
||||
e2e_latency: std::time::Duration,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let this = &mut *self;
|
||||
let project_state = this.get_or_init_project(project, cx);
|
||||
if let Some(buffer) = project_state
|
||||
let Some(registered_buffer) = project_state
|
||||
.registered_buffers
|
||||
.get_mut(&edited_buffer.entity_id())
|
||||
{
|
||||
let now = cx.background_executor().now();
|
||||
buffer.pending_predictions.push(PendingSettledPrediction {
|
||||
request_id: request_id,
|
||||
editable_anchor_range: edited_buffer_snapshot
|
||||
.anchor_range_inside(editable_offset_range),
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let editable_region_before_prediction = edited_buffer_snapshot
|
||||
.text_for_range(editable_offset_range.clone())
|
||||
.collect::<String>();
|
||||
let editable_anchor_range_for_result =
|
||||
edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone());
|
||||
let predicted_editable_region = edit_preview
|
||||
.result_text_snapshot()
|
||||
.text_for_range(editable_anchor_range_for_result.clone())
|
||||
.collect();
|
||||
let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors(
|
||||
edited_buffer_snapshot
|
||||
.syntax_layers_for_range(editable_anchor_range_for_result.clone(), true),
|
||||
);
|
||||
let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors(
|
||||
edit_preview.result_syntax_snapshot().layers_for_range(
|
||||
editable_anchor_range_for_result,
|
||||
edit_preview.result_text_snapshot(),
|
||||
true,
|
||||
),
|
||||
);
|
||||
let editable_anchor_range =
|
||||
edited_buffer_snapshot.anchor_range_inside(editable_offset_range);
|
||||
let now = cx.background_executor().now();
|
||||
registered_buffer
|
||||
.pending_predictions
|
||||
.push(PendingSettledPrediction {
|
||||
request_id,
|
||||
editable_anchor_range,
|
||||
editable_region_before_prediction,
|
||||
predicted_editable_region,
|
||||
ts_error_count_before_prediction,
|
||||
ts_error_count_after_prediction,
|
||||
example,
|
||||
e2e_latency,
|
||||
enqueued_at: now,
|
||||
last_edit_at: now,
|
||||
});
|
||||
this.settled_predictions_tx.unbounded_send(now).ok();
|
||||
}
|
||||
this.settled_predictions_tx.unbounded_send(now).ok();
|
||||
}
|
||||
|
||||
fn reject_current_prediction(
|
||||
|
|
|
|||
|
|
@ -3252,6 +3252,12 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
|
|||
cx.run_until_parked();
|
||||
|
||||
let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
let empty_edits: Arc<[(Range<Anchor>, Arc<str>)]> = Vec::new().into();
|
||||
let edit_preview_a = buffer
|
||||
.read_with(cx, |buffer, cx| {
|
||||
buffer.preview_edits(empty_edits.clone(), cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
// Region A: first 10 lines of the buffer.
|
||||
let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
|
||||
|
|
@ -3263,6 +3269,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
|
|||
&buffer,
|
||||
&snapshot_a,
|
||||
editable_region_a.clone(),
|
||||
&edit_preview_a,
|
||||
None,
|
||||
Duration::from_secs(0),
|
||||
cx,
|
||||
|
|
@ -3318,6 +3325,9 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
|
|||
cx.run_until_parked();
|
||||
|
||||
let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
let edit_preview_b = buffer
|
||||
.read_with(cx, |buffer, cx| buffer.preview_edits(empty_edits, cx))
|
||||
.await;
|
||||
let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
|
|
@ -3327,6 +3337,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
|
|||
&buffer,
|
||||
&snapshot_b2,
|
||||
editable_region_b.clone(),
|
||||
&edit_preview_b,
|
||||
None,
|
||||
Duration::from_secs(0),
|
||||
cx,
|
||||
|
|
|
|||
10
crates/edit_prediction/src/metrics.rs
Normal file
10
crates/edit_prediction/src/metrics.rs
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
mod kept_rate;
|
||||
mod tokenize;
|
||||
mod tree_sitter;
|
||||
|
||||
pub use kept_rate::KeptRateResult;
|
||||
#[cfg(test)]
|
||||
pub use kept_rate::TokenAnnotation;
|
||||
pub use kept_rate::compute_kept_rate;
|
||||
pub(crate) use tokenize::tokenize;
|
||||
pub use tree_sitter::count_tree_sitter_errors;
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
use crate::word_diff::tokenize;
|
||||
use crate::metrics::tokenize;
|
||||
|
||||
const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
|
||||
|
||||
#[cfg(test)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
|
|
@ -25,20 +27,28 @@ fn dp_index(width: usize, row: usize, column: usize) -> usize {
|
|||
row * width + column
|
||||
}
|
||||
|
||||
/// Return masks over `a` and `b` using one-sided LCS tie-breaking for each
|
||||
/// side while sharing a single DP table construction.
|
||||
fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
|
||||
/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
|
||||
/// while sharing a single DP table construction.
|
||||
fn fill_lcs_keep_masks(
|
||||
a: &[&str],
|
||||
b: &[&str],
|
||||
mut keep_a: Option<&mut [bool]>,
|
||||
mut keep_b: Option<&mut [bool]>,
|
||||
) {
|
||||
if a.is_empty() || b.is_empty() {
|
||||
return (vec![false; a.len()], vec![false; b.len()]);
|
||||
return;
|
||||
}
|
||||
|
||||
if a == b {
|
||||
return (vec![true; a.len()], vec![true; b.len()]);
|
||||
if let Some(keep_a) = keep_a.as_mut() {
|
||||
keep_a.fill(true);
|
||||
}
|
||||
if let Some(keep_b) = keep_b.as_mut() {
|
||||
keep_b.fill(true);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let mut keep_a = vec![false; a.len()];
|
||||
let mut keep_b = vec![false; b.len()];
|
||||
|
||||
let prefix_len = a
|
||||
.iter()
|
||||
.zip(b.iter())
|
||||
|
|
@ -61,22 +71,30 @@ fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
|
|||
};
|
||||
|
||||
for index in 0..prefix_len {
|
||||
keep_a[index] = true;
|
||||
keep_b[index] = true;
|
||||
if let Some(keep_a) = keep_a.as_mut() {
|
||||
keep_a[index] = true;
|
||||
}
|
||||
if let Some(keep_b) = keep_b.as_mut() {
|
||||
keep_b[index] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for offset in 0..suffix_len {
|
||||
let a_index = a.len() - suffix_len + offset;
|
||||
let b_index = b.len() - suffix_len + offset;
|
||||
keep_a[a_index] = true;
|
||||
keep_b[b_index] = true;
|
||||
if let Some(keep_a) = keep_a.as_mut() {
|
||||
keep_a[a_index] = true;
|
||||
}
|
||||
if let Some(keep_b) = keep_b.as_mut() {
|
||||
keep_b[b_index] = true;
|
||||
}
|
||||
}
|
||||
|
||||
let a_mid = &a[prefix_len..a.len() - suffix_len];
|
||||
let b_mid = &b[prefix_len..b.len() - suffix_len];
|
||||
|
||||
if a_mid.is_empty() || b_mid.is_empty() {
|
||||
return (keep_a, keep_b);
|
||||
return;
|
||||
}
|
||||
|
||||
let row_count = a_mid.len() + 1;
|
||||
|
|
@ -97,44 +115,59 @@ fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
|
|||
}
|
||||
}
|
||||
|
||||
let mut i = a_mid.len();
|
||||
let mut j = b_mid.len();
|
||||
if let Some(keep_a) = keep_a.as_mut() {
|
||||
let mut i = a_mid.len();
|
||||
let mut j = b_mid.len();
|
||||
|
||||
while i > 0 && j > 0 {
|
||||
if a_mid[i - 1] == b_mid[j - 1] {
|
||||
keep_a[prefix_len + i - 1] = true;
|
||||
i -= 1;
|
||||
j -= 1;
|
||||
} else {
|
||||
let up = dp[dp_index(column_count, i - 1, j)];
|
||||
let left = dp[dp_index(column_count, i, j - 1)];
|
||||
if up >= left {
|
||||
while i > 0 && j > 0 {
|
||||
if a_mid[i - 1] == b_mid[j - 1] {
|
||||
keep_a[prefix_len + i - 1] = true;
|
||||
i -= 1;
|
||||
} else {
|
||||
j -= 1;
|
||||
} else {
|
||||
let up = dp[dp_index(column_count, i - 1, j)];
|
||||
let left = dp[dp_index(column_count, i, j - 1)];
|
||||
if up >= left {
|
||||
i -= 1;
|
||||
} else {
|
||||
j -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut i = a_mid.len();
|
||||
let mut j = b_mid.len();
|
||||
if let Some(keep_b) = keep_b.as_mut() {
|
||||
let mut i = a_mid.len();
|
||||
let mut j = b_mid.len();
|
||||
|
||||
while i > 0 && j > 0 {
|
||||
if a_mid[i - 1] == b_mid[j - 1] {
|
||||
keep_b[prefix_len + j - 1] = true;
|
||||
i -= 1;
|
||||
j -= 1;
|
||||
} else {
|
||||
let up = dp[dp_index(column_count, i - 1, j)];
|
||||
let left = dp[dp_index(column_count, i, j - 1)];
|
||||
if left >= up {
|
||||
while i > 0 && j > 0 {
|
||||
if a_mid[i - 1] == b_mid[j - 1] {
|
||||
keep_b[prefix_len + j - 1] = true;
|
||||
i -= 1;
|
||||
j -= 1;
|
||||
} else {
|
||||
i -= 1;
|
||||
let up = dp[dp_index(column_count, i - 1, j)];
|
||||
let left = dp[dp_index(column_count, i, j - 1)];
|
||||
if left >= up {
|
||||
j -= 1;
|
||||
} else {
|
||||
i -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec<bool> {
|
||||
let mut keep_a = vec![false; a.len()];
|
||||
fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
|
||||
keep_a
|
||||
}
|
||||
|
||||
fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
|
||||
let mut keep_a = vec![false; a.len()];
|
||||
let mut keep_b = vec![false; b.len()];
|
||||
fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
|
||||
(keep_a, keep_b)
|
||||
}
|
||||
|
||||
|
|
@ -155,6 +188,12 @@ fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>
|
|||
(unmasked_tokens, unmasked_chars, masked_chars)
|
||||
}
|
||||
|
||||
fn should_bail_for_dirty_final(base: &str, predicted: &str, final_text: &str) -> bool {
|
||||
let predicted_delta_chars = predicted.len().abs_diff(base.len());
|
||||
let final_delta_chars = final_text.len().abs_diff(base.len());
|
||||
predicted_delta_chars.abs_diff(final_delta_chars) > MAX_DIRTY_LENGTH_DELTA_CHARS
|
||||
}
|
||||
|
||||
pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptRateResult {
|
||||
if base == predicted && predicted == final_text {
|
||||
let predicted_tokens = tokenize(predicted);
|
||||
|
|
@ -171,11 +210,26 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR
|
|||
};
|
||||
}
|
||||
|
||||
if should_bail_for_dirty_final(base, predicted, final_text) {
|
||||
let predicted_new_chars = predicted.len().abs_diff(base.len());
|
||||
let final_new_chars = final_text.len().abs_diff(base.len());
|
||||
return KeptRateResult {
|
||||
predicted_new_chars,
|
||||
final_new_chars,
|
||||
kept_chars: 0,
|
||||
discarded_chars: predicted_new_chars,
|
||||
context_chars: 0,
|
||||
kept_rate: 0.0,
|
||||
#[cfg(test)]
|
||||
token_annotations: vec![TokenAnnotation::Discarded; tokenize(predicted).len()],
|
||||
};
|
||||
}
|
||||
|
||||
let base_tokens = tokenize(base);
|
||||
let predicted_tokens = tokenize(predicted);
|
||||
let final_tokens = tokenize(final_text);
|
||||
|
||||
let (pred_base_mask, _) = lcs_keep_masks(&predicted_tokens, &base_tokens);
|
||||
let pred_base_mask = lcs_keep_mask(&predicted_tokens, &base_tokens);
|
||||
let (pred_final_mask, final_pred_mask) = lcs_keep_masks(&predicted_tokens, &final_tokens);
|
||||
let context_mask: Vec<bool> = pred_base_mask
|
||||
.iter()
|
||||
|
|
@ -186,7 +240,7 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR
|
|||
let (stripped_predicted, predicted_new_chars, context_chars) =
|
||||
analyze_masked_tokens(&predicted_tokens, &context_mask);
|
||||
|
||||
let (final_base_mask, _) = lcs_keep_masks(&final_tokens, &base_tokens);
|
||||
let final_base_mask = lcs_keep_mask(&final_tokens, &base_tokens);
|
||||
let final_context_mask: Vec<bool> = final_base_mask
|
||||
.iter()
|
||||
.zip(final_pred_mask.iter())
|
||||
|
|
@ -196,7 +250,7 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR
|
|||
let (stripped_final, final_new_chars, _) =
|
||||
analyze_masked_tokens(&final_tokens, &final_context_mask);
|
||||
|
||||
let keep_mask = lcs_keep_masks(&stripped_predicted, &stripped_final).0;
|
||||
let keep_mask = lcs_keep_mask(&stripped_predicted, &stripped_final);
|
||||
|
||||
let kept_chars: usize = stripped_predicted
|
||||
.iter()
|
||||
|
|
@ -265,8 +319,8 @@ mod test_kept_rate {
|
|||
let a = ["x", "a", "x", "b"];
|
||||
let b = ["a", "x", "b", "x"];
|
||||
let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
|
||||
assert_eq!(a_mask, lcs_keep_masks(&a, &b).0);
|
||||
assert_eq!(b_mask, lcs_keep_masks(&b, &a).0);
|
||||
assert_eq!(a_mask, lcs_keep_mask(&a, &b));
|
||||
assert_eq!(b_mask, lcs_keep_mask(&b, &a));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -342,6 +396,21 @@ mod test_kept_rate {
|
|||
assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bails_for_dirty_final() {
|
||||
let base = "fn example() {\n work();\n}\n";
|
||||
let predicted = "fn example() {\n work();\n predicted();\n}\n";
|
||||
let final_text = format!(
|
||||
"fn example() {{\n work();\n {}\n}}\n",
|
||||
"settled();\n ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
|
||||
);
|
||||
|
||||
let result = compute_kept_rate(base, predicted, &final_text);
|
||||
assert_eq!(result.kept_rate, 0.0);
|
||||
assert_eq!(result.kept_chars, 0);
|
||||
assert_eq!(result.discarded_chars, result.predicted_new_chars);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eprintln_token_alignment() {
|
||||
let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
|
||||
54
crates/edit_prediction/src/metrics/tokenize.rs
Normal file
54
crates/edit_prediction/src/metrics/tokenize.rs
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
fn char_class(character: char) -> u8 {
|
||||
if character.is_alphanumeric() || character == '_' {
|
||||
0
|
||||
} else if character.is_whitespace() {
|
||||
1
|
||||
} else {
|
||||
2
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn tokenize(text: &str) -> Vec<&str> {
|
||||
let mut tokens = Vec::new();
|
||||
let mut characters = text.char_indices().peekable();
|
||||
|
||||
while let Some((start, character)) = characters.next() {
|
||||
let class = char_class(character);
|
||||
if class == 2 {
|
||||
tokens.push(&text[start..start + character.len_utf8()]);
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut end = start + character.len_utf8();
|
||||
while let Some(&(_, next_character)) = characters.peek() {
|
||||
if char_class(next_character) != class {
|
||||
break;
|
||||
}
|
||||
end += next_character.len_utf8();
|
||||
characters.next();
|
||||
}
|
||||
tokens.push(&text[start..end]);
|
||||
}
|
||||
|
||||
tokens
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::tokenize;
|
||||
|
||||
#[test]
|
||||
fn tokenizes_code_like_text() {
|
||||
assert_eq!(tokenize("hello world"), vec!["hello", " ", "world"]);
|
||||
assert_eq!(
|
||||
tokenize("foo_bar123 + baz"),
|
||||
vec!["foo_bar123", " ", "+", " ", "baz"]
|
||||
);
|
||||
assert_eq!(
|
||||
tokenize("print(\"hello\")"),
|
||||
vec!["print", "(", "\"", "hello", "\"", ")"]
|
||||
);
|
||||
assert_eq!(tokenize("hello_world"), vec!["hello_world"]);
|
||||
assert_eq!(tokenize("fn();"), vec!["fn", "(", ")", ";"]);
|
||||
}
|
||||
}
|
||||
88
crates/edit_prediction/src/metrics/tree_sitter.rs
Normal file
88
crates/edit_prediction/src/metrics/tree_sitter.rs
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
use language::SyntaxLayer;
|
||||
|
||||
pub fn count_tree_sitter_errors<'a>(layers: impl Iterator<Item = SyntaxLayer<'a>>) -> usize {
|
||||
let mut total_count: usize = 0;
|
||||
for layer in layers {
|
||||
let node = layer.node();
|
||||
let mut cursor = node.walk();
|
||||
'layer: loop {
|
||||
let current = cursor.node();
|
||||
if current.is_error() || current.is_missing() {
|
||||
total_count += 1;
|
||||
}
|
||||
if current.has_error() && cursor.goto_first_child() {
|
||||
continue;
|
||||
}
|
||||
if cursor.goto_next_sibling() {
|
||||
continue;
|
||||
}
|
||||
loop {
|
||||
if !cursor.goto_parent() {
|
||||
break 'layer;
|
||||
}
|
||||
if cursor.goto_next_sibling() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
total_count
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Range;
|
||||
|
||||
use super::count_tree_sitter_errors;
|
||||
use gpui::{AppContext as _, TestAppContext};
|
||||
use language::{Buffer, BufferSnapshot, rust_lang};
|
||||
|
||||
fn error_count_in_range(edited_buffer_snapshot: &BufferSnapshot, range: Range<usize>) -> usize {
|
||||
let layers = edited_buffer_snapshot.syntax_layers_for_range(range, true);
|
||||
count_tree_sitter_errors(layers)
|
||||
}
|
||||
|
||||
fn rust_snapshot(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
|
||||
while buffer.read_with(cx, |buffer, _| buffer.is_parsing()) {
|
||||
cx.run_until_parked();
|
||||
}
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot())
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn counts_no_errors_for_valid_rust(cx: &mut TestAppContext) {
|
||||
let text = "fn helper(value: usize) -> usize {\n value + 1\n}\n";
|
||||
let snapshot = rust_snapshot(text, cx);
|
||||
|
||||
assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 0);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn counts_errors_for_invalid_rust(cx: &mut TestAppContext) {
|
||||
let text = "fn helper(value: usize) -> usize {\n let total = ;\n total\n}\n";
|
||||
let snapshot = rust_snapshot(text, cx);
|
||||
|
||||
assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 1);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn counts_no_errors_for_subrange_of_valid_rust(cx: &mut TestAppContext) {
|
||||
let text = "fn first() -> usize {\n let value = 1;\n value + 1\n}\n";
|
||||
let snapshot = rust_snapshot(text, cx);
|
||||
let body_start = text.find("let value").unwrap();
|
||||
let body_end = body_start + "let value = 1;".len();
|
||||
|
||||
assert_eq!(error_count_in_range(&snapshot, body_start..body_end), 0);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn counts_errors_for_subrange_of_invalid_rust(cx: &mut TestAppContext) {
|
||||
let text = "fn second() -> usize {\n let broken = ;\n broken\n}\n";
|
||||
let snapshot = rust_snapshot(text, cx);
|
||||
let error_start = text.find("let broken = ;").unwrap();
|
||||
let error_end = error_start + "let broken = ;".len();
|
||||
|
||||
assert_eq!(error_count_in_range(&snapshot, error_start..error_end), 1);
|
||||
}
|
||||
}
|
||||
|
|
@ -85,7 +85,6 @@ pub fn request_prediction_with_zeta(
|
|||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let organization_id = store
|
||||
|
|
@ -383,11 +382,26 @@ pub fn request_prediction_with_zeta(
|
|||
}));
|
||||
};
|
||||
|
||||
if can_collect_data {
|
||||
let result = EditPredictionResult::new(
|
||||
id,
|
||||
&edited_buffer,
|
||||
&edited_buffer_snapshot,
|
||||
edits.into(),
|
||||
cursor_position,
|
||||
inputs,
|
||||
model_version,
|
||||
request_duration,
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
if can_collect_data && let Ok(prediction) = &result.prediction {
|
||||
let weak_this = this.clone();
|
||||
let id = id.clone();
|
||||
let request_id = prediction.id.clone();
|
||||
let edited_buffer = edited_buffer.clone();
|
||||
let edited_buffer_snapshot = edited_buffer_snapshot.clone();
|
||||
let editable_range_in_buffer = editable_range_in_buffer.clone();
|
||||
let edit_preview = prediction.edit_preview.clone();
|
||||
let example_task = capture_data.and_then(|stored_events| {
|
||||
cx.update(|cx| {
|
||||
crate::capture_example(
|
||||
|
|
@ -410,11 +424,12 @@ pub fn request_prediction_with_zeta(
|
|||
weak_this
|
||||
.update(cx, |this, cx| {
|
||||
this.enqueue_settled_prediction(
|
||||
id.clone(),
|
||||
request_id.clone(),
|
||||
&project,
|
||||
&edited_buffer,
|
||||
&edited_buffer_snapshot,
|
||||
editable_range_in_buffer,
|
||||
&edit_preview,
|
||||
example_spec,
|
||||
request_duration,
|
||||
cx,
|
||||
|
|
@ -425,20 +440,7 @@ pub fn request_prediction_with_zeta(
|
|||
.detach();
|
||||
}
|
||||
|
||||
Ok(Some(
|
||||
EditPredictionResult::new(
|
||||
id,
|
||||
&edited_buffer,
|
||||
&edited_buffer_snapshot,
|
||||
edits.into(),
|
||||
cursor_position,
|
||||
inputs,
|
||||
model_version,
|
||||
request_duration,
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
))
|
||||
Ok(Some(result))
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -83,7 +83,6 @@ dynamic_prompts = []
|
|||
ignored = ["wasmtime"]
|
||||
|
||||
[dev-dependencies]
|
||||
criterion.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
|
|
@ -91,6 +90,3 @@ project = { workspace = true, features = ["test-support"] }
|
|||
tempfile.workspace = true
|
||||
workspace = { workspace = true, features = ["test-support"] }
|
||||
|
||||
[[bench]]
|
||||
name = "kept_rate"
|
||||
harness = false
|
||||
|
|
|
|||
|
|
@ -1,4 +1,2 @@
|
|||
#[allow(dead_code)]
|
||||
mod word_diff;
|
||||
|
||||
pub mod kept_rate;
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ mod filter_languages;
|
|||
mod format_prompt;
|
||||
mod git;
|
||||
mod headless;
|
||||
mod kept_rate;
|
||||
|
||||
mod load_project;
|
||||
mod metrics;
|
||||
mod openai_client;
|
||||
|
|
|
|||
|
|
@ -1298,4 +1298,4 @@ index abc123..def456 100644
|
|||
}
|
||||
}
|
||||
|
||||
pub use crate::kept_rate::compute_kept_rate;
|
||||
pub use edit_prediction::metrics::compute_kept_rate;
|
||||
|
|
|
|||
|
|
@ -878,6 +878,14 @@ impl EditPreview {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn result_text_snapshot(&self) -> &text::BufferSnapshot {
|
||||
&self.applied_edits_snapshot
|
||||
}
|
||||
|
||||
pub fn result_syntax_snapshot(&self) -> &SyntaxSnapshot {
|
||||
&self.syntax_snapshot
|
||||
}
|
||||
|
||||
pub fn anchor_to_offset_in_result(&self, anchor: Anchor) -> usize {
|
||||
anchor
|
||||
.bias_right(&self.old_snapshot)
|
||||
|
|
|
|||
Loading…
Reference in a new issue