mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-06-02 00:20:42 +00:00
Top collector (#413)
* Make TopCollector generic Make TopCollector take a generic type instead of only being tied to score. This will allow for sharing code between a TopCollector that sorts results by Score and a TopCollector that sorts documents by a fast field. This commit makes no functional changes to TopCollector. * Add TopFieldCollector and TopScoreCollector Create two new collectors that use the refactored TopCollector. TopFieldCollector has the same functionality that TopCollector originally had. TopFieldCollector allows for sorting results by a given fast field. Closes tantivy-search/tantivy#388 * Make TopCollector private Make TopCollector package private and export TopFieldCollector as TopCollector to maintain backwards compatibility. Mark TopCollector as deprecated to encourage use of the non-aliased TopFieldCollector. Remove Collector implementation for TopCollector since it is not longer used.
This commit is contained in:
@@ -15,7 +15,14 @@ mod multi_collector;
|
||||
pub use self::multi_collector::MultiCollector;
|
||||
|
||||
mod top_collector;
|
||||
pub use self::top_collector::TopCollector;
|
||||
|
||||
mod top_score_collector;
|
||||
pub use self::top_score_collector::TopScoreCollector;
|
||||
#[deprecated]
|
||||
pub use self::top_score_collector::TopScoreCollector as TopCollector;
|
||||
|
||||
mod top_field_collector;
|
||||
pub use self::top_field_collector::TopFieldCollector;
|
||||
|
||||
mod facet_collector;
|
||||
pub use self::facet_collector::FacetCollector;
|
||||
|
||||
@@ -100,11 +100,11 @@ impl<'a> Collector for MultiCollector<'a> {
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use collector::{Collector, CountCollector, TopCollector};
|
||||
use collector::{Collector, CountCollector, TopScoreCollector};
|
||||
|
||||
#[test]
|
||||
fn test_multi_collector() {
|
||||
let mut top_collector = TopCollector::with_limit(2);
|
||||
let mut top_collector = TopScoreCollector::with_limit(2);
|
||||
let mut count_collector = CountCollector::default();
|
||||
{
|
||||
let mut collectors =
|
||||
|
||||
@@ -1,115 +1,61 @@
|
||||
use super::Collector;
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::BinaryHeap;
|
||||
use DocAddress;
|
||||
use DocId;
|
||||
use Result;
|
||||
use Score;
|
||||
use SegmentLocalId;
|
||||
use SegmentReader;
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::BinaryHeap;
|
||||
|
||||
// Rust heap is a max-heap and we need a min heap.
|
||||
/// Contains a feature (field, score, etc.) of a document along with the document address.
|
||||
///
|
||||
/// It has a custom implementation of `PartialOrd` that reverses the order. This is because the
|
||||
/// default Rust heap is a max heap, whereas a min heap is needed.
|
||||
#[derive(Clone, Copy)]
|
||||
struct GlobalScoredDoc {
|
||||
score: Score,
|
||||
pub struct ComparableDoc<T> {
|
||||
feature: T,
|
||||
doc_address: DocAddress,
|
||||
}
|
||||
|
||||
impl PartialOrd for GlobalScoredDoc {
|
||||
fn partial_cmp(&self, other: &GlobalScoredDoc) -> Option<Ordering> {
|
||||
impl<T: PartialOrd> PartialOrd for ComparableDoc<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for GlobalScoredDoc {
|
||||
impl<T: PartialOrd> Ord for ComparableDoc<T> {
|
||||
#[inline]
|
||||
fn cmp(&self, other: &GlobalScoredDoc) -> Ordering {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
other
|
||||
.score
|
||||
.partial_cmp(&self.score)
|
||||
.feature
|
||||
.partial_cmp(&self.feature)
|
||||
.unwrap_or_else(|| other.doc_address.cmp(&self.doc_address))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for GlobalScoredDoc {
|
||||
fn eq(&self, other: &GlobalScoredDoc) -> bool {
|
||||
impl<T: PartialOrd> PartialEq for ComparableDoc<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.cmp(other) == Ordering::Equal
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for GlobalScoredDoc {}
|
||||
impl<T: PartialOrd> Eq for ComparableDoc<T> {}
|
||||
|
||||
/// The Top Collector keeps track of the K documents
|
||||
/// with the best scores.
|
||||
/// sorted by type `T`.
|
||||
///
|
||||
/// The implementation is based on a `BinaryHeap`.
|
||||
/// The theorical complexity for collecting the top `K` out of `n` documents
|
||||
/// is `O(n log K)`.
|
||||
///
|
||||
/// ```rust
|
||||
/// #[macro_use]
|
||||
/// extern crate tantivy;
|
||||
/// use tantivy::schema::{SchemaBuilder, TEXT};
|
||||
/// use tantivy::{Index, Result, DocId, Score};
|
||||
/// use tantivy::collector::TopCollector;
|
||||
/// use tantivy::query::QueryParser;
|
||||
///
|
||||
/// # fn main() { example().unwrap(); }
|
||||
/// fn example() -> Result<()> {
|
||||
/// let mut schema_builder = SchemaBuilder::new();
|
||||
/// let title = schema_builder.add_text_field("title", TEXT);
|
||||
/// let schema = schema_builder.build();
|
||||
/// let index = Index::create_in_ram(schema);
|
||||
/// {
|
||||
/// let mut index_writer = index.writer_with_num_threads(1, 3_000_000)?;
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "The Name of the Wind",
|
||||
/// ));
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "The Diary of Muadib",
|
||||
/// ));
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "A Dairy Cow",
|
||||
/// ));
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "The Diary of a Young Girl",
|
||||
/// ));
|
||||
/// index_writer.commit().unwrap();
|
||||
/// }
|
||||
///
|
||||
/// index.load_searchers()?;
|
||||
/// let searcher = index.searcher();
|
||||
///
|
||||
/// {
|
||||
/// let mut top_collector = TopCollector::with_limit(2);
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query = query_parser.parse_query("diary")?;
|
||||
/// searcher.search(&*query, &mut top_collector).unwrap();
|
||||
///
|
||||
/// let score_docs: Vec<(Score, DocId)> = top_collector
|
||||
/// .score_docs()
|
||||
/// .into_iter()
|
||||
/// .map(|(score, doc_address)| (score, doc_address.doc()))
|
||||
/// .collect();
|
||||
///
|
||||
/// assert_eq!(score_docs, vec![(0.7261542, 1), (0.6099695, 3)]);
|
||||
/// }
|
||||
///
|
||||
/// Ok(())
|
||||
/// }
|
||||
/// ```
|
||||
pub struct TopCollector {
|
||||
pub struct TopCollector<T> {
|
||||
limit: usize,
|
||||
heap: BinaryHeap<GlobalScoredDoc>,
|
||||
heap: BinaryHeap<ComparableDoc<T>>,
|
||||
segment_id: u32,
|
||||
}
|
||||
|
||||
impl TopCollector {
|
||||
impl<T: PartialOrd + Clone> TopCollector<T> {
|
||||
/// Creates a top collector, with a number of documents equal to "limit".
|
||||
///
|
||||
/// # Panics
|
||||
/// The method panics if limit is 0
|
||||
pub fn with_limit(limit: usize) -> TopCollector {
|
||||
pub fn with_limit(limit: usize) -> TopCollector<T> {
|
||||
if limit < 1 {
|
||||
panic!("Limit must be strictly greater than 0.");
|
||||
}
|
||||
@@ -125,22 +71,27 @@ impl TopCollector {
|
||||
/// Calling this method triggers the sort.
|
||||
/// The result of the sort is not cached.
|
||||
pub fn docs(&self) -> Vec<DocAddress> {
|
||||
self.score_docs()
|
||||
self.top_docs()
|
||||
.into_iter()
|
||||
.map(|score_doc| score_doc.1)
|
||||
.map(|(_feature, doc)| doc)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns K best ScoredDocument sorted in decreasing order.
|
||||
/// Returns K best FeatureDocuments sorted in decreasing order.
|
||||
///
|
||||
/// Calling this method triggers the sort.
|
||||
/// The result of the sort is not cached.
|
||||
pub fn score_docs(&self) -> Vec<(Score, DocAddress)> {
|
||||
let mut scored_docs: Vec<GlobalScoredDoc> = self.heap.iter().cloned().collect();
|
||||
scored_docs.sort();
|
||||
scored_docs
|
||||
pub fn top_docs(&self) -> Vec<(T, DocAddress)> {
|
||||
let mut feature_docs: Vec<ComparableDoc<T>> = self.heap.iter().cloned().collect();
|
||||
feature_docs.sort();
|
||||
feature_docs
|
||||
.into_iter()
|
||||
.map(|GlobalScoredDoc { score, doc_address }| (score, doc_address))
|
||||
.map(
|
||||
|ComparableDoc {
|
||||
feature,
|
||||
doc_address,
|
||||
}| (feature, doc_address),
|
||||
)
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -150,48 +101,47 @@ impl TopCollector {
|
||||
pub fn at_capacity(&self) -> bool {
|
||||
self.heap.len() >= self.limit
|
||||
}
|
||||
}
|
||||
|
||||
impl Collector for TopCollector {
|
||||
fn set_segment(&mut self, segment_id: SegmentLocalId, _: &SegmentReader) -> Result<()> {
|
||||
/// Sets the segment local ID for the collector
|
||||
pub fn set_segment_id(&mut self, segment_id: SegmentLocalId) {
|
||||
self.segment_id = segment_id;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect(&mut self, doc: DocId, score: Score) {
|
||||
/// Collects a document scored by the given feature
|
||||
///
|
||||
/// It collects documents until it has reached the max capacity. Once it reaches capacity, it
|
||||
/// will compare the lowest scoring item with the given one and keep whichever is greater.
|
||||
pub fn collect(&mut self, doc: DocId, feature: T) {
|
||||
if self.at_capacity() {
|
||||
// It's ok to unwrap as long as a limit of 0 is forbidden.
|
||||
let limit_doc: GlobalScoredDoc = *self.heap
|
||||
let limit_doc: ComparableDoc<T> = self
|
||||
.heap
|
||||
.peek()
|
||||
.expect("Top collector with size 0 is forbidden");
|
||||
if limit_doc.score < score {
|
||||
let mut mut_head = self.heap
|
||||
.expect("Top collector with size 0 is forbidden")
|
||||
.clone();
|
||||
if limit_doc.feature < feature {
|
||||
let mut mut_head = self
|
||||
.heap
|
||||
.peek_mut()
|
||||
.expect("Top collector with size 0 is forbidden");
|
||||
mut_head.score = score;
|
||||
mut_head.feature = feature;
|
||||
mut_head.doc_address = DocAddress(self.segment_id, doc);
|
||||
}
|
||||
} else {
|
||||
let wrapped_doc = GlobalScoredDoc {
|
||||
score,
|
||||
let wrapped_doc = ComparableDoc {
|
||||
feature,
|
||||
doc_address: DocAddress(self.segment_id, doc),
|
||||
};
|
||||
self.heap.push(wrapped_doc);
|
||||
}
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use collector::Collector;
|
||||
use DocId;
|
||||
use Score;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_not_at_capacity() {
|
||||
@@ -201,7 +151,7 @@ mod tests {
|
||||
top_collector.collect(5, 0.3);
|
||||
assert!(!top_collector.at_capacity());
|
||||
let score_docs: Vec<(Score, DocId)> = top_collector
|
||||
.score_docs()
|
||||
.top_docs()
|
||||
.into_iter()
|
||||
.map(|(score, doc_address)| (score, doc_address.doc()))
|
||||
.collect();
|
||||
@@ -219,7 +169,7 @@ mod tests {
|
||||
assert!(top_collector.at_capacity());
|
||||
{
|
||||
let score_docs: Vec<(Score, DocId)> = top_collector
|
||||
.score_docs()
|
||||
.top_docs()
|
||||
.into_iter()
|
||||
.map(|(score, doc_address)| (score, doc_address.doc()))
|
||||
.collect();
|
||||
@@ -238,7 +188,7 @@ mod tests {
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_top_0() {
|
||||
TopCollector::with_limit(0);
|
||||
let _collector: TopCollector<Score> = TopCollector::with_limit(0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
263
src/collector/top_field_collector.rs
Normal file
263
src/collector/top_field_collector.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
use collector::top_collector::TopCollector;
|
||||
use DocAddress;
|
||||
use DocId;
|
||||
use fastfield::FastFieldReader;
|
||||
use fastfield::FastValue;
|
||||
use Result;
|
||||
use Score;
|
||||
use SegmentReader;
|
||||
use super::Collector;
|
||||
use schema::Field;
|
||||
|
||||
/// The Top Field Collector keeps track of the K documents
|
||||
/// sorted by a fast field in the index
|
||||
///
|
||||
/// The implementation is based on a `BinaryHeap`.
|
||||
/// The theorical complexity for collecting the top `K` out of `n` documents
|
||||
/// is `O(n log K)`.
|
||||
///
|
||||
/// ```rust
|
||||
/// #[macro_use]
|
||||
/// extern crate tantivy;
|
||||
/// use tantivy::schema::{SchemaBuilder, TEXT, FAST};
|
||||
/// use tantivy::{Index, Result, DocId};
|
||||
/// use tantivy::collector::TopFieldCollector;
|
||||
/// use tantivy::query::QueryParser;
|
||||
///
|
||||
/// # fn main() { example().unwrap(); }
|
||||
/// fn example() -> Result<()> {
|
||||
/// let mut schema_builder = SchemaBuilder::new();
|
||||
/// let title = schema_builder.add_text_field("title", TEXT);
|
||||
/// let rating = schema_builder.add_u64_field("rating", FAST);
|
||||
/// let schema = schema_builder.build();
|
||||
/// let index = Index::create_in_ram(schema);
|
||||
/// {
|
||||
/// let mut index_writer = index.writer_with_num_threads(1, 3_000_000)?;
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "The Name of the Wind",
|
||||
/// rating => 92u64,
|
||||
/// ));
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "The Diary of Muadib",
|
||||
/// rating => 97u64,
|
||||
/// ));
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "A Dairy Cow",
|
||||
/// rating => 63u64,
|
||||
/// ));
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "The Diary of a Young Girl",
|
||||
/// rating => 80u64,
|
||||
/// ));
|
||||
/// index_writer.commit().unwrap();
|
||||
/// }
|
||||
///
|
||||
/// index.load_searchers()?;
|
||||
/// let searcher = index.searcher();
|
||||
///
|
||||
/// {
|
||||
/// let mut top_collector = TopFieldCollector::with_limit(rating, 2);
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query = query_parser.parse_query("diary")?;
|
||||
/// searcher.search(&*query, &mut top_collector).unwrap();
|
||||
///
|
||||
/// let score_docs: Vec<(u64, DocId)> = top_collector
|
||||
/// .top_docs()
|
||||
/// .into_iter()
|
||||
/// .map(|(field, doc_address)| (field, doc_address.doc()))
|
||||
/// .collect();
|
||||
///
|
||||
/// assert_eq!(score_docs, vec![(97u64, 1), (80, 3)]);
|
||||
/// }
|
||||
///
|
||||
/// Ok(())
|
||||
/// }
|
||||
/// ```
|
||||
pub struct TopFieldCollector<T: FastValue> {
|
||||
field: Field,
|
||||
collector: TopCollector<T>,
|
||||
fast_field: Option<FastFieldReader<T>>,
|
||||
}
|
||||
|
||||
impl<T: FastValue + PartialOrd + Clone> TopFieldCollector<T> {
|
||||
/// Creates a top field collector, with a number of documents equal to "limit".
|
||||
///
|
||||
/// The given field name must be a fast field, otherwise the collector have an error while
|
||||
/// collecting results.
|
||||
///
|
||||
/// # Panics
|
||||
/// The method panics if limit is 0
|
||||
pub fn with_limit(field: Field, limit: usize) -> Self {
|
||||
TopFieldCollector {
|
||||
field,
|
||||
collector: TopCollector::with_limit(limit),
|
||||
fast_field: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns K best documents sorted the given field name in decreasing order.
|
||||
///
|
||||
/// Calling this method triggers the sort.
|
||||
/// The result of the sort is not cached.
|
||||
pub fn docs(&self) -> Vec<DocAddress> {
|
||||
self.collector.docs()
|
||||
}
|
||||
|
||||
/// Returns K best FieldDocuments sorted in decreasing order.
|
||||
///
|
||||
/// Calling this method triggers the sort.
|
||||
/// The result of the sort is not cached.
|
||||
pub fn top_docs(&self) -> Vec<(T, DocAddress)> {
|
||||
self.collector.top_docs()
|
||||
}
|
||||
|
||||
/// Return true iff at least K documents have gone through
|
||||
/// the collector.
|
||||
#[inline]
|
||||
pub fn at_capacity(&self) -> bool {
|
||||
self.collector.at_capacity()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FastValue + PartialOrd + Clone> Collector for TopFieldCollector<T> {
|
||||
fn set_segment(&mut self, segment_id: u32, segment: &SegmentReader) -> Result<()> {
|
||||
self.collector.set_segment_id(segment_id);
|
||||
self.fast_field = Some(segment.fast_field_reader(self.field)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect(&mut self, doc: DocId, _score: Score) {
|
||||
let field_value = self
|
||||
.fast_field
|
||||
.as_ref()
|
||||
.expect("collect() was called before set_segment. This should never happen.")
|
||||
.get(doc);
|
||||
self.collector.collect(doc, field_value);
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use Index;
|
||||
use IndexWriter;
|
||||
use TantivyError;
|
||||
use query::Query;
|
||||
use query::QueryParser;
|
||||
use schema::{FAST, SchemaBuilder, TEXT};
|
||||
use schema::Field;
|
||||
use schema::IntOptions;
|
||||
use schema::Schema;
|
||||
use super::*;
|
||||
|
||||
const TITLE: &str = "title";
|
||||
const SIZE: &str = "size";
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_not_at_capacity() {
|
||||
let mut schema_builder = SchemaBuilder::new();
|
||||
let title = schema_builder.add_text_field(TITLE, TEXT);
|
||||
let size = schema_builder.add_u64_field(SIZE, FAST);
|
||||
let schema = schema_builder.build();
|
||||
let (index, query) = index("beer", title, schema, |index_writer| {
|
||||
index_writer.add_document(doc!(
|
||||
title => "bottle of beer",
|
||||
size => 12u64,
|
||||
));
|
||||
index_writer.add_document(doc!(
|
||||
title => "growler of beer",
|
||||
size => 64u64,
|
||||
));
|
||||
index_writer.add_document(doc!(
|
||||
title => "pint of beer",
|
||||
size => 16u64,
|
||||
));
|
||||
});
|
||||
let searcher = index.searcher();
|
||||
|
||||
let mut top_collector = TopFieldCollector::with_limit(size, 4);
|
||||
searcher.search(&*query, &mut top_collector).unwrap();
|
||||
assert!(!top_collector.at_capacity());
|
||||
|
||||
let score_docs: Vec<(u64, DocId)> = top_collector
|
||||
.top_docs()
|
||||
.into_iter()
|
||||
.map(|(field, doc_address)| (field, doc_address.doc()))
|
||||
.collect();
|
||||
assert_eq!(score_docs, vec![(64, 1), (16, 2), (12, 0)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_field_does_not_exist() {
|
||||
let mut schema_builder = SchemaBuilder::new();
|
||||
let title = schema_builder.add_text_field(TITLE, TEXT);
|
||||
let size = schema_builder.add_u64_field(SIZE, FAST);
|
||||
let schema = schema_builder.build();
|
||||
let (index, _) = index("beer", title, schema, |index_writer| {
|
||||
index_writer.add_document(doc!(
|
||||
title => "bottle of beer",
|
||||
size => 12u64,
|
||||
));
|
||||
});
|
||||
let searcher = index.searcher();
|
||||
let segment = searcher.segment_reader(0);
|
||||
let mut top_collector: TopFieldCollector<u64> = TopFieldCollector::with_limit(Field(2), 4);
|
||||
let _ = top_collector.set_segment(0, segment);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_field_not_fast_field() {
|
||||
let mut schema_builder = SchemaBuilder::new();
|
||||
let title = schema_builder.add_text_field(TITLE, TEXT);
|
||||
let size = schema_builder.add_u64_field(SIZE, IntOptions::default());
|
||||
let schema = schema_builder.build();
|
||||
let (index, _) = index("beer", title, schema, |index_writer| {
|
||||
index_writer.add_document(doc!(
|
||||
title => "bottle of beer",
|
||||
size => 12u64,
|
||||
));
|
||||
});
|
||||
let searcher = index.searcher();
|
||||
let segment = searcher.segment_reader(0);
|
||||
let mut top_collector: TopFieldCollector<u64> = TopFieldCollector::with_limit(size, 4);
|
||||
assert_matches!(
|
||||
top_collector.set_segment(0, segment),
|
||||
Err(TantivyError::FastFieldError(_))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_collect_before_set_segment() {
|
||||
let mut top_collector: TopFieldCollector<u64> = TopFieldCollector::with_limit(Field(0), 4);
|
||||
top_collector.collect(0, 0f32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_top_0() {
|
||||
let _: TopFieldCollector<u64> = TopFieldCollector::with_limit(Field(0), 0);
|
||||
}
|
||||
|
||||
fn index(
|
||||
query: &str,
|
||||
query_field: Field,
|
||||
schema: Schema,
|
||||
mut doc_adder: impl FnMut(&mut IndexWriter) -> (),
|
||||
) -> (Index, Box<Query>) {
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
|
||||
doc_adder(&mut index_writer);
|
||||
index_writer.commit().unwrap();
|
||||
index.load_searchers().unwrap();
|
||||
|
||||
let query_parser = QueryParser::for_index(&index, vec![query_field]);
|
||||
let query = query_parser.parse_query(query).unwrap();
|
||||
(index, query)
|
||||
}
|
||||
}
|
||||
187
src/collector/top_score_collector.rs
Normal file
187
src/collector/top_score_collector.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use collector::top_collector::TopCollector;
|
||||
use DocAddress;
|
||||
use DocId;
|
||||
use Result;
|
||||
use Score;
|
||||
use SegmentLocalId;
|
||||
use SegmentReader;
|
||||
use super::Collector;
|
||||
|
||||
/// The Top Score Collector keeps track of the K documents
|
||||
/// sorted by their score.
|
||||
///
|
||||
/// The implementation is based on a `BinaryHeap`.
|
||||
/// The theorical complexity for collecting the top `K` out of `n` documents
|
||||
/// is `O(n log K)`.
|
||||
///
|
||||
/// ```rust
|
||||
/// #[macro_use]
|
||||
/// extern crate tantivy;
|
||||
/// use tantivy::schema::{SchemaBuilder, TEXT};
|
||||
/// use tantivy::{Index, Result, DocId, Score};
|
||||
/// use tantivy::collector::TopScoreCollector;
|
||||
/// use tantivy::query::QueryParser;
|
||||
///
|
||||
/// # fn main() { example().unwrap(); }
|
||||
/// fn example() -> Result<()> {
|
||||
/// let mut schema_builder = SchemaBuilder::new();
|
||||
/// let title = schema_builder.add_text_field("title", TEXT);
|
||||
/// let schema = schema_builder.build();
|
||||
/// let index = Index::create_in_ram(schema);
|
||||
/// {
|
||||
/// let mut index_writer = index.writer_with_num_threads(1, 3_000_000)?;
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "The Name of the Wind",
|
||||
/// ));
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "The Diary of Muadib",
|
||||
/// ));
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "A Dairy Cow",
|
||||
/// ));
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "The Diary of a Young Girl",
|
||||
/// ));
|
||||
/// index_writer.commit().unwrap();
|
||||
/// }
|
||||
///
|
||||
/// index.load_searchers()?;
|
||||
/// let searcher = index.searcher();
|
||||
///
|
||||
/// {
|
||||
/// let mut top_collector = TopScoreCollector::with_limit(2);
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query = query_parser.parse_query("diary")?;
|
||||
/// searcher.search(&*query, &mut top_collector).unwrap();
|
||||
///
|
||||
/// let score_docs: Vec<(Score, DocId)> = top_collector
|
||||
/// .top_docs()
|
||||
/// .into_iter()
|
||||
/// .map(|(score, doc_address)| (score, doc_address.doc()))
|
||||
/// .collect();
|
||||
///
|
||||
/// assert_eq!(score_docs, vec![(0.7261542, 1), (0.6099695, 3)]);
|
||||
/// }
|
||||
///
|
||||
/// Ok(())
|
||||
/// }
|
||||
/// ```
|
||||
pub struct TopScoreCollector {
|
||||
collector: TopCollector<Score>,
|
||||
}
|
||||
|
||||
impl TopScoreCollector {
|
||||
/// Creates a top score collector, with a number of documents equal to "limit".
|
||||
///
|
||||
/// # Panics
|
||||
/// The method panics if limit is 0
|
||||
pub fn with_limit(limit: usize) -> TopScoreCollector {
|
||||
TopScoreCollector {
|
||||
collector: TopCollector::with_limit(limit),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns K best scored documents sorted in decreasing order.
|
||||
///
|
||||
/// Calling this method triggers the sort.
|
||||
/// The result of the sort is not cached.
|
||||
pub fn docs(&self) -> Vec<DocAddress> {
|
||||
self.collector.docs()
|
||||
}
|
||||
|
||||
/// Returns K best ScoredDocuments sorted in decreasing order.
|
||||
///
|
||||
/// Calling this method triggers the sort.
|
||||
/// The result of the sort is not cached.
|
||||
pub fn top_docs(&self) -> Vec<(Score, DocAddress)> {
|
||||
self.collector.top_docs()
|
||||
}
|
||||
|
||||
/// Returns K best ScoredDocuments sorted in decreasing order.
|
||||
///
|
||||
/// Calling this method triggers the sort.
|
||||
/// The result of the sort is not cached.
|
||||
#[deprecated]
|
||||
pub fn score_docs(&self) -> Vec<(Score, DocAddress)> {
|
||||
self.collector.top_docs()
|
||||
}
|
||||
|
||||
/// Return true iff at least K documents have gone through
|
||||
/// the collector.
|
||||
#[inline]
|
||||
pub fn at_capacity(&self) -> bool {
|
||||
self.collector.at_capacity()
|
||||
}
|
||||
}
|
||||
|
||||
impl Collector for TopScoreCollector {
|
||||
fn set_segment(&mut self, segment_id: SegmentLocalId, _: &SegmentReader) -> Result<()> {
|
||||
self.collector.set_segment_id(segment_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect(&mut self, doc: DocId, score: Score) {
|
||||
self.collector.collect(doc, score);
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use collector::Collector;
|
||||
use DocId;
|
||||
use Score;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_not_at_capacity() {
|
||||
let mut top_collector = TopScoreCollector::with_limit(4);
|
||||
top_collector.collect(1, 0.8);
|
||||
top_collector.collect(3, 0.2);
|
||||
top_collector.collect(5, 0.3);
|
||||
assert!(!top_collector.at_capacity());
|
||||
let score_docs: Vec<(Score, DocId)> = top_collector
|
||||
.top_docs()
|
||||
.into_iter()
|
||||
.map(|(score, doc_address)| (score, doc_address.doc()))
|
||||
.collect();
|
||||
assert_eq!(score_docs, vec![(0.8, 1), (0.3, 5), (0.2, 3)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_at_capacity() {
|
||||
let mut top_collector = TopScoreCollector::with_limit(4);
|
||||
top_collector.collect(1, 0.8);
|
||||
top_collector.collect(3, 0.2);
|
||||
top_collector.collect(5, 0.3);
|
||||
top_collector.collect(7, 0.9);
|
||||
top_collector.collect(9, -0.2);
|
||||
assert!(top_collector.at_capacity());
|
||||
{
|
||||
let score_docs: Vec<(Score, DocId)> = top_collector
|
||||
.top_docs()
|
||||
.into_iter()
|
||||
.map(|(score, doc_address)| (score, doc_address.doc()))
|
||||
.collect();
|
||||
assert_eq!(score_docs, vec![(0.9, 7), (0.8, 1), (0.3, 5), (0.2, 3)]);
|
||||
}
|
||||
{
|
||||
let docs: Vec<DocId> = top_collector
|
||||
.docs()
|
||||
.into_iter()
|
||||
.map(|doc_address| doc_address.doc())
|
||||
.collect();
|
||||
assert_eq!(docs, vec![7, 1, 5, 3]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_top_0() {
|
||||
TopScoreCollector::with_limit(0);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -153,7 +153,7 @@ mod test {
|
||||
|
||||
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
|
||||
searcher.search(&fuzzy_query, &mut collector).unwrap();
|
||||
let scored_docs = collector.score_docs();
|
||||
let scored_docs = collector.top_docs();
|
||||
assert_eq!(scored_docs.len(), 1, "Expected only 1 document");
|
||||
let (score, _) = scored_docs[0];
|
||||
assert_nearly_equals(1f32, score);
|
||||
|
||||
@@ -123,7 +123,7 @@ mod test {
|
||||
let mut collector = TopCollector::with_limit(2);
|
||||
let regex_query = RegexQuery::new("jap[ao]n".to_string(), country_field);
|
||||
searcher.search(®ex_query, &mut collector).unwrap();
|
||||
let scored_docs = collector.score_docs();
|
||||
let scored_docs = collector.top_docs();
|
||||
assert_eq!(scored_docs.len(), 1, "Expected only 1 document");
|
||||
let (score, _) = scored_docs[0];
|
||||
assert_nearly_equals(1f32, score);
|
||||
@@ -132,7 +132,7 @@ mod test {
|
||||
let mut collector = TopCollector::with_limit(2);
|
||||
let regex_query = RegexQuery::new("jap[A-Z]n".to_string(), country_field);
|
||||
searcher.search(®ex_query, &mut collector).unwrap();
|
||||
let scored_docs = collector.score_docs();
|
||||
let scored_docs = collector.top_docs();
|
||||
assert_eq!(scored_docs.len(), 0, "Expected ZERO document");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ mod tests {
|
||||
let term = Term::from_field_text(left_field, "left2");
|
||||
let term_query = TermQuery::new(term, IndexRecordOption::WithFreqs);
|
||||
searcher.search(&term_query, &mut collector).unwrap();
|
||||
let scored_docs = collector.score_docs();
|
||||
let scored_docs = collector.top_docs();
|
||||
assert_eq!(scored_docs.len(), 1);
|
||||
let (score, _) = scored_docs[0];
|
||||
assert_nearly_equals(0.77802235, score);
|
||||
@@ -82,7 +82,7 @@ mod tests {
|
||||
let term = Term::from_field_text(left_field, "left1");
|
||||
let term_query = TermQuery::new(term, IndexRecordOption::WithFreqs);
|
||||
searcher.search(&term_query, &mut collector).unwrap();
|
||||
let scored_docs = collector.score_docs();
|
||||
let scored_docs = collector.top_docs();
|
||||
assert_eq!(scored_docs.len(), 2);
|
||||
let (score1, _) = scored_docs[0];
|
||||
assert_nearly_equals(0.27101856, score1);
|
||||
@@ -94,7 +94,7 @@ mod tests {
|
||||
let query = query_parser.parse_query("left:left2 left:left1").unwrap();
|
||||
let mut collector = TopCollector::with_limit(2);
|
||||
searcher.search(&*query, &mut collector).unwrap();
|
||||
let scored_docs = collector.score_docs();
|
||||
let scored_docs = collector.top_docs();
|
||||
assert_eq!(scored_docs.len(), 2);
|
||||
let (score1, _) = scored_docs[0];
|
||||
assert_nearly_equals(0.9153879, score1);
|
||||
|
||||
Reference in New Issue
Block a user