Safer interface for union_postings

This commit is contained in:
Paul Masurel
2016-08-06 18:18:45 +09:00
parent 0ee473f474
commit d4bbec6631
10 changed files with 231 additions and 208 deletions

View File

@@ -1,5 +1,4 @@
use Result;
use Error;
use std::path::{PathBuf, Path};
use schema::Schema;
use DocId;
@@ -250,16 +249,10 @@ impl Segment {
pub fn open_read(&self, component: SegmentComponent) -> Result<ReadOnlySource> {
let path = self.relative_path(component);
let directory_lock = self.index.directory.read();
match directory_lock {
Ok(directory) => {
directory.open_read(&path)
.map_err(From::from)
}
Err(_) => {
Err(Error::Poisoned)
}
}
let directory = try!(self.index.directory.read());
let source = try!(directory.open_read(&path));
Ok(source)
}
pub fn open_write(&self, component: SegmentComponent) -> Result<WritePtr> {

View File

@@ -5,8 +5,6 @@ use std::error;
use std::sync::PoisonError;
use directory::OpenError;
#[derive(Debug)]
pub enum Error {
OpenError(OpenError),
@@ -20,8 +18,6 @@ impl Error {
pub fn make_other<E: error::Error + 'static>(e: E) -> Error {
Error::Other(Box::new(e))
}
}
impl From<io::Error> for Error {
@@ -31,7 +27,7 @@ impl From<io::Error> for Error {
}
impl<Guard> From<PoisonError<Guard>> for Error {
fn from(poison_error: PoisonError<Guard>) -> Error {
fn from(_: PoisonError<Guard>) -> Error {
Error::Poisoned
}
}

View File

@@ -66,6 +66,7 @@ impl<'a> DocSet for SegmentPostings<'a> {
// goes to the next element.
// next needs to be called a first time to point to the correct element.
#[inline(always)]
fn advance(&mut self,) -> bool {
self.cur += Wrapping(1);
if self.cur.0 >= self.len {
@@ -77,6 +78,7 @@ impl<'a> DocSet for SegmentPostings<'a> {
return true;
}
#[inline(always)]
fn doc(&self,) -> DocId {
self.block_decoder.output(self.index_within_block())
}

View File

@@ -22,7 +22,7 @@ impl Ord for HeapItem {
}
pub struct UnionPostings<TPostings: Postings, TAccumulator: MultiTermAccumulator> {
fieldnorms_readers: Vec<U32FastFieldReader>,
fieldnorm_readers: Vec<U32FastFieldReader>,
postings: Vec<TPostings>,
term_frequencies: Vec<u32>,
queue: BinaryHeap<HeapItem>,
@@ -31,14 +31,9 @@ pub struct UnionPostings<TPostings: Postings, TAccumulator: MultiTermAccumulator
}
impl<TPostings: Postings, TAccumulator: MultiTermAccumulator> UnionPostings<TPostings, TAccumulator> {
pub fn new(fieldnorms_reader: Vec<U32FastFieldReader>, mut postings: Vec<TPostings>, scorer: TAccumulator) -> UnionPostings<TPostings, TAccumulator> {
let num_postings = postings.len();
assert_eq!(fieldnorms_reader.len(), num_postings);
for posting in &mut postings {
assert!(posting.advance());
}
let mut term_frequencies: Vec<u32> = iter::repeat(0u32).take(num_postings).collect();
fn new_non_empty(fieldnorm_readers: Vec<U32FastFieldReader>, postings: Vec<TPostings>, scorer: TAccumulator) -> UnionPostings<TPostings, TAccumulator> {
let mut term_frequencies: Vec<u32> = iter::repeat(0u32).take(postings.len()).collect();
let heap_items: Vec<HeapItem> = postings
.iter()
.map(|posting| {
@@ -50,9 +45,8 @@ impl<TPostings: Postings, TAccumulator: MultiTermAccumulator> UnionPostings<TPos
HeapItem(doc, ord as u32)
})
.collect();
UnionPostings {
fieldnorms_readers: fieldnorms_reader,
fieldnorm_readers: fieldnorm_readers,
postings: postings,
term_frequencies: term_frequencies,
queue: BinaryHeap::from(heap_items),
@@ -60,6 +54,18 @@ impl<TPostings: Postings, TAccumulator: MultiTermAccumulator> UnionPostings<TPos
scorer: scorer
}
}
pub fn new(postings_and_fieldnorms: Vec<(TPostings, U32FastFieldReader)>, scorer: TAccumulator) -> UnionPostings<TPostings, TAccumulator> {
let mut postings = Vec::new();
let mut fieldnorm_readers = Vec::new();
for (mut posting, fieldnorm_reader) in postings_and_fieldnorms {
if posting.advance() {
postings.push(posting);
fieldnorm_readers.push(fieldnorm_reader);
}
}
UnionPostings::new_non_empty(fieldnorm_readers, postings, scorer)
}
pub fn scorer(&self,) -> &TAccumulator {
@@ -80,7 +86,7 @@ impl<TPostings: Postings, TAccumulator: MultiTermAccumulator> UnionPostings<TPos
}
fn get_field_norm(&self, ord:usize, doc:DocId) -> u32 {
self.fieldnorms_readers[ord].get(doc)
self.fieldnorm_readers[ord].get(doc)
}
}
@@ -166,8 +172,10 @@ mod tests {
let right = VecPostings::from(vec!(1, 3, 8));
let multi_term_scorer = TfIdfScorer::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32));
let mut union = UnionPostings::new(
vec!(left_fieldnorms, right_fieldnorms),
vec!(left, right),
vec!(
(left, left_fieldnorms),
(right, right_fieldnorms),
),
multi_term_scorer
);
assert_eq!(union.next(), Some(1u32));

View File

@@ -1,16 +1,22 @@
mod query;
mod multi_term_query;
mod multi_term_scorer;
mod multi_term_explainer;
mod scorer;
mod query_parser;
mod explanation;
mod tfidf;
pub use self::query::Query;
pub use self::multi_term_query::MultiTermQuery;
pub use self::multi_term_scorer::MultiTermScorer;
pub use self::multi_term_scorer::TfIdfScorer;
pub use self::multi_term_scorer::MultiTermExplainScorer;
pub use self::multi_term_explainer::MultiTermExplainer;
pub use self::tfidf::TfIdfScorer;
pub use self::scorer::Scorer;
pub use self::query_parser::QueryParser;
pub use self::explanation::Explanation;
pub use self::multi_term_scorer::MultiTermAccumulator;
pub use self::multi_term_scorer::MultiTermAccumulator;

View File

@@ -0,0 +1,36 @@
use super::MultiTermAccumulator;
use super::MultiTermScorer;
use super::Explanation;
pub struct MultiTermExplainer<TScorer: MultiTermScorer + Sized> {
scorer: TScorer,
vals: Vec<(usize, u32, u32)>,
}
impl<TScorer: MultiTermScorer + Sized> MultiTermExplainer<TScorer> {
pub fn explain_score(&self,) -> Explanation {
self.scorer.explain(&self.vals)
}
}
impl<TScorer: MultiTermScorer + Sized> From<TScorer> for MultiTermExplainer<TScorer> {
fn from(multi_term_scorer: TScorer) -> MultiTermExplainer<TScorer> {
MultiTermExplainer {
scorer: multi_term_scorer,
vals: Vec::new(),
}
}
}
impl<TScorer: MultiTermScorer + Sized> MultiTermAccumulator for MultiTermExplainer<TScorer> {
fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) {
self.vals.push((term_ord, term_freq, fieldnorm));
self.scorer.update(term_ord, term_freq, fieldnorm);
}
fn clear(&mut self,) {
self.vals.clear();
self.scorer.clear();
}
}

View File

@@ -7,13 +7,12 @@ use core::searcher::Searcher;
use collector::Collector;
use SegmentLocalId;
use core::SegmentReader;
use query::MultiTermExplainScorer;
use query::MultiTermExplainer;
use postings::SegmentPostings;
use postings::UnionPostings;
use postings::DocSet;
use query::TfIdfScorer;
use postings::SkipResult;
use fastfield::U32FastFieldReader;
use ScoredDoc;
use query::Scorer;
use query::MultiTermAccumulator;
@@ -33,12 +32,14 @@ impl Query for MultiTermQuery {
searcher: &Searcher,
doc_address: &DocAddress) -> Result<Explanation> {
let segment_reader = &searcher.segments()[doc_address.segment_ord() as usize];
let multi_term_scorer = MultiTermExplainScorer::from(self.scorer(searcher));
let multi_term_scorer = MultiTermExplainer::from(self.scorer(searcher));
let mut timer_tree = TimerTree::new();
let mut postings = self.search_segment(
let mut postings = try!(
self.search_segment(
segment_reader,
multi_term_scorer,
timer_tree.open("explain"));
timer_tree.open("explain"))
);
match postings.skip_next(doc_address.doc()) {
SkipResult::Reached => {
let scorer = postings.scorer();
@@ -67,10 +68,12 @@ impl Query for MultiTermQuery {
let _ = segment_search_timer.open("set_segment");
try!(collector.set_segment(segment_ord as SegmentLocalId, &segment_reader));
}
let mut postings = self.search_segment(
let mut postings = try!(
self.search_segment(
segment_reader,
multi_term_scorer.clone(),
segment_search_timer.open("get_postings"));
segment_search_timer.open("get_postings"))
);
{
let _collection_timer = segment_search_timer.open("collection");
while postings.advance() {
@@ -123,21 +126,26 @@ impl MultiTermQuery {
}
}
fn search_segment<'a, 'b, TScorer: MultiTermAccumulator>(&'b self, reader: &'b SegmentReader, multi_term_scorer: TScorer, mut timer: OpenTimer<'a>) -> UnionPostings<SegmentPostings, TScorer> {
let mut segment_postings: Vec<SegmentPostings> = Vec::with_capacity(self.terms.len());
let mut fieldnorms_readers: Vec<U32FastFieldReader> = Vec::with_capacity(self.terms.len());
fn search_segment<'a, 'b, TScorer: MultiTermAccumulator>(
&'b self,
reader: &'b SegmentReader,
multi_term_scorer: TScorer,
mut timer: OpenTimer<'a>) -> Result<UnionPostings<SegmentPostings, TScorer>> {
let mut postings_and_fieldnorms = Vec::with_capacity(self.num_terms());
{
let mut decode_timer = timer.open("decode_all");
for term in &self.terms {
let _decode_one_timer = decode_timer.open("decode_one");
reader.read_postings(term)
.map(|postings| {
match reader.read_postings(term) {
Some(postings) => {
let field = term.get_field();
fieldnorms_readers.push(reader.get_fieldnorms_reader(field).unwrap());
segment_postings.push(postings);
});
let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field));
postings_and_fieldnorms.push((postings, fieldnorm_reader));
}
None => {}
}
}
}
UnionPostings::new(fieldnorms_readers, segment_postings, multi_term_scorer)
Ok(UnionPostings::new(postings_and_fieldnorms, multi_term_scorer))
}
}

View File

@@ -9,158 +9,3 @@ pub trait MultiTermAccumulator {
pub trait MultiTermScorer: Scorer + MultiTermAccumulator {
fn explain(&self, vals: &Vec<(usize, u32, u32)>) -> Explanation;
}
#[derive(Clone)]
pub struct TfIdfScorer {
coords: Vec<f32>,
idf: Vec<f32>,
score: f32,
num_fields: usize,
term_names: Option<Vec<String>>, //< only here for explain
}
pub struct MultiTermExplainScorer<TScorer: MultiTermScorer + Sized> {
scorer: TScorer,
vals: Vec<(usize, u32, u32)>,
}
impl<TScorer: MultiTermScorer + Sized> MultiTermExplainScorer<TScorer> {
pub fn explain_score(&self,) -> Explanation {
self.scorer.explain(&self.vals)
}
}
impl<TScorer: MultiTermScorer + Sized> From<TScorer> for MultiTermExplainScorer<TScorer> {
fn from(multi_term_scorer: TScorer) -> MultiTermExplainScorer<TScorer> {
MultiTermExplainScorer {
scorer: multi_term_scorer,
vals: Vec::new(),
}
}
}
impl<TScorer: MultiTermScorer + Sized> MultiTermAccumulator for MultiTermExplainScorer<TScorer> {
fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) {
self.vals.push((term_ord, term_freq, fieldnorm));
self.scorer.update(term_ord, term_freq, fieldnorm);
}
fn clear(&mut self,) {
self.vals.clear();
self.scorer.clear();
}
}
impl TfIdfScorer {
pub fn new(coords: Vec<f32>, idf: Vec<f32>) -> TfIdfScorer {
TfIdfScorer {
coords: coords,
idf: idf,
score: 0f32,
num_fields: 0,
term_names: None,
}
}
fn coord(&self,) -> f32 {
self.coords[self.num_fields]
}
pub fn set_term_names(&mut self, term_names: Vec<String>) {
self.term_names = Some(term_names);
}
fn term_name(&self, ord: usize) -> String {
match &self.term_names {
&Some(ref term_names_vec) => term_names_vec[ord].clone(),
&None => format!("Field({})", ord)
}
}
fn term_score(&self, term_ord: usize, term_freq: u32, field_norm: u32) -> f32 {
(term_freq as f32 / field_norm as f32).sqrt() * self.idf[term_ord]
}
}
impl Scorer for TfIdfScorer {
fn score(&self, ) -> f32 {
self.score * self.coord()
}
}
impl MultiTermScorer for TfIdfScorer {
fn explain(&self, vals: &Vec<(usize, u32, u32)>) -> Explanation {
let score = self.score();
let mut explanation = Explanation::with_val(score);
let formula_components: Vec<String> = vals.iter()
.map(|&(ord, _, _)| ord)
.map(|ord| format!("<score for ({}>", self.term_name(ord)))
.collect();
let formula = format!("<coord> * ({})", formula_components.join(" + "));
explanation.set_formula(&formula);
for &(ord, term_freq, field_norm) in vals.iter() {
let term_score = self.term_score(ord, term_freq, field_norm);
let term_explanation = explanation.add_child(&self.term_name(ord), term_score);
term_explanation.set_formula(" sqrt(<term_freq> / <field_norm>) * <idf>");
}
explanation
}
}
impl MultiTermAccumulator for TfIdfScorer {
fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) {
assert!(term_freq != 0u32);
self.score += self.term_score(term_ord, term_freq, fieldnorm);
self.num_fields += 1;
}
fn clear(&mut self,) {
self.score = 0f32;
self.num_fields = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use query::Scorer;
fn abs_diff(left: f32, right: f32) -> f32 {
(right - left).abs()
}
#[test]
pub fn test_multiterm_scorer() {
let mut tfidf_scorer = TfIdfScorer::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32));
{
tfidf_scorer.update(0, 1, 1);
assert!(abs_diff(tfidf_scorer.score(), 1f32) < 0.001f32);
tfidf_scorer.clear();
}
{
tfidf_scorer.update(1, 1, 1);
assert_eq!(tfidf_scorer.score(), 4f32);
tfidf_scorer.clear();
}
{
tfidf_scorer.update(0, 2, 1);
assert!(abs_diff(tfidf_scorer.score(), 1.4142135) < 0.001f32);
tfidf_scorer.clear();
}
{
tfidf_scorer.update(0, 1, 1);
tfidf_scorer.update(1, 1, 1);
assert_eq!(tfidf_scorer.score(), 10f32);
tfidf_scorer.clear();
}
}
}

View File

@@ -15,8 +15,5 @@ pub trait Query {
fn explain(
&self,
searcher: &Searcher,
doc_address: &DocAddress) -> Result<Explanation> {
// TODO check that the document is there or return an error.
panic!("Not implemented");
}
doc_address: &DocAddress) -> Result<Explanation>;
}

132
src/query/tfidf.rs Normal file
View File

@@ -0,0 +1,132 @@
use super::MultiTermAccumulator;
use super::Scorer;
use super::MultiTermScorer;
use super::Explanation;
#[derive(Clone)]
pub struct TfIdfScorer {
coords: Vec<f32>,
idf: Vec<f32>,
score: f32,
num_fields: usize,
term_names: Option<Vec<String>>, //< only here for explain
}
impl MultiTermAccumulator for TfIdfScorer {
#[inline(always)]
fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) {
assert!(term_freq != 0u32);
self.score += self.term_score(term_ord, term_freq, fieldnorm);
self.num_fields += 1;
}
#[inline(always)]
fn clear(&mut self,) {
self.score = 0f32;
self.num_fields = 0;
}
}
impl TfIdfScorer {
pub fn new(coords: Vec<f32>, idf: Vec<f32>) -> TfIdfScorer {
TfIdfScorer {
coords: coords,
idf: idf,
score: 0f32,
num_fields: 0,
term_names: None,
}
}
#[inline(always)]
fn coord(&self,) -> f32 {
self.coords[self.num_fields]
}
pub fn set_term_names(&mut self, term_names: Vec<String>) {
self.term_names = Some(term_names);
}
fn term_name(&self, ord: usize) -> String {
match &self.term_names {
&Some(ref term_names_vec) => term_names_vec[ord].clone(),
&None => format!("Field({})", ord)
}
}
#[inline(always)]
fn term_score(&self, term_ord: usize, term_freq: u32, field_norm: u32) -> f32 {
(term_freq as f32 / field_norm as f32).sqrt() * self.idf[term_ord]
}
}
impl Scorer for TfIdfScorer {
#[inline(always)]
fn score(&self, ) -> f32 {
self.score * self.coord()
}
}
impl MultiTermScorer for TfIdfScorer {
fn explain(&self, vals: &Vec<(usize, u32, u32)>) -> Explanation {
let score = self.score();
let mut explanation = Explanation::with_val(score);
let formula_components: Vec<String> = vals.iter()
.map(|&(ord, _, _)| ord)
.map(|ord| format!("<score for ({}>", self.term_name(ord)))
.collect();
let formula = format!("<coord> * ({})", formula_components.join(" + "));
explanation.set_formula(&formula);
for &(ord, term_freq, field_norm) in vals.iter() {
let term_score = self.term_score(ord, term_freq, field_norm);
let term_explanation = explanation.add_child(&self.term_name(ord), term_score);
term_explanation.set_formula(" sqrt(<term_freq> / <field_norm>) * <idf>");
}
explanation
}
}
#[cfg(test)]
mod tests {
use super::*;
use query::Scorer;
use query::MultiTermAccumulator;
fn abs_diff(left: f32, right: f32) -> f32 {
(right - left).abs()
}
#[test]
pub fn test_multiterm_scorer() {
let mut tfidf_scorer = TfIdfScorer::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32));
{
tfidf_scorer.update(0, 1, 1);
assert!(abs_diff(tfidf_scorer.score(), 1f32) < 0.001f32);
tfidf_scorer.clear();
}
{
tfidf_scorer.update(1, 1, 1);
assert_eq!(tfidf_scorer.score(), 4f32);
tfidf_scorer.clear();
}
{
tfidf_scorer.update(0, 2, 1);
assert!(abs_diff(tfidf_scorer.score(), 1.4142135) < 0.001f32);
tfidf_scorer.clear();
}
{
tfidf_scorer.update(0, 1, 1);
tfidf_scorer.update(1, 1, 1);
assert_eq!(tfidf_scorer.score(), 10f32);
tfidf_scorer.clear();
}
}
}