mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
edit_prediction_context: Minor optimization of text similarity + some renames (#38941)
Release Notes: - N/A
This commit is contained in:
parent
bcc8149263
commit
da71465437
10 changed files with 198 additions and 176 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -5171,6 +5171,7 @@ dependencies = [
|
|||
"collections",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"hashbrown 0.15.3",
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
|
|
|
|||
|
|
@ -511,6 +511,7 @@ futures-lite = "1.13"
|
|||
git2 = { version = "0.20.1", default-features = false }
|
||||
globset = "0.4"
|
||||
handlebars = "4.3"
|
||||
hashbrown = "0.15.3"
|
||||
heck = "0.5"
|
||||
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
|
||||
hex = "0.4.3"
|
||||
|
|
|
|||
|
|
@ -103,13 +103,13 @@ pub struct ReferencedDeclaration {
|
|||
/// Index within `signatures`.
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub parent_index: Option<usize>,
|
||||
pub score_components: ScoreComponents,
|
||||
pub score_components: DeclarationScoreComponents,
|
||||
pub signature_score: f32,
|
||||
pub declaration_score: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ScoreComponents {
|
||||
pub struct DeclarationScoreComponents {
|
||||
pub is_same_file: bool,
|
||||
pub is_referenced_nearby: bool,
|
||||
pub is_referenced_in_breadcrumb: bool,
|
||||
|
|
@ -119,12 +119,12 @@ pub struct ScoreComponents {
|
|||
pub reference_line_distance: u32,
|
||||
pub declaration_line_distance: u32,
|
||||
pub declaration_line_distance_rank: usize,
|
||||
pub containing_range_vs_item_jaccard: f32,
|
||||
pub containing_range_vs_signature_jaccard: f32,
|
||||
pub excerpt_vs_item_jaccard: f32,
|
||||
pub excerpt_vs_signature_jaccard: f32,
|
||||
pub adjacent_vs_item_jaccard: f32,
|
||||
pub adjacent_vs_signature_jaccard: f32,
|
||||
pub containing_range_vs_item_weighted_overlap: f32,
|
||||
pub containing_range_vs_signature_weighted_overlap: f32,
|
||||
pub excerpt_vs_item_weighted_overlap: f32,
|
||||
pub excerpt_vs_signature_weighted_overlap: f32,
|
||||
pub adjacent_vs_item_weighted_overlap: f32,
|
||||
pub adjacent_vs_signature_weighted_overlap: f32,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ pub struct PlannedSnippet<'a> {
|
|||
}
|
||||
|
||||
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
|
||||
pub enum SnippetStyle {
|
||||
pub enum DeclarationStyle {
|
||||
Signature,
|
||||
Declaration,
|
||||
}
|
||||
|
|
@ -84,10 +84,10 @@ pub struct SectionLabels {
|
|||
impl<'a> PlannedPrompt<'a> {
|
||||
/// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
|
||||
///
|
||||
/// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
|
||||
/// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
|
||||
/// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
|
||||
/// upgrade.
|
||||
/// Initializes a priority queue by populating it with each snippet, finding the
|
||||
/// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
|
||||
/// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
|
||||
/// the cost of upgrade.
|
||||
///
|
||||
/// TODO: Implement an early halting condition. One option might be to have another priority
|
||||
/// queue where the score is the size, and update it accordingly. Another option might be to
|
||||
|
|
@ -131,13 +131,13 @@ impl<'a> PlannedPrompt<'a> {
|
|||
struct QueueEntry {
|
||||
score_density: OrderedFloat<f32>,
|
||||
declaration_index: usize,
|
||||
style: SnippetStyle,
|
||||
style: DeclarationStyle,
|
||||
}
|
||||
|
||||
// Initialize priority queue with the best score for each snippet.
|
||||
let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
|
||||
for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
|
||||
let (style, score_density) = SnippetStyle::iter()
|
||||
let (style, score_density) = DeclarationStyle::iter()
|
||||
.map(|style| {
|
||||
(
|
||||
style,
|
||||
|
|
@ -186,7 +186,7 @@ impl<'a> PlannedPrompt<'a> {
|
|||
this.budget_used += additional_bytes;
|
||||
this.add_parents(&mut included_parents, additional_parents);
|
||||
let planned_snippet = match queue_entry.style {
|
||||
SnippetStyle::Signature => {
|
||||
DeclarationStyle::Signature => {
|
||||
let Some(text) = declaration.text.get(declaration.signature_range.clone())
|
||||
else {
|
||||
return Err(anyhow!(
|
||||
|
|
@ -203,7 +203,7 @@ impl<'a> PlannedPrompt<'a> {
|
|||
text_is_truncated: declaration.text_is_truncated,
|
||||
}
|
||||
}
|
||||
SnippetStyle::Declaration => PlannedSnippet {
|
||||
DeclarationStyle::Declaration => PlannedSnippet {
|
||||
path: declaration.path.clone(),
|
||||
range: declaration.range.clone(),
|
||||
text: &declaration.text,
|
||||
|
|
@ -213,11 +213,13 @@ impl<'a> PlannedPrompt<'a> {
|
|||
this.snippets.push(planned_snippet);
|
||||
|
||||
// When a Signature is consumed, insert an entry for Definition style.
|
||||
if queue_entry.style == SnippetStyle::Signature {
|
||||
let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
|
||||
let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
|
||||
let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
|
||||
let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
|
||||
if queue_entry.style == DeclarationStyle::Signature {
|
||||
let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
|
||||
let declaration_size =
|
||||
declaration_size(&declaration, DeclarationStyle::Declaration);
|
||||
let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
|
||||
let declaration_score =
|
||||
declaration_score(&declaration, DeclarationStyle::Declaration);
|
||||
|
||||
let score_diff = declaration_score - signature_score;
|
||||
let size_diff = declaration_size.saturating_sub(signature_size);
|
||||
|
|
@ -225,7 +227,7 @@ impl<'a> PlannedPrompt<'a> {
|
|||
queue.push(QueueEntry {
|
||||
declaration_index: queue_entry.declaration_index,
|
||||
score_density: OrderedFloat(score_diff / (size_diff as f32)),
|
||||
style: SnippetStyle::Declaration,
|
||||
style: DeclarationStyle::Declaration,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -510,20 +512,20 @@ impl<'a> PlannedPrompt<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
|
||||
fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
|
||||
declaration_score(declaration, style) / declaration_size(declaration, style) as f32
|
||||
}
|
||||
|
||||
fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
|
||||
fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
|
||||
match style {
|
||||
SnippetStyle::Signature => declaration.signature_score,
|
||||
SnippetStyle::Declaration => declaration.declaration_score,
|
||||
DeclarationStyle::Signature => declaration.signature_score,
|
||||
DeclarationStyle::Declaration => declaration.declaration_score,
|
||||
}
|
||||
}
|
||||
|
||||
fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
|
||||
fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
|
||||
match style {
|
||||
SnippetStyle::Signature => declaration.signature_range.len(),
|
||||
SnippetStyle::Declaration => declaration.text.len(),
|
||||
DeclarationStyle::Signature => declaration.signature_range.len(),
|
||||
DeclarationStyle::Declaration => declaration.text.len(),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ cloud_llm_client.workspace = true
|
|||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
hashbrown.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use cloud_llm_client::predict_edits_v3::ScoreComponents;
|
||||
use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
|
||||
use itertools::Itertools as _;
|
||||
use language::BufferSnapshot;
|
||||
use ordered_float::OrderedFloat;
|
||||
|
|
@ -8,76 +8,67 @@ use strum::EnumIter;
|
|||
use text::{Point, ToPoint};
|
||||
|
||||
use crate::{
|
||||
Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
|
||||
Declaration, EditPredictionExcerpt, Identifier,
|
||||
reference::{Reference, ReferenceRegion},
|
||||
syntax_index::SyntaxIndexState,
|
||||
text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
|
||||
text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
|
||||
};
|
||||
|
||||
const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ScoredSnippet {
|
||||
pub struct ScoredDeclaration {
|
||||
pub identifier: Identifier,
|
||||
pub declaration: Declaration,
|
||||
pub score_components: ScoreComponents,
|
||||
pub scores: Scores,
|
||||
pub score_components: DeclarationScoreComponents,
|
||||
pub scores: DeclarationScores,
|
||||
}
|
||||
|
||||
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum SnippetStyle {
|
||||
pub enum DeclarationStyle {
|
||||
Signature,
|
||||
Declaration,
|
||||
}
|
||||
|
||||
impl ScoredSnippet {
|
||||
/// Returns the score for this snippet with the specified style.
|
||||
pub fn score(&self, style: SnippetStyle) -> f32 {
|
||||
impl ScoredDeclaration {
|
||||
/// Returns the score for this declaration with the specified style.
|
||||
pub fn score(&self, style: DeclarationStyle) -> f32 {
|
||||
match style {
|
||||
SnippetStyle::Signature => self.scores.signature,
|
||||
SnippetStyle::Declaration => self.scores.declaration,
|
||||
DeclarationStyle::Signature => self.scores.signature,
|
||||
DeclarationStyle::Declaration => self.scores.declaration,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn size(&self, style: SnippetStyle) -> usize {
|
||||
pub fn size(&self, style: DeclarationStyle) -> usize {
|
||||
match &self.declaration {
|
||||
Declaration::File { declaration, .. } => match style {
|
||||
SnippetStyle::Signature => declaration.signature_range.len(),
|
||||
SnippetStyle::Declaration => declaration.text.len(),
|
||||
DeclarationStyle::Signature => declaration.signature_range.len(),
|
||||
DeclarationStyle::Declaration => declaration.text.len(),
|
||||
},
|
||||
Declaration::Buffer { declaration, .. } => match style {
|
||||
SnippetStyle::Signature => declaration.signature_range.len(),
|
||||
SnippetStyle::Declaration => declaration.item_range.len(),
|
||||
DeclarationStyle::Signature => declaration.signature_range.len(),
|
||||
DeclarationStyle::Declaration => declaration.item_range.len(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn score_density(&self, style: SnippetStyle) -> f32 {
|
||||
pub fn score_density(&self, style: DeclarationStyle) -> f32 {
|
||||
self.score(style) / (self.size(style)) as f32
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scored_snippets(
|
||||
pub fn scored_declarations(
|
||||
index: &SyntaxIndexState,
|
||||
excerpt: &EditPredictionExcerpt,
|
||||
excerpt_text: &EditPredictionExcerptText,
|
||||
excerpt_occurrences: &Occurrences,
|
||||
adjacent_occurrences: &Occurrences,
|
||||
identifier_to_references: HashMap<Identifier, Vec<Reference>>,
|
||||
cursor_offset: usize,
|
||||
current_buffer: &BufferSnapshot,
|
||||
) -> Vec<ScoredSnippet> {
|
||||
let containing_range_identifier_occurrences =
|
||||
IdentifierOccurrences::within_string(&excerpt_text.body);
|
||||
) -> Vec<ScoredDeclaration> {
|
||||
let cursor_point = cursor_offset.to_point(¤t_buffer);
|
||||
|
||||
let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
|
||||
let end_point = Point::new(cursor_point.row + 1, 0);
|
||||
let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
|
||||
¤t_buffer
|
||||
.text_for_range(start_point..end_point)
|
||||
.collect::<String>(),
|
||||
);
|
||||
|
||||
let mut snippets = identifier_to_references
|
||||
let mut declarations = identifier_to_references
|
||||
.into_iter()
|
||||
.flat_map(|(identifier, references)| {
|
||||
let declarations =
|
||||
|
|
@ -137,7 +128,7 @@ pub fn scored_snippets(
|
|||
)| {
|
||||
let same_file_declaration_count = index.file_declaration_count(declaration);
|
||||
|
||||
score_snippet(
|
||||
score_declaration(
|
||||
&identifier,
|
||||
&references,
|
||||
declaration.clone(),
|
||||
|
|
@ -146,8 +137,8 @@ pub fn scored_snippets(
|
|||
declaration_line_distance_rank,
|
||||
same_file_declaration_count,
|
||||
declaration_count,
|
||||
&containing_range_identifier_occurrences,
|
||||
&adjacent_identifier_occurrences,
|
||||
&excerpt_occurrences,
|
||||
&adjacent_occurrences,
|
||||
cursor_point,
|
||||
current_buffer,
|
||||
)
|
||||
|
|
@ -158,14 +149,14 @@ pub fn scored_snippets(
|
|||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
snippets.sort_unstable_by_key(|snippet| {
|
||||
let score_density = snippet
|
||||
.score_density(SnippetStyle::Declaration)
|
||||
.max(snippet.score_density(SnippetStyle::Signature));
|
||||
declarations.sort_unstable_by_key(|declaration| {
|
||||
let score_density = declaration
|
||||
.score_density(DeclarationStyle::Declaration)
|
||||
.max(declaration.score_density(DeclarationStyle::Signature));
|
||||
Reverse(OrderedFloat(score_density))
|
||||
});
|
||||
|
||||
snippets
|
||||
declarations
|
||||
}
|
||||
|
||||
fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
|
||||
|
|
@ -178,7 +169,7 @@ fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Rang
|
|||
}
|
||||
}
|
||||
|
||||
fn score_snippet(
|
||||
fn score_declaration(
|
||||
identifier: &Identifier,
|
||||
references: &[Reference],
|
||||
declaration: Declaration,
|
||||
|
|
@ -187,11 +178,11 @@ fn score_snippet(
|
|||
declaration_line_distance_rank: usize,
|
||||
same_file_declaration_count: usize,
|
||||
declaration_count: usize,
|
||||
containing_range_identifier_occurrences: &IdentifierOccurrences,
|
||||
adjacent_identifier_occurrences: &IdentifierOccurrences,
|
||||
excerpt_occurrences: &Occurrences,
|
||||
adjacent_occurrences: &Occurrences,
|
||||
cursor: Point,
|
||||
current_buffer: &BufferSnapshot,
|
||||
) -> Option<ScoredSnippet> {
|
||||
) -> Option<ScoredDeclaration> {
|
||||
let is_referenced_nearby = references
|
||||
.iter()
|
||||
.any(|r| r.region == ReferenceRegion::Nearby);
|
||||
|
|
@ -208,37 +199,27 @@ fn score_snippet(
|
|||
.min()
|
||||
.unwrap();
|
||||
|
||||
let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
|
||||
let item_signature_occurrences =
|
||||
IdentifierOccurrences::within_string(&declaration.signature_text().0);
|
||||
let containing_range_vs_item_jaccard = jaccard_similarity(
|
||||
containing_range_identifier_occurrences,
|
||||
&item_source_occurrences,
|
||||
);
|
||||
let containing_range_vs_signature_jaccard = jaccard_similarity(
|
||||
containing_range_identifier_occurrences,
|
||||
&item_signature_occurrences,
|
||||
);
|
||||
let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
|
||||
let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
|
||||
let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
|
||||
let excerpt_vs_signature_jaccard =
|
||||
jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
|
||||
let adjacent_vs_item_jaccard =
|
||||
jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
|
||||
jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
|
||||
let adjacent_vs_signature_jaccard =
|
||||
jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
|
||||
jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
|
||||
|
||||
let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
|
||||
containing_range_identifier_occurrences,
|
||||
&item_source_occurrences,
|
||||
);
|
||||
let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
|
||||
containing_range_identifier_occurrences,
|
||||
&item_signature_occurrences,
|
||||
);
|
||||
let excerpt_vs_item_weighted_overlap =
|
||||
weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
|
||||
let excerpt_vs_signature_weighted_overlap =
|
||||
weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
|
||||
let adjacent_vs_item_weighted_overlap =
|
||||
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
|
||||
weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
|
||||
let adjacent_vs_signature_weighted_overlap =
|
||||
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
|
||||
weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
|
||||
|
||||
// TODO: Consider adding declaration_file_count
|
||||
let score_components = ScoreComponents {
|
||||
let score_components = DeclarationScoreComponents {
|
||||
is_same_file,
|
||||
is_referenced_nearby,
|
||||
is_referenced_in_breadcrumb,
|
||||
|
|
@ -248,32 +229,32 @@ fn score_snippet(
|
|||
reference_count,
|
||||
same_file_declaration_count,
|
||||
declaration_count,
|
||||
containing_range_vs_item_jaccard,
|
||||
containing_range_vs_signature_jaccard,
|
||||
excerpt_vs_item_jaccard,
|
||||
excerpt_vs_signature_jaccard,
|
||||
adjacent_vs_item_jaccard,
|
||||
adjacent_vs_signature_jaccard,
|
||||
containing_range_vs_item_weighted_overlap,
|
||||
containing_range_vs_signature_weighted_overlap,
|
||||
excerpt_vs_item_weighted_overlap,
|
||||
excerpt_vs_signature_weighted_overlap,
|
||||
adjacent_vs_item_weighted_overlap,
|
||||
adjacent_vs_signature_weighted_overlap,
|
||||
};
|
||||
|
||||
Some(ScoredSnippet {
|
||||
Some(ScoredDeclaration {
|
||||
identifier: identifier.clone(),
|
||||
declaration: declaration,
|
||||
scores: Scores::score(&score_components),
|
||||
scores: DeclarationScores::score(&score_components),
|
||||
score_components,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct Scores {
|
||||
pub struct DeclarationScores {
|
||||
pub signature: f32,
|
||||
pub declaration: f32,
|
||||
}
|
||||
|
||||
impl Scores {
|
||||
fn score(components: &ScoreComponents) -> Scores {
|
||||
impl DeclarationScores {
|
||||
fn score(components: &DeclarationScoreComponents) -> DeclarationScores {
|
||||
// TODO: handle truncation
|
||||
|
||||
// Score related to how likely this is the correct declaration, range 0 to 1
|
||||
|
|
@ -295,13 +276,11 @@ impl Scores {
|
|||
// For now instead of linear combination, the scores are just multiplied together.
|
||||
let combined_score = 10.0 * accuracy_score * distance_score;
|
||||
|
||||
Scores {
|
||||
signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
|
||||
DeclarationScores {
|
||||
signature: combined_score * components.excerpt_vs_signature_weighted_overlap,
|
||||
// declaration score gets boosted both by being multiplied by 2 and by there being more
|
||||
// weighted overlap.
|
||||
declaration: 2.0
|
||||
* combined_score
|
||||
* components.containing_range_vs_item_weighted_overlap,
|
||||
declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ pub struct EditPredictionContext {
|
|||
pub excerpt: EditPredictionExcerpt,
|
||||
pub excerpt_text: EditPredictionExcerptText,
|
||||
pub cursor_offset_in_excerpt: usize,
|
||||
pub snippets: Vec<ScoredSnippet>,
|
||||
pub declarations: Vec<ScoredDeclaration>,
|
||||
}
|
||||
|
||||
impl EditPredictionContext {
|
||||
|
|
@ -58,17 +58,28 @@ impl EditPredictionContext {
|
|||
index_state,
|
||||
)?;
|
||||
let excerpt_text = excerpt.text(buffer);
|
||||
let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
|
||||
|
||||
let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
|
||||
let adjacent_end = Point::new(cursor_point.row + 1, 0);
|
||||
let adjacent_occurrences = text_similarity::Occurrences::within_string(
|
||||
&buffer
|
||||
.text_for_range(adjacent_start..adjacent_end)
|
||||
.collect::<String>(),
|
||||
);
|
||||
|
||||
let cursor_offset_in_file = cursor_point.to_offset(buffer);
|
||||
// TODO fix this to not need saturating_sub
|
||||
let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
|
||||
|
||||
let snippets = if let Some(index_state) = index_state {
|
||||
let declarations = if let Some(index_state) = index_state {
|
||||
let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
|
||||
|
||||
scored_snippets(
|
||||
scored_declarations(
|
||||
&index_state,
|
||||
&excerpt,
|
||||
&excerpt_text,
|
||||
&excerpt_occurrences,
|
||||
&adjacent_occurrences,
|
||||
references,
|
||||
cursor_offset_in_file,
|
||||
buffer,
|
||||
|
|
@ -81,7 +92,7 @@ impl EditPredictionContext {
|
|||
excerpt,
|
||||
excerpt_text,
|
||||
cursor_offset_in_excerpt,
|
||||
snippets,
|
||||
declarations,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -137,7 +148,7 @@ mod tests {
|
|||
.unwrap();
|
||||
|
||||
let mut snippet_identifiers = context
|
||||
.snippets
|
||||
.declarations
|
||||
.iter()
|
||||
.map(|snippet| snippet.identifier.name.as_ref())
|
||||
.collect::<Vec<_>>();
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
use hashbrown::HashTable;
|
||||
use regex::Regex;
|
||||
use std::{collections::HashMap, sync::LazyLock};
|
||||
use std::{
|
||||
hash::{Hash, Hasher as _},
|
||||
sync::LazyLock,
|
||||
};
|
||||
|
||||
use crate::reference::Reference;
|
||||
|
||||
|
|
@ -14,47 +18,74 @@ use crate::reference::Reference;
|
|||
|
||||
static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
|
||||
|
||||
// TODO: use &str or Cow<str> keys?
|
||||
#[derive(Debug)]
|
||||
pub struct IdentifierOccurrences {
|
||||
identifier_to_count: HashMap<String, usize>,
|
||||
/// Multiset of text occurrences for text similarity that only stores hashes and counts.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Occurrences {
|
||||
table: HashTable<OccurrenceEntry>,
|
||||
total_count: usize,
|
||||
}
|
||||
|
||||
impl IdentifierOccurrences {
|
||||
pub fn within_string(code: &str) -> Self {
|
||||
Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
|
||||
#[derive(Debug)]
|
||||
struct OccurrenceEntry {
|
||||
hash: u64,
|
||||
count: usize,
|
||||
}
|
||||
|
||||
impl Occurrences {
|
||||
pub fn within_string(text: &str) -> Self {
|
||||
Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str()))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn within_references(references: &[Reference]) -> Self {
|
||||
Self::from_iterator(
|
||||
Self::from_identifiers(
|
||||
references
|
||||
.iter()
|
||||
.map(|reference| reference.identifier.name.as_ref()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
|
||||
let mut identifier_to_count = HashMap::new();
|
||||
let mut total_count = 0;
|
||||
for identifier in identifier_iterator {
|
||||
// TODO: Score matches that match case higher?
|
||||
//
|
||||
// TODO: Also include unsplit identifier?
|
||||
pub fn from_identifiers<'a>(identifiers: impl IntoIterator<Item = &'a str>) -> Self {
|
||||
let mut this = Self::default();
|
||||
// TODO: Score matches that match case higher?
|
||||
//
|
||||
// TODO: Also include unsplit identifier?
|
||||
for identifier in identifiers {
|
||||
for identifier_part in split_identifier(identifier) {
|
||||
identifier_to_count
|
||||
.entry(identifier_part.to_lowercase())
|
||||
.and_modify(|count| *count += 1)
|
||||
.or_insert(1);
|
||||
total_count += 1;
|
||||
this.add_hash(fx_hash(&identifier_part.to_lowercase()));
|
||||
}
|
||||
}
|
||||
IdentifierOccurrences {
|
||||
identifier_to_count,
|
||||
total_count,
|
||||
}
|
||||
this
|
||||
}
|
||||
|
||||
fn add_hash(&mut self, hash: u64) {
|
||||
self.table
|
||||
.entry(
|
||||
hash,
|
||||
|entry: &OccurrenceEntry| entry.hash == hash,
|
||||
|entry| entry.hash,
|
||||
)
|
||||
.and_modify(|entry| entry.count += 1)
|
||||
.or_insert(OccurrenceEntry { hash, count: 1 });
|
||||
self.total_count += 1;
|
||||
}
|
||||
|
||||
fn contains_hash(&self, hash: u64) -> bool {
|
||||
self.get_count(hash) != 0
|
||||
}
|
||||
|
||||
fn get_count(&self, hash: u64) -> usize {
|
||||
self.table
|
||||
.find(hash, |entry| entry.hash == hash)
|
||||
.map(|entry| entry.count)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fx_hash<T: Hash + ?Sized>(data: &T) -> u64 {
|
||||
let mut hasher = collections::FxHasher::default();
|
||||
data.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
// Splits camelcase / snakecase / kebabcase / pascalcase
|
||||
|
|
@ -115,54 +146,49 @@ fn split_identifier(identifier: &str) -> Vec<&str> {
|
|||
parts.into_iter().filter(|s| !s.is_empty()).collect()
|
||||
}
|
||||
|
||||
pub fn jaccard_similarity<'a>(
|
||||
mut set_a: &'a IdentifierOccurrences,
|
||||
mut set_b: &'a IdentifierOccurrences,
|
||||
) -> f32 {
|
||||
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
|
||||
pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
|
||||
if set_a.table.len() > set_b.table.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
let intersection = set_a
|
||||
.identifier_to_count
|
||||
.keys()
|
||||
.filter(|key| set_b.identifier_to_count.contains_key(*key))
|
||||
.table
|
||||
.iter()
|
||||
.filter(|entry| set_b.contains_hash(entry.hash))
|
||||
.count();
|
||||
let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
|
||||
let union = set_a.table.len() + set_b.table.len() - intersection;
|
||||
intersection as f32 / union as f32
|
||||
}
|
||||
|
||||
// TODO
|
||||
#[allow(dead_code)]
|
||||
pub fn overlap_coefficient<'a>(
|
||||
mut set_a: &'a IdentifierOccurrences,
|
||||
mut set_b: &'a IdentifierOccurrences,
|
||||
) -> f32 {
|
||||
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
|
||||
pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
|
||||
if set_a.table.len() > set_b.table.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
let intersection = set_a
|
||||
.identifier_to_count
|
||||
.keys()
|
||||
.filter(|key| set_b.identifier_to_count.contains_key(*key))
|
||||
.table
|
||||
.iter()
|
||||
.filter(|entry| set_b.contains_hash(entry.hash))
|
||||
.count();
|
||||
intersection as f32 / set_a.identifier_to_count.len() as f32
|
||||
intersection as f32 / set_a.table.len() as f32
|
||||
}
|
||||
|
||||
// TODO
|
||||
#[allow(dead_code)]
|
||||
pub fn weighted_jaccard_similarity<'a>(
|
||||
mut set_a: &'a IdentifierOccurrences,
|
||||
mut set_b: &'a IdentifierOccurrences,
|
||||
mut set_a: &'a Occurrences,
|
||||
mut set_b: &'a Occurrences,
|
||||
) -> f32 {
|
||||
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
|
||||
if set_a.table.len() > set_b.table.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
|
||||
let mut numerator = 0;
|
||||
let mut denominator_a = 0;
|
||||
let mut used_count_b = 0;
|
||||
for (symbol, count_a) in set_a.identifier_to_count.iter() {
|
||||
let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
|
||||
for entry_a in set_a.table.iter() {
|
||||
let count_a = entry_a.count;
|
||||
let count_b = set_b.get_count(entry_a.hash);
|
||||
numerator += count_a.min(count_b);
|
||||
denominator_a += count_a.max(count_b);
|
||||
used_count_b += count_b;
|
||||
|
|
@ -177,16 +203,17 @@ pub fn weighted_jaccard_similarity<'a>(
|
|||
}
|
||||
|
||||
pub fn weighted_overlap_coefficient<'a>(
|
||||
mut set_a: &'a IdentifierOccurrences,
|
||||
mut set_b: &'a IdentifierOccurrences,
|
||||
mut set_a: &'a Occurrences,
|
||||
mut set_b: &'a Occurrences,
|
||||
) -> f32 {
|
||||
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
|
||||
if set_a.table.len() > set_b.table.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
|
||||
let mut numerator = 0;
|
||||
for (symbol, count_a) in set_a.identifier_to_count.iter() {
|
||||
let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
|
||||
for entry_a in set_a.table.iter() {
|
||||
let count_a = entry_a.count;
|
||||
let count_b = set_b.get_count(entry_a.hash);
|
||||
numerator += count_a.min(count_b);
|
||||
}
|
||||
|
||||
|
|
@ -215,12 +242,12 @@ mod test {
|
|||
fn test_similarity_functions() {
|
||||
// 10 identifier parts, 8 unique
|
||||
// Repeats: 2 "outline", 2 "items"
|
||||
let set_a = IdentifierOccurrences::within_string(
|
||||
let set_a = Occurrences::within_string(
|
||||
"let mut outline_items = query_outline_items(&language, &tree, &source);",
|
||||
);
|
||||
// 14 identifier parts, 11 unique
|
||||
// Repeats: 2 "outline", 2 "language", 2 "tree"
|
||||
let set_b = IdentifierOccurrences::within_string(
|
||||
let set_b = Occurrences::within_string(
|
||||
"pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -733,7 +733,7 @@ fn make_cloud_request(
|
|||
let mut declaration_to_signature_index = HashMap::default();
|
||||
let mut referenced_declarations = Vec::new();
|
||||
|
||||
for snippet in context.snippets {
|
||||
for snippet in context.declarations {
|
||||
let project_entry_id = snippet.declaration.project_entry_id();
|
||||
let Some(path) = worktrees.iter().find_map(|worktree| {
|
||||
worktree.entry_for_id(project_entry_id).map(|entry| {
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
|
|||
use workspace::{Item, SplitDirection, Workspace};
|
||||
use zeta2::{Zeta, ZetaOptions};
|
||||
|
||||
use edit_prediction_context::{EditPredictionExcerptOptions, SnippetStyle};
|
||||
use edit_prediction_context::{DeclarationStyle, EditPredictionExcerptOptions};
|
||||
|
||||
actions!(
|
||||
dev,
|
||||
|
|
@ -285,7 +285,7 @@ impl Zeta2Inspector {
|
|||
let mut languages = HashMap::default();
|
||||
for lang_id in prediction
|
||||
.context
|
||||
.snippets
|
||||
.declarations
|
||||
.iter()
|
||||
.map(|snippet| snippet.declaration.identifier().language_id)
|
||||
.chain(prediction.context.excerpt_text.language_id)
|
||||
|
|
@ -334,7 +334,7 @@ impl Zeta2Inspector {
|
|||
cx,
|
||||
);
|
||||
|
||||
for snippet in &prediction.context.snippets {
|
||||
for snippet in &prediction.context.declarations {
|
||||
let path = this
|
||||
.project
|
||||
.read(cx)
|
||||
|
|
@ -345,7 +345,7 @@ impl Zeta2Inspector {
|
|||
"{} (Score density: {})",
|
||||
path.map(|p| p.path.display(path_style).to_string())
|
||||
.unwrap_or_else(|| "".to_string()),
|
||||
snippet.score_density(SnippetStyle::Declaration)
|
||||
snippet.score_density(DeclarationStyle::Declaration)
|
||||
))
|
||||
.unwrap()
|
||||
.into(),
|
||||
|
|
|
|||
Loading…
Reference in a new issue