From 9e27da8b4ee79726d5223e69580ec87dcc44968c Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Wed, 28 Oct 2020 11:30:53 +0900 Subject: [PATCH] Added CR comments. Added Unit tests. --- src/collector/docset_collector.rs | 61 ++++++++++++++++++ src/collector/mod.rs | 3 + src/lib.rs | 2 +- src/query/boolean_query/boolean_query.rs | 81 ++++++++++++++++++++++-- src/query/mod.rs | 2 +- src/query/query.rs | 2 + src/schema/term.rs | 1 + 7 files changed, 146 insertions(+), 6 deletions(-) create mode 100644 src/collector/docset_collector.rs diff --git a/src/collector/docset_collector.rs b/src/collector/docset_collector.rs new file mode 100644 index 000000000..26cf7556a --- /dev/null +++ b/src/collector/docset_collector.rs @@ -0,0 +1,61 @@ +use std::collections::HashSet; + +use crate::{DocAddress, DocId, Score}; + +use super::{Collector, SegmentCollector}; + +/// Collectors that returns the set of DocAddress that matches the query. +/// +/// This collector is mostly useful for tests. +pub struct DocSetCollector; + +impl Collector for DocSetCollector { + type Fruit = HashSet; + type Child = DocSetChildCollector; + + fn for_segment( + &self, + segment_local_id: crate::SegmentLocalId, + _segment: &crate::SegmentReader, + ) -> crate::Result { + Ok(DocSetChildCollector { + segment_local_id, + docs: HashSet::new(), + }) + } + + fn requires_scoring(&self) -> bool { + false + } + + fn merge_fruits( + &self, + segment_fruits: Vec<(u32, HashSet)>, + ) -> crate::Result { + let len: usize = segment_fruits.iter().map(|(_, docset)| docset.len()).sum(); + let mut result = HashSet::with_capacity(len); + for (segment_local_id, docs) in segment_fruits { + for doc in docs { + result.insert(DocAddress(segment_local_id, doc)); + } + } + Ok(result) + } +} + +pub struct DocSetChildCollector { + segment_local_id: u32, + docs: HashSet, +} + +impl SegmentCollector for DocSetChildCollector { + type Fruit = (u32, HashSet); + + fn collect(&mut self, doc: crate::DocId, _score: Score) { + self.docs.insert(doc); + } + + fn harvest(self) -> (u32, HashSet) { + (self.segment_local_id, self.docs) + } +} diff --git a/src/collector/mod.rs b/src/collector/mod.rs index b47118007..2533c1950 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -111,6 +111,9 @@ mod facet_collector; pub use self::facet_collector::FacetCollector; use crate::query::Weight; +mod docset_collector; +pub use self::docset_collector::DocSetCollector; + /// `Fruit` is the type for the result of our collection. /// e.g. `usize` for the `Count` collector. pub trait Fruit: Send + downcast_rs::Downcast {} diff --git a/src/lib.rs b/src/lib.rs index b55cf1d4b..a7a0670a0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -277,7 +277,7 @@ impl DocAddress { /// /// The id used for the segment is actually an ordinal /// in the list of `Segment`s held by a `Searcher`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct DocAddress(pub SegmentLocalId, pub DocId); #[cfg(test)] diff --git a/src/query/boolean_query/boolean_query.rs b/src/query/boolean_query/boolean_query.rs index 80c7d36a2..8cde29bea 100644 --- a/src/query/boolean_query/boolean_query.rs +++ b/src/query/boolean_query/boolean_query.rs @@ -143,7 +143,7 @@ impl Clone for BooleanQuery { impl From)>> for BooleanQuery { fn from(subqueries: Vec<(Occur, Box)>) -> BooleanQuery { - BooleanQuery { subqueries } + BooleanQuery::new(subqueries) } } @@ -167,7 +167,6 @@ impl Query for BooleanQuery { } impl BooleanQuery { - /// Creates a new boolean query. pub fn new(subqueries: Vec<(Occur, Box)>) -> BooleanQuery { BooleanQuery { subqueries } @@ -176,13 +175,13 @@ impl BooleanQuery { /// Returns the intersection of the queries. pub fn intersection(queries: Vec>) -> BooleanQuery { let subqueries = queries.into_iter().map(|s| (Occur::Must, s)).collect(); - BooleanQuery { subqueries } + BooleanQuery::new(subqueries) } /// Returns the union of the queries. pub fn union(queries: Vec>) -> BooleanQuery { let subqueries = queries.into_iter().map(|s| (Occur::Should, s)).collect(); - BooleanQuery { subqueries } + BooleanQuery::new(subqueries) } /// Helper method to create a boolean query matching a given list of terms. @@ -204,3 +203,77 @@ impl BooleanQuery { &self.subqueries[..] } } + +#[cfg(test)] +mod tests { + use super::BooleanQuery; + use crate::collector::DocSetCollector; + use crate::query::{QueryClone, TermQuery}; + use crate::schema::{IndexRecordOption, Schema, TEXT}; + use crate::{DocAddress, Index, Term}; + + fn create_test_index() -> crate::Result { + let mut schema_builder = Schema::builder(); + let text = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(text=>"b c")); + writer.add_document(doc!(text=>"a c")); + writer.add_document(doc!(text=>"a b")); + writer.add_document(doc!(text=>"a d")); + writer.commit()?; + Ok(index) + } + + #[test] + fn test_union() -> crate::Result<()> { + let index = create_test_index()?; + let searcher = index.reader()?.searcher(); + let text = index.schema().get_field("text").unwrap(); + let term_a = TermQuery::new(Term::from_field_text(text, "a"), IndexRecordOption::Basic); + let term_d = TermQuery::new(Term::from_field_text(text, "d"), IndexRecordOption::Basic); + let union_ad = BooleanQuery::union(vec![term_a.box_clone(), term_d.box_clone()]); + let docs = searcher.search(&union_ad, &DocSetCollector)?; + assert_eq!( + docs, + vec![ + DocAddress(0u32, 1u32), + DocAddress(0u32, 2u32), + DocAddress(0u32, 3u32) + ] + .into_iter() + .collect() + ); + Ok(()) + } + + #[test] + fn test_intersection() -> crate::Result<()> { + let index = create_test_index()?; + let searcher = index.reader()?.searcher(); + let text = index.schema().get_field("text").unwrap(); + let term_a = TermQuery::new(Term::from_field_text(text, "a"), IndexRecordOption::Basic); + let term_b = TermQuery::new(Term::from_field_text(text, "b"), IndexRecordOption::Basic); + let term_c = TermQuery::new(Term::from_field_text(text, "c"), IndexRecordOption::Basic); + let intersection_ab = + BooleanQuery::intersection(vec![term_a.box_clone(), term_b.box_clone()]); + let intersection_ac = + BooleanQuery::intersection(vec![term_a.box_clone(), term_c.box_clone()]); + let intersection_bc = + BooleanQuery::intersection(vec![term_b.box_clone(), term_c.box_clone()]); + { + let docs = searcher.search(&intersection_ab, &DocSetCollector)?; + assert_eq!(docs, vec![DocAddress(0u32, 2u32)].into_iter().collect()); + } + { + let docs = searcher.search(&intersection_ac, &DocSetCollector)?; + assert_eq!(docs, vec![DocAddress(0u32, 1u32)].into_iter().collect()); + } + { + let docs = searcher.search(&intersection_bc, &DocSetCollector)?; + assert_eq!(docs, vec![DocAddress(0u32, 0u32)].into_iter().collect()); + } + Ok(()) + } +} diff --git a/src/query/mod.rs b/src/query/mod.rs index dfeeea290..d404c5642 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -46,7 +46,7 @@ pub(crate) use self::fuzzy_query::DFAWrapper; pub use self::fuzzy_query::FuzzyTermQuery; pub use self::intersection::intersect_scorers; pub use self::phrase_query::PhraseQuery; -pub use self::query::Query; +pub use self::query::{Query, QueryClone}; pub use self::query_parser::QueryParser; pub use self::query_parser::QueryParserError; pub use self::range_query::RangeQuery; diff --git a/src/query/query.rs b/src/query/query.rs index 2f561cd2b..25515c1d5 100644 --- a/src/query/query.rs +++ b/src/query/query.rs @@ -71,7 +71,9 @@ pub trait Query: QueryClone + Send + Sync + downcast_rs::Downcast + fmt::Debug { fn query_terms(&self, _term_set: &mut BTreeSet) {} } +/// Implements `box_clone`. pub trait QueryClone { + /// Returns a boxed clone of `self`. fn box_clone(&self) -> Box; } diff --git a/src/schema/term.rs b/src/schema/term.rs index 3bd590818..0662e5230 100644 --- a/src/schema/term.rs +++ b/src/schema/term.rs @@ -96,6 +96,7 @@ impl Term { term } + /// Builds a term bytes. pub fn from_field_bytes(field: Field, bytes: &[u8]) -> Term { let mut term = Term::for_field(field); term.set_bytes(bytes);