This commit is contained in:
barrotsteindev
2020-11-22 10:40:45 +02:00
parent 8782c0eada
commit 7c94dfdc15

View File

@@ -49,32 +49,43 @@ use crate::{Score, SegmentReader, TantivyError};
/// assert_eq!(top_docs.len(), 2);
/// assert_eq!(top_docs[0].1, DocAddress(0, 1));
/// assert_eq!(top_docs[1].1, DocAddress(0, 3));
///
///
/// let filter_all_collector = FilterCollector::new(price, &|value| false, TopDocs::with_limit(2));
/// let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap();
///
///
/// assert_eq!(filtered_top_docs.len(), 0);
/// ```
pub struct FilterCollector<TCollector, TSegmentCollector> {
field: Field,
collector: TCollector,
predicate: &'static (dyn Fn(u64) -> bool + Send + Sync),
phantom: PhantomData<TSegmentCollector>
phantom: PhantomData<TSegmentCollector>,
}
impl<TCollector, TSegmentCollector> FilterCollector<TCollector, TSegmentCollector>
where
impl<TCollector, TSegmentCollector> FilterCollector<TCollector, TSegmentCollector>
where
TCollector: Collector<Child = TSegmentCollector> + Send + Sync,
TSegmentCollector: SegmentCollector + Send + Sync {
pub fn new(field: Field, predicate: &'static (dyn Fn(u64) -> bool + Send + Sync), collector: TCollector) -> FilterCollector<TCollector, TSegmentCollector> {
FilterCollector { field, predicate, collector, phantom: PhantomData }
TSegmentCollector: SegmentCollector + Send + Sync,
{
pub fn new(
field: Field,
predicate: &'static (dyn Fn(u64) -> bool + Send + Sync),
collector: TCollector,
) -> FilterCollector<TCollector, TSegmentCollector> {
FilterCollector {
field,
predicate,
collector,
phantom: PhantomData,
}
}
}
impl<TCollector, TSegmentCollector> Collector for FilterCollector<TCollector, TSegmentCollector>
where
impl<TCollector, TSegmentCollector> Collector for FilterCollector<TCollector, TSegmentCollector>
where
TSegmentCollector: SegmentCollector + Send + Sync,
TCollector: Collector<Child = TSegmentCollector> + Send + Sync {
TCollector: Collector<Child = TSegmentCollector> + Send + Sync,
{
// That's the type of our result.
// Our standard deviation will be a float.
type Fruit = TCollector::Fruit;
@@ -98,26 +109,25 @@ where
})?;
let child_segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
match child_segment_collector {
Ok(segment_collector) => {
Ok(FilterSegmentCollector::<TSegmentCollector> {
fast_field_reader,
segment_collector: segment_collector,
predicate: self.predicate
})
},
Err(_) => {
Err(TantivyError::SystemError("Could not open segment: ".to_owned()))
},
Ok(segment_collector) => Ok(FilterSegmentCollector::<TSegmentCollector> {
fast_field_reader,
segment_collector: segment_collector,
predicate: self.predicate,
}),
Err(_) => Err(TantivyError::SystemError(
"Could not open segment: ".to_owned(),
)),
}
}
fn requires_scoring(&self) -> bool {
self.collector.requires_scoring()
}
fn merge_fruits(&self, segment_fruits: Vec<<TCollector::Child as SegmentCollector>::Fruit>) -> crate::Result<TCollector::Fruit> {
fn merge_fruits(
&self,
segment_fruits: Vec<<TCollector::Child as SegmentCollector>::Fruit>,
) -> crate::Result<TCollector::Fruit> {
self.collector.merge_fruits(segment_fruits)
}
}
@@ -125,12 +135,13 @@ where
pub struct FilterSegmentCollector<TSegmentCollector> {
fast_field_reader: FastFieldReader<u64>,
segment_collector: TSegmentCollector,
predicate: &'static (dyn Fn(u64) -> bool + Send + Sync)
predicate: &'static (dyn Fn(u64) -> bool + Send + Sync),
}
impl<TSegmentCollector> SegmentCollector for FilterSegmentCollector<TSegmentCollector>
where
TSegmentCollector: SegmentCollector + Send + Sync {
impl<TSegmentCollector> SegmentCollector for FilterSegmentCollector<TSegmentCollector>
where
TSegmentCollector: SegmentCollector + Send + Sync,
{
type Fruit = TSegmentCollector::Fruit;
fn collect(&mut self, doc: u32, score: Score) {