Closes #245 = BM25. (#260)

* Closes #245 = BM25.

Scores are the same as Lucene.

* Fixing travis conf
This commit is contained in:
Paul Masurel
2018-03-22 15:06:56 +09:00
committed by GitHub
parent e22f767fda
commit b7f8884246
8 changed files with 173 additions and 58 deletions

View File

@@ -26,7 +26,6 @@ before_script:
- export PATH=$HOME/.cargo/bin:$PATH
- cargo install cargo-update || echo "cargo-update already installed"
- cargo install cargo-travis || echo "cargo-travis already installed"
- cargo install-update -a # update outdated cached binaries
script:
- cargo build
- cargo test

View File

@@ -1,7 +1,9 @@
Tantivy 0.5.2
==========================
- Removed C code. Tantivy is now pure Rust.
- Removed C code. Tantivy is now pure Rust.
- BM25
- Approximate field norms encoded over 1 byte.
Tantivy 0.5.1
==========================

View File

@@ -4,7 +4,6 @@ use schema::Document;
use collector::Collector;
use common::TimerTree;
use query::Query;
use DocId;
use DocAddress;
use schema::{Field, Term};
use termdict::{TermDictionary, TermMerger};
@@ -33,20 +32,20 @@ impl Searcher {
}
/// Returns the overall number of documents in the index.
pub fn num_docs(&self) -> DocId {
pub fn num_docs(&self) -> u64 {
self.segment_readers
.iter()
.map(|segment_reader| segment_reader.num_docs())
.sum::<u32>()
.map(|segment_reader| segment_reader.num_docs() as u64)
.sum::<u64>()
}
/// Return the overall number of documents containing
/// the given term.
pub fn doc_freq(&self, term: &Term) -> u32 {
pub fn doc_freq(&self, term: &Term) -> u64 {
self.segment_readers
.iter()
.map(|segment_reader| segment_reader.inverted_index(term.field()).doc_freq(term))
.sum::<u32>()
.map(|segment_reader| segment_reader.inverted_index(term.field()).doc_freq(term) as u64)
.sum::<u64>()
}
/// Return the list of segment readers

View File

@@ -293,6 +293,15 @@ mod tests {
use rand::{Rng, SeedableRng, XorShiftRng};
use rand::distributions::{IndependentSample, Range};
pub fn assert_nearly_equals(expected: f32, val: f32) {
assert!(nearly_equals(val, expected), "Got {}, expected {}.", val, expected);
}
pub fn nearly_equals(a: f32, b: f32) -> bool {
(a - b).abs() < 0.0005 * (a + b).abs()
}
fn generate_array_with_seed(n: usize, ratio: f32, seed_val: u32) -> Vec<u32> {
let seed: &[u32; 4] = &[1, 2, 3, seed_val];
let mut rng: XorShiftRng = XorShiftRng::from_seed(*seed);

View File

@@ -17,18 +17,13 @@ pub(crate) type TermScorerNoDeletes = TermScorer<SegmentPostings<NoDelete>>;
mod tests {
use docset::DocSet;
use postings::SegmentPostings;
use query::{Query, Scorer};
use query::term_query::TermScorer;
use query::TermQuery;
use Index;
use schema::*;
use fieldnorm::FieldNormReader;
use schema::IndexRecordOption;
use tests::assert_nearly_equals;
use Term;
use schema::{TEXT, STRING, SchemaBuilder, IndexRecordOption};
use collector::TopCollector;
use query::{TermQuery, QueryParser, Query, Scorer};
fn abs_diff(left: f32, right: f32) -> f32 {
(right - left).abs()
}
#[test]
pub fn test_term_query_no_freq() {
@@ -57,22 +52,65 @@ mod tests {
let mut term_scorer = term_weight.scorer(segment_reader).unwrap();
assert!(term_scorer.advance());
assert_eq!(term_scorer.doc(), 0);
assert_eq!(term_scorer.score(), 0.30685282);
assert_eq!(term_scorer.score(), 0.28768212);
}
#[test]
pub fn test_term_scorer() {
let left_fieldnorms = FieldNormReader::from(vec![10, 4]);
assert_eq!(left_fieldnorms.fieldnorm(0), 10);
assert_eq!(left_fieldnorms.fieldnorm(1), 4);
let left = SegmentPostings::create_from_docs(&[1]);
let mut left_scorer = TermScorer {
idf: 0.30685282,
fieldnorm_reader_opt: Some(left_fieldnorms),
postings: left,
};
left_scorer.advance();
assert!(abs_diff(left_scorer.score(), 0.15342641) < 0.001f32);
pub fn test_term_weight() {
let mut schema_builder = SchemaBuilder::new();
let left_field = schema_builder.add_text_field("left", TEXT);
let right_field = schema_builder.add_text_field("right", TEXT);
let large_field = schema_builder.add_text_field("large", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 10_000_000).unwrap();
index_writer.add_document(doc!(
left_field => "left1 left2 left2 left2f2 left2f2 left3 abcde abcde abcde abcde abcde abcde abcde abcde abcde abcewde abcde abcde",
right_field => "right1 right2",
large_field => "large0 large1 large2 large3 large4 large5 large6 large7 large8 large9 large10 large11 large12 large13 large14 large15 large16 large17 large18 large19 large20 large21 large22 large23 large24 large25 large26 large27 large28 large29 large30 large31 large32 large33 large34 large35 large36 large37 large38 large39 large40 large41 large42 large43 large44 large45 large46 large47 large48 large49 large50 large51 large52 large53 large54 large55 large56 large57 large58 large59 large60 large61 large62 large63 large64 large65 large66 large67 large68 large69 large70 large71 large72 large73 large74 large75 large76 large77 large78 large79 large80 large81 large82 large83 large84 large85 large86 large87 large88 large89 large90 large91 large92 large93 large94 large95 large96 large97 large98 large99 large100 large101 large102 large103 large104 large105 large106 large107 large108 large109 large110 large111 large112 large113 large114 large115 large116 large117 large118 large119 large120 large121 large122 large123 large124 large125 large126 large127 large128 large129 large130 large131 large132 large133 large134 large135 large136 large137 large138 large139 large140 large141 large142 large143 large144 large145 large146 large147 large148 large149 large150 large151 large152 large153 large154 large155 large156 large157 large158 large159 large160 large161 large162 large163 large164 large165 large166 large167 large168 large169 large170 large171 large172 large173 large174 large175 large176 large177 large178 large179 large180 large181 large182 large183 large184 large185 large186 large187 large188 large189 large190 large191 large192 large193 large194 large195 large196 large197 large198 large199 large200 large201 large202 large203 large204 large205 large206 large207 large208 large209 large210 large211 large212 large213 large214 large215 large216 large217 large218 large219 large220 large221 large222 large223 large224 large225 large226 large227 large228 large229 large230 large231 large232 large233 large234 large235 large236 large237 large238 large239 large240 large241 large242 large243 large244 large245 large246 large247 large248 large249 large250 large251 large252 large253 large254 large255 large256 large257 large258 large259 large260 large261 large262 large263 large264 large265 large266 large267 large268 large269 large270 large271 large272 large273 large274 large275 large276 large277 large278 large279 large280 large281 large282 large283 large284 large285 large286"
));
index_writer.add_document(doc!(left_field => "left4 left1"));
index_writer.commit().unwrap();
}
index.load_searchers().unwrap();
let searcher = index.searcher();
{
let mut collector = TopCollector::with_limit(2);
let term = Term::from_field_text(left_field, "left2");
let term_query = TermQuery::new(term, IndexRecordOption::WithFreqs);
searcher.search(&term_query, &mut collector).unwrap();
let scored_docs = collector.score_docs();
assert_eq!(scored_docs.len(), 1);
let (score, _) = scored_docs[0];
assert_nearly_equals(0.77802235, score);
}
{
let mut collector = TopCollector::with_limit(2);
let term = Term::from_field_text(left_field, "left1");
let term_query = TermQuery::new(term, IndexRecordOption::WithFreqs);
searcher.search(&term_query, &mut collector).unwrap();
let scored_docs = collector.score_docs();
assert_eq!(scored_docs.len(), 2);
let (score1, _) = scored_docs[0];
assert_nearly_equals(0.27101856, score1);
let (score2, _) = scored_docs[1];
assert_nearly_equals(0.13736556, score2);
}
{
let query_parser = QueryParser::for_index(&index, vec![]);
let query = query_parser.parse_query("left:left2 left:left1").unwrap();
let mut collector = TopCollector::with_limit(2);
searcher.search(&*query, &mut collector).unwrap();
let scored_docs = collector.score_docs();
assert_eq!(scored_docs.len(), 2);
let (score1, _) = scored_docs[0];
assert_nearly_equals(0.9153879, score1);
let (score2, _) = scored_docs[1];
assert_nearly_equals(0.27101856, score2);
}
}
}

View File

@@ -36,17 +36,27 @@ impl TermQuery {
/// this method return a specific implementation.
/// This is useful for optimization purpose.
pub fn specialized_weight(&self, searcher: &Searcher, scoring_enabled: bool) -> TermWeight {
let mut total_num_tokens = 0;
let mut total_num_docs = 0;
for segment_reader in searcher.segment_readers() {
let inverted_index = segment_reader.inverted_index(self.term.field());
total_num_tokens += inverted_index.total_num_tokens();
total_num_docs += segment_reader.max_doc();
}
let average_field_norm = total_num_tokens as f32 / total_num_docs as f32;
let index_record_option = if scoring_enabled {
self.index_record_option
} else {
IndexRecordOption::Basic
};
TermWeight {
num_docs: searcher.num_docs(),
doc_freq: searcher.doc_freq(&self.term),
term: self.term.clone(),
index_record_option,
}
TermWeight::new(
searcher.doc_freq(&self.term),
searcher.num_docs(),
average_field_norm,
self.term.clone(),
index_record_option
)
}
}
@@ -55,3 +65,4 @@ impl Query for TermQuery {
Ok(box self.specialized_weight(searcher, scoring_enabled))
}
}

View File

@@ -7,9 +7,10 @@ use postings::Postings;
use fieldnorm::FieldNormReader;
pub struct TermScorer<TPostings: Postings> {
pub idf: Score,
pub fieldnorm_reader_opt: Option<FieldNormReader>,
pub postings: TPostings,
pub weight: f32,
pub cache: [f32; 256],
}
impl<TPostings: Postings> DocSet for TermScorer<TPostings> {
@@ -32,14 +33,16 @@ impl<TPostings: Postings> DocSet for TermScorer<TPostings> {
impl<TPostings: Postings> Scorer for TermScorer<TPostings> {
fn score(&mut self) -> Score {
let doc = self.postings.doc();
let tf = match self.fieldnorm_reader_opt {
Some(ref fieldnorm_reader) => {
let field_norm = fieldnorm_reader.fieldnorm(doc);
(self.postings.term_freq() as f32 / field_norm as f32)
}
None => self.postings.term_freq() as f32,
};
self.idf * tf.sqrt()
let doc = self.doc();
let fieldnorm_id = self.fieldnorm_reader_opt
.as_ref()
.map(|fieldnorm_reader| {
fieldnorm_reader.fieldnorm_id(doc)
})
.unwrap_or(0u8);
let norm = self.cache[fieldnorm_id as usize];
let term_freq = self.postings.term_freq() as f32;
self.weight * term_freq / (term_freq + norm)
}
}

View File

@@ -9,15 +9,18 @@ use super::term_scorer::TermScorer;
use fastfield::DeleteBitSet;
use postings::NoDelete;
use Result;
use fieldnorm::FieldNormReader;
use std::f32;
pub struct TermWeight {
pub(crate) num_docs: u32,
pub(crate) doc_freq: u32,
pub(crate) term: Term,
pub(crate) index_record_option: IndexRecordOption,
term: Term,
index_record_option: IndexRecordOption,
weight: f32,
cache: [f32; 256],
}
impl Weight for TermWeight {
fn scorer(&self, reader: &SegmentReader) -> Result<Box<Scorer>> {
let field = self.term.field();
let inverted_index = reader.inverted_index(field);
@@ -29,15 +32,17 @@ impl Weight for TermWeight {
scorer =
if let Some(segment_postings) = postings_opt {
box TermScorer {
idf: self.idf(),
fieldnorm_reader_opt,
postings: segment_postings,
weight: self.weight,
cache: self.cache
}
} else {
box TermScorer {
idf: 1f32,
fieldnorm_reader_opt: None,
postings: SegmentPostings::<NoDelete>::empty(),
weight: self.weight,
cache: self.cache
}
};
} else {
@@ -46,15 +51,17 @@ impl Weight for TermWeight {
scorer =
if let Some(segment_postings) = postings_opt {
box TermScorer {
idf: self.idf(),
fieldnorm_reader_opt,
postings: segment_postings,
weight: self.weight,
cache: self.cache
}
} else {
box TermScorer {
idf: 1f32,
fieldnorm_reader_opt: None,
postings: SegmentPostings::<NoDelete>::empty(),
weight: self.weight,
cache: self.cache
}
};
}
@@ -75,8 +82,55 @@ impl Weight for TermWeight {
}
}
const K1: f32 = 1.2;
const B: f32 = 0.75;
fn cached_tf_component(fieldnorm: u32, average_fieldnorm: f32) -> f32 {
K1 * (1f32 - B + B * fieldnorm as f32 / average_fieldnorm)
}
fn compute_tf_cache(average_fieldnorm: f32) -> [f32; 256] {
let mut cache = [0f32; 256];
for fieldnorm_id in 0..256 {
let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8);
cache[fieldnorm_id] = cached_tf_component(fieldnorm, average_fieldnorm);
}
cache
}
fn idf(doc_freq: u64, doc_count: u64) -> f32 {
let x = ((doc_count - doc_freq) as f32 + 0.5) / (doc_freq as f32 + 0.5);
(1f32 + x).ln()
}
impl TermWeight {
fn idf(&self) -> f32 {
1.0 + (self.num_docs as f32 / (self.doc_freq as f32 + 1.0)).ln()
pub fn new(doc_freq: u64,
doc_count: u64,
average_fieldnorm: f32,
term: Term,
index_record_option: IndexRecordOption) -> TermWeight {
let idf = idf(doc_freq, doc_count);
TermWeight {
term,
index_record_option,
weight: idf * (1f32 + K1),
cache: compute_tf_cache(average_fieldnorm),
}
}
}
#[cfg(test)]
mod tests {
use tests::assert_nearly_equals;
use super::idf;
#[test]
fn test_idf() {
assert_nearly_equals(idf(1, 2), 0.6931472);
}
}