From ac704f2f226018916e8d5647c7f602641449536d Mon Sep 17 00:00:00 2001 From: barrotsteindev Date: Tue, 8 Dec 2020 14:36:52 +0200 Subject: [PATCH] WIP generic filter collector --- src/collector/filter_collector_wrapper.rs | 123 ++++++++++++++++------ src/collector/tests.rs | 52 +++++++++ 2 files changed, 140 insertions(+), 35 deletions(-) diff --git a/src/collector/filter_collector_wrapper.rs b/src/collector/filter_collector_wrapper.rs index 0e3e33927..276823031 100644 --- a/src/collector/filter_collector_wrapper.rs +++ b/src/collector/filter_collector_wrapper.rs @@ -9,9 +9,11 @@ // --- // Importing tantivy... -use crate::collector::{Collector, SegmentCollector}; -use crate::fastfield::FastFieldReader; +use std::marker::PhantomData; + +use crate::fastfield::{FastValue, FastFieldReader}; use crate::schema::Field; +use crate::collector::{Collector, SegmentCollector}; use crate::{Score, SegmentReader, TantivyError}; /// The `FilterCollector` collector filters docs using a u64 fast field value and a predicate. @@ -20,20 +22,23 @@ use crate::{Score, SegmentReader, TantivyError}; /// ```rust /// use tantivy::collector::{TopDocs, FilterCollector}; /// use tantivy::query::QueryParser; -/// use tantivy::schema::{Schema, TEXT, INDEXED, FAST}; +/// use tantivy::schema::{Schema, FAST, TEXT}; +/// use tantivy::DateTime; +/// use std::str::FromStr; /// use tantivy::{doc, DocAddress, Index}; /// /// let mut schema_builder = Schema::builder(); /// let title = schema_builder.add_text_field("title", TEXT); -/// let price = schema_builder.add_u64_field("price", INDEXED | FAST); +/// let price = schema_builder.add_u64_field("price", FAST); +/// let date = schema_builder.add_date_field("date", FAST); /// let schema = schema_builder.build(); /// let index = Index::create_in_ram(schema); /// /// let mut index_writer = index.writer_with_num_threads(1, 10_000_000).unwrap(); -/// index_writer.add_document(doc!(title => "The Name of the Wind", price => 30_200u64)); -/// index_writer.add_document(doc!(title => "The Diary of Muadib", price => 29_240u64)); -/// index_writer.add_document(doc!(title => "A Dairy Cow", price => 21_240u64)); -/// index_writer.add_document(doc!(title => "The Diary of a Young Girl", price => 20_120u64)); +/// index_writer.add_document(doc!(title => "The Name of the Wind", price => 30_200u64, date => DateTime::from_str("1898-04-09T00:00:00+00:00").unwrap())); +/// index_writer.add_document(doc!(title => "The Diary of Muadib", price => 29_240u64, date => DateTime::from_str("2020-04-09T00:00:00+00:00").unwrap())); +/// index_writer.add_document(doc!(title => "A Dairy Cow", price => 21_240u64, date => DateTime::from_str("2019-04-09T00:00:00+00:00").unwrap())); +/// index_writer.add_document(doc!(title => "The Diary of a Young Girl", price => 20_120u64, date => DateTime::from_str("2018-04-09T00:00:00+00:00").unwrap())); /// assert!(index_writer.commit().is_ok()); /// /// let reader = index.reader().unwrap(); @@ -41,78 +46,123 @@ use crate::{Score, SegmentReader, TantivyError}; /// /// let query_parser = QueryParser::for_index(&index, vec![title]); /// let query = query_parser.parse_query("diary").unwrap(); -/// let no_filter_collector = FilterCollector::new(price, &|value| value > 20_120u64, TopDocs::with_limit(2)); -/// let top_docs = searcher.search(&query, &no_filter_collector).unwrap(); +/// let filter_some_collector = FilterCollector::new(price, &|value: u64| value > 20_120u64, TopDocs::with_limit(2)); +/// let top_docs = searcher.search(&query, &filter_some_collector).unwrap(); /// /// assert_eq!(top_docs.len(), 1); /// assert_eq!(top_docs[0].1, DocAddress(0, 1)); /// -/// let filter_all_collector = FilterCollector::new(price, &|value| value < 5u64, TopDocs::with_limit(2)); +/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(price, &|value| value < 5u64, TopDocs::with_limit(2)); /// let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap(); /// /// assert_eq!(filtered_top_docs.len(), 0); +/// +/// fn date_debug(value: DateTime) -> bool { +/// println!("date: {:?}", value); +/// assert_eq!(value, DateTime::from_str("1000-04-09T00:00:00+00:00").unwrap()); +/// (value - DateTime::from_str("2019-04-09T00:00:00+00:00").unwrap()).num_weeks() > 0 +/// } +/// +/// let filter_dates_collector = FilterCollector::new(date, &date_debug, TopDocs::with_limit(2)); +/// let filtered_date_docs = searcher.search(&query, &filter_all_collector).unwrap(); +/// +/// assert_eq!(filtered_date_docs.len(), 5); /// ``` -pub struct FilterCollector +pub struct FilterCollector where TPredicate: 'static, { field: Field, collector: TCollector, predicate: &'static TPredicate, + t_predicate_value: PhantomData, } -impl FilterCollector +impl + FilterCollector where TCollector: Collector + Send + Sync, - TPredicate: Fn(u64) -> bool + Send + Sync, + TPredicate: Fn(TPredicateValue) -> bool + Send + Sync, { /// Create a new FilterCollector. pub fn new( field: Field, predicate: &'static TPredicate, collector: TCollector, - ) -> FilterCollector { + ) -> FilterCollector { FilterCollector { field, predicate, collector, + t_predicate_value: PhantomData, } } } -impl Collector for FilterCollector +impl Collector + for FilterCollector where TCollector: Collector + Send + Sync, - TPredicate: 'static + Fn(u64) -> bool + Send + Sync, + TPredicate: 'static + Fn(TPredicateValue) -> bool + Send + Sync, + TPredicateValue: 'static + FastValue, { // That's the type of our result. // Our standard deviation will be a float. type Fruit = TCollector::Fruit; - type Child = FilterSegmentCollector; + type Child = FilterSegmentCollector; fn for_segment( &self, segment_local_id: u32, segment_reader: &SegmentReader, - ) -> crate::Result> { - let fast_field_reader = segment_reader - .fast_fields() - .u64(self.field) - .ok_or_else(|| { - let field_name = segment_reader.schema().get_field_name(self.field); - TantivyError::SchemaError(format!( - "Field {:?} is not a u64 fast field.", - field_name - )) - })?; + ) -> crate::Result> { + let schema = segment_reader.schema(); + let field_entry = schema.get_field_entry(self.field); + if !field_entry.is_fast() { + return Err(TantivyError::SchemaError(format!( + "Field {:?} is not a fast field.", + field_entry.name() + ))); + } + let schema_type = TPredicateValue::to_type(); + let requested_type = field_entry.field_type().value_type(); + if schema_type != requested_type { + return Err(TantivyError::SchemaError(format!( + "Field {:?} is of type {:?}!={:?}", + field_entry.name(), + schema_type, + requested_type + ))); + } + + let err_closure = || { + let field_name = segment_reader.schema().get_field_name(self.field); + TantivyError::SchemaError(format!( + "Field {:?} is not a u64 fast field.", + field_name + )) + }; + let fast_fields = segment_reader.fast_fields(); + let fast_filed_reader: crate::Result> = match schema_type { + crate::schema::Type::U64 => {fast_fields.u64(self.field).ok_or_else(err_closure)} + crate::schema::Type::I64 => {fast_fields.i64(self.field).ok_or_else(err_closure)} + crate::schema::Type::F64 => {fast_fields.f64(self.field).ok_or_else(err_closure)} + crate::schema::Type::Date => {fast_fields.date(self.field).ok_or_else(err_closure)} + crate::schema::Type::Bytes => {fast_fields.bytes(self.field).ok_or_else(err_closure)} + crate::schema::Type::Str | crate::schema::Type::HierarchicalFacet => {Err(TantivyError::SchemaError(format!("Field {:?} uses an unsupported type", segment_reader.schema().get_field_name(self.field))))} + }; + let segment_collector = self .collector .for_segment(segment_local_id, segment_reader)?; + + let a = fast_filed_reader?; Ok(FilterSegmentCollector { - fast_field_reader, + fast_field_reader: a, segment_collector: segment_collector, predicate: self.predicate, + t_predicate_value: PhantomData, }) } @@ -128,20 +178,23 @@ where } } -pub struct FilterSegmentCollector +pub struct FilterSegmentCollector where TPredicate: 'static, + TPredicateValue: 'static + FastValue, { - fast_field_reader: FastFieldReader, + fast_field_reader: FastFieldReader, segment_collector: TSegmentCollector, predicate: &'static TPredicate, + t_predicate_value: PhantomData, } -impl SegmentCollector - for FilterSegmentCollector +impl SegmentCollector + for FilterSegmentCollector where TSegmentCollector: SegmentCollector, - TPredicate: 'static + Fn(u64) -> bool + Send + Sync, + TPredicate: 'static + Fn(TPredicateValue) -> bool + Send + Sync, + TPredicateValue: 'static + FastValue, { type Fruit = TSegmentCollector::Fruit; diff --git a/src/collector/tests.rs b/src/collector/tests.rs index 144261be8..7cac6dea4 100644 --- a/src/collector/tests.rs +++ b/src/collector/tests.rs @@ -8,6 +8,13 @@ use crate::DocId; use crate::Score; use crate::SegmentLocalId; +use crate::collector::{TopDocs, FilterCollector}; +use crate::query::QueryParser; +use crate::schema::{Schema, FAST, TEXT}; +use crate::DateTime; +use std::str::FromStr; +use crate::{doc, Index}; + pub const TEST_COLLECTOR_WITH_SCORE: TestCollector = TestCollector { compute_score: true, }; @@ -16,6 +23,51 @@ pub const TEST_COLLECTOR_WITHOUT_SCORE: TestCollector = TestCollector { compute_score: true, }; +#[test] +pub fn test_filter_collector() { + + let mut schema_builder = Schema::builder(); + let title = schema_builder.add_text_field("title", TEXT); + let price = schema_builder.add_u64_field("price", FAST); + let date = schema_builder.add_date_field("date", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + + let mut index_writer = index.writer_with_num_threads(1, 10_000_000).unwrap(); + index_writer.add_document(doc!(title => "The Name of the Wind", price => 30_200u64, date => DateTime::from_str("1898-04-09T00:00:00+00:00").unwrap())); + index_writer.add_document(doc!(title => "The Diary of Muadib", price => 29_240u64, date => DateTime::from_str("2020-04-09T00:00:00+00:00").unwrap())); + index_writer.add_document(doc!(title => "A Dairy Cow", price => 21_240u64, date => DateTime::from_str("2019-04-09T00:00:00+00:00").unwrap())); + index_writer.add_document(doc!(title => "The Diary of a Young Girl", price => 20_120u64, date => DateTime::from_str("2018-04-09T00:00:00+00:00").unwrap())); + assert!(index_writer.commit().is_ok()); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let query_parser = QueryParser::for_index(&index, vec![title]); + let query = query_parser.parse_query("diary").unwrap(); + let filter_some_collector = FilterCollector::new(price, &|value: u64| value > 20_120u64, TopDocs::with_limit(2)); + let top_docs = searcher.search(&query, &filter_some_collector).unwrap(); + + assert_eq!(top_docs.len(), 1); + assert_eq!(top_docs[0].1, DocAddress(0, 1)); + + let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(price, &|value| value < 5u64, TopDocs::with_limit(2)); + let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap(); + + assert_eq!(filtered_top_docs.len(), 0); + + fn date_debug(value: DateTime) -> bool { + println!("date: {:?}", value); + assert_eq!(value, DateTime::from_str("1000-04-09T00:00:00+00:00").unwrap()); + (value - DateTime::from_str("2019-04-09T00:00:00+00:00").unwrap()).num_weeks() > 0 + } + + let filter_dates_collector = FilterCollector::new(date, &date_debug, TopDocs::with_limit(2)); + let filtered_date_docs = searcher.search(&query, &filter_dates_collector).unwrap(); + + assert_eq!(filtered_date_docs.len(), 5); +} + /// Stores all of the doc ids. /// This collector is only used for tests. /// It is unusable in pr