Issue/range query (#242)

BitSet and RangeQuery
This commit is contained in:
Paul Masurel
2018-02-05 09:33:25 +09:00
committed by GitHub
parent 6a104e4f69
commit 1fc7afa90a
23 changed files with 1247 additions and 90 deletions

View File

@@ -59,8 +59,8 @@ impl DocSet for AllScorer {
self.doc
}
fn size_hint(&self) -> usize {
self.max_doc as usize
fn size_hint(&self) -> u32 {
self.max_doc
}
}

268
src/query/bitset/mod.rs Normal file
View File

@@ -0,0 +1,268 @@
use common::{BitSet, TinySet};
use DocId;
use postings::DocSet;
use postings::SkipResult;
use std::cmp::Ordering;
/// A `BitSetDocSet` makes it possible to iterate through a bitset as if it was a `DocSet`.
///
/// # Implementation detail
///
/// Skipping is relatively fast here as we can directly point to the
/// right tiny bitset bucket.
///
/// TODO: Consider implementing a `BitTreeSet` in order to advance faster
/// when the bitset is sparse
pub struct BitSetDocSet {
docs: BitSet,
cursor_bucket: u32, //< index associated to the current tiny bitset
cursor_tinybitset: TinySet,
doc: u32,
}
impl BitSetDocSet {
fn go_to_bucket(&mut self, bucket_addr: u32) {
self.cursor_bucket = bucket_addr;
self.cursor_tinybitset = self.docs.tinyset(bucket_addr);
}
}
impl From<BitSet> for BitSetDocSet {
fn from(docs: BitSet) -> BitSetDocSet {
let first_tiny_bitset = if docs.max_value() == 0 {
TinySet::empty()
} else {
docs.tinyset(0)
};
BitSetDocSet {
docs,
cursor_bucket: 0,
cursor_tinybitset: first_tiny_bitset,
doc: 0u32,
}
}
}
impl DocSet for BitSetDocSet {
fn advance(&mut self) -> bool {
if let Some(lower) = self.cursor_tinybitset.pop_lowest() {
self.doc = (self.cursor_bucket as u32 * 64u32) | lower;
return true;
}
if let Some(cursor_bucket) = self.docs.first_non_empty_bucket(self.cursor_bucket + 1) {
self.go_to_bucket(cursor_bucket);
let lower = self.cursor_tinybitset.pop_lowest().unwrap();
self.doc = (cursor_bucket * 64u32) | lower;
true
} else {
false
}
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
// skip is required to advance.
if !self.advance() {
return SkipResult::End;
}
let target_bucket = target / 64u32;
// Mask for all of the bits greater or equal
// to our target document.
match target_bucket.cmp(&self.cursor_bucket) {
Ordering::Greater => {
self.go_to_bucket(target_bucket);
let greater_filter: TinySet = TinySet::range_greater_or_equal(target);
self.cursor_tinybitset = self.cursor_tinybitset.intersect(greater_filter);
if !self.advance() {
SkipResult::End
} else {
if self.doc() == target {
SkipResult::Reached
} else {
debug_assert!(self.doc() > target);
SkipResult::OverStep
}
}
}
Ordering::Equal => loop {
match self.doc().cmp(&target) {
Ordering::Less => {
if !self.advance() {
return SkipResult::End;
}
}
Ordering::Equal => {
return SkipResult::Reached;
}
Ordering::Greater => {
debug_assert!(self.doc() > target);
return SkipResult::OverStep;
}
}
},
Ordering::Less => {
debug_assert!(self.doc() > target);
SkipResult::OverStep
}
}
}
/// Returns the current document
fn doc(&self) -> DocId {
self.doc
}
/// Advances the cursor to the next document
/// None is returned if the iterator has `DocSet`
/// has already been entirely consumed.
fn next(&mut self) -> Option<DocId> {
if self.advance() {
Some(self.doc())
} else {
None
}
}
/// Returns half of the `max_doc`
/// This is quite a terrible heuristic,
/// but we don't have access to any better
/// value.
fn size_hint(&self) -> u32 {
self.docs.len() as u32
}
}
#[cfg(test)]
mod tests {
use DocId;
use common::BitSet;
use postings::{DocSet, SkipResult};
use super::BitSetDocSet;
extern crate test;
fn create_docbitset(docs: &[DocId], max_doc: DocId) -> BitSetDocSet {
let mut docset = BitSet::with_max_value(max_doc);
for &doc in docs {
docset.insert(doc);
}
BitSetDocSet::from(docset)
}
fn test_go_through_sequential(docs: &[DocId]) {
let mut docset = create_docbitset(docs, 1_000u32);
for &doc in docs {
assert!(docset.advance());
assert_eq!(doc, docset.doc());
}
assert!(!docset.advance());
assert!(!docset.advance());
}
#[test]
fn test_docbitset_sequential() {
test_go_through_sequential(&[]);
test_go_through_sequential(&[1, 2, 3]);
test_go_through_sequential(&[1, 2, 3, 4, 5, 63, 64, 65]);
test_go_through_sequential(&[63, 64, 65]);
test_go_through_sequential(&[1, 2, 3, 4, 95, 96, 97, 98, 99]);
}
#[test]
fn test_docbitset_skip() {
{
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112], 10_000);
assert_eq!(docset.skip_next(7), SkipResult::Reached);
assert_eq!(docset.doc(), 7);
assert!(docset.advance(), 7);
assert_eq!(docset.doc(), 5112);
assert!(!docset.advance());
}
{
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112], 10_000);
assert_eq!(docset.skip_next(3), SkipResult::OverStep);
assert_eq!(docset.doc(), 5);
assert!(docset.advance());
}
{
let mut docset = create_docbitset(&[5112], 10_000);
assert_eq!(docset.skip_next(5112), SkipResult::Reached);
assert_eq!(docset.doc(), 5112);
assert!(!docset.advance());
}
{
let mut docset = create_docbitset(&[5112], 10_000);
assert_eq!(docset.skip_next(5113), SkipResult::End);
assert!(!docset.advance());
}
{
let mut docset = create_docbitset(&[5112], 10_000);
assert_eq!(docset.skip_next(5111), SkipResult::OverStep);
assert_eq!(docset.doc(), 5112);
assert!(!docset.advance());
}
{
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5500, 6666], 10_000);
assert_eq!(docset.skip_next(5112), SkipResult::Reached);
assert_eq!(docset.doc(), 5112);
assert!(docset.advance());
assert_eq!(docset.doc(), 5500);
assert!(docset.advance());
assert_eq!(docset.doc(), 6666);
assert!(!docset.advance());
}
{
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5500, 6666], 10_000);
assert_eq!(docset.skip_next(5111), SkipResult::OverStep);
assert_eq!(docset.doc(), 5112);
assert!(docset.advance());
assert_eq!(docset.doc(), 5500);
assert!(docset.advance());
assert_eq!(docset.doc(), 6666);
assert!(!docset.advance());
}
{
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5513, 6666], 10_000);
assert_eq!(docset.skip_next(5111), SkipResult::OverStep);
assert_eq!(docset.doc(), 5112);
assert!(docset.advance());
assert_eq!(docset.doc(), 5513);
assert!(docset.advance());
assert_eq!(docset.doc(), 6666);
assert!(!docset.advance());
}
}
#[bench]
fn bench_bitset_1pct_insert(b: &mut test::Bencher) {
use tests;
let els = tests::generate_nonunique_unsorted(1_000_000u32, 10_000);
b.iter(|| {
let mut bitset = BitSet::with_max_value(1_000_000);
for el in els.iter().cloned() { bitset.insert(el); }
});
}
#[bench]
fn bench_bitset_1pct_clone(b: &mut test::Bencher) {
use tests;
let els = tests::generate_nonunique_unsorted(1_000_000u32, 10_000);
let mut bitset = BitSet::with_max_value(1_000_000);
for el in els { bitset.insert(el); }
b.iter(|| { bitset.clone() });
}
#[bench]
fn bench_bitset_1pct_clone_iterate(b: &mut test::Bencher) {
use tests;
use DocSet;
let els = tests::generate_nonunique_unsorted(1_000_000u32, 10_000);
let mut bitset = BitSet::with_max_value(1_000_000);
for el in els { bitset.insert(el); }
b.iter(|| {
let mut docset = BitSetDocSet::from(bitset.clone());
while docset.advance() {}
});
}
}

