edit_prediction_context: Minor optimization of text similarity + some renames (#38941)

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-09-26 01:57:28 -06:00 committed by GitHub
parent bcc8149263
commit da71465437
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 198 additions and 176 deletions

1
Cargo.lock generated
View file

@ -5171,6 +5171,7 @@ dependencies = [
"collections",
"futures 0.3.31",
"gpui",
"hashbrown 0.15.3",
"indoc",
"itertools 0.14.0",
"language",

View file

@ -511,6 +511,7 @@ futures-lite = "1.13"
git2 = { version = "0.20.1", default-features = false }
globset = "0.4"
handlebars = "4.3"
hashbrown = "0.15.3"
heck = "0.5"
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
hex = "0.4.3"

View file

@ -103,13 +103,13 @@ pub struct ReferencedDeclaration {
/// Index within `signatures`.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub parent_index: Option<usize>,
pub score_components: ScoreComponents,
pub score_components: DeclarationScoreComponents,
pub signature_score: f32,
pub declaration_score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreComponents {
pub struct DeclarationScoreComponents {
pub is_same_file: bool,
pub is_referenced_nearby: bool,
pub is_referenced_in_breadcrumb: bool,
@ -119,12 +119,12 @@ pub struct ScoreComponents {
pub reference_line_distance: u32,
pub declaration_line_distance: u32,
pub declaration_line_distance_rank: usize,
pub containing_range_vs_item_jaccard: f32,
pub containing_range_vs_signature_jaccard: f32,
pub excerpt_vs_item_jaccard: f32,
pub excerpt_vs_signature_jaccard: f32,
pub adjacent_vs_item_jaccard: f32,
pub adjacent_vs_signature_jaccard: f32,
pub containing_range_vs_item_weighted_overlap: f32,
pub containing_range_vs_signature_weighted_overlap: f32,
pub excerpt_vs_item_weighted_overlap: f32,
pub excerpt_vs_signature_weighted_overlap: f32,
pub adjacent_vs_item_weighted_overlap: f32,
pub adjacent_vs_signature_weighted_overlap: f32,
}

View file

@ -70,7 +70,7 @@ pub struct PlannedSnippet<'a> {
}
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
pub enum SnippetStyle {
pub enum DeclarationStyle {
Signature,
Declaration,
}
@ -84,10 +84,10 @@ pub struct SectionLabels {
impl<'a> PlannedPrompt<'a> {
/// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
///
/// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
/// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
/// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
/// upgrade.
/// Initializes a priority queue by populating it with each snippet, finding the
/// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
/// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
/// the cost of upgrade.
///
/// TODO: Implement an early halting condition. One option might be to have another priority
/// queue where the score is the size, and update it accordingly. Another option might be to
@ -131,13 +131,13 @@ impl<'a> PlannedPrompt<'a> {
struct QueueEntry {
score_density: OrderedFloat<f32>,
declaration_index: usize,
style: SnippetStyle,
style: DeclarationStyle,
}
// Initialize priority queue with the best score for each snippet.
let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
let (style, score_density) = SnippetStyle::iter()
let (style, score_density) = DeclarationStyle::iter()
.map(|style| {
(
style,
@ -186,7 +186,7 @@ impl<'a> PlannedPrompt<'a> {
this.budget_used += additional_bytes;
this.add_parents(&mut included_parents, additional_parents);
let planned_snippet = match queue_entry.style {
SnippetStyle::Signature => {
DeclarationStyle::Signature => {
let Some(text) = declaration.text.get(declaration.signature_range.clone())
else {
return Err(anyhow!(
@ -203,7 +203,7 @@ impl<'a> PlannedPrompt<'a> {
text_is_truncated: declaration.text_is_truncated,
}
}
SnippetStyle::Declaration => PlannedSnippet {
DeclarationStyle::Declaration => PlannedSnippet {
path: declaration.path.clone(),
range: declaration.range.clone(),
text: &declaration.text,
@ -213,11 +213,13 @@ impl<'a> PlannedPrompt<'a> {
this.snippets.push(planned_snippet);
// When a Signature is consumed, insert an entry for Definition style.
if queue_entry.style == SnippetStyle::Signature {
let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
if queue_entry.style == DeclarationStyle::Signature {
let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
let declaration_size =
declaration_size(&declaration, DeclarationStyle::Declaration);
let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
let declaration_score =
declaration_score(&declaration, DeclarationStyle::Declaration);
let score_diff = declaration_score - signature_score;
let size_diff = declaration_size.saturating_sub(signature_size);
@ -225,7 +227,7 @@ impl<'a> PlannedPrompt<'a> {
queue.push(QueueEntry {
declaration_index: queue_entry.declaration_index,
score_density: OrderedFloat(score_diff / (size_diff as f32)),
style: SnippetStyle::Declaration,
style: DeclarationStyle::Declaration,
});
}
}
@ -510,20 +512,20 @@ impl<'a> PlannedPrompt<'a> {
}
}
fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
declaration_score(declaration, style) / declaration_size(declaration, style) as f32
}
fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
match style {
SnippetStyle::Signature => declaration.signature_score,
SnippetStyle::Declaration => declaration.declaration_score,
DeclarationStyle::Signature => declaration.signature_score,
DeclarationStyle::Declaration => declaration.declaration_score,
}
}
fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
match style {
SnippetStyle::Signature => declaration.signature_range.len(),
SnippetStyle::Declaration => declaration.text.len(),
DeclarationStyle::Signature => declaration.signature_range.len(),
DeclarationStyle::Declaration => declaration.text.len(),
}
}

View file

@ -18,6 +18,7 @@ cloud_llm_client.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
hashbrown.workspace = true
itertools.workspace = true
language.workspace = true
log.workspace = true

View file

@ -1,4 +1,4 @@
use cloud_llm_client::predict_edits_v3::ScoreComponents;
use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
use itertools::Itertools as _;
use language::BufferSnapshot;
use ordered_float::OrderedFloat;
@ -8,76 +8,67 @@ use strum::EnumIter;
use text::{Point, ToPoint};
use crate::{
Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
Declaration, EditPredictionExcerpt, Identifier,
reference::{Reference, ReferenceRegion},
syntax_index::SyntaxIndexState,
text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
};
const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
#[derive(Clone, Debug)]
pub struct ScoredSnippet {
pub struct ScoredDeclaration {
pub identifier: Identifier,
pub declaration: Declaration,
pub score_components: ScoreComponents,
pub scores: Scores,
pub score_components: DeclarationScoreComponents,
pub scores: DeclarationScores,
}
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum SnippetStyle {
pub enum DeclarationStyle {
Signature,
Declaration,
}
impl ScoredSnippet {
/// Returns the score for this snippet with the specified style.
pub fn score(&self, style: SnippetStyle) -> f32 {
impl ScoredDeclaration {
/// Returns the score for this declaration with the specified style.
pub fn score(&self, style: DeclarationStyle) -> f32 {
match style {
SnippetStyle::Signature => self.scores.signature,
SnippetStyle::Declaration => self.scores.declaration,
DeclarationStyle::Signature => self.scores.signature,
DeclarationStyle::Declaration => self.scores.declaration,
}
}
pub fn size(&self, style: SnippetStyle) -> usize {
pub fn size(&self, style: DeclarationStyle) -> usize {
match &self.declaration {
Declaration::File { declaration, .. } => match style {
SnippetStyle::Signature => declaration.signature_range.len(),
SnippetStyle::Declaration => declaration.text.len(),
DeclarationStyle::Signature => declaration.signature_range.len(),
DeclarationStyle::Declaration => declaration.text.len(),
},
Declaration::Buffer { declaration, .. } => match style {
SnippetStyle::Signature => declaration.signature_range.len(),
SnippetStyle::Declaration => declaration.item_range.len(),
DeclarationStyle::Signature => declaration.signature_range.len(),
DeclarationStyle::Declaration => declaration.item_range.len(),
},
}
}
pub fn score_density(&self, style: SnippetStyle) -> f32 {
pub fn score_density(&self, style: DeclarationStyle) -> f32 {
self.score(style) / (self.size(style)) as f32
}
}
pub fn scored_snippets(
pub fn scored_declarations(
index: &SyntaxIndexState,
excerpt: &EditPredictionExcerpt,
excerpt_text: &EditPredictionExcerptText,
excerpt_occurrences: &Occurrences,
adjacent_occurrences: &Occurrences,
identifier_to_references: HashMap<Identifier, Vec<Reference>>,
cursor_offset: usize,
current_buffer: &BufferSnapshot,
) -> Vec<ScoredSnippet> {
let containing_range_identifier_occurrences =
IdentifierOccurrences::within_string(&excerpt_text.body);
) -> Vec<ScoredDeclaration> {
let cursor_point = cursor_offset.to_point(&current_buffer);
let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
let end_point = Point::new(cursor_point.row + 1, 0);
let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
&current_buffer
.text_for_range(start_point..end_point)
.collect::<String>(),
);
let mut snippets = identifier_to_references
let mut declarations = identifier_to_references
.into_iter()
.flat_map(|(identifier, references)| {
let declarations =
@ -137,7 +128,7 @@ pub fn scored_snippets(
)| {
let same_file_declaration_count = index.file_declaration_count(declaration);
score_snippet(
score_declaration(
&identifier,
&references,
declaration.clone(),
@ -146,8 +137,8 @@ pub fn scored_snippets(
declaration_line_distance_rank,
same_file_declaration_count,
declaration_count,
&containing_range_identifier_occurrences,
&adjacent_identifier_occurrences,
&excerpt_occurrences,
&adjacent_occurrences,
cursor_point,
current_buffer,
)
@ -158,14 +149,14 @@ pub fn scored_snippets(
.flatten()
.collect::<Vec<_>>();
snippets.sort_unstable_by_key(|snippet| {
let score_density = snippet
.score_density(SnippetStyle::Declaration)
.max(snippet.score_density(SnippetStyle::Signature));
declarations.sort_unstable_by_key(|declaration| {
let score_density = declaration
.score_density(DeclarationStyle::Declaration)
.max(declaration.score_density(DeclarationStyle::Signature));
Reverse(OrderedFloat(score_density))
});
snippets
declarations
}
fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
@ -178,7 +169,7 @@ fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Rang
}
}
fn score_snippet(
fn score_declaration(
identifier: &Identifier,
references: &[Reference],
declaration: Declaration,
@ -187,11 +178,11 @@ fn score_snippet(
declaration_line_distance_rank: usize,
same_file_declaration_count: usize,
declaration_count: usize,
containing_range_identifier_occurrences: &IdentifierOccurrences,
adjacent_identifier_occurrences: &IdentifierOccurrences,
excerpt_occurrences: &Occurrences,
adjacent_occurrences: &Occurrences,
cursor: Point,
current_buffer: &BufferSnapshot,
) -> Option<ScoredSnippet> {
) -> Option<ScoredDeclaration> {
let is_referenced_nearby = references
.iter()
.any(|r| r.region == ReferenceRegion::Nearby);
@ -208,37 +199,27 @@ fn score_snippet(
.min()
.unwrap();
let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
let item_signature_occurrences =
IdentifierOccurrences::within_string(&declaration.signature_text().0);
let containing_range_vs_item_jaccard = jaccard_similarity(
containing_range_identifier_occurrences,
&item_source_occurrences,
);
let containing_range_vs_signature_jaccard = jaccard_similarity(
containing_range_identifier_occurrences,
&item_signature_occurrences,
);
let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
let excerpt_vs_signature_jaccard =
jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
let adjacent_vs_item_jaccard =
jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
let adjacent_vs_signature_jaccard =
jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
containing_range_identifier_occurrences,
&item_source_occurrences,
);
let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
containing_range_identifier_occurrences,
&item_signature_occurrences,
);
let excerpt_vs_item_weighted_overlap =
weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
let excerpt_vs_signature_weighted_overlap =
weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
let adjacent_vs_item_weighted_overlap =
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
let adjacent_vs_signature_weighted_overlap =
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
// TODO: Consider adding declaration_file_count
let score_components = ScoreComponents {
let score_components = DeclarationScoreComponents {
is_same_file,
is_referenced_nearby,
is_referenced_in_breadcrumb,
@ -248,32 +229,32 @@ fn score_snippet(
reference_count,
same_file_declaration_count,
declaration_count,
containing_range_vs_item_jaccard,
containing_range_vs_signature_jaccard,
excerpt_vs_item_jaccard,
excerpt_vs_signature_jaccard,
adjacent_vs_item_jaccard,
adjacent_vs_signature_jaccard,
containing_range_vs_item_weighted_overlap,
containing_range_vs_signature_weighted_overlap,
excerpt_vs_item_weighted_overlap,
excerpt_vs_signature_weighted_overlap,
adjacent_vs_item_weighted_overlap,
adjacent_vs_signature_weighted_overlap,
};
Some(ScoredSnippet {
Some(ScoredDeclaration {
identifier: identifier.clone(),
declaration: declaration,
scores: Scores::score(&score_components),
scores: DeclarationScores::score(&score_components),
score_components,
})
}
#[derive(Clone, Debug, Serialize)]
pub struct Scores {
pub struct DeclarationScores {
pub signature: f32,
pub declaration: f32,
}
impl Scores {
fn score(components: &ScoreComponents) -> Scores {
impl DeclarationScores {
fn score(components: &DeclarationScoreComponents) -> DeclarationScores {
// TODO: handle truncation
// Score related to how likely this is the correct declaration, range 0 to 1
@ -295,13 +276,11 @@ impl Scores {
// For now instead of linear combination, the scores are just multiplied together.
let combined_score = 10.0 * accuracy_score * distance_score;
Scores {
signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
DeclarationScores {
signature: combined_score * components.excerpt_vs_signature_weighted_overlap,
// declaration score gets boosted both by being multiplied by 2 and by there being more
// weighted overlap.
declaration: 2.0
* combined_score
* components.containing_range_vs_item_weighted_overlap,
declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap,
}
}
}

View file

@ -21,7 +21,7 @@ pub struct EditPredictionContext {
pub excerpt: EditPredictionExcerpt,
pub excerpt_text: EditPredictionExcerptText,
pub cursor_offset_in_excerpt: usize,
pub snippets: Vec<ScoredSnippet>,
pub declarations: Vec<ScoredDeclaration>,
}
impl EditPredictionContext {
@ -58,17 +58,28 @@ impl EditPredictionContext {
index_state,
)?;
let excerpt_text = excerpt.text(buffer);
let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
let adjacent_end = Point::new(cursor_point.row + 1, 0);
let adjacent_occurrences = text_similarity::Occurrences::within_string(
&buffer
.text_for_range(adjacent_start..adjacent_end)
.collect::<String>(),
);
let cursor_offset_in_file = cursor_point.to_offset(buffer);
// TODO fix this to not need saturating_sub
let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
let snippets = if let Some(index_state) = index_state {
let declarations = if let Some(index_state) = index_state {
let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
scored_snippets(
scored_declarations(
&index_state,
&excerpt,
&excerpt_text,
&excerpt_occurrences,
&adjacent_occurrences,
references,
cursor_offset_in_file,
buffer,
@ -81,7 +92,7 @@ impl EditPredictionContext {
excerpt,
excerpt_text,
cursor_offset_in_excerpt,
snippets,
declarations,
})
}
}
@ -137,7 +148,7 @@ mod tests {
.unwrap();
let mut snippet_identifiers = context
.snippets
.declarations
.iter()
.map(|snippet| snippet.identifier.name.as_ref())
.collect::<Vec<_>>();

View file

@ -1,5 +1,9 @@
use hashbrown::HashTable;
use regex::Regex;
use std::{collections::HashMap, sync::LazyLock};
use std::{
hash::{Hash, Hasher as _},
sync::LazyLock,
};
use crate::reference::Reference;
@ -14,47 +18,74 @@ use crate::reference::Reference;
static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
// TODO: use &str or Cow<str> keys?
#[derive(Debug)]
pub struct IdentifierOccurrences {
identifier_to_count: HashMap<String, usize>,
/// Multiset of text occurrences for text similarity that only stores hashes and counts.
#[derive(Debug, Default)]
pub struct Occurrences {
table: HashTable<OccurrenceEntry>,
total_count: usize,
}
impl IdentifierOccurrences {
pub fn within_string(code: &str) -> Self {
Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
#[derive(Debug)]
struct OccurrenceEntry {
hash: u64,
count: usize,
}
impl Occurrences {
pub fn within_string(text: &str) -> Self {
Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str()))
}
#[allow(dead_code)]
pub fn within_references(references: &[Reference]) -> Self {
Self::from_iterator(
Self::from_identifiers(
references
.iter()
.map(|reference| reference.identifier.name.as_ref()),
)
}
pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
let mut identifier_to_count = HashMap::new();
let mut total_count = 0;
for identifier in identifier_iterator {
// TODO: Score matches that match case higher?
//
// TODO: Also include unsplit identifier?
pub fn from_identifiers<'a>(identifiers: impl IntoIterator<Item = &'a str>) -> Self {
let mut this = Self::default();
// TODO: Score matches that match case higher?
//
// TODO: Also include unsplit identifier?
for identifier in identifiers {
for identifier_part in split_identifier(identifier) {
identifier_to_count
.entry(identifier_part.to_lowercase())
.and_modify(|count| *count += 1)
.or_insert(1);
total_count += 1;
this.add_hash(fx_hash(&identifier_part.to_lowercase()));
}
}
IdentifierOccurrences {
identifier_to_count,
total_count,
}
this
}
fn add_hash(&mut self, hash: u64) {
self.table
.entry(
hash,
|entry: &OccurrenceEntry| entry.hash == hash,
|entry| entry.hash,
)
.and_modify(|entry| entry.count += 1)
.or_insert(OccurrenceEntry { hash, count: 1 });
self.total_count += 1;
}
fn contains_hash(&self, hash: u64) -> bool {
self.get_count(hash) != 0
}
fn get_count(&self, hash: u64) -> usize {
self.table
.find(hash, |entry| entry.hash == hash)
.map(|entry| entry.count)
.unwrap_or(0)
}
}
pub fn fx_hash<T: Hash + ?Sized>(data: &T) -> u64 {
let mut hasher = collections::FxHasher::default();
data.hash(&mut hasher);
hasher.finish()
}
// Splits camelcase / snakecase / kebabcase / pascalcase
@ -115,54 +146,49 @@ fn split_identifier(identifier: &str) -> Vec<&str> {
parts.into_iter().filter(|s| !s.is_empty()).collect()
}
pub fn jaccard_similarity<'a>(
mut set_a: &'a IdentifierOccurrences,
mut set_b: &'a IdentifierOccurrences,
) -> f32 {
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
if set_a.table.len() > set_b.table.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let intersection = set_a
.identifier_to_count
.keys()
.filter(|key| set_b.identifier_to_count.contains_key(*key))
.table
.iter()
.filter(|entry| set_b.contains_hash(entry.hash))
.count();
let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
let union = set_a.table.len() + set_b.table.len() - intersection;
intersection as f32 / union as f32
}
// TODO
#[allow(dead_code)]
pub fn overlap_coefficient<'a>(
mut set_a: &'a IdentifierOccurrences,
mut set_b: &'a IdentifierOccurrences,
) -> f32 {
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
if set_a.table.len() > set_b.table.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let intersection = set_a
.identifier_to_count
.keys()
.filter(|key| set_b.identifier_to_count.contains_key(*key))
.table
.iter()
.filter(|entry| set_b.contains_hash(entry.hash))
.count();
intersection as f32 / set_a.identifier_to_count.len() as f32
intersection as f32 / set_a.table.len() as f32
}
// TODO
#[allow(dead_code)]
pub fn weighted_jaccard_similarity<'a>(
mut set_a: &'a IdentifierOccurrences,
mut set_b: &'a IdentifierOccurrences,
mut set_a: &'a Occurrences,
mut set_b: &'a Occurrences,
) -> f32 {
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
if set_a.table.len() > set_b.table.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let mut numerator = 0;
let mut denominator_a = 0;
let mut used_count_b = 0;
for (symbol, count_a) in set_a.identifier_to_count.iter() {
let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
for entry_a in set_a.table.iter() {
let count_a = entry_a.count;
let count_b = set_b.get_count(entry_a.hash);
numerator += count_a.min(count_b);
denominator_a += count_a.max(count_b);
used_count_b += count_b;
@ -177,16 +203,17 @@ pub fn weighted_jaccard_similarity<'a>(
}
pub fn weighted_overlap_coefficient<'a>(
mut set_a: &'a IdentifierOccurrences,
mut set_b: &'a IdentifierOccurrences,
mut set_a: &'a Occurrences,
mut set_b: &'a Occurrences,
) -> f32 {
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
if set_a.table.len() > set_b.table.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let mut numerator = 0;
for (symbol, count_a) in set_a.identifier_to_count.iter() {
let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
for entry_a in set_a.table.iter() {
let count_a = entry_a.count;
let count_b = set_b.get_count(entry_a.hash);
numerator += count_a.min(count_b);
}
@ -215,12 +242,12 @@ mod test {
fn test_similarity_functions() {
// 10 identifier parts, 8 unique
// Repeats: 2 "outline", 2 "items"
let set_a = IdentifierOccurrences::within_string(
let set_a = Occurrences::within_string(
"let mut outline_items = query_outline_items(&language, &tree, &source);",
);
// 14 identifier parts, 11 unique
// Repeats: 2 "outline", 2 "language", 2 "tree"
let set_b = IdentifierOccurrences::within_string(
let set_b = Occurrences::within_string(
"pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
);

View file

@ -733,7 +733,7 @@ fn make_cloud_request(
let mut declaration_to_signature_index = HashMap::default();
let mut referenced_declarations = Vec::new();
for snippet in context.snippets {
for snippet in context.declarations {
let project_entry_id = snippet.declaration.project_entry_id();
let Some(path) = worktrees.iter().find_map(|worktree| {
worktree.entry_for_id(project_entry_id).map(|entry| {

View file

@ -18,7 +18,7 @@ use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
use workspace::{Item, SplitDirection, Workspace};
use zeta2::{Zeta, ZetaOptions};
use edit_prediction_context::{EditPredictionExcerptOptions, SnippetStyle};
use edit_prediction_context::{DeclarationStyle, EditPredictionExcerptOptions};
actions!(
dev,
@ -285,7 +285,7 @@ impl Zeta2Inspector {
let mut languages = HashMap::default();
for lang_id in prediction
.context
.snippets
.declarations
.iter()
.map(|snippet| snippet.declaration.identifier().language_id)
.chain(prediction.context.excerpt_text.language_id)
@ -334,7 +334,7 @@ impl Zeta2Inspector {
cx,
);
for snippet in &prediction.context.snippets {
for snippet in &prediction.context.declarations {
let path = this
.project
.read(cx)
@ -345,7 +345,7 @@ impl Zeta2Inspector {
"{} (Score density: {})",
path.map(|p| p.path.display(path_style).to_string())
.unwrap_or_else(|| "".to_string()),
snippet.score_density(SnippetStyle::Declaration)
snippet.score_density(DeclarationStyle::Declaration)
))
.unwrap()
.into(),