diff --git a/Cargo.toml b/Cargo.toml index cc0753c7f..6befcdf3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ byteorder = "0.4" memmap = "0.2" lazy_static = "0.1.*" regex = "0.1" -fst = { git = "https://github.com/BurntSushi/fst.git", rev = "e6c7eec6" } +fst = "0.1.31" atomicwrites = "0.0.14" tempfile = "2.0.0" rustc-serialize = "0.3.16" diff --git a/src/core/searcher.rs b/src/core/searcher.rs index a22e59754..dfa0b1069 100644 --- a/src/core/searcher.rs +++ b/src/core/searcher.rs @@ -6,6 +6,7 @@ use collector::Collector; use std::io; use common::TimerTree; use query::Query; +use DocId; use DocAddress; use schema::Term; @@ -22,6 +23,13 @@ impl Searcher { let segment_reader = &self.segments[segment_local_id as usize]; segment_reader.doc(doc_id) } + + pub fn num_docs(&self,) -> DocId { + self.segments + .iter() + .map(|segment_reader| segment_reader.num_docs()) + .fold(0u32, |acc, val| acc + val) + } pub fn doc_freq(&self, term: &Term) -> u32 { self.segments diff --git a/src/core/segment_reader.rs b/src/core/segment_reader.rs index 52e6fa0b1..dee94ba37 100644 --- a/src/core/segment_reader.rs +++ b/src/core/segment_reader.rs @@ -42,6 +42,10 @@ impl SegmentReader { pub fn max_doc(&self) -> DocId { self.segment_info.max_doc } + + pub fn num_docs(&self) -> DocId { + self.segment_info.max_doc + } pub fn get_fast_field_reader(&self, field: Field) -> io::Result { let field_entry = self.schema.get_field_entry(field); diff --git a/src/core/segment_writer.rs b/src/core/segment_writer.rs index c5ecb865b..5ddd9d2ad 100644 --- a/src/core/segment_writer.rs +++ b/src/core/segment_writer.rs @@ -35,10 +35,6 @@ fn create_fieldnorms_writer(schema: &Schema) -> U32FastFieldsWriter { U32FastFieldsWriter::new(u32_fields) } -fn compute_field_norm(num_tokens: usize) -> u32 { - ((350f32 / (1f32 + num_tokens as f32).sqrt()) as u32) -} - impl SegmentWriter { @@ -109,11 +105,10 @@ impl SegmentWriter { } } } - let field_norm = compute_field_norm(num_tokens); self.fieldnorms_writer .get_field_writer(field) .map(|field_norms_writer| { - field_norms_writer.set_val(doc_id, field_norm) + field_norms_writer.set_val(doc_id, num_tokens as u32) }); } diff --git a/src/core/writer.rs b/src/core/writer.rs index 6119228be..b65241492 100644 --- a/src/core/writer.rs +++ b/src/core/writer.rs @@ -54,7 +54,7 @@ impl IndexWriter { let mut segment_writer = SegmentWriter::for_segment(segment.clone(), &schema_clone).unwrap(); segment_writer.add_document(&*doc, &schema_clone).unwrap(); - for _ in 0..(225_000 - 1) { + for _ in 0..100_000 { { let queue = queue_output_clone.lock().unwrap(); match queue.recv() { diff --git a/src/lib.rs b/src/lib.rs index de8df5453..a0b8b31b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -214,9 +214,9 @@ mod tests { let searcher = index.searcher().unwrap(); let segment_reader: &SegmentReader = searcher.segments().iter().next().unwrap(); let fieldnorms_reader = segment_reader.get_fieldnorms_reader(text_field).unwrap(); - assert_eq!(fieldnorms_reader.get(0), 175); + assert_eq!(fieldnorms_reader.get(0), 3); assert_eq!(fieldnorms_reader.get(1), 0); - assert_eq!(fieldnorms_reader.get(2), 202); + assert_eq!(fieldnorms_reader.get(2), 2); } } diff --git a/src/postings/union_postings.rs b/src/postings/union_postings.rs index 35329df66..c1b248315 100644 --- a/src/postings/union_postings.rs +++ b/src/postings/union_postings.rs @@ -141,7 +141,11 @@ mod tests { serializer.close().unwrap(); U32FastFieldReader::open(ReadOnlySource::Anonymous(data.copy_vec())).unwrap() } - + + fn abs_diff(left: f32, right: f32) -> f32 { + (right - left).abs() + } + #[test] pub fn test_union_postings() { let left_fieldnorms = create_u32_fastfieldreader(Field(1), vec!(100,200,300)); @@ -156,14 +160,14 @@ mod tests { ); assert!(union.next()); assert_eq!(union.doc(), 1); - assert_eq!(union.score(), 210f32); + assert!(abs_diff(union.score(), 2.182179f32) < 0.001); assert!(union.next()); assert_eq!(union.doc(), 2); - assert_eq!(union.score(), 20f32); + assert!(abs_diff(union.score(), 0.2236068) < 0.001f32); assert!(union.next()); assert_eq!(union.doc(), 3); assert!(union.next()); - assert_eq!(union.score(), 80f32); + assert!(abs_diff(union.score(), 0.8944272f32) < 0.001f32); assert_eq!(union.doc(), 8); assert!(!union.next()); } diff --git a/src/query/multi_term_query.rs b/src/query/multi_term_query.rs index 3f2c5b9b5..df2790d9d 100644 --- a/src/query/multi_term_query.rs +++ b/src/query/multi_term_query.rs @@ -55,14 +55,15 @@ impl Query for MultiTermQuery { impl MultiTermQuery { fn scorer(&self, searcher: &Searcher) -> MultiTermScorer { + let num_docs = searcher.num_docs() as f32; let idfs: Vec = self.terms.iter() .map(|term| searcher.doc_freq(term)) .map(|doc_freq| { if doc_freq == 0 { - return 1. + 1. } else { - 1.0 / (doc_freq as f32) + 1. + ( num_docs / (doc_freq as f32) ).ln() } }) .collect(); diff --git a/src/query/multi_term_scorer.rs b/src/query/multi_term_scorer.rs index 764a05a96..f9f298f7d 100644 --- a/src/query/multi_term_scorer.rs +++ b/src/query/multi_term_scorer.rs @@ -20,8 +20,10 @@ impl MultiTermScorer { } pub fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) { - self.score += ((term_freq * fieldnorm) as f32) * self.idf[term_ord]; - self.num_fields += 1; + if term_freq > 0 { + self.score += (term_freq as f32 / fieldnorm as f32).sqrt() * self.idf[term_ord]; + self.num_fields += 1; + } } fn coord(&self,) -> f32 { @@ -49,23 +51,29 @@ mod tests { use super::*; use query::Scorer; + + + fn abs_diff(left: f32, right: f32) -> f32 { + (right - left).abs() + } #[test] pub fn test_multiterm_scorer() { let mut multi_term_scorer = MultiTermScorer::new(vec!(1f32, 2f32), vec!(1f32, 4f32)); { multi_term_scorer.update(0, 1, 1); - assert_eq!(multi_term_scorer.score(), 1f32); - multi_term_scorer.clear(); + assert!(abs_diff(multi_term_scorer.score(), 1f32) < 0.001f32); + multi_term_scorer.clear(); + } { multi_term_scorer.update(1, 1, 1); assert_eq!(multi_term_scorer.score(), 4f32); - multi_term_scorer.clear(); + multi_term_scorer.clear(); } { multi_term_scorer.update(0, 2, 1); - assert_eq!(multi_term_scorer.score(), 2f32); + assert!(abs_diff(multi_term_scorer.score(), 1.4142135) < 0.001f32); multi_term_scorer.clear(); } { diff --git a/src/query/query_parser.rs b/src/query/query_parser.rs index bb3c85008..e8fa9cdd3 100644 --- a/src/query/query_parser.rs +++ b/src/query/query_parser.rs @@ -15,7 +15,7 @@ pub enum ParsingError { pub struct QueryParser { schema: Schema, - default_field: Field, + default_fields: Vec, } pub enum StandardQuery { @@ -35,24 +35,29 @@ impl Query for StandardQuery { impl QueryParser { pub fn new(schema: Schema, - default_field: Field) -> QueryParser { + default_fields: Vec) -> QueryParser { QueryParser { schema: schema, - default_field: default_field, + default_fields: default_fields, } } // TODO check that the term is str. // we only support str field for the moment - fn transform_literal(&self, literal: Literal) -> Result { + fn transform_literal(&self, literal: Literal) -> Result, ParsingError> { match literal { Literal::DefaultField(val) => { - Ok(Term::from_field_text(self.default_field, &val)) + let terms = self.default_fields + .iter() + .cloned() + .map(|field| Term::from_field_text(field, &val)) + .collect(); + Ok(terms) }, Literal::WithField(field_name, val) => { match self.schema.get_field(&field_name) { Some(field) => { - Ok(Term::from_field_text(field, &val)) + Ok(vec!(Term::from_field_text(field, &val))) } None => { Err(ParsingError::FieldDoesNotExist(field_name)) @@ -65,12 +70,16 @@ impl QueryParser { pub fn parse_query(&self, query: &str) -> Result { match parser(query_language).parse(query) { Ok(literals) => { - let terms_result: Result, ParsingError> = literals.0.into_iter() - .map(|literal| self.transform_literal(literal)) - .collect(); - terms_result - .map(MultiTermQuery::new) - .map(StandardQuery::MultiTerm) + let mut terms_result: Vec = Vec::new(); + for literal in literals.0.into_iter() { + let literal_terms = try!(self.transform_literal(literal)); + terms_result.extend_from_slice(&literal_terms); + } + Ok( + StandardQuery::MultiTerm( + MultiTermQuery::new(terms_result) + ) + ) } Err(_) => { Err(ParsingError::SyntaxError)