From 56b2e9731ff430bd294efa15cd21507265441eaa Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sat, 12 May 2018 16:00:58 -0700 Subject: [PATCH] working. Chained collector is broken though --- src/collector/mod.rs | 17 +-- src/collector/multi_collector.rs | 173 +++++++++++++++++++++---------- 2 files changed, 126 insertions(+), 64 deletions(-) diff --git a/src/collector/mod.rs b/src/collector/mod.rs index aea822367..8121e1acc 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -9,6 +9,7 @@ use SegmentLocalId; use SegmentReader; use query::Query; use Searcher; +use downcast; mod count_collector; pub use self::count_collector::CountCollector; @@ -55,7 +56,7 @@ pub use self::chained_collector::chain; /// /// Segments are not guaranteed to be visited in any specific order. pub trait Collector { - type Child : SegmentCollector; + type Child : SegmentCollector + 'static; /// `set_segment` is called before beginning to enumerate /// on this segment. fn for_segment( @@ -96,16 +97,21 @@ pub trait Collector { } } -pub trait SegmentCollector { +pub trait SegmentCollector: downcast::Any + 'static { /// The query pushes the scored document to the collector via this method. fn collect(&mut self, doc: DocId, score: Score); } +#[allow(missing_docs)] +mod downcast_impl { + downcast!(super::SegmentCollector); +} + impl<'a, C: Collector> Collector for &'a mut C { type Child = C::Child; fn for_segment( - &mut self, + &mut self, // TODO Ask Jason : why &mut self here!? segment_local_id: SegmentLocalId, segment: &SegmentReader, ) -> Result { @@ -121,11 +127,6 @@ impl<'a, C: Collector> Collector for &'a mut C { } } -impl<'a, S: SegmentCollector> SegmentCollector for &'a mut S { - fn collect(&mut self, doc: u32, score: f32) { - (*self).collect(doc, score); - } -} #[cfg(test)] pub mod tests { diff --git a/src/collector/multi_collector.rs b/src/collector/multi_collector.rs index 23a6f05a6..33e88956b 100644 --- a/src/collector/multi_collector.rs +++ b/src/collector/multi_collector.rs @@ -1,96 +1,157 @@ use super::Collector; use super::SegmentCollector; use DocId; -use Result; use Score; +use Result; use SegmentLocalId; use SegmentReader; -use std::any::Any; +use downcast::Downcast; -pub trait AnyCollector: Collector + Any where ::Child : AnySegmentCollector { - fn merge_children_anys(&mut self, children: Vec>); + +pub struct CollectorWrapper<'a, TCollector: 'a + Collector>(&'a mut TCollector); + +trait UntypedCollector { + fn for_segment(&mut self, segment_local_id: u32, segment: &SegmentReader) -> Result>; + + fn requires_scoring(&self) -> bool; + + fn merge_children_anys(&mut self, childrens: Vec>); } -pub trait AnySegmentCollector: SegmentCollector + Any { -} -/// Multicollector makes it possible to collect on more than one collector. -/// It should only be used for use cases where the Collector types is unknown -/// at compile time. -/// If the type of the collectors is known, you should prefer to use `ChainedCollector`. -pub struct MultiCollector<'a> { - collectors: Vec<&'a mut AnyCollector>, -} - -impl<'a> MultiCollector<'a> { - /// Constructor - pub fn from(collectors: Vec<&'a mut AnyCollector>) -> MultiCollector { - MultiCollector { collectors } - } -} - -pub struct SegmentMultiCollector { - segment_collectors: Vec>, -} - -impl<'a> Collector for MultiCollector<'a> { - type Child = SegmentMultiCollector; - - fn for_segment(&mut self, segment_local_id: u32, segment: &SegmentReader) -> Result { - let segment_collectors = self.collectors.iter_mut() - .map(|x| x.for_segment(segment_local_id, segment)) - .collect::>>()?; - Ok(SegmentMultiCollector { segment_collectors }) +impl<'a, TCollector:'a + Collector> UntypedCollector for CollectorWrapper<'a, TCollector> { + fn for_segment(&mut self, segment_local_id: u32, segment: &SegmentReader) -> Result> { + let segment_collector = self.0.for_segment(segment_local_id, segment)?; + Ok(Box::new(segment_collector)) } fn requires_scoring(&self) -> bool { - self.collectors - .iter() - .any(|collector| collector.requires_scoring()) + self.0.requires_scoring() } - fn merge_children(&mut self, children: Vec) { - let mut per_collector_children = - (0..self.collectors.len()) + fn merge_children_anys(&mut self, childrens: Vec>) { + let typed_children: Vec = childrens.into_iter() + .map(|untyped_child_collector| { + *Downcast::::downcast(untyped_child_collector).unwrap() + }).collect(); + self.0.merge_children(typed_children); + } +} + +pub struct MultiCollector<'a> { + collector_wrappers: Vec> +} + +impl<'a> MultiCollector<'a> { + fn new() -> MultiCollector<'a> { + MultiCollector { + collector_wrappers: Vec::new() + } + } + + fn add_collector(&mut self, collector: &'a mut TCollector) { + let collector_wrapper = CollectorWrapper(collector); + self.collector_wrappers.push(Box::new(collector_wrapper)); + } +} + +impl<'a> Collector for MultiCollector<'a> { + + type Child = MultiCollectorChild; + + fn for_segment(&mut self, segment_local_id: SegmentLocalId, segment: &SegmentReader) -> Result { + let children = self.collector_wrappers + .iter_mut() + .map(|collector_wrapper| { + collector_wrapper.for_segment(segment_local_id, segment) + }) + .collect::>>()?; + Ok(MultiCollectorChild { + children + }) + } + + fn requires_scoring(&self) -> bool { + self.collector_wrappers + .iter() + .any(|c| c.requires_scoring()) + } + + fn merge_children(&mut self, children: Vec) { + let mut per_collector_children: Vec>> = + (0..self.collector_wrappers.len()) .map(|_| Vec::with_capacity(children.len())) .collect::>(); - for child in children.into_iter() { - for (idx, segment_collector) in child.segment_collectors.into_iter().enumerate() { + for child in children { + for (idx, segment_collector) in child.children.into_iter().enumerate() { per_collector_children[idx].push(segment_collector); } } - for (collector, children) in self.collectors.iter_mut().zip(per_collector_children) { + for (collector, children) in self.collector_wrappers.iter_mut().zip(per_collector_children) { collector.merge_children_anys(children); } } + +} + +trait UntypedSegmentCollector { + fn collect(); +} + +pub struct MultiCollectorChild { + children: Vec> +} + +impl SegmentCollector for MultiCollectorChild { + fn collect(&mut self, doc: DocId, score: Score) { + for child in &mut self.children { + child.collect(doc, score); + } + } } -impl SegmentCollector for SegmentMultiCollector { - fn collect(&mut self, doc: DocId, score: Score) { - for collector in &mut self.segment_collectors { - collector.collect(doc, score); - } - } -} #[cfg(test)] mod tests { use super::*; use collector::{Collector, CountCollector, TopCollector}; + use schema::{TEXT, SchemaBuilder}; + use query::TermQuery; + use Index; + use Term; + use schema::IndexRecordOption; #[test] fn test_multi_collector() { + let mut schema_builder = SchemaBuilder::new(); + let text = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema); + { + let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); + index_writer.add_document(doc!(text=>"abc")); + index_writer.add_document(doc!(text=>"abc abc abc")); + index_writer.add_document(doc!(text=>"abc abc")); + index_writer.commit().unwrap(); + index_writer.add_document(doc!(text=>"")); + index_writer.add_document(doc!(text=>"abc abc abc abc")); + index_writer.add_document(doc!(text=>"abc")); + index_writer.commit().unwrap(); + } + index.load_searchers().unwrap(); + let searcher = index.searcher(); + let term = Term::from_field_text(text, "abc"); + let query = TermQuery::new(term, IndexRecordOption::Basic); let mut top_collector = TopCollector::with_limit(2); let mut count_collector = CountCollector::default(); { - let mut collectors = - MultiCollector::from(vec![&mut top_collector, &mut count_collector]); - collectors.collect(1, 0.2); - collectors.collect(2, 0.1); - collectors.collect(3, 0.5); + let mut collectors = MultiCollector::new(); + collectors.add_collector(&mut top_collector); + collectors.add_collector(&mut count_collector); + collectors.search(&*searcher, &query).unwrap(); } - assert_eq!(count_collector.count(), 3); - assert!(top_collector.at_capacity()); + assert_eq!(count_collector.count(), 5); } }