mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-09 02:22:54 +00:00
More than one default field in query parser. fst version to 1.3.1.
This commit is contained in:
@@ -9,7 +9,7 @@ byteorder = "0.4"
|
||||
memmap = "0.2"
|
||||
lazy_static = "0.1.*"
|
||||
regex = "0.1"
|
||||
fst = { git = "https://github.com/BurntSushi/fst.git", rev = "e6c7eec6" }
|
||||
fst = "0.1.31"
|
||||
atomicwrites = "0.0.14"
|
||||
tempfile = "2.0.0"
|
||||
rustc-serialize = "0.3.16"
|
||||
|
||||
@@ -6,6 +6,7 @@ use collector::Collector;
|
||||
use std::io;
|
||||
use common::TimerTree;
|
||||
use query::Query;
|
||||
use DocId;
|
||||
use DocAddress;
|
||||
use schema::Term;
|
||||
|
||||
@@ -22,6 +23,13 @@ impl Searcher {
|
||||
let segment_reader = &self.segments[segment_local_id as usize];
|
||||
segment_reader.doc(doc_id)
|
||||
}
|
||||
|
||||
pub fn num_docs(&self,) -> DocId {
|
||||
self.segments
|
||||
.iter()
|
||||
.map(|segment_reader| segment_reader.num_docs())
|
||||
.fold(0u32, |acc, val| acc + val)
|
||||
}
|
||||
|
||||
pub fn doc_freq(&self, term: &Term) -> u32 {
|
||||
self.segments
|
||||
|
||||
@@ -42,6 +42,10 @@ impl SegmentReader {
|
||||
pub fn max_doc(&self) -> DocId {
|
||||
self.segment_info.max_doc
|
||||
}
|
||||
|
||||
pub fn num_docs(&self) -> DocId {
|
||||
self.segment_info.max_doc
|
||||
}
|
||||
|
||||
pub fn get_fast_field_reader(&self, field: Field) -> io::Result<U32FastFieldReader> {
|
||||
let field_entry = self.schema.get_field_entry(field);
|
||||
|
||||
@@ -35,10 +35,6 @@ fn create_fieldnorms_writer(schema: &Schema) -> U32FastFieldsWriter {
|
||||
U32FastFieldsWriter::new(u32_fields)
|
||||
}
|
||||
|
||||
fn compute_field_norm(num_tokens: usize) -> u32 {
|
||||
((350f32 / (1f32 + num_tokens as f32).sqrt()) as u32)
|
||||
}
|
||||
|
||||
impl SegmentWriter {
|
||||
|
||||
|
||||
@@ -109,11 +105,10 @@ impl SegmentWriter {
|
||||
}
|
||||
}
|
||||
}
|
||||
let field_norm = compute_field_norm(num_tokens);
|
||||
self.fieldnorms_writer
|
||||
.get_field_writer(field)
|
||||
.map(|field_norms_writer| {
|
||||
field_norms_writer.set_val(doc_id, field_norm)
|
||||
field_norms_writer.set_val(doc_id, num_tokens as u32)
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ impl IndexWriter {
|
||||
|
||||
let mut segment_writer = SegmentWriter::for_segment(segment.clone(), &schema_clone).unwrap();
|
||||
segment_writer.add_document(&*doc, &schema_clone).unwrap();
|
||||
for _ in 0..(225_000 - 1) {
|
||||
for _ in 0..100_000 {
|
||||
{
|
||||
let queue = queue_output_clone.lock().unwrap();
|
||||
match queue.recv() {
|
||||
|
||||
@@ -214,9 +214,9 @@ mod tests {
|
||||
let searcher = index.searcher().unwrap();
|
||||
let segment_reader: &SegmentReader = searcher.segments().iter().next().unwrap();
|
||||
let fieldnorms_reader = segment_reader.get_fieldnorms_reader(text_field).unwrap();
|
||||
assert_eq!(fieldnorms_reader.get(0), 175);
|
||||
assert_eq!(fieldnorms_reader.get(0), 3);
|
||||
assert_eq!(fieldnorms_reader.get(1), 0);
|
||||
assert_eq!(fieldnorms_reader.get(2), 202);
|
||||
assert_eq!(fieldnorms_reader.get(2), 2);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -141,7 +141,11 @@ mod tests {
|
||||
serializer.close().unwrap();
|
||||
U32FastFieldReader::open(ReadOnlySource::Anonymous(data.copy_vec())).unwrap()
|
||||
}
|
||||
|
||||
|
||||
fn abs_diff(left: f32, right: f32) -> f32 {
|
||||
(right - left).abs()
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_union_postings() {
|
||||
let left_fieldnorms = create_u32_fastfieldreader(Field(1), vec!(100,200,300));
|
||||
@@ -156,14 +160,14 @@ mod tests {
|
||||
);
|
||||
assert!(union.next());
|
||||
assert_eq!(union.doc(), 1);
|
||||
assert_eq!(union.score(), 210f32);
|
||||
assert!(abs_diff(union.score(), 2.182179f32) < 0.001);
|
||||
assert!(union.next());
|
||||
assert_eq!(union.doc(), 2);
|
||||
assert_eq!(union.score(), 20f32);
|
||||
assert!(abs_diff(union.score(), 0.2236068) < 0.001f32);
|
||||
assert!(union.next());
|
||||
assert_eq!(union.doc(), 3);
|
||||
assert!(union.next());
|
||||
assert_eq!(union.score(), 80f32);
|
||||
assert!(abs_diff(union.score(), 0.8944272f32) < 0.001f32);
|
||||
assert_eq!(union.doc(), 8);
|
||||
assert!(!union.next());
|
||||
}
|
||||
|
||||
@@ -55,14 +55,15 @@ impl Query for MultiTermQuery {
|
||||
impl MultiTermQuery {
|
||||
|
||||
fn scorer(&self, searcher: &Searcher) -> MultiTermScorer {
|
||||
let num_docs = searcher.num_docs() as f32;
|
||||
let idfs: Vec<f32> = self.terms.iter()
|
||||
.map(|term| searcher.doc_freq(term))
|
||||
.map(|doc_freq| {
|
||||
if doc_freq == 0 {
|
||||
return 1.
|
||||
1.
|
||||
}
|
||||
else {
|
||||
1.0 / (doc_freq as f32)
|
||||
1. + ( num_docs / (doc_freq as f32) ).ln()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -20,8 +20,10 @@ impl MultiTermScorer {
|
||||
}
|
||||
|
||||
pub fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) {
|
||||
self.score += ((term_freq * fieldnorm) as f32) * self.idf[term_ord];
|
||||
self.num_fields += 1;
|
||||
if term_freq > 0 {
|
||||
self.score += (term_freq as f32 / fieldnorm as f32).sqrt() * self.idf[term_ord];
|
||||
self.num_fields += 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn coord(&self,) -> f32 {
|
||||
@@ -49,23 +51,29 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use query::Scorer;
|
||||
|
||||
|
||||
fn abs_diff(left: f32, right: f32) -> f32 {
|
||||
(right - left).abs()
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_multiterm_scorer() {
|
||||
let mut multi_term_scorer = MultiTermScorer::new(vec!(1f32, 2f32), vec!(1f32, 4f32));
|
||||
{
|
||||
multi_term_scorer.update(0, 1, 1);
|
||||
assert_eq!(multi_term_scorer.score(), 1f32);
|
||||
multi_term_scorer.clear();
|
||||
assert!(abs_diff(multi_term_scorer.score(), 1f32) < 0.001f32);
|
||||
multi_term_scorer.clear();
|
||||
|
||||
}
|
||||
{
|
||||
multi_term_scorer.update(1, 1, 1);
|
||||
assert_eq!(multi_term_scorer.score(), 4f32);
|
||||
multi_term_scorer.clear();
|
||||
multi_term_scorer.clear();
|
||||
}
|
||||
{
|
||||
multi_term_scorer.update(0, 2, 1);
|
||||
assert_eq!(multi_term_scorer.score(), 2f32);
|
||||
assert!(abs_diff(multi_term_scorer.score(), 1.4142135) < 0.001f32);
|
||||
multi_term_scorer.clear();
|
||||
}
|
||||
{
|
||||
|
||||
@@ -15,7 +15,7 @@ pub enum ParsingError {
|
||||
|
||||
pub struct QueryParser {
|
||||
schema: Schema,
|
||||
default_field: Field,
|
||||
default_fields: Vec<Field>,
|
||||
}
|
||||
|
||||
pub enum StandardQuery {
|
||||
@@ -35,24 +35,29 @@ impl Query for StandardQuery {
|
||||
|
||||
impl QueryParser {
|
||||
pub fn new(schema: Schema,
|
||||
default_field: Field) -> QueryParser {
|
||||
default_fields: Vec<Field>) -> QueryParser {
|
||||
QueryParser {
|
||||
schema: schema,
|
||||
default_field: default_field,
|
||||
default_fields: default_fields,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO check that the term is str.
|
||||
// we only support str field for the moment
|
||||
fn transform_literal(&self, literal: Literal) -> Result<Term, ParsingError> {
|
||||
fn transform_literal(&self, literal: Literal) -> Result<Vec<Term>, ParsingError> {
|
||||
match literal {
|
||||
Literal::DefaultField(val) => {
|
||||
Ok(Term::from_field_text(self.default_field, &val))
|
||||
let terms = self.default_fields
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|field| Term::from_field_text(field, &val))
|
||||
.collect();
|
||||
Ok(terms)
|
||||
},
|
||||
Literal::WithField(field_name, val) => {
|
||||
match self.schema.get_field(&field_name) {
|
||||
Some(field) => {
|
||||
Ok(Term::from_field_text(field, &val))
|
||||
Ok(vec!(Term::from_field_text(field, &val)))
|
||||
}
|
||||
None => {
|
||||
Err(ParsingError::FieldDoesNotExist(field_name))
|
||||
@@ -65,12 +70,16 @@ impl QueryParser {
|
||||
pub fn parse_query(&self, query: &str) -> Result<StandardQuery, ParsingError> {
|
||||
match parser(query_language).parse(query) {
|
||||
Ok(literals) => {
|
||||
let terms_result: Result<Vec<Term>, ParsingError> = literals.0.into_iter()
|
||||
.map(|literal| self.transform_literal(literal))
|
||||
.collect();
|
||||
terms_result
|
||||
.map(MultiTermQuery::new)
|
||||
.map(StandardQuery::MultiTerm)
|
||||
let mut terms_result: Vec<Term> = Vec::new();
|
||||
for literal in literals.0.into_iter() {
|
||||
let literal_terms = try!(self.transform_literal(literal));
|
||||
terms_result.extend_from_slice(&literal_terms);
|
||||
}
|
||||
Ok(
|
||||
StandardQuery::MultiTerm(
|
||||
MultiTermQuery::new(terms_result)
|
||||
)
|
||||
)
|
||||
}
|
||||
Err(_) => {
|
||||
Err(ParsingError::SyntaxError)
|
||||
|
||||
Reference in New Issue
Block a user