diff --git a/src/fastfield/mod.rs b/src/fastfield/mod.rs index d1e1c8b6d..39d68ebfb 100644 --- a/src/fastfield/mod.rs +++ b/src/fastfield/mod.rs @@ -39,9 +39,7 @@ fn compute_num_bits(amplitude: u32) -> u8 { mod tests { use super::compute_num_bits; - use super::U32FastFieldsReader; - use super::U32FastFieldsWriter; - use super::FastFieldSerializer; + use super::*; use schema::Field; use std::path::Path; use directory::{Directory, WritePtr, RAMDirectory}; @@ -81,6 +79,17 @@ mod tests { doc.add_u32(field, value); fast_field_writers.add_document(&doc); } + + #[test] + pub fn test_fastfield() { + let test_fastfield = U32FastFieldReader::from(vec!(100,200,300)); + println!("{}", test_fastfield.get(0)); + println!("{}", test_fastfield.get(1)); + println!("{}", test_fastfield.get(2)); + assert_eq!(test_fastfield.get(0), 100); + assert_eq!(test_fastfield.get(1), 200); + assert_eq!(test_fastfield.get(2), 300); + } #[test] fn test_intfastfield_small() { diff --git a/src/fastfield/reader.rs b/src/fastfield/reader.rs index f57aa9362..18046025a 100644 --- a/src/fastfield/reader.rs +++ b/src/fastfield/reader.rs @@ -5,8 +5,12 @@ use std::ops::Deref; use directory::ReadOnlySource; use common::BinarySerializable; use DocId; -use schema::Field; - +use schema::{Field, SchemaBuilder}; +use std::path::Path; +use schema::FAST; +use directory::{WritePtr, RAMDirectory, Directory}; +use fastfield::FastFieldSerializer; +use fastfield::U32FastFieldsWriter; use super::compute_num_bits; pub struct U32FastFieldReader { @@ -62,6 +66,31 @@ impl U32FastFieldReader { } } + +impl From> for U32FastFieldReader { + fn from(vals: Vec) -> U32FastFieldReader { + let mut schema_builder = SchemaBuilder::default(); + let field = schema_builder.add_u32_field("field", FAST); + let schema = schema_builder.build(); + let path = Path::new("test"); + let mut directory: RAMDirectory = RAMDirectory::create(); + { + let write: WritePtr = directory.open_write(Path::new("test")).unwrap(); + let mut serializer = FastFieldSerializer::new(write).unwrap(); + let mut fast_field_writers = U32FastFieldsWriter::from_schema(&schema); + for val in vals { + let mut fast_field_writer = fast_field_writers.get_field_writer(field).unwrap(); + fast_field_writer.add_val(val); + } + fast_field_writers.serialize(&mut serializer).unwrap(); + serializer.close().unwrap(); + } + let source = directory.open_read(&path).unwrap(); + let fast_field_readers = U32FastFieldsReader::open(source).unwrap(); + fast_field_readers.get_field(field).unwrap() + } +} + pub struct U32FastFieldsReader { source: ReadOnlySource, field_offsets: HashMap, diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 374f08c33..2227353de 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -17,7 +17,9 @@ mod offset_postings; mod freq_handler; mod docset; mod segment_postings_option; +mod segment_postings_tester; +pub use self::segment_postings_tester::SegmentPostingsTester; pub use self::docset::{SkipResult, DocSet}; pub use self::offset_postings::OffsetPostings; pub use self::recorder::{Recorder, NothingRecorder, TermFrequencyRecorder, TFAndPositionRecorder}; diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index 8de872f90..ab12add7d 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -62,7 +62,6 @@ impl<'a> SegmentPostings<'a> { impl<'a> DocSet for SegmentPostings<'a> { - // goes to the next element. // next needs to be called a first time to point to the correct element. #[inline] diff --git a/src/postings/segment_postings_tester.rs b/src/postings/segment_postings_tester.rs new file mode 100644 index 000000000..b34e69003 --- /dev/null +++ b/src/postings/segment_postings_tester.rs @@ -0,0 +1,76 @@ +use super::FreqHandler; +use DocId; +use std::path::Path; +use super::SegmentPostings; +use super::serializer::PostingsSerializer; +use schema::{SchemaBuilder, STRING}; +use directory::{RAMDirectory, Directory}; +use schema::Term; + + +const EMPTY_POSITIONS: [DocId; 0] = [0u32; 0]; + +pub struct SegmentPostingsTester { + data: Vec, + len: u32, +} + +impl SegmentPostingsTester { + pub fn get(&self) -> SegmentPostings { + SegmentPostings::from_data(self.len, &self.data, FreqHandler::new_without_freq()) + } +} + +impl From> for SegmentPostingsTester { + + fn from(doc_ids: Vec) -> SegmentPostingsTester { + let mut directory = RAMDirectory::create(); + let mut schema_builder = SchemaBuilder::default(); + let field = schema_builder.add_text_field("text", STRING); + let schema = schema_builder.build(); + let mut postings_serializer = PostingsSerializer::new( + directory.open_write(Path::new("terms")).unwrap(), + directory.open_write(Path::new("postings")).unwrap(), + directory.open_write(Path::new("positions")).unwrap(), + schema + ).unwrap(); + let term = Term::from_field_text(field, "dummy"); + postings_serializer.new_term(&term, doc_ids.len() as u32); + for doc_id in &doc_ids { + postings_serializer.write_doc(*doc_id, 1u32, &EMPTY_POSITIONS); + } + postings_serializer.close_term(); + postings_serializer.close(); + let postings_data = directory.open_read(Path::new("postings")).unwrap(); + SegmentPostingsTester { + data: Vec::from(postings_data.as_slice()), + len: doc_ids.len() as u32, + } + } + +} + + + +#[cfg(test)] +mod tests { + + use super::*; + use postings::DocSet; + + #[test] + pub fn test_segment_postings_tester() { + let segment_postings_tester = SegmentPostingsTester::from(vec!(1,2,17,32)); + let mut postings = segment_postings_tester.get(); + assert!(postings.advance()); + assert_eq!(postings.doc(), 1); + assert!(postings.advance()); + assert_eq!(postings.doc(), 2); + assert!(postings.advance()); + assert_eq!(postings.doc(), 17); + assert!(postings.advance()); + assert_eq!(postings.doc(), 32); + assert!(!postings.advance()); + } + +} diff --git a/src/postings/serializer.rs b/src/postings/serializer.rs index 5fb7c9505..1cf1b5392 100644 --- a/src/postings/serializer.rs +++ b/src/postings/serializer.rs @@ -66,13 +66,16 @@ pub struct PostingsSerializer { impl PostingsSerializer { + + /// Open a new `PostingsSerializer` for the given segment - pub fn open(segment: &mut Segment) -> Result { - let terms_write = try!(segment.open_write(SegmentComponent::TERMS)); + pub fn new( + terms_write: WritePtr, + postings_write: WritePtr, + positions_write: WritePtr, + schema: Schema + ) -> Result { let terms_fst_builder = try!(FstMapBuilder::new(terms_write)); - let postings_write = try!(segment.open_write(SegmentComponent::POSTINGS)); - let positions_write = try!(segment.open_write(SegmentComponent::POSITIONS)); - let schema = segment.schema(); Ok(PostingsSerializer { terms_fst_builder: terms_fst_builder, postings_write: postings_write, @@ -91,6 +94,20 @@ impl PostingsSerializer { }) } + + /// Open a new `PostingsSerializer` for the given segment + pub fn open(segment: &mut Segment) -> Result { + let terms_write = try!(segment.open_write(SegmentComponent::TERMS)); + let postings_write = try!(segment.open_write(SegmentComponent::POSTINGS)); + let positions_write = try!(segment.open_write(SegmentComponent::POSITIONS)); + PostingsSerializer::new( + terms_write, + postings_write, + positions_write, + segment.schema() + ) + } + fn load_indexing_options(&mut self, field: Field) { let field_entry: &FieldEntry = self.schema.get_field_entry(field); self.text_indexing_options = match *field_entry.field_type() { diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs index 559640230..91b31d192 100644 --- a/src/query/boolean_query/boolean_scorer.rs +++ b/src/query/boolean_query/boolean_scorer.rs @@ -5,46 +5,7 @@ use std::collections::BinaryHeap; use std::cmp::Ordering; use postings::DocSet; use query::OccurFilter; - - -struct ScoreCombiner { - coords: Vec, - num_fields: usize, - score: Score, -} - -impl ScoreCombiner { - - fn update(&mut self, score: Score) { - self.score += score; - self.num_fields += 1; - } - - fn clear(&mut self,) { - self.score = 0f32; - self.num_fields = 0; - } - - /// Compute the coord term - fn coord(&self,) -> f32 { - self.coords[self.num_fields] - } - - #[inline] - fn score(&self, ) -> Score { - self.score * self.coord() - } -} - -impl From> for ScoreCombiner { - fn from(coords: Vec) -> ScoreCombiner { - ScoreCombiner { - coords: coords, - num_fields: 0, - score: 0f32, - } - } -} +use query::boolean_query::ScoreCombiner; /// Each `HeapItem` represents the head of @@ -82,12 +43,13 @@ pub struct BooleanScorer { impl BooleanScorer { - pub fn new(postings: Vec, filter: OccurFilter) -> BooleanScorer { - let num_postings = postings.len(); - let query_coords: Vec = (0..num_postings + 1) - .map(|i| (i as Score) / (num_postings as Score)) - .collect(); - let score_combiner = ScoreCombiner::from(query_coords); + pub fn set_score_combiner(&mut self, score_combiner: ScoreCombiner) { + self.score_combiner = score_combiner; + } + + pub fn new(postings: Vec, + filter: OccurFilter) -> BooleanScorer { + let score_combiner = ScoreCombiner::default_for_num_scorers(postings.len()); let mut non_empty_postings: Vec = Vec::new(); for mut posting in postings { let non_empty = posting.advance(); @@ -131,10 +93,9 @@ impl BooleanScorer { let mut mutable_head = self.queue.peek_mut().unwrap(); let cur_postings = &mut self.postings[mutable_head.ord as usize]; if cur_postings.advance() { - mutable_head.doc = cur_postings.doc(); + mutable_head.doc = cur_postings.doc(); return; } - } self.queue.pop(); } @@ -188,3 +149,84 @@ impl Scorer for BooleanScorer { } } + + + +#[cfg(test)] +mod tests { + + use super::*; + use postings::{DocSet, VecPostings}; + use query::TfIdf; + use query::Scorer; + use query::OccurFilter; + use query::term_query::TermScorer; + use directory::Directory; + use directory::RAMDirectory; + use schema::Field; + use super::super::ScoreCombiner; + use std::path::Path; + use query::Occur; + use postings::SegmentPostingsTester; + use postings::Postings; + use fastfield::{U32FastFieldReader, U32FastFieldWriter, FastFieldSerializer}; + + + + fn abs_diff(left: f32, right: f32) -> f32 { + (right - left).abs() + } + + #[test] + pub fn test_boolean_scorer() { + let occurs = vec!(Occur::Should, Occur::Should); + let occur_filter = OccurFilter::new(&occurs); + + let left_fieldnorms = U32FastFieldReader::from(vec!(100,200,300)); + let left_tester = SegmentPostingsTester::from(vec!(1, 2, 3)); + let left = left_tester.get(); + let left_scorer = TermScorer { + idf: 1f32, + fieldnorm_reader: left_fieldnorms, + segment_postings: left, + }; + + let right_fieldnorms = U32FastFieldReader::from(vec!(15,25,35)); + let right_tester = SegmentPostingsTester::from(vec!(1, 3, 8)); + let right = right_tester.get(); + let mut right_scorer = TermScorer { + idf: 4f32, + fieldnorm_reader: right_fieldnorms, + segment_postings: right, + }; + let score_combiner = ScoreCombiner::from(vec!(0f32, 1f32, 2f32)); + let mut boolean_scorer = BooleanScorer::new(vec!(left_scorer, right_scorer), occur_filter); + boolean_scorer.set_score_combiner(score_combiner); + assert_eq!(boolean_scorer.next(), Some(1u32)); + assert!(abs_diff(boolean_scorer.score(), 1.7414213) < 0.001); + assert_eq!(boolean_scorer.next(), Some(2u32)); + assert!(abs_diff(boolean_scorer.score(), 0.057735026) < 0.001f32); + assert_eq!(boolean_scorer.next(), Some(3u32)); + assert_eq!(boolean_scorer.next(), Some(8u32)); + assert!(abs_diff(boolean_scorer.score(), 1.0327955) < 0.001f32); + assert!(!boolean_scorer.advance()); + } + + + #[test] + pub fn test_term_scorer() { + let left_fieldnorms = U32FastFieldReader::from(vec!(10, 4)); + assert_eq!(left_fieldnorms.get(0), 10); + assert_eq!(left_fieldnorms.get(1), 4); + let left_tester = SegmentPostingsTester::from(vec!(1)); + let left = left_tester.get(); + let mut left_scorer = TermScorer { + idf: 0.30685282, // 1f32, + fieldnorm_reader: left_fieldnorms, + segment_postings: left, + }; + left_scorer.advance(); + assert!(abs_diff(left_scorer.score(), 0.15342641) < 0.001f32); + } + +} diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 9e7506c2e..3f19cb92e 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -2,7 +2,9 @@ mod boolean_clause; mod boolean_query; mod boolean_scorer; mod boolean_weight; +mod score_combiner; pub use self::boolean_query::BooleanQuery; pub use self::boolean_clause::BooleanClause; -pub use self::boolean_scorer::BooleanScorer; \ No newline at end of file +pub use self::boolean_scorer::BooleanScorer; +pub use self::score_combiner::ScoreCombiner; \ No newline at end of file diff --git a/src/query/boolean_query/score_combiner.rs b/src/query/boolean_query/score_combiner.rs new file mode 100644 index 000000000..204c57c23 --- /dev/null +++ b/src/query/boolean_query/score_combiner.rs @@ -0,0 +1,46 @@ +use Score; + +pub struct ScoreCombiner { + coords: Vec, + num_fields: usize, + score: Score, +} + +impl ScoreCombiner { + + pub fn update(&mut self, score: Score) { + self.score += score; + self.num_fields += 1; + } + + pub fn clear(&mut self,) { + self.score = 0f32; + self.num_fields = 0; + } + + /// Compute the coord term + fn coord(&self,) -> f32 { + self.coords[self.num_fields] + } + + pub fn score(&self, ) -> Score { + self.score * self.coord() + } + + pub fn default_for_num_scorers(num_scorers: usize) -> ScoreCombiner { + let query_coords: Vec = (0..num_scorers + 1) + .map(|i| (i as Score) / (num_scorers as Score)) + .collect(); + ScoreCombiner::from(query_coords) + } +} + +impl From> for ScoreCombiner { + fn from(coords: Vec) -> ScoreCombiner { + ScoreCombiner { + coords: coords, + num_fields: 0, + score: 0f32, + } + } +} \ No newline at end of file diff --git a/src/query/daat_multiterm_scorer.rs b/src/query/daat_multiterm_scorer.rs deleted file mode 100644 index 98f51f49f..000000000 --- a/src/query/daat_multiterm_scorer.rs +++ /dev/null @@ -1,253 +0,0 @@ -use DocId; -use postings::{Postings, DocSet}; -use std::cmp::Ordering; -use std::collections::BinaryHeap; -use query::MultiTermAccumulator; -use query::Similarity; -use fastfield::U32FastFieldReader; -use query::Occur; -use std::iter; -use query::OccurFilter; -use super::Scorer; -use Score; - -/// Each `HeapItem` represents the head of -/// a segment postings being merged. -/// -/// * `doc` - is the current doc id for the given segment postings -/// * `ord` - is the ordinal used to identify to which segment postings -/// this heap item belong to. -#[derive(Eq, PartialEq)] -struct HeapItem { - doc: DocId, - ord: u32, -} - -/// `HeapItem` are ordered by the document -impl PartialOrd for HeapItem { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for HeapItem { - fn cmp(&self, other:&Self) -> Ordering { - (other.doc).cmp(&self.doc) - } -} - - -/// Document-At-A-Time multi term scorer. -/// -/// The scorer merges multiple segment postings and pushes -/// term information to the score accumulator. -pub struct DAATMultiTermScorer { - fieldnorm_readers: Vec, - postings: Vec, - term_frequencies: Vec, - queue: BinaryHeap, - doc: DocId, - similarity: TAccumulator, - filter: OccurFilter, -} - -impl DAATMultiTermScorer { - - fn new_non_empty( - - fieldnorm_readers: Vec, - postings: Vec, - similarity: TAccumulator, - filter: OccurFilter - ) -> DAATMultiTermScorer { - let mut term_frequencies: Vec = iter::repeat(0u32).take(postings.len()).collect(); - let heap_items: Vec = postings - .iter() - .map(|posting| { - (posting.doc(), posting.term_freq()) - }) - .enumerate() - .map(|(ord, (doc, tf))| { - term_frequencies[ord] = tf; - HeapItem { - doc: doc, - ord: ord as u32 - } - }) - .collect(); - DAATMultiTermScorer { - fieldnorm_readers: fieldnorm_readers, - postings: postings, - term_frequencies: term_frequencies, - queue: BinaryHeap::from(heap_items), - doc: 0, - similarity: similarity, - filter: filter - } - } - - /// Constructor - pub fn new(postings_and_fieldnorms: Vec<(Occur, TPostings, U32FastFieldReader)>, similarity: TAccumulator) -> DAATMultiTermScorer { - let mut postings = Vec::new(); - let mut fieldnorm_readers = Vec::new(); - let mut occurs = Vec::new(); - for (occur, mut posting, fieldnorm_reader) in postings_and_fieldnorms { - if posting.advance() { - postings.push(posting); - fieldnorm_readers.push(fieldnorm_reader); - occurs.push(occur); - } - } - let filter = OccurFilter::new(&occurs); - DAATMultiTermScorer::new_non_empty(fieldnorm_readers, postings, similarity, filter) - } - - /// Returns the scorer - pub fn scorer(&self,) -> &TAccumulator { - &self.similarity - } - - /// Advances the head of our heap (the segment postings with the lowest doc) - /// It will also update the new current `DocId` as well as the term frequency - /// associated with the segment postings. - /// - /// After advancing the `SegmentPosting`, the postings is removed from the heap - /// if it has been entirely consumed, or pushed back into the heap. - /// - /// # Panics - /// This method will panic if the head `SegmentPostings` is not empty. - fn advance_head(&mut self,) { - { - let mut mutable_head = self.queue.peek_mut().unwrap(); - let cur_postings = &mut self.postings[mutable_head.ord as usize]; - if cur_postings.advance() { - let doc = cur_postings.doc(); - self.term_frequencies[mutable_head.ord as usize] = cur_postings.term_freq(); - mutable_head.doc = doc; - return; - } - - } - self.queue.pop(); - } - - /// Returns the field norm for the segment postings with the given ordinal, - /// and the given document. - fn get_field_norm(&self, ord:usize, doc:DocId) -> u32 { - self.fieldnorm_readers[ord].get(doc) - } - -} - -impl Scorer for DAATMultiTermScorer { - fn score(&self,) -> Score { - self.similarity.score() - } -} - -impl DocSet for DAATMultiTermScorer { - - fn advance(&mut self,) -> bool { - loop { - self.similarity.clear(); - let mut ord_bitset = 0u64; - match self.queue.peek() { - Some(heap_item) => { - self.doc = heap_item.doc; - let ord: usize = heap_item.ord as usize; - let fieldnorm = self.get_field_norm(ord, heap_item.doc); - let tf = self.term_frequencies[ord]; - self.similarity.update(ord, tf, fieldnorm); - ord_bitset |= 1 << ord; - } - None => { - return false; - } - } - self.advance_head(); - while let Some(&HeapItem {doc, ord}) = self.queue.peek() { - if doc == self.doc { - let peek_ord: usize = ord as usize; - let peek_tf = self.term_frequencies[peek_ord]; - let peek_fieldnorm = self.get_field_norm(peek_ord, doc); - self.similarity.update(peek_ord, peek_tf, peek_fieldnorm); - ord_bitset |= 1 << peek_ord; - } - else { - break; - } - self.advance_head(); - } - if self.filter.accept(ord_bitset) { - return true; - } - } - } - - fn doc(&self,) -> DocId { - self.doc - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use postings::{DocSet, VecPostings}; - use query::TfIdf; - use query::Scorer; - use directory::Directory; - use directory::RAMDirectory; - use schema::Field; - use std::path::Path; - use query::Occur; - use fastfield::{U32FastFieldReader, U32FastFieldWriter, FastFieldSerializer}; - - - pub fn create_u32_fastfieldreader(field: Field, vals: Vec) -> U32FastFieldReader { - let mut u32_field_writer = U32FastFieldWriter::new(field); - for val in vals { - u32_field_writer.add_val(val); - } - let path = Path::new("some_path"); - let mut directory = RAMDirectory::create(); - { - let write = directory.open_write(&path).unwrap(); - let mut serializer = FastFieldSerializer::new(write).unwrap(); - u32_field_writer.serialize(&mut serializer).unwrap(); - serializer.close().unwrap(); - } - let read = directory.open_read(&path).unwrap(); - U32FastFieldReader::open(read).unwrap() - } - - fn abs_diff(left: f32, right: f32) -> f32 { - (right - left).abs() - } - - #[test] - pub fn test_daat_scorer() { - let left_fieldnorms = create_u32_fastfieldreader(Field(1), vec!(100,200,300)); - let right_fieldnorms = create_u32_fastfieldreader(Field(2), vec!(15,25,35)); - let left = VecPostings::from(vec!(1, 2, 3)); - let right = VecPostings::from(vec!(1, 3, 8)); - let tfidf = TfIdf::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32)); - let mut daat_scorer = DAATMultiTermScorer::new( - vec!( - (Occur::Should, left, left_fieldnorms), - (Occur::Should, right, right_fieldnorms), - ), - tfidf - ); - assert_eq!(daat_scorer.next(), Some(1u32)); - assert!(abs_diff(daat_scorer.score(), 2.182179f32) < 0.001); - assert_eq!(daat_scorer.next(), Some(2u32)); - assert!(abs_diff(daat_scorer.score(), 0.2236068) < 0.001f32); - assert_eq!(daat_scorer.next(), Some(3u32)); - assert_eq!(daat_scorer.next(), Some(8u32)); - assert!(abs_diff(daat_scorer.score(), 0.8944272f32) < 0.001f32); - assert!(!daat_scorer.advance()); - } - -} - diff --git a/src/query/mod.rs b/src/query/mod.rs index 161f30425..57bdeb1ba 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -13,7 +13,6 @@ mod query_parser; mod explanation; mod tfidf; mod occur; -mod daat_multiterm_scorer; mod similarity; mod weight; mod occur_filter; @@ -26,9 +25,6 @@ pub use self::empty_scorer::EmptyScorer; pub use self::occur_filter::OccurFilter; pub use self::similarity::Similarity; - -pub use self::daat_multiterm_scorer::DAATMultiTermScorer; - pub use self::boolean_query::BooleanQuery; pub use self::occur::Occur; pub use self::query::Query; diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index 6ad5c591d..e7c3bf644 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -27,6 +27,7 @@ impl<'a> Scorer for TermScorer<'a> { fn score(&self,) -> Score { let doc = self.segment_postings.doc(); let field_norm = self.fieldnorm_reader.get(doc); - self.idf * (self.segment_postings.term_freq() as f32 / field_norm as f32).sqrt() + self.idf * (self.segment_postings.term_freq() as f32 / field_norm as f32).sqrt() } } +