WIP generic filter collector

This commit is contained in:
barrotsteindev
2020-12-08 14:36:52 +02:00
parent c805871b92
commit ac704f2f22
2 changed files with 140 additions and 35 deletions

View File

@@ -9,9 +9,11 @@
// ---
// Importing tantivy...
use crate::collector::{Collector, SegmentCollector};
use crate::fastfield::FastFieldReader;
use std::marker::PhantomData;
use crate::fastfield::{FastValue, FastFieldReader};
use crate::schema::Field;
use crate::collector::{Collector, SegmentCollector};
use crate::{Score, SegmentReader, TantivyError};
/// The `FilterCollector` collector filters docs using a u64 fast field value and a predicate.
@@ -20,20 +22,23 @@ use crate::{Score, SegmentReader, TantivyError};
/// ```rust
/// use tantivy::collector::{TopDocs, FilterCollector};
/// use tantivy::query::QueryParser;
/// use tantivy::schema::{Schema, TEXT, INDEXED, FAST};
/// use tantivy::schema::{Schema, FAST, TEXT};
/// use tantivy::DateTime;
/// use std::str::FromStr;
/// use tantivy::{doc, DocAddress, Index};
///
/// let mut schema_builder = Schema::builder();
/// let title = schema_builder.add_text_field("title", TEXT);
/// let price = schema_builder.add_u64_field("price", INDEXED | FAST);
/// let price = schema_builder.add_u64_field("price", FAST);
/// let date = schema_builder.add_date_field("date", FAST);
/// let schema = schema_builder.build();
/// let index = Index::create_in_ram(schema);
///
/// let mut index_writer = index.writer_with_num_threads(1, 10_000_000).unwrap();
/// index_writer.add_document(doc!(title => "The Name of the Wind", price => 30_200u64));
/// index_writer.add_document(doc!(title => "The Diary of Muadib", price => 29_240u64));
/// index_writer.add_document(doc!(title => "A Dairy Cow", price => 21_240u64));
/// index_writer.add_document(doc!(title => "The Diary of a Young Girl", price => 20_120u64));
/// index_writer.add_document(doc!(title => "The Name of the Wind", price => 30_200u64, date => DateTime::from_str("1898-04-09T00:00:00+00:00").unwrap()));
/// index_writer.add_document(doc!(title => "The Diary of Muadib", price => 29_240u64, date => DateTime::from_str("2020-04-09T00:00:00+00:00").unwrap()));
/// index_writer.add_document(doc!(title => "A Dairy Cow", price => 21_240u64, date => DateTime::from_str("2019-04-09T00:00:00+00:00").unwrap()));
/// index_writer.add_document(doc!(title => "The Diary of a Young Girl", price => 20_120u64, date => DateTime::from_str("2018-04-09T00:00:00+00:00").unwrap()));
/// assert!(index_writer.commit().is_ok());
///
/// let reader = index.reader().unwrap();
@@ -41,78 +46,123 @@ use crate::{Score, SegmentReader, TantivyError};
///
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary").unwrap();
/// let no_filter_collector = FilterCollector::new(price, &|value| value > 20_120u64, TopDocs::with_limit(2));
/// let top_docs = searcher.search(&query, &no_filter_collector).unwrap();
/// let filter_some_collector = FilterCollector::new(price, &|value: u64| value > 20_120u64, TopDocs::with_limit(2));
/// let top_docs = searcher.search(&query, &filter_some_collector).unwrap();
///
/// assert_eq!(top_docs.len(), 1);
/// assert_eq!(top_docs[0].1, DocAddress(0, 1));
///
/// let filter_all_collector = FilterCollector::new(price, &|value| value < 5u64, TopDocs::with_limit(2));
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(price, &|value| value < 5u64, TopDocs::with_limit(2));
/// let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap();
///
/// assert_eq!(filtered_top_docs.len(), 0);
///
/// fn date_debug(value: DateTime) -> bool {
/// println!("date: {:?}", value);
/// assert_eq!(value, DateTime::from_str("1000-04-09T00:00:00+00:00").unwrap());
/// (value - DateTime::from_str("2019-04-09T00:00:00+00:00").unwrap()).num_weeks() > 0
/// }
///
/// let filter_dates_collector = FilterCollector::new(date, &date_debug, TopDocs::with_limit(2));
/// let filtered_date_docs = searcher.search(&query, &filter_all_collector).unwrap();
///
/// assert_eq!(filtered_date_docs.len(), 5);
/// ```
pub struct FilterCollector<TCollector, TPredicate>
pub struct FilterCollector<TCollector, TPredicate, TPredicateValue: FastValue>
where
TPredicate: 'static,
{
field: Field,
collector: TCollector,
predicate: &'static TPredicate,
t_predicate_value: PhantomData<TPredicateValue>,
}
impl<TCollector, TPredicate> FilterCollector<TCollector, TPredicate>
impl<TCollector, TPredicate, TPredicateValue: FastValue>
FilterCollector<TCollector, TPredicate, TPredicateValue>
where
TCollector: Collector + Send + Sync,
TPredicate: Fn(u64) -> bool + Send + Sync,
TPredicate: Fn(TPredicateValue) -> bool + Send + Sync,
{
/// Create a new FilterCollector.
pub fn new(
field: Field,
predicate: &'static TPredicate,
collector: TCollector,
) -> FilterCollector<TCollector, TPredicate> {
) -> FilterCollector<TCollector, TPredicate, TPredicateValue> {
FilterCollector {
field,
predicate,
collector,
t_predicate_value: PhantomData,
}
}
}
impl<TCollector, TPredicate> Collector for FilterCollector<TCollector, TPredicate>
impl<TCollector, TPredicate, TPredicateValue: FastValue> Collector
for FilterCollector<TCollector, TPredicate, TPredicateValue>
where
TCollector: Collector + Send + Sync,
TPredicate: 'static + Fn(u64) -> bool + Send + Sync,
TPredicate: 'static + Fn(TPredicateValue) -> bool + Send + Sync,
TPredicateValue: 'static + FastValue,
{
// That's the type of our result.
// Our standard deviation will be a float.
type Fruit = TCollector::Fruit;
type Child = FilterSegmentCollector<TCollector::Child, TPredicate>;
type Child = FilterSegmentCollector<TCollector::Child, TPredicate, TPredicateValue>;
fn for_segment(
&self,
segment_local_id: u32,
segment_reader: &SegmentReader,
) -> crate::Result<FilterSegmentCollector<TCollector::Child, TPredicate>> {
let fast_field_reader = segment_reader
.fast_fields()
.u64(self.field)
.ok_or_else(|| {
let field_name = segment_reader.schema().get_field_name(self.field);
TantivyError::SchemaError(format!(
"Field {:?} is not a u64 fast field.",
field_name
))
})?;
) -> crate::Result<FilterSegmentCollector<TCollector::Child, TPredicate, TPredicateValue>> {
let schema = segment_reader.schema();
let field_entry = schema.get_field_entry(self.field);
if !field_entry.is_fast() {
return Err(TantivyError::SchemaError(format!(
"Field {:?} is not a fast field.",
field_entry.name()
)));
}
let schema_type = TPredicateValue::to_type();
let requested_type = field_entry.field_type().value_type();
if schema_type != requested_type {
return Err(TantivyError::SchemaError(format!(
"Field {:?} is of type {:?}!={:?}",
field_entry.name(),
schema_type,
requested_type
)));
}
let err_closure = || {
let field_name = segment_reader.schema().get_field_name(self.field);
TantivyError::SchemaError(format!(
"Field {:?} is not a u64 fast field.",
field_name
))
};
let fast_fields = segment_reader.fast_fields();
let fast_filed_reader: crate::Result<FastFieldReader<TPredicateValue>> = match schema_type {
crate::schema::Type::U64 => {fast_fields.u64(self.field).ok_or_else(err_closure)}
crate::schema::Type::I64 => {fast_fields.i64(self.field).ok_or_else(err_closure)}
crate::schema::Type::F64 => {fast_fields.f64(self.field).ok_or_else(err_closure)}
crate::schema::Type::Date => {fast_fields.date(self.field).ok_or_else(err_closure)}
crate::schema::Type::Bytes => {fast_fields.bytes(self.field).ok_or_else(err_closure)}
crate::schema::Type::Str | crate::schema::Type::HierarchicalFacet => {Err(TantivyError::SchemaError(format!("Field {:?} uses an unsupported type", segment_reader.schema().get_field_name(self.field))))}
};
let segment_collector = self
.collector
.for_segment(segment_local_id, segment_reader)?;
let a = fast_filed_reader?;
Ok(FilterSegmentCollector {
fast_field_reader,
fast_field_reader: a,
segment_collector: segment_collector,
predicate: self.predicate,
t_predicate_value: PhantomData,
})
}
@@ -128,20 +178,23 @@ where
}
}
pub struct FilterSegmentCollector<TSegmentCollector, TPredicate>
pub struct FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue>
where
TPredicate: 'static,
TPredicateValue: 'static + FastValue,
{
fast_field_reader: FastFieldReader<u64>,
fast_field_reader: FastFieldReader<TPredicateValue>,
segment_collector: TSegmentCollector,
predicate: &'static TPredicate,
t_predicate_value: PhantomData<TPredicateValue>,
}
impl<TSegmentCollector, TPredicate> SegmentCollector
for FilterSegmentCollector<TSegmentCollector, TPredicate>
impl<TSegmentCollector, TPredicate, TPredicateValue> SegmentCollector
for FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue>
where
TSegmentCollector: SegmentCollector,
TPredicate: 'static + Fn(u64) -> bool + Send + Sync,
TPredicate: 'static + Fn(TPredicateValue) -> bool + Send + Sync,
TPredicateValue: 'static + FastValue,
{
type Fruit = TSegmentCollector::Fruit;

View File

@@ -8,6 +8,13 @@ use crate::DocId;
use crate::Score;
use crate::SegmentLocalId;
use crate::collector::{TopDocs, FilterCollector};
use crate::query::QueryParser;
use crate::schema::{Schema, FAST, TEXT};
use crate::DateTime;
use std::str::FromStr;
use crate::{doc, Index};
pub const TEST_COLLECTOR_WITH_SCORE: TestCollector = TestCollector {
compute_score: true,
};
@@ -16,6 +23,51 @@ pub const TEST_COLLECTOR_WITHOUT_SCORE: TestCollector = TestCollector {
compute_score: true,
};
#[test]
pub fn test_filter_collector() {
let mut schema_builder = Schema::builder();
let title = schema_builder.add_text_field("title", TEXT);
let price = schema_builder.add_u64_field("price", FAST);
let date = schema_builder.add_date_field("date", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 10_000_000).unwrap();
index_writer.add_document(doc!(title => "The Name of the Wind", price => 30_200u64, date => DateTime::from_str("1898-04-09T00:00:00+00:00").unwrap()));
index_writer.add_document(doc!(title => "The Diary of Muadib", price => 29_240u64, date => DateTime::from_str("2020-04-09T00:00:00+00:00").unwrap()));
index_writer.add_document(doc!(title => "A Dairy Cow", price => 21_240u64, date => DateTime::from_str("2019-04-09T00:00:00+00:00").unwrap()));
index_writer.add_document(doc!(title => "The Diary of a Young Girl", price => 20_120u64, date => DateTime::from_str("2018-04-09T00:00:00+00:00").unwrap()));
assert!(index_writer.commit().is_ok());
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![title]);
let query = query_parser.parse_query("diary").unwrap();
let filter_some_collector = FilterCollector::new(price, &|value: u64| value > 20_120u64, TopDocs::with_limit(2));
let top_docs = searcher.search(&query, &filter_some_collector).unwrap();
assert_eq!(top_docs.len(), 1);
assert_eq!(top_docs[0].1, DocAddress(0, 1));
let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(price, &|value| value < 5u64, TopDocs::with_limit(2));
let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap();
assert_eq!(filtered_top_docs.len(), 0);
fn date_debug(value: DateTime) -> bool {
println!("date: {:?}", value);
assert_eq!(value, DateTime::from_str("1000-04-09T00:00:00+00:00").unwrap());
(value - DateTime::from_str("2019-04-09T00:00:00+00:00").unwrap()).num_weeks() > 0
}
let filter_dates_collector = FilterCollector::new(date, &date_debug, TopDocs::with_limit(2));
let filtered_date_docs = searcher.search(&query, &filter_dates_collector).unwrap();
assert_eq!(filtered_date_docs.len(), 5);
}
/// Stores all of the doc ids.
/// This collector is only used for tests.
/// It is unusable in pr