More than one default field in query parser. fst version to 1.3.1.

This commit is contained in:
Paul Masurel
2016-08-01 09:22:01 +09:00
parent 090973ff16
commit f94efcf5aa
10 changed files with 63 additions and 34 deletions

View File

@@ -9,7 +9,7 @@ byteorder = "0.4"
memmap = "0.2"
lazy_static = "0.1.*"
regex = "0.1"
fst = { git = "https://github.com/BurntSushi/fst.git", rev = "e6c7eec6" }
fst = "0.1.31"
atomicwrites = "0.0.14"
tempfile = "2.0.0"
rustc-serialize = "0.3.16"

View File

@@ -6,6 +6,7 @@ use collector::Collector;
use std::io;
use common::TimerTree;
use query::Query;
use DocId;
use DocAddress;
use schema::Term;
@@ -22,6 +23,13 @@ impl Searcher {
let segment_reader = &self.segments[segment_local_id as usize];
segment_reader.doc(doc_id)
}
pub fn num_docs(&self,) -> DocId {
self.segments
.iter()
.map(|segment_reader| segment_reader.num_docs())
.fold(0u32, |acc, val| acc + val)
}
pub fn doc_freq(&self, term: &Term) -> u32 {
self.segments

View File

@@ -42,6 +42,10 @@ impl SegmentReader {
pub fn max_doc(&self) -> DocId {
self.segment_info.max_doc
}
pub fn num_docs(&self) -> DocId {
self.segment_info.max_doc
}
pub fn get_fast_field_reader(&self, field: Field) -> io::Result<U32FastFieldReader> {
let field_entry = self.schema.get_field_entry(field);

View File

@@ -35,10 +35,6 @@ fn create_fieldnorms_writer(schema: &Schema) -> U32FastFieldsWriter {
U32FastFieldsWriter::new(u32_fields)
}
fn compute_field_norm(num_tokens: usize) -> u32 {
((350f32 / (1f32 + num_tokens as f32).sqrt()) as u32)
}
impl SegmentWriter {
@@ -109,11 +105,10 @@ impl SegmentWriter {
}
}
}
let field_norm = compute_field_norm(num_tokens);
self.fieldnorms_writer
.get_field_writer(field)
.map(|field_norms_writer| {
field_norms_writer.set_val(doc_id, field_norm)
field_norms_writer.set_val(doc_id, num_tokens as u32)
});
}

View File

@@ -54,7 +54,7 @@ impl IndexWriter {
let mut segment_writer = SegmentWriter::for_segment(segment.clone(), &schema_clone).unwrap();
segment_writer.add_document(&*doc, &schema_clone).unwrap();
for _ in 0..(225_000 - 1) {
for _ in 0..100_000 {
{
let queue = queue_output_clone.lock().unwrap();
match queue.recv() {

View File

@@ -214,9 +214,9 @@ mod tests {
let searcher = index.searcher().unwrap();
let segment_reader: &SegmentReader = searcher.segments().iter().next().unwrap();
let fieldnorms_reader = segment_reader.get_fieldnorms_reader(text_field).unwrap();
assert_eq!(fieldnorms_reader.get(0), 175);
assert_eq!(fieldnorms_reader.get(0), 3);
assert_eq!(fieldnorms_reader.get(1), 0);
assert_eq!(fieldnorms_reader.get(2), 202);
assert_eq!(fieldnorms_reader.get(2), 2);
}
}

View File

@@ -141,7 +141,11 @@ mod tests {
serializer.close().unwrap();
U32FastFieldReader::open(ReadOnlySource::Anonymous(data.copy_vec())).unwrap()
}
fn abs_diff(left: f32, right: f32) -> f32 {
(right - left).abs()
}
#[test]
pub fn test_union_postings() {
let left_fieldnorms = create_u32_fastfieldreader(Field(1), vec!(100,200,300));
@@ -156,14 +160,14 @@ mod tests {
);
assert!(union.next());
assert_eq!(union.doc(), 1);
assert_eq!(union.score(), 210f32);
assert!(abs_diff(union.score(), 2.182179f32) < 0.001);
assert!(union.next());
assert_eq!(union.doc(), 2);
assert_eq!(union.score(), 20f32);
assert!(abs_diff(union.score(), 0.2236068) < 0.001f32);
assert!(union.next());
assert_eq!(union.doc(), 3);
assert!(union.next());
assert_eq!(union.score(), 80f32);
assert!(abs_diff(union.score(), 0.8944272f32) < 0.001f32);
assert_eq!(union.doc(), 8);
assert!(!union.next());
}

View File

@@ -55,14 +55,15 @@ impl Query for MultiTermQuery {
impl MultiTermQuery {
fn scorer(&self, searcher: &Searcher) -> MultiTermScorer {
let num_docs = searcher.num_docs() as f32;
let idfs: Vec<f32> = self.terms.iter()
.map(|term| searcher.doc_freq(term))
.map(|doc_freq| {
if doc_freq == 0 {
return 1.
1.
}
else {
1.0 / (doc_freq as f32)
1. + ( num_docs / (doc_freq as f32) ).ln()
}
})
.collect();

View File

@@ -20,8 +20,10 @@ impl MultiTermScorer {
}
pub fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) {
self.score += ((term_freq * fieldnorm) as f32) * self.idf[term_ord];
self.num_fields += 1;
if term_freq > 0 {
self.score += (term_freq as f32 / fieldnorm as f32).sqrt() * self.idf[term_ord];
self.num_fields += 1;
}
}
fn coord(&self,) -> f32 {
@@ -49,23 +51,29 @@ mod tests {
use super::*;
use query::Scorer;
fn abs_diff(left: f32, right: f32) -> f32 {
(right - left).abs()
}
#[test]
pub fn test_multiterm_scorer() {
let mut multi_term_scorer = MultiTermScorer::new(vec!(1f32, 2f32), vec!(1f32, 4f32));
{
multi_term_scorer.update(0, 1, 1);
assert_eq!(multi_term_scorer.score(), 1f32);
multi_term_scorer.clear();
assert!(abs_diff(multi_term_scorer.score(), 1f32) < 0.001f32);
multi_term_scorer.clear();
}
{
multi_term_scorer.update(1, 1, 1);
assert_eq!(multi_term_scorer.score(), 4f32);
multi_term_scorer.clear();
multi_term_scorer.clear();
}
{
multi_term_scorer.update(0, 2, 1);
assert_eq!(multi_term_scorer.score(), 2f32);
assert!(abs_diff(multi_term_scorer.score(), 1.4142135) < 0.001f32);
multi_term_scorer.clear();
}
{

View File

@@ -15,7 +15,7 @@ pub enum ParsingError {
pub struct QueryParser {
schema: Schema,
default_field: Field,
default_fields: Vec<Field>,
}
pub enum StandardQuery {
@@ -35,24 +35,29 @@ impl Query for StandardQuery {
impl QueryParser {
pub fn new(schema: Schema,
default_field: Field) -> QueryParser {
default_fields: Vec<Field>) -> QueryParser {
QueryParser {
schema: schema,
default_field: default_field,
default_fields: default_fields,
}
}
// TODO check that the term is str.
// we only support str field for the moment
fn transform_literal(&self, literal: Literal) -> Result<Term, ParsingError> {
fn transform_literal(&self, literal: Literal) -> Result<Vec<Term>, ParsingError> {
match literal {
Literal::DefaultField(val) => {
Ok(Term::from_field_text(self.default_field, &val))
let terms = self.default_fields
.iter()
.cloned()
.map(|field| Term::from_field_text(field, &val))
.collect();
Ok(terms)
},
Literal::WithField(field_name, val) => {
match self.schema.get_field(&field_name) {
Some(field) => {
Ok(Term::from_field_text(field, &val))
Ok(vec!(Term::from_field_text(field, &val)))
}
None => {
Err(ParsingError::FieldDoesNotExist(field_name))
@@ -65,12 +70,16 @@ impl QueryParser {
pub fn parse_query(&self, query: &str) -> Result<StandardQuery, ParsingError> {
match parser(query_language).parse(query) {
Ok(literals) => {
let terms_result: Result<Vec<Term>, ParsingError> = literals.0.into_iter()
.map(|literal| self.transform_literal(literal))
.collect();
terms_result
.map(MultiTermQuery::new)
.map(StandardQuery::MultiTerm)
let mut terms_result: Vec<Term> = Vec::new();
for literal in literals.0.into_iter() {
let literal_terms = try!(self.transform_literal(literal));
terms_result.extend_from_slice(&literal_terms);
}
Ok(
StandardQuery::MultiTerm(
MultiTermQuery::new(terms_result)
)
)
}
Err(_) => {
Err(ParsingError::SyntaxError)