View File

@@ -8,7 +8,6 @@ use schema::Term;
use query::TermQuery;
use schema::IndexRecordOption;
use query::Occur;
use query::OccurFilter;
/// The boolean query combines a set of queries
///
@@ -39,14 +38,9 @@ impl Query for BooleanQuery {
fn weight(&self, searcher: &Searcher) -> Result<Box<Weight>> {
let sub_weights = self.subqueries
.iter()
.map(|&(ref _occur, ref subquery)| subquery.weight(searcher))
.map(|&(ref occur, ref subquery)| Ok((*occur, subquery.weight(searcher)?)))
.collect::<Result<_>>()?;
let occurs: Vec<Occur> = self.subqueries
.iter()
.map(|&(ref occur, ref _subquery)| *occur)
.collect();
let filter = OccurFilter::new(&occurs);
Ok(box BooleanWeight::new(sub_weights, filter))
Ok(box BooleanWeight::new(sub_weights))
}
}

View File

@@ -90,7 +90,7 @@ impl<TScorer: Scorer> BooleanScorer<TScorer> {
}
impl<TScorer: Scorer> DocSet for BooleanScorer<TScorer> {
fn size_hint(&self) -> usize {
fn size_hint(&self) -> u32 {
// TODO fix this. it should be the min
// of the MUST scorer
// and the max of the SHOULD scorers.

View File

@@ -1,31 +1,43 @@
use query::Weight;
use core::SegmentReader;
use query::EmptyScorer;
use query::Scorer;
use super::BooleanScorer;
use query::OccurFilter;
use query::Occur;
use Result;
pub struct BooleanWeight {
weights: Vec<Box<Weight>>,
occur_filter: OccurFilter,
weights: Vec<(Occur, Box<Weight>)>,
}
impl BooleanWeight {
pub fn new(weights: Vec<Box<Weight>>, occur_filter: OccurFilter) -> BooleanWeight {
BooleanWeight {
weights,
occur_filter,
}
pub fn new(weights: Vec<(Occur, Box<Weight>)>) -> BooleanWeight {
BooleanWeight { weights }
}
}
impl Weight for BooleanWeight {
fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result<Box<Scorer + 'a>> {
let sub_scorers: Vec<Box<Scorer + 'a>> = self.weights
.iter()
.map(|weight| weight.scorer(reader))
.collect::<Result<_>>()?;
let boolean_scorer = BooleanScorer::new(sub_scorers, self.occur_filter);
Ok(box boolean_scorer)
if self.weights.is_empty() {
Ok(box EmptyScorer)
} else if self.weights.len() == 1 {
let &(occur, ref weight) = &self.weights[0];
if occur == Occur::MustNot {
Ok(box EmptyScorer)
} else {
weight.scorer(reader)
}
} else {
let sub_scorers: Vec<Box<Scorer + 'a>> = self.weights
.iter()
.map(|&(_, ref weight)| weight)
.map(|weight| weight.scorer(reader))
.collect::<Result<_>>()?;
let occurs: Vec<Occur> = self.weights.iter().map(|&(ref occur, _)| *occur).collect();
let occur_filter = OccurFilter::new(&occurs);
let boolean_scorer = BooleanScorer::new(sub_scorers, occur_filter);
Ok(box boolean_scorer)
}
}
}

View File

@@ -12,7 +12,10 @@ mod term_query;
mod query_parser;
mod phrase_query;
mod all_query;
mod bitset;
mod range_query;
pub use self::bitset::BitSetDocSet;
pub use self::boolean_query::BooleanQuery;
pub use self::occur_filter::OccurFilter;
pub use self::occur::Occur;
@@ -24,4 +27,7 @@ pub use self::scorer::EmptyScorer;
pub use self::scorer::Scorer;
pub use self::term_query::TermQuery;
pub use self::weight::Weight;
pub use self::all_query::{AllQuery, AllScorer, AllWeight};
pub use self::range_query::RangeQuery;
pub use self::scorer::ConstScorer;

View File

@@ -35,7 +35,7 @@ impl DocSet for PostingsWithOffset {
self.segment_postings.doc()
}
fn size_hint(&self) -> usize {
fn size_hint(&self) -> u32 {
self.segment_postings.size_hint()
}
@@ -122,7 +122,7 @@ impl DocSet for PhraseScorer {
self.intersection_docset.doc()
}
fn size_hint(&self) -> usize {
fn size_hint(&self) -> u32 {
self.intersection_docset.size_hint()
}
}

292
src/query/range_query.rs Normal file
View File

@@ -0,0 +1,292 @@
use schema::{Field, IndexRecordOption, Term};
use query::{Query, Scorer, Weight};
use termdict::{TermDictionary, TermStreamer, TermStreamerBuilder};
use core::SegmentReader;
use common::BitSet;
use Result;
use std::any::Any;
use core::Searcher;
use query::BitSetDocSet;
use query::ConstScorer;
use std::collections::Bound;
use std::collections::range::RangeArgument;
fn map_bound<TFrom, Transform: Fn(TFrom)->Vec<u8> >(bound: Bound<TFrom>, transform: &Transform) -> Bound<Vec<u8>> {
use self::Bound::*;
match bound {
Excluded(from_val) => Excluded(transform(from_val)),
Included(from_val) => Included(transform(from_val)),
Unbounded => Unbounded
}
}
/// `RangeQuery` match all documents that have at least one term within a defined range.
///
/// Matched document will all get a constant `Score` of one.
///
/// # Implementation
///
/// The current implement will iterate over the terms within the range
/// and append all of the document cross into a `BitSet`.
///
/// # Example
///
/// ```rust
///
/// # #[macro_use]
/// # extern crate tantivy;
/// # use tantivy::Index;
/// # use tantivy::schema::{SchemaBuilder, INT_INDEXED};
/// # use tantivy::collector::CountCollector;
/// # use tantivy::query::Query;
/// # use tantivy::Result;
/// # use tantivy::query::RangeQuery;
/// #
/// # fn run() -> Result<()> {
/// # let mut schema_builder = SchemaBuilder::new();
/// # let year_field = schema_builder.add_u64_field("year", INT_INDEXED);
/// # let schema = schema_builder.build();
/// #
/// # let index = Index::create_in_ram(schema);
/// # {
/// # let mut index_writer = index.writer_with_num_threads(1, 6_000_000).unwrap();
/// # for year in 1950u64..2017u64 {
/// # let num_docs_within_year = 10 + (year - 1950) * (year - 1950);
/// # for _ in 0..num_docs_within_year {
/// # index_writer.add_document(doc!(year_field => year));
/// # }
/// # }
/// # index_writer.commit().unwrap();
/// # }
/// # index.load_searchers()?;
/// let searcher = index.searcher();
///
/// let docs_in_the_sixties = RangeQuery::new_u64(year_field, 1960..1970);
///
/// // ... or `1960..=1969` if inclusive range is enabled.
/// let mut count_collector = CountCollector::default();
/// docs_in_the_sixties.search(&*searcher, &mut count_collector)?;
///
/// let num_60s_books = count_collector.count();
///
/// # assert_eq!(num_60s_books, 2285);
/// # Ok(())
/// # }
/// #
/// # fn main() {
/// # run().unwrap()
/// # }
/// ```
#[derive(Debug)]
pub struct RangeQuery {
field: Field,
left_bound: Bound<Vec<u8>>,
right_bound: Bound<Vec<u8>>,
}
impl RangeQuery {
/// Create a new `RangeQuery` over a `i64` field.
pub fn new_i64<TRangeArgument: RangeArgument<i64>>(field: Field, range: TRangeArgument) -> RangeQuery {
let make_term_val = |val: &i64| {
Term::from_field_i64(field, *val).value_bytes().to_owned()
};
RangeQuery {
field,
left_bound: map_bound(range.start(), &make_term_val),
right_bound: map_bound(range.end(), &make_term_val)
}
}
/// Create a new `RangeQuery` over a `u64` field.
pub fn new_u64<TRangeArgument: RangeArgument<u64>>(field: Field, range: TRangeArgument) -> RangeQuery {
let make_term_val = |val: &u64| {
Term::from_field_u64(field, *val).value_bytes().to_owned()
};
RangeQuery {
field,
left_bound: map_bound(range.start(), &make_term_val),
right_bound: map_bound(range.end(), &make_term_val)
}
}
/// Create a new `RangeQuery` over a `Str` field.
pub fn new_str<'b, TRangeArgument: RangeArgument<&'b str>>(field: Field, range: TRangeArgument) -> RangeQuery {
let make_term_val = |val: &&str| {
val.as_bytes().to_vec()
};
RangeQuery {
field,
left_bound: map_bound(range.start(), &make_term_val),
right_bound: map_bound(range.end(), &make_term_val)
}
}
}
impl Query for RangeQuery {
fn as_any(&self) -> &Any {
self
}
fn weight(&self, _searcher: &Searcher) -> Result<Box<Weight>> {
Ok(box RangeWeight {
field: self.field,
left_bound: self.left_bound.clone(),
right_bound: self.right_bound.clone()
})
}
}
pub struct RangeWeight {
field: Field,
left_bound: Bound<Vec<u8>>,
right_bound: Bound<Vec<u8>>,
}
impl RangeWeight {
fn term_range<'a, T>(&self, term_dict: &'a T) -> T::Streamer
where
T: TermDictionary<'a> + 'a,
{
use std::collections::Bound::*;
let mut term_stream_builder = term_dict.range();
term_stream_builder = match &self.left_bound {
&Included(ref term_val) => term_stream_builder.ge(term_val),
&Excluded(ref term_val) => term_stream_builder.gt(term_val),
&Unbounded => term_stream_builder,
};
term_stream_builder = match &self.right_bound {
&Included(ref term_val) => term_stream_builder.le(term_val),
&Excluded(ref term_val) => term_stream_builder.lt(term_val),
&Unbounded => term_stream_builder,
};
term_stream_builder.into_stream()
}
}
impl Weight for RangeWeight {
fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result<Box<Scorer + 'a>> {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(self.field);
let term_dict = inverted_index.terms();
let mut term_range = self.term_range(term_dict);
while term_range.advance() {
let term_info = term_range.value();
let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic);
while block_segment_postings.advance() {
for &doc in block_segment_postings.docs() {
doc_bitset.insert(doc);
}
}
}
let doc_bitset = BitSetDocSet::from(doc_bitset);
Ok(box ConstScorer::new(doc_bitset))
}
}
#[cfg(test)]
mod tests {
use Index;
use schema::{Document, Field, SchemaBuilder, INT_INDEXED};
use collector::CountCollector;
use std::collections::Bound;
use query::Query;
use Result;
use super::RangeQuery;
#[test]
fn test_range_query_simple() {
fn run() -> Result<()> {
let mut schema_builder = SchemaBuilder::new();
let year_field= schema_builder.add_u64_field("year", INT_INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 6_000_000).unwrap();
for year in 1950u64..2017u64 {
let num_docs_within_year = 10 + (year - 1950) * (year - 1950);
for _ in 0..num_docs_within_year {
index_writer.add_document(doc!(year_field => year));
}
}
index_writer.commit().unwrap();
}
index.load_searchers().unwrap();
let searcher = index.searcher();
let docs_in_the_sixties = RangeQuery::new_u64(year_field, 1960u64..1970u64);
// ... or `1960..=1969` if inclusive range is enabled.
let mut count_collector = CountCollector::default();
docs_in_the_sixties.search(&*searcher, &mut count_collector)?;
assert_eq!(count_collector.count(), 2285);
Ok(())
}
run().unwrap();
}
#[test]
fn test_range_query() {
let int_field: Field;
let schema = {
let mut schema_builder = SchemaBuilder::new();
int_field = schema_builder.add_i64_field("intfield", INT_INDEXED);
schema_builder.build()
};
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(2, 6_000_000).unwrap();
for i in 1..100 {
let mut doc = Document::new();
for j in 1..100 {
if i % j == 0 {
doc.add_i64(int_field, j as i64);
}
}
index_writer.add_document(doc);
}
index_writer.commit().unwrap();
}
index.load_searchers().unwrap();
let searcher = index.searcher();
let count_multiples = |range_query: RangeQuery| {
let mut count_collector = CountCollector::default();
range_query
.search(&*searcher, &mut count_collector)
.unwrap();
count_collector.count()
};
assert_eq!(
count_multiples(RangeQuery::new_i64(int_field, 10..11)),
9
);
assert_eq!(
count_multiples(RangeQuery::new_i64(int_field, (Bound::Included(10), Bound::Included(11)) )),
18
);
assert_eq!(
count_multiples(RangeQuery::new_i64(int_field, (Bound::Excluded(9), Bound::Included(10)))),
9
);
assert_eq!(
count_multiples(RangeQuery::new_i64(int_field, 9..)),
91
);
}
}

