diff --git a/src/collector/filter_collector_wrapper.rs b/src/collector/filter_collector_wrapper.rs new file mode 100644 index 000000000..ddcf68cb7 --- /dev/null +++ b/src/collector/filter_collector_wrapper.rs @@ -0,0 +1,153 @@ +// # Custom collector example +// +// This example shows how you can implement your own +// collector. As an example, we will compute a collector +// that computes the standard deviation of a given fast field. +// +// Of course, you can have a look at the tantivy's built-in collectors +// such as the `CountCollector` for more examples. + +// --- +// Importing tantivy... +use std::marker::PhantomData; + +use crate::collector::{Collector, SegmentCollector}; +use crate::fastfield::FastFieldReader; +use crate::schema::Field; +use crate::{Score, SegmentReader, TantivyError}; + +/// The `TopDocs` collector keeps track of the top `K` documents +/// sorted by their score. +/// +/// The implementation is based on a `BinaryHeap`. +/// The theorical complexity for collecting the top `K` out of `n` documents +/// is `O(n log K)`. +/// +/// This collector guarantees a stable sorting in case of a tie on the +/// document score. As such, it is suitable to implement pagination. +/// +/// ```rust +/// use tantivy::collector::{TopDocs, FilterCollector}; +/// use tantivy::query::QueryParser; +/// use tantivy::schema::{Schema, TEXT, INDEXED, FAST}; +/// use tantivy::{doc, DocAddress, Index}; +/// +/// let mut schema_builder = Schema::builder(); +/// let title = schema_builder.add_text_field("title", TEXT); +/// let price = schema_builder.add_u64_field("price", INDEXED | FAST); +/// let schema = schema_builder.build(); +/// let index = Index::create_in_ram(schema); +/// +/// let mut index_writer = index.writer_with_num_threads(1, 10_000_000).unwrap(); +/// index_writer.add_document(doc!(title => "The Name of the Wind", price => 30_200u64)); +/// index_writer.add_document(doc!(title => "The Diary of Muadib", price => 29_240u64)); +/// index_writer.add_document(doc!(title => "A Dairy Cow", price => 21_240u64)); +/// index_writer.add_document(doc!(title => "The Diary of a Young Girl", price => 20_120u64)); +/// assert!(index_writer.commit().is_ok()); +/// +/// let reader = index.reader().unwrap(); +/// let searcher = reader.searcher(); +/// +/// let query_parser = QueryParser::for_index(&index, vec![title]); +/// let query = query_parser.parse_query("diary").unwrap(); +/// let no_filter_collector = FilterCollector::new(price, &|value| true, TopDocs::with_limit(2)); +/// let top_docs = searcher.search(&query, &no_filter_collector).unwrap(); +/// +/// assert_eq!(top_docs.len(), 2); +/// assert_eq!(top_docs[0].1, DocAddress(0, 1)); +/// assert_eq!(top_docs[1].1, DocAddress(0, 3)); +/// +/// let filter_all_collector = FilterCollector::new(price, &|value| false, TopDocs::with_limit(2)); +/// let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap(); +/// +/// assert_eq!(filtered_top_docs.len(), 0); +/// ``` +pub struct FilterCollector { + field: Field, + collector: TCollector, + predicate: &'static (dyn Fn(u64) -> bool + Send + Sync), + phantom: PhantomData +} + +impl FilterCollector +where + TCollector: Collector + Send + Sync, + TSegmentCollector: 'static + SegmentCollector + Send + Sync { + pub fn new(field: Field, predicate: &'static (dyn Fn(u64) -> bool + Send + Sync), collector: TCollector) -> FilterCollector { + FilterCollector { field, predicate, collector, phantom: PhantomData } + } +} + +impl Collector for FilterCollector +where + TSegmentCollector: 'static + SegmentCollector + Send + Sync, + TCollector: Collector + Send + Sync { + // That's the type of our result. + // Our standard deviation will be a float. + type Fruit = TCollector::Fruit; + + type Child = FilterSegmentCollector; + + fn for_segment( + &self, + segment_local_id: u32, + segment_reader: &SegmentReader, + ) -> crate::Result> { + let fast_field_reader = segment_reader + .fast_fields() + .u64(self.field) + .ok_or_else(|| { + let field_name = segment_reader.schema().get_field_name(self.field); + TantivyError::SchemaError(format!( + "Field {:?} is not a u64 fast field.", + field_name + )) + })?; + let child_segment_collector = self.collector.for_segment(segment_local_id, segment_reader); + match child_segment_collector { + Ok(segment_collector) => { + Ok(FilterSegmentCollector:: { + fast_field_reader, + segment_collector: segment_collector, + predicate: self.predicate + }) + }, + Err(_) => { + Err(TantivyError::SystemError("Could not open segment: ".to_owned())) + }, + } + + + } + + fn requires_scoring(&self) -> bool { + self.collector.requires_scoring() + } + + fn merge_fruits(&self, segment_fruits: Vec<::Fruit>) -> crate::Result { + self.collector.merge_fruits(segment_fruits) + } +} + +pub struct FilterSegmentCollector { + fast_field_reader: FastFieldReader, + segment_collector: TSegmentCollector, + predicate: &'static (dyn Fn(u64) -> bool + Send + Sync) +} + +impl SegmentCollector for FilterSegmentCollector +where + TSegmentCollector: 'static + SegmentCollector + Send + Sync { + type Fruit = TSegmentCollector::Fruit; + + fn collect(&mut self, doc: u32, score: Score) { + let value = self.fast_field_reader.get(doc); + if (self.predicate)(value) { + self.segment_collector.collect(doc, score) + } + } + + fn harvest(self) -> ::Fruit { + self.segment_collector.harvest() + } +} diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 2533c1950..73fc432b7 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -114,6 +114,9 @@ use crate::query::Weight; mod docset_collector; pub use self::docset_collector::DocSetCollector; +mod filter_collector_wrapper; +pub use self::filter_collector_wrapper::FilterCollector; + /// `Fruit` is the type for the result of our collection. /// e.g. `usize` for the `Count` collector. pub trait Fruit: Send + downcast_rs::Downcast {}