diff --git a/src/postings/docset.rs b/src/postings/docset.rs index 5194c2690..5a722f920 100644 --- a/src/postings/docset.rs +++ b/src/postings/docset.rs @@ -95,7 +95,7 @@ pub trait DocSet { /// length of the docset. fn size_hint(&self) -> u32; - fn to_doc_bitset(mut self, max_doc: DocId) -> DocBitSet { + fn to_doc_bitset(&mut self, max_doc: DocId) -> DocBitSet { let mut docs = DocBitSet::with_maxdoc(max_doc); while self.advance() { let doc = self.doc(); diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index bebaac443..197758c26 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -251,11 +251,11 @@ impl DocSet for SegmentPostings { docs[self.cur] } - fn to_doc_bitset(mut self, max_doc: DocId) -> DocBitSet { + fn to_doc_bitset(&mut self, max_doc: DocId) -> DocBitSet { // finish the current block let mut docs = DocBitSet::with_maxdoc(max_doc); if self.advance() { - for &doc in self.block_cursor.docs()[self.cur..] { + for &doc in &self.block_cursor.docs()[self.cur..] { docs.insert(doc); } // ... iterate through the remaining blocks. diff --git a/src/query/mod.rs b/src/query/mod.rs index aafeb0a7c..f22ef4ddb 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -13,6 +13,8 @@ mod query_parser; mod phrase_query; mod all_query; mod bitset; +mod range_query; + pub use self::bitset::BitSetDocSet; pub use self::boolean_query::BooleanQuery; @@ -27,3 +29,6 @@ pub use self::scorer::Scorer; pub use self::term_query::TermQuery; pub use self::weight::Weight; pub use self::all_query::{AllQuery, AllWeight, AllScorer}; +pub use self::range_query::{RangeQuery, RangeWeight}; +pub use self::scorer::ConstScorer; + diff --git a/src/query/range_query.rs b/src/query/range_query.rs new file mode 100644 index 000000000..d4727ff3b --- /dev/null +++ b/src/query/range_query.rs @@ -0,0 +1,205 @@ +use schema::{Field, Term, IndexRecordOption}; +use query::{Query, Weight, Scorer}; +use termdict::{TermDictionary, TermStreamer, TermStreamerBuilder}; +use core::SegmentReader; +use common::DocBitSet; +use Result; +use std::any::Any; +use core::Searcher; +use query::BitSetDocSet; +use query::ConstScorer; + +#[derive(Clone, Debug)] +enum Boundary { + Included(Vec), + Excluded(Vec), + Unbounded, +} + +#[derive(Clone, Debug)] +pub struct RangeDefinition { + field: Field, + left_bound: Boundary, + right_bound: Boundary +} + +impl RangeDefinition { + fn for_field(field: Field) -> RangeDefinition{ + RangeDefinition { + field, + left_bound: Boundary::Unbounded, + right_bound: Boundary::Unbounded + } + } + + fn left_included(mut self, left: Term) -> RangeDefinition { + assert_eq!(left.field(), self.field); + self.left_bound = Boundary::Included(left.value_bytes().to_owned()); + self + } + + fn left_excluded(mut self, left: Term) -> RangeDefinition { + assert_eq!(left.field(), self.field); + self.left_bound = Boundary::Excluded(left.value_bytes().to_owned()); + self + } + + fn right_included(mut self, right: Term) -> RangeDefinition { + assert_eq!(right.field(), self.field); + self.right_bound = Boundary::Included(right.value_bytes().to_owned()); + self + } + + fn right_excluded(mut self, right: Term) -> RangeDefinition { + assert_eq!(right.field(), self.field); + self.right_bound = Boundary::Excluded(right.value_bytes().to_owned()); + self + } + + fn term_range<'a, T>(&self, term_dict: &'a T) -> T::Streamer + where T: TermDictionary<'a> + 'a + { + use self::Boundary::*; + let mut term_stream_builder = term_dict.range(); + term_stream_builder = + match &self.left_bound { + &Included(ref term_val) => term_stream_builder.ge(term_val), + &Excluded(ref term_val) => term_stream_builder.gt(term_val), + &Unbounded => term_stream_builder + }; + term_stream_builder = + match &self.right_bound { + &Included(ref term_val) => term_stream_builder.le(term_val), + &Excluded(ref term_val) => term_stream_builder.lt(term_val), + &Unbounded => term_stream_builder + }; + term_stream_builder.into_stream() + } +} + +#[derive(Debug)] +pub struct RangeQuery { + range_definition: RangeDefinition +} + +impl RangeQuery { + fn new(range_definition: RangeDefinition) -> RangeQuery { + RangeQuery { + range_definition + } + } +} + +impl Query for RangeQuery { + fn as_any(&self) -> &Any { + self + } + + fn weight(&self, _searcher: &Searcher) -> Result> { + Ok(box RangeWeight { + range_definition: self.range_definition.clone() + }) + } +} + + +pub struct RangeWeight { + range_definition: RangeDefinition +} + +impl Weight for RangeWeight { + fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { + let max_doc = reader.max_doc(); + let mut doc_bitset = DocBitSet::with_maxdoc(max_doc); + + let inverted_index = reader.inverted_index(self.range_definition.field); + let term_dict = inverted_index.terms(); + let mut term_range = self.range_definition.term_range(term_dict); + while term_range.advance() { + let term_info = term_range.value(); + let mut block_segment_postings = inverted_index.read_block_postings_from_terminfo(term_info,IndexRecordOption::Basic); + while block_segment_postings.advance() { + for &doc in block_segment_postings.docs() { + doc_bitset.insert(doc); + } + } + } + let doc_bitset = BitSetDocSet::from(doc_bitset); + Ok(box ConstScorer::new(doc_bitset)) + } +} + +#[cfg(test)] +mod tests { + + use Index; + use schema::{SchemaBuilder, Field, Document, INT_INDEXED}; + + + #[test] + fn test_range_query() { + let int_field: Field; + let schema = { + let mut schema_builder = SchemaBuilder::new(); + int_field = schema_builder.add_i64_field("intfield", INT_INDEXED); + schema_builder.build() + }; + + let index = Index::create_in_ram(schema); + { + let mut index_writer = index.writer_with_num_threads(2, 6_000_000).unwrap(); + + for i in 1..100 { + let mut doc = Document::new(); + for j in 1..100 { + if i % j == 0 { + doc.add_i64(int_field, j as i64); + } + } + index_writer.add_document(doc); + } + + index_writer.commit().unwrap(); + } + index.load_searchers().unwrap(); + let searcher = index.searcher(); + use collector::CountCollector; + use schema::Term; + use query::Query; + use super::{RangeQuery, RangeDefinition}; + + let count_multiples = |range: RangeDefinition| { + let mut count_collector = CountCollector::default(); + let range_query = RangeQuery::new(range); + range_query.search(&*searcher, &mut count_collector).unwrap(); + count_collector.count() + }; + + assert_eq!( + count_multiples(RangeDefinition::for_field(int_field) + .left_included(Term::from_field_i64(int_field, 10)) + .right_excluded(Term::from_field_i64(int_field, 11))) + , 9 + ); + assert_eq!( + count_multiples(RangeDefinition::for_field(int_field) + .left_included(Term::from_field_i64(int_field, 10)) + .right_included(Term::from_field_i64(int_field, 11))) + , 18 + ); + assert_eq!( + count_multiples(RangeDefinition::for_field(int_field) + .left_excluded(Term::from_field_i64(int_field, 9)) + .right_included(Term::from_field_i64(int_field, 10))) + , 9 + ); + assert_eq!( + count_multiples(RangeDefinition::for_field(int_field) + .left_excluded(Term::from_field_i64(int_field, 9))) + , 90 + ); + + } + + +} \ No newline at end of file diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 04bd13619..22783967a 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -2,6 +2,8 @@ use DocSet; use DocId; use Score; use collector::Collector; +use postings::SkipResult; +use common::DocBitSet; use std::ops::{Deref, DerefMut}; /// Scored set of documents matching a query within a specific segment. @@ -59,3 +61,54 @@ impl Scorer for EmptyScorer { 0f32 } } + +pub struct ConstScorer { + docset: TDocSet, + score: Score +} + +impl ConstScorer { + pub fn new(docset: TDocSet) -> ConstScorer { + ConstScorer { + docset, + score: 1f32 + } + } + + pub fn set_score(&mut self, score: Score) { + self.score = score; + } +} + +impl DocSet for ConstScorer { + fn advance(&mut self) -> bool { + self.docset.advance() + } + + fn skip_next(&mut self, target: DocId) -> SkipResult { + self.docset.skip_next(target) + } + + fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize { + self.docset.fill_buffer(buffer) + } + + fn doc(&self) -> DocId { + self.docset.doc() + } + + fn size_hint(&self) -> u32 { + self.docset.size_hint() + } + + fn to_doc_bitset(&mut self, max_doc: DocId) -> DocBitSet { + self.docset.to_doc_bitset(max_doc) + } +} + + +impl Scorer for ConstScorer { + fn score(&self) -> Score { + 1f32 + } +}