View File

@@ -2,6 +2,8 @@ use DocSet;
use DocId;
use Score;
use collector::Collector;
use postings::SkipResult;
use common::BitSet;
use std::ops::{Deref, DerefMut};
/// Scored set of documents matching a query within a specific segment.
@@ -49,7 +51,7 @@ impl DocSet for EmptyScorer {
DocId::max_value()
}
fn size_hint(&self) -> usize {
fn size_hint(&self) -> u32 {
0
}
}
@@ -59,3 +61,63 @@ impl Scorer for EmptyScorer {
0f32
}
}
/// Wraps a `DocSet` and simply returns a constant `Scorer`.
/// The `ConstScorer` is useful if you have a `DocSet` where
/// you needed a scorer.
///
/// The `ConstScorer`'s constant score can be set
/// by calling `.set_score(...)`.
pub struct ConstScorer<TDocSet: DocSet> {
docset: TDocSet,
score: Score,
}
impl<TDocSet: DocSet> ConstScorer<TDocSet> {
/// Creates a new `ConstScorer`.
pub fn new(docset: TDocSet) -> ConstScorer<TDocSet> {
ConstScorer {
docset,
score: 1f32,
}
}
/// Sets the constant score to a different value.
pub fn set_score(&mut self, score: Score) {
self.score = score;
}
}
impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
fn advance(&mut self) -> bool {
self.docset.advance()
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
self.docset.skip_next(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
self.docset.fill_buffer(buffer)
}
fn doc(&self) -> DocId {
self.docset.doc()
}
fn size_hint(&self) -> u32 {
self.docset.size_hint()
}
fn append_to_bitset(&mut self, bitset: &mut BitSet) {
self.docset.append_to_bitset(bitset);
}
}
impl<TDocSet: DocSet> Scorer for ConstScorer<TDocSet> {
fn score(&self) -> Score {
1f32
}
}

View File

@@ -37,7 +37,7 @@ where
self.postings.doc()
}
fn size_hint(&self) -> usize {
fn size_hint(&self) -> u32 {
self.postings.size_hint()
}