working. Chained collector is broken though

This commit is contained in:
Paul Masurel
2018-05-12 16:00:58 -07:00
parent c85668cabe
commit 56b2e9731f
2 changed files with 126 additions and 64 deletions

View File

@@ -9,6 +9,7 @@ use SegmentLocalId;
use SegmentReader;
use query::Query;
use Searcher;
use downcast;
mod count_collector;
pub use self::count_collector::CountCollector;
@@ -55,7 +56,7 @@ pub use self::chained_collector::chain;
///
/// Segments are not guaranteed to be visited in any specific order.
pub trait Collector {
type Child : SegmentCollector;
type Child : SegmentCollector + 'static;
/// `set_segment` is called before beginning to enumerate
/// on this segment.
fn for_segment(
@@ -96,16 +97,21 @@ pub trait Collector {
}
}
pub trait SegmentCollector {
pub trait SegmentCollector: downcast::Any + 'static {
/// The query pushes the scored document to the collector via this method.
fn collect(&mut self, doc: DocId, score: Score);
}
#[allow(missing_docs)]
mod downcast_impl {
downcast!(super::SegmentCollector);
}
impl<'a, C: Collector> Collector for &'a mut C {
type Child = C::Child;
fn for_segment(
&mut self,
&mut self, // TODO Ask Jason : why &mut self here!?
segment_local_id: SegmentLocalId,
segment: &SegmentReader,
) -> Result<C::Child> {
@@ -121,11 +127,6 @@ impl<'a, C: Collector> Collector for &'a mut C {
}
}
impl<'a, S: SegmentCollector> SegmentCollector for &'a mut S {
fn collect(&mut self, doc: u32, score: f32) {
(*self).collect(doc, score);
}
}
#[cfg(test)]
pub mod tests {

View File

@@ -1,96 +1,157 @@
use super::Collector;
use super::SegmentCollector;
use DocId;
use Result;
use Score;
use Result;
use SegmentLocalId;
use SegmentReader;
use std::any::Any;
use downcast::Downcast;
pub trait AnyCollector: Collector + Any where <Self as Collector>::Child : AnySegmentCollector {
fn merge_children_anys(&mut self, children: Vec<Box<AnySegmentCollector>>);
pub struct CollectorWrapper<'a, TCollector: 'a + Collector>(&'a mut TCollector);
trait UntypedCollector {
fn for_segment(&mut self, segment_local_id: u32, segment: &SegmentReader) -> Result<Box<SegmentCollector>>;
fn requires_scoring(&self) -> bool;
fn merge_children_anys(&mut self, childrens: Vec<Box<SegmentCollector>>);
}
pub trait AnySegmentCollector: SegmentCollector + Any {
}
/// Multicollector makes it possible to collect on more than one collector.
/// It should only be used for use cases where the Collector types is unknown
/// at compile time.
/// If the type of the collectors is known, you should prefer to use `ChainedCollector`.
pub struct MultiCollector<'a> {
collectors: Vec<&'a mut AnyCollector>,
}
impl<'a> MultiCollector<'a> {
/// Constructor
pub fn from(collectors: Vec<&'a mut AnyCollector>) -> MultiCollector {
MultiCollector { collectors }
}
}
pub struct SegmentMultiCollector {
segment_collectors: Vec<Box<AnySegmentCollector>>,
}
impl<'a> Collector for MultiCollector<'a> {
type Child = SegmentMultiCollector;
fn for_segment(&mut self, segment_local_id: u32, segment: &SegmentReader) -> Result<SegmentMultiCollector> {
let segment_collectors = self.collectors.iter_mut()
.map(|x| x.for_segment(segment_local_id, segment))
.collect::<Result<Vec<_>>>()?;
Ok(SegmentMultiCollector { segment_collectors })
impl<'a, TCollector:'a + Collector> UntypedCollector for CollectorWrapper<'a, TCollector> {
fn for_segment(&mut self, segment_local_id: u32, segment: &SegmentReader) -> Result<Box<SegmentCollector>> {
let segment_collector = self.0.for_segment(segment_local_id, segment)?;
Ok(Box::new(segment_collector))
}
fn requires_scoring(&self) -> bool {
self.collectors
.iter()
.any(|collector| collector.requires_scoring())
self.0.requires_scoring()
}
fn merge_children(&mut self, children: Vec<SegmentMultiCollector>) {
let mut per_collector_children =
(0..self.collectors.len())
fn merge_children_anys(&mut self, childrens: Vec<Box<SegmentCollector>>) {
let typed_children: Vec<TCollector::Child> = childrens.into_iter()
.map(|untyped_child_collector| {
*Downcast::<TCollector::Child>::downcast(untyped_child_collector).unwrap()
}).collect();
self.0.merge_children(typed_children);
}
}
pub struct MultiCollector<'a> {
collector_wrappers: Vec<Box<UntypedCollector + 'a>>
}
impl<'a> MultiCollector<'a> {
fn new() -> MultiCollector<'a> {
MultiCollector {
collector_wrappers: Vec::new()
}
}
fn add_collector<TCollector: 'a + Collector>(&mut self, collector: &'a mut TCollector) {
let collector_wrapper = CollectorWrapper(collector);
self.collector_wrappers.push(Box::new(collector_wrapper));
}
}
impl<'a> Collector for MultiCollector<'a> {
type Child = MultiCollectorChild;
fn for_segment(&mut self, segment_local_id: SegmentLocalId, segment: &SegmentReader) -> Result<MultiCollectorChild> {
let children = self.collector_wrappers
.iter_mut()
.map(|collector_wrapper| {
collector_wrapper.for_segment(segment_local_id, segment)
})
.collect::<Result<Vec<_>>>()?;
Ok(MultiCollectorChild {
children
})
}
fn requires_scoring(&self) -> bool {
self.collector_wrappers
.iter()
.any(|c| c.requires_scoring())
}
fn merge_children(&mut self, children: Vec<MultiCollectorChild>) {
let mut per_collector_children: Vec<Vec<Box<SegmentCollector>>> =
(0..self.collector_wrappers.len())
.map(|_| Vec::with_capacity(children.len()))
.collect::<Vec<_>>();
for child in children.into_iter() {
for (idx, segment_collector) in child.segment_collectors.into_iter().enumerate() {
for child in children {
for (idx, segment_collector) in child.children.into_iter().enumerate() {
per_collector_children[idx].push(segment_collector);
}
}
for (collector, children) in self.collectors.iter_mut().zip(per_collector_children) {
for (collector, children) in self.collector_wrappers.iter_mut().zip(per_collector_children) {
collector.merge_children_anys(children);
}
}
}
trait UntypedSegmentCollector {
fn collect();
}
pub struct MultiCollectorChild {
children: Vec<Box<SegmentCollector>>
}
impl SegmentCollector for MultiCollectorChild {
fn collect(&mut self, doc: DocId, score: Score) {
for child in &mut self.children {
child.collect(doc, score);
}
}
}
impl SegmentCollector for SegmentMultiCollector {
fn collect(&mut self, doc: DocId, score: Score) {
for collector in &mut self.segment_collectors {
collector.collect(doc, score);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use collector::{Collector, CountCollector, TopCollector};
use schema::{TEXT, SchemaBuilder};
use query::TermQuery;
use Index;
use Term;
use schema::IndexRecordOption;
#[test]
fn test_multi_collector() {
let mut schema_builder = SchemaBuilder::new();
let text = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
index_writer.add_document(doc!(text=>"abc"));
index_writer.add_document(doc!(text=>"abc abc abc"));
index_writer.add_document(doc!(text=>"abc abc"));
index_writer.commit().unwrap();
index_writer.add_document(doc!(text=>""));
index_writer.add_document(doc!(text=>"abc abc abc abc"));
index_writer.add_document(doc!(text=>"abc"));
index_writer.commit().unwrap();
}
index.load_searchers().unwrap();
let searcher = index.searcher();
let term = Term::from_field_text(text, "abc");
let query = TermQuery::new(term, IndexRecordOption::Basic);
let mut top_collector = TopCollector::with_limit(2);
let mut count_collector = CountCollector::default();
{
let mut collectors =
MultiCollector::from(vec![&mut top_collector, &mut count_collector]);
collectors.collect(1, 0.2);
collectors.collect(2, 0.1);
collectors.collect(3, 0.5);
let mut collectors = MultiCollector::new();
collectors.add_collector(&mut top_collector);
collectors.add_collector(&mut count_collector);
collectors.search(&*searcher, &query).unwrap();
}
assert_eq!(count_collector.count(), 3);
assert!(top_collector.at_capacity());
assert_eq!(count_collector.count(), 5);
}
}