diff --git a/.cargo/config.toml b/.cargo/config.toml index e39f7b055b70a7db694f891193ff5b424514e9dd..4922804fa4a3e47d2afeb6d824918c607000013a 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -8,6 +8,9 @@ debug = "limited" opt-level = 3 debug = "line-tables-only" +[profile.dev.package.tantivy] +debug-assertions = false + [profile.release] lto = "thin" strip = "debuginfo" diff --git a/src/index/custom_combiner.rs b/src/index/custom_combiner.rs new file mode 100644 index 0000000000000000000000000000000000000000..8cab3fd2a3db08010643e8dba1f33079131fdeca --- /dev/null +++ b/src/index/custom_combiner.rs @@ -0,0 +1,74 @@ +use tantivy::{ + Result as TantivyResult, Score, Term, + query::{BooleanWeight, EnableScoring, Occur, Query, ScoreCombiner, Scorer, Weight}, +}; + +#[derive(Default, Clone, Copy)] +pub struct CustomCombiner { + score: Score, + num_matching_clauses: usize, +} + +impl ScoreCombiner for CustomCombiner { + fn update<S>(&mut self, scorer: &mut S) + where + S: Scorer, + { + self.score += scorer.score(); + self.num_matching_clauses += 1; + } + + fn clear(&mut self) { + self.score = 0.0; + self.num_matching_clauses = 0; + } + + fn score(&self) -> Score { + self.score * self.num_matching_clauses as Score + } +} + +#[derive(Debug)] +pub struct CustomBooleanQuery { + clauses: Vec<(Occur, Box<dyn Query>)>, +} + +impl CustomBooleanQuery { + pub fn new(clauses: Vec<(Occur, Box<dyn Query>)>) -> Self { + Self { clauses } + } +} + +impl Clone for CustomBooleanQuery { + fn clone(&self) -> Self { + let clauses = self + .clauses + .iter() + .map(|(occur, clause)| (*occur, clause.box_clone())) + .collect(); + + Self { clauses } + } +} + +impl Query for CustomBooleanQuery { + fn weight(&self, enable_scoring: EnableScoring<'_>) -> TantivyResult<Box<dyn Weight>> { + let clauses = self + .clauses + .iter() + .map(|(occur, clause)| clause.weight(enable_scoring).map(|weight| (*occur, weight))) + .collect::<Result<Vec<_>, _>>()?; + + Ok(Box::new(BooleanWeight::new( + clauses, + enable_scoring.is_scoring_enabled(), + Box::new(CustomCombiner::default), + ))) + } + + fn query_terms<'a>(&'a self, visitor: &mut dyn FnMut(&'a Term, bool)) { + for (_occur, clause) in &self.clauses { + clause.query_terms(visitor); + } + } +} diff --git a/src/index/mod.rs b/src/index/mod.rs index a2e1557237ee7d4d832f13c80f6ba2ac2bf6608c..730d633503e21db5b3cb45875152386c2cf4e15b 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -1,5 +1,6 @@ mod bounding_box; pub(crate) mod collector; +mod custom_combiner; mod custom_tokenizer; mod indexer; mod scorer; diff --git a/src/index/searcher/mod.rs b/src/index/searcher/mod.rs index 9aa93f642fb6eead6c7d368705453ed3dc55a2a6..83be3f65066a4db7159af5c79769b69d2dfed0c0 100644 --- a/src/index/searcher/mod.rs +++ b/src/index/searcher/mod.rs @@ -24,7 +24,7 @@ use tantivy::{ ScoreTweaker as TantivyScoreTweaker, TopDocs, }, postings::TermInfo, - query::{BooleanQuery, QueryParser, TermQuery}, + query::{BooleanQuery, Query, QueryParser, TermQuery}, schema::{Facet, Field, IndexRecordOption, OwnedValue as Value, TantivyDocument as Document}, }; use tantivy_columnar::{ColumnValues, StrColumn}; @@ -38,6 +38,7 @@ use crate::{ Fields, bounding_box::BoundingBoxes, collector::{AllDocs, FirstDoc}, + custom_combiner::CustomBooleanQuery, index_reader, register_tokenizers, spatial_cluster::{SpatialCluster, SpatialClustersCollector}, split_compound_nouns::SplitCompoundNouns, @@ -190,20 +191,18 @@ impl Searcher { None }; - let mut ast = query.to_ast()?; + let ast = query.to_ast()?; let mut terms = Vec::new(); collect_terms(&mut terms, &ast); - if let Some(split_compound_nouns) = &*self.split_compound_nouns.load() { - split_compound_nouns.rewrite( - &self.index.schema(), - &default_fields(&self.fields), - &mut ast, - ); - } - - let query = self.parser.build_query_from_user_input_ast(ast)?; + let query = interpret_query( + &self.index, + &self.parser, + &self.split_compound_nouns, + &self.fields, + ast, + )?; let mut queries = Vec::with_capacity(8); queries.push(query); @@ -352,19 +351,15 @@ impl Searcher { } pub fn explain(&self, query: &dyn QueryRepr, source: &str, id: &str) -> Result<String> { - let mut ast = query.to_ast()?; - - if let Some(split_compound_nouns) = &*self.split_compound_nouns.load() { - split_compound_nouns.rewrite( - &self.index.schema(), - &default_fields(&self.fields), - &mut ast, - ); - } - let searcher = self.reader.searcher(); - let query = self.parser.build_query_from_user_input_ast(ast)?; + let query = interpret_query( + &self.index, + &self.parser, + &self.split_compound_nouns, + &self.fields, + query.to_ast()?, + )?; let doc = { let source_query = Box::new(TermQuery::new( @@ -416,17 +411,15 @@ impl Searcher { time_range: Option<TimeRange>, sampling_fraction: Option<f32>, ) -> Result<impl Stream<Item = Result<UniquelyIdentifiedDataset>> + use<'a>> { - let mut ast = query.to_ast()?; - - if let Some(split_compound_nouns) = &*self.split_compound_nouns.load() { - split_compound_nouns.rewrite( - &self.index.schema(), - &default_fields(&self.fields), - &mut ast, - ); - } + let searcher = self.reader.searcher(); - let query = self.parser.build_query_from_user_input_ast(ast)?; + let query = interpret_query( + &self.index, + &self.parser, + &self.split_compound_nouns, + &self.fields, + query.to_ast()?, + )?; let mut queries = Vec::with_capacity(3); queries.push(query); @@ -451,12 +444,10 @@ impl Searcher { queries.into_iter().next().unwrap() }; - let searcher = self.reader.searcher(); + let docs = searcher.search(&query, &AllDocs)?; let mut sampler = Sampler::new(sampling_fraction); - let docs = searcher.search(&query, &AllDocs)?; - Ok(try_stream! { for mut segment in docs { loop { @@ -775,6 +766,32 @@ fn last_term(ast: UserInputAst) -> Option<(String, Option<String>)> { } } +fn interpret_query( + index: &Index, + parser: &QueryParser, + split_compound_nouns: &ArcSwapOption<SplitCompoundNouns>, + fields: &Fields, + mut ast: UserInputAst, +) -> Result<Box<dyn Query>> { + if let Some(split_compound_nouns) = &*split_compound_nouns.load() { + split_compound_nouns.rewrite(&index.schema(), &default_fields(fields), &mut ast); + } + + let mut query = parser.build_query_from_user_input_ast(ast)?; + + if let Some(boolean_query) = query.downcast_ref::<BooleanQuery>() { + let clauses = boolean_query + .clauses() + .iter() + .map(|(occur, clause)| (*occur, clause.box_clone())) + .collect::<Vec<_>>(); + + query = Box::new(CustomBooleanQuery::new(clauses)); + } + + Ok(query) +} + struct ScoreTweaker<'a> { origin_weights: &'a [FacetWeight], }