diff --git a/examples/simple_search.rs b/examples/simple_search.rs index 42b0600c8..2ca533f1d 100644 --- a/examples/simple_search.rs +++ b/examples/simple_search.rs @@ -8,7 +8,6 @@ use tantivy::Index; use tantivy::schema::*; use tantivy::collector::TopCollector; use tantivy::query::QueryParser; -use tantivy::query::Query; fn main() { // Let's create a temporary directory for the diff --git a/src/collector/chained_collector.rs b/src/collector/chained_collector.rs index e52164f6a..5840eb775 100644 --- a/src/collector/chained_collector.rs +++ b/src/collector/chained_collector.rs @@ -2,7 +2,8 @@ use collector::Collector; use SegmentLocalId; use SegmentReader; use std::io; -use ScoredDoc; +use DocId; +use Score; /// Collector that does nothing. @@ -15,7 +16,7 @@ impl Collector for DoNothingCollector { Ok(()) } #[inline] - fn collect(&mut self, _: ScoredDoc) {} + fn collect(&mut self, _doc: DocId, _score: Score) {} } /// Zero-cost abstraction used to collect on multiple collectors. @@ -43,9 +44,9 @@ impl Collector for ChainedCollector ChainedCollector { mod tests { use super::*; - use ScoredDoc; use collector::{Collector, CountCollector, TopCollector}; #[test] @@ -73,9 +73,9 @@ mod tests { let mut collectors = chain() .push(&mut top_collector) .push(&mut count_collector); - collectors.collect(ScoredDoc(0.2, 1)); - collectors.collect(ScoredDoc(0.1, 2)); - collectors.collect(ScoredDoc(0.5, 3)); + collectors.collect(1, 0.2); + collectors.collect(2, 0.1); + collectors.collect(3, 0.5); } assert_eq!(count_collector.count(), 3); assert!(top_collector.at_capacity()); diff --git a/src/collector/count_collector.rs b/src/collector/count_collector.rs index 44d547ec3..8a9014a25 100644 --- a/src/collector/count_collector.rs +++ b/src/collector/count_collector.rs @@ -1,6 +1,7 @@ use std::io; use super::Collector; -use ScoredDoc; +use DocId; +use Score; use SegmentReader; use SegmentLocalId; @@ -31,7 +32,7 @@ impl Collector for CountCollector { Ok(()) } - fn collect(&mut self, _: ScoredDoc) { + fn collect(&mut self, _: DocId, _: Score) { self.count += 1; } } @@ -41,16 +42,14 @@ mod tests { use super::*; use test::Bencher; - use ScoredDoc; use collector::Collector; #[bench] fn build_collector(b: &mut Bencher) { b.iter(|| { let mut count_collector = CountCollector::default(); - let docs: Vec = (0..1_000_000).collect(); - for doc in docs { - count_collector.collect(ScoredDoc(1f32, doc)); + for doc in 0..1_000_000 { + count_collector.collect(doc, 1f32); } count_collector.count() }); diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 683b7eb1c..84bc38485 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -1,6 +1,7 @@ use SegmentReader; use SegmentLocalId; -use ScoredDoc; +use DocId; +use Score; use std::io; mod count_collector; @@ -49,7 +50,7 @@ pub trait Collector { /// on this segment. fn set_segment(&mut self, segment_local_id: SegmentLocalId, segment: &SegmentReader) -> io::Result<()>; /// The query pushes the scored document to the collector via this method. - fn collect(&mut self, scored_doc: ScoredDoc); + fn collect(&mut self, doc: DocId, score: Score); } @@ -58,8 +59,8 @@ impl<'a, C: Collector> Collector for &'a mut C { (*self).set_segment(segment_local_id, segment) } /// The query pushes the scored document to the collector via this method. - fn collect(&mut self, scored_doc: ScoredDoc) { - (*self).collect(scored_doc); + fn collect(&mut self, doc: DocId, score: Score) { + (*self).collect(doc, score); } } @@ -69,8 +70,8 @@ pub mod tests { use super::*; use test::Bencher; - use ScoredDoc; use DocId; + use Score; use core::SegmentReader; use std::io; use SegmentLocalId; @@ -112,8 +113,8 @@ pub mod tests { Ok(()) } - fn collect(&mut self, scored_doc: ScoredDoc) { - self.docs.push(scored_doc.doc() + self.offset); + fn collect(&mut self, doc: DocId, _score: Score) { + self.docs.push(doc + self.offset); } } @@ -150,8 +151,8 @@ pub mod tests { Ok(()) } - fn collect(&mut self, scored_doc: ScoredDoc) { - let val = self.ff_reader.as_ref().unwrap().get(scored_doc.doc()); + fn collect(&mut self, doc: DocId, _score: Score) { + let val = self.ff_reader.as_ref().unwrap().get(doc); self.vals.push(val); } } @@ -163,7 +164,7 @@ pub mod tests { let mut count_collector = CountCollector::default(); let docs: Vec = (0..1_000_000).collect(); for doc in docs { - count_collector.collect(ScoredDoc(1f32, doc)); + count_collector.collect(doc, 1f32); } count_collector.count() }); diff --git a/src/collector/multi_collector.rs b/src/collector/multi_collector.rs index 92958018d..6ce999e80 100644 --- a/src/collector/multi_collector.rs +++ b/src/collector/multi_collector.rs @@ -1,6 +1,7 @@ use std::io; use super::Collector; -use ScoredDoc; +use DocId; +use Score; use SegmentReader; use SegmentLocalId; @@ -31,9 +32,9 @@ impl<'a> Collector for MultiCollector<'a> { Ok(()) } - fn collect(&mut self, scored_doc: ScoredDoc) { + fn collect(&mut self, doc: DocId, score: Score) { for collector in &mut self.collectors { - collector.collect(scored_doc); + collector.collect(doc, score); } } } @@ -44,7 +45,6 @@ impl<'a> Collector for MultiCollector<'a> { mod tests { use super::*; - use ScoredDoc; use collector::{Collector, CountCollector, TopCollector}; #[test] @@ -53,9 +53,9 @@ mod tests { let mut count_collector = CountCollector::default(); { let mut collectors = MultiCollector::from(vec!(&mut top_collector, &mut count_collector)); - collectors.collect(ScoredDoc(0.2, 1)); - collectors.collect(ScoredDoc(0.1, 2)); - collectors.collect(ScoredDoc(0.5, 3)); + collectors.collect(1, 0.2); + collectors.collect(2, 0.1); + collectors.collect(3, 0.5); } assert_eq!(count_collector.count(), 3); assert!(top_collector.at_capacity()); diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index e7fd0d018..21c023caf 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -1,11 +1,11 @@ use std::io; use super::Collector; -use ScoredDoc; use SegmentReader; use SegmentLocalId; use DocAddress; use std::collections::BinaryHeap; use std::cmp::Ordering; +use DocId; use Score; // Rust heap is a max-heap and we need a min heap. @@ -13,6 +13,7 @@ use Score; struct GlobalScoredDoc { score: Score, doc_address: DocAddress + } impl PartialOrd for GlobalScoredDoc { @@ -109,20 +110,20 @@ impl Collector for TopCollector { Ok(()) } - fn collect(&mut self, scored_doc: ScoredDoc) { + fn collect(&mut self, doc: DocId, score: Score) { if self.at_capacity() { // It's ok to unwrap as long as a limit of 0 is forbidden. let limit_doc: GlobalScoredDoc = *self.heap.peek().expect("Top collector with size 0 is forbidden"); - if limit_doc.score < scored_doc.score() { + if limit_doc.score < score { let mut mut_head = self.heap.peek_mut().expect("Top collector with size 0 is forbidden"); - mut_head.score = scored_doc.score(); - mut_head.doc_address = DocAddress(self.segment_id, scored_doc.doc()); + mut_head.score = score; + mut_head.doc_address = DocAddress(self.segment_id, doc); } } else { let wrapped_doc = GlobalScoredDoc { - score: scored_doc.score(), - doc_address: DocAddress(self.segment_id, scored_doc.doc()) + score: score, + doc_address: DocAddress(self.segment_id, doc) }; self.heap.push(wrapped_doc); } @@ -135,7 +136,6 @@ impl Collector for TopCollector { mod tests { use super::*; - use ScoredDoc; use DocId; use Score; use collector::Collector; @@ -143,9 +143,9 @@ mod tests { #[test] fn test_top_collector_not_at_capacity() { let mut top_collector = TopCollector::with_limit(4); - top_collector.collect(ScoredDoc(0.8, 1)); - top_collector.collect(ScoredDoc(0.2, 3)); - top_collector.collect(ScoredDoc(0.3, 5)); + top_collector.collect(1, 0.8); + top_collector.collect(3, 0.2); + top_collector.collect(5, 0.3); assert!(!top_collector.at_capacity()); let score_docs: Vec<(Score, DocId)> = top_collector.score_docs() .into_iter() @@ -159,11 +159,11 @@ mod tests { #[test] fn test_top_collector_at_capacity() { let mut top_collector = TopCollector::with_limit(4); - top_collector.collect(ScoredDoc(0.8, 1)); - top_collector.collect(ScoredDoc(0.2, 3)); - top_collector.collect(ScoredDoc(0.3, 5)); - top_collector.collect(ScoredDoc(0.9, 7)); - top_collector.collect(ScoredDoc(-0.2, 9)); + top_collector.collect(1, 0.8); + top_collector.collect(3, 0.2); + top_collector.collect(5, 0.3); + top_collector.collect(7, 0.9); + top_collector.collect(9, -0.2); assert!(top_collector.at_capacity()); { let score_docs: Vec<(Score, DocId)> = top_collector diff --git a/src/common/mod.rs b/src/common/mod.rs index 77ff67edc..bf38fee07 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -9,7 +9,6 @@ pub use self::timer::OpenTimer; pub use self::vint::VInt; use std::io; - pub fn make_io_err(msg: String) -> io::Error { io::Error::new(io::ErrorKind::Other, msg) } @@ -25,3 +24,4 @@ pub trait HasLen { self.len() == 0 } } + diff --git a/src/fastfield/mod.rs b/src/fastfield/mod.rs index d1e1c8b6d..39d68ebfb 100644 --- a/src/fastfield/mod.rs +++ b/src/fastfield/mod.rs @@ -39,9 +39,7 @@ fn compute_num_bits(amplitude: u32) -> u8 { mod tests { use super::compute_num_bits; - use super::U32FastFieldsReader; - use super::U32FastFieldsWriter; - use super::FastFieldSerializer; + use super::*; use schema::Field; use std::path::Path; use directory::{Directory, WritePtr, RAMDirectory}; @@ -81,6 +79,17 @@ mod tests { doc.add_u32(field, value); fast_field_writers.add_document(&doc); } + + #[test] + pub fn test_fastfield() { + let test_fastfield = U32FastFieldReader::from(vec!(100,200,300)); + println!("{}", test_fastfield.get(0)); + println!("{}", test_fastfield.get(1)); + println!("{}", test_fastfield.get(2)); + assert_eq!(test_fastfield.get(0), 100); + assert_eq!(test_fastfield.get(1), 200); + assert_eq!(test_fastfield.get(2), 300); + } #[test] fn test_intfastfield_small() { diff --git a/src/fastfield/reader.rs b/src/fastfield/reader.rs index f57aa9362..335799565 100644 --- a/src/fastfield/reader.rs +++ b/src/fastfield/reader.rs @@ -5,10 +5,22 @@ use std::ops::Deref; use directory::ReadOnlySource; use common::BinarySerializable; use DocId; -use schema::Field; - +use schema::{Field, SchemaBuilder}; +use std::path::Path; +use schema::FAST; +use directory::{WritePtr, RAMDirectory, Directory}; +use fastfield::FastFieldSerializer; +use fastfield::U32FastFieldsWriter; use super::compute_num_bits; + +lazy_static! { + static ref U32_FAST_FIELD_EMPTY: ReadOnlySource = { + let u32_fast_field = U32FastFieldReader::from(Vec::new()); + u32_fast_field._data.clone() + }; +} + pub struct U32FastFieldReader { _data: ReadOnlySource, data_ptr: *const u8, @@ -20,6 +32,10 @@ pub struct U32FastFieldReader { impl U32FastFieldReader { + pub fn empty() -> U32FastFieldReader { + U32FastFieldReader::open(U32_FAST_FIELD_EMPTY.clone()).expect("should always work.") + } + pub fn min_val(&self,) -> u32 { self.min_val } @@ -62,6 +78,31 @@ impl U32FastFieldReader { } } + +impl From> for U32FastFieldReader { + fn from(vals: Vec) -> U32FastFieldReader { + let mut schema_builder = SchemaBuilder::default(); + let field = schema_builder.add_u32_field("field", FAST); + let schema = schema_builder.build(); + let path = Path::new("test"); + let mut directory: RAMDirectory = RAMDirectory::create(); + { + let write: WritePtr = directory.open_write(Path::new("test")).unwrap(); + let mut serializer = FastFieldSerializer::new(write).unwrap(); + let mut fast_field_writers = U32FastFieldsWriter::from_schema(&schema); + for val in vals { + let mut fast_field_writer = fast_field_writers.get_field_writer(field).unwrap(); + fast_field_writer.add_val(val); + } + fast_field_writers.serialize(&mut serializer).unwrap(); + serializer.close().unwrap(); + } + let source = directory.open_read(&path).unwrap(); + let fast_field_readers = U32FastFieldsReader::open(source).unwrap(); + fast_field_readers.get_field(field).unwrap() + } +} + pub struct U32FastFieldsReader { source: ReadOnlySource, field_offsets: HashMap, diff --git a/src/lib.rs b/src/lib.rs index d30bc6976..4a33b0213 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ #![allow(unknown_lints)] // for the clippy lint options #![allow(module_inception)] +#![feature(box_syntax)] #![feature(optin_builtin_traits)] #![feature(conservative_impl_trait)] #![cfg_attr(test, feature(test))] @@ -58,11 +59,16 @@ mod macros { macro_rules! doc( ($($field:ident => $value:expr),*) => {{ - let mut document = Document::default(); - $( - document.add(FieldValue::new($field, $value.into())); - )* - document + #[allow(unused_mut)] // avoid emitting a warning for `doc!()` + { + let mut document = Document::default(); + $( + document.add(FieldValue::new($field, $value.into())); + )* + document + } + + }}; ); } @@ -138,22 +144,6 @@ impl DocAddress { } } -/// A scored doc is simply a couple `(score, doc_id)` -#[derive(Clone, Copy)] -pub struct ScoredDoc(Score, DocId); - -impl ScoredDoc { - - /// Returns the score - pub fn score(&self,) -> Score { - self.0 - } - - /// Returns the doc - pub fn doc(&self,) -> DocId { - self.1 - } -} /// `DocAddress` contains all the necessary information /// to identify a document given a `Searcher` object. @@ -188,18 +178,15 @@ mod tests { // writing the segment let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { - let mut doc = Document::default(); - doc.add_text(text_field, "af b"); + let doc = doc!(text_field=>"af b"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c"); + let doc = doc!(text_field=>"a b c"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c d"); + let doc = doc!(text_field=>"a b c d"); index_writer.add_document(doc).unwrap(); } assert!(index_writer.commit().is_ok()); @@ -214,27 +201,22 @@ mod tests { let index = Index::create_in_ram(schema_builder.build()); let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c"); - index_writer.add_document(doc).unwrap(); + index_writer.add_document(doc!(text_field=>"a b c")).unwrap(); index_writer.commit().unwrap(); } { { - let mut doc = Document::default(); - doc.add_text(text_field, "a"); + let doc = doc!(text_field=>"a"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a a"); + let doc = doc!(text_field=>"a a"); index_writer.add_document(doc).unwrap(); } index_writer.commit().unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "c"); + let doc = doc!(text_field=>"c"); index_writer.add_document(doc).unwrap(); index_writer.commit().unwrap(); } @@ -260,17 +242,15 @@ mod tests { { let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c"); + let doc = doc!(text_field=>"a b c"); index_writer.add_document(doc).unwrap(); } { - let doc = Document::default(); + let doc = doc!(); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b"); + let doc = doc!(text_field=>"a b"); index_writer.add_document(doc).unwrap(); } index_writer.commit().unwrap(); @@ -296,8 +276,7 @@ mod tests { // writing the segment let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { - let mut doc = Document::default(); - doc.add_text(text_field, "af af af bc bc"); + let doc = doc!(text_field=>"af af af bc bc"); index_writer.add_document(doc).unwrap(); } index_writer.commit().unwrap(); @@ -325,18 +304,15 @@ mod tests { // writing the segment let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { - let mut doc = Document::default(); - doc.add_text(text_field, "af af af b"); + let doc = doc!(text_field=>"af af af b"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c"); + let doc = doc!(text_field=>"a b c"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c d"); + let doc = doc!(text_field=>"a b c d"); index_writer.add_document(doc).unwrap(); } index_writer.commit().unwrap(); @@ -394,18 +370,15 @@ mod tests { // writing the segment let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { - let mut doc = Document::default(); - doc.add_text(text_field, "af b"); + let doc = doc!(text_field=>"af b"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c"); + let doc = doc!(text_field=>"a b c"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c d"); + let doc = doc!(text_field=>"a b c d"); index_writer.add_document(doc).unwrap(); } index_writer.commit().unwrap(); diff --git a/src/postings/docset.rs b/src/postings/docset.rs index db40db619..24ae1ffe6 100644 --- a/src/postings/docset.rs +++ b/src/postings/docset.rs @@ -17,53 +17,56 @@ pub enum SkipResult { } -/// Represents an iterable set of sorted doc ids. +/// Represents an iterable set of sorted doc ids. pub trait DocSet { /// Goes to the next element. /// `.advance(...)` needs to be called a first time to point to the correct /// element. - fn advance(&mut self,) -> bool; - + fn advance(&mut self) -> bool; + /// After skipping, position the iterator in such a way that `.doc()` /// will return a value greater than or equal to target. - /// + /// /// SkipResult expresses whether the `target value` was reached, overstepped, /// or if the `DocSet` was entirely consumed without finding any value - /// greater or equal to the `target`. + /// greater or equal to the `target`. + /// + /// WARNING: Calling skip always advances the docset. + /// More specifically, if the docset is already positionned on the target + /// skipping will advance to the next position and return SkipResult::Overstep. + /// fn skip_next(&mut self, target: DocId) -> SkipResult { + self.advance(); loop { match self.doc().cmp(&target) { Ordering::Less => { if !self.advance() { return SkipResult::End; } - }, - Ordering::Equal => { return SkipResult::Reached }, - Ordering::Greater => { return SkipResult::OverStep }, + } + Ordering::Equal => return SkipResult::Reached, + Ordering::Greater => return SkipResult::OverStep, } } } - + /// Returns the current document - fn doc(&self,) -> DocId; - + fn doc(&self) -> DocId; + /// Advances the cursor to the next document - /// None is returned if the iterator has `DocSet` - /// has already been entirely consumed. - fn next(&mut self,) -> Option { + /// None is returned if the iterator has `DocSet` + /// has already been entirely consumed. + fn next(&mut self) -> Option { if self.advance() { Some(self.doc()) - } - else { + } else { None } - } + } } - -impl DocSet for Box { - - fn advance(&mut self,) -> bool { +impl DocSet for Box { + fn advance(&mut self) -> bool { let unboxed: &mut TDocSet = self.borrow_mut(); unboxed.advance() } @@ -73,28 +76,25 @@ impl DocSet for Box { unboxed.skip_next(target) } - fn doc(&self,) -> DocId { + fn doc(&self) -> DocId { let unboxed: &TDocSet = self.borrow(); unboxed.doc() } } impl<'a, TDocSet: DocSet> DocSet for &'a mut TDocSet { - - fn advance(&mut self,) -> bool { + fn advance(&mut self) -> bool { let unref: &mut TDocSet = *self; unref.advance() } - + fn skip_next(&mut self, target: DocId) -> SkipResult { let unref: &mut TDocSet = *self; unref.skip_next(target) } - fn doc(&self,) -> DocId { + fn doc(&self) -> DocId { let unref: &TDocSet = *self; unref.doc() } } - - diff --git a/src/postings/freq_handler.rs b/src/postings/freq_handler.rs index ea9cb6ae6..70808798a 100644 --- a/src/postings/freq_handler.rs +++ b/src/postings/freq_handler.rs @@ -17,7 +17,7 @@ pub struct FreqHandler { fn read_positions(data: &[u8]) -> Vec { - let mut composite_reader = CompositeDecoder::new(); + let mut composite_reader = CompositeDecoder::new(); let mut readable: &[u8] = data; let uncompressed_len = VInt::deserialize(&mut readable).unwrap().0 as usize; composite_reader.uncompress_unsorted(readable, uncompressed_len); @@ -27,17 +27,16 @@ fn read_positions(data: &[u8]) -> Vec { impl FreqHandler { - /// Returns a `FreqHandler` that just decodes `DocId`s. pub fn new_without_freq() -> FreqHandler { FreqHandler { freq_decoder: SIMDBlockDecoder::with_val(1u32), - positions: Vec::new(), + positions: Vec::new(), option: SegmentPostingsOption::NoFreq, positions_offsets: [0; NUM_DOCS_PER_BLOCK + 1], } } - + /// Returns a `FreqHandler` that decodes `DocId`s and term frequencies. pub fn new_with_freq() -> FreqHandler { FreqHandler { @@ -54,15 +53,15 @@ impl FreqHandler { let positions = read_positions(position_data); FreqHandler { freq_decoder: SIMDBlockDecoder::new(), - positions: positions, + positions: positions, option: SegmentPostingsOption::FreqAndPositions, positions_offsets: [0; NUM_DOCS_PER_BLOCK + 1], } } - - fn fill_positions_offset(&mut self,) { + + fn fill_positions_offset(&mut self) { let mut cur_position: usize = self.positions_offsets[NUM_DOCS_PER_BLOCK]; - let mut i: usize = 0; + let mut i: usize = 0; self.positions_offsets[i] = cur_position; let mut last_cur_position = cur_position; for &doc_freq in self.freq_decoder.output_array() { @@ -78,16 +77,16 @@ impl FreqHandler { last_cur_position = cur_position; } } - - + + /// Accessor to term frequency /// /// idx is the offset of the current doc in the block. /// It takes value between 0 and 128. - pub fn freq(&self, idx: usize)-> u32 { + pub fn freq(&self, idx: usize) -> u32 { self.freq_decoder.output(idx) } - + /// Accessor to the positions /// /// idx is the offset of the current doc in the block. @@ -95,18 +94,14 @@ impl FreqHandler { pub fn positions(&self, idx: usize) -> &[u32] { let start = self.positions_offsets[idx]; let stop = self.positions_offsets[idx + 1]; - &self.positions[start..stop] + &self.positions[start..stop] } - + /// Decompresses a complete frequency block pub fn read_freq_block<'a>(&mut self, data: &'a [u8]) -> &'a [u8] { match self.option { - SegmentPostingsOption::NoFreq => { - data - } - SegmentPostingsOption::Freq => { - self.freq_decoder.uncompress_block_unsorted(data) - } + SegmentPostingsOption::NoFreq => data, + SegmentPostingsOption::Freq => self.freq_decoder.uncompress_block_unsorted(data), SegmentPostingsOption::FreqAndPositions => { let remaining: &'a [u8] = self.freq_decoder.uncompress_block_unsorted(data); self.fill_positions_offset(); @@ -114,7 +109,7 @@ impl FreqHandler { } } } - + /// Decompresses an incomplete frequency block pub fn read_freq_vint(&mut self, data: &[u8], num_els: usize) { match self.option { @@ -128,5 +123,4 @@ impl FreqHandler { } } } - } \ No newline at end of file diff --git a/src/postings/intersection.rs b/src/postings/intersection.rs index 68cb1ec3b..75699065c 100644 --- a/src/postings/intersection.rs +++ b/src/postings/intersection.rs @@ -1,90 +1,80 @@ use postings::DocSet; -use std::cmp::Ordering; +use postings::SkipResult; use DocId; // TODO Find a way to specialize `IntersectionDocSet` /// Creates a `DocSet` that iterator through the intersection of two `DocSet`s. -pub struct IntersectionDocSet<'a> { - left: Box, - right: Box, - finished: bool, +pub struct IntersectionDocSet { + docsets: Vec, + finished: bool, + doc: DocId, } -impl<'a> IntersectionDocSet<'a> { - - /// Intersect two `DocSet`s - fn from_pair(left: Box, right: Box) -> IntersectionDocSet<'a> { +impl From> for IntersectionDocSet { + fn from(docsets: Vec) -> IntersectionDocSet { + assert!(docsets.len() >= 2); IntersectionDocSet { - left: left, - right: right, + docsets: docsets, finished: false, - } + doc: DocId::max_value(), + } } - - /// Intersect a list of `DocSet`s - pub fn new(mut postings: Vec>) -> IntersectionDocSet<'a> { - let left = postings.pop().unwrap(); - let right = - if postings.len() == 1 { - postings.pop().unwrap() - } - else { - Box::new(IntersectionDocSet::new(postings)) - }; - IntersectionDocSet::from_pair(left, right) +} + +impl IntersectionDocSet { + /// Returns an array to the underlying `DocSet`s of the intersection. + /// These `DocSet` are in the same position as the `IntersectionDocSet`, + /// so that user can access their `docfreq` and `positions`. + pub fn docsets(&self) -> &[TDocSet] { + &self.docsets[..] } } -impl<'a> DocSet for IntersectionDocSet<'a> { - - fn advance(&mut self,) -> bool { +impl DocSet for IntersectionDocSet { + fn advance(&mut self) -> bool { if self.finished { return false; } - - if !self.left.advance() { - self.finished = true; - return false; - } - if !self.right.advance() { - self.finished = true; - return false; - } + let num_docsets = self.docsets.len(); + let mut count_matching = 1; + let mut doc_candidate = { + let mut first_docset = &mut self.docsets[0]; + if !first_docset.advance() { + self.finished = true; + return false; + } + first_docset.doc() + }; + let mut ord = 1; loop { - match self.left.doc().cmp(&self.right.doc()) { - Ordering::Equal => { - return true; - } - Ordering::Less => { - if !self.left.advance() { - self.finished = true; - return false; + let mut doc_set = &mut self.docsets[ord]; + match doc_set.skip_next(doc_candidate) { + SkipResult::Reached => { + count_matching += 1; + if count_matching == num_docsets { + self.doc = doc_candidate; + return true; } } - Ordering::Greater => { - if !self.right.advance() { - self.finished = true; - return false; - } + SkipResult::End => { + self.finished = true; + return false; } + SkipResult::OverStep => { + count_matching = 1; + doc_candidate = doc_set.doc(); + } + } + ord += 1; + if ord == num_docsets { + ord = 0; } } } - - fn doc(&self,) -> DocId { - self.left.doc() + + fn doc(&self) -> DocId { + self.doc } } - -/// Intersects a `Vec` of `DocSets` -pub fn intersection<'a, TDocSet: DocSet + 'a>(postings: Vec) -> IntersectionDocSet<'a> { - let boxed_postings: Vec> = postings - .into_iter() - .map(|postings: TDocSet| { - Box::new(postings) as Box - }) - .collect(); - IntersectionDocSet::new(boxed_postings) -} diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 374f08c33..e6c07837a 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -29,12 +29,11 @@ pub use self::postings::Postings; #[cfg(test)] pub use self::vec_postings::VecPostings; - pub use self::chained_postings::ChainedPostings; pub use self::segment_postings::SegmentPostings; -pub use self::intersection::intersection; pub use self::intersection::IntersectionDocSet; pub use self::freq_handler::FreqHandler; + pub use self::segment_postings_option::SegmentPostingsOption; pub use common::HasLen; @@ -49,6 +48,7 @@ mod tests { use core::Index; use std::iter; use datastruct::stacker::Heap; + use query::TermQuery; #[test] @@ -72,7 +72,7 @@ mod tests { } #[test] - pub fn test_position_and_fieldnorm_write_fullstack() { + pub fn test_position_and_fieldnorm() { let mut schema_builder = SchemaBuilder::default(); let text_field = schema_builder.add_text_field("text", TEXT); let schema = schema_builder.build(); @@ -153,12 +153,43 @@ mod tests { } } + #[test] + pub fn test_position_and_fieldnorm2() { + let mut schema_builder = SchemaBuilder::default(); + let text_field = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + { + let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); + { + let mut doc = Document::default(); + doc.add_text(text_field, "g b b d c g c"); + index_writer.add_document(doc).unwrap(); + } + { + let mut doc = Document::default(); + doc.add_text(text_field, "g a b b a d c g c"); + index_writer.add_document(doc).unwrap(); + } + assert!(index_writer.commit().is_ok()); + } + let term_query = TermQuery::from(Term::from_field_text(text_field, "a")); + let searcher = index.searcher(); + let mut term_weight = term_query.specialized_weight(&*searcher); + term_weight.segment_postings_options = SegmentPostingsOption::FreqAndPositions; + let segment_reader = &searcher.segment_readers()[0]; + let mut term_scorer = term_weight.specialized_scorer(segment_reader).unwrap(); + assert!(term_scorer.advance()); + assert_eq!(term_scorer.doc(), 1u32); + assert_eq!(term_scorer.postings().positions(), &[1u32, 4]); + } + #[test] fn test_intersection() { { - let left = Box::new(VecPostings::from(vec!(1, 3, 9))); - let right = Box::new(VecPostings::from(vec!(3, 4, 9, 18))); - let mut intersection = IntersectionDocSet::new(vec!(left, right)); + let left = VecPostings::from(vec!(1, 3, 9)); + let right = VecPostings::from(vec!(3, 4, 9, 18)); + let mut intersection = IntersectionDocSet::from(vec!(left, right)); assert!(intersection.advance()); assert_eq!(intersection.doc(), 3); assert!(intersection.advance()); @@ -166,10 +197,10 @@ mod tests { assert!(!intersection.advance()); } { - let a = Box::new(VecPostings::from(vec!(1, 3, 9))); - let b = Box::new(VecPostings::from(vec!(3, 4, 9, 18))); - let c = Box::new(VecPostings::from(vec!(1, 5, 9, 111))); - let mut intersection = IntersectionDocSet::new(vec!(a, b, c)); + let a = VecPostings::from(vec!(1, 3, 9)); + let b = VecPostings::from(vec!(3, 4, 9, 18)); + let c = VecPostings::from(vec!(1, 5, 9, 111)); + let mut intersection = IntersectionDocSet::from(vec!(a, b, c)); assert!(intersection.advance()); assert_eq!(intersection.doc(), 9); assert!(!intersection.advance()); diff --git a/src/postings/offset_postings.rs b/src/postings/offset_postings.rs index fe7ea453d..1410ef922 100644 --- a/src/postings/offset_postings.rs +++ b/src/postings/offset_postings.rs @@ -15,7 +15,6 @@ pub struct OffsetPostings<'a> { } impl<'a> OffsetPostings<'a> { - /// Constructor pub fn new(underlying: SegmentPostings<'a>, offset: DocId) -> OffsetPostings { OffsetPostings { @@ -26,38 +25,35 @@ impl<'a> OffsetPostings<'a> { } impl<'a> DocSet for OffsetPostings<'a> { - fn advance(&mut self,) -> bool { + fn advance(&mut self) -> bool { self.underlying.advance() } - - fn doc(&self,) -> DocId { + + fn doc(&self) -> DocId { self.underlying.doc() + self.offset } - + fn skip_next(&mut self, target: DocId) -> SkipResult { if target >= self.offset { SkipResult::OverStep - } - else { - self.underlying.skip_next(target - self.offset) + } else { + self.underlying.skip_next(target - self.offset) } } } impl<'a> HasLen for OffsetPostings<'a> { - fn len(&self,) -> usize { + fn len(&self) -> usize { self.underlying.len() } } impl<'a> Postings for OffsetPostings<'a> { - - fn term_freq(&self,) -> u32 { + fn term_freq(&self) -> u32 { self.underlying.term_freq() } - + fn positions(&self) -> &[u32] { self.underlying.positions() } - } \ No newline at end of file diff --git a/src/postings/postings.rs b/src/postings/postings.rs index 8b964d0a9..52f16198a 100644 --- a/src/postings/postings.rs +++ b/src/postings/postings.rs @@ -1,8 +1,5 @@ use std::borrow::Borrow; use postings::docset::DocSet; -use common::HasLen; - - /// Postings (also called inverted list) /// @@ -10,58 +7,38 @@ use common::HasLen; /// containing the term. Optionally, for each document, /// it may also give access to the term frequency /// as well as the list of term positions. -/// +/// /// Its main implementation is `SegmentPostings`, /// but other implementations mocking `SegmentPostings` exist, /// for merging segments or for testing. pub trait Postings: DocSet { /// Returns the term frequency - fn term_freq(&self,) -> u32; + fn term_freq(&self) -> u32; /// Returns the list of positions of the term, expressed as a list of /// token ordinals. fn positions(&self) -> &[u32]; } impl Postings for Box { - - fn term_freq(&self,) -> u32 { + fn term_freq(&self) -> u32 { let unboxed: &TPostings = self.borrow(); unboxed.term_freq() } - + fn positions(&self) -> &[u32] { let unboxed: &TPostings = self.borrow(); unboxed.positions() } - } impl<'a, TPostings: Postings> Postings for &'a mut TPostings { - - fn term_freq(&self,) -> u32 { + fn term_freq(&self) -> u32 { let unref: &TPostings = *self; unref.term_freq() } - + fn positions(&self) -> &[u32] { let unref: &TPostings = *self; unref.positions() } - -} - - - -impl HasLen for Box { - fn len(&self,) -> usize { - let unboxed: &THasLen = self.borrow(); - unboxed.borrow().len() - } -} - -impl<'a> HasLen for &'a HasLen { - fn len(&self,) -> usize { - let unref: &HasLen = *self; - unref.len() - } } diff --git a/src/postings/postings_writer.rs b/src/postings/postings_writer.rs index 3b6ddd440..c3d1f997f 100644 --- a/src/postings/postings_writer.rs +++ b/src/postings/postings_writer.rs @@ -9,12 +9,11 @@ use schema::Field; use analyzer::StreamingIterator; use datastruct::stacker::{HashMap, Heap}; -/// The `PostingsWriter` is in charge of receiving documenting +/// The `PostingsWriter` is in charge of receiving documenting /// and building a `Segment` in anonymous memory. /// /// `PostingsWriter` writes in a `Heap`. pub trait PostingsWriter { - /// Record that a document contains a term at a given position. /// /// * doc - the document id @@ -22,17 +21,22 @@ pub trait PostingsWriter { /// * term - the term /// * heap - heap used to store the postings informations as well as the terms /// in the hashmap. - fn suscribe(&mut self, doc: DocId, pos: u32, term: &Term, heap: &Heap); - + fn suscribe(&mut self, doc: DocId, pos: u32, term: &Term, heap: &Heap); + /// Serializes the postings on disk. /// The actual serialization format is handled by the `PostingsSerializer`. fn serialize(&self, serializer: &mut PostingsSerializer, heap: &Heap) -> io::Result<()>; - + /// Closes all of the currently open `Recorder`'s. fn close(&mut self, heap: &Heap); - + /// Tokenize a text and suscribe all of its token. - fn index_text<'a>(&mut self, doc_id: DocId, field: Field, field_values: &[&'a FieldValue], heap: &Heap) -> u32 { + fn index_text<'a>(&mut self, + doc_id: DocId, + field: Field, + field_values: &[&'a FieldValue], + heap: &Heap) + -> u32 { let mut pos = 0u32; let mut num_tokens: u32 = 0u32; let mut term = Term::allocate(field, 100); @@ -65,7 +69,7 @@ fn hashmap_size_in_bits(heap_capacity: u32) -> usize { let num_buckets_usable = heap_capacity / 100; let hash_table_size = num_buckets_usable * 2; let mut pow = 512; - for num_bits in 10 .. 32 { + for num_bits in 10..32 { pow <<= 1; if pow > hash_table_size { return num_bits; @@ -75,31 +79,26 @@ fn hashmap_size_in_bits(heap_capacity: u32) -> usize { } impl<'a, Rec: Recorder + 'static> SpecializedPostingsWriter<'a, Rec> { - /// constructor pub fn new(heap: &'a Heap) -> SpecializedPostingsWriter<'a, Rec> { let capacity = heap.capacity(); let hashmap_size = hashmap_size_in_bits(capacity); - SpecializedPostingsWriter { - term_index: HashMap::new(hashmap_size, heap), - } + SpecializedPostingsWriter { term_index: HashMap::new(hashmap_size, heap) } } - + /// Builds a `SpecializedPostingsWriter` storing its data in a heap. pub fn new_boxed(heap: &'a Heap) -> Box { Box::new(SpecializedPostingsWriter::::new(heap)) - } - + } } impl<'a, Rec: Recorder + 'static> PostingsWriter for SpecializedPostingsWriter<'a, Rec> { - fn close(&mut self, heap: &Heap) { for recorder in self.term_index.values_mut() { recorder.close_doc(heap); } } - + #[inline] fn suscribe(&mut self, doc: DocId, position: u32, term: &Term, heap: &Heap) { let mut recorder = self.term_index.get_or_create(term); @@ -112,9 +111,9 @@ impl<'a, Rec: Recorder + 'static> PostingsWriter for SpecializedPostingsWriter<' } recorder.record_position(position, heap); } - + fn serialize(&self, serializer: &mut PostingsSerializer, heap: &Heap) -> io::Result<()> { - let mut term_offsets: Vec<(&[u8], (u32, &Rec))> = self.term_index + let mut term_offsets: Vec<(&[u8], (u32, &Rec))> = self.term_index .iter() .collect(); term_offsets.sort_by_key(|&(k, _v)| k); @@ -128,8 +127,6 @@ impl<'a, Rec: Recorder + 'static> PostingsWriter for SpecializedPostingsWriter<' } Ok(()) } - - } diff --git a/src/postings/recorder.rs b/src/postings/recorder.rs index 095102a3d..94173720b 100644 --- a/src/postings/recorder.rs +++ b/src/postings/recorder.rs @@ -4,32 +4,36 @@ use postings::PostingsSerializer; use datastruct::stacker::{ExpUnrolledLinkedList, Heap, HeapAllocable}; const EMPTY_ARRAY: [u32; 0] = [0u32; 0]; -const POSITION_END: u32 = 4294967295; +const POSITION_END: u32 = 4294967295; /// Recorder is in charge of recording relevant information about /// the presence of a term in a document. /// -/// Depending on the `TextIndexingOptions` associated to the +/// Depending on the `TextIndexingOptions` associated to the /// field, the recorder may records /// * the document frequency -/// * the document id +/// * the document id /// * the term frequency /// * the term positions pub trait Recorder: HeapAllocable { /// Returns the current document - fn current_doc(&self,) -> u32; + fn current_doc(&self) -> u32; /// Starts recording information about a new document - /// This method shall only be called if the term is within the document. + /// This method shall only be called if the term is within the document. fn new_doc(&mut self, doc: DocId, heap: &Heap); - /// Record the position of a term. For each document, + /// Record the position of a term. For each document, /// this method will be called `term_freq` times. fn record_position(&mut self, position: u32, heap: &Heap); - /// Close the document. It will help record the term frequency. + /// Close the document. It will help record the term frequency. fn close_doc(&mut self, heap: &Heap); /// Returns the number of document that have been seen so far - fn doc_freq(&self,) -> u32; + fn doc_freq(&self) -> u32; /// Pushes the postings information to the serializer. - fn serialize(&self, self_addr: u32, serializer: &mut PostingsSerializer, heap: &Heap) -> io::Result<()>; + fn serialize(&self, + self_addr: u32, + serializer: &mut PostingsSerializer, + heap: &Heap) + -> io::Result<()>; } /// Only records the doc ids @@ -51,11 +55,10 @@ impl HeapAllocable for NothingRecorder { } impl Recorder for NothingRecorder { - - fn current_doc(&self,) -> DocId { + fn current_doc(&self) -> DocId { self.current_doc } - + fn new_doc(&mut self, doc: DocId, heap: &Heap) { self.current_doc = doc; self.stack.push(doc, heap); @@ -66,17 +69,20 @@ impl Recorder for NothingRecorder { fn close_doc(&mut self, _heap: &Heap) {} - fn doc_freq(&self,) -> u32 { + fn doc_freq(&self) -> u32 { self.doc_freq } - - fn serialize(&self, self_addr: u32, serializer: &mut PostingsSerializer, heap: &Heap) -> io::Result<()> { + + fn serialize(&self, + self_addr: u32, + serializer: &mut PostingsSerializer, + heap: &Heap) + -> io::Result<()> { for doc in self.stack.iter(self_addr, heap) { try!(serializer.write_doc(doc, 0u32, &EMPTY_ARRAY)); } Ok(()) } - } /// Recorder encoding document ids, and term frequencies @@ -94,16 +100,13 @@ impl HeapAllocable for TermFrequencyRecorder { stack: ExpUnrolledLinkedList::with_addr(addr), current_doc: u32::max_value(), current_tf: 0u32, - doc_freq: 0u32 - } + doc_freq: 0u32, + } } } impl Recorder for TermFrequencyRecorder { - - - - fn current_doc(&self,) -> DocId { + fn current_doc(&self) -> DocId { self.current_doc } @@ -112,22 +115,26 @@ impl Recorder for TermFrequencyRecorder { self.current_doc = doc; self.stack.push(doc, heap); } - + fn record_position(&mut self, _position: u32, _heap: &Heap) { self.current_tf += 1; } - + fn close_doc(&mut self, heap: &Heap) { debug_assert!(self.current_tf > 0); self.stack.push(self.current_tf, heap); self.current_tf = 0; } - - fn doc_freq(&self,) -> u32 { + + fn doc_freq(&self) -> u32 { self.doc_freq } - - fn serialize(&self, self_addr:u32, serializer: &mut PostingsSerializer, heap: &Heap) -> io::Result<()> { + + fn serialize(&self, + self_addr: u32, + serializer: &mut PostingsSerializer, + heap: &Heap) + -> io::Result<()> { let mut doc_iter = self.stack.iter(self_addr, heap); loop { if let Some(doc) = doc_iter.next() { @@ -140,7 +147,6 @@ impl Recorder for TermFrequencyRecorder { } Ok(()) } - } /// Recorder encoding term frequencies as well as positions. @@ -162,12 +168,10 @@ impl HeapAllocable for TFAndPositionRecorder { } impl Recorder for TFAndPositionRecorder { - - - fn current_doc(&self,) -> DocId { + fn current_doc(&self) -> DocId { self.current_doc } - + fn new_doc(&mut self, doc: DocId, heap: &Heap) { self.doc_freq += 1; self.current_doc = doc; @@ -177,16 +181,20 @@ impl Recorder for TFAndPositionRecorder { fn record_position(&mut self, position: u32, heap: &Heap) { self.stack.push(position, heap); } - + fn close_doc(&mut self, heap: &Heap) { self.stack.push(POSITION_END, heap); } - - fn doc_freq(&self,) -> u32 { + + fn doc_freq(&self) -> u32 { self.doc_freq } - - fn serialize(&self, self_addr: u32, serializer: &mut PostingsSerializer, heap: &Heap) -> io::Result<()> { + + fn serialize(&self, + self_addr: u32, + serializer: &mut PostingsSerializer, + heap: &Heap) + -> io::Result<()> { let mut doc_positions = Vec::with_capacity(100); let mut positions_iter = self.stack.iter(self_addr, heap); while let Some(doc) = positions_iter.next() { @@ -197,8 +205,7 @@ impl Recorder for TFAndPositionRecorder { Some(position) => { if position == POSITION_END { break; - } - else { + } else { doc_positions.push(position - prev_position); prev_position = position; } @@ -212,7 +219,4 @@ impl Recorder for TFAndPositionRecorder { } Ok(()) } - } - - diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index 8de872f90..0bb8af8e3 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -4,11 +4,11 @@ use postings::{Postings, FreqHandler, DocSet, HasLen}; use std::num::Wrapping; +const EMPTY_DATA: [u8; 0] = [0u8; 0]; - -/// `SegmentPostings` represents the inverted list or postings associated to +/// `SegmentPostings` represents the inverted list or postings associated to /// a term in a `Segment`. -/// +/// /// As we iterate through the `SegmentPostings`, the frequencies are optionally decoded. /// Positions on the other hand, are optionally entirely decoded upfront. pub struct SegmentPostings<'a> { @@ -21,16 +21,16 @@ pub struct SegmentPostings<'a> { } impl<'a> SegmentPostings<'a> { - - fn load_next_block(&mut self,) { + fn load_next_block(&mut self) { let num_remaining_docs = self.len - self.cur.0; if num_remaining_docs >= NUM_DOCS_PER_BLOCK { - self.remaining_data = self.block_decoder.uncompress_block_sorted(self.remaining_data, self.doc_offset); + self.remaining_data = self.block_decoder + .uncompress_block_sorted(self.remaining_data, self.doc_offset); self.remaining_data = self.freq_handler.read_freq_block(self.remaining_data); self.doc_offset = self.block_decoder.output(NUM_DOCS_PER_BLOCK - 1); - } - else { - self.remaining_data = self.block_decoder.uncompress_vint_sorted(self.remaining_data, self.doc_offset, num_remaining_docs); + } else { + self.remaining_data = self.block_decoder + .uncompress_vint_sorted(self.remaining_data, self.doc_offset, num_remaining_docs); self.freq_handler.read_freq_vint(self.remaining_data, num_remaining_docs); } } @@ -39,7 +39,7 @@ impl<'a> SegmentPostings<'a> { /// /// * `len` - number of document in the posting lists. /// * `data` - data array. The complete data is not necessarily used. - /// * `freq_handler` - the freq handler is in charge of decoding + /// * `freq_handler` - the freq handler is in charge of decoding /// frequencies and/or positions pub fn from_data(len: u32, data: &'a [u8], freq_handler: FreqHandler) -> SegmentPostings<'a> { SegmentPostings { @@ -51,22 +51,32 @@ impl<'a> SegmentPostings<'a> { cur: Wrapping(usize::max_value()), } } - - /// Index within a block is used as an address when - /// interacting with the `FreqHandler` - fn index_within_block(&self,) -> usize { - self.cur.0 % NUM_DOCS_PER_BLOCK + + /// Returns an empty segment postings object + pub fn empty() -> SegmentPostings<'static> { + SegmentPostings { + len: 0, + doc_offset: 0, + block_decoder: SIMDBlockDecoder::new(), + freq_handler: FreqHandler::new_without_freq(), + remaining_data: &EMPTY_DATA, + cur: Wrapping(usize::max_value()), + } } + /// Index within a block is used as an address when + /// interacting with the `FreqHandler` + fn index_within_block(&self) -> usize { + self.cur.0 % NUM_DOCS_PER_BLOCK + } } impl<'a> DocSet for SegmentPostings<'a> { - // goes to the next element. // next needs to be called a first time to point to the correct element. #[inline] - fn advance(&mut self,) -> bool { + fn advance(&mut self) -> bool { self.cur += Wrapping(1); if self.cur.0 >= self.len { return false; @@ -76,27 +86,25 @@ impl<'a> DocSet for SegmentPostings<'a> { } true } - + #[inline] - fn doc(&self,) -> DocId { + fn doc(&self) -> DocId { self.block_decoder.output(self.index_within_block()) } - } impl<'a> HasLen for SegmentPostings<'a> { - fn len(&self,) -> usize { + fn len(&self) -> usize { self.len } } impl<'a> Postings for SegmentPostings<'a> { - fn term_freq(&self,) -> u32 { + fn term_freq(&self) -> u32 { self.freq_handler.freq(self.index_within_block()) } - + fn positions(&self) -> &[u32] { self.freq_handler.positions(self.index_within_block()) } } - diff --git a/src/postings/segment_postings_option.rs b/src/postings/segment_postings_option.rs index 70dfd97e4..cf2f8b936 100644 --- a/src/postings/segment_postings_option.rs +++ b/src/postings/segment_postings_option.rs @@ -2,10 +2,11 @@ /// Object describing the amount of information required when reading a postings. /// -/// Since decoding information is not free, this makes it possible to +/// Since decoding information is not free, this makes it possible to /// avoid this extra cost when the information is not required. /// For instance, positions are useful when running phrase queries -/// but useless in other queries, +/// but useless in other queries. +#[derive(Clone, Copy)] pub enum SegmentPostingsOption { /// Only the doc ids are decoded NoFreq, @@ -13,4 +14,4 @@ pub enum SegmentPostingsOption { Freq, /// DocIds, term frequencies and positions will be decoded. FreqAndPositions, -} \ No newline at end of file +} diff --git a/src/postings/serializer.rs b/src/postings/serializer.rs index 5fb7c9505..3316d1f5e 100644 --- a/src/postings/serializer.rs +++ b/src/postings/serializer.rs @@ -19,14 +19,14 @@ use common::BinarySerializable; /// `PostingsSerializer` is in charge of serializing -/// postings on disk, in the +/// postings on disk, in the /// * `.idx` (inverted index) /// * `.pos` (positions file) /// * `.term` (term dictionary) -/// -/// `PostingsWriter` are in charge of pushing the data to the +/// +/// `PostingsWriter` are in charge of pushing the data to the /// serializer. -/// +/// /// The serializer expects to receive the following calls /// in this order : /// @@ -45,10 +45,10 @@ use common::BinarySerializable; /// Terms have to be pushed in a lexicographically-sorted order. /// Within a term, document have to be pushed in increasing order. /// -/// A description of the serialization format is -/// [available here](https://fulmicoton.gitbooks.io/tantivy-doc/content/inverted-index.html). +/// A description of the serialization format is +/// [available here](https://fulmicoton.gitbooks.io/tantivy-doc/content/inverted-index.html). pub struct PostingsSerializer { - terms_fst_builder: FstMapBuilder, // TODO find an alternative to work around the "move" + terms_fst_builder: FstMapBuilder, /* TODO find an alternative to work around the "move" */ postings_write: WritePtr, positions_write: WritePtr, written_bytes_postings: usize, @@ -65,14 +65,13 @@ pub struct PostingsSerializer { } impl PostingsSerializer { - - /// Open a new `PostingsSerializer` for the given segment - pub fn open(segment: &mut Segment) -> Result { - let terms_write = try!(segment.open_write(SegmentComponent::TERMS)); + /// Open a new `PostingsSerializer` for the given segment + pub fn new(terms_write: WritePtr, + postings_write: WritePtr, + positions_write: WritePtr, + schema: Schema) + -> Result { let terms_fst_builder = try!(FstMapBuilder::new(terms_write)); - let postings_write = try!(segment.open_write(SegmentComponent::POSTINGS)); - let positions_write = try!(segment.open_write(SegmentComponent::POSITIONS)); - let schema = segment.schema(); Ok(PostingsSerializer { terms_fst_builder: terms_fst_builder, postings_write: postings_write, @@ -90,27 +89,36 @@ impl PostingsSerializer { term_open: false, }) } - + + + /// Open a new `PostingsSerializer` for the given segment + pub fn open(segment: &mut Segment) -> Result { + let terms_write = try!(segment.open_write(SegmentComponent::TERMS)); + let postings_write = try!(segment.open_write(SegmentComponent::POSTINGS)); + let positions_write = try!(segment.open_write(SegmentComponent::POSITIONS)); + PostingsSerializer::new(terms_write, + postings_write, + positions_write, + segment.schema()) + } + fn load_indexing_options(&mut self, field: Field) { let field_entry: &FieldEntry = self.schema.get_field_entry(field); self.text_indexing_options = match *field_entry.field_type() { - FieldType::Str(ref text_options) => { - text_options.get_indexing_options() - } + FieldType::Str(ref text_options) => text_options.get_indexing_options(), FieldType::U32(ref u32_options) => { if u32_options.is_indexed() { TextIndexingOptions::Unindexed - } - else { - TextIndexingOptions::Untokenized + } else { + TextIndexingOptions::Untokenized } } }; } - + /// Starts the postings for a new term. /// * term - the term. It needs to come after the previous term according - /// to the lexicographical order. + /// to the lexicographical order. /// * doc_freq - return the number of document containing the term. pub fn new_term(&mut self, term: &Term, doc_freq: DocId) -> io::Result<()> { if self.term_open { @@ -130,31 +138,34 @@ impl PostingsSerializer { self.terms_fst_builder .insert(term.as_slice(), &term_info) } - + /// Finish the serialization for this term postings. /// /// If the current block is incomplete, it need to be encoded - /// using `VInt` encoding. - pub fn close_term(&mut self,) -> io::Result<()> { + /// using `VInt` encoding. + pub fn close_term(&mut self) -> io::Result<()> { if self.term_open { if !self.doc_ids.is_empty() { // we have doc ids waiting to be written - // this happens when the number of doc ids is + // this happens when the number of doc ids is // not a perfect multiple of our block size. // // In that case, the remaining part is encoded // using variable int encoding. { - let block_encoded = self.block_encoder.compress_vint_sorted(&self.doc_ids, self.last_doc_id_encoded); + let block_encoded = self.block_encoder + .compress_vint_sorted(&self.doc_ids, self.last_doc_id_encoded); self.written_bytes_postings += block_encoded.len(); try!(self.postings_write.write_all(block_encoded)); self.doc_ids.clear(); } - // ... Idem for term frequencies + // ... Idem for term frequencies if self.text_indexing_options.is_termfreq_enabled() { - let block_encoded = self.block_encoder.compress_vint_unsorted(&self.term_freqs[..]); + let block_encoded = self.block_encoder + .compress_vint_unsorted(&self.term_freqs[..]); for num in block_encoded { - self.written_bytes_postings += try!(num.serialize(&mut self.postings_write)); + self.written_bytes_postings += + try!(num.serialize(&mut self.postings_write)); } self.term_freqs.clear(); } @@ -162,8 +173,10 @@ impl PostingsSerializer { // On the other hand, positions are entirely buffered until the // end of the term, at which point they are compressed and written. if self.text_indexing_options.is_position_enabled() { - self.written_bytes_positions += try!(VInt(self.position_deltas.len() as u64).serialize(&mut self.positions_write)); - let positions_encoded: &[u8] = self.positions_encoder.compress_unsorted(&self.position_deltas[..]); + self.written_bytes_positions += try!(VInt(self.position_deltas.len() as u64) + .serialize(&mut self.positions_write)); + let positions_encoded: &[u8] = self.positions_encoder + .compress_unsorted(&self.position_deltas[..]); try!(self.positions_write.write_all(positions_encoded)); self.written_bytes_positions += positions_encoded.len(); self.position_deltas.clear(); @@ -172,8 +185,8 @@ impl PostingsSerializer { } Ok(()) } - - + + /// Serialize the information that a document contains the current term, /// its term frequency, and the position deltas. /// @@ -183,7 +196,11 @@ impl PostingsSerializer { /// /// Term frequencies and positions may be ignored by the serializer depending /// on the configuration of the field in the `Schema`. - pub fn write_doc(&mut self, doc_id: DocId, term_freq: u32, position_deltas: &[u32]) -> io::Result<()> { + pub fn write_doc(&mut self, + doc_id: DocId, + term_freq: u32, + position_deltas: &[u32]) + -> io::Result<()> { self.doc_ids.push(doc_id); if self.text_indexing_options.is_termfreq_enabled() { self.term_freqs.push(term_freq as u32); @@ -194,14 +211,16 @@ impl PostingsSerializer { if self.doc_ids.len() == NUM_DOCS_PER_BLOCK { { // encode the doc ids - let block_encoded: &[u8] = self.block_encoder.compress_block_sorted(&self.doc_ids, self.last_doc_id_encoded); + let block_encoded: &[u8] = self.block_encoder + .compress_block_sorted(&self.doc_ids, self.last_doc_id_encoded); self.last_doc_id_encoded = self.doc_ids[self.doc_ids.len() - 1]; try!(self.postings_write.write_all(block_encoded)); self.written_bytes_postings += block_encoded.len(); } if self.text_indexing_options.is_termfreq_enabled() { // encode the term_freqs - let block_encoded: &[u8] = self.block_encoder.compress_block_unsorted(&self.term_freqs); + let block_encoded: &[u8] = self.block_encoder + .compress_block_unsorted(&self.term_freqs); try!(self.postings_write.write_all(block_encoded)); self.written_bytes_postings += block_encoded.len(); self.term_freqs.clear(); @@ -210,9 +229,9 @@ impl PostingsSerializer { } Ok(()) } - + /// Closes the serializer. - pub fn close(mut self,) -> io::Result<()> { + pub fn close(mut self) -> io::Result<()> { try!(self.close_term()); try!(self.terms_fst_builder.finish()); try!(self.postings_write.flush()); diff --git a/src/postings/vec_postings.rs b/src/postings/vec_postings.rs index 8d4ba0d48..399307cff 100644 --- a/src/postings/vec_postings.rs +++ b/src/postings/vec_postings.rs @@ -1,11 +1,10 @@ #![allow(dead_code)] use DocId; -use postings::{Postings, DocSet, SkipResult, HasLen}; +use postings::{Postings, DocSet, HasLen}; use std::num::Wrapping; -use std::cmp::Ordering; -const EMPTY_ARRAY: [u32; 0] = []; +const EMPTY_ARRAY: [u32; 0] = []; /// Simulate a `Postings` objects from a `VecPostings`. /// `VecPostings` only exist for testing purposes. @@ -27,99 +26,43 @@ impl From> for VecPostings { } impl DocSet for VecPostings { - fn advance(&mut self,) -> bool { + fn advance(&mut self) -> bool { self.cursor += Wrapping(1); self.doc_ids.len() > self.cursor.0 } - - fn doc(&self,) -> DocId { + + fn doc(&self) -> DocId { self.doc_ids[self.cursor.0] } - - fn skip_next(&mut self, target: DocId) -> SkipResult { - let mut start: usize = self.cursor.0; - match self.doc_ids[start].cmp(&target) { - Ordering::Equal => { - return SkipResult::Reached; - } - Ordering::Greater => { - if self.cursor.0 < self.doc_ids.len() { - return SkipResult::OverStep; - } - else { - return SkipResult::End; - } - } - Ordering::Less => { - // see below - } - } - - let mut end = self.doc_ids.len(); - - while end - start > 1 { - // find an upper bound - let mut jump = 1; - loop { - let jump_dest = start + jump; - if jump_dest >= end { - // we jump out of bounds - break; - } - match self.doc_ids[jump_dest].cmp(&target) { - Ordering::Less => { - // still below the target, let's keep jumping. - start = jump_dest; - jump *= 2; - } - Ordering::Equal => { - self.cursor = Wrapping(jump_dest); - return SkipResult::Reached; - } - Ordering::Greater => { - end = jump_dest; - break; - } - } - } - } - self.cursor = Wrapping(start + 1); - if self.cursor.0 < self.doc_ids.len() { - SkipResult::OverStep - } - else { - SkipResult::End - } - } } impl HasLen for VecPostings { - fn len(&self,) -> usize { + fn len(&self) -> usize { self.doc_ids.len() } } impl Postings for VecPostings { - fn term_freq(&self,) -> u32 { + fn term_freq(&self) -> u32 { 1u32 } - + fn positions(&self) -> &[u32] { &EMPTY_ARRAY - } + } } #[cfg(test)] pub mod tests { - + use super::*; - use DocId; - use postings::{Postings, SkipResult, DocSet}; - - + use DocId; + use postings::{Postings, SkipResult, DocSet}; + + #[test] pub fn test_vec_postings() { - let doc_ids: Vec = (0u32..1024u32).map(|e| e*3).collect(); + let doc_ids: Vec = (0u32..1024u32).map(|e| e * 3).collect(); let mut postings = VecPostings::from(doc_ids); assert!(postings.advance()); assert_eq!(postings.doc(), 0u32); @@ -132,5 +75,5 @@ pub mod tests { assert_eq!(postings.doc(), 300u32); assert_eq!(postings.skip_next(6000u32), SkipResult::End); } -} +} diff --git a/src/query/boolean_query/boolean_query.rs b/src/query/boolean_query/boolean_query.rs new file mode 100644 index 000000000..2d660f7aa --- /dev/null +++ b/src/query/boolean_query/boolean_query.rs @@ -0,0 +1,48 @@ +use Result; +use std::any::Any; +use super::boolean_weight::BooleanWeight; +use query::Weight; +use Searcher; +use query::Query; +use query::Occur; +use query::OccurFilter; + +/// The boolean query combines a set of queries +/// +/// The documents matched by the boolean query are +/// those which +/// * match all of the sub queries associated with the +/// `Must` occurence +/// * match none of the sub queries associated with the +/// `MustNot` occurence. +/// * match at least one of the subqueries that is not +/// a `MustNot` occurence. +#[derive(Debug)] +pub struct BooleanQuery { + subqueries: Vec<(Occur, Box)> +} + +impl From)>> for BooleanQuery { + fn from(subqueries: Vec<(Occur, Box)>) -> BooleanQuery { + BooleanQuery { subqueries: subqueries } + } +} + +impl Query for BooleanQuery { + fn as_any(&self) -> &Any { + self + } + + fn weight(&self, searcher: &Searcher) -> Result> { + let sub_weights = try!(self.subqueries + .iter() + .map(|&(ref _occur, ref subquery)| subquery.weight(searcher)) + .collect()); + let occurs: Vec = self.subqueries + .iter() + .map(|&(ref occur, ref _subquery)| *occur) + .collect(); + let filter = OccurFilter::new(&occurs); + Ok(box BooleanWeight::new(sub_weights, filter)) + } +} \ No newline at end of file diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs new file mode 100644 index 000000000..c24f67760 --- /dev/null +++ b/src/query/boolean_query/boolean_scorer.rs @@ -0,0 +1,149 @@ +use query::Scorer; +use DocId; +use std::collections::BinaryHeap; +use std::cmp::Ordering; +use postings::DocSet; +use query::OccurFilter; +use query::boolean_query::ScoreCombiner; + + +/// Each `HeapItem` represents the head of +/// one of scorer being merged. +/// +/// * `doc` - is the current doc id for the given segment postings +/// * `ord` - is the ordinal used to identify to which segment postings +/// this heap item belong to. +#[derive(Eq, PartialEq)] +struct HeapItem { + doc: DocId, + ord: u32, +} + +/// `HeapItem` are ordered by the document +impl PartialOrd for HeapItem { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for HeapItem { + fn cmp(&self, other:&Self) -> Ordering { + (other.doc).cmp(&self.doc) + } +} + +pub struct BooleanScorer { + scorers: Vec, + queue: BinaryHeap, + doc: DocId, + score_combiner: ScoreCombiner, + occur_filter: OccurFilter, +} + +impl BooleanScorer { + + pub fn scorers(&self) -> &[TScorer] { + &self.scorers + } + + pub fn new(scorers: Vec, + occur_filter: OccurFilter) -> BooleanScorer { + let score_combiner = ScoreCombiner::default_for_num_scorers(scorers.len()); + let mut non_empty_scorers: Vec = Vec::new(); + for mut posting in scorers { + let non_empty = posting.advance(); + if non_empty { + non_empty_scorers.push(posting); + } + } + let heap_items: Vec = non_empty_scorers + .iter() + .map(|posting| posting.doc()) + .enumerate() + .map(|(ord, doc)| { + HeapItem { + doc: doc, + ord: ord as u32 + } + }) + .collect(); + BooleanScorer { + scorers: non_empty_scorers, + queue: BinaryHeap::from(heap_items), + doc: 0u32, + score_combiner: score_combiner, + occur_filter: occur_filter, + + } + } + + /// Advances the head of our heap (the segment posting with the lowest doc) + /// It will also update the new current `DocId` as well as the term frequency + /// associated with the segment postings. + /// + /// After advancing the `SegmentPosting`, the postings is removed from the heap + /// if it has been entirely consumed, or pushed back into the heap. + /// + /// # Panics + /// This method will panic if the head `SegmentPostings` is not empty. + fn advance_head(&mut self,) { + { + let mut mutable_head = self.queue.peek_mut().unwrap(); + let cur_scorers = &mut self.scorers[mutable_head.ord as usize]; + if cur_scorers.advance() { + mutable_head.doc = cur_scorers.doc(); + return; + } + } + self.queue.pop(); + } +} + +impl DocSet for BooleanScorer { + fn advance(&mut self,) -> bool { + loop { + self.score_combiner.clear(); + let mut ord_bitset = 0u64; + match self.queue.peek() { + Some(heap_item) => { + let ord = heap_item.ord as usize; + self.doc = heap_item.doc; + let score = self.scorers[ord].score(); + self.score_combiner.update(score); + ord_bitset |= 1 << ord; + } + None => { + return false; + } + } + self.advance_head(); + while let Some(&HeapItem {doc, ord}) = self.queue.peek() { + if doc == self.doc { + let ord = ord as usize; + let score = self.scorers[ord].score(); + self.score_combiner.update(score); + ord_bitset |= 1 << ord; + } + else { + break; + } + self.advance_head(); + } + if self.occur_filter.accept(ord_bitset) { + return true; + } + } + } + + fn doc(&self,) -> DocId { + self.doc + } +} + +impl Scorer for BooleanScorer { + + fn score(&self,) -> f32 { + self.score_combiner.score() + } +} + diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs new file mode 100644 index 000000000..830f85edf --- /dev/null +++ b/src/query/boolean_query/boolean_weight.rs @@ -0,0 +1,32 @@ +use query::Weight; +use core::SegmentReader; +use query::Scorer; +use super::BooleanScorer; +use query::OccurFilter; +use Result; + +pub struct BooleanWeight { + weights: Vec>, + occur_filter: OccurFilter, +} + +impl BooleanWeight { + pub fn new(weights: Vec>, occur_filter: OccurFilter) -> BooleanWeight { + BooleanWeight { + weights: weights, + occur_filter: occur_filter, + } + } +} + + +impl Weight for BooleanWeight { + fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { + let sub_scorers: Vec> = try!(self.weights + .iter() + .map(|weight| weight.scorer(reader)) + .collect()); + let boolean_scorer = BooleanScorer::new(sub_scorers, self.occur_filter); + Ok(box boolean_scorer) + } +} diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs new file mode 100644 index 000000000..36a113b7f --- /dev/null +++ b/src/query/boolean_query/mod.rs @@ -0,0 +1,158 @@ +mod boolean_query; +mod boolean_scorer; +mod boolean_weight; +mod score_combiner; + +pub use self::boolean_query::BooleanQuery; +pub use self::boolean_scorer::BooleanScorer; +pub use self::score_combiner::ScoreCombiner; + + + +#[cfg(test)] +mod tests { + + use super::*; + use postings::{DocSet, VecPostings}; + use query::Scorer; + use query::OccurFilter; + use query::term_query::TermScorer; + use query::Occur; + use query::Query; + use query::TermQuery; + use collector::tests::TestCollector; + use Index; + use schema::*; + use fastfield::{U32FastFieldReader}; + + fn abs_diff(left: f32, right: f32) -> f32 { + (right - left).abs() + } + + + #[test] + pub fn test_boolean_query() { + let mut schema_builder = SchemaBuilder::default(); + let text_field = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + let index = Index::create_from_tempdir(schema).unwrap(); + { + // writing the segment + let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); + { + let doc = doc!(text_field => "a b c"); + index_writer.add_document(doc).unwrap(); + } + { + let doc = doc!(text_field => "a c"); + index_writer.add_document(doc).unwrap(); + } + { + let doc = doc!(text_field => "b c"); + index_writer.add_document(doc).unwrap(); + } + { + let doc = doc!(text_field => "a b c d"); + index_writer.add_document(doc).unwrap(); + } + { + let doc = doc!(text_field => "d"); + index_writer.add_document(doc).unwrap(); + } + assert!(index_writer.commit().is_ok()); + } + let make_term_query = |text: &str| { + let term_query = TermQuery::from(Term::from_field_text(text_field, text)); + let query: Box = box term_query; + query + }; + + + let matching_docs = |boolean_query: &Query| { + let searcher = index.searcher(); + let mut test_collector = TestCollector::default(); + boolean_query.search(&*searcher, &mut test_collector).unwrap(); + test_collector.docs() + }; + { + let boolean_query = BooleanQuery::from(vec![(Occur::Must, make_term_query("a")) ]); + assert_eq!(matching_docs(&boolean_query), vec!(0, 1, 3)); + } + { + let boolean_query = BooleanQuery::from(vec![(Occur::Should, make_term_query("a")) ]); + assert_eq!(matching_docs(&boolean_query), vec!(0, 1, 3)); + } + { + let boolean_query = BooleanQuery::from(vec![(Occur::Should, make_term_query("a")), (Occur::Should, make_term_query("b"))]); + assert_eq!(matching_docs(&boolean_query), vec!(0, 1, 2, 3)); + } + { + let boolean_query = BooleanQuery::from(vec![(Occur::Must, make_term_query("a")), (Occur::Should, make_term_query("b"))]); + assert_eq!(matching_docs(&boolean_query), vec!(0, 1, 3)); + } + { + let boolean_query = BooleanQuery::from(vec![(Occur::Must, make_term_query("a")), + (Occur::Should, make_term_query("b")), + (Occur::MustNot, make_term_query("d")), + ]); + assert_eq!(matching_docs(&boolean_query), vec!(0, 1)); + } + { + let boolean_query = BooleanQuery::from(vec![(Occur::MustNot, make_term_query("d")),]); + // TODO optimize this use case : only MustNot subqueries... no need + // to read any postings. + assert_eq!(matching_docs(&boolean_query), Vec::new()); + } + } + + #[test] + pub fn test_boolean_scorer() { + let occurs = vec!(Occur::Should, Occur::Should); + let occur_filter = OccurFilter::new(&occurs); + + let left_fieldnorms = U32FastFieldReader::from(vec!(100,200,300)); + + let left = VecPostings::from(vec!(1, 2, 3)); + let left_scorer = TermScorer { + idf: 1f32, + fieldnorm_reader: left_fieldnorms, + postings: left, + }; + + let right_fieldnorms = U32FastFieldReader::from(vec!(15,25,35)); + let right = VecPostings::from(vec!(1, 3, 8)); + + let right_scorer = TermScorer { + idf: 4f32, + fieldnorm_reader: right_fieldnorms, + postings: right, + }; + + let mut boolean_scorer = BooleanScorer::new(vec!(left_scorer, right_scorer), occur_filter); + assert_eq!(boolean_scorer.next(), Some(1u32)); + assert!(abs_diff(boolean_scorer.score(), 0.8707107) < 0.001); + assert_eq!(boolean_scorer.next(), Some(2u32)); + assert!(abs_diff(boolean_scorer.score(), 0.028867513) < 0.001f32); + assert_eq!(boolean_scorer.next(), Some(3u32)); + assert_eq!(boolean_scorer.next(), Some(8u32)); + assert!(abs_diff(boolean_scorer.score(), 0.5163978) < 0.001f32); + assert!(!boolean_scorer.advance()); + } + + + #[test] + pub fn test_term_scorer() { + let left_fieldnorms = U32FastFieldReader::from(vec!(10, 4)); + assert_eq!(left_fieldnorms.get(0), 10); + assert_eq!(left_fieldnorms.get(1), 4); + let left = VecPostings::from(vec!(1)); + let mut left_scorer = TermScorer { + idf: 0.30685282, + fieldnorm_reader: left_fieldnorms, + postings: left, + }; + left_scorer.advance(); + assert!(abs_diff(left_scorer.score(), 0.15342641) < 0.001f32); + } + +} diff --git a/src/query/boolean_query/score_combiner.rs b/src/query/boolean_query/score_combiner.rs new file mode 100644 index 000000000..204c57c23 --- /dev/null +++ b/src/query/boolean_query/score_combiner.rs @@ -0,0 +1,46 @@ +use Score; + +pub struct ScoreCombiner { + coords: Vec, + num_fields: usize, + score: Score, +} + +impl ScoreCombiner { + + pub fn update(&mut self, score: Score) { + self.score += score; + self.num_fields += 1; + } + + pub fn clear(&mut self,) { + self.score = 0f32; + self.num_fields = 0; + } + + /// Compute the coord term + fn coord(&self,) -> f32 { + self.coords[self.num_fields] + } + + pub fn score(&self, ) -> Score { + self.score * self.coord() + } + + pub fn default_for_num_scorers(num_scorers: usize) -> ScoreCombiner { + let query_coords: Vec = (0..num_scorers + 1) + .map(|i| (i as Score) / (num_scorers as Score)) + .collect(); + ScoreCombiner::from(query_coords) + } +} + +impl From> for ScoreCombiner { + fn from(coords: Vec) -> ScoreCombiner { + ScoreCombiner { + coords: coords, + num_fields: 0, + score: 0f32, + } + } +} \ No newline at end of file diff --git a/src/query/daat_multiterm_scorer.rs b/src/query/daat_multiterm_scorer.rs deleted file mode 100644 index b0ab1920c..000000000 --- a/src/query/daat_multiterm_scorer.rs +++ /dev/null @@ -1,285 +0,0 @@ -use DocId; -use postings::{Postings, DocSet}; -use std::cmp::Ordering; -use std::collections::BinaryHeap; -use query::MultiTermAccumulator; -use query::Similarity; -use fastfield::U32FastFieldReader; -use query::Occur; -use std::iter; -use super::Scorer; -use Score; - -/// Each `HeapItem` represents the head of -/// a segment postings being merged. -/// -/// * `doc` - is the current doc id for the given segment postings -/// * `ord` - is the ordinal used to identify to which segment postings -/// this heap item belong to. -#[derive(Eq, PartialEq)] -struct HeapItem { - doc: DocId, - ord: u32, -} - -/// `HeapItem` are ordered by the document -impl PartialOrd for HeapItem { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for HeapItem { - fn cmp(&self, other:&Self) -> Ordering { - (other.doc).cmp(&self.doc) - } -} - -struct Filter { - and_mask: u64, - result: u64, -} - -impl Filter { - fn accept(&self, ord_set: u64) -> bool { - (self.and_mask & ord_set) == self.result - } - - fn new(occurs: &[Occur]) -> Filter { - let mut and_mask = 0u64; - let mut result = 0u64; - for (i, occur) in occurs.iter().enumerate() { - let shift = 1 << i; - match *occur { - Occur::Must => { - and_mask |= shift; - result |= shift; - }, - Occur::MustNot => { - and_mask |= shift; - }, - Occur::Should => {}, - } - } - Filter { - and_mask: and_mask, - result: result - } - } -} - -/// Document-At-A-Time multi term scorer. -/// -/// The scorer merges multiple segment postings and pushes -/// term information to the score accumulator. -pub struct DAATMultiTermScorer { - fieldnorm_readers: Vec, - postings: Vec, - term_frequencies: Vec, - queue: BinaryHeap, - doc: DocId, - similarity: TAccumulator, - filter: Filter, -} - -impl DAATMultiTermScorer { - - fn new_non_empty( - - fieldnorm_readers: Vec, - postings: Vec, - similarity: TAccumulator, - filter: Filter - ) -> DAATMultiTermScorer { - let mut term_frequencies: Vec = iter::repeat(0u32).take(postings.len()).collect(); - let heap_items: Vec = postings - .iter() - .map(|posting| { - (posting.doc(), posting.term_freq()) - }) - .enumerate() - .map(|(ord, (doc, tf))| { - term_frequencies[ord] = tf; - HeapItem { - doc: doc, - ord: ord as u32 - } - }) - .collect(); - DAATMultiTermScorer { - fieldnorm_readers: fieldnorm_readers, - postings: postings, - term_frequencies: term_frequencies, - queue: BinaryHeap::from(heap_items), - doc: 0, - similarity: similarity, - filter: filter - } - } - - /// Constructor - pub fn new(postings_and_fieldnorms: Vec<(Occur, TPostings, U32FastFieldReader)>, similarity: TAccumulator) -> DAATMultiTermScorer { - let mut postings = Vec::new(); - let mut fieldnorm_readers = Vec::new(); - let mut occurs = Vec::new(); - for (occur, mut posting, fieldnorm_reader) in postings_and_fieldnorms { - if posting.advance() { - postings.push(posting); - fieldnorm_readers.push(fieldnorm_reader); - occurs.push(occur); - } - } - let filter = Filter::new(&occurs); - DAATMultiTermScorer::new_non_empty(fieldnorm_readers, postings, similarity, filter) - } - - /// Returns the scorer - pub fn scorer(&self,) -> &TAccumulator { - &self.similarity - } - - /// Advances the head of our heap (the segment postings with the lowest doc) - /// It will also update the new current `DocId` as well as the term frequency - /// associated with the segment postings. - /// - /// After advancing the `SegmentPosting`, the postings is removed from the heap - /// if it has been entirely consumed, or pushed back into the heap. - /// - /// # Panics - /// This method will panic if the head `SegmentPostings` is not empty. - fn advance_head(&mut self,) { - - { - let mut mutable_head = self.queue.peek_mut().unwrap(); - let cur_postings = &mut self.postings[mutable_head.ord as usize]; - if cur_postings.advance() { - let doc = cur_postings.doc(); - self.term_frequencies[mutable_head.ord as usize] = cur_postings.term_freq(); - mutable_head.doc = doc; - return; - } - - } - self.queue.pop(); - } - - /// Returns the field norm for the segment postings with the given ordinal, - /// and the given document. - fn get_field_norm(&self, ord:usize, doc:DocId) -> u32 { - self.fieldnorm_readers[ord].get(doc) - } - -} - -impl Scorer for DAATMultiTermScorer { - fn score(&self,) -> Score { - self.similarity.score() - } -} - -impl DocSet for DAATMultiTermScorer { - - fn advance(&mut self,) -> bool { - loop { - self.similarity.clear(); - let mut ord_bitset = 0u64; - match self.queue.peek() { - Some(heap_item) => { - self.doc = heap_item.doc; - let ord: usize = heap_item.ord as usize; - let fieldnorm = self.get_field_norm(ord, heap_item.doc); - let tf = self.term_frequencies[ord]; - self.similarity.update(ord, tf, fieldnorm); - ord_bitset |= 1 << ord; - } - None => { - return false; - } - } - self.advance_head(); - while let Some(&HeapItem {doc, ord}) = self.queue.peek() { - if doc == self.doc { - let peek_ord: usize = ord as usize; - let peek_tf = self.term_frequencies[peek_ord]; - let peek_fieldnorm = self.get_field_norm(peek_ord, doc); - self.similarity.update(peek_ord, peek_tf, peek_fieldnorm); - ord_bitset |= 1 << peek_ord; - } - else { - break; - } - self.advance_head(); - } - if self.filter.accept(ord_bitset) { - return true; - } - } - } - - fn doc(&self,) -> DocId { - self.doc - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use postings::{DocSet, VecPostings}; - use query::TfIdf; - use query::Scorer; - use directory::Directory; - use directory::RAMDirectory; - use schema::Field; - use std::path::Path; - use query::Occur; - use fastfield::{U32FastFieldReader, U32FastFieldWriter, FastFieldSerializer}; - - - pub fn create_u32_fastfieldreader(field: Field, vals: Vec) -> U32FastFieldReader { - let mut u32_field_writer = U32FastFieldWriter::new(field); - for val in vals { - u32_field_writer.add_val(val); - } - let path = Path::new("some_path"); - let mut directory = RAMDirectory::create(); - { - let write = directory.open_write(&path).unwrap(); - let mut serializer = FastFieldSerializer::new(write).unwrap(); - u32_field_writer.serialize(&mut serializer).unwrap(); - serializer.close().unwrap(); - } - let read = directory.open_read(&path).unwrap(); - U32FastFieldReader::open(read).unwrap() - } - - fn abs_diff(left: f32, right: f32) -> f32 { - (right - left).abs() - } - - #[test] - pub fn test_daat_scorer() { - let left_fieldnorms = create_u32_fastfieldreader(Field(1), vec!(100,200,300)); - let right_fieldnorms = create_u32_fastfieldreader(Field(2), vec!(15,25,35)); - let left = VecPostings::from(vec!(1, 2, 3)); - let right = VecPostings::from(vec!(1, 3, 8)); - let tfidf = TfIdf::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32)); - let mut daat_scorer = DAATMultiTermScorer::new( - vec!( - (Occur::Should, left, left_fieldnorms), - (Occur::Should, right, right_fieldnorms), - ), - tfidf - ); - assert_eq!(daat_scorer.next(), Some(1u32)); - assert!(abs_diff(daat_scorer.score(), 2.182179f32) < 0.001); - assert_eq!(daat_scorer.next(), Some(2u32)); - assert!(abs_diff(daat_scorer.score(), 0.2236068) < 0.001f32); - assert_eq!(daat_scorer.next(), Some(3u32)); - assert_eq!(daat_scorer.next(), Some(8u32)); - assert!(abs_diff(daat_scorer.score(), 0.8944272f32) < 0.001f32); - assert!(!daat_scorer.advance()); - } - -} - diff --git a/src/query/mod.rs b/src/query/mod.rs index 12af7d35c..c494b8c68 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -4,29 +4,29 @@ /// mod query; +mod boolean_query; mod multi_term_query; -mod multi_term_accumulator; -mod similarity_explainer; +mod phrase_query; mod scorer; mod query_parser; mod explanation; -mod tfidf; mod occur; -mod daat_multiterm_scorer; -mod similarity; +mod weight; +mod occur_filter; +mod term_query; -pub use self::similarity::Similarity; - -pub use self::daat_multiterm_scorer::DAATMultiTermScorer; +pub use self::occur_filter::OccurFilter; +pub use self::boolean_query::BooleanQuery; pub use self::occur::Occur; pub use self::query::Query; +pub use self::term_query::TermQuery; +pub use self::phrase_query::PhraseQuery; pub use self::multi_term_query::MultiTermQuery; -pub use self::similarity_explainer::SimilarityExplainer; -pub use self::tfidf::TfIdf; - +pub use self::multi_term_query::MultiTermWeight; pub use self::scorer::Scorer; +pub use self::scorer::EmptyScorer; pub use self::query_parser::QueryParser; pub use self::explanation::Explanation; -pub use self::multi_term_accumulator::MultiTermAccumulator; -pub use self::query_parser::ParsingError; \ No newline at end of file +pub use self::query_parser::ParsingError; +pub use self::weight::Weight; diff --git a/src/query/multi_term_accumulator.rs b/src/query/multi_term_accumulator.rs deleted file mode 100644 index abbec0c56..000000000 --- a/src/query/multi_term_accumulator.rs +++ /dev/null @@ -1,14 +0,0 @@ - -/// Accumulator of the matching terms information -pub trait MultiTermAccumulator { - /// Update the accumulator given the information of a term. - /// - term_ord is the term_ordinal - /// - term_freq is the frequency of the term within the document - /// - fieldnorm is the number of tokens associated to the field for this document - /// - /// The term's update do not have to arrive in a specific order. - /// Terms that are not present in the document will not be updated. - fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32); - /// Resets the accumulator - fn clear(&mut self,); -} diff --git a/src/query/multi_term_query.rs b/src/query/multi_term_query.rs deleted file mode 100644 index 24026e6f8..000000000 --- a/src/query/multi_term_query.rs +++ /dev/null @@ -1,175 +0,0 @@ -use Result; -use Error; -use schema::Term; -use query::Query; -use common::TimerTree; -use common::OpenTimer; -use core::searcher::Searcher; -use collector::Collector; -use SegmentLocalId; -use core::SegmentReader; -use query::SimilarityExplainer; -use postings::SegmentPostings; -use postings::DocSet; -use query::TfIdf; -use postings::SkipResult; -use ScoredDoc; -use query::Scorer; -use query::MultiTermAccumulator; -use DocAddress; -use query::Explanation; -use query::occur::Occur; -use postings::SegmentPostingsOption; -use query::DAATMultiTermScorer; - - -/// Query involving one or more terms. -#[derive(Eq, PartialEq, Debug)] -pub struct MultiTermQuery { - occur_terms: Vec<(Occur, Term)>, -} - -impl MultiTermQuery { - - /// Accessor for the number of terms - pub fn num_terms(&self,) -> usize { - self.occur_terms.len() - } - - /// Builds the similitude object - fn similitude(&self, searcher: &Searcher) -> TfIdf { - let num_terms = self.num_terms(); - let num_docs = searcher.num_docs() as f32; - let idfs: Vec = self.occur_terms - .iter() - .map(|&(_, ref term)| searcher.doc_freq(term)) - .map(|doc_freq| { - if doc_freq == 0 { - 1. - } - else { - 1. + ( num_docs / (doc_freq as f32) ).ln() - } - }) - .collect(); - let query_coords = (0..num_terms + 1) - .map(|i| (i as f32) / (num_terms as f32)) - .collect(); - // TODO have the actual terms in these names - let term_names = self.occur_terms - .iter() - .map(|&(_, ref term)| format!("{:?}", &term)) - .collect(); - let mut tfidf = TfIdf::new(query_coords, idfs); - tfidf.set_term_names(term_names); - tfidf - } - - - /// Search the segment. - fn search_segment<'a, 'b, TAccumulator: MultiTermAccumulator>( - &'b self, - reader: &'b SegmentReader, - accumulator: TAccumulator, - mut timer: OpenTimer<'a>) -> Result> { - let mut postings_and_fieldnorms = Vec::with_capacity(self.num_terms()); - { - let mut decode_timer = timer.open("decode_all"); - for &(occur, ref term) in &self.occur_terms { - let _decode_one_timer = decode_timer.open("decode_one"); - if let Some(postings) = reader.read_postings(term, SegmentPostingsOption::Freq) { - let field = term.field(); - let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field)); - postings_and_fieldnorms.push((occur, postings, fieldnorm_reader)); - } - } - } - if postings_and_fieldnorms.len() > 64 { - // TODO putting the SHOULD at the end of the list should push the limit. - return Err(Error::InvalidArgument(String::from("Limit of 64 terms was exceeded."))); - } - Ok(DAATMultiTermScorer::new(postings_and_fieldnorms, accumulator)) - } -} - - -impl From> for MultiTermQuery { - fn from(occur_terms: Vec<(Occur, Term)>) -> MultiTermQuery { - MultiTermQuery { - occur_terms: occur_terms, - } - } -} - -impl From> for MultiTermQuery { - fn from(terms: Vec) -> MultiTermQuery { - let should_terms = terms - .into_iter() - .map(|term| (Occur::Should, term)) - .collect(); - MultiTermQuery { - occur_terms: should_terms, - } - } -} - -impl Query for MultiTermQuery { - - fn explain( - &self, - searcher: &Searcher, - doc_address: &DocAddress) -> Result { - let segment_reader = searcher.segment_reader(doc_address.segment_ord() as usize); - let similitude = SimilarityExplainer::from(self.similitude(searcher)); - let mut timer_tree = TimerTree::default(); - let mut postings = try!( - self.search_segment( - segment_reader, - similitude, - timer_tree.open("explain")) - ); - Ok(match postings.skip_next(doc_address.doc()) { - SkipResult::Reached => { - let scorer = postings.scorer(); - scorer.explain_score() - } - _ => { - let mut explanation = Explanation::with_val(0f32); - explanation.description(&format!("Failed to run explain: the document {:?} does not match", doc_address)); - explanation - } - }) - } - - fn search( - &self, - searcher: &Searcher, - collector: &mut C) -> Result { - let mut timer_tree = TimerTree::default(); - { - let mut search_timer = timer_tree.open("search"); - for (segment_ord, segment_reader) in searcher.segment_readers().iter().enumerate() { - let mut segment_search_timer = search_timer.open("segment_search"); - { - let _ = segment_search_timer.open("set_segment"); - try!(collector.set_segment(segment_ord as SegmentLocalId, &segment_reader)); - } - let mut postings = try!( - self.search_segment( - segment_reader, - self.similitude(searcher), - segment_search_timer.open("get_postings")) - ); - { - let _collection_timer = segment_search_timer.open("collection"); - while postings.advance() { - let scored_doc = ScoredDoc(postings.score(), postings.doc()); - collector.collect(scored_doc); - } - } - } - } - Ok(timer_tree) - } -} - diff --git a/src/query/multi_term_query/mod.rs b/src/query/multi_term_query/mod.rs new file mode 100644 index 000000000..38cd209a4 --- /dev/null +++ b/src/query/multi_term_query/mod.rs @@ -0,0 +1,5 @@ +mod multi_term_query; +mod multi_term_weight; + +pub use self::multi_term_query::MultiTermQuery; +pub use self::multi_term_weight::MultiTermWeight; \ No newline at end of file diff --git a/src/query/multi_term_query/multi_term_query.rs b/src/query/multi_term_query/multi_term_query.rs new file mode 100644 index 000000000..7b207d64f --- /dev/null +++ b/src/query/multi_term_query/multi_term_query.rs @@ -0,0 +1,77 @@ +use Result; +use query::Weight; +use std::any::Any; +use schema::Term; +use query::MultiTermWeight; +use query::Query; +use core::searcher::Searcher; +use query::occur::Occur; +use query::occur_filter::OccurFilter; +use query::term_query::TermQuery; +use postings::SegmentPostingsOption; + + +/// Query involving one or more terms. +#[derive(Eq, Clone, PartialEq, Debug)] +pub struct MultiTermQuery { + // TODO need a better Debug + occur_terms: Vec<(Occur, Term)>, +} + +impl MultiTermQuery { + /// Accessor for the number of terms + pub fn num_terms(&self) -> usize { + self.occur_terms.len() + } + + /// Same as `weight()`, except that rather than a boxed trait, + /// `specialized_weight` returns a specific type of the weight, allowing for + /// compile-time optimization. + pub fn specialized_weight(&self, searcher: &Searcher) -> MultiTermWeight { + let term_queries: Vec = self.occur_terms + .iter() + .map(|&(_, ref term)| TermQuery::from(term.clone())) + .collect(); + let occurs: Vec = self.occur_terms + .iter() + .map(|&(occur, _)| occur.clone()) + .collect(); + let occur_filter = OccurFilter::new(&occurs); + let weights = term_queries.iter() + .map(|term_query| { + let mut term_weight = term_query.specialized_weight(searcher); + term_weight.segment_postings_options = SegmentPostingsOption::FreqAndPositions; + term_weight + }) + .collect(); + MultiTermWeight::new(weights, occur_filter) + } +} + + + +impl Query for MultiTermQuery { + fn as_any(&self) -> &Any { + self + } + + fn weight(&self, searcher: &Searcher) -> Result> { + Ok(box self.specialized_weight(searcher)) + } +} + + +impl From> for MultiTermQuery { + fn from(occur_terms: Vec<(Occur, Term)>) -> MultiTermQuery { + MultiTermQuery { occur_terms: occur_terms } + } +} + +impl From> for MultiTermQuery { + fn from(terms: Vec) -> MultiTermQuery { + let should_terms: Vec<(Occur, Term)> = terms.into_iter() + .map(|term| (Occur::Should, term)) + .collect(); + MultiTermQuery::from(should_terms) + } +} \ No newline at end of file diff --git a/src/query/multi_term_query/multi_term_weight.rs b/src/query/multi_term_query/multi_term_weight.rs new file mode 100644 index 000000000..17e58d877 --- /dev/null +++ b/src/query/multi_term_query/multi_term_weight.rs @@ -0,0 +1,45 @@ +use Result; +use query::Weight; +use core::SegmentReader; +use query::Scorer; +use query::occur_filter::OccurFilter; +use postings::SegmentPostings; +use query::term_query::{TermWeight, TermScorer}; +use query::boolean_query::BooleanScorer; + +/// Weight object associated to a [`MultiTermQuery`](./struct.MultiTermQuery.html). +pub struct MultiTermWeight { + weights: Vec, + occur_filter: OccurFilter, +} + +impl MultiTermWeight { + /// MultiTermWeigh constructor. + /// The `OccurFilter` is tied with the weights order. + pub fn new(weights: Vec, occur_filter: OccurFilter) -> MultiTermWeight { + MultiTermWeight { + weights: weights, + occur_filter: occur_filter, + } + } + + /// Same as `scorer()`, except that rather than a boxed trait, + /// `specialized_scorer` returns a specific type of the scorer, allowing for + /// compile-time optimization. + pub fn specialized_scorer<'a>(&'a self, + reader: &'a SegmentReader) + -> Result>>> { + let mut term_scorers: Vec> = Vec::new(); + for term_weight in &self.weights { + let term_scorer = try!(term_weight.specialized_scorer(reader)); + term_scorers.push(term_scorer); + } + Ok(BooleanScorer::new(term_scorers, self.occur_filter)) + } +} + +impl Weight for MultiTermWeight { + fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { + Ok(box try!(self.specialized_scorer(reader))) + } +} diff --git a/src/query/occur.rs b/src/query/occur.rs index 86bade98d..1f42b4c63 100644 --- a/src/query/occur.rs +++ b/src/query/occur.rs @@ -2,9 +2,9 @@ /// should be present or must not be present. #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum Occur { - /// The term should be present in the document. - /// Document without the term will be considered - /// in scoring as well. + /// For a given document to be considered for scoring, + /// at least one of the document with the Should or the Must + /// Occur constraint must be within the document. Should, /// Document without the term are excluded from the search. Must, diff --git a/src/query/occur_filter.rs b/src/query/occur_filter.rs new file mode 100644 index 000000000..53280fa6c --- /dev/null +++ b/src/query/occur_filter.rs @@ -0,0 +1,44 @@ +use query::Occur; + + +/// An OccurFilter represents a filter over a bitset of +// at most 64 elements. +/// +/// It wraps some simple bitmask to compute the filter +/// rapidly. +#[derive(Clone, Copy)] +pub struct OccurFilter { + and_mask: u64, + result: u64, +} + +impl OccurFilter { + + /// Returns true if the bitset is matching the occur list. + pub fn accept(&self, ord_set: u64) -> bool { + (self.and_mask & ord_set) == self.result + } + + /// Builds an `OccurFilter` from a list of `Occur`. + pub fn new(occurs: &[Occur]) -> OccurFilter { + let mut and_mask = 0u64; + let mut result = 0u64; + for (i, occur) in occurs.iter().enumerate() { + let shift = 1 << i; + match *occur { + Occur::Must => { + and_mask |= shift; + result |= shift; + }, + Occur::MustNot => { + and_mask |= shift; + }, + Occur::Should => {}, + } + } + OccurFilter { + and_mask: and_mask, + result: result + } + } +} diff --git a/src/query/phrase_query.rs b/src/query/phrase_query.rs deleted file mode 100644 index a0ce11748..000000000 --- a/src/query/phrase_query.rs +++ /dev/null @@ -1,86 +0,0 @@ -use schema::Term; -use query::Query; -use common::TimerTree; -use common::OpenTimer; -use std::io; -use core::searcher::Searcher; -use collector::Collector; -use core::searcher::SegmentLocalId; -use core::SegmentReader; -use postings::Postings; -use postings::SegmentPostings; -use postings::intersection; - -pub struct PhraseQuery { - terms: Vec, -} - -impl Query for PhraseQuery { - - fn search(&self, searcher: &Searcher, collector: &mut C) -> io::Result { - let mut timer_tree = TimerTree::default(); - { - let mut search_timer = timer_tree.open("search"); - for (segment_ord, segment_reader) in searcher.segments().iter().enumerate() { - let mut segment_search_timer = search_timer.open("segment_search"); - { - let _ = segment_search_timer.open("set_segment"); - try!(collector.set_segment(segment_ord as SegmentLocalId, &segment_reader)); - } - let mut postings = self.search_segment(segment_reader, segment_search_timer.open("get_postings")); - { - let _collection_timer = segment_search_timer.open("collection"); - while postings.next() { - collector.collect(postings.doc()); - } - } - } - } - Ok(timer_tree) - } -} - -impl PhraseQuery { - pub fn new(terms: Vec) -> PhraseQuery { - PhraseQuery { - terms: terms, - } - } - - fn search_segment<'a, 'b>(&'b self, reader: &'b SegmentReader, mut timer: OpenTimer<'a>) -> Box { - if self.terms.len() == 1 { - match reader.get_term(&self.terms[0]) { - Some(term_info) => { - let postings: SegmentPostings<'b> = reader.read_postings(&term_info); - Box::new(postings) - }, - None => { - Box::new(SegmentPostings::empty()) - }, - } - } else { - let mut segment_postings: Vec = Vec::new(); - { - let mut decode_timer = timer.open("decode_all"); - for term in self.terms.iter() { - match reader.get_term(term) { - Some(term_info) => { - let _decode_one_timer = decode_timer.open("decode_one"); - let segment_posting = reader.read_postings_with_positions(&term_info); - segment_postings.push(segment_posting); - } - None => { - // currently this is a strict intersection. - return Box::new(SegmentPostings::empty()); - } - } - } - } - Box::new(intersection(segment_postings)) - } - } -} - - - - diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs new file mode 100644 index 000000000..0500b8257 --- /dev/null +++ b/src/query/phrase_query/mod.rs @@ -0,0 +1,70 @@ +mod phrase_query; +mod phrase_weight; +mod phrase_scorer; + +pub use self::phrase_query::PhraseQuery; +pub use self::phrase_weight::PhraseWeight; +pub use self::phrase_scorer::PhraseScorer; + + +#[cfg(test)] +mod tests { + + use super::*; + use query::Query; + use core::Index; + use schema::FieldValue; + use schema::{Document, Term, SchemaBuilder, TEXT}; + use collector::tests::TestCollector; + + #[test] + pub fn test_phrase_query() { + + let mut schema_builder = SchemaBuilder::default(); + let text_field = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + { + let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); + { // 0 + let doc = doc!(text_field=>"b b b d c g c"); + index_writer.add_document(doc).unwrap(); + } + { // 1 + let doc = doc!(text_field=>"a b b d c g c"); + index_writer.add_document(doc).unwrap(); + } + { // 2 + let doc = doc!(text_field=>"a b a b c"); + index_writer.add_document(doc).unwrap(); + } + { // 3 + let doc = doc!(text_field=>"c a b a d ga a"); + index_writer.add_document(doc).unwrap(); + } + { // 4 + let doc = doc!(text_field=>"a b c"); + index_writer.add_document(doc).unwrap(); + } + assert!(index_writer.commit().is_ok()); + } + + let searcher = index.searcher(); + let test_query = |texts: Vec<&str>| { + let mut test_collector = TestCollector::default(); + let terms: Vec = texts + .iter() + .map(|text| Term::from_field_text(text_field, text)) + .collect(); + let phrase_query = PhraseQuery::from(terms); + phrase_query.search(&*searcher, &mut test_collector).expect("search should succeed"); + test_collector.docs() + }; + assert_eq!(test_query(vec!("a", "b", "c")), vec!(2, 4)); + assert_eq!(test_query(vec!("a", "b")), vec!(1, 2, 3, 4)); + assert_eq!(test_query(vec!("b", "b")), vec!(0, 1)); + assert_eq!(test_query(vec!("g", "ewrwer")), vec!()); + assert_eq!(test_query(vec!("g", "a")), vec!()); + } + +} diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs new file mode 100644 index 000000000..d44bdceb0 --- /dev/null +++ b/src/query/phrase_query/phrase_query.rs @@ -0,0 +1,54 @@ +use schema::Term; +use query::Query; +use core::searcher::Searcher; +use super::PhraseWeight; +use std::any::Any; +use query::Weight; +use Result; + + +/// `PhraseQuery` matches a specific sequence of word. +/// For instance the phrase query for `"part time"` will match +/// the sentence +/// +/// **Alan just got a part time job.** +/// +/// On the other hand it will not match the sentence. +/// +/// **This is my favorite part of the job.** +/// +/// Using a `PhraseQuery` on a field requires positions +/// to be indexed for this field. +/// +#[derive(Debug)] +pub struct PhraseQuery { + phrase_terms: Vec, +} + +impl Query for PhraseQuery { + + + /// Used to make it possible to cast Box + /// into a specific type. This is mostly useful for unit tests. + fn as_any(&self) -> &Any { + self + } + + /// Create the weight associated to a query. + /// + /// See [Weight](./trait.Weight.html). + fn weight(&self, _searcher: &Searcher) -> Result> { + Ok(box PhraseWeight::from(self.phrase_terms.clone())) + } + +} + + +impl From> for PhraseQuery { + fn from(phrase_terms: Vec) -> PhraseQuery { + assert!(phrase_terms.len() > 1); + PhraseQuery { + phrase_terms: phrase_terms, + } + } +} diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs new file mode 100644 index 000000000..d2a6a645f --- /dev/null +++ b/src/query/phrase_query/phrase_scorer.rs @@ -0,0 +1,81 @@ +use query::Scorer; +use DocSet; +use postings::SegmentPostings; +use postings::Postings; +use postings::IntersectionDocSet; +use DocId; + +pub struct PhraseScorer<'a> { + pub intersection_docset: IntersectionDocSet>, +} + + +impl<'a> PhraseScorer<'a> { + fn phrase_match(&self) -> bool { + let mut positions_arr: Vec<&[u32]> = self.intersection_docset + .docsets() + .iter() + .map(|posting| { + posting.positions() + }) + .collect(); + + let num_postings = positions_arr.len() as u32; + + let mut ord = 1u32; + let mut pos_candidate = positions_arr[0][0]; + positions_arr[0] = &(positions_arr[0])[1..]; + let mut count_matching = 1; + + 'outer: loop { + let target = pos_candidate + ord; + let positions = positions_arr[ord as usize]; + for i in 0..positions.len() { + let pos_i = positions[i]; + if pos_i < target { + continue; + } + if pos_i == target { + count_matching += 1; + if count_matching == num_postings { + return true; + } + } + else if pos_i > target { + count_matching = 1; + pos_candidate = positions[i] - ord; + positions_arr[ord as usize] = &(positions_arr[ord as usize])[(i+1)..]; + } + ord += 1; + if ord == num_postings { + ord = 0; + } + continue 'outer; + } + return false; + } + } +} + +impl<'a> DocSet for PhraseScorer<'a> { + fn advance(&mut self,) -> bool { + while self.intersection_docset.advance() { + if self.phrase_match() { + return true; + } + } + false + } + + fn doc(&self,) -> DocId { + self.intersection_docset.doc() + } +} + + +impl<'a> Scorer for PhraseScorer<'a> { + fn score(&self,) -> f32 { + 1f32 + } + +} diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs new file mode 100644 index 000000000..d2a384183 --- /dev/null +++ b/src/query/phrase_query/phrase_weight.rs @@ -0,0 +1,39 @@ +use query::Weight; +use query::Scorer; +use schema::Term; +use postings::SegmentPostingsOption; +use core::SegmentReader; +use super::PhraseScorer; +use postings::IntersectionDocSet; +use query::EmptyScorer; +use Result; + +pub struct PhraseWeight { + phrase_terms: Vec, +} + +impl From> for PhraseWeight { + fn from(phrase_terms: Vec) -> PhraseWeight { + PhraseWeight { + phrase_terms: phrase_terms + } + } +} + +impl Weight for PhraseWeight { + fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { + let mut term_postings_list = Vec::new(); + for term in &self.phrase_terms { + let term_postings_option = reader.read_postings(term, SegmentPostingsOption::FreqAndPositions); + if let Some(term_postings) = term_postings_option { + term_postings_list.push(term_postings); + } + else { + return Ok(box EmptyScorer); + } + } + Ok(box PhraseScorer { + intersection_docset: IntersectionDocSet::from(term_postings_list), + }) + } +} diff --git a/src/query/query.rs b/src/query/query.rs index 5ef8e5823..f4bcfa3c2 100644 --- a/src/query/query.rs +++ b/src/query/query.rs @@ -2,25 +2,84 @@ use Result; use collector::Collector; use core::searcher::Searcher; use common::TimerTree; -use DocAddress; -use query::Explanation; +use SegmentLocalId; +use super::Weight; +use std::fmt; +use std::any::Any; -/// Queries represent the query of the user, and are in charge -/// of the logic defining the set of documents that should be -/// sent to the collector, as well as the way to score the -/// documents. -pub trait Query { +/// Query trait are in charge of defining : +/// +/// - a set of documents +/// - a way to score these documents +/// +/// When performing a [search](#method.search), these documents will then +/// be pushed to a [Collector](../collector/trait.Collector.html), +/// which will in turn be in charge of deciding what to do with them. +/// +/// Concretely, this scored docset is represented by the +/// [`Scorer`](./trait.Scorer.html) trait. +/// +/// Because our index is actually split into segments, the +/// query does not actually directly creates `DocSet` object. +/// Instead, the query creates a [`Weight`](./trait.Weight.html) +/// object for a given searcher. +/// +/// The weight object, in turn, makes it possible to create +/// a scorer for a specific [`SegmentReader`](../struct.SegmentReader.html). +/// +/// So to sum it up : +/// - a `Query` is recipe to define a set of documents as well the way to score them. +/// - a `Weight` is this recipe tied to a specific `Searcher`. It may for instance +/// hold statistics about the different term of the query. It is created by the query. +/// - a `Scorer` is a cursor over the set of matching documents, for a specific +/// [`SegmentReader`](../struct.SegmentReader.html). It is created by the [`Weight`](./trait.Weight.html). +/// +/// When implementing a new type of `Query`, it is normal to implement a +/// dedicated `Query`, `Weight` and `Scorer`. +pub trait Query: fmt::Debug { - /// Perform the search operation - fn search( + /// Used to make it possible to cast Box + /// into a specific type. This is mostly useful for unit tests. + fn as_any(&self) -> &Any; + + /// Create the weight associated to a query. + /// + /// See [Weight](./trait.Weight.html). + fn weight(&self, searcher: &Searcher) -> Result>; + + /// Search works as follows : + /// + /// First the weight object associated to the query is created. + /// + /// Then, the query loops over the segments and for each segment : + /// - setup the collector and informs it that the segment being processed has changed. + /// - creates a `Scorer` object associated for this segment + /// - iterate throw the matched documents and push them to the collector. + /// + fn search( &self, searcher: &Searcher, - collector: &mut C) -> Result; + collector: &mut Collector) -> Result { + + let mut timer_tree = TimerTree::default(); + let weight = try!(self.weight(searcher)); - /// Explain the score of a specific document - fn explain( - &self, - searcher: &Searcher, - doc_address: &DocAddress) -> Result; + { + let mut search_timer = timer_tree.open("search"); + for (segment_ord, segment_reader) in searcher.segment_readers().iter().enumerate() { + let mut segment_search_timer = search_timer.open("segment_search"); + { + let _ = segment_search_timer.open("set_segment"); + try!(collector.set_segment(segment_ord as SegmentLocalId, &segment_reader)); + } + let mut scorer = try!(weight.scorer(segment_reader)); + { + let _collection_timer = segment_search_timer.open("collection"); + scorer.collect(collector); + } + } + } + Ok(timer_tree) + } } diff --git a/src/query/query_parser.rs b/src/query/query_parser.rs index 89876825d..3d0ee87ec 100644 --- a/src/query/query_parser.rs +++ b/src/query/query_parser.rs @@ -1,14 +1,8 @@ -use Result as tantivy_Error; use combine::*; -use collector::Collector; -use core::searcher::Searcher; -use common::TimerTree; use query::{Query, MultiTermQuery}; use schema::{Schema, FieldType, Term, Field}; use analyzer::SimpleTokenizer; use analyzer::StreamingIterator; -use DocAddress; -use query::Explanation; use query::Occur; @@ -61,23 +55,6 @@ pub struct QueryParser { } -/// The `QueryParser` returns a `StandardQuery`. -#[derive(Eq, PartialEq, Debug)] -pub enum StandardQuery { - MultiTerm(MultiTermQuery), -} - -impl StandardQuery { - /// Number of terms involved in the query. - pub fn num_terms(&self,) -> usize { - match *self { - StandardQuery::MultiTerm(ref q) => { - q.num_terms() - } - } - } -} - impl QueryParser { /// Creates a `QueryParser` @@ -142,7 +119,7 @@ impl QueryParser { /// /// Implementing a lenient mode for this query parser is tracked /// in [Issue 5](https://github.com/fulmicoton/tantivy/issues/5) - pub fn parse_query(&self, query: &str) -> Result { + pub fn parse_query(&self, query: &str) -> Result, ParsingError> { match parser(query_language).parse(query.trim()) { Ok(literals) => { let mut terms_result: Vec<(Occur, Term)> = Vec::new(); @@ -154,9 +131,7 @@ impl QueryParser { .map(|term| (occur, term) )); } Ok( - StandardQuery::MultiTerm( - MultiTermQuery::from(terms_result) - ) + box MultiTermQuery::from(terms_result) ) } Err(_) => { @@ -167,24 +142,6 @@ impl QueryParser { } -impl Query for StandardQuery { - fn search(&self, searcher: &Searcher, collector: &mut C) -> tantivy_Error { - match *self { - StandardQuery::MultiTerm(ref q) => { - q.search(searcher, collector) - } - } - } - - fn explain( - &self, - searcher: &Searcher, - doc_address: &DocAddress) -> tantivy_Error { - match *self { - StandardQuery::MultiTerm(ref q) => q.explain(searcher, doc_address) - } - } -} fn compute_terms(field: Field, text: &str) -> Vec { @@ -324,6 +281,10 @@ mod test { assert!(query_parser.parse("f:@e!e").is_err()); } + // fn extract(query_parser: &QueryParser, q: &str) -> T { + // query_parser.parse_query(q).unwrap().as_any().downcast_ref::().unwrap(), + // } + #[test] pub fn test_query_parser() { let mut schema_builder = SchemaBuilder::default(); @@ -334,9 +295,9 @@ mod test { assert!(query_parser.parse_query("a:b").is_err()); { let terms = vec!(Term::from_field_text(title_field, "abctitle")); - let query = StandardQuery::MultiTerm(MultiTermQuery::from(terms)); + let query = MultiTermQuery::from(terms); assert_eq!( - query_parser.parse_query("title:abctitle").unwrap(), + *query_parser.parse_query("title:abctitle").unwrap().as_any().downcast_ref::().unwrap(), query ); } @@ -345,21 +306,21 @@ mod test { Term::from_field_text(text_field, "abctitle"), Term::from_field_text(author_field, "abctitle"), ); - let query = StandardQuery::MultiTerm(MultiTermQuery::from(terms)); + let query = MultiTermQuery::from(terms); assert_eq!( - query_parser.parse_query("abctitle").unwrap(), + *query_parser.parse_query("abctitle").unwrap().as_any().downcast_ref::().unwrap(), query ); } { let terms = vec!(Term::from_field_text(title_field, "abctitle")); - let query = StandardQuery::MultiTerm(MultiTermQuery::from(terms)); + let query = MultiTermQuery::from(terms); assert_eq!( - query_parser.parse_query("title:abctitle ").unwrap(), + *query_parser.parse_query("title:abctitle ").unwrap().as_any().downcast_ref::().unwrap(), query ); assert_eq!( - query_parser.parse_query(" title:abctitle").unwrap(), + *query_parser.parse_query(" title:abctitle").unwrap().as_any().downcast_ref::().unwrap(), query ); } diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 7e9f17424..ae8c33d66 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -1,13 +1,59 @@ use DocSet; +use DocId; +use Score; +use collector::Collector; +use std::ops::{Deref, DerefMut}; - -/// Scored `DocSet` +/// Scored set of documents matching a query within a specific segment. +/// +/// See [Query](./trait.Query.html). pub trait Scorer: DocSet { /// Returns the score. /// /// This method will perform a bit of computation and is not cached. - fn score(&self,) -> f32; + fn score(&self,) -> Score; + + /// Consumes the complete `DocSet` and + /// push the scored documents to the collector. + fn collect(&mut self, collector: &mut Collector) { + while self.advance() { + collector.collect(self.doc(), self.score()); + } + } } +impl<'a> Scorer for Box { + fn score(&self,) -> Score { + self.deref().score() + } + + fn collect(&mut self, collector: &mut Collector) { + let scorer = self.deref_mut(); + while scorer.advance() { + collector.collect(scorer.doc(), scorer.score()); + } + } +} + +/// EmptyScorer is a dummy Scorer in which no document matches. +/// +/// It is useful for tests and handling edge cases. +pub struct EmptyScorer; + +impl DocSet for EmptyScorer { + fn advance(&mut self,) -> bool { + false + } + + fn doc(&self,) -> DocId { + DocId::max_value() + } +} + +impl Scorer for EmptyScorer { + fn score(&self,) -> Score { + 0f32 + } +} diff --git a/src/query/similarity.rs b/src/query/similarity.rs deleted file mode 100644 index 1ac371af7..000000000 --- a/src/query/similarity.rs +++ /dev/null @@ -1,19 +0,0 @@ -use Score; -use query::Explanation; -use query::MultiTermAccumulator; - -/// Similarity score -pub trait Similarity: MultiTermAccumulator { - - /// Compute and returns the similarity score, - /// - /// The results are not cached. - fn score(&self, ) -> Score; - - /// Explain the computation of this similarity given all of - /// terms information. - /// - /// `vals` is an array of `(term_ord, term_freq, field_norm)`. - /// Terms that are not present should not appear in the array. - fn explain(&self, vals: &[(usize, u32, u32)]) -> Explanation; -} diff --git a/src/query/similarity_explainer.rs b/src/query/similarity_explainer.rs deleted file mode 100644 index 996b778a7..000000000 --- a/src/query/similarity_explainer.rs +++ /dev/null @@ -1,50 +0,0 @@ -use Score; -use super::MultiTermAccumulator; -use super::Similarity; -use super::Explanation; - - -/// Wrapper over a similarity used to run `explain` -pub struct SimilarityExplainer { - scorer: TSimilarity, - vals: Vec<(usize, u32, u32)>, -} - -impl SimilarityExplainer { - /// Returns the underlying similary's score explanation. - pub fn explain_score(&self,) -> Explanation { - self.scorer.explain(&self.vals) - } -} - -impl From for SimilarityExplainer { - fn from(multi_term_scorer: TSimilarity) -> SimilarityExplainer { - SimilarityExplainer { - scorer: multi_term_scorer, - vals: Vec::new(), - } - } -} - -impl MultiTermAccumulator for SimilarityExplainer { - fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) { - self.vals.push((term_ord, term_freq, fieldnorm)); - self.scorer.update(term_ord, term_freq, fieldnorm); - } - - fn clear(&mut self,) { - self.vals.clear(); - self.scorer.clear(); - } -} - -impl Similarity for SimilarityExplainer { - - fn score(&self,) -> Score { - self.scorer.score() - } - - fn explain(&self, vals: &[(usize, u32, u32)]) -> Explanation { - self.scorer.explain(vals) - } -} diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs new file mode 100644 index 000000000..e8be286c1 --- /dev/null +++ b/src/query/term_query/mod.rs @@ -0,0 +1,7 @@ +mod term_query; +mod term_weight; +mod term_scorer; + +pub use self::term_query::TermQuery; +pub use self::term_weight::TermWeight; +pub use self::term_scorer::TermScorer; diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs new file mode 100644 index 000000000..3ec748c25 --- /dev/null +++ b/src/query/term_query/term_query.rs @@ -0,0 +1,58 @@ +use Term; +use Result; +use super::term_weight::TermWeight; +use query::Query; +use query::Weight; +use postings::SegmentPostingsOption; +use Searcher; +use std::any::Any; + +/// A Term query matches all of the documents +/// containing a specific term. +/// +/// The score associated is defined as +/// `idf` * sqrt(`term_freq` / `field norm`) +/// in which : +/// * idf - inverse document frequency. +/// * term_freq - number of occurrences of the term in the field +/// * field norm - number of tokens in the field. +#[derive(Debug)] +pub struct TermQuery { + term: Term, +} + +impl TermQuery { + + /// Returns a weight object. + /// + /// While `.weight(...)` returns a boxed trait object, + /// this method return a specific implementation. + /// This is useful for optimization purpose. + pub fn specialized_weight(&self, searcher: &Searcher) -> TermWeight { + TermWeight { + num_docs: searcher.num_docs(), + doc_freq: searcher.doc_freq(&self.term), + term: self.term.clone(), + segment_postings_options: SegmentPostingsOption::NoFreq, + } + } +} + +impl From for TermQuery { + fn from(term: Term) -> TermQuery { + TermQuery { + term: term + } + } +} + +impl Query for TermQuery { + fn as_any(&self) -> &Any { + self + } + + fn weight(&self, searcher: &Searcher) -> Result> { + Ok(box self.specialized_weight(searcher)) + } + +} diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs new file mode 100644 index 000000000..c12b24174 --- /dev/null +++ b/src/query/term_query/term_scorer.rs @@ -0,0 +1,37 @@ +use Score; +use DocId; +use fastfield::U32FastFieldReader; +use postings::DocSet; +use query::Scorer; +use postings::Postings; + +pub struct TermScorer where TPostings: Postings { + pub idf: Score, + pub fieldnorm_reader: U32FastFieldReader, + pub postings: TPostings, +} + +impl TermScorer where TPostings: Postings { + pub fn postings(&self) -> &TPostings { + &self.postings + } +} + +impl DocSet for TermScorer where TPostings: Postings { + fn advance(&mut self,) -> bool { + self.postings.advance() + } + + fn doc(&self,) -> DocId { + self.postings.doc() + } +} + +impl Scorer for TermScorer where TPostings: Postings { + fn score(&self,) -> Score { + let doc = self.postings.doc(); + let field_norm = self.fieldnorm_reader.get(doc); + self.idf * (self.postings.term_freq() as f32 / field_norm as f32).sqrt() + } +} + diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs new file mode 100644 index 000000000..9d7bac3ee --- /dev/null +++ b/src/query/term_query/term_weight.rs @@ -0,0 +1,56 @@ +use Term; +use query::Weight; +use core::SegmentReader; +use query::Scorer; +use postings::SegmentPostingsOption; +use postings::SegmentPostings; +use fastfield::U32FastFieldReader; +use super::term_scorer::TermScorer; +use Result; + +pub struct TermWeight { + pub num_docs: u32, + pub doc_freq: u32, + pub term: Term, + pub segment_postings_options: SegmentPostingsOption, +} + + +impl Weight for TermWeight { + + fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { + let specialized_scorer = try!(self.specialized_scorer(reader)); + Ok(box specialized_scorer) + } + +} + +impl TermWeight { + + fn idf(&self) -> f32 { + 1.0 + (self.num_docs as f32 / (self.doc_freq as f32 + 1.0)).ln() + } + + pub fn specialized_scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result>> { + let field = self.term.field(); + let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field)); + Ok( + reader + .read_postings(&self.term, self.segment_postings_options) + .map(|segment_postings| + TermScorer { + idf: self.idf(), + fieldnorm_reader: fieldnorm_reader, + postings: segment_postings, + } + ) + .unwrap_or( + TermScorer { + idf: 1f32, + fieldnorm_reader: U32FastFieldReader::empty(), + postings: SegmentPostings::empty() + }) + ) + } + +} \ No newline at end of file diff --git a/src/query/tfidf.rs b/src/query/tfidf.rs deleted file mode 100644 index 749c4663e..000000000 --- a/src/query/tfidf.rs +++ /dev/null @@ -1,143 +0,0 @@ -use Score; -use super::MultiTermAccumulator; -use super::Explanation; -use super::Similarity; - - -/// `TfIdf` is the default pertinence score in tantivy. -/// -/// See [Tf-Idf in the global documentation](https://fulmicoton.gitbooks.io/tantivy-doc/content/tfidf.html) -#[derive(Clone)] -pub struct TfIdf { - coords: Vec, - idf: Vec, - score: f32, - num_fields: usize, - term_names: Option>, //< only here for explain -} - -impl MultiTermAccumulator for TfIdf { - - #[inline] - fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) { - assert!(term_freq != 0u32); - self.score += self.term_score(term_ord, term_freq, fieldnorm); - self.num_fields += 1; - } - - #[inline] - fn clear(&mut self,) { - self.score = 0f32; - self.num_fields = 0; - } -} - -impl TfIdf { - /// Constructor - /// * coords - Coords act as a boosting factor for queries - /// containing many terms. The coords must have a length - /// of `num_terms + 1` - /// * idf - idf value for each given term. `idf` must - /// have a length of `num_terms`. - pub fn new(coords: Vec, idf: Vec) -> TfIdf { - TfIdf { - coords: coords, - idf: idf, - score: 0f32, - num_fields: 0, - term_names: None, - } - } - - /// Compute the coord term - fn coord(&self,) -> f32 { - self.coords[self.num_fields] - } - - /// Set the term names for the explain function - pub fn set_term_names(&mut self, term_names: Vec) { - self.term_names = Some(term_names); - } - - /// Return the name for the ordinal `ord` - fn term_name(&self, ord: usize) -> String { - match self.term_names { - Some(ref term_names_vec) => term_names_vec[ord].clone(), - None => format!("Field({})", ord) - } - } - - #[inline] - fn term_score(&self, term_ord: usize, term_freq: u32, field_norm: u32) -> f32 { - (term_freq as f32 / field_norm as f32).sqrt() * self.idf[term_ord] - } -} - -impl Similarity for TfIdf { - - #[inline] - fn score(&self, ) -> Score { - self.score * self.coord() - } - - fn explain(&self, vals: &[(usize, u32, u32)]) -> Explanation { - let score = self.score(); - let mut explanation = Explanation::with_val(score); - let formula_components: Vec = vals.iter() - .map(|&(ord, _, _)| ord) - .map(|ord| format!("", self.term_name(ord))) - .collect(); - let formula = format!(" * ({})", formula_components.join(" + ")); - explanation.set_formula(&formula); - for &(ord, term_freq, field_norm) in vals { - let term_score = self.term_score(ord, term_freq, field_norm); - let term_explanation = explanation.add_child(&self.term_name(ord), term_score); - term_explanation.set_formula(" sqrt( / ) * "); - } - explanation - } -} - - - - -#[cfg(test)] -mod tests { - - use super::*; - use query::MultiTermAccumulator; - use query::Similarity; - - fn abs_diff(left: f32, right: f32) -> f32 { - (right - left).abs() - } - - #[test] - pub fn test_tfidf() { - let mut tfidf = TfIdf::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32)); - { - tfidf.update(0, 1, 1); - assert!(abs_diff(tfidf.score(), 1f32) < 0.001f32); - tfidf.clear(); - } - { - tfidf.update(1, 1, 1); - assert_eq!(tfidf.score(), 4f32); - tfidf.clear(); - } - { - tfidf.update(0, 2, 1); - assert!(abs_diff(tfidf.score(), 1.4142135) < 0.001f32); - tfidf.clear(); - } - { - tfidf.update(0, 1, 1); - tfidf.update(1, 1, 1); - assert_eq!(tfidf.score(), 10f32); - tfidf.clear(); - } - - - } - -} \ No newline at end of file diff --git a/src/query/weight.rs b/src/query/weight.rs new file mode 100644 index 000000000..db583a3e4 --- /dev/null +++ b/src/query/weight.rs @@ -0,0 +1,16 @@ +use super::Scorer; +use Result; +use core::SegmentReader; + + +/// A Weight is the specialization of a Query +/// for a given set of segments. +/// +/// See [Query](./trait.Query.html). +pub trait Weight { + + /// Returns the scorer for the given segment. + /// See [Query](./trait.Query.html). + fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result>; + +}