diff --git a/src/core/merger.rs b/src/core/merger.rs index 130df2ea8..f280729ab 100644 --- a/src/core/merger.rs +++ b/src/core/merger.rs @@ -223,6 +223,7 @@ mod tests { use core::searcher::DocAddress; use collector::FastFieldTestCollector; use collector::TestCollector; + use query::MultiTermQuery; use schema::TextIndexingOptions; #[test] @@ -286,7 +287,8 @@ mod tests { let searcher = index.searcher().unwrap(); let get_doc_ids = |terms: Vec| { let mut collector = TestCollector::new(); - assert!(searcher.search(&terms, &mut collector).is_ok()); + let query = MultiTermQuery::new(terms); + assert!(searcher.search(&query, &mut collector).is_ok()); collector.docs() }; { @@ -329,8 +331,9 @@ mod tests { } { let get_fast_vals = |terms: Vec| { + let query = MultiTermQuery::new(terms); let mut collector = FastFieldTestCollector::for_field(score_field); - assert!(searcher.search(&terms, &mut collector).is_ok()); + assert!(searcher.search(&query, &mut collector).is_ok()); collector.vals().clone() }; assert_eq!( diff --git a/src/core/reader.rs b/src/core/reader.rs index 53ba2c024..7ce768ae3 100644 --- a/src/core/reader.rs +++ b/src/core/reader.rs @@ -96,7 +96,7 @@ impl SegmentReader { SegmentPostings::from_data(term_info.doc_freq, &postings_data) } - fn get_term<'a>(&'a self, term: &Term) -> Option { + pub fn get_term<'a>(&'a self, term: &Term) -> Option { self.term_infos.get(term.as_slice()) } diff --git a/src/core/searcher.rs b/src/core/searcher.rs index 1bd0d1109..84d245ed3 100644 --- a/src/core/searcher.rs +++ b/src/core/searcher.rs @@ -2,11 +2,12 @@ use core::reader::SegmentReader; use core::index::Index; use core::index::Segment; use DocId; -use schema::{Document, Term}; +use schema::Document; use collector::Collector; use std::io; use common::TimerTree; use postings::Postings; +use query::Query; #[derive(Debug)] pub struct Searcher { @@ -40,6 +41,10 @@ impl Searcher { segments: Vec::new(), } } + + pub fn segments(&self,) -> &Vec { + &self.segments + } pub fn for_index(index: Index) -> io::Result { let mut searcher = Searcher::new(); @@ -48,27 +53,31 @@ impl Searcher { } Ok(searcher) } - - pub fn search(&self, terms: &Vec, collector: &mut C) -> io::Result { - let mut timer_tree = TimerTree::new(); - { - let mut search_timer = timer_tree.open("search"); - for (segment_ord, segment) in self.segments.iter().enumerate() { - let mut segment_search_timer = search_timer.open("segment_search"); - { - let _ = segment_search_timer.open("set_segment"); - try!(collector.set_segment(segment_ord as SegmentLocalId, &segment)); - } - let mut postings = segment.search(terms, segment_search_timer.open("get_postings")); - { - let _collection_timer = segment_search_timer.open("collection"); - while postings.next() { - collector.collect(postings.doc()); - } - } - } - } - Ok(timer_tree) + + pub fn search(&self, query: &Q, collector: &mut C) -> io::Result { + query.search(self, collector) } + + // pub fn search(&self, terms: &Vec, collector: &mut C) -> io::Result { + // let mut timer_tree = TimerTree::new(); + // { + // let mut search_timer = timer_tree.open("search"); + // for (segment_ord, segment) in self.segments.iter().enumerate() { + // let mut segment_search_timer = search_timer.open("segment_search"); + // { + // let _ = segment_search_timer.open("set_segment"); + // try!(collector.set_segment(segment_ord as SegmentLocalId, &segment)); + // } + // let mut postings = segment.search(terms, segment_search_timer.open("get_postings")); + // { + // let _collection_timer = segment_search_timer.open("collection"); + // while postings.next() { + // collector.collect(postings.doc()); + // } + // } + // } + // } + // Ok(timer_tree) + // } } diff --git a/src/lib.rs b/src/lib.rs index bf1d67119..979ab5ff4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,8 @@ mod compression; mod fastfield; mod store; mod common; +pub mod query; + pub mod analyzer; pub mod collector; @@ -59,6 +61,7 @@ mod tests { use super::*; use collector::TestCollector; + use query::MultiTermQuery; #[test] fn test_indexing() { @@ -123,8 +126,9 @@ mod tests { { let searcher = index.searcher().unwrap(); let get_doc_ids = |terms: Vec| { + let query = MultiTermQuery::new(terms); let mut collector = TestCollector::new(); - assert!(searcher.search(&terms, &mut collector).is_ok()); + assert!(searcher.search(&query, &mut collector).is_ok()); collector.docs() }; { @@ -159,4 +163,33 @@ mod tests { } } } + + #[test] + fn test_searcher_2() { + let mut schema = schema::Schema::new(); + let text_field = schema.add_text_field("text", &schema::TEXT); + let index = Index::create_in_ram(schema); + + { + // writing the segment + let mut index_writer = index.writer_with_num_threads(1).unwrap(); + { + let mut doc = Document::new(); + doc.set(&text_field, "af b"); + index_writer.add_document(doc).unwrap(); + } + { + let mut doc = Document::new(); + doc.set(&text_field, "a b c"); + index_writer.add_document(doc).unwrap(); + } + { + let mut doc = Document::new(); + doc.set(&text_field, "a b c d"); + index_writer.add_document(doc).unwrap(); + } + index_writer.wait().unwrap(); + } + index.searcher().unwrap(); + } } diff --git a/src/postings/mod.rs b/src/postings/mod.rs index ac0f081e3..59af677e7 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -51,7 +51,7 @@ mod tests { let read = segment.open_read(SegmentComponent::POSITIONS).unwrap(); assert_eq!(read.len(), 12); } - + } diff --git a/src/query/mod.rs b/src/query/mod.rs new file mode 100644 index 000000000..5ce4e6c03 --- /dev/null +++ b/src/query/mod.rs @@ -0,0 +1,5 @@ +mod query; +mod multi_term_query; + +pub use self::query::Query; +pub use self::multi_term_query::MultiTermQuery; \ No newline at end of file diff --git a/src/query/multi_term_query.rs b/src/query/multi_term_query.rs new file mode 100644 index 000000000..36a673755 --- /dev/null +++ b/src/query/multi_term_query.rs @@ -0,0 +1,88 @@ +use schema::Term; +use query::Query; +use common::TimerTree; +use common::OpenTimer; +use std::io; +use core::searcher::Searcher; +use collector::Collector; +use core::searcher::SegmentLocalId; +use core::reader::SegmentReader; +use postings::Postings; +use postings::SegmentPostings; +use postings::intersection; + +pub struct MultiTermQuery { + terms: Vec, +} + +impl Query for MultiTermQuery { + + + fn search(&self, searcher: &Searcher, collector: &mut C) -> io::Result { + let mut timer_tree = TimerTree::new(); + { + let mut search_timer = timer_tree.open("search"); + for (segment_ord, segment_reader) in searcher.segments().iter().enumerate() { + let mut segment_search_timer = search_timer.open("segment_search"); + { + let _ = segment_search_timer.open("set_segment"); + try!(collector.set_segment(segment_ord as SegmentLocalId, &segment_reader)); + } + let mut postings = self.search_segment(segment_reader, segment_search_timer.open("get_postings")); + { + let _collection_timer = segment_search_timer.open("collection"); + while postings.next() { + collector.collect(postings.doc()); + } + } + } + } + Ok(timer_tree) + } +} + +impl MultiTermQuery { + pub fn new(terms: Vec) -> MultiTermQuery { + MultiTermQuery { + terms: terms, + } + } + + + fn search_segment<'a, 'b>(&'b self, reader: &'b SegmentReader, mut timer: OpenTimer<'a>) -> Box { + if self.terms.len() == 1 { + match reader.get_term(&self.terms[0]) { + Some(term_info) => { + let postings: SegmentPostings<'b> = reader.read_postings(&term_info); + Box::new(postings) + }, + None => { + Box::new(SegmentPostings::empty()) + }, + } + } else { + let mut segment_postings: Vec = Vec::new(); + { + let mut decode_timer = timer.open("decode_all"); + for term in self.terms.iter() { + match reader.get_term(term) { + Some(term_info) => { + let _decode_one_timer = decode_timer.open("decode_one"); + let segment_posting = reader.read_postings(&term_info); + segment_postings.push(segment_posting); + } + None => { + // currently this is a strict intersection. + return Box::new(SegmentPostings::empty()); + } + } + } + } + Box::new(intersection(segment_postings)) + } + } +} + + + + diff --git a/src/query/query.rs b/src/query/query.rs new file mode 100644 index 000000000..c6803c727 --- /dev/null +++ b/src/query/query.rs @@ -0,0 +1,8 @@ +use std::io; +use collector::Collector; +use core::searcher::Searcher; +use common::TimerTree; + +pub trait Query { + fn search(&self, searcher: &Searcher, collector: &mut C) -> io::Result; +}